├── mar3d
├── utils
│ ├── __init__.py
│ ├── visualizers
│ │ ├── __init__.py
│ │ ├── html_util.py
│ │ └── color_util.py
│ ├── .DS_Store
│ ├── __pycache__
│ │ ├── base.cpython-310.pyc
│ │ ├── base.cpython-311.pyc
│ │ ├── misc.cpython-310.pyc
│ │ ├── misc.cpython-311.pyc
│ │ ├── ops.cpython-310.pyc
│ │ ├── ops.cpython-311.pyc
│ │ ├── config.cpython-310.pyc
│ │ ├── config.cpython-311.pyc
│ │ ├── saving.cpython-310.pyc
│ │ ├── saving.cpython-311.pyc
│ │ ├── typing.cpython-310.pyc
│ │ ├── typing.cpython-311.pyc
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-311.pyc
│ │ ├── callbacks.cpython-310.pyc
│ │ ├── callbacks.cpython-311.pyc
│ │ ├── scheduler.cpython-310.pyc
│ │ ├── scheduler.cpython-311.pyc
│ │ ├── checkpoint.cpython-310.pyc
│ │ └── checkpoint.cpython-311.pyc
│ ├── typing.py
│ ├── checkpoint.py
│ ├── base.py
│ ├── scheduler.py
│ ├── config.py
│ ├── misc.py
│ ├── callbacks.py
│ └── ops.py
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-311.pyc
│ │ ├── objaverse.cpython-310.pyc
│ │ ├── objaverse.cpython-311.pyc
│ │ ├── objaversediff.cpython-310.pyc
│ │ └── objaversediff.cpython-311.pyc
│ └── objaverse.py
├── systems
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── base.cpython-310.pyc
│ │ ├── base.cpython-311.pyc
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-311.pyc
│ │ ├── diffloss.cpython-311.pyc
│ │ ├── mar_diffusion.cpython-310.pyc
│ │ ├── mar_diffusion.cpython-311.pyc
│ │ ├── shape_diffusion.cpython-310.pyc
│ │ └── shape_autoencoder.cpython-310.pyc
│ ├── base.py
│ ├── diffloss.py
│ └── mar_diffusion.py
├── models
│ ├── geometry
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── base.cpython-310.pyc
│ │ │ ├── base.cpython-311.pyc
│ │ │ ├── utils.cpython-310.pyc
│ │ │ ├── utils.cpython-311.pyc
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ └── __init__.cpython-311.pyc
│ │ └── base.py
│ ├── conditional_encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── base.cpython-310.pyc
│ │ │ ├── base.cpython-311.pyc
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── __init__.cpython-311.pyc
│ │ │ ├── clip_encoder.cpython-310.pyc
│ │ │ └── clip_encoder.cpython-311.pyc
│ │ ├── clip
│ │ │ ├── __pycache__
│ │ │ │ ├── modeling_clip.cpython-310.pyc
│ │ │ │ ├── modeling_clip.cpython-311.pyc
│ │ │ │ ├── modeling_conditional_clip.cpython-310.pyc
│ │ │ │ └── modeling_conditional_clip.cpython-311.pyc
│ │ │ └── modeling_conditional_clip.py
│ │ ├── base.py
│ │ └── clip_encoder.py
│ ├── .DS_Store
│ ├── __init__.py
│ ├── autoencoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── utils.cpython-310.pyc
│ │ │ ├── utils.cpython-311.pyc
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── __init__.cpython-311.pyc
│ │ │ ├── michelangelo_autoencoder.cpython-310.pyc
│ │ │ └── michelangelo_autoencoder.cpython-311.pyc
│ │ └── utils.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-311.pyc
│ │ └── diffloss.cpython-310.pyc
│ ├── transformers
│ │ ├── __pycache__
│ │ │ ├── utils.cpython-310.pyc
│ │ │ ├── utils.cpython-311.pyc
│ │ │ ├── attention.cpython-310.pyc
│ │ │ ├── attention.cpython-311.pyc
│ │ │ ├── perceiver_1d.cpython-310.pyc
│ │ │ └── perceiver_1d.cpython-311.pyc
│ │ ├── utils.py
│ │ ├── perceiver_1d.py
│ │ └── attention.py
│ └── diffloss.py
├── .DS_Store
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ └── __init__.cpython-311.pyc
└── __init__.py
├── configs
├── .DS_Store
├── shape-autoencoder
│ └── occ.yaml
└── mar-diffusion
│ └── mar.yaml
├── diffusion
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-311.pyc
│ ├── respace.cpython-310.pyc
│ ├── respace.cpython-311.pyc
│ ├── diffusion_utils.cpython-310.pyc
│ ├── diffusion_utils.cpython-311.pyc
│ ├── gaussian_diffusion.cpython-310.pyc
│ └── gaussian_diffusion.cpython-311.pyc
├── __init__.py
├── diffusion_utils.py
└── respace.py
├── train_diffusion.sh
├── README.md
└── launch.py
/mar3d/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import base
2 |
--------------------------------------------------------------------------------
/mar3d/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | objaverse
3 | )
--------------------------------------------------------------------------------
/mar3d/utils/visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/mar3d/systems/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | mar_diffusion
3 | )
--------------------------------------------------------------------------------
/mar3d/models/geometry/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | base
3 | )
4 |
--------------------------------------------------------------------------------
/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/configs/.DS_Store
--------------------------------------------------------------------------------
/mar3d/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/.DS_Store
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | clip_encoder,
3 | )
4 |
--------------------------------------------------------------------------------
/mar3d/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/.DS_Store
--------------------------------------------------------------------------------
/mar3d/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | autoencoders,
3 | conditional_encoders,
4 | )
--------------------------------------------------------------------------------
/mar3d/models/autoencoders/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | michelangelo_autoencoder,
3 | )
4 |
--------------------------------------------------------------------------------
/mar3d/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/.DS_Store
--------------------------------------------------------------------------------
/mar3d/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/base.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/base.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/base.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/base.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/misc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/misc.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/misc.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/misc.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/ops.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/ops.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/ops.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/ops.cpython-311.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/respace.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/respace.cpython-310.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/respace.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/respace.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/data/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/data/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/systems/__pycache__/base.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/base.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/systems/__pycache__/base.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/base.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/config.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/config.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/saving.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/saving.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/saving.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/saving.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/typing.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/typing.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/typing.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/typing.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/data/__pycache__/objaverse.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/objaverse.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/data/__pycache__/objaverse.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/objaverse.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/__pycache__/diffloss.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/__pycache__/diffloss.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/callbacks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/callbacks.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/callbacks.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/callbacks.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/scheduler.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/scheduler.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/scheduler.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/scheduler.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/data/__pycache__/objaversediff.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/objaversediff.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/data/__pycache__/objaversediff.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/objaversediff.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/systems/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/systems/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/systems/__pycache__/diffloss.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/diffloss.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/checkpoint.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/checkpoint.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/utils/__pycache__/checkpoint.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/checkpoint.cpython-311.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/diffusion_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/diffusion_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/diffusion_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/diffusion_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/geometry/__pycache__/base.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/base.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/geometry/__pycache__/base.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/base.cpython-311.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/geometry/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/geometry/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/systems/__pycache__/mar_diffusion.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/mar_diffusion.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/systems/__pycache__/mar_diffusion.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/mar_diffusion.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/systems/__pycache__/shape_diffusion.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/shape_diffusion.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/autoencoders/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/autoencoders/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/geometry/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/geometry/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/transformers/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/transformers/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/systems/__pycache__/shape_autoencoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/shape_autoencoder.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/autoencoders/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/autoencoders/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/transformers/__pycache__/attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/attention.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/transformers/__pycache__/attention.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/attention.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/__pycache__/base.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/base.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/__pycache__/base.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/base.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/transformers/__pycache__/perceiver_1d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/perceiver_1d.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/transformers/__pycache__/perceiver_1d.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/perceiver_1d.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/__pycache__/clip_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/clip_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/__pycache__/clip_encoder.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/clip_encoder.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-311.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-310.pyc
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-311.pyc
--------------------------------------------------------------------------------
/train_diffusion.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0 #0,1,2,3 #,4,5,6,7 #
2 | python launch.py --config ./configs/mar-diffusion/mar.yaml --train --gpu 0
3 | # python launch.py --config ./configs/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6.yaml --train --gpu 0
--------------------------------------------------------------------------------
/mar3d/models/transformers/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def init_linear(l, stddev):
4 | nn.init.normal_(l.weight, std=stddev)
5 | if l.bias is not None:
6 | nn.init.constant_(l.bias, 0.0)
7 |
8 | class MLP(nn.Module):
9 | def __init__(self, *,
10 | width: int,
11 | init_scale: float):
12 | super().__init__()
13 | self.width = width
14 | self.c_fc = nn.Linear(width, width * 4)
15 | self.c_proj = nn.Linear(width * 4, width)
16 | self.gelu = nn.GELU()
17 | init_linear(self.c_fc, init_scale)
18 | init_linear(self.c_proj, init_scale)
19 |
20 | def forward(self, x):
21 | return self.c_proj(self.gelu(self.c_fc(x)))
22 |
--------------------------------------------------------------------------------
/mar3d/utils/visualizers/html_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import io
3 | import base64
4 | import numpy as np
5 | from PIL import Image
6 |
7 |
8 | def to_html_frame(content):
9 |
10 | html_frame = f"""
11 |
12 |
13 | {content}
14 |
15 |
16 | """
17 |
18 | return html_frame
19 |
20 |
21 | def to_single_row_table(caption: str, content: str):
22 |
23 | table_html = f"""
24 |
25 | {caption}
26 |
27 | | {content} |
28 |
29 |
30 | """
31 |
32 | return table_html
33 |
34 |
35 | def to_image_embed_tag(image: np.ndarray):
36 |
37 | # Convert np.ndarray to bytes
38 | img = Image.fromarray(image)
39 | raw_bytes = io.BytesIO()
40 | img.save(raw_bytes, "PNG")
41 |
42 | # Encode bytes to base64
43 | image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
44 |
45 | image_tag = f"""
46 |
47 | """
48 |
49 | return image_tag
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## [CVPR 2025 Highlight]
2 | ## **MAR-3D:Progressive Masked Auto-regressor for High-Resolution 3D Generation**
3 |
4 |
5 | Please prepare you preprocessed inputs points and supervision sdf/occ following https://github.com/wyysf-98/CraftsMan3D
6 | (We are unable to release the preprocessing coed and data)
7 |
8 | root_dir should contain .npz files,image file mapping is in objaverse.py
9 |
10 | The trianing pipepline is also following CraftsMan3D's code base.
11 |
12 | ## Training:
13 |
14 | bash train_diffusion.sh
15 |
16 | ### Links
17 | 📄 **Paper**: [arXiv:2503.20519](https://arxiv.org/abs/2503.20519)
18 | 🌐 **Project Page**: [https://jinnan-chen.github.io/projects/MAR-3D/](https://jinnan-chen.github.io/projects/MAR-3D/)
19 |
20 | ### Authors
21 | Jinnan Chen, Lingting Zhu, Zeyu Hu, Shengju Qian, Yugang Chen, Xin Wang, Gim Hee Lee
22 |
23 | ### Citation
24 | ```bibtex
25 | @article{chen2025mar3d,
26 | title = {MAR-3D: Progressive Masked Auto-regressor for High-Resolution 3D Generation},
27 | author = {Chen, Jinnan and Zhu, Lingting and Hu, Zeyu and Qian, Shengju and
28 | Chen, Yugang and Wang, Xin and Lee, Gim Hee},
29 | journal = {arXiv preprint arXiv:2503.20519},
30 | year = {2025}
31 | }
32 | ```
33 |
--------------------------------------------------------------------------------
/mar3d/utils/typing.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains type annotations for the project, using
3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
5 |
6 | Two types of typing checking can be used:
7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
9 | """
10 |
11 | # Basic types
12 | from typing import (
13 | Any,
14 | Callable,
15 | Dict,
16 | Iterable,
17 | List,
18 | Literal,
19 | NamedTuple,
20 | NewType,
21 | Optional,
22 | Sized,
23 | Tuple,
24 | Type,
25 | TypeVar,
26 | Union,
27 | Sequence,
28 | )
29 |
30 | # Tensor dtype
31 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
32 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
33 |
34 | # Config type
35 | from omegaconf import DictConfig
36 |
37 | # PyTorch Tensor type
38 | from torch import Tensor
39 |
40 | # Runtime type checking decorator
41 | from typeguard import typechecked as typechecker
42 |
--------------------------------------------------------------------------------
/mar3d/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | __modules__ = {}
4 |
5 |
6 | def register(name):
7 | def decorator(cls):
8 | if name in __modules__:
9 | raise ValueError(
10 | f"Module {name} already exists! Names of extensions conflict!"
11 | )
12 | else:
13 | __modules__[name] = cls
14 | return cls
15 |
16 | return decorator
17 |
18 |
19 | def find(name):
20 | if name in __modules__:
21 | return __modules__[name]
22 | else:
23 | try:
24 | module_string = ".".join(name.split(".")[:-1])
25 | cls_name = name.split(".")[-1]
26 | module = importlib.import_module(module_string, package=None)
27 | return getattr(module, cls_name)
28 | except Exception as e:
29 | raise ValueError(f"Module {name} not found!")
30 |
31 |
32 | ### grammar sugar for logging utilities ###
33 | import logging
34 |
35 | logger = logging.getLogger("pytorch_lightning")
36 |
37 | from pytorch_lightning.utilities.rank_zero import (
38 | rank_zero_debug,
39 | rank_zero_info,
40 | rank_zero_only,
41 | )
42 |
43 | debug = rank_zero_debug
44 | info = rank_zero_info
45 |
46 |
47 | @rank_zero_only
48 | def warn(*args, **kwargs):
49 | logger.warn(*args, **kwargs)
50 |
51 |
52 | from . import data, models, systems
53 |
--------------------------------------------------------------------------------
/mar3d/models/transformers/perceiver_1d.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from mar3d.utils.typing import *
7 | from mar3d.utils.checkpoint import checkpoint
8 |
9 | from .utils import init_linear
10 | from .attention import ResidualAttentionBlock
11 |
12 |
13 | class Perceiver(nn.Module):
14 | def __init__(
15 | self,
16 | *,
17 | n_ctx: int,
18 | width: int,
19 | layers: int,
20 | heads: int,
21 | init_scale: float = 0.25,
22 | qkv_bias: bool = True,
23 | use_flash: bool = False,
24 | use_checkpoint: bool = False
25 | ):
26 | super().__init__()
27 | self.n_ctx = n_ctx
28 | self.width = width
29 | self.layers = layers
30 | self.resblocks = nn.ModuleList(
31 | [
32 | ResidualAttentionBlock(
33 | n_ctx=n_ctx,
34 | width=width,
35 | heads=heads,
36 | init_scale=init_scale,
37 | qkv_bias=qkv_bias,
38 | use_flash=use_flash,
39 | use_checkpoint=use_checkpoint
40 | )
41 | for _ in range(layers)
42 | ]
43 | )
44 |
45 | def forward(self, x: torch.Tensor):
46 | for block in self.resblocks:
47 | x = block(x)
48 | return x
--------------------------------------------------------------------------------
/mar3d/utils/visualizers/color_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | # Helper functions
6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
7 | colormap = plt.cm.get_cmap(colormap)
8 | if normalize:
9 | vmin = np.min(inp)
10 | vmax = np.max(inp)
11 |
12 | norm = plt.Normalize(vmin, vmax)
13 | return colormap(norm(inp))[:, :3]
14 |
15 |
16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
17 | # tex dims need to be power of two.
18 | array = np.ones((width, height, 3), dtype='float32')
19 |
20 | # width in texels of each checker
21 | checker_w = width / n_checkers_x
22 | checker_h = height / n_checkers_y
23 |
24 | for y in range(height):
25 | for x in range(width):
26 | color_key = int(x / checker_w) + int(y / checker_h)
27 | if color_key % 2 == 0:
28 | array[x, y, :] = [1., 0.874, 0.0]
29 | else:
30 | array[x, y, :] = [0., 0., 0.]
31 | return array
32 |
33 |
34 | def gen_circle(width=256, height=256):
35 | xx, yy = np.mgrid[:width, :height]
36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
37 | array = np.ones((width, height, 4), dtype='float32')
38 | array[:, :, 0] = (circle <= width)
39 | array[:, :, 1] = (circle <= width)
40 | array[:, :, 2] = (circle <= width)
41 | array[:, :, 3] = circle <= width
42 | return array
43 |
44 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Adopted from DiT, which is modified from OpenAI's diffusion repos
2 | # DiT: https://github.com/facebookresearch/DiT/diffusion
3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
6 |
7 | from . import gaussian_diffusion as gd
8 | from .respace import SpacedDiffusion, space_timesteps
9 |
10 |
11 | def create_diffusion(
12 | timestep_respacing,
13 | noise_schedule="linear",
14 | use_kl=False,
15 | sigma_small=False,
16 | predict_xstart=False,
17 | learn_sigma=True,
18 | rescale_learned_sigmas=False,
19 | diffusion_steps=1000
20 | ):
21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22 | if use_kl:
23 | loss_type = gd.LossType.RESCALED_KL
24 | elif rescale_learned_sigmas:
25 | loss_type = gd.LossType.RESCALED_MSE
26 | else:
27 | loss_type = gd.LossType.MSE
28 | if timestep_respacing is None or timestep_respacing == "":
29 | timestep_respacing = [diffusion_steps]
30 | return SpacedDiffusion(
31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32 | betas=betas,
33 | model_mean_type=(
34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35 | ),
36 | model_var_type=(
37 | (
38 | gd.ModelVarType.FIXED_LARGE
39 | if not sigma_small
40 | else gd.ModelVarType.FIXED_SMALL
41 | )
42 | if not learn_sigma
43 | else gd.ModelVarType.LEARNED_RANGE
44 | ),
45 | loss_type=loss_type
46 | # rescale_timesteps=rescale_timesteps,
47 | )
48 |
--------------------------------------------------------------------------------
/configs/shape-autoencoder/occ.yaml:
--------------------------------------------------------------------------------
1 | exp_root_dir: "outputs"
2 | name: "michelangelo-autoencoder/mul-l64-256-e64-ne8-nd16"
3 | tag: "${rmspace:n${data.n_samples}+${data.supervision_type}+rot${data.rotate}+noise${data.noise_sigma}+${system.shape_model.embed_type}+dsample${system.shape_model.use_downsample}+pfeat${system.shape_model.point_feats}+logits${system.loss.lambda_logits}+kl${system.loss.lambda_kl}+lr${system.optimizer.args.lr},_}"
4 | seed: 0
5 | # resume: model-64.ckpt
6 | data_type: "objaverse-datamodule"
7 | data:
8 | root_dir: ""
9 | data_type: "occupancy"
10 | n_samples: 20480
11 | noise_sigma: 0.
12 | rotate: False
13 | load_supervision: True
14 | supervision_type: "occupancy"
15 | n_supervision: 25600
16 | load_image: False # whether to load images
17 | load_caption: False # whether to load captions
18 | batch_size: 32
19 | num_workers: 8
20 |
21 | system_type: "shape-autoencoder-system"
22 | system:
23 | sample_posterior: true
24 | shape_model_type: "michelangelo-autoencoder"
25 | shape_model:
26 | num_latents: 256 # 1024
27 | embed_dim: 64
28 | point_feats: 3 # xyz + normal
29 | out_dim: 1 # only occupancy
30 | embed_type: "fourier"
31 | num_freqs: 8
32 | include_pi: false
33 | heads: 12
34 | width: 768
35 | num_encoder_layers: 8
36 | num_decoder_layers: 16
37 | use_ln_post: true
38 | init_scale: 0.25
39 | qkv_bias: true
40 | use_flash: true
41 | use_checkpoint: true
42 | use_downsample: false
43 |
44 | loggers:
45 | wandb:
46 | enable: false
47 | project: "mar3d"
48 | name: shape-autoencoder+${name}+${tag}
49 |
50 | loss:
51 | lambda_logits: 1.
52 | lambda_kl: 0.001
53 |
54 | optimizer:
55 | name: AdamW
56 | args:
57 | lr: 1.e-4
58 | betas: [0.9, 0.99]
59 | eps: 1.e-6
60 |
61 | scheduler:
62 | name: SequentialLR
63 | interval: step
64 | schedulers:
65 | - name: LinearLR
66 | interval: step
67 | args:
68 | start_factor: 1e-6
69 | end_factor: 1.0
70 | total_iters: 5000
71 | - name: CosineAnnealingLR
72 | interval: step
73 | args:
74 | T_max: 5000
75 | eta_min: 0.
76 | milestones: [5000]
77 |
78 | trainer:
79 | num_nodes: 1
80 | max_epochs: 100000
81 | log_every_n_steps: 10
82 | num_sanity_val_steps: 0
83 | # val_check_interval: 200
84 | check_val_every_n_epoch: 5
85 | enable_progress_bar: true
86 | precision: 16-mixed
87 | strategy: 'ddp_find_unused_parameters_true'
88 |
89 | checkpoint:
90 | save_last: true
91 | save_top_k: -1
92 | every_n_train_steps: 10000
93 |
--------------------------------------------------------------------------------
/mar3d/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
4 | """
5 |
6 | import torch
7 | from mar3d.utils.typing import *
8 |
9 | def checkpoint(
10 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
11 | inputs: Sequence[torch.Tensor],
12 | params: Iterable[torch.Tensor],
13 | flag: bool,
14 | use_deepspeed: bool = False
15 | ):
16 | """
17 | Evaluate a function without caching intermediate activations, allowing for
18 | reduced memory at the expense of extra compute in the backward pass.
19 | :param func: the function to evaluate.
20 | :param inputs: the argument sequence to pass to `func`.
21 | :param params: a sequence of parameters `func` depends on but does not
22 | explicitly take as arguments.
23 | :param flag: if False, disable gradient checkpointing.
24 | :param use_deepspeed: if True, use deepspeed
25 | """
26 | if flag:
27 | if use_deepspeed:
28 | import deepspeed
29 | return deepspeed.checkpointing.checkpoint(func, *inputs)
30 |
31 | args = tuple(inputs) + tuple(params)
32 | return CheckpointFunction.apply(func, len(inputs), *args)
33 | else:
34 | return func(*inputs)
35 |
36 |
37 | class CheckpointFunction(torch.autograd.Function):
38 | @staticmethod
39 | @torch.cuda.amp.custom_fwd
40 | def forward(ctx, run_function, length, *args):
41 | ctx.run_function = run_function
42 | ctx.input_tensors = list(args[:length])
43 | ctx.input_params = list(args[length:])
44 |
45 | with torch.no_grad():
46 | output_tensors = ctx.run_function(*ctx.input_tensors)
47 | return output_tensors
48 |
49 | @staticmethod
50 | @torch.cuda.amp.custom_bwd
51 | def backward(ctx, *output_grads):
52 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
53 | with torch.enable_grad():
54 | # Fixes a bug where the first op in run_function modifies the
55 | # Tensor storage in place, which is not allowed for detach()'d
56 | # Tensors.
57 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
58 | output_tensors = ctx.run_function(*shallow_copies)
59 | input_grads = torch.autograd.grad(
60 | output_tensors,
61 | ctx.input_tensors + ctx.input_params,
62 | output_grads,
63 | allow_unused=True,
64 | )
65 | del ctx.input_tensors
66 | del ctx.input_params
67 | del output_tensors
68 | return (None, None) + input_grads
--------------------------------------------------------------------------------
/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a Gaussian distribution discretizing to a
50 | given image.
51 | :param x: the target images. It is assumed that this was uint8 values,
52 | rescaled to the range [-1, 1].
53 | :param means: the Gaussian mean Tensor.
54 | :param log_scales: the Gaussian log stddev Tensor.
55 | :return: a tensor like x of log probabilities (in nats).
56 | """
57 | assert x.shape == means.shape == log_scales.shape
58 | centered_x = x - means
59 | inv_stdv = th.exp(-log_scales)
60 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
61 | cdf_plus = approx_standard_normal_cdf(plus_in)
62 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
63 | cdf_min = approx_standard_normal_cdf(min_in)
64 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
65 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
66 | cdf_delta = cdf_plus - cdf_min
67 | log_probs = th.where(
68 | x < -0.999,
69 | log_cdf_plus,
70 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
71 | )
72 | assert log_probs.shape == x.shape
73 | return log_probs
74 |
--------------------------------------------------------------------------------
/configs/mar-diffusion/mar.yaml:
--------------------------------------------------------------------------------
1 | exp_root_dir: "outputs"
2 | name: "image-to-shape-diffusion/mar"
3 | tag: "${rmspace:${system.shape_model_type}+n${data.n_samples}+noise${data.noise_sigma}+pfeat${system.shape_model.point_feats}+normemb${system.condition_model.normalize_embeds}+lr${system.optimizer.args.lr}+qkvbias${system.shape_model.qkv_bias}+nfreq${system.shape_model.num_freqs}+ln_post${system.shape_model.use_ln_post},_}"
4 | seed: 0
5 | data_type: "objaverse-datamodule"
6 | data:
7 | root_dir: [""] #sample points and sdf/occ.npz
8 | data_type: "sdf"
9 | n_samples: 20480
10 | noise_sigma: 0.
11 | load_supervision: false
12 | supervision_type: "occupancy"
13 | n_supervision: 4096
14 | load_image: True # whether to load images
15 | image_data_path: ""
16 | image_type: "rgb" # rgb, normal, mvrgb, mvnormal
17 | idx: [0,1,2,3,4,5,6,7]
18 | n_views: 1
19 | load_caption: false # whether to load captions
20 | batch_size: 12
21 | num_workers: 8
22 |
23 | system_type: "mar-diffusion-system"
24 | system:
25 | data_type: ${data_type}
26 | shape_model_type: "michelangelo-autoencoder"
27 | shape_model:
28 | num_latents: 256 #1024 #512 #256
29 | embed_dim: 64
30 | point_feats: 3
31 | out_dim: 1
32 | num_freqs: 8
33 | include_pi: false
34 | heads: 12
35 | width: 768
36 | num_encoder_layers: 16 #16 #8
37 | num_decoder_layers: 32 #32 #16
38 | use_ln_post: true
39 | init_scale: 0.25
40 | qkv_bias: true
41 | use_flash: true
42 | use_checkpoint: true
43 | use_downsample: false
44 |
45 | condition_model_type: "clip-embedder"
46 | condition_model:
47 | pretrained_model_name_or_path: "openai/clip-vit-large-patch14"
48 | encode_camera: false
49 | camera_embeds_dim: 32 # 16 * 2[sin, cos]
50 | n_views: ${data.n_views}
51 | empty_embeds_ratio: 0.1
52 | normalize_embeds: false
53 | dino: true
54 | zero_uncond_embeds: false
55 | loggers:
56 | wandb:
57 | enable: false
58 | project: "mar3d"
59 | name: image-to-shape-diffusion+${name}+${tag}
60 | loss:
61 | loss_type: "mse"
62 | lambda_diffusion: 1.
63 | optimizer:
64 | name: AdamW
65 | args:
66 | lr: 1.e-4
67 | betas: [0.9, 0.99]
68 | eps: 1.e-6
69 | scheduler:
70 | name: SequentialLR
71 | interval: step
72 | schedulers:
73 | - name: LinearLR
74 | interval: step
75 | args:
76 | start_factor: 1e-6
77 | end_factor: 1.0
78 | total_iters: 5000
79 | - name: CosineAnnealingLR
80 | interval: step
81 | args:
82 | T_max: 5000
83 | eta_min: 0.
84 | milestones: [5000]
85 |
86 | trainer:
87 | num_nodes: 1
88 | max_epochs: 100000
89 | log_every_n_steps: 1
90 | num_sanity_val_steps: 0
91 | check_val_every_n_epoch: 2
92 | enable_progress_bar: true
93 | precision: 16-mixed
94 | strategy: 'ddp_find_unused_parameters_true'
95 |
96 | checkpoint:
97 | save_last: true
98 | save_top_k: -1
99 | every_n_train_steps: 5000
--------------------------------------------------------------------------------
/mar3d/utils/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from mar3d.utils.config import parse_structured
7 | from mar3d.utils.misc import get_device, load_module_weights
8 | from mar3d.utils.typing import *
9 |
10 |
11 | class Configurable:
12 | @dataclass
13 | class Config:
14 | pass
15 |
16 | def __init__(self, cfg: Optional[dict] = None) -> None:
17 | super().__init__()
18 | self.cfg = parse_structured(self.Config, cfg)
19 |
20 |
21 | class Updateable:
22 | def do_update_step(
23 | self, epoch: int, global_step: int, on_load_weights: bool = False
24 | ):
25 | for attr in self.__dir__():
26 | if attr.startswith("_"):
27 | continue
28 | try:
29 | module = getattr(self, attr)
30 | except:
31 | continue # ignore attributes like property, which can't be retrived using getattr?
32 | if isinstance(module, Updateable):
33 | module.do_update_step(
34 | epoch, global_step, on_load_weights=on_load_weights
35 | )
36 | self.update_step(epoch, global_step, on_load_weights=on_load_weights)
37 |
38 | def do_update_step_end(self, epoch: int, global_step: int):
39 | for attr in self.__dir__():
40 | if attr.startswith("_"):
41 | continue
42 | try:
43 | module = getattr(self, attr)
44 | except:
45 | continue # ignore attributes like property, which can't be retrived using getattr?
46 | if isinstance(module, Updateable):
47 | module.do_update_step_end(epoch, global_step)
48 | self.update_step_end(epoch, global_step)
49 |
50 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
51 | # override this method to implement custom update logic
52 | # if on_load_weights is True, you should be careful doing things related to model evaluations,
53 | # as the models and tensors are not guarenteed to be on the same device
54 | pass
55 |
56 | def update_step_end(self, epoch: int, global_step: int):
57 | pass
58 |
59 |
60 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None:
61 | if isinstance(module, Updateable):
62 | module.do_update_step(epoch, global_step)
63 |
64 |
65 | def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None:
66 | if isinstance(module, Updateable):
67 | module.do_update_step_end(epoch, global_step)
68 |
69 |
70 | class BaseObject(Updateable):
71 | @dataclass
72 | class Config:
73 | pass
74 |
75 | cfg: Config # add this to every subclass of BaseObject to enable static type checking
76 |
77 | def __init__(
78 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
79 | ) -> None:
80 | super().__init__()
81 | self.cfg = parse_structured(self.Config, cfg)
82 | self.device = get_device()
83 | self.configure(*args, **kwargs)
84 |
85 | def configure(self, *args, **kwargs) -> None:
86 | pass
87 |
88 |
89 | class BaseModule(nn.Module, Updateable):
90 | @dataclass
91 | class Config:
92 | weights: Optional[str] = None
93 |
94 | cfg: Config # add this to every subclass of BaseModule to enable static type checking
95 |
96 | def __init__(
97 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
98 | ) -> None:
99 | super().__init__()
100 | self.cfg = parse_structured(self.Config, cfg)
101 | self.device = get_device()
102 | self.configure(*args, **kwargs)
103 | if self.cfg.weights is not None:
104 | # format: path/to/weights:module_name
105 | weights_path, module_name = self.cfg.weights.split(":")
106 | state_dict, epoch, global_step = load_module_weights(
107 | weights_path, module_name=module_name, map_location="cpu"
108 | )
109 | self.load_state_dict(state_dict)
110 | self.do_update_step(
111 | epoch, global_step, on_load_weights=True
112 | ) # restore states
113 | # dummy tensor to indicate model state
114 | self._dummy: Float[Tensor, "..."]
115 | self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False)
116 |
117 | def configure(self, *args, **kwargs) -> None:
118 | pass
119 |
--------------------------------------------------------------------------------
/mar3d/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import warnings
3 | from bisect import bisect_right
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.optim import lr_scheduler
8 |
9 | import mar3d
10 |
11 |
12 | def get_scheduler(name):
13 | if hasattr(lr_scheduler, name):
14 | return getattr(lr_scheduler, name)
15 | else:
16 | raise NotImplementedError
17 |
18 |
19 | def getattr_recursive(m, attr):
20 | for name in attr.split("."):
21 | m = getattr(m, name)
22 | return m
23 |
24 |
25 | def get_parameters(model, name):
26 | module = getattr_recursive(model, name)
27 | if isinstance(module, nn.Module):
28 | return module.parameters()
29 | elif isinstance(module, nn.Parameter):
30 | return module
31 | return []
32 | from itertools import chain
33 |
34 | def parse_optimizer_control(config, model):
35 | # import ipdb
36 | # ipdb.set_trace()
37 | params = model.denoiser_model.blocks[0].cross.parameters()
38 | for num in range(len(model.denoiser_model.blocks)):
39 | if num>0:
40 | params=chain(params,model.denoiser_model.blocks[num].cross.parameters())
41 |
42 | # if hasattr(config, "params"):
43 | # params = [
44 | # {"params": get_parameters(model, name), "name": name, **args}
45 | # for name, args in config.params.items()
46 | # ]
47 | # mar3d.debug(f"Specify optimizer params: {config.params}")
48 | # else:
49 | # params = model.parameters()
50 | if config.name in ["FusedAdam"]:
51 | import apex
52 | optim = getattr(apex.optimizers, config.name)(params, **config.args)
53 | elif config.name in ["Adan"]:
54 | from mar3d.systems import optimizers
55 |
56 | optim = getattr(optimizers, config.name)(params, **config.args)
57 | else:
58 | optim = getattr(torch.optim, config.name)(params, **config.args)
59 | return optim
60 |
61 | def parse_optimizer(config, model):
62 | if hasattr(config, "params"):
63 | params = [
64 | {"params": get_parameters(model, name), "name": name, **args}
65 | for name, args in config.params.items()
66 | ]
67 | mar3d.debug(f"Specify optimizer params: {config.params}")
68 | else:
69 | # params = model.parameters()
70 | params= [p for p in model.parameters() if p.requires_grad]
71 | # if config.name in ["FusedAdam"]:
72 | # import apex
73 |
74 | # optim = getattr(apex.optimizers, config.name)(params, **config.args)
75 | # elif config.name in ["Adan"]:
76 | # from mar3d.systems import optimizers
77 |
78 | # optim = getattr(optimizers, config.name)(params, **config.args)
79 | # else:
80 | optim = getattr(torch.optim, config.name)(params, **config.args)
81 | return optim
82 |
83 |
84 | def parse_scheduler_to_instance(config, optimizer):
85 | if config.name == "ChainedScheduler":
86 | schedulers = [
87 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
88 | ]
89 | scheduler = lr_scheduler.ChainedScheduler(schedulers)
90 | elif config.name == "Sequential":
91 | schedulers = [
92 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
93 | ]
94 | scheduler = lr_scheduler.SequentialLR(
95 | optimizer, schedulers, milestones=config.milestones
96 | )
97 | else:
98 | scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args)
99 | return scheduler
100 |
101 |
102 | def parse_scheduler(config, optimizer):
103 | interval = config.get("interval", "epoch")
104 | assert interval in ["epoch", "step"]
105 | if config.name == "SequentialLR":
106 | scheduler = {
107 | "scheduler": lr_scheduler.SequentialLR(
108 | optimizer,
109 | [
110 | parse_scheduler(conf, optimizer)["scheduler"]
111 | for conf in config.schedulers
112 | ],
113 | milestones=config.milestones,
114 | ),
115 | "interval": interval,
116 | }
117 | elif config.name == "ChainedScheduler":
118 | scheduler = {
119 | "scheduler": lr_scheduler.ChainedScheduler(
120 | [
121 | parse_scheduler(conf, optimizer)["scheduler"]
122 | for conf in config.schedulers
123 | ]
124 | ),
125 | "interval": interval,
126 | }
127 | else:
128 | scheduler = {
129 | "scheduler": get_scheduler(config.name)(optimizer, **config.args),
130 | "interval": interval,
131 | }
132 | return scheduler
133 |
--------------------------------------------------------------------------------
/mar3d/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from datetime import datetime
4 |
5 | from omegaconf import OmegaConf
6 |
7 | import mar3d
8 | from mar3d.utils.typing import *
9 |
10 | # ============ Register OmegaConf Recolvers ============= #
11 | OmegaConf.register_new_resolver(
12 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n)
13 | )
14 | OmegaConf.register_new_resolver("add", lambda a, b: a + b)
15 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b)
16 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
17 | OmegaConf.register_new_resolver("div", lambda a, b: a / b)
18 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b)
19 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p))
20 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: str(s).replace(" ", sub))
21 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)])
22 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0)
23 | OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0)
24 | OmegaConf.register_new_resolver("not", lambda s: not s)
25 | OmegaConf.register_new_resolver(
26 | "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0
27 | )
28 | # ======================================================= #
29 |
30 |
31 | def C_max(value: Any) -> float:
32 | if isinstance(value, int) or isinstance(value, float):
33 | pass
34 | else:
35 | value = config_to_primitive(value)
36 | if not isinstance(value, list):
37 | raise TypeError("Scalar specification only supports list, got", type(value))
38 | if len(value) >= 6:
39 | max_value = value[2]
40 | for i in range(4, len(value), 2):
41 | max_value = max(max_value, value[i])
42 | value = [value[0], value[1], max_value, value[3]]
43 | if len(value) == 3:
44 | value = [0] + value
45 | assert len(value) == 4
46 | start_step, start_value, end_value, end_step = value
47 | value = max(start_value, end_value)
48 | return value
49 |
50 |
51 | @dataclass
52 | class ExperimentConfig:
53 | name: str = "default"
54 | description: str = ""
55 | tag: str = ""
56 | seed: int = 0
57 | use_timestamp: bool = True
58 | timestamp: Optional[str] = None
59 | exp_root_dir: str = "outputs"
60 |
61 | ### these shouldn't be set manually
62 | exp_dir: str = "outputs/default"
63 | trial_name: str = "exp"
64 | trial_dir: str = "outputs/default/exp"
65 | n_gpus: int = 1
66 | ###
67 |
68 | resume: Optional[str] = None
69 |
70 | data_type: str = ""
71 | data: dict = field(default_factory=dict)
72 |
73 | system_type: str = ""
74 | system: dict = field(default_factory=dict)
75 |
76 | # accept pytorch-lightning trainer parameters
77 | # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api
78 | trainer: dict = field(default_factory=dict)
79 |
80 | # accept pytorch-lightning checkpoint callback parameters
81 | # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
82 | checkpoint: dict = field(default_factory=dict)
83 |
84 | def __post_init__(self):
85 | if not self.tag and not self.use_timestamp:
86 | raise ValueError("Either tag is specified or use_timestamp is True.")
87 | self.trial_name = self.tag
88 | # if resume from an existing config, self.timestamp should not be None
89 | if self.timestamp is None:
90 | self.timestamp = ""
91 | if self.use_timestamp:
92 | if self.n_gpus > 1:
93 | mar3d.warn(
94 | "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag."
95 | )
96 | else:
97 | self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S")
98 | self.trial_name += self.timestamp
99 | self.exp_dir = os.path.join(self.exp_root_dir, self.name)
100 | self.trial_dir = os.path.join(self.exp_dir, self.trial_name)
101 | os.makedirs(self.trial_dir, exist_ok=True)
102 |
103 |
104 | def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any:
105 | if from_string:
106 | yaml_confs = [OmegaConf.create(s) for s in yamls]
107 | else:
108 | yaml_confs = [OmegaConf.load(f) for f in yamls]
109 | cli_conf = OmegaConf.from_cli(cli_args)
110 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
111 | OmegaConf.resolve(cfg)
112 | assert isinstance(cfg, DictConfig)
113 | scfg = parse_structured(ExperimentConfig, cfg)
114 | return scfg
115 |
116 |
117 | def config_to_primitive(config, resolve: bool = True) -> Any:
118 | return OmegaConf.to_container(config, resolve=resolve)
119 |
120 |
121 | def dump_config(path: str, config) -> None:
122 | with open(path, "w") as fp:
123 | OmegaConf.save(config=config, f=fp)
124 |
125 |
126 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
127 | scfg = OmegaConf.structured(fields(**cfg))
128 | return scfg
--------------------------------------------------------------------------------
/mar3d/utils/misc.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import re
4 |
5 | import torch
6 | import torch.distributed as dist
7 | from packaging import version
8 |
9 | from mar3d.utils.config import config_to_primitive
10 | from mar3d.utils.typing import *
11 |
12 |
13 |
14 | def parse_version(ver: str):
15 | return version.parse(ver)
16 |
17 |
18 | def get_rank():
19 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
20 | # therefore LOCAL_RANK needs to be checked first
21 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
22 | for key in rank_keys:
23 | rank = os.environ.get(key)
24 | if rank is not None:
25 | return int(rank)
26 | return 0
27 |
28 | def get_world_size():
29 | world_size_keys = ("WORLD_SIZE", "SLURM_NTASKS", "JSM_NAMESPACE_SIZE")
30 | for key in world_size_keys:
31 | world_size = os.environ.get(key)
32 | if world_size is not None:
33 | return int(world_size)
34 | return 1
35 |
36 | def get_device():
37 | return torch.device(f"cuda:{get_rank()}")
38 |
39 |
40 | def load_module_weights(
41 | path, module_name=None, ignore_modules=None, map_location=None
42 | ) -> Tuple[dict, int, int]:
43 | if module_name is not None and ignore_modules is not None:
44 | raise ValueError("module_name and ignore_modules cannot be both set")
45 | if map_location is None:
46 | map_location = get_device()
47 |
48 | ckpt = torch.load(path, map_location=map_location)
49 | state_dict = ckpt["state_dict"]
50 | state_dict_to_load = state_dict
51 |
52 | if ignore_modules is not None:
53 | state_dict_to_load = {}
54 | for k, v in state_dict.items():
55 | ignore = any(
56 | [k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
57 | )
58 | if ignore:
59 | continue
60 | state_dict_to_load[k] = v
61 |
62 | if module_name is not None:
63 | state_dict_to_load = {}
64 | for k, v in state_dict.items():
65 | m = re.match(rf"^{module_name}\.(.*)$", k)
66 | if m is None:
67 | continue
68 | state_dict_to_load[m.group(1)] = v
69 |
70 | return state_dict_to_load, ckpt["epoch"], ckpt["global_step"]
71 |
72 |
73 | def C(value: Any, epoch: int, global_step: int) -> float:
74 | if isinstance(value, int) or isinstance(value, float):
75 | pass
76 | else:
77 | value = config_to_primitive(value)
78 | if not isinstance(value, list):
79 | raise TypeError("Scalar specification only supports list, got", type(value))
80 | if len(value) == 3:
81 | value = [0] + value
82 | assert len(value) == 4
83 | start_step, start_value, end_value, end_step = value
84 | if isinstance(end_step, int):
85 | current_step = global_step
86 | value = start_value + (end_value - start_value) * max(
87 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
88 | )
89 | elif isinstance(end_step, float):
90 | current_step = epoch
91 | value = start_value + (end_value - start_value) * max(
92 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
93 | )
94 | return value
95 |
96 |
97 | def cleanup():
98 | gc.collect()
99 | torch.cuda.empty_cache()
100 | tcnn.free_temporary_memory()
101 |
102 |
103 | def finish_with_cleanup(func: Callable):
104 | def wrapper(*args, **kwargs):
105 | out = func(*args, **kwargs)
106 | cleanup()
107 | return out
108 |
109 | return wrapper
110 |
111 |
112 | def _distributed_available():
113 | return torch.distributed.is_available() and torch.distributed.is_initialized()
114 |
115 |
116 | def barrier():
117 | if not _distributed_available():
118 | return
119 | else:
120 | torch.distributed.barrier()
121 |
122 |
123 | def broadcast(tensor, src=0):
124 | if not _distributed_available():
125 | return tensor
126 | else:
127 | torch.distributed.broadcast(tensor, src=src)
128 | return tensor
129 |
130 |
131 | def enable_gradient(model, enabled: bool = True) -> None:
132 | for param in model.parameters():
133 | param.requires_grad_(enabled)
134 |
135 |
136 | def all_gather_batch(tensors):
137 | """
138 | Performs all_gather operation on the provided tensors.
139 | """
140 | # Queue the gathered tensors
141 | world_size = get_world_size()
142 | # There is no need for reduction in the single-proc case
143 | if world_size == 1:
144 | if isinstance(tensors, list):
145 | return tensors
146 | return tensors
147 | if not isinstance(tensors, list):
148 | is_list = False
149 | tensors = [tensors]
150 | else:
151 | is_list = True
152 | output_tensor = []
153 | tensor_list = []
154 | for tensor in tensors:
155 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
156 | dist.all_gather(
157 | tensor_all,
158 | tensor,
159 | async_op=False # performance opt
160 | )
161 |
162 | tensor_list.append(tensor_all)
163 |
164 | for tensor_all in tensor_list:
165 | output_tensor.append(torch.cat(tensor_all, dim=0))
166 | if not is_list:
167 | return output_tensor[0]
168 | return output_tensor
--------------------------------------------------------------------------------
/mar3d/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 |
5 | import pytorch_lightning
6 |
7 | from mar3d.utils.config import dump_config
8 | from mar3d.utils.misc import parse_version
9 |
10 | if parse_version(pytorch_lightning.__version__) > parse_version("1.8"):
11 | from pytorch_lightning.callbacks import Callback
12 | else:
13 | from pytorch_lightning.callbacks.base import Callback
14 |
15 | from pytorch_lightning.callbacks.progress import TQDMProgressBar
16 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
17 |
18 |
19 | class VersionedCallback(Callback):
20 | def __init__(self, save_root, version=None, use_version=True):
21 | self.save_root = save_root
22 | self._version = version
23 | self.use_version = use_version
24 |
25 | @property
26 | def version(self) -> int:
27 | """Get the experiment version.
28 |
29 | Returns:
30 | The experiment version if specified else the next version.
31 | """
32 | if self._version is None:
33 | self._version = self._get_next_version()
34 | return self._version
35 |
36 | def _get_next_version(self):
37 | existing_versions = []
38 | if os.path.isdir(self.save_root):
39 | for f in os.listdir(self.save_root):
40 | bn = os.path.basename(f)
41 | if bn.startswith("version_"):
42 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "")
43 | existing_versions.append(int(dir_ver))
44 | if len(existing_versions) == 0:
45 | return 0
46 | return max(existing_versions) + 1
47 |
48 | @property
49 | def savedir(self):
50 | if not self.use_version:
51 | return self.save_root
52 | return os.path.join(
53 | self.save_root,
54 | self.version
55 | if isinstance(self.version, str)
56 | else f"version_{self.version}",
57 | )
58 |
59 |
60 | class CodeSnapshotCallback(VersionedCallback):
61 | def __init__(self, save_root, version=None, use_version=True):
62 | super().__init__(save_root, version, use_version)
63 |
64 | def get_file_list(self):
65 | return [
66 | b.decode()
67 | for b in set(
68 | subprocess.check_output(
69 | 'git ls-files -- ":!:load/*"', shell=True
70 | ).splitlines()
71 | )
72 | | set( # hard code, TODO: use config to exclude folders or files
73 | subprocess.check_output(
74 | "git ls-files --others --exclude-standard", shell=True
75 | ).splitlines()
76 | )
77 | ]
78 |
79 | @rank_zero_only
80 | def save_code_snapshot(self):
81 | os.makedirs(self.savedir, exist_ok=True)
82 | for f in self.get_file_list():
83 | if not os.path.exists(f) or os.path.isdir(f):
84 | continue
85 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
86 | shutil.copyfile(f, os.path.join(self.savedir, f))
87 |
88 | def on_fit_start(self, trainer, pl_module):
89 | try:
90 | self.save_code_snapshot()
91 | except:
92 | rank_zero_warn(
93 | "Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
94 | )
95 |
96 |
97 | class ConfigSnapshotCallback(VersionedCallback):
98 | def __init__(self, config_path, config, save_root, version=None, use_version=True):
99 | super().__init__(save_root, version, use_version)
100 | self.config_path = config_path
101 | self.config = config
102 |
103 | @rank_zero_only
104 | def save_config_snapshot(self):
105 | os.makedirs(self.savedir, exist_ok=True)
106 | dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config)
107 | shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml"))
108 |
109 | def on_fit_start(self, trainer, pl_module):
110 | self.save_config_snapshot()
111 |
112 |
113 | class CustomProgressBar(TQDMProgressBar):
114 | def get_metrics(self, *args, **kwargs):
115 | # don't show the version number
116 | items = super().get_metrics(*args, **kwargs)
117 | items.pop("v_num", None)
118 | return items
119 |
120 |
121 | class ProgressCallback(Callback):
122 | def __init__(self, save_path):
123 | super().__init__()
124 | self.save_path = save_path
125 | self._file_handle = None
126 |
127 | @property
128 | def file_handle(self):
129 | if self._file_handle is None:
130 | self._file_handle = open(self.save_path, "w")
131 | return self._file_handle
132 |
133 | @rank_zero_only
134 | def write(self, msg: str) -> None:
135 | self.file_handle.seek(0)
136 | self.file_handle.truncate()
137 | self.file_handle.write(msg)
138 | self.file_handle.flush()
139 |
140 | @rank_zero_only
141 | def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
142 | self.write(
143 | f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%"
144 | )
145 |
146 | @rank_zero_only
147 | def on_validation_start(self, trainer, pl_module):
148 | self.write(f"Rendering validation image ...")
149 |
150 | @rank_zero_only
151 | def on_test_start(self, trainer, pl_module):
152 | self.write(f"Rendering video ...")
153 |
154 | @rank_zero_only
155 | def on_predict_start(self, trainer, pl_module):
156 | self.write(f"Exporting mesh assets ...")
157 |
--------------------------------------------------------------------------------
/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/mar3d/utils/ops.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import mar3d
10 | from mar3d.utils.typing import *
11 |
12 |
13 | def dot(x, y):
14 | return torch.sum(x * y, -1, keepdim=True)
15 |
16 |
17 | def reflect(x, n):
18 | return 2 * dot(x, n) * n - x
19 |
20 |
21 | ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
22 |
23 |
24 | def scale_tensor(
25 | dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
26 | ):
27 | if inp_scale is None:
28 | inp_scale = (0, 1)
29 | if tgt_scale is None:
30 | tgt_scale = (0, 1)
31 | if isinstance(tgt_scale, Tensor):
32 | assert dat.shape[-1] == tgt_scale.shape[-1]
33 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
34 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
35 | return dat
36 |
37 |
38 | def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
39 | if chunk_size <= 0:
40 | return func(*args, **kwargs)
41 | B = None
42 | for arg in list(args) + list(kwargs.values()):
43 | if isinstance(arg, torch.Tensor):
44 | B = arg.shape[0]
45 | break
46 | assert (
47 | B is not None
48 | ), "No tensor found in args or kwargs, cannot determine batch size."
49 | out = defaultdict(list)
50 | out_type = None
51 | # max(1, B) to support B == 0
52 | for i in range(0, max(1, B), chunk_size):
53 | out_chunk = func(
54 | *[
55 | arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
56 | for arg in args
57 | ],
58 | **{
59 | k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
60 | for k, arg in kwargs.items()
61 | },
62 | )
63 | if out_chunk is None:
64 | continue
65 | out_type = type(out_chunk)
66 | if isinstance(out_chunk, torch.Tensor):
67 | out_chunk = {0: out_chunk}
68 | elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
69 | chunk_length = len(out_chunk)
70 | out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
71 | elif isinstance(out_chunk, dict):
72 | pass
73 | else:
74 | print(
75 | f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
76 | )
77 | exit(1)
78 | for k, v in out_chunk.items():
79 | v = v if torch.is_grad_enabled() else v.detach()
80 | out[k].append(v)
81 |
82 | if out_type is None:
83 | return None
84 |
85 | out_merged: Dict[Any, Optional[torch.Tensor]] = {}
86 | for k, v in out.items():
87 | if all([vv is None for vv in v]):
88 | # allow None in return value
89 | out_merged[k] = None
90 | elif all([isinstance(vv, torch.Tensor) for vv in v]):
91 | out_merged[k] = torch.cat(v, dim=0)
92 | else:
93 | raise TypeError(
94 | f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
95 | )
96 |
97 | if out_type is torch.Tensor:
98 | return out_merged[0]
99 | elif out_type in [tuple, list]:
100 | return out_type([out_merged[i] for i in range(chunk_length)])
101 | elif out_type is dict:
102 | return out_merged
103 |
104 |
105 | def randn_tensor(
106 | shape: Union[Tuple, List],
107 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
108 | device: Optional["torch.device"] = None,
109 | dtype: Optional["torch.dtype"] = None,
110 | layout: Optional["torch.layout"] = None,
111 | ):
112 | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
113 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
114 | is always created on the CPU.
115 | """
116 | # device on which tensor is created defaults to device
117 | rand_device = device
118 | batch_size = shape[0]
119 |
120 | layout = layout or torch.strided
121 | device = device or torch.device("cpu")
122 |
123 | if generator is not None:
124 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
125 | if gen_device_type != device.type and gen_device_type == "cpu":
126 | rand_device = "cpu"
127 | if device != "mps":
128 | logger.info(
129 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
130 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
131 | f" slighly speed up this function by passing a generator that was created on the {device} device."
132 | )
133 | elif gen_device_type != device.type and gen_device_type == "cuda":
134 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
135 |
136 | # make sure generator list of length 1 is treated like a non-list
137 | if isinstance(generator, list) and len(generator) == 1:
138 | generator = generator[0]
139 |
140 | if isinstance(generator, list):
141 | shape = (1,) + shape[1:]
142 | latents = [
143 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
144 | for i in range(batch_size)
145 | ]
146 | latents = torch.cat(latents, dim=0).to(device)
147 | else:
148 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
149 |
150 | return latents
151 |
152 |
153 | def generate_dense_grid_points(
154 | bbox_min: np.ndarray,
155 | bbox_max: np.ndarray,
156 | octree_depth: int,
157 | indexing: str = "ij"
158 | ):
159 | length = bbox_max - bbox_min
160 | num_cells = np.exp2(octree_depth)
161 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
162 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
163 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
164 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
165 | xyz = np.stack((xs, ys, zs), axis=-1)
166 | xyz = xyz.reshape(-1, 3)
167 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
168 |
169 | return xyz, grid_size, length
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/base.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | from PIL import Image
6 | from dataclasses import dataclass
7 | from torchvision.transforms import Normalize
8 | from torchvision.transforms import InterpolationMode
9 | from torchvision.transforms.transforms import _interpolation_modes_from_int
10 |
11 | from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor
12 | from transformers.utils import ModelOutput
13 | from typing import Iterable, Optional, Union, List
14 |
15 | import mar3d
16 | from mar3d.utils.base import BaseModule
17 | from mar3d.utils.typing import *
18 |
19 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
20 |
21 |
22 | class BaseEmbedder(BaseModule):
23 | @dataclass
24 | class Config(BaseModule.Config):
25 | pretrained_model_name_or_path: Optional[str] = None # the pretrained model name or path
26 |
27 | encode_camera: bool = False # whether to encode camera
28 | camera_embeds_type: str = "sincos" # the type of camera embeds
29 | camera_embeds_dim: Optional[int] = None # the dimension of camera embeds
30 | n_views: int = 1 # the number of views
31 |
32 | dino: bool = False
33 | nor: bool = False
34 | empty_embeds_ratio: float = 0.1 # the ratio of empty embeds
35 | zero_uncond_embeds: bool = True
36 | bbox: bool =False
37 | normalize_embeds: bool = False # whether to normalize the embeds
38 |
39 | cfg: Config
40 |
41 | def configure(self) -> None:
42 | super().configure()
43 |
44 | if self.cfg.encode_camera:
45 | self.distance = 1.0
46 | self.register_buffer(
47 | "cameras",
48 | torch.as_tensor([
49 | [[1, 0, 0, 0],
50 | [0, 0, -1, -self.distance],
51 | [0, 1, 0, 0],
52 | [0, 0, 0, 1]], # front to back
53 |
54 | [[0, 0, 1, self.distance],
55 | [1, 0, 0, 0],
56 | [0, 1, 0, 0],
57 | [0, 0, 0, 1]], # right to left
58 |
59 | [[-1, 0, 0, 0],
60 | [0, 0, 1, self.distance],
61 | [0, 1, 0, 0],
62 | [0, 0, 0, 1]], # back to front
63 |
64 | [[0, 0, -1, -self.distance],
65 | [-1, 0, 0, 0],
66 | [0, 1, 0, 0],
67 | [0, 0, 0, 1]], # left to right
68 | ], dtype=torch.float32),
69 | )
70 |
71 | def encode_image(self, images: Iterable[Optional[ImageType]], camera_embeds: Optional[torch.Tensor] = None, **kwargs) -> torch.FloatTensor:
72 | pass
73 |
74 | def encode_text(self, texts: List[str], **kwargs) -> torch.FloatTensor:
75 | pass
76 |
77 | def encode_camera(self, c2ws: torch.Tensor):
78 | if self.cfg.camera_embeds_type == "sincos":
79 | assert c2ws.shape[-1] == 4 and c2ws.shape[-2] == 4, f"Invalid c2ws shape: {c2ws.shape}"
80 | c2ws = c2ws.view(-1, 16)
81 | return torch.cat([torch.sin(c2ws), torch.cos(c2ws)], dim=-1)
82 | else:
83 | raise NotImplementedError(f"Unknown camera_embeds_type: {self.cfg.camera_embeds_type}")
84 |
85 | def post_process_embeds(self, text_embeds, visual_embeds):
86 | bs = text_embeds.shape[0] if text_embeds is not None else visual_embeds.shape[0]
87 |
88 | if self.cfg.normalize_embeds:
89 | # post-process the text/visual embeds
90 | if text_embeds is not None:
91 | text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
92 | if visual_embeds is not None:
93 | visual_embeds = visual_embeds / visual_embeds.norm(dim=-1, keepdim=True)
94 |
95 | assert text_embeds is not None or visual_embeds is not None
96 |
97 | # return text_embeds, visual_embeds
98 | if text_embeds is not None and visual_embeds is not None:
99 | return [text_embeds, visual_embeds]
100 | # return torch.cat([text_embeds, visual_embeds], dim=1)
101 | elif text_embeds is not None:
102 | return text_embeds
103 | else:
104 | return visual_embeds
105 |
106 | def forward(self, batch):
107 | bs = batch["surface"].shape[0]
108 |
109 | text_embeds, visual_embeds = None, None
110 |
111 | bbox_drop=False
112 | if random.random() < self.cfg.empty_embeds_ratio:
113 | if "text_input_ids" in batch or "text_embeds" in batch:
114 | if self.empty_text_embeds is None:
115 | if not self.cfg.zero_uncond_embeds:
116 | self.empty_text_embeds = self.encode_text([""]).detach() # [1, 77, 768]
117 | text_embeds = self.empty_text_embeds.repeat(bs, 1, 1)
118 | if "image" in batch or "image_embeds" in batch:
119 | visual_embeds = self.empty_image_embeds.repeat(bs, 1, 1)
120 | if "normal" in batch:
121 | visual_embeds_normal=self.empty_image_embeds.repeat(bs, 1, 1)
122 | if self.cfg.bbox:
123 | bbox_drop=True
124 |
125 |
126 | elif "mvimages" in batch or "mvimage_embeds" in batch:
127 | visual_embeds = self.empty_image_embeds.unsqueeze(1).repeat(bs, 1, 1, 1)
128 | else:
129 | # for text inputs
130 | if "text_input_ids" in batch:
131 | # import ipdb
132 | # ipdb.set_trace()
133 | text_embeds = self.encode_text(batch["text_input_ids"])
134 |
135 | # for visual inputs
136 | if "image" in batch:
137 | if self.cfg.encode_camera:
138 | visual_embeds = self.encode_image(batch["image"], cameras=batch["c2w"])
139 | else:
140 | visual_embeds = self.encode_image(batch["image"])
141 | if "normal" in batch:
142 | # if self.cfg.encode_camera:
143 | # visual_embeds = self.encode_image(batch["image"], cameras=batch["c2w"])
144 | # else:
145 |
146 | visual_embeds_normal = self.encode_image(batch["normal"])
147 |
148 | elif "mvimages" in batch:
149 |
150 | n_views = batch["mvimages"].shape[1]
151 | if self.cfg.encode_camera:
152 | visual_embeds = self.encode_image(
153 | batch["mvimages"].view(-1, *batch["mvimages"].shape[-3:]), \
154 | cameras=batch["c2ws"]).view(bs, n_views, *self.empty_image_embeds.shape[-2:])
155 | else:
156 | visual_embeds = self.encode_image(
157 | batch["mvimages"].view(-1, *batch["mvimages"].shape[-3:])).view(bs, n_views, *self.empty_image_embeds.shape[-2:])
158 |
159 | return self.post_process_embeds(text_embeds, visual_embeds)
160 |
--------------------------------------------------------------------------------
/mar3d/models/geometry/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import mar3d
9 | from .utils import (
10 | Mesh,
11 | IsosurfaceHelper,
12 | MarchingCubeCPUHelper,
13 | MarchingTetrahedraHelper,
14 | )
15 |
16 | from mar3d.utils.base import BaseModule
17 | from mar3d.utils.ops import chunk_batch, scale_tensor
18 | from mar3d.utils.typing import *
19 |
20 | class BaseGeometry(BaseModule):
21 | @dataclass
22 | class Config(BaseModule.Config):
23 | pass
24 |
25 | cfg: Config
26 |
27 | @staticmethod
28 | def create_from(
29 | other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs
30 | ) -> "BaseGeometry":
31 | raise TypeError(
32 | f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}"
33 | )
34 |
35 | def export(self, *args, **kwargs) -> Dict[str, Any]:
36 | return {}
37 |
38 |
39 | class BaseImplicitGeometry(BaseGeometry):
40 | @dataclass
41 | class Config(BaseGeometry.Config):
42 | radius: float = 1.0
43 | isosurface: bool = True
44 | isosurface_method: str = "mt"
45 | isosurface_resolution: int = 128
46 | isosurface_threshold: Union[float, str] = 0.0
47 | isosurface_chunk: int = 0
48 | isosurface_coarse_to_fine: bool = True
49 | isosurface_deformable_grid: bool = False
50 | isosurface_remove_outliers: bool = True
51 | isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01
52 |
53 | cfg: Config
54 |
55 | def configure(self) -> None:
56 | self.bbox: Float[Tensor, "2 3"]
57 | self.register_buffer(
58 | "bbox",
59 | torch.as_tensor(
60 | [
61 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
62 | [self.cfg.radius, self.cfg.radius, self.cfg.radius],
63 | ],
64 | dtype=torch.float32,
65 | ),
66 | )
67 | self.isosurface_helper: Optional[IsosurfaceHelper] = None
68 | self.unbounded: bool = False
69 |
70 | def _initilize_isosurface_helper(self):
71 | if self.cfg.isosurface and self.isosurface_helper is None:
72 | if self.cfg.isosurface_method == "mc-cpu":
73 | self.isosurface_helper = MarchingCubeCPUHelper(
74 | self.cfg.isosurface_resolution
75 | ).to(self.device)
76 | elif self.cfg.isosurface_method == "mt":
77 | self.isosurface_helper = MarchingTetrahedraHelper(
78 | self.cfg.isosurface_resolution,
79 | f"load/tets/{self.cfg.isosurface_resolution}_tets.npz",
80 | ).to(self.device)
81 | else:
82 | raise AttributeError(
83 | "Unknown isosurface method {self.cfg.isosurface_method}"
84 | )
85 |
86 | def forward(
87 | self, points: Float[Tensor, "*N Di"], output_normal: bool = False
88 | ) -> Dict[str, Float[Tensor, "..."]]:
89 | raise NotImplementedError
90 |
91 | def forward_field(
92 | self, points: Float[Tensor, "*N Di"]
93 | ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
94 | # return the value of the implicit field, could be density / signed distance
95 | # also return a deformation field if the grid vertices can be optimized
96 | raise NotImplementedError
97 |
98 | def forward_level(
99 | self, field: Float[Tensor, "*N 1"], threshold: float
100 | ) -> Float[Tensor, "*N 1"]:
101 | # return the value of the implicit field, where the zero level set represents the surface
102 | raise NotImplementedError
103 |
104 | def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh:
105 | def batch_func(x):
106 | # scale to bbox as the input vertices are in [0, 1]
107 | field, deformation = self.forward_field(
108 | scale_tensor(
109 | x.to(bbox.device), self.isosurface_helper.points_range, bbox
110 | ),
111 | )
112 | field = field.to(
113 | x.device
114 | ) # move to the same device as the input (could be CPU)
115 | if deformation is not None:
116 | deformation = deformation.to(x.device)
117 | return field, deformation
118 |
119 | assert self.isosurface_helper is not None
120 |
121 | field, deformation = chunk_batch(
122 | batch_func,
123 | self.cfg.isosurface_chunk,
124 | self.isosurface_helper.grid_vertices,
125 | )
126 |
127 | threshold: float
128 |
129 | if isinstance(self.cfg.isosurface_threshold, float):
130 | threshold = self.cfg.isosurface_threshold
131 | elif self.cfg.isosurface_threshold == "auto":
132 | eps = 1.0e-5
133 | threshold = field[field > eps].mean().item()
134 | mar3d.info(
135 | f"Automatically determined isosurface threshold: {threshold}"
136 | )
137 | else:
138 | raise TypeError(
139 | f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}"
140 | )
141 |
142 | level = self.forward_level(field, threshold)
143 | mesh: Mesh = self.isosurface_helper(level, deformation=deformation)
144 | mesh.v_pos = scale_tensor(
145 | mesh.v_pos, self.isosurface_helper.points_range, bbox
146 | ) # scale to bbox as the grid vertices are in [0, 1]
147 | mesh.add_extra("bbox", bbox)
148 |
149 | if self.cfg.isosurface_remove_outliers:
150 | # remove outliers components with small number of faces
151 | # only enabled when the mesh is not differentiable
152 | mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold)
153 |
154 | return mesh
155 |
156 | def isosurface(self) -> Mesh:
157 | if not self.cfg.isosurface:
158 | raise NotImplementedError(
159 | "Isosurface is not enabled in the current configuration"
160 | )
161 | self._initilize_isosurface_helper()
162 | if self.cfg.isosurface_coarse_to_fine:
163 | mar3d.debug("First run isosurface to get a tight bounding box ...")
164 | with torch.no_grad():
165 | mesh_coarse = self._isosurface(self.bbox)
166 | vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0)
167 | vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0])
168 | vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1])
169 | mar3d.debug("Run isosurface again with the tight bounding box ...")
170 | mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True)
171 | else:
172 | mesh = self._isosurface(self.bbox)
173 | return mesh
174 |
175 |
176 | class BaseExplicitGeometry(BaseGeometry):
177 | @dataclass
178 | class Config(BaseGeometry.Config):
179 | radius: float = 1.0
180 |
181 | cfg: Config
182 |
183 | def configure(self) -> None:
184 | self.bbox: Float[Tensor, "2 3"]
185 | self.register_buffer(
186 | "bbox",
187 | torch.as_tensor(
188 | [
189 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
190 | [self.cfg.radius, self.cfg.radius, self.cfg.radius],
191 | ],
192 | dtype=torch.float32,
193 | ),
194 | )
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch import nn
4 | import numpy as np
5 | from PIL import Image
6 | from einops import rearrange
7 | from dataclasses import dataclass
8 | from torchvision.transforms import Normalize
9 | from torchvision.transforms import InterpolationMode
10 | from torchvision.transforms.transforms import _interpolation_modes_from_int
11 | from torchvision import transforms
12 |
13 | from transformers import CLIPTokenizer, CLIPImageProcessor
14 | from transformers.utils import ModelOutput
15 | from typing import Iterable, Optional, Union, List
16 |
17 | import mar3d
18 | from mar3d.utils.typing import *
19 | from .clip.modeling_clip import CLIPModel
20 | from .clip.modeling_conditional_clip import ConditionalCLIPModel
21 | from .base import BaseEmbedder, ImageType
22 | @dataclass
23 | class CLIPEmbedOutput(ModelOutput):
24 | last_hidden_state: torch.FloatTensor = None
25 | pooler_output: torch.FloatTensor = None
26 | embeds: torch.FloatTensor = None
27 |
28 | @mar3d.register("clip-embedder")
29 | class CLIPEmbedder(BaseEmbedder):
30 |
31 | @dataclass
32 | class Config(BaseEmbedder.Config):
33 | freeze_modulation: bool = False
34 | config_path: str = ''
35 |
36 | cfg: Config
37 |
38 | def configure(self) -> None:
39 | super().configure()
40 | if not self.cfg.encode_camera:
41 | self.model: CLIPModel = CLIPModel.from_pretrained(self.cfg.pretrained_model_name_or_path)
42 | else:
43 | if self.cfg.pretrained_model_name_or_path == '':
44 | assert self.cfg.config_path is not None, "The config path should be provided"
45 | conditional_clip_config = ConditionalCLIPModel.config_class.from_json_file(self.cfg.config_path)
46 | conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim
47 | self.model: CLIPModel = ConditionalCLIPModel(conditional_clip_config)
48 | else:
49 | conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained(
50 | self.cfg.pretrained_model_name_or_path,
51 | )
52 | conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim
53 | self.model: CLIPModel = ConditionalCLIPModel.from_pretrained(
54 | self.cfg.pretrained_model_name_or_path,
55 | vision_config=conditional_clip_config.vision_config
56 | )
57 |
58 | self.tokenizer = None
59 | self.image_preprocess = CLIPImageProcessor()
60 | self.transform = transforms.Compose(
61 | [
62 | transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
63 | transforms.CenterCrop(224), # crop a (224, 224) square
64 | transforms.Normalize(
65 | mean=[0.48145466, 0.4578275, 0.40821073],
66 | std=[0.26862954, 0.26130258, 0.27577711],
67 | ),
68 | ]
69 | )
70 |
71 | self.logit_scale = self.model.logit_scale.exp()
72 |
73 | if self.cfg.zero_uncond_embeds:
74 | self.empty_text_embeds = torch.zeros((1, 77, 768)).detach()
75 | self.empty_image_embeds = torch.zeros((self.cfg.n_views, 257, 1024)).detach()
76 |
77 | else:
78 | try:
79 | self.empty_text_embeds = self.encode_text([""]).detach() # [1, 77, 768]
80 | except:
81 | self.empty_text_embeds = None
82 |
83 |
84 | if self.cfg.encode_camera:
85 | self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3), self.cameras[:self.cfg.n_views]).detach()
86 | else:
87 | self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3)).detach()
88 |
89 | # Freeze the model parameters
90 |
91 | self.model.eval()
92 | for k, p in self.model.named_parameters():
93 | ks = k.split('.')
94 | if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation:
95 | p.requires_grad_(True)
96 | else:
97 | p.requires_grad_(False)
98 | # def encode_image_normal(self, images):
99 |
100 | # return normal_feats
101 | def encode_image(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor:
102 | camera_embeds = None
103 |
104 | if isinstance(images, (np.ndarray, torch.Tensor)): # for training process
105 | assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]"
106 | do_rescale = False
107 | if self.cfg.encode_camera:
108 | assert cameras is not None, "The cameras should be provided"
109 | camera_embeds = self.encode_camera(cameras)
110 | pixel_values = self.transform(images.permute(0, 3, 1, 2))
111 | else: # for inference process
112 | do_rescale = True
113 | if self.cfg.encode_camera:
114 | if cameras is None:
115 | bs = len(images) // self.cfg.n_views
116 | cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.model.device)
117 | camera_embeds = self.encode_camera(cameras)
118 | pixel_values = self.image_preprocess.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values
119 |
120 | if force_none_camera_embeds:
121 | camera_embeds = None
122 |
123 |
124 | packed = False
125 |
126 |
127 | if pixel_values.ndim == 4:
128 | packed = True
129 | pixel_values = pixel_values.unsqueeze(1)
130 | if camera_embeds is not None:
131 | camera_embeds = camera_embeds.unsqueeze(1)
132 |
133 | if self.cfg.encode_camera and camera_embeds is not None:
134 | vision_outputs = self.model.vision_model(
135 | pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"),
136 | condition=rearrange(camera_embeds, "B N C -> (B N) C")
137 | )
138 | else:
139 | vision_outputs = self.model.vision_model(
140 | pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"),
141 | )
142 |
143 |
144 |
145 | if return_dict:
146 | pooler_output = vision_outputs[1] # pooled_output
147 | image_features = self.model.visual_projection(pooler_output)
148 | return CLIPEmbedOutput(
149 | last_hidden_state=vision_outputs.last_hidden_state,
150 | pooler_output=pooler_output,
151 | embeds=image_features
152 | )
153 |
154 | else:
155 | return vision_outputs.last_hidden_state
156 |
157 |
158 | @torch.no_grad()
159 | def encode_text(self, text_inputs: torch.Tensor, return_dict: bool = False) -> torch.FloatTensor:
160 | if self.tokenizer is None:
161 | self.tokenizer = CLIPTokenizer.from_pretrained(self.cfg.pretrained_model_name_or_path)
162 |
163 | if isinstance(text_inputs, list):
164 | text_inputs = self.tokenizer(
165 | text_inputs,
166 | max_length=self.tokenizer.model_max_length,
167 | padding="max_length",
168 | return_tensors="pt"
169 | ).input_ids
170 | text_outputs = self.model.text_model(input_ids=text_inputs.to(self.model.device))
171 |
172 | pooler_output = text_outputs[1] # pooled_output
173 | text_features = self.model.text_projection(pooler_output)
174 |
175 | if return_dict:
176 | return CLIPEmbedOutput(
177 | last_hidden_state=text_outputs.last_hidden_state,
178 | pooler_output=pooler_output,
179 | embeds=text_features
180 | )
181 | else:
182 | return text_outputs.last_hidden_state
--------------------------------------------------------------------------------
/mar3d/systems/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 |
4 | import pytorch_lightning as pl
5 | import torch.nn.functional as F
6 |
7 | import mar3d
8 | from mar3d.utils.base import (
9 | Updateable,
10 | update_end_if_possible,
11 | update_if_possible,
12 | )
13 | from mar3d.utils.scheduler import parse_optimizer, parse_scheduler,parse_optimizer_control
14 | from mar3d.utils.config import parse_structured
15 | from mar3d.utils.misc import C, cleanup, get_device, load_module_weights
16 | from mar3d.utils.saving import SaverMixin
17 | from mar3d.utils.typing import *
18 |
19 |
20 | class BaseSystem(pl.LightningModule, Updateable, SaverMixin):
21 | @dataclass
22 | class Config:
23 | loggers: dict = field(default_factory=dict)
24 | loss: dict = field(default_factory=dict)
25 | optimizer: dict = field(default_factory=dict)
26 | scheduler: Optional[dict] = None
27 | weights: Optional[str] = None
28 | weights_ignore_modules: Optional[List[str]] = None
29 | cleanup_after_validation_step: bool = False
30 | cleanup_after_test_step: bool = False
31 |
32 | pretrained_model_path: Optional[str] = None
33 | strict_load: bool = True
34 | cfg: Config
35 |
36 | def __init__(self, cfg, resumed=False) -> None:
37 | super().__init__()
38 | self.cfg = parse_structured(self.Config, cfg)
39 | self._save_dir: Optional[str] = None
40 | self._resumed: bool = resumed
41 | self._resumed_eval: bool = False
42 | self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0}
43 | if "loggers" in cfg:
44 | self.create_loggers(cfg.loggers)
45 |
46 | self.configure()
47 | if self.cfg.weights is not None:
48 | self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules)
49 | self.post_configure()
50 |
51 |
52 |
53 |
54 | def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True):
55 | """自定义加载状态字典的逻辑"""
56 | model_state_dict = self.state_dict()
57 |
58 | # 新的状态字典,用于存储形状匹配的参数
59 | new_state_dict = {}
60 |
61 | for name, param in state_dict.items():
62 | if name in model_state_dict:
63 | if param.shape == model_state_dict[name].shape:
64 | new_state_dict[name] = param
65 | else:
66 | print(f"Skipping parameter {name} due to shape mismatch. "
67 | f"Loaded shape: {param.shape}, "
68 | f"Model shape: {model_state_dict[name].shape}")
69 | else:
70 | print(f"Parameter {name} not found in current model")
71 |
72 | print(f"Loaded {len(new_state_dict)} parameters with matching shapes.")
73 |
74 | # 调用父类的 load_state_dict 方法,使用我们过滤后的状态字典
75 | return super().load_state_dict(new_state_dict, strict=False)
76 |
77 |
78 | def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None):
79 | state_dict, epoch, global_step = load_module_weights(
80 | weights, ignore_modules=ignore_modules, map_location="cpu"
81 | )
82 | self.load_state_dict(state_dict, strict=False)
83 | # restore step-dependent states
84 | self.do_update_step(epoch, global_step, on_load_weights=True)
85 |
86 | def set_resume_status(self, current_epoch: int, global_step: int):
87 | # restore correct epoch and global step in eval
88 | self._resumed_eval = True
89 | self._resumed_eval_status["current_epoch"] = current_epoch
90 | self._resumed_eval_status["global_step"] = global_step
91 |
92 | @property
93 | def resumed(self):
94 | # whether from resumed checkpoint
95 | return self._resumed
96 |
97 | @property
98 | def true_global_step(self):
99 | if self._resumed_eval:
100 | return self._resumed_eval_status["global_step"]
101 | else:
102 | return self.global_step
103 |
104 | @property
105 | def true_current_epoch(self):
106 | if self._resumed_eval:
107 | return self._resumed_eval_status["current_epoch"]
108 | else:
109 | return self.current_epoch
110 |
111 | def configure(self) -> None:
112 | pass
113 |
114 | def post_configure(self) -> None:
115 | """
116 | executed after weights are loaded
117 | """
118 | pass
119 |
120 | def C(self, value: Any) -> float:
121 | return C(value, self.true_current_epoch, self.true_global_step)
122 |
123 | def configure_optimizers(self):
124 |
125 |
126 |
127 | # if self.cfg.control:
128 | # optim = parse_optimizer_control(self.cfg.optimizer, self)
129 | # else:
130 | optim = parse_optimizer(self.cfg.optimizer, self)
131 | ret = {
132 | "optimizer": optim,
133 | }
134 |
135 | if self.cfg.scheduler is not None:
136 | ret.update(
137 | {
138 | "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim),
139 | }
140 | )
141 | return ret
142 |
143 | def training_step(self, batch, batch_idx):
144 | raise NotImplementedError
145 |
146 | def validation_step(self, batch, batch_idx):
147 | raise NotImplementedError
148 |
149 | def on_train_batch_end(self, outputs, batch, batch_idx):
150 | self.dataset = self.trainer.train_dataloader.dataset
151 | update_end_if_possible(
152 | self.dataset, self.true_current_epoch, self.true_global_step
153 | )
154 | self.do_update_step_end(self.true_current_epoch, self.true_global_step)
155 |
156 | def on_validation_batch_end(self, outputs, batch, batch_idx):
157 | self.dataset = self.trainer.val_dataloaders.dataset
158 | update_end_if_possible(
159 | self.dataset, self.true_current_epoch, self.true_global_step
160 | )
161 | self.do_update_step_end(self.true_current_epoch, self.true_global_step)
162 | if self.cfg.cleanup_after_validation_step:
163 | # cleanup to save vram
164 | cleanup()
165 |
166 | def on_validation_epoch_end(self):
167 | raise NotImplementedError
168 |
169 | def test_step(self, batch, batch_idx):
170 | raise NotImplementedError
171 |
172 | def on_test_batch_end(self, outputs, batch, batch_idx):
173 | self.dataset = self.trainer.test_dataloaders.dataset
174 | update_end_if_possible(
175 | self.dataset, self.true_current_epoch, self.true_global_step
176 | )
177 | self.do_update_step_end(self.true_current_epoch, self.true_global_step)
178 | if self.cfg.cleanup_after_test_step:
179 | # cleanup to save vram
180 | cleanup()
181 |
182 | def on_test_epoch_end(self):
183 | pass
184 |
185 | def predict_step(self, batch, batch_idx):
186 | raise NotImplementedError
187 |
188 | def on_predict_batch_end(self, outputs, batch, batch_idx):
189 | self.dataset = self.trainer.predict_dataloaders.dataset
190 | update_end_if_possible(
191 | self.dataset, self.true_current_epoch, self.true_global_step
192 | )
193 | self.do_update_step_end(self.true_current_epoch, self.true_global_step)
194 | if self.cfg.cleanup_after_test_step:
195 | # cleanup to save vram
196 | cleanup()
197 |
198 | def on_predict_epoch_end(self):
199 | pass
200 |
201 | def preprocess_data(self, batch, stage):
202 | pass
203 |
204 | """
205 | Implementing on_after_batch_transfer of DataModule does the same.
206 | But on_after_batch_transfer does not support DP.
207 | """
208 |
209 | def on_train_batch_start(self, batch, batch_idx, unused=0):
210 | self.preprocess_data(batch, "train")
211 | self.dataset = self.trainer.train_dataloader.dataset
212 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
213 | self.do_update_step(self.true_current_epoch, self.true_global_step)
214 |
215 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0):
216 | self.preprocess_data(batch, "validation")
217 | self.dataset = self.trainer.val_dataloaders.dataset
218 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
219 | self.do_update_step(self.true_current_epoch, self.true_global_step)
220 |
221 | def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0):
222 | self.preprocess_data(batch, "test")
223 | self.dataset = self.trainer.test_dataloaders.dataset
224 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
225 | self.do_update_step(self.true_current_epoch, self.true_global_step)
226 |
227 | def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0):
228 | self.preprocess_data(batch, "predict")
229 | self.dataset = self.trainer.predict_dataloaders.dataset
230 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
231 | self.do_update_step(self.true_current_epoch, self.true_global_step)
232 |
233 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
234 | pass
235 |
236 | def on_before_optimizer_step(self, optimizer):
237 | """
238 | # some gradient-related debugging goes here, example:
239 | from lightning.pytorch.utilities import grad_norm
240 | norms = grad_norm(self.geometry, norm_type=2)
241 | print(norms)
242 | """
243 | pass
244 |
--------------------------------------------------------------------------------
/mar3d/models/diffloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 | import math
5 |
6 | from diffusion import create_diffusion
7 |
8 |
9 | class DiffLoss(nn.Module):
10 | """Diffusion Loss"""
11 | def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
12 | super(DiffLoss, self).__init__()
13 | self.in_channels = target_channels
14 | self.net = SimpleMLPAdaLN(
15 | in_channels=target_channels,
16 | model_channels=width,
17 | out_channels=target_channels * 2, # for vlb loss
18 | z_channels=z_channels,
19 | num_res_blocks=depth,
20 | grad_checkpointing=grad_checkpointing
21 | )
22 |
23 | self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
24 | self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
25 |
26 | def forward(self, target, z, mask=None):
27 | t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
28 | model_kwargs = dict(c=z)
29 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
30 | loss = loss_dict["loss"]
31 | if mask is not None:
32 | loss = (loss * mask).sum() / mask.sum()
33 | return loss.mean()
34 |
35 | def sample(self, z, temperature=1.0, cfg=1.0):
36 | # diffusion loss sampling
37 | if not cfg == 1.0:
38 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
39 | noise = torch.cat([noise, noise], dim=0)
40 | model_kwargs = dict(c=z, cfg_scale=cfg)
41 | sample_fn = self.net.forward_with_cfg
42 | else:
43 | noise = torch.randn(z.shape[0], self.in_channels).cuda()
44 | model_kwargs = dict(c=z)
45 | sample_fn = self.net.forward
46 |
47 | sampled_token_latent = self.gen_diffusion.p_sample_loop(
48 | sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
49 | temperature=temperature
50 | )
51 |
52 | return sampled_token_latent
53 |
54 |
55 | def modulate(x, shift, scale):
56 | return x * (1 + scale) + shift
57 |
58 |
59 | class TimestepEmbedder(nn.Module):
60 | """
61 | Embeds scalar timesteps into vector representations.
62 | """
63 | def __init__(self, hidden_size, frequency_embedding_size=256):
64 | super().__init__()
65 | self.mlp = nn.Sequential(
66 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
67 | nn.SiLU(),
68 | nn.Linear(hidden_size, hidden_size, bias=True),
69 | )
70 | self.frequency_embedding_size = frequency_embedding_size
71 |
72 | @staticmethod
73 | def timestep_embedding(t, dim, max_period=10000):
74 | """
75 | Create sinusoidal timestep embeddings.
76 | :param t: a 1-D Tensor of N indices, one per batch element.
77 | These may be fractional.
78 | :param dim: the dimension of the output.
79 | :param max_period: controls the minimum frequency of the embeddings.
80 | :return: an (N, D) Tensor of positional embeddings.
81 | """
82 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
83 | half = dim // 2
84 | freqs = torch.exp(
85 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
86 | ).to(device=t.device)
87 | args = t[:, None].float() * freqs[None]
88 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
89 | if dim % 2:
90 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
91 | return embedding
92 |
93 | def forward(self, t):
94 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
95 | t_emb = self.mlp(t_freq)
96 | return t_emb
97 |
98 |
99 | class ResBlock(nn.Module):
100 | """
101 | A residual block that can optionally change the number of channels.
102 | :param channels: the number of input channels.
103 | """
104 |
105 | def __init__(
106 | self,
107 | channels
108 | ):
109 | super().__init__()
110 | self.channels = channels
111 |
112 | self.in_ln = nn.LayerNorm(channels, eps=1e-6)
113 | self.mlp = nn.Sequential(
114 | nn.Linear(channels, channels, bias=True),
115 | nn.SiLU(),
116 | nn.Linear(channels, channels, bias=True),
117 | )
118 |
119 | self.adaLN_modulation = nn.Sequential(
120 | nn.SiLU(),
121 | nn.Linear(channels, 3 * channels, bias=True)
122 | )
123 |
124 | def forward(self, x, y):
125 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
126 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
127 | h = self.mlp(h)
128 | return x + gate_mlp * h
129 |
130 |
131 | class FinalLayer(nn.Module):
132 | """
133 | The final layer of DiT.
134 | """
135 | def __init__(self, model_channels, out_channels):
136 | super().__init__()
137 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
138 | self.linear = nn.Linear(model_channels, out_channels, bias=True)
139 | self.adaLN_modulation = nn.Sequential(
140 | nn.SiLU(),
141 | nn.Linear(model_channels, 2 * model_channels, bias=True)
142 | )
143 |
144 | def forward(self, x, c):
145 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
146 | x = modulate(self.norm_final(x), shift, scale)
147 | x = self.linear(x)
148 | return x
149 |
150 |
151 | class SimpleMLPAdaLN(nn.Module):
152 | """
153 | The MLP for Diffusion Loss.
154 | :param in_channels: channels in the input Tensor.
155 | :param model_channels: base channel count for the model.
156 | :param out_channels: channels in the output Tensor.
157 | :param z_channels: channels in the condition.
158 | :param num_res_blocks: number of residual blocks per downsample.
159 | """
160 |
161 | def __init__(
162 | self,
163 | in_channels,
164 | model_channels,
165 | out_channels,
166 | z_channels,
167 | num_res_blocks,
168 | grad_checkpointing=False
169 | ):
170 | super().__init__()
171 |
172 | self.in_channels = in_channels
173 | self.model_channels = model_channels
174 | self.out_channels = out_channels
175 | self.num_res_blocks = num_res_blocks
176 | self.grad_checkpointing = grad_checkpointing
177 |
178 | self.time_embed = TimestepEmbedder(model_channels)
179 | self.cond_embed = nn.Linear(z_channels, model_channels)
180 |
181 | self.input_proj = nn.Linear(in_channels, model_channels)
182 |
183 | res_blocks = []
184 | for i in range(num_res_blocks):
185 | res_blocks.append(ResBlock(
186 | model_channels,
187 | ))
188 |
189 | self.res_blocks = nn.ModuleList(res_blocks)
190 | self.final_layer = FinalLayer(model_channels, out_channels)
191 |
192 | self.initialize_weights()
193 |
194 | def initialize_weights(self):
195 | def _basic_init(module):
196 | if isinstance(module, nn.Linear):
197 | torch.nn.init.xavier_uniform_(module.weight)
198 | if module.bias is not None:
199 | nn.init.constant_(module.bias, 0)
200 | self.apply(_basic_init)
201 |
202 | # Initialize timestep embedding MLP
203 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
204 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
205 |
206 | # Zero-out adaLN modulation layers
207 | for block in self.res_blocks:
208 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
209 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
210 |
211 | # Zero-out output layers
212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
214 | nn.init.constant_(self.final_layer.linear.weight, 0)
215 | nn.init.constant_(self.final_layer.linear.bias, 0)
216 |
217 | def forward(self, x, t, c):
218 | """
219 | Apply the model to an input batch.
220 | :param x: an [N x C x ...] Tensor of inputs.
221 | :param t: a 1-D batch of timesteps.
222 | :param c: conditioning from AR transformer.
223 | :return: an [N x C x ...] Tensor of outputs.
224 | """
225 |
226 | x = self.input_proj(x)
227 | t = self.time_embed(t)
228 | c = self.cond_embed(c)
229 |
230 | y = t + c
231 |
232 | if self.grad_checkpointing and not torch.jit.is_scripting():
233 | for block in self.res_blocks:
234 | x = checkpoint(block, x, y)
235 | else:
236 | for block in self.res_blocks:
237 | x = block(x, y)
238 |
239 | return self.final_layer(x, y)
240 |
241 | def forward_with_cfg(self, x, t, c, cfg_scale):
242 | half = x[: len(x) // 2]
243 | combined = torch.cat([half, half], dim=0)
244 | model_out = self.forward(combined, t, c)
245 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
246 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
247 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
248 | eps = torch.cat([half_eps, half_eps], dim=0)
249 | return torch.cat([eps, rest], dim=1)
250 |
--------------------------------------------------------------------------------
/mar3d/systems/diffloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 | import math
5 |
6 | from diffusion import create_diffusion
7 |
8 |
9 | class DiffLoss(nn.Module):
10 | """Diffusion Loss"""
11 | def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
12 | super(DiffLoss, self).__init__()
13 | self.in_channels = target_channels
14 | self.net = SimpleMLPAdaLN(
15 | in_channels=target_channels,
16 | model_channels=width,
17 | out_channels=target_channels * 2, # for vlb loss
18 | z_channels=z_channels,
19 | num_res_blocks=depth,
20 | grad_checkpointing=grad_checkpointing
21 | )
22 |
23 | self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
24 | self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
25 |
26 | def forward(self, target, z, mask=None):
27 | t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
28 | model_kwargs = dict(c=z)
29 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
30 | loss = loss_dict["loss"]
31 | if mask is not None:
32 | loss = (loss * mask).sum() / mask.sum()
33 | return loss.mean()
34 |
35 | def sample(self, z, temperature=1.0, cfg=1.0):
36 | # diffusion loss sampling
37 | if not cfg == 1.0:
38 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
39 | noise = torch.cat([noise, noise], dim=0)
40 | model_kwargs = dict(c=z, cfg_scale=cfg)
41 | sample_fn = self.net.forward_with_cfg
42 | else:
43 | noise = torch.randn(z.shape[0], self.in_channels).cuda()
44 | model_kwargs = dict(c=z)
45 | sample_fn = self.net.forward
46 |
47 | sampled_token_latent = self.gen_diffusion.p_sample_loop(
48 | sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
49 | temperature=temperature
50 | )
51 |
52 | return sampled_token_latent
53 |
54 |
55 | def modulate(x, shift, scale):
56 | return x * (1 + scale) + shift
57 |
58 |
59 | class TimestepEmbedder(nn.Module):
60 | """
61 | Embeds scalar timesteps into vector representations.
62 | """
63 | def __init__(self, hidden_size, frequency_embedding_size=256):
64 | super().__init__()
65 | self.mlp = nn.Sequential(
66 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
67 | nn.SiLU(),
68 | nn.Linear(hidden_size, hidden_size, bias=True),
69 | )
70 | self.frequency_embedding_size = frequency_embedding_size
71 |
72 | @staticmethod
73 | def timestep_embedding(t, dim, max_period=10000):
74 | """
75 | Create sinusoidal timestep embeddings.
76 | :param t: a 1-D Tensor of N indices, one per batch element.
77 | These may be fractional.
78 | :param dim: the dimension of the output.
79 | :param max_period: controls the minimum frequency of the embeddings.
80 | :return: an (N, D) Tensor of positional embeddings.
81 | """
82 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
83 | half = dim // 2
84 | freqs = torch.exp(
85 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
86 | ).to(device=t.device)
87 | args = t[:, None].float() * freqs[None]
88 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
89 | if dim % 2:
90 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
91 | return embedding
92 |
93 | def forward(self, t):
94 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
95 | t_emb = self.mlp(t_freq)
96 | return t_emb
97 |
98 |
99 | class ResBlock(nn.Module):
100 | """
101 | A residual block that can optionally change the number of channels.
102 | :param channels: the number of input channels.
103 | """
104 |
105 | def __init__(
106 | self,
107 | channels
108 | ):
109 | super().__init__()
110 | self.channels = channels
111 |
112 | self.in_ln = nn.LayerNorm(channels, eps=1e-6)
113 | self.mlp = nn.Sequential(
114 | nn.Linear(channels, channels, bias=True),
115 | nn.SiLU(),
116 | nn.Linear(channels, channels, bias=True),
117 | )
118 |
119 | self.adaLN_modulation = nn.Sequential(
120 | nn.SiLU(),
121 | nn.Linear(channels, 3 * channels, bias=True)
122 | )
123 |
124 | def forward(self, x, y):
125 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
126 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
127 | h = self.mlp(h)
128 | return x + gate_mlp * h
129 |
130 |
131 | class FinalLayer(nn.Module):
132 | """
133 | The final layer of DiT.
134 | """
135 | def __init__(self, model_channels, out_channels):
136 | super().__init__()
137 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
138 | self.linear = nn.Linear(model_channels, out_channels, bias=True)
139 | self.adaLN_modulation = nn.Sequential(
140 | nn.SiLU(),
141 | nn.Linear(model_channels, 2 * model_channels, bias=True)
142 | )
143 |
144 | def forward(self, x, c):
145 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
146 | x = modulate(self.norm_final(x), shift, scale)
147 | x = self.linear(x)
148 | return x
149 |
150 |
151 | class SimpleMLPAdaLN(nn.Module):
152 | """
153 | The MLP for Diffusion Loss.
154 | :param in_channels: channels in the input Tensor.
155 | :param model_channels: base channel count for the model.
156 | :param out_channels: channels in the output Tensor.
157 | :param z_channels: channels in the condition.
158 | :param num_res_blocks: number of residual blocks per downsample.
159 | """
160 |
161 | def __init__(
162 | self,
163 | in_channels,
164 | model_channels,
165 | out_channels,
166 | z_channels,
167 | num_res_blocks,
168 | grad_checkpointing=False
169 | ):
170 | super().__init__()
171 |
172 | self.in_channels = in_channels
173 | self.model_channels = model_channels
174 | self.out_channels = out_channels
175 | self.num_res_blocks = num_res_blocks
176 | self.grad_checkpointing = grad_checkpointing
177 |
178 | self.time_embed = TimestepEmbedder(model_channels)
179 | self.cond_embed = nn.Linear(z_channels, model_channels)
180 |
181 | self.input_proj = nn.Linear(in_channels, model_channels)
182 |
183 | res_blocks = []
184 | for i in range(num_res_blocks):
185 | res_blocks.append(ResBlock(
186 | model_channels,
187 | ))
188 |
189 | self.res_blocks = nn.ModuleList(res_blocks)
190 | self.final_layer = FinalLayer(model_channels, out_channels)
191 |
192 | self.initialize_weights()
193 |
194 | def initialize_weights(self):
195 | def _basic_init(module):
196 | if isinstance(module, nn.Linear):
197 | torch.nn.init.xavier_uniform_(module.weight)
198 | if module.bias is not None:
199 | nn.init.constant_(module.bias, 0)
200 | self.apply(_basic_init)
201 |
202 | # Initialize timestep embedding MLP
203 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
204 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
205 |
206 | # Zero-out adaLN modulation layers
207 | for block in self.res_blocks:
208 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
209 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
210 |
211 | # Zero-out output layers
212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
214 | nn.init.constant_(self.final_layer.linear.weight, 0)
215 | nn.init.constant_(self.final_layer.linear.bias, 0)
216 |
217 | def forward(self, x, t, c):
218 | """
219 | Apply the model to an input batch.
220 | :param x: an [N x C x ...] Tensor of inputs.
221 | :param t: a 1-D batch of timesteps.
222 | :param c: conditioning from AR transformer.
223 | :return: an [N x C x ...] Tensor of outputs.
224 | """
225 |
226 | x = self.input_proj(x)
227 | t = self.time_embed(t)
228 | c = self.cond_embed(c)
229 |
230 | y = t + c
231 |
232 | if self.grad_checkpointing and not torch.jit.is_scripting():
233 | for block in self.res_blocks:
234 | x = checkpoint(block, x, y)
235 | else:
236 | for block in self.res_blocks:
237 | x = block(x, y)
238 |
239 | return self.final_layer(x, y)
240 |
241 | def forward_with_cfg(self, x, t, c, cfg_scale):
242 | half = x[: len(x) // 2]
243 | combined = torch.cat([half, half], dim=0)
244 | model_out = self.forward(combined, t, c)
245 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
246 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
247 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
248 | eps = torch.cat([half_eps, half_eps], dim=0)
249 | return torch.cat([eps, rest], dim=1)
250 |
--------------------------------------------------------------------------------
/mar3d/models/transformers/attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from mar3d.utils.typing import *
7 | from mar3d.utils.checkpoint import checkpoint
8 |
9 | from .utils import init_linear, MLP
10 |
11 | class MultiheadAttention(nn.Module):
12 | def __init__(
13 | self,
14 | *,
15 | n_ctx: int,
16 | width: int,
17 | heads: int,
18 | init_scale: float,
19 | qkv_bias: bool,
20 | use_flash: bool = False
21 | ):
22 | super().__init__()
23 | self.n_ctx = n_ctx
24 | self.width = width
25 | self.heads = heads
26 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
27 | self.c_proj = nn.Linear(width, width)
28 | self.attention = QKVMultiheadAttention(heads=heads, n_ctx=n_ctx, use_flash=use_flash)
29 | init_linear(self.c_qkv, init_scale)
30 | init_linear(self.c_proj, init_scale)
31 |
32 | def forward(self, x):
33 | x = self.c_qkv(x)
34 | x = checkpoint(self.attention, (x,), (), True)
35 | x = self.c_proj(x)
36 | return x
37 |
38 |
39 | class QKVMultiheadAttention(nn.Module):
40 | def __init__(self, *, heads: int, n_ctx: int, use_flash: bool = False):
41 | super().__init__()
42 | self.heads = heads
43 | self.n_ctx = n_ctx
44 | self.use_flash = use_flash
45 |
46 | def forward(self, qkv):
47 | bs, n_ctx, width = qkv.shape
48 | attn_ch = width // self.heads // 3
49 | scale = 1 / math.sqrt(math.sqrt(attn_ch))
50 | qkv = qkv.view(bs, n_ctx, self.heads, -1)
51 | q, k, v = torch.split(qkv, attn_ch, dim=-1)
52 |
53 | if self.use_flash:
54 | q = q.permute(0, 2, 1, 3)
55 | k = k.permute(0, 2, 1, 3)
56 | v = v.permute(0, 2, 1, 3)
57 | out = F.scaled_dot_product_attention(q, k, v).permute(0, 2, 1, 3).reshape(bs, n_ctx, -1)
58 | else:
59 | weight = torch.einsum(
60 | "bthc,bshc->bhts", q * scale, k * scale
61 | ) # More stable with f16 than dividing afterwards
62 | wdtype = weight.dtype
63 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
64 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
65 |
66 | return out
67 |
68 | class ResidualAttentionBlock(nn.Module):
69 | def __init__(
70 | self,
71 | *,
72 | n_ctx: int,
73 | width: int,
74 | heads: int,
75 | init_scale: float = 1.0,
76 | qkv_bias: bool = True,
77 | use_flash: bool = False,
78 | use_checkpoint: bool = False
79 | ):
80 | super().__init__()
81 |
82 | self.use_checkpoint = use_checkpoint
83 |
84 | self.attn = MultiheadAttention(
85 | n_ctx=n_ctx,
86 | width=width,
87 | heads=heads,
88 | init_scale=init_scale,
89 | qkv_bias=qkv_bias,
90 | use_flash=use_flash
91 | )
92 | self.ln_1 = nn.LayerNorm(width)
93 | self.mlp = MLP(width=width, init_scale=init_scale)
94 | self.ln_2 = nn.LayerNorm(width)
95 |
96 | def _forward(self, x: torch.Tensor):
97 | x = x + self.attn(self.ln_1(x))
98 | x = x + self.mlp(self.ln_2(x))
99 | return x
100 |
101 | def forward(self, x: torch.Tensor):
102 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
103 |
104 |
105 | class MultiheadCrossAttention(nn.Module):
106 | def __init__(
107 | self,
108 | *,
109 | width: int,
110 | heads: int,
111 | init_scale: float,
112 | qkv_bias: bool = True,
113 | use_flash: bool = False,
114 | n_data: Optional[int] = None,
115 | data_width: Optional[int] = None,
116 | ):
117 | super().__init__()
118 | self.n_data = n_data
119 | self.width = width
120 | self.heads = heads
121 | self.data_width = width if data_width is None else data_width
122 | self.c_q = nn.Linear(width, width, bias=qkv_bias)
123 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
124 | self.c_proj = nn.Linear(width, width)
125 | self.attention = QKVMultiheadCrossAttention(
126 | heads=heads, n_data=n_data, use_flash=use_flash
127 | )
128 | init_linear(self.c_q, init_scale)
129 | init_linear(self.c_kv, init_scale)
130 | init_linear(self.c_proj, init_scale)
131 |
132 | def forward(self, x, data):
133 | x = self.c_q(x)
134 | data = self.c_kv(data)
135 | x = checkpoint(self.attention, (x, data), (), True)
136 | x = self.c_proj(x)
137 | return x
138 |
139 |
140 | class QKVMultiheadCrossAttention(nn.Module):
141 | def __init__(self, *, heads: int, use_flash: bool = False, n_data: Optional[int] = None):
142 |
143 | super().__init__()
144 | self.heads = heads
145 | self.n_data = n_data
146 | self.use_flash = use_flash
147 |
148 | def forward(self, q, kv):
149 | _, n_ctx, _ = q.shape
150 | bs, n_data, width = kv.shape
151 | attn_ch = width // self.heads // 2
152 | scale = 1 / math.sqrt(math.sqrt(attn_ch))
153 | q = q.view(bs, n_ctx, self.heads, -1)
154 | kv = kv.view(bs, n_data, self.heads, -1)
155 | k, v = torch.split(kv, attn_ch, dim=-1)
156 |
157 | if self.use_flash:
158 | q = q.permute(0, 2, 1, 3)
159 | k = k.permute(0, 2, 1, 3)
160 | v = v.permute(0, 2, 1, 3)
161 | out = F.scaled_dot_product_attention(q, k, v).permute(0, 2, 1, 3).reshape(bs, n_ctx, -1)
162 | else:
163 | weight = torch.einsum(
164 | "bthc,bshc->bhts", q * scale, k * scale
165 | ) # More stable with f16 than dividing afterwards
166 | wdtype = weight.dtype
167 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
168 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
169 |
170 | return out
171 |
172 |
173 | class ResidualCrossAttentionBlock(nn.Module):
174 | def __init__(
175 | self,
176 | *,
177 | n_data: Optional[int] = None,
178 | width: int,
179 | heads: int,
180 | data_width: Optional[int] = None,
181 | init_scale: float = 0.25,
182 | qkv_bias: bool = True,
183 | use_flash: bool = False
184 | ):
185 | super().__init__()
186 |
187 | if data_width is None:
188 | data_width = width
189 |
190 | self.attn = MultiheadCrossAttention(
191 | n_data=n_data,
192 | width=width,
193 | heads=heads,
194 | data_width=data_width,
195 | init_scale=init_scale,
196 | qkv_bias=qkv_bias,
197 | use_flash=use_flash,
198 | )
199 | self.ln_1 = nn.LayerNorm(width)
200 | self.ln_2 = nn.LayerNorm(data_width)
201 | self.mlp = MLP(width=width, init_scale=init_scale)
202 | self.ln_3 = nn.LayerNorm(width)
203 |
204 | def forward(self, x: torch.Tensor, data: torch.Tensor):
205 | x = x + self.attn(self.ln_1(x), self.ln_2(data))
206 | x = x + self.mlp(self.ln_3(x))
207 | return x
208 |
209 | class MultiheadCrossAttention_(nn.Module):
210 | def __init__(
211 | self,
212 | *,
213 | widthq:int,
214 | width: int,
215 | heads: int,
216 | init_scale: float,
217 | qkv_bias: bool = True,
218 | use_flash: bool = False,
219 | n_data: Optional[int] = None,
220 | kv_width: Optional[int] = None,
221 | out_width: int,
222 | ):
223 | super().__init__()
224 | self.n_data = n_data
225 | self.width = width
226 | self.width1 = widthq
227 | self.heads = heads
228 | self.data_width = width if kv_width is None else kv_width
229 | self.c_q = nn.Linear(widthq, width, bias=qkv_bias)
230 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
231 | self.c_proj = nn.Linear(width, out_width)
232 | self.attention = QKVMultiheadCrossAttention(
233 | heads=heads, n_data=n_data, use_flash=use_flash
234 | )
235 | init_linear(self.c_q, init_scale)
236 | init_linear(self.c_kv, init_scale)
237 | init_linear(self.c_proj, init_scale)
238 |
239 | def forward(self, x, data):
240 |
241 | x = self.c_q(x)
242 | data = self.c_kv(data)
243 | x = checkpoint(self.attention, (x, data), (), True)
244 | x = self.c_proj(x)
245 | return x
246 | class CrossAttentionBlock(nn.Module):
247 | def __init__(
248 | self,
249 | *,
250 | n_data: Optional[int] = None,
251 | widq: int,
252 | width: int,
253 | heads: int,
254 | kv_width: Optional[int] = None,
255 | out_width: int,
256 | init_scale: float = 0.25,
257 | qkv_bias: bool = True,
258 | use_flash: bool = False
259 | ):
260 | super().__init__()
261 |
262 | # if data_width is None:
263 | # data_width = width
264 |
265 | self.attn = MultiheadCrossAttention_(
266 | n_data=n_data,
267 | widthq=widq,
268 | width=width,
269 | heads=heads,
270 | kv_width=kv_width,
271 | out_width=out_width,
272 | init_scale=init_scale,
273 | qkv_bias=qkv_bias,
274 | use_flash=use_flash,
275 | )
276 | self.ln_1 = nn.LayerNorm(widq)
277 | self.ln_2 = nn.LayerNorm(kv_width)
278 | self.mlp = MLP(width=out_width, init_scale=init_scale)
279 | self.ln_3 = nn.LayerNorm( out_width)
280 |
281 | def forward(self, x: torch.Tensor, data: torch.Tensor):
282 | x = self.attn(self.ln_1(x), self.ln_2(data))
283 | x = self.mlp(self.ln_3(x))
284 | return x
--------------------------------------------------------------------------------
/launch.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import contextlib
3 | import importlib
4 | import logging
5 | import os
6 | import sys
7 | import time
8 | import traceback
9 | import pytorch_lightning as pl
10 | import torch
11 | from pytorch_lightning import Trainer
12 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
13 | from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
14 | from pytorch_lightning.utilities.rank_zero import rank_zero_only
15 | import mar3d
16 | from mar3d.systems.base import BaseSystem
17 | from mar3d.utils.callbacks import (
18 | CodeSnapshotCallback,
19 | ConfigSnapshotCallback,
20 | CustomProgressBar,
21 | ProgressCallback,
22 | )
23 | from mar3d.utils.config import ExperimentConfig, load_config
24 | from mar3d.utils.misc import get_rank
25 | from mar3d.utils.typing import Optional
26 | class ColoredFilter(logging.Filter):
27 | """
28 | A logging filter to add color to certain log levels.
29 | """
30 |
31 | RESET = "\033[0m"
32 | RED = "\033[31m"
33 | GREEN = "\033[32m"
34 | YELLOW = "\033[33m"
35 | BLUE = "\033[34m"
36 | MAGENTA = "\033[35m"
37 | CYAN = "\033[36m"
38 |
39 | COLORS = {
40 | "WARNING": YELLOW,
41 | "INFO": GREEN,
42 | "DEBUG": BLUE,
43 | "CRITICAL": MAGENTA,
44 | "ERROR": RED,
45 | }
46 |
47 | RESET = "\x1b[0m"
48 |
49 | def __init__(self):
50 | super().__init__()
51 |
52 | def filter(self, record):
53 | if record.levelname in self.COLORS:
54 | color_start = self.COLORS[record.levelname]
55 | record.levelname = f"{color_start}[{record.levelname}]"
56 | record.msg = f"{record.msg}{self.RESET}"
57 | return True
58 |
59 |
60 | def load_custom_module(module_path):
61 | module_name = os.path.basename(module_path)
62 | if os.path.isfile(module_path):
63 | sp = os.path.splitext(module_path)
64 | module_name = sp[0]
65 | try:
66 | if os.path.isfile(module_path):
67 | module_spec = importlib.util.spec_from_file_location(
68 | module_name, module_path
69 | )
70 | else:
71 | module_spec = importlib.util.spec_from_file_location(
72 | module_name, os.path.join(module_path, "__init__.py")
73 | )
74 |
75 | module = importlib.util.module_from_spec(module_spec)
76 | sys.modules[module_name] = module
77 | module_spec.loader.exec_module(module)
78 | return True
79 | except Exception as e:
80 | print(traceback.format_exc())
81 | print(f"Cannot import {module_path} module for custom nodes:", e)
82 | return False
83 |
84 |
85 | def load_custom_modules():
86 | node_paths = ["custom"]
87 | node_import_times = []
88 | if not os.path.exists("node_paths"):
89 | return
90 | for custom_node_path in node_paths:
91 | possible_modules = os.listdir(custom_node_path)
92 | if "__pycache__" in possible_modules:
93 | possible_modules.remove("__pycache__")
94 |
95 | for possible_module in possible_modules:
96 | module_path = os.path.join(custom_node_path, possible_module)
97 | if (
98 | os.path.isfile(module_path)
99 | and os.path.splitext(module_path)[1] != ".py"
100 | ):
101 | continue
102 | if module_path.endswith(".disabled"):
103 | continue
104 | time_before = time.perf_counter()
105 | success = load_custom_module(module_path)
106 | node_import_times.append(
107 | (time.perf_counter() - time_before, module_path, success)
108 | )
109 |
110 | if len(node_import_times) > 0:
111 | print("\nImport times for custom modules:")
112 | for n in sorted(node_import_times):
113 | if n[2]:
114 | import_message = ""
115 | else:
116 | import_message = " (IMPORT FAILED)"
117 | print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
118 | print()
119 |
120 |
121 | def main(args, extras) -> None:
122 | # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning
123 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
124 | env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None)
125 | env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else []
126 | selected_gpus = [0]
127 | torch.set_float32_matmul_precision("high")
128 |
129 | # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified.
130 | # As far as Pytorch Lightning is concerned, we always use all available GPUs
131 | # (possibly filtered by CUDA_VISIBLE_DEVICES).
132 | devices = -1
133 | if len(env_gpus) > 0:
134 | n_gpus = len(env_gpus)
135 | else:
136 | selected_gpus = list(args.gpu.split(","))
137 | n_gpus = len(selected_gpus)
138 | print(f"Using {n_gpus} GPUs: {selected_gpus}")
139 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
140 |
141 | if args.typecheck:
142 | from jaxtyping import install_import_hook
143 |
144 | install_import_hook("mar3d", "typeguard.typechecked")
145 |
146 | logger = logging.getLogger("pytorch_lightning")
147 | if args.verbose:
148 | logger.setLevel(logging.DEBUG)
149 |
150 | for handler in logger.handlers:
151 | if handler.stream == sys.stderr: # type: ignore
152 | if not args.gradio:
153 | handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
154 | handler.addFilter(ColoredFilter())
155 | else:
156 | handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
157 |
158 | load_custom_modules()
159 |
160 | # parse YAML config to OmegaConf
161 | cfg: ExperimentConfig
162 | cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus)
163 |
164 | # set a different seed for each device
165 | pl.seed_everything(cfg.seed + get_rank(), workers=True)
166 |
167 | dm = mar3d.find(cfg.data_type)(cfg.data)
168 | system: BaseSystem = mar3d.find(cfg.system_type)(
169 | cfg.system, resumed=cfg.resume is not None
170 | )
171 |
172 | system.set_save_dir(os.path.join(cfg.trial_dir, "save"))
173 |
174 | if args.gradio:
175 | fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs"))
176 | fh.setLevel(logging.INFO)
177 | if args.verbose:
178 | fh.setLevel(logging.DEBUG)
179 | fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
180 | logger.addHandler(fh)
181 |
182 | callbacks = []
183 | if args.train:
184 | callbacks += [
185 | ModelCheckpoint(
186 | dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint
187 | ),
188 | LearningRateMonitor(logging_interval="step"),
189 | # CodeSnapshotCallback(
190 | # os.path.join(cfg.trial_dir, "code"), use_version=False
191 | # ),
192 | ConfigSnapshotCallback(
193 | args.config,
194 | cfg,
195 | os.path.join(cfg.trial_dir, "configs"),
196 | use_version=False,
197 | ),
198 | ]
199 | if args.gradio:
200 | callbacks += [
201 | ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress"))
202 | ]
203 | else:
204 | callbacks += [CustomProgressBar(refresh_rate=1)]
205 |
206 | def write_to_text(file, lines):
207 | with open(file, "w") as f:
208 | for line in lines:
209 | f.write(line + "\n")
210 |
211 | loggers = []
212 | if args.train:
213 | # make tensorboard logging dir to suppress warning
214 | rank_zero_only(
215 | lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True)
216 | )()
217 | loggers += [
218 | TensorBoardLogger(cfg.trial_dir, name="tb_logs"),
219 | CSVLogger(cfg.trial_dir, name="csv_logs"),
220 | ] + system.get_loggers()
221 | rank_zero_only(
222 | lambda: write_to_text(
223 | os.path.join(cfg.trial_dir, "cmd.txt"),
224 | ["python " + " ".join(sys.argv), str(args)],
225 | )
226 | )()
227 |
228 | trainer = Trainer(
229 | callbacks=callbacks,
230 | logger=loggers,
231 | inference_mode=False,
232 | accelerator="gpu",
233 | devices=devices,
234 | # profiler="advanced",
235 | **cfg.trainer,
236 | )
237 |
238 | def set_system_status(system: BaseSystem, ckpt_path: Optional[str]):
239 | if ckpt_path is None:
240 | return
241 | ckpt = torch.load(ckpt_path, map_location="cpu")
242 | system.set_resume_status(ckpt["epoch"], ckpt["global_step"])
243 | if args.train:
244 | trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume)
245 | trainer.test(system, datamodule=dm)
246 | if args.gradio:
247 | # also export assets if in gradio mode
248 | trainer.predict(system, datamodule=dm)
249 | elif args.validate:
250 | # manually set epoch and global_step as they cannot be automatically resumed
251 | set_system_status(system, cfg.resume)
252 | trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume)
253 | elif args.test:
254 | # manually set epoch and global_step as they cannot be automatically resumed
255 | set_system_status(system, cfg.resume)
256 | trainer.test(system, datamodule=dm, ckpt_path=cfg.resume)
257 | elif args.export:
258 | set_system_status(system, cfg.resume)
259 | trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume)
260 |
261 |
262 | if __name__ == "__main__":
263 | parser = argparse.ArgumentParser()
264 | parser.add_argument("--config", required=True, help="path to config file")
265 | parser.add_argument(
266 | "--gpu",
267 | default="0",
268 | help="GPU(s) to be used. 0 means use the 1st available GPU. "
269 | "1,2 means use the 2nd and 3rd available GPU. "
270 | "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, "
271 | "this argument is ignored and all available GPUs are always used.",
272 | )
273 |
274 | group = parser.add_mutually_exclusive_group(required=True)
275 | group.add_argument("--train", action="store_true")
276 | group.add_argument("--validate", action="store_true")
277 | group.add_argument("--test", action="store_true")
278 | group.add_argument("--export", action="store_true")
279 |
280 | parser.add_argument(
281 | "--gradio", action="store_true", help="if true, run in gradio mode"
282 | )
283 |
284 | parser.add_argument(
285 | "--verbose", action="store_true", help="if true, set logging level to DEBUG"
286 | )
287 |
288 | parser.add_argument(
289 | "--typecheck",
290 | action="store_true",
291 | help="whether to enable dynamic type checking",
292 | )
293 |
294 | args, extras = parser.parse_known_args()
295 |
296 | if args.gradio:
297 | with contextlib.redirect_stdout(sys.stderr):
298 | main(args, extras)
299 | else:
300 | main(args, extras)
301 |
--------------------------------------------------------------------------------
/mar3d/models/autoencoders/utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import distributed as tdist
6 | from torch.nn import functional as F
7 | import math
8 | import numpy as np
9 | from einops import repeat, rearrange
10 | from skimage import measure
11 |
12 | from mar3d.utils.base import BaseModule
13 | from mar3d.utils.typing import *
14 | from mar3d.utils.misc import get_world_size
15 | from mar3d.utils.ops import generate_dense_grid_points
16 |
17 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
18 |
19 | class FourierEmbedder(nn.Module):
20 | def __init__(self,
21 | num_freqs: int = 6,
22 | logspace: bool = True,
23 | input_dim: int = 3,
24 | include_input: bool = True,
25 | include_pi: bool = True) -> None:
26 | super().__init__()
27 |
28 | if logspace:
29 | frequencies = 2.0 ** torch.arange(
30 | num_freqs,
31 | dtype=torch.float32
32 | )
33 | else:
34 | frequencies = torch.linspace(
35 | 1.0,
36 | 2.0 ** (num_freqs - 1),
37 | num_freqs,
38 | dtype=torch.float32
39 | )
40 |
41 | if include_pi:
42 | frequencies *= torch.pi
43 |
44 | self.register_buffer("frequencies", frequencies, persistent=False)
45 | self.include_input = include_input
46 | self.num_freqs = num_freqs
47 |
48 | self.out_dim = self.get_dims(input_dim)
49 |
50 | def get_dims(self, input_dim):
51 | temp = 1 if self.include_input or self.num_freqs == 0 else 0
52 | out_dim = input_dim * (self.num_freqs * 2 + temp)
53 |
54 | return out_dim
55 |
56 | def forward(self, x: torch.Tensor) -> torch.Tensor:
57 | if self.num_freqs > 0:
58 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
59 | if self.include_input:
60 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
61 | else:
62 | return torch.cat((embed.sin(), embed.cos()), dim=-1)
63 | else:
64 | return x
65 |
66 |
67 | class LearnedFourierEmbedder(nn.Module):
68 | def __init__(self, input_dim, dim):
69 | super().__init__()
70 | assert (dim % 2) == 0
71 | half_dim = dim // 2
72 | per_channel_dim = half_dim // input_dim
73 | self.weights = nn.Parameter(torch.randn(per_channel_dim))
74 |
75 | self.out_dim = self.get_dims(input_dim)
76 |
77 | def forward(self, x):
78 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
79 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
80 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
81 | return fouriered
82 |
83 | def get_dims(self, input_dim):
84 | return input_dim * (self.weights.shape[0] * 2 + 1)
85 |
86 | class Sine(nn.Module):
87 | def __init__(self, w0 = 1.):
88 | super().__init__()
89 | self.w0 = w0
90 | def forward(self, x):
91 | return torch.sin(self.w0 * x)
92 |
93 | class Siren(nn.Module):
94 | def __init__(
95 | self,
96 | in_dim,
97 | out_dim,
98 | w0 = 1.,
99 | c = 6.,
100 | is_first = False,
101 | use_bias = True,
102 | activation = None,
103 | dropout = 0.
104 | ):
105 | super().__init__()
106 | self.in_dim = in_dim
107 | self.out_dim = out_dim
108 | self.is_first = is_first
109 |
110 | weight = torch.zeros(out_dim, in_dim)
111 | bias = torch.zeros(out_dim) if use_bias else None
112 | self.init_(weight, bias, c = c, w0 = w0)
113 |
114 | self.weight = nn.Parameter(weight)
115 | self.bias = nn.Parameter(bias) if use_bias else None
116 | self.activation = Sine(w0) if activation is None else activation
117 | self.dropout = nn.Dropout(dropout)
118 |
119 | def init_(self, weight, bias, c, w0):
120 | dim = self.in_dim
121 |
122 | w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
123 | weight.uniform_(-w_std, w_std)
124 |
125 | if bias is not None:
126 | bias.uniform_(-w_std, w_std)
127 |
128 | def forward(self, x):
129 | out = F.linear(x, self.weight, self.bias)
130 | out = self.activation(out)
131 | out = self.dropout(out)
132 | return out
133 |
134 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True):
135 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
136 | return nn.Identity(), input_dim
137 |
138 | elif embed_type == "fourier":
139 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
140 |
141 | elif embed_type == "learned_fourier":
142 | embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs)
143 |
144 | elif embed_type == "siren":
145 | embedder_obj = Siren(in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim)
146 |
147 | elif embed_type == "hashgrid":
148 | raise NotImplementedError
149 |
150 | elif embed_type == "sphere_harmonic":
151 | raise NotImplementedError
152 |
153 | else:
154 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
155 | return embedder_obj
156 |
157 |
158 | ###################### AutoEncoder
159 | class AutoEncoder(BaseModule):
160 | @dataclass
161 | class Config(BaseModule.Config):
162 | pretrained_model_name_or_path: str = ""
163 | num_latents: int = 256
164 | embed_dim: int = 64
165 | width: int = 768
166 | upsample: bool = False
167 | cfg: Config
168 |
169 | def configure(self) -> None:
170 | super().configure()
171 |
172 | def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
173 | raise NotImplementedError
174 |
175 | def decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
176 | raise NotImplementedError
177 |
178 | def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
179 | posterior = None
180 | if self.cfg.embed_dim > 0:
181 | moments = self.pre_kl(latents)
182 |
183 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
184 | if sample_posterior:
185 | kl_embed = posterior.sample()
186 | else:
187 | kl_embed = posterior.mode()
188 | else:
189 | kl_embed = latents
190 | return kl_embed, posterior
191 |
192 | def forward(self,
193 | surface: torch.FloatTensor,
194 | surface2: torch.FloatTensor =None,
195 | surface3: torch.FloatTensor =None,
196 | queries: torch.FloatTensor =None,
197 | sample_posterior: bool = True,
198 | res: torch.FloatTensor =None,
199 | ):
200 | shape_latents, kl_embed, posterior = self.encode(surface,surface2,surface3, sample_posterior=sample_posterior)
201 |
202 | # logs=[]
203 | # de_latents=[]
204 | # if type(kl_embed) == list:
205 |
206 | # # for kl_ in kl_embed:
207 |
208 | # latents_0 = self.decode(kl_embed[0]) # [B, num_latents, width]
209 | # latents_1 = self.decode(kl_embed[1])
210 | # # import ipdb
211 | # # ipdb.set_trace()
212 |
213 | # latents_0_up=self.cross_level(latents_1,latents_0 )
214 | # latents_2=latents_1+latents_0_up
215 |
216 | # de_latents=[latents_0_up,latents_2]
217 | # logits = self.query(queries, latents_0_up) # [B,]
218 | # logits_2 = self.query(queries, latents_2) # [B,]
219 | # logs=[logits,logits_2]
220 | # else:
221 | de_latents = self.decode(kl_embed) # [B, num_latents, width]
222 | logs = self.query(queries, de_latents) # [B,]
223 |
224 |
225 | return kl_embed, de_latents, posterior, logs
226 |
227 | def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor) -> torch.FloatTensor:
228 | raise NotImplementedError
229 |
230 | @torch.no_grad()
231 | def extract_geometry(self,
232 | latents: torch.FloatTensor,
233 | bounds: Union[Tuple[float], List[float], float] = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05),
234 | octree_depth: int = 8,
235 | num_chunks: int = 10000,
236 | ):
237 |
238 | if isinstance(bounds, float):
239 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
240 |
241 | bbox_min = np.array(bounds[0:3])
242 | bbox_max = np.array(bounds[3:6])
243 | bbox_size = bbox_max - bbox_min
244 |
245 | xyz_samples, grid_size, length = generate_dense_grid_points(
246 | bbox_min=bbox_min,
247 | bbox_max=bbox_max,
248 | octree_depth=octree_depth,
249 | indexing="ij"
250 | )
251 | xyz_samples = torch.FloatTensor(xyz_samples)
252 |
253 |
254 | batch_size = latents.shape[0]
255 |
256 | batch_logits = []
257 | for start in range(0, xyz_samples.shape[0], num_chunks):
258 | queries = xyz_samples[start: start + num_chunks, :].to(latents)
259 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
260 |
261 | logits = self.query(batch_queries, latents)
262 | batch_logits.append(logits.cpu())
263 |
264 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float().numpy()
265 |
266 | mesh_v_f = []
267 | has_surface = np.zeros((batch_size,), dtype=np.bool_)
268 | for i in range(batch_size):
269 | try:
270 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
271 | vertices = vertices / grid_size * bbox_size + bbox_min
272 | faces = faces[:, [2, 1, 0]]
273 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
274 | has_surface[i] = True
275 | except:
276 | mesh_v_f.append((None, None))
277 | has_surface[i] = False
278 |
279 | return mesh_v_f, has_surface
280 |
281 | class DiagonalGaussianDistribution(object):
282 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
283 | self.feat_dim = feat_dim
284 | self.parameters = parameters
285 |
286 | if isinstance(parameters, list):
287 | self.mean = parameters[0]
288 | self.logvar = parameters[1]
289 | else:
290 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
291 |
292 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
293 | self.deterministic = deterministic
294 | self.std = torch.exp(0.5 * self.logvar)
295 | self.var = torch.exp(self.logvar)
296 | if self.deterministic:
297 | self.var = self.std = torch.zeros_like(self.mean)
298 |
299 | def sample(self):
300 | x = self.mean + self.std * torch.randn_like(self.mean)
301 | return x
302 |
303 | def kl(self, other=None, dims=(1, 2)):
304 | if self.deterministic:
305 | return torch.Tensor([0.])
306 | else:
307 | if other is None:
308 | return 0.5 * torch.mean(torch.pow(self.mean, 2)
309 | + self.var - 1.0 - self.logvar,
310 | dim=dims)
311 | else:
312 | return 0.5 * torch.mean(
313 | torch.pow(self.mean - other.mean, 2) / other.var
314 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
315 | dim=dims)
316 |
317 | def nll(self, sample, dims=(1, 2)):
318 | if self.deterministic:
319 | return torch.Tensor([0.])
320 | logtwopi = np.log(2.0 * np.pi)
321 | return 0.5 * torch.sum(
322 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
323 | dim=dims)
324 |
325 | def mode(self):
326 | return self.mean
327 |
--------------------------------------------------------------------------------
/mar3d/data/objaverse.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import json
4 | from dataclasses import dataclass, field
5 |
6 | import random
7 | import imageio
8 | import numpy as np
9 | import pytorch_lightning as pl
10 | import torch
11 | import torch.nn.functional as F
12 | from torch.utils.data import DataLoader, Dataset
13 | from torchvision import transforms
14 | from PIL import Image
15 | from transformers import CLIPImageProcessor, CLIPTokenizer
16 | import pickle
17 | from mar3d import register
18 | from mar3d.utils.base import Updateable
19 | from mar3d.utils.config import parse_structured
20 | from mar3d.utils.typing import *
21 | from plyfile import PlyData, PlyElement
22 | import pandas as pd
23 | def save_ply_plyfile(points, filename):
24 | # 创建结构化数组
25 | vertex = np.zeros(len(points), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
26 | vertex['x'] = points[:, 0]
27 | vertex['y'] = points[:, 1]
28 | vertex['z'] = points[:, 2]
29 |
30 | # 创建PlyElement
31 | el = PlyElement.describe(vertex, 'vertex')
32 |
33 | # 写入PLY文件
34 | PlyData([el], text=True).write(filename)
35 | def rot2eul(R):
36 | beta = -np.arcsin(R[2,0])
37 | alpha = np.arctan2(R[2,1]/np.cos(beta),R[2,2]/np.cos(beta))
38 | gamma = np.arctan2(R[1,0]/np.cos(beta),R[0,0]/np.cos(beta))
39 | return np.array((alpha, beta, gamma))
40 |
41 | def eul2rot(theta) :
42 | R = np.array([[np.cos(theta[1])*np.cos(theta[2]), np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2]) - np.sin(theta[2])*np.cos(theta[0]), np.sin(theta[1])*np.cos(theta[0])*np.cos(theta[2]) + np.sin(theta[0])*np.sin(theta[2])],
43 | [np.sin(theta[2])*np.cos(theta[1]), np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) + np.cos(theta[0])*np.cos(theta[2]), np.sin(theta[1])*np.sin(theta[2])*np.cos(theta[0]) - np.sin(theta[0])*np.cos(theta[2])],
44 | [-np.sin(theta[1]), np.sin(theta[0])*np.cos(theta[1]), np.cos(theta[0])*np.cos(theta[1])]])
45 | return R
46 |
47 | @dataclass
48 | class ObjaverseDataModuleConfig:
49 | root_dir: str = None
50 | data_type: str = "occupancy" # occupancy or sdf
51 | n_samples: int = 4096 # number of points in input point cloud
52 | scale: float = 1.0 # scale of the input point cloud and target supervision
53 | noise_sigma: float = 0.0 # noise level of the input point cloud
54 |
55 | load_supervision: bool = True # whether to load supervision
56 | supervision_type: str = "occupancy" # occupancy, sdf, tsdf, tsdf_w_surface
57 |
58 | n_supervision: int = 10000 # number of points in supervision
59 |
60 | load_image: bool = False # whether to load images
61 | image_data_path: str = "" # path to the image data
62 | image_type: str = "rgb" # rgb, normal
63 | background_color: Tuple[float, float, float] = field(
64 | default_factory=lambda: (1.0, 1.0, 1.0)
65 | )
66 | idx: Optional[List[int]] = None # index of the image to load
67 | n_views: int = 1 # number of views
68 | rotate: bool = False # whether to rotate the input point cloud and the supervision
69 |
70 | load_caption: bool = False # whether to load captions
71 | caption_type: str = "text" # text, clip_embeds
72 | tokenizer_pretrained_model_name_or_path: str = ""
73 |
74 | batch_size: int = 32
75 | num_workers: int = 0
76 |
77 |
78 | class ObjaverseDataset(Dataset):
79 | def __init__(self, cfg: Any, split: str) -> None:
80 | super().__init__()
81 | self.cfg = cfg
82 | self.split = split
83 |
84 | # make sure root_dir is list
85 | if isinstance(self.cfg.root_dir, str):
86 | self.cfg.root_dir = [self.cfg.root_dir]
87 |
88 | # cache file
89 | cache_file = ''
90 | if os.path.exists(cache_file):
91 |
92 | with open(cache_file, 'rb') as f:
93 | self.uids_og = pickle.load(f)
94 | else:
95 | self.uids_og = self._scan_files()
96 | with open(cache_file, 'wb') as f:
97 | pickle.dump(self.uids_og, f)
98 |
99 | self.background_color = torch.as_tensor(self.cfg.background_color)
100 |
101 | if self.cfg.load_image:
102 |
103 | mapping=''
104 |
105 |
106 | df = pd.read_csv(mapping, sep=',', header=None, names=['number', 'hash'], skiprows=1)
107 |
108 |
109 | self.mapping_dict = dict(zip(df['number'], df['hash']))
110 | self.uids= []
111 | for uid_idx, uid_tuple in enumerate(self.uids_og):
112 | # Extract the path and filename
113 | _, filename = uid_tuple
114 | # Get the part after the underscore
115 | if '_' in filename:
116 | search_str = filename.split('_', 1)[1]
117 |
118 | if search_str in self.mapping_dict.keys():
119 |
120 | self.uids.append(uid_tuple)
121 |
122 | else:
123 | self.uids=self.uids_og
124 | print(f"Loaded {len(self.uids)} {split} usable uids")
125 | def _scan_files(self):
126 | uids = []
127 | total_files = []
128 |
129 | for root_dir in self.cfg.root_dir:
130 | files = os.listdir(root_dir)
131 | # 给每个文件添加对应的根目录信息
132 | files = [(root_dir, file) for file in files]
133 | total_files.extend(files)
134 |
135 |
136 | if self.split == 'train':
137 | total_files = total_files[150:]
138 | else:
139 | total_files = total_files[:150]
140 |
141 |
142 | return [
143 | (root_dir, file) for root_dir, file in total_files
144 | if os.path.exists(
145 | f'{root_dir}/{file}/xxx.npz'
146 | )
147 | ]
148 |
149 | def __len__(self):
150 |
151 | return len(self.uids)
152 |
153 | def _load_shape(self, index: int) -> Dict[str, Any]:
154 |
155 | if self.cfg.supervision_type == "sdf":
156 |
157 | sdfs3 = np.asarray(pointcloud['clean_surface_sdf'])
158 | ind3 = rng.choice(surface_og_n.shape[0], self.cfg.n_supervision//3, replace=False)
159 |
160 | rand_points3=surface_og[ind3]
161 | sdfs3 =sdfs3[ind3]
162 | normal3=surface_og_n[ind3]
163 |
164 | rand_points=np.concatenate((rand_points1,rand_points2,rand_points3),axis=0)
165 | sdfs=np.concatenate((sdfs1,sdfs2,sdfs3),axis=0)
166 |
167 |
168 | else:
169 | rand_points=np.concatenate((rand_points1,rand_points2),axis=0)
170 | sdfs=np.concatenate((sdfs1,sdfs2),axis=0)
171 |
172 | ret = {
173 | "uid": self.uids[index][1],
174 | "surface": surface.astype(np.float32),
175 |
176 | }
177 |
178 |
179 | ret["rand_points"] = rand_points.astype(np.float32)
180 |
181 | if self.cfg.supervision_type == "sdf":
182 | sdfs=np.nan_to_num(sdfs, nan=1.0, posinf=1.0, neginf=-1.0)
183 | # ret["sdf"] = sdfs.flatten().astype(np.float32).clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold) / self.cfg.tsdf_threshold
184 | # ret["sdf"] = sdfs.flatten().astype(np.float32).clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold)
185 | ret["sdf"] = sdfs.flatten().astype(np.float32)
186 | # ret["sdf"] = sdfs[ind2].flatten().astype(np.float32)
187 | ret['surface_normal']=normal3
188 | elif self.cfg.supervision_type == "occupancy":
189 | # ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype(np.float32)
190 | ret["occupancies"] = np.where(sdfs.flatten() < 0, 0, 1).astype(np.float32)
191 | else:
192 | raise NotImplementedError(f"Supervision type {self.cfg.supervision_type} not implemented")
193 |
194 | return ret
195 |
196 | def _load_image(self, index: int) -> Dict[str, Any]:
197 | name=self.uids[index][1].split('_')[1]
198 | file_path=self.mapping_dict[name]
199 | # image_paths=os.path.join(images_root,file_path,file_path,name)
200 | def _load_single_image(img_path):
201 | img = torch.from_numpy(
202 | np.asarray(
203 | Image.fromarray(imageio.v2.imread(img_path))
204 | .convert("RGBA")
205 | )
206 | / 255.0
207 | ).float()
208 | mask: Float[Tensor, "H W 1"] = img[:, :, -1:]
209 | image: Float[Tensor, "H W 3"] = img[:, :, :3] * mask + self.background_color[
210 | None, None, :
211 | ] * (1 - mask)
212 | return image
213 | ret = {}
214 | if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal":
215 | assert self.cfg.n_views == 1, "Only single view is supported for single image"
216 | sel_idx = random.choice(self.cfg.idx)
217 | ret["sel_image_idx"] = sel_idx
218 |
219 |
220 | img_path=file_path
221 | ret["image"] = _load_single_image(img_path)
222 | else:
223 | raise NotImplementedError(f"Image type {self.cfg.image_type} not implemented")
224 |
225 | return ret
226 |
227 | def _load_caption(self, index: int, drop_text_embed: bool = False) -> Dict[str, Any]:
228 | ret = {}
229 | if self.cfg.caption_type == "text":
230 | caption = eval(json.load(open(f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f'/annotation.json')))
231 | texts = [v for k, v in caption.items()]
232 | sel_idx = random.randint(0, len(texts) - 1)
233 | ret["sel_caption_idx"] = sel_idx
234 | ret['text_input_ids'] = self.tokenizer(
235 | texts[sel_idx] if not drop_text_embed else "",
236 | max_length=self.tokenizer.model_max_length,
237 | padding="max_length",
238 | truncation=True,
239 | return_tensors="pt"
240 | ).input_ids.detach()
241 | else:
242 | raise NotImplementedError(f"Caption type {self.cfg.caption_type} not implemented")
243 |
244 | return ret
245 |
246 | def get_data(self, index):
247 | # load shape
248 | ret = self._load_shape(index)
249 |
250 |
251 | if self.cfg.load_image:
252 | ret.update(self._load_image(index))
253 | return ret
254 |
255 | def __getitem__(self, index):
256 | try:
257 | return self.get_data(index)
258 | except Exception as e:
259 | print(f"Error in {self.uids[index]}: {e}")
260 | return self.__getitem__(np.random.randint(len(self)))
261 |
262 |
263 | def collate(self, batch):
264 | batch = torch.utils.data.default_collate(batch)
265 | return batch
266 |
267 |
268 |
269 | @register("objaverse-datamodule")
270 | class ObjaverseDataModule(pl.LightningDataModule):
271 | cfg: ObjaverseDataModuleConfig
272 |
273 | def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
274 | super().__init__()
275 | self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg)
276 |
277 | def setup(self, stage=None) -> None:
278 | if stage in [None, "fit"]:
279 | self.train_dataset = ObjaverseDataset(self.cfg, "train")
280 | if stage in [None, "fit", "validate"]:
281 | self.val_dataset = ObjaverseDataset(self.cfg, "val")
282 | if stage in [None, "test", "predict"]:
283 | self.test_dataset = ObjaverseDataset(self.cfg, "test")
284 |
285 | def prepare_data(self):
286 | pass
287 |
288 | def general_loader(self, dataset, batch_size, collate_fn=None, num_workers=0) -> DataLoader:
289 | return DataLoader(
290 | dataset, batch_size=batch_size, shuffle=True,collate_fn=collate_fn, num_workers=num_workers
291 | )
292 | def train_dataloader(self) -> DataLoader:
293 | return self.general_loader(
294 | self.train_dataset,
295 | batch_size=self.cfg.batch_size,
296 | collate_fn=self.train_dataset.collate,
297 | num_workers=self.cfg.num_workers
298 | )
299 |
300 | def val_dataloader(self) -> DataLoader:
301 | return self.general_loader(self.val_dataset, batch_size=1)
302 |
303 | def test_dataloader(self) -> DataLoader:
304 | return self.general_loader(self.test_dataset, batch_size=1)
305 |
306 | def predict_dataloader(self) -> DataLoader:
307 | return self.general_loader(self.test_dataset, batch_size=1)
--------------------------------------------------------------------------------
/mar3d/models/conditional_encoders/clip/modeling_conditional_clip.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # Reference:
16 | # * transformers/models/dinov2/modeling_dinov2.py
17 | # * https://github.com/facebookresearch/DiT/blob/main/models.py#L101
18 | # * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2
19 | """ PyTorch CLIP model."""
20 |
21 | from typing import Dict, List, Optional, Set, Tuple, Union
22 |
23 | import torch
24 | import torch.nn as nn
25 |
26 | from .modeling_clip import (
27 | CLIPConfig,
28 | CLIPTextConfig,
29 | CLIPVisionConfig,
30 | CLIPEncoderLayer,
31 | CLIPTextTransformer,
32 | CLIPVisionTransformer,
33 | CLIPModel,
34 | CLIPVisionEmbeddings,
35 | CLIPVisionModel,
36 | CLIPOutput,
37 | BaseModelOutput,
38 | BaseModelOutputWithPooling
39 | )
40 |
41 |
42 | class ModLN(nn.Module):
43 | def __init__(self, inner_dim: int, mod_dim: int = 32):
44 | super().__init__()
45 | self.mlp = nn.Sequential(
46 | nn.SiLU(),
47 | nn.Linear(mod_dim, inner_dim * 2),
48 | )
49 |
50 | for m in self.modules():
51 | if isinstance(m, nn.Linear):
52 | nn.init.zeros_(m.weight)
53 | nn.init.zeros_(m.bias)
54 |
55 | def forward(self, x:torch.Tensor, condition:torch.Tensor):
56 | '''
57 | x: [N, M, C_in], M: num of tokens
58 | condition: [N, C_mod]
59 | '''
60 | shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1)
61 | return x * (1 + scale) + shift
62 |
63 |
64 | class ConditionalCLIPVisionConfig(CLIPVisionConfig):
65 | def __init__(self, modulation_dim: int = 32, *args, **kwargs):
66 | super().__init__(*args, **kwargs)
67 | self.modulation_dim = modulation_dim
68 |
69 |
70 | class ConditionalCLIPEncoderLayer(CLIPEncoderLayer):
71 | """This corresponds to the Block class in the original implementation."""
72 |
73 | def __init__(self, config: ConditionalCLIPVisionConfig) -> None:
74 | super().__init__(config)
75 | self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim)
76 | self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim)
77 |
78 | def forward(
79 | self,
80 | hidden_states: torch.Tensor,
81 | attention_mask: torch.Tensor,
82 | causal_attention_mask: torch.Tensor,
83 | condition: Optional[torch.Tensor] = None,
84 | output_attentions: bool = False,
85 | ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
86 | residual = hidden_states
87 |
88 | hidden_states = self.mod_norm1(self.layer_norm1(hidden_states), condition)
89 | hidden_states, attn_weights = self.self_attn(
90 | hidden_states=hidden_states,
91 | attention_mask=attention_mask,
92 | causal_attention_mask=causal_attention_mask,
93 | output_attentions=output_attentions,
94 | )
95 | hidden_states = residual + hidden_states
96 |
97 | residual = hidden_states
98 | hidden_states = self.mod_norm2(self.layer_norm2(hidden_states), condition)
99 | hidden_states = self.mlp(hidden_states)
100 | hidden_states = residual + hidden_states
101 |
102 | outputs = (hidden_states,)
103 |
104 | if output_attentions:
105 | outputs += (attn_weights,)
106 |
107 | return outputs
108 |
109 |
110 | class ConditionalCLIPEncoder(nn.Module):
111 | def __init__(self, config: CLIPConfig) -> None:
112 | super().__init__()
113 | self.config = config
114 | self.layers = nn.ModuleList([ConditionalCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
115 | self.gradient_checkpointing = False
116 |
117 | def forward(
118 | self,
119 | inputs_embeds,
120 | attention_mask: Optional[torch.Tensor] = None,
121 | causal_attention_mask: Optional[torch.Tensor] = None,
122 | output_attentions: Optional[bool] = None,
123 | output_hidden_states: Optional[bool] = None,
124 | condition: Optional[torch.Tensor] = None,
125 | return_dict: Optional[bool] = None,
126 | ) -> Union[tuple, BaseModelOutput]:
127 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
128 | output_hidden_states = (
129 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
130 | )
131 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
132 |
133 | encoder_states = () if output_hidden_states else None
134 | all_attentions = () if output_attentions else None
135 |
136 | hidden_states = inputs_embeds
137 | for idx, encoder_layer in enumerate(self.layers):
138 | if output_hidden_states:
139 | encoder_states = encoder_states + (hidden_states,)
140 | if self.gradient_checkpointing and self.training:
141 | layer_outputs = self._gradient_checkpointing_func(
142 | encoder_layer.__call__,
143 | hidden_states,
144 | attention_mask,
145 | causal_attention_mask,
146 | condition=condition,
147 | output_attentions=output_attentions,
148 | )
149 | else:
150 | layer_outputs = encoder_layer(
151 | hidden_states,
152 | attention_mask,
153 | causal_attention_mask,
154 | condition=condition,
155 | output_attentions=output_attentions,
156 | )
157 |
158 | hidden_states = layer_outputs[0]
159 |
160 | if output_attentions:
161 | all_attentions = all_attentions + (layer_outputs[1],)
162 |
163 | if output_hidden_states:
164 | encoder_states = encoder_states + (hidden_states,)
165 |
166 | if not return_dict:
167 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
168 | return BaseModelOutput(
169 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
170 | )
171 |
172 |
173 | class ConditionalCLIPVisionTransformer(CLIPVisionTransformer):
174 | def __init__(self, config: ConditionalCLIPVisionConfig):
175 | super().__init__(config)
176 | self.config = config
177 | embed_dim = config.hidden_size
178 |
179 | self.embeddings = CLIPVisionEmbeddings(config)
180 | self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
181 | self.encoder = ConditionalCLIPEncoder(config)
182 | self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
183 |
184 | def forward(
185 | self,
186 | pixel_values: Optional[torch.FloatTensor] = None,
187 | condition: Optional[torch.Tensor] = None,
188 | output_attentions: Optional[bool] = None,
189 | output_hidden_states: Optional[bool] = None,
190 | return_dict: Optional[bool] = None,
191 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
192 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
193 | output_hidden_states = (
194 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
195 | )
196 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
197 |
198 | if pixel_values is None:
199 | raise ValueError("You have to specify pixel_values")
200 |
201 | hidden_states = self.embeddings(pixel_values)
202 | hidden_states = self.pre_layrnorm(hidden_states)
203 |
204 | encoder_outputs = self.encoder(
205 | inputs_embeds=hidden_states,
206 | output_attentions=output_attentions,
207 | output_hidden_states=output_hidden_states,
208 | condition=condition,
209 | return_dict=return_dict,
210 | )
211 |
212 | last_hidden_state = encoder_outputs[0]
213 | pooled_output = last_hidden_state[:, 0, :]
214 | pooled_output = self.post_layernorm(pooled_output)
215 |
216 | if not return_dict:
217 | return (last_hidden_state, pooled_output) + encoder_outputs[1:]
218 |
219 | return BaseModelOutputWithPooling(
220 | last_hidden_state=last_hidden_state,
221 | pooler_output=pooled_output,
222 | hidden_states=encoder_outputs.hidden_states,
223 | attentions=encoder_outputs.attentions,
224 | )
225 |
226 |
227 | class ConditionalCLIPVisionModel(CLIPVisionModel):
228 | config_class = ConditionalCLIPVisionConfig
229 |
230 | def __init__(self, config: ConditionalCLIPVisionConfig):
231 | super().__init__(config)
232 | self.vision_model = ConditionalCLIPVisionTransformer(config)
233 | # Initialize weights and apply final processing
234 | self.post_init()
235 |
236 | def forward(
237 | self,
238 | pixel_values: Optional[torch.FloatTensor] = None,
239 | condition: Optional[torch.Tensor] = None,
240 | output_attentions: Optional[bool] = None,
241 | output_hidden_states: Optional[bool] = None,
242 | return_dict: Optional[bool] = None,
243 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
244 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
245 |
246 | return self.vision_model(
247 | pixel_values=pixel_values,
248 | condition=condition,
249 | output_attentions=output_attentions,
250 | output_hidden_states=output_hidden_states,
251 | return_dict=return_dict,
252 | )
253 |
254 |
255 | class ConditionalCLIPModel(CLIPModel):
256 | config_class = CLIPConfig
257 |
258 | def __init__(self, config: CLIPConfig):
259 | super().__init__(config)
260 |
261 | if not isinstance(config.text_config, CLIPTextConfig):
262 | raise ValueError(
263 | "config.text_config is expected to be of type CLIPTextConfig but is of type"
264 | f" {type(config.text_config)}."
265 | )
266 |
267 | if not isinstance(config.vision_config, CLIPVisionConfig):
268 | raise ValueError(
269 | "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
270 | f" {type(config.vision_config)}."
271 | )
272 |
273 | text_config = config.text_config
274 | vision_config = config.vision_config
275 |
276 | self.projection_dim = config.projection_dim
277 | self.text_embed_dim = text_config.hidden_size
278 | self.vision_embed_dim = vision_config.hidden_size
279 |
280 | self.text_model = CLIPTextTransformer(text_config)
281 | self.vision_model = ConditionalCLIPVisionTransformer(vision_config)
282 |
283 | self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
284 | self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
285 | self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
286 |
287 | # Initialize weights and apply final processing
288 | self.post_init()
289 |
290 | def get_image_features(
291 | self,
292 | pixel_values: Optional[torch.FloatTensor] = None,
293 | condition: Optional[torch.Tensor] = None,
294 | output_attentions: Optional[bool] = None,
295 | output_hidden_states: Optional[bool] = None,
296 | return_dict: Optional[bool] = None,
297 | ) -> torch.FloatTensor:
298 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
299 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
300 | output_hidden_states = (
301 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
302 | )
303 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
304 |
305 | vision_outputs = self.vision_model(
306 | pixel_values=pixel_values,
307 | condition=condition,
308 | output_attentions=output_attentions,
309 | output_hidden_states=output_hidden_states,
310 | return_dict=return_dict,
311 | )
312 |
313 | pooled_output = vision_outputs[1] # pooled_output
314 | image_features = self.visual_projection(pooled_output)
315 |
316 | return image_features
317 |
318 | def forward(
319 | self,
320 | input_ids: Optional[torch.LongTensor] = None,
321 | pixel_values: Optional[torch.FloatTensor] = None,
322 | condition: Optional[torch.Tensor] = None,
323 | attention_mask: Optional[torch.Tensor] = None,
324 | position_ids: Optional[torch.LongTensor] = None,
325 | return_loss: Optional[bool] = None,
326 | output_attentions: Optional[bool] = None,
327 | output_hidden_states: Optional[bool] = None,
328 | return_dict: Optional[bool] = None,
329 | ) -> Union[Tuple, CLIPOutput]:
330 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
331 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
332 | output_hidden_states = (
333 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
334 | )
335 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
336 |
337 | vision_outputs = self.vision_model(
338 | pixel_values=pixel_values,
339 | condition=condition,
340 | output_attentions=output_attentions,
341 | output_hidden_states=output_hidden_states,
342 | return_dict=return_dict,
343 | )
344 |
345 | text_outputs = self.text_model(
346 | input_ids=input_ids,
347 | attention_mask=attention_mask,
348 | position_ids=position_ids,
349 | output_attentions=output_attentions,
350 | output_hidden_states=output_hidden_states,
351 | return_dict=return_dict,
352 | )
353 |
354 | image_embeds = vision_outputs[1]
355 | image_embeds = self.visual_projection(image_embeds)
356 |
357 | text_embeds = text_outputs[1]
358 | text_embeds = self.text_projection(text_embeds)
359 |
360 | # normalized features
361 | image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
362 | text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
363 |
364 | # cosine similarity as logits
365 | logit_scale = self.logit_scale.exp()
366 | logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
367 | logits_per_image = logits_per_text.t()
368 |
369 | loss = None
370 | if return_loss:
371 | loss = clip_loss(logits_per_text)
372 |
373 | if not return_dict:
374 | output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
375 | return ((loss,) + output) if loss is not None else output
376 |
377 | return CLIPOutput(
378 | loss=loss,
379 | logits_per_image=logits_per_image,
380 | logits_per_text=logits_per_text,
381 | text_embeds=text_embeds,
382 | image_embeds=image_embeds,
383 | text_model_output=text_outputs,
384 | vision_model_output=vision_outputs,
385 | )
386 |
--------------------------------------------------------------------------------
/mar3d/systems/mar_diffusion.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | import os
3 | import numpy as np
4 | import json
5 | import copy
6 | import torch
7 | import torch.nn.functional as F
8 | from skimage import measure
9 | from einops import repeat
10 | from tqdm import tqdm
11 | from PIL import Image
12 |
13 | from transformers import CLIPImageProcessor, CLIPTokenizer
14 | from diffusers import (
15 | DDPMScheduler,
16 | DDIMScheduler,
17 | UniPCMultistepScheduler,
18 | KarrasVeScheduler,
19 | DPMSolverMultistepScheduler
20 | )
21 | import scipy.stats as stats
22 |
23 | from plyfile import PlyData, PlyElement
24 | import mar3d
25 | from mar3d.systems.base import BaseSystem
26 | from mar3d.utils.ops import generate_dense_grid_points
27 | from mar3d.utils.typing import *
28 | import torchvision
29 | import pandas as pd
30 | import math
31 | from timm.models.vision_transformer import Block
32 | from mar3d.systems.diffloss import DiffLoss
33 | import torch.nn as nn
34 |
35 |
36 | class MAR(nn.Module):
37 | """ Masked Autoencoder with VisionTransformer backbone
38 | """
39 | def __init__(self, latent_size=256, patch_size=1,
40 | encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
41 | decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
42 | mlp_ratio=4., norm_layer=nn.LayerNorm,
43 | vae_embed_dim=64,
44 | mask_ratio_min=0.7,
45 | label_drop_prob=0.1,
46 | attn_dropout=0.1,
47 | proj_dropout=0.1,
48 | buffer_size=256,
49 | diffloss_d=6,
50 | diffloss_w=1024,
51 | num_sampling_steps='100',
52 | diffusion_batch_mul=4,
53 | grad_checkpointing=False,
54 | ):
55 | super().__init__()
56 |
57 | # --------------------------------------------------------------------------
58 | # VAE and patchify specifics
59 | self.vae_embed_dim = vae_embed_dim
60 | # self.vae_stride = vae_stride
61 | self.patch_size = patch_size
62 | # self.seq_h = self.seq_w = img_size // vae_stride // patch_size
63 | self.seq_len =latent_size// patch_size
64 | self.token_embed_dim = vae_embed_dim * patch_size
65 | self.grad_checkpointing = grad_checkpointing
66 |
67 |
68 | # --------------------------------------------------------------------------
69 | # Class Embedding
70 | # self.num_classes = class_num
71 | # self.class_emb = nn.Embedding(1000, encoder_embed_dim)
72 | self.label_drop_prob = label_drop_prob
73 | # Fake class embedding for CFG's unconditional generation
74 | self.fake_latent = nn.Parameter(torch.zeros(buffer_size, encoder_embed_dim))
75 |
76 |
77 | # --------------------------------------------------------------------------
78 | # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
79 | self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
80 |
81 | # --------------------------------------------------------------------------
82 | # MAR encoder specifics
83 | self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
84 | self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
85 | self.buffer_size = buffer_size
86 | self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))
87 |
88 | self.encoder_blocks = nn.ModuleList([
89 | Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
90 | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
91 | self.encoder_norm = norm_layer( encoder_embed_dim)
92 |
93 | # --------------------------------------------------------------------------
94 | # MAR decoder specifics
95 | self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
96 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
97 | self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
98 |
99 | self.decoder_blocks = nn.ModuleList([
100 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
101 | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
102 |
103 | self.decoder_norm = norm_layer(decoder_embed_dim)
104 | self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
105 |
106 | self.initialize_weights()
107 |
108 | # --------------------------------------------------------------------------
109 | # Diffusion Loss
110 | self.diffloss = DiffLoss(
111 | target_channels=self.token_embed_dim,
112 | z_channels=decoder_embed_dim,
113 | width=diffloss_w,
114 | depth=diffloss_d,
115 | num_sampling_steps=num_sampling_steps,
116 | grad_checkpointing=grad_checkpointing
117 | )
118 | self.diffusion_batch_mul = diffusion_batch_mul
119 |
120 | def initialize_weights(self):
121 | # parameters
122 | # torch.nn.init.normal_(self.class_emb.weight, std=.02)
123 | torch.nn.init.normal_(self.fake_latent, std=.02)
124 |
125 | torch.nn.init.normal_(self.mask_token, std=.02)
126 | torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
127 | torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
128 | torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
129 |
130 | # initialize nn.Linear and nn.LayerNorm
131 | self.apply(self._init_weights)
132 |
133 | def _init_weights(self, m):
134 | if isinstance(m, nn.Linear):
135 | # we use xavier_uniform following official JAX ViT:
136 | torch.nn.init.xavier_uniform_(m.weight)
137 | if isinstance(m, nn.Linear) and m.bias is not None:
138 | nn.init.constant_(m.bias, 0)
139 | elif isinstance(m, nn.LayerNorm):
140 | if m.bias is not None:
141 | nn.init.constant_(m.bias, 0)
142 | if m.weight is not None:
143 | nn.init.constant_(m.weight, 1.0)
144 | def sample_orders(self, bsz):
145 | # generate a batch of random generation orders
146 | orders = []
147 | for _ in range(bsz):
148 | order = np.array(list(range(self.seq_len)))
149 | np.random.shuffle(order)
150 | orders.append(order)
151 | orders = torch.Tensor(np.array(orders)).cuda().long()
152 | return orders
153 |
154 | def random_masking(self, x, orders):
155 | # generate token mask
156 | bsz, seq_len, embed_dim = x.shape
157 | mask_rate = self.mask_ratio_generator.rvs(1)[0]
158 | num_masked_tokens = int(np.ceil(seq_len * mask_rate))
159 | mask = torch.zeros(bsz, seq_len, device=x.device)
160 | mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
161 | src=torch.ones(bsz, seq_len, device=x.device))
162 | return mask
163 |
164 | def forward_mae_encoder(self, x, mask, class_embedding):
165 |
166 |
167 | x = self.z_proj(x)
168 | bsz, seq_len, embed_dim = x.shape
169 |
170 | # concat buffer
171 | x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1)
172 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
173 |
174 | # random drop class embedding during training
175 | if self.training:
176 |
177 | drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
178 | drop_latent_mask =drop_latent_mask .to(x.device)
179 | drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(x.dtype)
180 |
181 | class_embedding = drop_latent_mask[:,:,None] * self.fake_latent[None] + (1 - drop_latent_mask) [:,:,None]* class_embedding
182 |
183 |
184 | x[:, :self.buffer_size] = class_embedding
185 |
186 | # encoder position embedding
187 | x = x + self.encoder_pos_embed_learned
188 | x = self.z_proj_ln(x)
189 |
190 | # dropping
191 | x = x[ (1-mask_with_buffer).nonzero(as_tuple=True) ].reshape(bsz, -1, embed_dim)
192 |
193 | # apply Transformer blocks
194 | if self.grad_checkpointing and not torch.jit.is_scripting():
195 | for block in self.encoder_blocks:
196 | x = checkpoint(block, x)
197 | else:
198 | for block in self.encoder_blocks:
199 | x = block(x)
200 | x = self.encoder_norm(x)
201 |
202 | return x
203 |
204 | def forward_mae_decoder(self, x, mask):
205 |
206 | x = self.decoder_embed(x)
207 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
208 |
209 | # pad mask tokens
210 | mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
211 | x_after_pad = mask_tokens.clone()
212 | x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
213 |
214 | # decoder position embedding
215 | x = x_after_pad + self.decoder_pos_embed_learned
216 |
217 | # apply Transformer blocks
218 | if self.grad_checkpointing and not torch.jit.is_scripting():
219 | for block in self.decoder_blocks:
220 | x = checkpoint(block, x)
221 | else:
222 | for block in self.decoder_blocks:
223 | x = block(x)
224 | x = self.decoder_norm(x)
225 |
226 | x = x[:, self.buffer_size:]
227 | x = x + self.diffusion_pos_embed_learned
228 | return x
229 |
230 | def forward_loss(self, z, target, mask):
231 | bsz, seq_len, _ = target.shape
232 | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
233 | z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
234 | mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
235 | loss = self.diffloss(z=z, target=target, mask=mask)
236 | return loss
237 |
238 | def forward(self, x, labels,low_res=None):
239 |
240 |
241 | gt_latents = x.clone().detach()
242 | orders = self.sample_orders(bsz=x.size(0))
243 | mask = self.random_masking(x, orders)
244 |
245 | # mae encoder
246 | x = self.forward_mae_encoder(x, mask,labels)
247 |
248 | # mae decoder
249 | z = self.forward_mae_decoder(x, mask)
250 | # print(z.shape)
251 | # diffloss
252 | loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
253 |
254 | return loss
255 |
256 | def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=True):
257 |
258 | # init and sample generation orders
259 | mask = torch.ones(bsz, self.seq_len).cuda()
260 | tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
261 | orders = self.sample_orders(bsz)
262 |
263 | indices = list(range(num_iter))
264 | if progress:
265 | indices = tqdm(indices)
266 | # generate latents
267 | for step in indices:
268 | cur_tokens = tokens.clone()
269 | if not cfg == 1.0:
270 | # ipdb.set_trace()
271 | tokens = torch.cat([tokens, tokens], dim=0)
272 | class_embedding = torch.cat([labels, self.fake_latent[None].repeat(bsz, 1,1)], dim=0)
273 | mask = torch.cat([mask, mask], dim=0)
274 | # mae encoder
275 | x = self.forward_mae_encoder(tokens, mask, class_embedding)
276 |
277 | # mae decoder
278 | z = self.forward_mae_decoder(x, mask)
279 | # ipdb.set_trace()
280 |
281 | # mask ratio for the next round, following MaskGIT and MAGE.
282 | mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
283 | mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
284 |
285 | # masks out at least one for the next iteration
286 | mask_len = torch.maximum(torch.Tensor([1]).cuda(),
287 | torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
288 |
289 | # get masking for next iteration and locations to be predicted in this iteration
290 | mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
291 | if step >= num_iter - 1:
292 | mask_to_pred = mask[:bsz].bool()
293 | else:
294 | mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
295 | mask = mask_next
296 | if not cfg == 1.0:
297 | mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
298 |
299 | # sample token latents for this step
300 | z = z[mask_to_pred.nonzero(as_tuple=True)]
301 | # cfg schedule follow Muse
302 | if cfg_schedule == "linear":
303 | cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
304 | elif cfg_schedule == "constant":
305 | cfg_iter = cfg
306 | else:
307 | raise NotImplementedError
308 | # ipdb.set_trace()
309 | # print(z.shape)
310 | sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
311 | if not cfg == 1.0:
312 | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
313 | mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
314 |
315 | cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
316 | tokens = cur_tokens.clone()
317 |
318 | return tokens
319 |
320 |
321 | @mar3d.register("mar-diffusion-system")
322 | class ShapeDiffusionSystem(BaseSystem):
323 | @dataclass
324 | class Config(BaseSystem.Config):
325 | data_type: str = None
326 | val_samples_json: str = None
327 | # shape vae model
328 | shape_model_type: str = None
329 | # condition model
330 | shape_model: dict = field(default_factory=dict)
331 | condition_model_type: str = None
332 | condition_model: dict = field(default_factory=dict)
333 |
334 | cfg: Config
335 |
336 | def configure(self):
337 | super().configure()
338 | self.shape_model = mar3d.find(self.cfg.shape_model_type)(self.cfg.shape_model)
339 | self.shape_model.eval()
340 | self.sigma_min = 0.0
341 | self.condition = mar3d.find(self.cfg.condition_model_type)(self.cfg.condition_model)
342 | self.condition.requires_grad_(False)
343 | self.mar=MAR(patch_size=1,latent_size=256,buffer_size=257)
344 |
345 | def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
346 |
347 | shape_embeds, kl_embed, posterior = self.shape_model.encode(
348 | batch["surface"][..., :3 + self.cfg.shape_model.point_feats],
349 | sample_posterior=True)
350 |
351 | latents=kl_embed
352 | cond_latents = self.condition(batch).to(latents.device)
353 |
354 | loss=self.mar(latents,cond_latents)
355 | return {
356 | "loss_diffusion": loss,
357 | "latents": latents,
358 | }
359 |
360 | def training_step(self, batch, batch_idx):
361 | out = self(batch)
362 | loss = 0.
363 | for name, value in out.items():
364 | if name.startswith("loss_"):
365 | self.log(f"train/{name}", value)
366 | loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])
367 |
368 | for name, value in self.cfg.loss.items():
369 | if name.startswith("lambda_"):
370 | self.log(f"train_params/{name}", self.C(value))
371 |
372 | return {"loss": loss}
373 |
374 | @torch.no_grad()
375 | def validation_step(self, batch, batch_idx):
376 | self.eval()
377 | oct_num=9
378 | with torch.no_grad():
379 |
380 | cond = self.condition.encode_image(batch['image'])
381 |
382 |
383 | latents = self.mar.sample_tokens(bsz=cond.shape[0], num_iter=64, cfg=3.0 ,
384 | cfg_schedule="linear", labels=cond, temperature=1.0)
385 | output=self.shape_model.decode(latents )
386 |
387 | mesh_v_f, has_surface = self.shape_model.extract_geometry(output, octree_depth=oct_num)
388 | for j in range(len(mesh_v_f)):
389 | self.save_mesh(
390 | f"it{self.true_global_step}/{batch['uid'][0]}_cfg{3.0}_oct{oct_num}.obj",
391 | mesh_v_f[j][0], mesh_v_f[j][1]
392 | )
393 |
394 | self.save_image(f"it{self.true_global_step}/{batch['uid'][0]}_gt.jpg", (batch['image']*255).int() )
395 |
396 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------