├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── assets ├── overview_simplified.png └── panorama_pipeline.png ├── baselines ├── da_v2.py ├── da_v2_metric.py ├── metric3d_v2.py └── moge.py ├── configs ├── eval │ ├── all_benchmarks.json │ └── benchmarks │ │ ├── ddad.json │ │ ├── diode.json │ │ ├── eth3d.json │ │ ├── gso.json │ │ ├── hammer.json │ │ ├── ibims-1.json │ │ ├── kitti.json │ │ ├── nyu.json │ │ ├── sintel.json │ │ └── spring.json └── train │ └── v1.json ├── docs ├── eval.md └── train.md ├── example_images ├── BooksCorridor.png ├── Braunschweig_Panoram.jpg ├── BunnyCake.jpg └── MaitreyaBuddha.png ├── moge ├── __init__.py ├── model │ ├── __init__.py │ ├── dinov2 │ │ ├── __init__.py │ │ ├── hub │ │ │ ├── __init__.py │ │ │ ├── backbones.py │ │ │ └── utils.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── block.py │ │ │ ├── dino_head.py │ │ │ ├── drop_path.py │ │ │ ├── layer_scale.py │ │ │ ├── mlp.py │ │ │ ├── patch_embed.py │ │ │ └── swiglu_ffn.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── vision_transformer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── cluster.py │ │ │ ├── config.py │ │ │ ├── dtype.py │ │ │ ├── param_groups.py │ │ │ └── utils.py │ ├── utils.py │ └── v1.py ├── scripts │ ├── __init__.py │ ├── app.py │ ├── cli.py │ ├── eval_baseline.py │ ├── infer.py │ ├── infer_baseline.py │ ├── infer_panorama.py │ ├── train.py │ └── vis_data.py ├── test │ ├── __init__.py │ ├── baseline.py │ ├── dataloader.py │ └── metrics.py ├── train │ ├── __init__.py │ ├── dataloader.py │ ├── losses.py │ └── utils.py └── utils │ ├── __init__.py │ ├── alignment.py │ ├── download.py │ ├── geometry_numpy.py │ ├── geometry_torch.py │ ├── io.py │ ├── panorama.py │ ├── pipeline.py │ ├── tools.py │ ├── vis.py │ ├── webfile.py │ └── webzipfile.py ├── pyproject.toml ├── pyrightconfig.json └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Ww][Ii][Nn]32/ 27 | [Aa][Rr][Mm]/ 28 | [Aa][Rr][Mm]64/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | [Ll]og/ 33 | [Ll]ogs/ 34 | 35 | # Visual Studio 2015/2017 cache/options directory 36 | .vs/ 37 | # Uncomment if you have tasks that create the project's static files in wwwroot 38 | #wwwroot/ 39 | 40 | # Visual Studio 2017 auto generated files 41 | Generated\ Files/ 42 | 43 | # MSTest test Results 44 | [Tt]est[Rr]esult*/ 45 | [Bb]uild[Ll]og.* 46 | 47 | # NUnit 48 | *.VisualState.xml 49 | TestResult.xml 50 | nunit-*.xml 51 | 52 | # Build Results of an ATL Project 53 | [Dd]ebugPS/ 54 | [Rr]eleasePS/ 55 | dlldata.c 56 | 57 | # Benchmark Results 58 | BenchmarkDotNet.Artifacts/ 59 | 60 | # .NET Core 61 | project.lock.json 62 | project.fragment.lock.json 63 | artifacts/ 64 | 65 | # ASP.NET Scaffolding 66 | ScaffoldingReadMe.txt 67 | 68 | # StyleCop 69 | StyleCopReport.xml 70 | 71 | # Files built by Visual Studio 72 | *_i.c 73 | *_p.c 74 | *_h.h 75 | *.ilk 76 | *.meta 77 | *.obj 78 | *.iobj 79 | *.pch 80 | *.pdb 81 | *.ipdb 82 | *.pgc 83 | *.pgd 84 | *.rsp 85 | *.sbr 86 | *.tlb 87 | *.tli 88 | *.tlh 89 | *.tmp 90 | *.tmp_proj 91 | *_wpftmp.csproj 92 | *.log 93 | *.tlog 94 | *.vspscc 95 | *.vssscc 96 | .builds 97 | *.pidb 98 | *.svclog 99 | *.scc 100 | 101 | # Chutzpah Test files 102 | _Chutzpah* 103 | 104 | # Visual C++ cache files 105 | ipch/ 106 | *.aps 107 | *.ncb 108 | *.opendb 109 | *.opensdf 110 | *.sdf 111 | *.cachefile 112 | *.VC.db 113 | *.VC.VC.opendb 114 | 115 | # Visual Studio profiler 116 | *.psess 117 | *.vsp 118 | *.vspx 119 | *.sap 120 | 121 | # Visual Studio Trace Files 122 | *.e2e 123 | 124 | # TFS 2012 Local Workspace 125 | $tf/ 126 | 127 | # Guidance Automation Toolkit 128 | *.gpState 129 | 130 | # ReSharper is a .NET coding add-in 131 | _ReSharper*/ 132 | *.[Rr]e[Ss]harper 133 | *.DotSettings.user 134 | 135 | # TeamCity is a build add-in 136 | _TeamCity* 137 | 138 | # DotCover is a Code Coverage Tool 139 | *.dotCover 140 | 141 | # AxoCover is a Code Coverage Tool 142 | .axoCover/* 143 | !.axoCover/settings.json 144 | 145 | # Coverlet is a free, cross platform Code Coverage Tool 146 | coverage*.json 147 | coverage*.xml 148 | coverage*.info 149 | 150 | # Visual Studio code coverage results 151 | *.coverage 152 | *.coveragexml 153 | 154 | # NCrunch 155 | _NCrunch_* 156 | .*crunch*.local.xml 157 | nCrunchTemp_* 158 | 159 | # MightyMoose 160 | *.mm.* 161 | AutoTest.Net/ 162 | 163 | # Web workbench (sass) 164 | .sass-cache/ 165 | 166 | # Installshield output folder 167 | [Ee]xpress/ 168 | 169 | # DocProject is a documentation generator add-in 170 | DocProject/buildhelp/ 171 | DocProject/Help/*.HxT 172 | DocProject/Help/*.HxC 173 | DocProject/Help/*.hhc 174 | DocProject/Help/*.hhk 175 | DocProject/Help/*.hhp 176 | DocProject/Help/Html2 177 | DocProject/Help/html 178 | 179 | # Click-Once directory 180 | publish/ 181 | 182 | # Publish Web Output 183 | *.[Pp]ublish.xml 184 | *.azurePubxml 185 | # Note: Comment the next line if you want to checkin your web deploy settings, 186 | # but database connection strings (with potential passwords) will be unencrypted 187 | *.pubxml 188 | *.publishproj 189 | 190 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 191 | # checkin your Azure Web App publish settings, but sensitive information contained 192 | # in these scripts will be unencrypted 193 | PublishScripts/ 194 | 195 | # NuGet Packages 196 | *.nupkg 197 | # NuGet Symbol Packages 198 | *.snupkg 199 | # The packages folder can be ignored because of Package Restore 200 | **/[Pp]ackages/* 201 | # except build/, which is used as an MSBuild target. 202 | !**/[Pp]ackages/build/ 203 | # Uncomment if necessary however generally it will be regenerated when needed 204 | #!**/[Pp]ackages/repositories.config 205 | # NuGet v3's project.json files produces more ignorable files 206 | *.nuget.props 207 | *.nuget.targets 208 | 209 | # Microsoft Azure Build Output 210 | csx/ 211 | *.build.csdef 212 | 213 | # Microsoft Azure Emulator 214 | ecf/ 215 | rcf/ 216 | 217 | # Windows Store app package directories and files 218 | AppPackages/ 219 | BundleArtifacts/ 220 | Package.StoreAssociation.xml 221 | _pkginfo.txt 222 | *.appx 223 | *.appxbundle 224 | *.appxupload 225 | 226 | # Visual Studio cache files 227 | # files ending in .cache can be ignored 228 | *.[Cc]ache 229 | # but keep track of directories ending in .cache 230 | !?*.[Cc]ache/ 231 | 232 | # Others 233 | ClientBin/ 234 | ~$* 235 | *~ 236 | *.dbmdl 237 | *.dbproj.schemaview 238 | *.jfm 239 | *.pfx 240 | *.publishsettings 241 | orleans.codegen.cs 242 | 243 | # Including strong name files can present a security risk 244 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 245 | #*.snk 246 | 247 | # Since there are multiple workflows, uncomment next line to ignore bower_components 248 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 249 | #bower_components/ 250 | 251 | # RIA/Silverlight projects 252 | Generated_Code/ 253 | 254 | # Backup & report files from converting an old project file 255 | # to a newer Visual Studio version. Backup files are not needed, 256 | # because we have git ;-) 257 | _UpgradeReport_Files/ 258 | Backup*/ 259 | UpgradeLog*.XML 260 | UpgradeLog*.htm 261 | ServiceFabricBackup/ 262 | *.rptproj.bak 263 | 264 | # SQL Server files 265 | *.mdf 266 | *.ldf 267 | *.ndf 268 | 269 | # Business Intelligence projects 270 | *.rdl.data 271 | *.bim.layout 272 | *.bim_*.settings 273 | *.rptproj.rsuser 274 | *- [Bb]ackup.rdl 275 | *- [Bb]ackup ([0-9]).rdl 276 | *- [Bb]ackup ([0-9][0-9]).rdl 277 | 278 | # Microsoft Fakes 279 | FakesAssemblies/ 280 | 281 | # GhostDoc plugin setting file 282 | *.GhostDoc.xml 283 | 284 | # Node.js Tools for Visual Studio 285 | .ntvs_analysis.dat 286 | node_modules/ 287 | 288 | # Visual Studio 6 build log 289 | *.plg 290 | 291 | # Visual Studio 6 workspace options file 292 | *.opt 293 | 294 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 295 | *.vbw 296 | 297 | # Visual Studio 6 auto-generated project file (contains which files were open etc.) 298 | *.vbp 299 | 300 | # Visual Studio 6 workspace and project file (working project files containing files to include in project) 301 | *.dsw 302 | *.dsp 303 | 304 | # Visual Studio 6 technical files 305 | *.ncb 306 | *.aps 307 | 308 | # Visual Studio LightSwitch build output 309 | **/*.HTMLClient/GeneratedArtifacts 310 | **/*.DesktopClient/GeneratedArtifacts 311 | **/*.DesktopClient/ModelManifest.xml 312 | **/*.Server/GeneratedArtifacts 313 | **/*.Server/ModelManifest.xml 314 | _Pvt_Extensions 315 | 316 | # Paket dependency manager 317 | .paket/paket.exe 318 | paket-files/ 319 | 320 | # FAKE - F# Make 321 | .fake/ 322 | 323 | # CodeRush personal settings 324 | .cr/personal 325 | 326 | # Python Tools for Visual Studio (PTVS) 327 | __pycache__/ 328 | *.pyc 329 | 330 | # Cake - Uncomment if you are using it 331 | # tools/** 332 | # !tools/packages.config 333 | 334 | # Tabs Studio 335 | *.tss 336 | 337 | # Telerik's JustMock configuration file 338 | *.jmconfig 339 | 340 | # BizTalk build output 341 | *.btp.cs 342 | *.btm.cs 343 | *.odx.cs 344 | *.xsd.cs 345 | 346 | # OpenCover UI analysis results 347 | OpenCover/ 348 | 349 | # Azure Stream Analytics local run output 350 | ASALocalRun/ 351 | 352 | # MSBuild Binary and Structured Log 353 | *.binlog 354 | 355 | # NVidia Nsight GPU debugger configuration file 356 | *.nvuser 357 | 358 | # MFractors (Xamarin productivity tool) working folder 359 | .mfractor/ 360 | 361 | # Local History for Visual Studio 362 | .localhistory/ 363 | 364 | # Visual Studio History (VSHistory) files 365 | .vshistory/ 366 | 367 | # BeatPulse healthcheck temp database 368 | healthchecksdb 369 | 370 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 371 | MigrationBackup/ 372 | 373 | # Ionide (cross platform F# VS Code tools) working folder 374 | .ionide/ 375 | 376 | # Fody - auto-generated XML schema 377 | FodyWeavers.xsd 378 | 379 | # VS Code files for those working on multiple tools 380 | .vscode/* 381 | !.vscode/settings.json 382 | !.vscode/tasks.json 383 | !.vscode/launch.json 384 | !.vscode/extensions.json 385 | *.code-workspace 386 | 387 | # Local History for Visual Studio Code 388 | .history/ 389 | 390 | # Windows Installer files from build outputs 391 | *.cab 392 | *.msi 393 | *.msix 394 | *.msm 395 | *.msp 396 | 397 | # JetBrains Rider 398 | *.sln.iml 399 | 400 | # Python 401 | *.egg-info/ 402 | /build 403 | 404 | # MoGe 405 | /data* 406 | /download 407 | /extract 408 | /debug 409 | /workspace 410 | /mlruns 411 | /infer_output 412 | /video_output 413 | /eval_output 414 | /.blobcache 415 | /test_images 416 | /test_videos 417 | /vis 418 | /videos 419 | /blobmnt 420 | /eval_dump 421 | /pretrained 422 | /.gradio 423 | /tmp -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 2024-11-28 2 | ### Added 3 | - Supported user-provided camera FOV. See [scripts/infer.py](scripts/infer.py) --fov_x. 4 | - Related issues: [#25](https://github.com/microsoft/MoGe/issues/25) and [#24](https://github.com/microsoft/MoGe/issues/24). 5 | - Added inference scripts for panorama images. See [scripts/infer_panorama.py](scripts/infer_panorama.py). 6 | - Related issue: [#19](https://github.com/microsoft/MoGe/issues/19). 7 | 8 | ### Fixed 9 | - Suppressed unnecessary numpy runtime warnings. 10 | - Specified recommended versions of requirements. 11 | - Related issue: [#21](https://github.com/microsoft/MoGe/issues/21). 12 | 13 | ### Changed 14 | - Moved `app.py` and `infer.py` to [scripts/](scripts/) 15 | - Improved edge removal. 16 | 17 | ## 2025-03-18 18 | ### Added 19 | - Training and evaluation code. See [docs/train.md](docs/train.md) and [docs/eval.md](docs/eval.md). 20 | - Supported installation via pip. Thanks to @fabiencastan and @jgoueslard 21 | for commits in the [#47](https://github.com/microsoft/MoGe/pull/47) 22 | - Supported command-line usage when installed. 23 | 24 | ### Changed 25 | - Moved `scripts/` into `moge/` for package installation and command-line usage. 26 | - Renamed `moge.model.moge_model` to `moge.model.v1` for version management. 27 | Now you can import the model class through `from moge.model.v1 import MoGeModel` or `from moge.model import import_model_class_by_version; MoGeModel = import_model_class_by_version('v1')`. 28 | - Exposed `num_tokens` parameter in MoGe model. -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # MoGe: Unlocking Accurate Monocular Geometry Estimation for Open-Domain Images with Optimal Training Supervision 4 | 5 | arXiv 6 | Project Page 7 | 8 | 9 |
10 | 11 | Method overview 12 | 13 | MoGe is a powerful model for recovering 3D geometry from monocular open-domain images. The model consists of a ViT encoder and a convolutional decoder. It directly predicts an affine-invariant point map as well as a mask that excludes regions with undefined geometry (e.g., sky), from which the camera shift, camera focal length and depth map can be further derived. 14 | 15 | ***Check our [website](https://wangrc.site/MoGePage) for videos and interactive results!*** 16 | 17 | ## Features 18 | 19 | * **Accurate 3D geometry estimation**: Estimate point maps from single images with high precision. Capable of capturing depth variations up to 1000×, ensuring a comprehensive scene representation. 20 | * **Optional ground-truth FOV input**: Enhance model accuracy further by providing the true field of view. 21 | * **Flexible resolution support**: Works seamlessly with various resolutions and aspect ratios, from 2:1 to 1:2. 22 | * **Optimized for speed**: Achieves <0.1s latency per image on an A100 / RTX 3090 GPU with fp16, and 0.2s with fp32. 23 | 24 | ## TODO List 25 | 26 | - [x] Release inference code & ViT-Large model. 27 | - [x] Release evaluation and training code. 28 | - [ ] Release ViT-Base and ViT-Giant models. 29 | 30 | 🌟*Updated on 2025/03/18* [CHANGELOG](CHANGELOG.md) 31 | - **Training and evaluation code released!** 32 | - Installation via pip and CLI usage supported. 33 | 34 | ## Installation 35 | 36 | ### Install via pip 37 | 38 | ```bash 39 | pip install git+https://github.com/microsoft/MoGe.git 40 | ``` 41 | 42 | ### Or clone this repository 43 | 44 | ```bash 45 | git clone https://github.com/microsoft/MoGe.git 46 | cd MoGe 47 | ``` 48 | 49 | and install the requirements 50 | 51 | ```bash 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | MoGe should be compatible with most requirements versions. Please check the `requirements.txt` for more details if you have concerns. 56 | 57 | ## Usage 58 | 59 | ### Pretrained model 60 | 61 | The ViT-Large model has been uploaded to Hugging Face hub at [Ruicheng/moge-vitl](https://huggingface.co/Ruicheng/moge-vitl). 62 | You may load the model via `MoGeModel.from_pretrained("Ruicheng/moge-vitl")` without manually downloading. 63 | 64 | If loading the model from a local file is preferred, you may manually download the model from the huggingface hub and load it via `MoGeModel.from_pretrained("PATH_TO_LOCAL_MODEL.pt")`. 65 | 66 | ### Minimal code example 67 | 68 | Here is a minimal example for loading the model and inferring on a single image. 69 | 70 | ```python 71 | import cv2 72 | import torch 73 | from moge.model.v1 import MoGeModel 74 | 75 | device = torch.device("cuda") 76 | 77 | # Load the model from huggingface hub (or load from local). 78 | model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) 79 | 80 | # Read the input image and convert to tensor (3, H, W) and normalize to [0, 1] 81 | input_image = cv2.cvtColor(cv2.imread("PATH_TO_IMAGE.jpg"), cv2.COLOR_BGR2RGB) 82 | input_image = torch.tensor(input_image / 255, dtype=torch.float32, device=device).permute(2, 0, 1) 83 | 84 | # Infer 85 | output = model.infer(input_image) 86 | # `output` has keys "points", "depth", "mask" and "intrinsics", 87 | # The maps are in the same size as the input image. 88 | # { 89 | # "points": (H, W, 3), # scale-invariant point map in OpenCV camera coordinate system (x right, y down, z forward) 90 | # "depth": (H, W), # scale-invariant depth map 91 | # "mask": (H, W), # a binary mask for valid pixels. 92 | # "intrinsics": (3, 3), # normalized camera intrinsics 93 | # } 94 | # For more usage details, see the `MoGeModel.infer` docstring. 95 | ``` 96 | 97 | ### Gradio demo | `moge app` 98 | 99 | The demo is also available at our [Hugging Face space](https://huggingface.co/spaces/Ruicheng/MoGe). 100 | 101 | ```bash 102 | # Using the command line tool 103 | moge app 104 | 105 | # In this repo 106 | python moge/scripts/app.py # --share for Gradio public sharing 107 | ``` 108 | 109 | See also [`moge/scripts/app.py`](moge/scripts/app.py) 110 | 111 | 112 | ### Inference | `moge infer` 113 | 114 | Run the script `moge/scripts/infer.py` via the following command: 115 | 116 | ```bash 117 | # Save the output [maps], [glb] and [ply] files 118 | moge infer -i IMAGES_FOLDER_OR_IMAGE_PATH --o OUTPUT_FOLDER --maps --glb --ply 119 | 120 | # Show the result in a window (requires pyglet < 2.0, e.g. pip install pyglet==1.5.29) 121 | moge infer -i IMAGES_FOLDER_OR_IMAGE_PATH --o OUTPUT_FOLDER --show 122 | ``` 123 | 124 | For detailed options, run `moge infer --help`: 125 | 126 | ``` 127 | Usage: moge infer [OPTIONS] 128 | 129 | Inference script for the MoGe model. 130 | 131 | Options: 132 | -i, --input PATH Input image or folder path. "jpg" and "png" are 133 | supported. 134 | --fov_x FLOAT If camera parameters are known, set the 135 | horizontal field of view in degrees. Otherwise, 136 | MoGe will estimate it. 137 | -o, --output PATH Output folder path 138 | --pretrained TEXT Pretrained model name or path. Defaults to 139 | "Ruicheng/moge-vitl" 140 | --device TEXT Device name (e.g. "cuda", "cuda:0", "cpu"). 141 | Defaults to "cuda" 142 | --fp16 Use fp16 precision for 2x faster inference. 143 | --resize INTEGER Resize the image(s) & output maps to a specific 144 | size. Defaults to None (no resizing). 145 | --resolution_level INTEGER An integer [0-9] for the resolution level for 146 | inference. Higher value means more tokens and 147 | the finer details will be captured, but 148 | inference can be slower. Defaults to 9. Note 149 | that it is irrelevant to the output size, which 150 | is always the same as the input size. 151 | `resolution_level` actually controls 152 | `num_tokens`. See `num_tokens` for more details. 153 | --num_tokens INTEGER number of tokens used for inference. A integer 154 | in the (suggested) range of `[1200, 2500]`. 155 | `resolution_level` will be ignored if 156 | `num_tokens` is provided. Default: None 157 | --threshold FLOAT Threshold for removing edges. Defaults to 0.03. 158 | Smaller value removes more edges. "inf" means no 159 | thresholding. 160 | --maps Whether to save the output maps and fov(image, 161 | depth, mask, points, fov). 162 | --glb Whether to save the output as a.glb file. The 163 | color will be saved as a texture. 164 | --ply Whether to save the output as a.ply file. The 165 | color will be saved as vertex colors. 166 | --show Whether show the output in a window. Note that 167 | this requires pyglet<2 installed as required by 168 | trimesh. 169 | --help Show this message and exit. 170 | ``` 171 | 172 | See also [`moge/scripts/infer.py`](moge/scripts/infer.py) 173 | 174 | ### 360° panorama images | `moge infer_panorama` 175 | 176 | > *NOTE: This is an experimental extension of MoGe.* 177 | 178 | The script will split the 360-degree panorama image into multiple perspective views and infer on each view separately. 179 | The output maps will be combined to produce a panorama depth map and point map. 180 | 181 | Note that the panorama image must have spherical parameterization (e.g., environment maps or equirectangular images). Other formats must be converted to spherical format before using this script. Run `moge infer_panorama --help` for detailed options. 182 | 183 | 184 |
185 | 186 | 187 | The photo is from [this URL](https://commons.wikimedia.org/wiki/Category:360%C2%B0_panoramas_with_equirectangular_projection#/media/File:Braunschweig_Sankt-%C3%84gidien_Panorama_02.jpg) 188 |
189 | 190 | See also [`moge/scripts/infer_panorama.py`](moge/scripts/infer_panorama.py) 191 | 192 | ## Training & Finetuning 193 | 194 | See [docs/train.md](docs/train.md) 195 | 196 | ## Evaluation 197 | 198 | See [docs/eval.md](docs/eval.md) 199 | 200 | ## License 201 | 202 | MoGe code is released under the MIT license, except for DINOv2 code in `moge/model/dinov2` which is released by Meta AI under the Apache 2.0 license. 203 | See [LICENSE](LICENSE) for more details. 204 | 205 | 206 | ## Citation 207 | 208 | If you find our work useful in your research, we gratefully request that you consider citing our paper: 209 | 210 | ``` 211 | @misc{wang2024moge, 212 | title={MoGe: Unlocking Accurate Monocular Geometry Estimation for Open-Domain Images with Optimal Training Supervision}, 213 | author={Wang, Ruicheng and Xu, Sicheng and Dai, Cassie and Xiang, Jianfeng and Deng, Yu and Tong, Xin and Yang, Jiaolong}, 214 | year={2024}, 215 | eprint={2410.19115}, 216 | archivePrefix={arXiv}, 217 | primaryClass={cs.CV}, 218 | url={https://arxiv.org/abs/2410.19115}, 219 | } 220 | ``` 221 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /assets/overview_simplified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/assets/overview_simplified.png -------------------------------------------------------------------------------- /assets/panorama_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/assets/panorama_pipeline.png -------------------------------------------------------------------------------- /baselines/da_v2.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/DepthAnything/Depth-Anything-V2 2 | import os 3 | import sys 4 | from typing import * 5 | from pathlib import Path 6 | 7 | import click 8 | import torch 9 | import torch.nn.functional as F 10 | import torchvision.transforms as T 11 | import torchvision.transforms.functional as TF 12 | 13 | from moge.test.baseline import MGEBaselineInterface 14 | 15 | 16 | class Baseline(MGEBaselineInterface): 17 | def __init__(self, repo_path: str, backbone: str, num_tokens: int, device: Union[torch.device, str]): 18 | # Create from repo 19 | repo_path = os.path.abspath(repo_path) 20 | if repo_path not in sys.path: 21 | sys.path.append(repo_path) 22 | if not Path(repo_path).exists(): 23 | raise FileNotFoundError(f'Cannot find the Depth-Anything repository at {repo_path}. Please clone the repository and provide the path to it using the --repo option.') 24 | from depth_anything_v2.dpt import DepthAnythingV2 25 | 26 | device = torch.device(device) 27 | 28 | # Instantiate model 29 | model = DepthAnythingV2(encoder=backbone, features=256, out_channels=[256, 512, 1024, 1024]) 30 | 31 | # Load checkpoint 32 | checkpoint_path = os.path.join(repo_path, f'checkpoints/depth_anything_v2_{backbone}.pth') 33 | if not os.path.exists(checkpoint_path): 34 | raise FileNotFoundError(f'Cannot find the checkpoint file at {checkpoint_path}. Please download the checkpoint file and place it in the checkpoints directory.') 35 | checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) 36 | model.load_state_dict(checkpoint) 37 | 38 | model.to(device).eval() 39 | self.model = model 40 | self.num_tokens = num_tokens 41 | self.device = device 42 | 43 | @click.command() 44 | @click.option('--repo', 'repo_path', type=click.Path(), default='../Depth-Anything-V2', help='Path to the Depth-Anything repository.') 45 | @click.option('--backbone', type=click.Choice(['vits', 'vitb', 'vitl']), default='vitl', help='Encoder architecture.') 46 | @click.option('--num_tokens', type=int, default=None, help='Number of tokens to use for the input image.') 47 | @click.option('--device', type=str, default='cuda', help='Device to use for inference.') 48 | @staticmethod 49 | def load(repo_path: str, backbone, num_tokens: int, device: torch.device = 'cuda'): 50 | return Baseline(repo_path, backbone, num_tokens, device) 51 | 52 | @torch.inference_mode() 53 | def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: 54 | original_height, original_width = image.shape[-2:] 55 | 56 | assert intrinsics is None, "Depth-Anything-V2 does not support camera intrinsics input" 57 | 58 | if image.ndim == 3: 59 | image = image.unsqueeze(0) 60 | omit_batch_dim = True 61 | else: 62 | omit_batch_dim = False 63 | 64 | if self.num_tokens is None: 65 | resize_factor = 518 / min(original_height, original_width) 66 | expected_width = round(original_width * resize_factor / 14) * 14 67 | expected_height = round(original_height * resize_factor / 14) * 14 68 | else: 69 | aspect_ratio = original_width / original_height 70 | tokens_rows = round((self.num_tokens * aspect_ratio) ** 0.5) 71 | tokens_cols = round((self.num_tokens / aspect_ratio) ** 0.5) 72 | expected_width = tokens_cols * 14 73 | expected_height = tokens_rows * 14 74 | image = TF.resize(image, (expected_height, expected_width), interpolation=T.InterpolationMode.BICUBIC, antialias=True) 75 | 76 | image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 77 | 78 | disparity = self.model(image) 79 | 80 | disparity = F.interpolate(disparity[:, None], size=(original_height, original_width), mode='bilinear', align_corners=False, antialias=False)[:, 0] 81 | 82 | if omit_batch_dim: 83 | disparity = disparity.squeeze(0) 84 | 85 | return { 86 | 'disparity_affine_invariant': disparity 87 | } 88 | 89 | -------------------------------------------------------------------------------- /baselines/da_v2_metric.py: -------------------------------------------------------------------------------- 1 | # Reference https://github.com/DepthAnything/Depth-Anything-V2/metric_depth 2 | import os 3 | import sys 4 | from typing import * 5 | from pathlib import Path 6 | 7 | import click 8 | import torch 9 | import torch.nn.functional as F 10 | import torchvision.transforms as T 11 | import torchvision.transforms.functional as TF 12 | import cv2 13 | 14 | from moge.test.baseline import MGEBaselineInterface 15 | 16 | 17 | class Baseline(MGEBaselineInterface): 18 | 19 | def __init__(self, repo_path: str, backbone: str, domain: str, num_tokens: int, device: str): 20 | device = torch.device(device) 21 | repo_path = os.path.abspath(repo_path) 22 | if not Path(repo_path).exists(): 23 | raise FileNotFoundError(f'Cannot find the Depth-Anything repository at {repo_path}. Please clone the repository and provide the path to it using the --repo option.') 24 | sys.path.append(os.path.join(repo_path, 'metric_depth')) 25 | from depth_anything_v2.dpt import DepthAnythingV2 26 | 27 | model_configs = { 28 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 29 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 30 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]} 31 | } 32 | 33 | if domain == 'indoor': 34 | dataset = 'hypersim' 35 | max_depth = 20 36 | elif domain == 'outdoor': 37 | dataset = 'vkitti' 38 | max_depth = 80 39 | else: 40 | raise ValueError(f"Invalid domain: {domain}") 41 | 42 | model = DepthAnythingV2(**model_configs[backbone], max_depth=max_depth) 43 | checkpoint_path = os.path.join(repo_path, f'checkpoints/depth_anything_v2_metric_{dataset}_{backbone}.pth') 44 | if not os.path.exists(checkpoint_path): 45 | raise FileNotFoundError(f'Cannot find the checkpoint file at {checkpoint_path}. Please download the checkpoint file and place it in the checkpoints directory.') 46 | model.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True)) 47 | model.eval().to(device) 48 | 49 | self.model = model 50 | self.num_tokens = num_tokens 51 | self.device = device 52 | 53 | @click.command() 54 | @click.option('--repo', 'repo_path', type=click.Path(), default='../Depth-Anything-V2', help='Path to the Depth-Anything repository.') 55 | @click.option('--backbone', type=click.Choice(['vits', 'vitb', 'vitl']), default='vitl', help='Backbone architecture.') 56 | @click.option('--domain', type=click.Choice(['indoor', 'outdoor']), help='Domain of the dataset.') 57 | @click.option('--num_tokens', type=int, default=None, help='Number of tokens for the ViT model') 58 | @click.option('--device', type=str, default='cuda', help='Device to use for inference.') 59 | @staticmethod 60 | def load(repo_path: str, backbone: str, domain: str, num_tokens: int, device: str): 61 | return Baseline(repo_path, backbone, domain, num_tokens, device) 62 | 63 | @torch.inference_mode() 64 | def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: 65 | original_height, original_width = image.shape[-2:] 66 | 67 | assert intrinsics is None, "Depth-Anything-V2 does not support camera intrinsics input" 68 | 69 | if image.ndim == 3: 70 | image = image.unsqueeze(0) 71 | omit_batch_dim = True 72 | else: 73 | omit_batch_dim = False 74 | 75 | if self.num_tokens is None: 76 | resize_factor = 518 / min(original_height, original_width) 77 | expected_width = round(original_width * resize_factor / 14) * 14 78 | expected_height = round(original_height * resize_factor / 14) * 14 79 | else: 80 | aspect_ratio = original_width / original_height 81 | tokens_rows = round((self.num_tokens * aspect_ratio) ** 0.5) 82 | tokens_cols = round((self.num_tokens / aspect_ratio) ** 0.5) 83 | expected_width = tokens_cols * 14 84 | expected_height = tokens_rows * 14 85 | image = TF.resize(image, (expected_height, expected_width), interpolation=T.InterpolationMode.BICUBIC, antialias=True) 86 | 87 | image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 88 | 89 | depth = self.model(image) 90 | 91 | depth = F.interpolate(depth[:, None], size=(original_height, original_width), mode='bilinear', align_corners=False, antialias=False)[:, 0] 92 | 93 | if omit_batch_dim: 94 | depth = depth.squeeze(0) 95 | 96 | return { 97 | 'depth_metric': depth 98 | } 99 | 100 | -------------------------------------------------------------------------------- /baselines/metric3d_v2.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/YvanYin/Metric3D 2 | import os 3 | import sys 4 | from typing import * 5 | 6 | import click 7 | import torch 8 | import torch.nn.functional as F 9 | import cv2 10 | 11 | from moge.test.baseline import MGEBaselineInterface 12 | 13 | 14 | class Baseline(MGEBaselineInterface): 15 | def __init__(self, backbone: Literal['vits', 'vitl', 'vitg'], device): 16 | backbone_map = { 17 | 'vits': 'metric3d_vit_small', 18 | 'vitl': 'metric3d_vit_large', 19 | 'vitg': 'metric3d_vit_giant2' 20 | } 21 | 22 | device = torch.device(device) 23 | model = torch.hub.load('yvanyin/metric3d', backbone_map[backbone], pretrain=True) 24 | model.to(device).eval() 25 | 26 | self.model = model 27 | self.device = device 28 | 29 | @click.command() 30 | @click.option('--backbone', type=click.Choice(['vits', 'vitl', 'vitg']), default='vitl', help='Encoder architecture.') 31 | @click.option('--device', type=str, default='cuda', help='Device to use.') 32 | @staticmethod 33 | def load(backbone: str = 'vitl', device: torch.device = 'cuda'): 34 | return Baseline(backbone, device) 35 | 36 | @torch.inference_mode() 37 | def inference_one_image(self, image: torch.Tensor, intrinsics: torch.Tensor = None): 38 | # Reference: https://github.com/YvanYin/Metric3D/blob/main/mono/utils/do_test.py 39 | 40 | # rgb_origin: RGB, 0-255, uint8 41 | rgb_origin = image.cpu().numpy().transpose((1, 2, 0)) * 255 42 | 43 | # keep ratio resize 44 | input_size = (616, 1064) # for vit model 45 | h, w = rgb_origin.shape[:2] 46 | scale = min(input_size[0] / h, input_size[1] / w) 47 | rgb = cv2.resize(rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) 48 | if intrinsics is not None: 49 | focal = intrinsics[0, 0] * int(w * scale) 50 | 51 | # padding to input_size 52 | padding = [123.675, 116.28, 103.53] 53 | h, w = rgb.shape[:2] 54 | pad_h = input_size[0] - h 55 | pad_w = input_size[1] - w 56 | pad_h_half = pad_h // 2 57 | pad_w_half = pad_w // 2 58 | rgb = cv2.copyMakeBorder(rgb, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=padding) 59 | pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] 60 | 61 | # normalize rgb 62 | mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None] 63 | std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None] 64 | rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float() 65 | rgb = torch.div((rgb - mean), std) 66 | rgb = rgb[None, :, :, :].cuda() 67 | 68 | # inference 69 | pred_depth, confidence, output_dict = self.model.inference({'input': rgb}) 70 | 71 | # un pad 72 | pred_depth = pred_depth.squeeze() 73 | pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]] 74 | pred_depth = pred_depth.clamp_min(0.5) # clamp to 0.5m, since metric3d could yield very small depth values, resulting in crashed the scale shift alignment. 75 | 76 | # upsample to original size 77 | pred_depth = F.interpolate(pred_depth[None, None, :, :], image.shape[-2:], mode='bilinear').squeeze() 78 | 79 | if intrinsics is not None: 80 | # de-canonical transform 81 | canonical_to_real_scale = focal / 1000.0 # 1000.0 is the focal length of canonical camera 82 | pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric 83 | pred_depth = torch.clamp(pred_depth, 0, 300) 84 | 85 | pred_normal, normal_confidence = output_dict['prediction_normal'].split([3, 1], dim=1) # see https://arxiv.org/abs/2109.09881 for details 86 | 87 | # un pad and resize to some size if needed 88 | pred_normal = pred_normal.squeeze(0) 89 | pred_normal = pred_normal[:, pad_info[0] : pred_normal.shape[1] - pad_info[1], pad_info[2] : pred_normal.shape[2] - pad_info[3]] 90 | 91 | # you can now do anything with the normal 92 | pred_normal = F.interpolate(pred_normal[None, :, :, :], image.shape[-2:], mode='bilinear').squeeze(0) 93 | pred_normal = F.normalize(pred_normal, p=2, dim=0) 94 | 95 | return pred_depth, pred_normal.permute(1, 2, 0) 96 | 97 | @torch.inference_mode() 98 | def infer(self, image: torch.Tensor, intrinsics: torch.Tensor = None): 99 | # image: (B, H, W, 3) or (H, W, 3) 100 | if image.ndim == 3: 101 | pred_depth, pred_normal = self.inference_one_image(image, intrinsics) 102 | else: 103 | for i in range(image.shape[0]): 104 | pred_depth_i, pred_normal_i = self.inference_one_image(image[i], intrinsics[i] if intrinsics is not None else None) 105 | pred_depth.append(pred_depth_i) 106 | pred_normal.append(pred_normal_i) 107 | pred_depth = torch.stack(pred_depth, dim=0) 108 | pred_normal = torch.stack(pred_normal, dim=0) 109 | 110 | if intrinsics is not None: 111 | return { 112 | "depth_metric": pred_depth, 113 | } 114 | else: 115 | return { 116 | "depth_scale_invariant": pred_depth, 117 | } 118 | -------------------------------------------------------------------------------- /baselines/moge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import * 4 | import importlib 5 | 6 | import click 7 | import torch 8 | import utils3d 9 | 10 | from moge.test.baseline import MGEBaselineInterface 11 | 12 | 13 | class Baseline(MGEBaselineInterface): 14 | 15 | def __init__(self, num_tokens: int, resolution_level: int, pretrained_model_name_or_path: str, use_fp16: bool, device: str = 'cuda:0', version: str = 'v1'): 16 | super().__init__() 17 | from moge.model import import_model_class_by_version 18 | MoGeModel = import_model_class_by_version(version) 19 | self.version = version 20 | 21 | self.model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() 22 | 23 | self.device = torch.device(device) 24 | self.num_tokens = num_tokens 25 | self.resolution_level = resolution_level 26 | self.use_fp16 = use_fp16 27 | 28 | @click.command() 29 | @click.option('--num_tokens', type=int, default=None) 30 | @click.option('--resolution_level', type=int, default=9) 31 | @click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl') 32 | @click.option('--fp16', 'use_fp16', is_flag=True) 33 | @click.option('--device', type=str, default='cuda:0') 34 | @click.option('--version', type=str, default='v1') 35 | @staticmethod 36 | def load(num_tokens: int, resolution_level: int, pretrained_model_name_or_path: str, use_fp16: bool, device: str = 'cuda:0', version: str = 'v1'): 37 | return Baseline(num_tokens, resolution_level, pretrained_model_name_or_path, use_fp16, device, version) 38 | 39 | # Implementation for inference 40 | @torch.inference_mode() 41 | def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.FloatTensor] = None): 42 | if intrinsics is not None: 43 | fov_x, _ = utils3d.torch.intrinsics_to_fov(intrinsics) 44 | fov_x = torch.rad2deg(fov_x) 45 | else: 46 | fov_x = None 47 | output = self.model.infer(image, fov_x=fov_x, apply_mask=True, num_tokens=self.num_tokens) 48 | 49 | if self.version == 'v1': 50 | return { 51 | 'points_scale_invariant': output['points'], 52 | 'depth_scale_invariant': output['depth'], 53 | 'intrinsics': output['intrinsics'], 54 | } 55 | else: 56 | return { 57 | 'points_metric': output['points'], 58 | 'depth_metric': output['depth'], 59 | 'intrinsics': output['intrinsics'], 60 | } 61 | 62 | @torch.inference_mode() 63 | def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: torch.FloatTensor = None): 64 | if intrinsics is not None: 65 | fov_x, _ = utils3d.torch.intrinsics_to_fov(intrinsics) 66 | fov_x = torch.rad2deg(fov_x) 67 | else: 68 | fov_x = None 69 | output = self.model.infer(image, fov_x=fov_x, apply_mask=False, num_tokens=self.num_tokens, use_fp16=self.use_fp16) 70 | 71 | if self.version == 'v1': 72 | return { 73 | 'points_scale_invariant': output['points'], 74 | 'depth_scale_invariant': output['depth'], 75 | 'intrinsics': output['intrinsics'], 76 | } 77 | else: 78 | return { 79 | 'points_metric': output['points'], 80 | 'depth_metric': output['depth'], 81 | 'intrinsics': output['intrinsics'], 82 | } 83 | 84 | -------------------------------------------------------------------------------- /configs/eval/all_benchmarks.json: -------------------------------------------------------------------------------- 1 | { 2 | "NYUv2": { 3 | "path": "data/eval/NYUv2", 4 | "width": 640, 5 | "height": 480, 6 | "split": ".index.txt", 7 | "depth_unit": 1.0 8 | }, 9 | "KITTI": { 10 | "path": "data/eval/KITTI", 11 | "width": 750, 12 | "height": 375, 13 | "split": ".index.txt", 14 | "depth_unit": 1 15 | }, 16 | "ETH3D": { 17 | "path": "data/eval/ETH3D", 18 | "width": 2048, 19 | "height": 1365, 20 | "split": ".index.txt", 21 | "include_segmentation": true, 22 | "depth_unit": 1 23 | }, 24 | "iBims-1": { 25 | "path": "data/eval/iBims-1", 26 | "width": 640, 27 | "height": 480, 28 | "split": ".index.txt", 29 | "has_sharp_boundary": true, 30 | "include_segmentation": true, 31 | "depth_unit": 1.0 32 | }, 33 | "GSO": { 34 | "path": "data/eval/GSO", 35 | "width": 512, 36 | "height": 512, 37 | "split": ".index.txt" 38 | }, 39 | "Sintel": { 40 | "path": "data/eval/Sintel", 41 | "width": 872, 42 | "height": 436, 43 | "split": ".index.txt", 44 | "has_sharp_boundary": true, 45 | "include_segmentation": true 46 | }, 47 | "DDAD": { 48 | "path": "data/eval/DDAD", 49 | "width": 1400, 50 | "height": 700, 51 | "include_segmentation": true, 52 | "split": ".index.txt", 53 | "depth_unit": 1.0 54 | }, 55 | "DIODE": { 56 | "path": "data/eval/DIODE", 57 | "width": 1024, 58 | "height": 768, 59 | "split": ".index.txt", 60 | "include_segmentation": true, 61 | "depth_unit": 1.0 62 | }, 63 | "Spring": { 64 | "path": "data/eval/Spring", 65 | "width": 1920, 66 | "height": 1080, 67 | "split": ".index.txt", 68 | "has_sharp_boundary": true 69 | }, 70 | "HAMMER": { 71 | "path": "data/eval/HAMMER", 72 | "width": 1664, 73 | "height": 832, 74 | "split": ".index.txt", 75 | "depth_unit": 1, 76 | "has_sharp_boundary": true 77 | } 78 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/ddad.json: -------------------------------------------------------------------------------- 1 | { 2 | "DDAD": { 3 | "path": "data/eval/DDAD", 4 | "width": 1400, 5 | "height": 700, 6 | "include_segmentation": true, 7 | "split": ".index.txt" 8 | } 9 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/diode.json: -------------------------------------------------------------------------------- 1 | { 2 | "DIODE": { 3 | "path": "data/eval/DIODE", 4 | "width": 1024, 5 | "height": 768, 6 | "split": ".index.txt", 7 | "include_segmentation": true 8 | } 9 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/eth3d.json: -------------------------------------------------------------------------------- 1 | { 2 | "ETH3D": { 3 | "path": "data/eval/ETH3D", 4 | "width": 2048, 5 | "height": 1365, 6 | "split": ".index.txt", 7 | "include_segmentation": true, 8 | "depth_unit": 1 9 | } 10 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/gso.json: -------------------------------------------------------------------------------- 1 | { 2 | "GSO": { 3 | "path": "data/eval/GSO", 4 | "width": 512, 5 | "height": 512, 6 | "split": ".index.txt" 7 | } 8 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/hammer.json: -------------------------------------------------------------------------------- 1 | { 2 | "HAMMER": { 3 | "path": "data/eval/HAMMER", 4 | "width": 1664, 5 | "height": 832, 6 | "split": ".index.txt", 7 | "depth_unit": 1, 8 | "has_sharp_boundary": true 9 | } 10 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/ibims-1.json: -------------------------------------------------------------------------------- 1 | { 2 | "iBims-1": { 3 | "path": "data/eval/iBims-1", 4 | "width": 640, 5 | "height": 480, 6 | "split": ".index.txt", 7 | "include_segmentation": true, 8 | "has_sharp_boundary": true 9 | } 10 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/kitti.json: -------------------------------------------------------------------------------- 1 | { 2 | "KITTI": { 3 | "path": "data/eval/KITTI", 4 | "width": 750, 5 | "height": 375, 6 | "split": ".index.txt", 7 | "depth_unit": 1 8 | } 9 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/nyu.json: -------------------------------------------------------------------------------- 1 | { 2 | "NYUv2": { 3 | "path": "data/eval/NYUv2", 4 | "width": 640, 5 | "height": 480, 6 | "split": ".test.txt" 7 | } 8 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/sintel.json: -------------------------------------------------------------------------------- 1 | { 2 | "Sintel": { 3 | "path": "data/eval/Sintel", 4 | "width": 872, 5 | "height": 436, 6 | "split": ".index.txt", 7 | "include_segmentation": true, 8 | "has_sharp_boundary": true 9 | } 10 | } -------------------------------------------------------------------------------- /configs/eval/benchmarks/spring.json: -------------------------------------------------------------------------------- 1 | { 2 | "Spring": { 3 | "path": "data/eval/Spring", 4 | "width": 1920, 5 | "height": 1080, 6 | "split": ".test.txt", 7 | "has_sharp_boundary": true 8 | } 9 | } -------------------------------------------------------------------------------- /configs/train/v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "aspect_ratio_range": [0.5, 2.0], 4 | "area_range": [250000, 1000000], 5 | "clamp_max_depth": 1000.0, 6 | "center_augmentation": 0.5, 7 | "fov_range_absolute": [1, 179], 8 | "fov_range_relative": [0.01, 1.0], 9 | "image_augmentation": ["jittering", "jpeg_loss", "blurring"], 10 | "datasets": [ 11 | { 12 | "name": "TartanAir", 13 | "path": "blobmnt/data_v3/TartanAir", 14 | "label_type": "synthetic", 15 | "index": ".index.txt", 16 | "depth": "depth.png", 17 | "weight": 4.8, 18 | "center_augmentation": 0.25, 19 | "fov_range_absolute": [30, 150], 20 | "fov_range_relative": [0.5, 1.0], 21 | "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise"] 22 | } 23 | ] 24 | }, 25 | "model_version": "v1", 26 | "model": { 27 | "encoder": "dinov2_vitl14", 28 | "remap_output": "exp", 29 | "intermediate_layers": 4, 30 | "dim_upsample": [256, 128, 64], 31 | "dim_times_res_block_hidden": 2, 32 | "num_res_blocks": 2, 33 | "num_tokens_range": [1200, 2500], 34 | "last_conv_channels": 32, 35 | "last_conv_size": 1 36 | }, 37 | "optimizer": { 38 | "type": "AdamW", 39 | "params": [ 40 | {"params": {"include": ["*"], "exclude": ["*backbone.*"]}, "lr": 1e-4}, 41 | {"params": {"include": ["*backbone.*"]}, "lr": 1e-5} 42 | ] 43 | }, 44 | "lr_scheduler": { 45 | "type": "SequentialLR", 46 | "params": { 47 | "schedulers": [ 48 | {"type": "LambdaLR", "params": {"lr_lambda": ["1.0", "max(0.0, min(1.0, (epoch - 1000) / 1000))"]}}, 49 | {"type": "StepLR", "params": {"step_size": 25000, "gamma": 0.5}} 50 | ], 51 | "milestones": [2000] 52 | } 53 | }, 54 | "low_resolution_training_steps": 50000, 55 | "loss": { 56 | "invalid": {}, 57 | "synthetic": { 58 | "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, 59 | "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, 60 | "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}}, 61 | "patch_64": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 64, "align_resolution": 4, "num_patches": 4096}}, 62 | "normal": {"function": "normal_loss", "weight": 1.0}, 63 | "mask": {"function": "mask_l2_loss", "weight": 1.0} 64 | }, 65 | "sfm": { 66 | "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, 67 | "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, 68 | "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}}, 69 | "mask": {"function": "mask_l2_loss", "weight": 1.0} 70 | }, 71 | "lidar": { 72 | "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, 73 | "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, 74 | "mask": {"function": "mask_l2_loss", "weight": 1.0} 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /docs/eval.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | We provide a unified evaluation script that runs baselines on multiple benchmarks. It takes a baseline model and evaluation configurations, evaluates on-the-fly, and reports results instantly in a JSON file. 4 | 5 | ## Benchmarks 6 | 7 | Donwload the processed datasets from [Huggingface Datasets](https://huggingface.co/datasets/Ruicheng/monocular-geometry-evaluation) and put them in the `data/eval` directory, using `huggingface-cli`: 8 | 9 | ```bash 10 | mkdir -p data/eval 11 | huggingface-cli download Ruicheng/monocular-geometry-evaluation --repo-type dataset --local-dir data/eval --local-dir-use-symlinks False 12 | ``` 13 | 14 | Then unzip the downloaded files: 15 | 16 | ```bash 17 | cd data/eval 18 | unzip '*.zip' 19 | # rm *.zip # if you don't keep the zip files 20 | ``` 21 | 22 | ## Configuration 23 | 24 | See [`configs/eval/all_benchmarks.json`](../configs/eval/all_benchmarks.json) for an example of evaluation configurations on all benchmarks. You can modify this file to evaluate on different benchmarks or different baselines. 25 | 26 | ## Baseline 27 | 28 | Some examples of baselines are provided in [`baselines/`](../baselines/). Pass the path to the baseline model python code to the `--baseline` argument of the evaluation script. 29 | 30 | ## Run Evaluation 31 | 32 | Run the script [`moge/scripts/eval_baseline.py`](../moge/scripts/eval_baseline.py). 33 | For example, 34 | 35 | ```bash 36 | # Evaluate MoGe on the 10 benchmarks 37 | python moge/scripts/eval_baseline.py --baseline baselines/moge.py --config configs/eval/all_benchmarks.json --output eval_output/moge.json --pretrained Ruicheng/moge-vitl --resolution_level 9 38 | 39 | # Evaluate Depth Anything V2 on the 10 benchmarks. (NOTE: affine disparity) 40 | python moge/scripts/eval_baseline.py --baseline baselines/da_v2.py --config configs/eval/all_benchmarks.json --output eval_output/da_v2.json 41 | ``` 42 | 43 | The `--baselies` `--input` `--output` arguments are for the inference script. The rest arguments, e.g. `--pretrained` `--resolution_level`, are custormized for loading the baseline model. 44 | 45 | Details of the arguments: 46 | 47 | ``` 48 | Usage: eval_baseline.py [OPTIONS] 49 | 50 | Evaluation script. 51 | 52 | Options: 53 | --baseline PATH Path to the baseline model python code. 54 | --config PATH Path to the evaluation configurations. Defaults to 55 | "configs/eval/all_benchmarks.json". 56 | --output PATH Path to the output json file. 57 | --oracle Use oracle mode for evaluation, i.e., use the GT intrinsics 58 | input. 59 | --dump_pred Dump predition results. 60 | --dump_gt Dump ground truth. 61 | --help Show this message and exit. 62 | ``` 63 | 64 | 65 | 66 | ## Wrap a Customized Baseline 67 | 68 | Wrap any baseline method with [`moge.test.baseline.MGEBaselineInterface`](../moge/test/baseline.py). 69 | See [`baselines/`](../baselines/) for more examples. 70 | 71 | It is a good idea to check the correctness of the baseline implementation by running inference on a small set of images via [`moge/scripts/infer_baselines.py`](../moge/scripts/infer_baselines.py): 72 | 73 | ```base 74 | python moge/scripts/infer_baselines.py --baseline baselines/moge.py --input example_images/ --output infer_outupt/moge --pretrained Ruicheng/moge-vitl --maps --ply 75 | ``` 76 | 77 | 78 | -------------------------------------------------------------------------------- /docs/train.md: -------------------------------------------------------------------------------- 1 | 2 | # Training 3 | 4 | This document provides instructions for training and finetuning the MoGe model. 5 | 6 | ## Additional Requirements 7 | 8 | The following packages other than those listed in [`pyproject.toml`](../pyproject.toml) are required for training and finetuning the MoGe model: 9 | 10 | ``` 11 | accelerate 12 | sympy 13 | mlflow 14 | ``` 15 | 16 | ## Data preparation 17 | 18 | ### Dataset format 19 | 20 | Each dataset should be organized as follows: 21 | 22 | ``` 23 | somedataset 24 | ├── .index.txt # A list of instance paths 25 | ├── folder1 26 | │ ├── instance1 # Each instance is in a folder 27 | │ │ ├── image.jpg # RGB image. 28 | │ │ ├── depth.png # 16-bit depth. See moge/utils/io.py for details 29 | │ │ ├── meta.json # Stores "intrinsics" as a 3x3 matrix 30 | │ │ └── ... # Other componests such as segmentation mask, normal map etc. 31 | ... 32 | ``` 33 | 34 | * `.index.txt` is placed at top directory to store a list of instance paths in this dataset. The dataloader will look for instances in this list. You may also use a custom split, e.g. `.train.txt`, `.val.txt` and specify it in the configuration file. 35 | 36 | * For depth images, it is recommended to use `read_depth()` and `write_depth()` in [`moge/utils/io.py`](../moge/utils/io.py) to read and write depth images. The depth is stored in logarithmic scale in 16-bit PNG format, offering a balanced precision, dynamic range and compression ratio compared to 16-bit and 32-bit EXR and linear depth formats. It also encodes `NaN` and `Inf` values for invalid depth values. 37 | 38 | * The `meta.json` should be a dictionary containing the key `intrinsics`, which are **normalized** camera parameters. You may put more metadata. 39 | 40 | * We also support reading and storing segementation masks for evaluation data (see paper evaluation of local points), which are saved in PNG format with semantic labels stored in png metadata as JSON strings. See `read_segmentation()` and `write_segmentation()` in [`moge/utils/io.py`](../moge/utils/io.py) for details. 41 | 42 | 43 | ### Visual inspection 44 | 45 | We provide a script to visualize the data and check the data quality. It will export the instance as a PLY file for visualization of point cloud. 46 | 47 | ```bash 48 | python moge/scripts/vis_data.py PATH_TO_INSTANCE --ply [-o SOMEWHERE_ELSE_TO_SAVE_VIS] 49 | ``` 50 | 51 | ### DataLoader 52 | 53 | Our training dataloaders is customized to handle loading data, performing perspective crop, and augmentation in a multithreading pipeline. Please refer to [`moge/train/dataloader.py`](../moge/train/dataloader.py) if you have any concern. 54 | 55 | 56 | ## Configuration 57 | 58 | See [`configs/train/v1.json`](../configs/train/v1.json) for an example configuration file. The configuration file defines the hyperparameters for training the MoGe model. 59 | Here is a commented configuration for reference: 60 | 61 | ```json 62 | { 63 | "data": { 64 | "aspect_ratio_range": [0.5, 2.0], # Range of aspect ratio of sampled images 65 | "area_range": [250000, 1000000], # Range of sampled image area in pixels 66 | "clamp_max_depth": 1000.0, # Maximum far/near 67 | "center_augmentation": 0.5, # Ratio of center crop augmentation 68 | "fov_range_absolute": [1, 179], # Absolute range of FOV in degrees 69 | "fov_range_relative": [0.01, 1.0], # Relative range of FOV to the original FOV 70 | "image_augmentation": ["jittering", "jpeg_loss", "blurring"], # List of image augmentation techniques 71 | "datasets": [ 72 | { 73 | "name": "TartanAir", # Name of the dataset. Name it as you like. 74 | "path": "data/TartanAir", # Path to the dataset 75 | "label_type": "synthetic", # Label type for this dataset. Losses will be applied accordingly. see "loss" config 76 | "weight": 4.8, # Probability of sampling this dataset 77 | "index": ".index.txt", # File name of the index file. Defaults to .index.txt 78 | "depth": "depth.png", # File name of depth images. Defaults to depth.png 79 | "center_augmentation": 0.25, # Below are dataset-specific hyperparameters. Overriding the global ones above. 80 | "fov_range_absolute": [30, 150], 81 | "fov_range_relative": [0.5, 1.0], 82 | "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise"] 83 | } 84 | ] 85 | }, 86 | "model_version": "v1", # Model version. If you have multiple model variants, you can use this to switch between them. 87 | "model": { # Model hyperparameters. Will be passed to Model __init__() as kwargs. 88 | "encoder": "dinov2_vitl14", 89 | "remap_output": "exp", 90 | "intermediate_layers": 4, 91 | "dim_upsample": [256, 128, 64], 92 | "dim_times_res_block_hidden": 2, 93 | "num_res_blocks": 2, 94 | "num_tokens_range": [1200, 2500], 95 | "last_conv_channels": 32, 96 | "last_conv_size": 1 97 | }, 98 | "optimizer": { # Reflection-like optimizer configurations. See moge.train.utils.py build_optimizer() for details. 99 | "type": "AdamW", 100 | "params": [ 101 | {"params": {"include": ["*"], "exclude": ["*backbone.*"]}, "lr": 1e-4}, 102 | {"params": {"include": ["*backbone.*"]}, "lr": 1e-5} 103 | ] 104 | }, 105 | "lr_scheduler": { # Reflection-like lr_scheduler configurations. See moge.train.utils.py build_lr_scheduler() for details. 106 | "type": "SequentialLR", 107 | "params": { 108 | "schedulers": [ 109 | {"type": "LambdaLR", "params": {"lr_lambda": ["1.0", "max(0.0, min(1.0, (epoch - 1000) / 1000))"]}}, 110 | {"type": "StepLR", "params": {"step_size": 25000, "gamma": 0.5}} 111 | ], 112 | "milestones": [2000] 113 | } 114 | }, 115 | "low_resolution_training_steps": 50000, # Total number of low-resolution training steps. It makes the early stage training faster. Later stage training on varying size images will be slower. 116 | "loss": { 117 | "invalid": {}, # invalid instance due to runtime error when loading data 118 | "synthetic": { # Below are loss hyperparameters 119 | "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, 120 | "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, 121 | "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}}, 122 | "patch_64": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 64, "align_resolution": 4, "num_patches": 4096}}, 123 | "normal": {"function": "normal_loss", "weight": 1.0}, 124 | "mask": {"function": "mask_l2_loss", "weight": 1.0} 125 | }, 126 | "sfm": { 127 | "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, 128 | "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, 129 | "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}}, 130 | "mask": {"function": "mask_l2_loss", "weight": 1.0} 131 | }, 132 | "lidar": { 133 | "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}}, 134 | "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}}, 135 | "mask": {"function": "mask_l2_loss", "weight": 1.0} 136 | } 137 | } 138 | } 139 | ``` 140 | 141 | ## Run Training 142 | 143 | Launch the training script [`moge/scripts/train.py`](../moge/scripts/train.py). Note that we use [`accelerate`](https://github.com/huggingface/accelerate) for distributed training. 144 | 145 | ```bash 146 | accelerate launch \ 147 | --num_processes 8 \ 148 | moge/scripts/train.py \ 149 | --config configs/train/v1.json \ 150 | --workspace workspace/debug \ 151 | --gradient_accumulation_steps 2 \ 152 | --batch_size_forward 2 \ 153 | --checkpoint latest \ 154 | --enable_gradient_checkpointing True \ 155 | --vis_every 1000 \ 156 | --enable_mlflow True 157 | ``` 158 | 159 | 160 | ## Finetuning 161 | 162 | To finetune the pre-trained MoGe model, download the model checkpoint and put it in a local directory, e.g. `pretrained/moge-vitl.pt`. 163 | 164 | > NOTE: when finetuning pretrained MoGe model, a much lower learning rate is required. 165 | The suggested learning rate for finetuning is not greater than 1e-5 for the head and 1e-6 for the backbone. 166 | And the batch size is recommended to be 32 at least. 167 | The settings in default configuration are not optimal for specific datasets and may require further tuning. 168 | 169 | ```bash 170 | accelerate launch \ 171 | --num_processes 8 \ 172 | moge/scripts/train.py \ 173 | --config configs/train/v1.json \ 174 | --workspace workspace/debug \ 175 | --gradient_accumulation_steps 2 \ 176 | --batch_size_forward 2 \ 177 | --checkpoint pretrained/moge-vitl.pt \ 178 | --enable_gradient_checkpointing True \ 179 | --vis_every 1000 \ 180 | --enable_mlflow True 181 | ``` 182 | -------------------------------------------------------------------------------- /example_images/BooksCorridor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/example_images/BooksCorridor.png -------------------------------------------------------------------------------- /example_images/Braunschweig_Panoram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/example_images/Braunschweig_Panoram.jpg -------------------------------------------------------------------------------- /example_images/BunnyCake.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/example_images/BunnyCake.jpg -------------------------------------------------------------------------------- /example_images/MaitreyaBuddha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/example_images/MaitreyaBuddha.png -------------------------------------------------------------------------------- /moge/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/__init__.py -------------------------------------------------------------------------------- /moge/model/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import * 3 | 4 | if TYPE_CHECKING: 5 | from .v1 import MoGeModel as MoGeModelV1 6 | 7 | 8 | def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1']]: 9 | assert version in ['v1'], f'Unsupported model version: {version}' 10 | 11 | try: 12 | module = importlib.import_module(f'.{version}', __package__) 13 | except ModuleNotFoundError: 14 | raise ValueError(f'Model version "{version}" not found.') 15 | 16 | cls = getattr(module, 'MoGeModel') 17 | return cls 18 | -------------------------------------------------------------------------------- /moge/model/dinov2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | __version__ = "0.0.1" 7 | -------------------------------------------------------------------------------- /moge/model/dinov2/hub/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /moge/model/dinov2/hub/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | from typing import Union 8 | 9 | import torch 10 | 11 | from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name 12 | 13 | 14 | class Weights(Enum): 15 | LVD142M = "LVD142M" 16 | 17 | 18 | def _make_dinov2_model( 19 | *, 20 | arch_name: str = "vit_large", 21 | img_size: int = 518, 22 | patch_size: int = 14, 23 | init_values: float = 1.0, 24 | ffn_layer: str = "mlp", 25 | block_chunks: int = 0, 26 | num_register_tokens: int = 0, 27 | interpolate_antialias: bool = False, 28 | interpolate_offset: float = 0.1, 29 | pretrained: bool = True, 30 | weights: Union[Weights, str] = Weights.LVD142M, 31 | **kwargs, 32 | ): 33 | from ..models import vision_transformer as vits 34 | 35 | if isinstance(weights, str): 36 | try: 37 | weights = Weights[weights] 38 | except KeyError: 39 | raise AssertionError(f"Unsupported weights: {weights}") 40 | 41 | model_base_name = _make_dinov2_model_name(arch_name, patch_size) 42 | vit_kwargs = dict( 43 | img_size=img_size, 44 | patch_size=patch_size, 45 | init_values=init_values, 46 | ffn_layer=ffn_layer, 47 | block_chunks=block_chunks, 48 | num_register_tokens=num_register_tokens, 49 | interpolate_antialias=interpolate_antialias, 50 | interpolate_offset=interpolate_offset, 51 | ) 52 | vit_kwargs.update(**kwargs) 53 | model = vits.__dict__[arch_name](**vit_kwargs) 54 | 55 | if pretrained: 56 | model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) 57 | url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" 58 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") 59 | model.load_state_dict(state_dict, strict=True) 60 | 61 | return model 62 | 63 | 64 | def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 65 | """ 66 | DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. 67 | """ 68 | return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) 69 | 70 | 71 | def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 72 | """ 73 | DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. 74 | """ 75 | return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) 76 | 77 | 78 | def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 79 | """ 80 | DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. 81 | """ 82 | return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) 83 | 84 | 85 | def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 86 | """ 87 | DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. 88 | """ 89 | return _make_dinov2_model( 90 | arch_name="vit_giant2", 91 | ffn_layer="swiglufused", 92 | weights=weights, 93 | pretrained=pretrained, 94 | **kwargs, 95 | ) 96 | 97 | 98 | def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 99 | """ 100 | DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. 101 | """ 102 | return _make_dinov2_model( 103 | arch_name="vit_small", 104 | pretrained=pretrained, 105 | weights=weights, 106 | num_register_tokens=4, 107 | interpolate_antialias=True, 108 | interpolate_offset=0.0, 109 | **kwargs, 110 | ) 111 | 112 | 113 | def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 114 | """ 115 | DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. 116 | """ 117 | return _make_dinov2_model( 118 | arch_name="vit_base", 119 | pretrained=pretrained, 120 | weights=weights, 121 | num_register_tokens=4, 122 | interpolate_antialias=True, 123 | interpolate_offset=0.0, 124 | **kwargs, 125 | ) 126 | 127 | 128 | def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 129 | """ 130 | DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. 131 | """ 132 | return _make_dinov2_model( 133 | arch_name="vit_large", 134 | pretrained=pretrained, 135 | weights=weights, 136 | num_register_tokens=4, 137 | interpolate_antialias=True, 138 | interpolate_offset=0.0, 139 | **kwargs, 140 | ) 141 | 142 | 143 | def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 144 | """ 145 | DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. 146 | """ 147 | return _make_dinov2_model( 148 | arch_name="vit_giant2", 149 | ffn_layer="swiglufused", 150 | weights=weights, 151 | pretrained=pretrained, 152 | num_register_tokens=4, 153 | interpolate_antialias=True, 154 | interpolate_offset=0.0, 155 | **kwargs, 156 | ) 157 | -------------------------------------------------------------------------------- /moge/model/dinov2/hub/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 15 | 16 | 17 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: 18 | compact_arch_name = arch_name.replace("_", "")[:4] 19 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" 20 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" 21 | 22 | 23 | class CenterPadding(nn.Module): 24 | def __init__(self, multiple): 25 | super().__init__() 26 | self.multiple = multiple 27 | 28 | def _get_pad(self, size): 29 | new_size = math.ceil(size / self.multiple) * self.multiple 30 | pad_size = new_size - size 31 | pad_size_left = pad_size // 2 32 | pad_size_right = pad_size - pad_size_left 33 | return pad_size_left, pad_size_right 34 | 35 | @torch.inference_mode() 36 | def forward(self, x): 37 | pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) 38 | output = F.pad(x, pads) 39 | return output 40 | -------------------------------------------------------------------------------- /moge/model/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dino_head import DINOHead 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /moge/model/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 22 | try: 23 | if XFORMERS_ENABLED: 24 | from xformers.ops import memory_efficient_attention, unbind 25 | 26 | XFORMERS_AVAILABLE = True 27 | # warnings.warn("xFormers is available (Attention)") 28 | else: 29 | # warnings.warn("xFormers is disabled (Attention)") 30 | raise ImportError 31 | except ImportError: 32 | XFORMERS_AVAILABLE = False 33 | # warnings.warn("xFormers is not available (Attention)") 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int = 8, 41 | qkv_bias: bool = False, 42 | proj_bias: bool = True, 43 | attn_drop: float = 0.0, 44 | proj_drop: float = 0.0, 45 | ) -> None: 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | self.scale = head_dim**-0.5 50 | 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 57 | B, N, C = x.shape 58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 59 | 60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 61 | attn = q @ k.transpose(-2, -1) 62 | 63 | attn = attn.softmax(dim=-1) 64 | attn = self.attn_drop(attn) 65 | 66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | return x 70 | 71 | 72 | class MemEffAttention(Attention): 73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 74 | if not XFORMERS_AVAILABLE: 75 | if attn_bias is not None: 76 | raise AssertionError("xFormers is required for using nested tensors") 77 | return super().forward(x) 78 | 79 | B, N, C = x.shape 80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 81 | 82 | q, k, v = unbind(qkv, 2) 83 | 84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 85 | x = x.reshape([B, N, C]) 86 | 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | -------------------------------------------------------------------------------- /moge/model/dinov2/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | import logging 11 | import os 12 | from typing import Callable, List, Any, Tuple, Dict 13 | import warnings 14 | 15 | import torch 16 | from torch import nn, Tensor 17 | 18 | from .attention import Attention, MemEffAttention 19 | from .drop_path import DropPath 20 | from .layer_scale import LayerScale 21 | from .mlp import Mlp 22 | 23 | 24 | logger = logging.getLogger("dinov2") 25 | 26 | 27 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 28 | try: 29 | if XFORMERS_ENABLED: 30 | from xformers.ops import fmha, scaled_index_add, index_select_cat 31 | 32 | XFORMERS_AVAILABLE = True 33 | # warnings.warn("xFormers is available (Block)") 34 | else: 35 | # warnings.warn("xFormers is disabled (Block)") 36 | raise ImportError 37 | except ImportError: 38 | XFORMERS_AVAILABLE = False 39 | # warnings.warn("xFormers is not available (Block)") 40 | 41 | 42 | class Block(nn.Module): 43 | def __init__( 44 | self, 45 | dim: int, 46 | num_heads: int, 47 | mlp_ratio: float = 4.0, 48 | qkv_bias: bool = False, 49 | proj_bias: bool = True, 50 | ffn_bias: bool = True, 51 | drop: float = 0.0, 52 | attn_drop: float = 0.0, 53 | init_values=None, 54 | drop_path: float = 0.0, 55 | act_layer: Callable[..., nn.Module] = nn.GELU, 56 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 57 | attn_class: Callable[..., nn.Module] = Attention, 58 | ffn_layer: Callable[..., nn.Module] = Mlp, 59 | ) -> None: 60 | super().__init__() 61 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 62 | self.norm1 = norm_layer(dim) 63 | self.attn = attn_class( 64 | dim, 65 | num_heads=num_heads, 66 | qkv_bias=qkv_bias, 67 | proj_bias=proj_bias, 68 | attn_drop=attn_drop, 69 | proj_drop=drop, 70 | ) 71 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 72 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 73 | 74 | self.norm2 = norm_layer(dim) 75 | mlp_hidden_dim = int(dim * mlp_ratio) 76 | self.mlp = ffn_layer( 77 | in_features=dim, 78 | hidden_features=mlp_hidden_dim, 79 | act_layer=act_layer, 80 | drop=drop, 81 | bias=ffn_bias, 82 | ) 83 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 84 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 85 | 86 | self.sample_drop_ratio = drop_path 87 | 88 | def forward(self, x: Tensor) -> Tensor: 89 | def attn_residual_func(x: Tensor) -> Tensor: 90 | return self.ls1(self.attn(self.norm1(x))) 91 | 92 | def ffn_residual_func(x: Tensor) -> Tensor: 93 | return self.ls2(self.mlp(self.norm2(x))) 94 | 95 | if self.training and self.sample_drop_ratio > 0.1: 96 | # the overhead is compensated only for a drop path rate larger than 0.1 97 | x = drop_add_residual_stochastic_depth( 98 | x, 99 | residual_func=attn_residual_func, 100 | sample_drop_ratio=self.sample_drop_ratio, 101 | ) 102 | x = drop_add_residual_stochastic_depth( 103 | x, 104 | residual_func=ffn_residual_func, 105 | sample_drop_ratio=self.sample_drop_ratio, 106 | ) 107 | elif self.training and self.sample_drop_ratio > 0.0: 108 | x = x + self.drop_path1(attn_residual_func(x)) 109 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 110 | else: 111 | x = x + attn_residual_func(x) 112 | x = x + ffn_residual_func(x) 113 | return x 114 | 115 | 116 | def drop_add_residual_stochastic_depth( 117 | x: Tensor, 118 | residual_func: Callable[[Tensor], Tensor], 119 | sample_drop_ratio: float = 0.0, 120 | ) -> Tensor: 121 | # 1) extract subset using permutation 122 | b, n, d = x.shape 123 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 124 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 125 | x_subset = x[brange] 126 | 127 | # 2) apply residual_func to get residual 128 | residual = residual_func(x_subset) 129 | 130 | x_flat = x.flatten(1) 131 | residual = residual.flatten(1) 132 | 133 | residual_scale_factor = b / sample_subset_size 134 | 135 | # 3) add the residual 136 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 137 | return x_plus_residual.view_as(x) 138 | 139 | 140 | def get_branges_scales(x, sample_drop_ratio=0.0): 141 | b, n, d = x.shape 142 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 143 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 144 | residual_scale_factor = b / sample_subset_size 145 | return brange, residual_scale_factor 146 | 147 | 148 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 149 | if scaling_vector is None: 150 | x_flat = x.flatten(1) 151 | residual = residual.flatten(1) 152 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 153 | else: 154 | x_plus_residual = scaled_index_add( 155 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 156 | ) 157 | return x_plus_residual 158 | 159 | 160 | attn_bias_cache: Dict[Tuple, Any] = {} 161 | 162 | 163 | def get_attn_bias_and_cat(x_list, branges=None): 164 | """ 165 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 166 | """ 167 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 168 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 169 | if all_shapes not in attn_bias_cache.keys(): 170 | seqlens = [] 171 | for b, x in zip(batch_sizes, x_list): 172 | for _ in range(b): 173 | seqlens.append(x.shape[1]) 174 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 175 | attn_bias._batch_sizes = batch_sizes 176 | attn_bias_cache[all_shapes] = attn_bias 177 | 178 | if branges is not None: 179 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 180 | else: 181 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 182 | cat_tensors = torch.cat(tensors_bs1, dim=1) 183 | 184 | return attn_bias_cache[all_shapes], cat_tensors 185 | 186 | 187 | def drop_add_residual_stochastic_depth_list( 188 | x_list: List[Tensor], 189 | residual_func: Callable[[Tensor, Any], Tensor], 190 | sample_drop_ratio: float = 0.0, 191 | scaling_vector=None, 192 | ) -> Tensor: 193 | # 1) generate random set of indices for dropping samples in the batch 194 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 195 | branges = [s[0] for s in branges_scales] 196 | residual_scale_factors = [s[1] for s in branges_scales] 197 | 198 | # 2) get attention bias and index+concat the tensors 199 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 200 | 201 | # 3) apply residual_func to get residual, and split the result 202 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 203 | 204 | outputs = [] 205 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 206 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 207 | return outputs 208 | 209 | 210 | class NestedTensorBlock(Block): 211 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 212 | """ 213 | x_list contains a list of tensors to nest together and run 214 | """ 215 | assert isinstance(self.attn, MemEffAttention) 216 | 217 | if self.training and self.sample_drop_ratio > 0.0: 218 | 219 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 220 | return self.attn(self.norm1(x), attn_bias=attn_bias) 221 | 222 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 223 | return self.mlp(self.norm2(x)) 224 | 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=attn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | x_list = drop_add_residual_stochastic_depth_list( 232 | x_list, 233 | residual_func=ffn_residual_func, 234 | sample_drop_ratio=self.sample_drop_ratio, 235 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 236 | ) 237 | return x_list 238 | else: 239 | 240 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 241 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 242 | 243 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 244 | return self.ls2(self.mlp(self.norm2(x))) 245 | 246 | attn_bias, x = get_attn_bias_and_cat(x_list) 247 | x = x + attn_residual_func(x, attn_bias=attn_bias) 248 | x = x + ffn_residual_func(x) 249 | return attn_bias.split(x) 250 | 251 | def forward(self, x_or_x_list): 252 | if isinstance(x_or_x_list, Tensor): 253 | return super().forward(x_or_x_list) 254 | elif isinstance(x_or_x_list, list): 255 | if not XFORMERS_AVAILABLE: 256 | raise AssertionError("xFormers is required for using nested tensors") 257 | return self.forward_nested(x_or_x_list) 258 | else: 259 | raise AssertionError 260 | -------------------------------------------------------------------------------- /moge/model/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.init import trunc_normal_ 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | class DINOHead(nn.Module): 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim, 17 | use_bn=False, 18 | nlayers=3, 19 | hidden_dim=2048, 20 | bottleneck_dim=256, 21 | mlp_bias=True, 22 | ): 23 | super().__init__() 24 | nlayers = max(nlayers, 1) 25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 26 | self.apply(self._init_weights) 27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 28 | self.last_layer.weight_g.data.fill_(1) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=0.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | x = self.mlp(x) 38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 40 | x = self.last_layer(x) 41 | return x 42 | 43 | 44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 45 | if nlayers == 1: 46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 47 | else: 48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 49 | if use_bn: 50 | layers.append(nn.BatchNorm1d(hidden_dim)) 51 | layers.append(nn.GELU()) 52 | for _ in range(nlayers - 2): 53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 54 | if use_bn: 55 | layers.append(nn.BatchNorm1d(hidden_dim)) 56 | layers.append(nn.GELU()) 57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 58 | return nn.Sequential(*layers) 59 | -------------------------------------------------------------------------------- /moge/model/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /moge/model/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /moge/model/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /moge/model/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | _, _, H, W = x.shape 70 | patch_H, patch_W = self.patch_size 71 | 72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 74 | 75 | x = self.proj(x) # B C H W 76 | H, W = x.size(2), x.size(3) 77 | x = x.flatten(2).transpose(1, 2) # B HW C 78 | x = self.norm(x) 79 | if not self.flatten_embedding: 80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 81 | return x 82 | 83 | def flops(self) -> float: 84 | Ho, Wo = self.patches_resolution 85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 86 | if self.norm is not None: 87 | flops += Ho * Wo * self.embed_dim 88 | return flops 89 | -------------------------------------------------------------------------------- /moge/model/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | try: 39 | if XFORMERS_ENABLED: 40 | from xformers.ops import SwiGLU 41 | 42 | XFORMERS_AVAILABLE = True 43 | # warnings.warn("xFormers is available (SwiGLU)") 44 | else: 45 | # warnings.warn("xFormers is disabled (SwiGLU)") 46 | raise ImportError 47 | except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | # warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /moge/model/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | from . import vision_transformer as vits 9 | 10 | 11 | logger = logging.getLogger("dinov2") 12 | 13 | 14 | def build_model(args, only_teacher=False, img_size=224): 15 | args.arch = args.arch.removesuffix("_memeff") 16 | if "vit" in args.arch: 17 | vit_kwargs = dict( 18 | img_size=img_size, 19 | patch_size=args.patch_size, 20 | init_values=args.layerscale, 21 | ffn_layer=args.ffn_layer, 22 | block_chunks=args.block_chunks, 23 | qkv_bias=args.qkv_bias, 24 | proj_bias=args.proj_bias, 25 | ffn_bias=args.ffn_bias, 26 | num_register_tokens=args.num_register_tokens, 27 | interpolate_offset=args.interpolate_offset, 28 | interpolate_antialias=args.interpolate_antialias, 29 | ) 30 | teacher = vits.__dict__[args.arch](**vit_kwargs) 31 | if only_teacher: 32 | return teacher, teacher.embed_dim 33 | student = vits.__dict__[args.arch]( 34 | **vit_kwargs, 35 | drop_path_rate=args.drop_path_rate, 36 | drop_path_uniform=args.drop_path_uniform, 37 | ) 38 | embed_dim = student.embed_dim 39 | return student, teacher, embed_dim 40 | 41 | 42 | def build_model_from_cfg(cfg, only_teacher=False): 43 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) 44 | -------------------------------------------------------------------------------- /moge/model/dinov2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /moge/model/dinov2/utils/cluster.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | import os 8 | from pathlib import Path 9 | from typing import Any, Dict, Optional 10 | 11 | 12 | class ClusterType(Enum): 13 | AWS = "aws" 14 | FAIR = "fair" 15 | RSC = "rsc" 16 | 17 | 18 | def _guess_cluster_type() -> ClusterType: 19 | uname = os.uname() 20 | if uname.sysname == "Linux": 21 | if uname.release.endswith("-aws"): 22 | # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" 23 | return ClusterType.AWS 24 | elif uname.nodename.startswith("rsc"): 25 | # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" 26 | return ClusterType.RSC 27 | 28 | return ClusterType.FAIR 29 | 30 | 31 | def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: 32 | if cluster_type is None: 33 | return _guess_cluster_type() 34 | 35 | return cluster_type 36 | 37 | 38 | def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 39 | cluster_type = get_cluster_type(cluster_type) 40 | if cluster_type is None: 41 | return None 42 | 43 | CHECKPOINT_DIRNAMES = { 44 | ClusterType.AWS: "checkpoints", 45 | ClusterType.FAIR: "checkpoint", 46 | ClusterType.RSC: "checkpoint/dino", 47 | } 48 | return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] 49 | 50 | 51 | def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 52 | checkpoint_path = get_checkpoint_path(cluster_type) 53 | if checkpoint_path is None: 54 | return None 55 | 56 | username = os.environ.get("USER") 57 | assert username is not None 58 | return checkpoint_path / username 59 | 60 | 61 | def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: 62 | cluster_type = get_cluster_type(cluster_type) 63 | if cluster_type is None: 64 | return None 65 | 66 | SLURM_PARTITIONS = { 67 | ClusterType.AWS: "learnlab", 68 | ClusterType.FAIR: "learnlab", 69 | ClusterType.RSC: "learn", 70 | } 71 | return SLURM_PARTITIONS[cluster_type] 72 | 73 | 74 | def get_slurm_executor_parameters( 75 | nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs 76 | ) -> Dict[str, Any]: 77 | # create default parameters 78 | params = { 79 | "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html 80 | "gpus_per_node": num_gpus_per_node, 81 | "tasks_per_node": num_gpus_per_node, # one task per GPU 82 | "cpus_per_task": 10, 83 | "nodes": nodes, 84 | "slurm_partition": get_slurm_partition(cluster_type), 85 | } 86 | # apply cluster-specific adjustments 87 | cluster_type = get_cluster_type(cluster_type) 88 | if cluster_type == ClusterType.AWS: 89 | params["cpus_per_task"] = 12 90 | del params["mem_gb"] 91 | elif cluster_type == ClusterType.RSC: 92 | params["cpus_per_task"] = 12 93 | # set additional parameters / apply overrides 94 | params.update(kwargs) 95 | return params 96 | -------------------------------------------------------------------------------- /moge/model/dinov2/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import logging 8 | import os 9 | 10 | from omegaconf import OmegaConf 11 | 12 | import dinov2.distributed as distributed 13 | from dinov2.logging import setup_logging 14 | from dinov2.utils import utils 15 | from dinov2.configs import dinov2_default_config 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | def apply_scaling_rules_to_cfg(cfg): # to fix 22 | if cfg.optim.scaling_rule == "sqrt_wrt_1024": 23 | base_lr = cfg.optim.base_lr 24 | cfg.optim.lr = base_lr 25 | cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) 26 | logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") 27 | else: 28 | raise NotImplementedError 29 | return cfg 30 | 31 | 32 | def write_config(cfg, output_dir, name="config.yaml"): 33 | logger.info(OmegaConf.to_yaml(cfg)) 34 | saved_cfg_path = os.path.join(output_dir, name) 35 | with open(saved_cfg_path, "w") as f: 36 | OmegaConf.save(config=cfg, f=f) 37 | return saved_cfg_path 38 | 39 | 40 | def get_cfg_from_args(args): 41 | args.output_dir = os.path.abspath(args.output_dir) 42 | args.opts += [f"train.output_dir={args.output_dir}"] 43 | default_cfg = OmegaConf.create(dinov2_default_config) 44 | cfg = OmegaConf.load(args.config_file) 45 | cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) 46 | return cfg 47 | 48 | 49 | def default_setup(args): 50 | distributed.enable(overwrite=True) 51 | seed = getattr(args, "seed", 0) 52 | rank = distributed.get_global_rank() 53 | 54 | global logger 55 | setup_logging(output=args.output_dir, level=logging.INFO) 56 | logger = logging.getLogger("dinov2") 57 | 58 | utils.fix_random_seeds(seed + rank) 59 | logger.info("git:\n {}\n".format(utils.get_sha())) 60 | logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 61 | 62 | 63 | def setup(args): 64 | """ 65 | Create configs and perform basic setups. 66 | """ 67 | cfg = get_cfg_from_args(args) 68 | os.makedirs(args.output_dir, exist_ok=True) 69 | default_setup(args) 70 | apply_scaling_rules_to_cfg(cfg) 71 | write_config(cfg, args.output_dir) 72 | return cfg 73 | -------------------------------------------------------------------------------- /moge/model/dinov2/utils/dtype.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from typing import Dict, Union 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | TypeSpec = Union[str, np.dtype, torch.dtype] 14 | 15 | 16 | _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { 17 | np.dtype("bool"): torch.bool, 18 | np.dtype("uint8"): torch.uint8, 19 | np.dtype("int8"): torch.int8, 20 | np.dtype("int16"): torch.int16, 21 | np.dtype("int32"): torch.int32, 22 | np.dtype("int64"): torch.int64, 23 | np.dtype("float16"): torch.float16, 24 | np.dtype("float32"): torch.float32, 25 | np.dtype("float64"): torch.float64, 26 | np.dtype("complex64"): torch.complex64, 27 | np.dtype("complex128"): torch.complex128, 28 | } 29 | 30 | 31 | def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: 32 | if isinstance(dtype, torch.dtype): 33 | return dtype 34 | if isinstance(dtype, str): 35 | dtype = np.dtype(dtype) 36 | assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" 37 | return _NUMPY_TO_TORCH_DTYPE[dtype] 38 | -------------------------------------------------------------------------------- /moge/model/dinov2/utils/param_groups.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from collections import defaultdict 7 | import logging 8 | 9 | 10 | logger = logging.getLogger("dinov2") 11 | 12 | 13 | def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): 14 | """ 15 | Calculate lr decay rate for different ViT blocks. 16 | Args: 17 | name (string): parameter name. 18 | lr_decay_rate (float): base lr decay rate. 19 | num_layers (int): number of ViT blocks. 20 | Returns: 21 | lr decay rate for the given parameter. 22 | """ 23 | layer_id = num_layers + 1 24 | if name.startswith("backbone") or force_is_backbone: 25 | if ( 26 | ".pos_embed" in name 27 | or ".patch_embed" in name 28 | or ".mask_token" in name 29 | or ".cls_token" in name 30 | or ".register_tokens" in name 31 | ): 32 | layer_id = 0 33 | elif force_is_backbone and ( 34 | "pos_embed" in name 35 | or "patch_embed" in name 36 | or "mask_token" in name 37 | or "cls_token" in name 38 | or "register_tokens" in name 39 | ): 40 | layer_id = 0 41 | elif ".blocks." in name and ".residual." not in name: 42 | layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 43 | elif chunked_blocks and "blocks." in name and "residual." not in name: 44 | layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 45 | elif "blocks." in name and "residual." not in name: 46 | layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 47 | 48 | return lr_decay_rate ** (num_layers + 1 - layer_id) 49 | 50 | 51 | def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): 52 | chunked_blocks = False 53 | if hasattr(model, "n_blocks"): 54 | logger.info("chunked fsdp") 55 | n_blocks = model.n_blocks 56 | chunked_blocks = model.chunked_blocks 57 | elif hasattr(model, "blocks"): 58 | logger.info("first code branch") 59 | n_blocks = len(model.blocks) 60 | elif hasattr(model, "backbone"): 61 | logger.info("second code branch") 62 | n_blocks = len(model.backbone.blocks) 63 | else: 64 | logger.info("else code branch") 65 | n_blocks = 0 66 | all_param_groups = [] 67 | 68 | for name, param in model.named_parameters(): 69 | name = name.replace("_fsdp_wrapped_module.", "") 70 | if not param.requires_grad: 71 | continue 72 | decay_rate = get_vit_lr_decay_rate( 73 | name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks 74 | ) 75 | d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} 76 | 77 | if "last_layer" in name: 78 | d.update({"is_last_layer": True}) 79 | 80 | if name.endswith(".bias") or "norm" in name or "gamma" in name: 81 | d.update({"wd_multiplier": 0.0}) 82 | 83 | if "patch_embed" in name: 84 | d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) 85 | 86 | all_param_groups.append(d) 87 | logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") 88 | 89 | return all_param_groups 90 | 91 | 92 | def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): 93 | fused_params_groups = defaultdict(lambda: {"params": []}) 94 | for d in all_params_groups: 95 | identifier = "" 96 | for k in keys: 97 | identifier += k + str(d[k]) + "_" 98 | 99 | for k in keys: 100 | fused_params_groups[identifier][k] = d[k] 101 | fused_params_groups[identifier]["params"].append(d["params"]) 102 | 103 | return fused_params_groups.values() 104 | -------------------------------------------------------------------------------- /moge/model/dinov2/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import os 8 | import random 9 | import subprocess 10 | from urllib.parse import urlparse 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key): 21 | if urlparse(pretrained_weights).scheme: # If it looks like an URL 22 | state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") 23 | else: 24 | state_dict = torch.load(pretrained_weights, map_location="cpu") 25 | if checkpoint_key is not None and checkpoint_key in state_dict: 26 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") 27 | state_dict = state_dict[checkpoint_key] 28 | # remove `module.` prefix 29 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 30 | # remove `backbone.` prefix induced by multicrop wrapper 31 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 32 | msg = model.load_state_dict(state_dict, strict=False) 33 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) 34 | 35 | 36 | def fix_random_seeds(seed=31): 37 | """ 38 | Fix random seeds. 39 | """ 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed_all(seed) 42 | np.random.seed(seed) 43 | random.seed(seed) 44 | 45 | 46 | def get_sha(): 47 | cwd = os.path.dirname(os.path.abspath(__file__)) 48 | 49 | def _run(command): 50 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 51 | 52 | sha = "N/A" 53 | diff = "clean" 54 | branch = "N/A" 55 | try: 56 | sha = _run(["git", "rev-parse", "HEAD"]) 57 | subprocess.check_output(["git", "diff"], cwd=cwd) 58 | diff = _run(["git", "diff-index", "HEAD"]) 59 | diff = "has uncommitted changes" if diff else "clean" 60 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 61 | except Exception: 62 | pass 63 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 64 | return message 65 | 66 | 67 | class CosineScheduler(object): 68 | def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): 69 | super().__init__() 70 | self.final_value = final_value 71 | self.total_iters = total_iters 72 | 73 | freeze_schedule = np.zeros((freeze_iters)) 74 | 75 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 76 | 77 | iters = np.arange(total_iters - warmup_iters - freeze_iters) 78 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 79 | self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) 80 | 81 | assert len(self.schedule) == self.total_iters 82 | 83 | def __getitem__(self, it): 84 | if it >= self.total_iters: 85 | return self.final_value 86 | else: 87 | return self.schedule[it] 88 | 89 | 90 | def has_batchnorms(model): 91 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 92 | for name, module in model.named_modules(): 93 | if isinstance(module, bn_types): 94 | return True 95 | return False 96 | -------------------------------------------------------------------------------- /moge/model/utils.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def wrap_module_with_gradient_checkpointing(module: nn.Module): 8 | from torch.utils.checkpoint import checkpoint 9 | class _CheckpointingWrapper(module.__class__): 10 | _restore_cls = module.__class__ 11 | def forward(self, *args, **kwargs): 12 | return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) 13 | 14 | module.__class__ = _CheckpointingWrapper 15 | return module 16 | 17 | 18 | def unwrap_module_with_gradient_checkpointing(module: nn.Module): 19 | module.__class__ = module.__class__._restore_cls 20 | 21 | 22 | def wrap_dinov2_attention_with_sdpa(module: nn.Module): 23 | assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later" 24 | class _AttentionWrapper(module.__class__): 25 | def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: 26 | B, N, C = x.shape 27 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) 28 | 29 | q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) 30 | 31 | x = F.scaled_dot_product_attention(q, k, v, attn_bias) 32 | x = x.permute(0, 2, 1, 3).reshape(B, N, C) 33 | 34 | x = self.proj(x) 35 | x = self.proj_drop(x) 36 | return x 37 | module.__class__ = _AttentionWrapper 38 | return module 39 | 40 | 41 | def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]: 42 | group_to_use = torch.distributed.group.WORLD 43 | world_size = group_to_use.size() 44 | grad = bucket.buffer() 45 | grad.div_(world_size) 46 | torch.distributed.all_reduce(grad, group=group_to_use) 47 | fut = torch.futures.Future() 48 | fut.set_result(grad) 49 | return fut 50 | -------------------------------------------------------------------------------- /moge/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/scripts/__init__.py -------------------------------------------------------------------------------- /moge/scripts/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: 5 | sys.path.insert(0, _package_root) 6 | import time 7 | import uuid 8 | import tempfile 9 | from typing import * 10 | import atexit 11 | from concurrent.futures import ThreadPoolExecutor 12 | 13 | import click 14 | 15 | 16 | @click.command(help='Web demo') 17 | @click.option('--share', is_flag=True, help='Whether to run the app in shared mode.') 18 | @click.option('--max_size', default=800, type=int, help='The maximum size of the input image.') 19 | @click.option('--pretrained', 'pretrained_model_name_or_path', default='Ruicheng/moge-vitl', help='The name or path of the pre-trained model.') 20 | def main(share: bool, max_size: int, pretrained_model_name_or_path: str): 21 | # Lazy import 22 | import cv2 23 | import torch 24 | import numpy as np 25 | import trimesh 26 | import trimesh.visual 27 | from PIL import Image 28 | import gradio as gr 29 | try: 30 | import spaces # This is for deployment at huggingface.co/spaces 31 | HUGGINFACE_SPACES_INSTALLED = True 32 | except ImportError: 33 | HUGGINFACE_SPACES_INSTALLED = False 34 | 35 | import utils3d 36 | from moge.utils.vis import colorize_depth 37 | from moge.model.v1 import MoGeModel 38 | 39 | 40 | model = MoGeModel.from_pretrained(pretrained_model_name_or_path).cuda().eval() 41 | thread_pool_executor = ThreadPoolExecutor(max_workers=1) 42 | 43 | def delete_later(path: Union[str, os.PathLike], delay: int = 300): 44 | def _delete(): 45 | try: 46 | os.remove(path) 47 | except: 48 | pass 49 | def _wait_and_delete(): 50 | time.sleep(delay) 51 | _delete(path) 52 | thread_pool_executor.submit(_wait_and_delete) 53 | atexit.register(_delete) 54 | 55 | # Inference on GPU. 56 | @(spaces.GPU if HUGGINFACE_SPACES_INSTALLED else lambda x: x) 57 | def run_with_gpu(image: np.ndarray) -> Dict[str, np.ndarray]: 58 | image_tensor = torch.tensor(image, dtype=torch.float32, device=torch.device('cuda')).permute(2, 0, 1) / 255 59 | output = model.infer(image_tensor, apply_mask=True, resolution_level=9) 60 | output = {k: v.cpu().numpy() for k, v in output.items()} 61 | return output 62 | 63 | # Full inference pipeline 64 | def run(image: np.ndarray, remove_edge: bool = True): 65 | run_id = str(uuid.uuid4()) 66 | 67 | larger_size = max(image.shape[:2]) 68 | if larger_size > max_size: 69 | scale = max_size / larger_size 70 | image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA) 71 | 72 | height, width = image.shape[:2] 73 | 74 | output = run_with_gpu(image) 75 | points, depth, mask = output['points'], output['depth'], output['mask'] 76 | normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask) 77 | fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(output['intrinsics']) 78 | fov_x, fov_y = np.rad2deg([fov_x, fov_y]) 79 | 80 | faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( 81 | points, 82 | image.astype(np.float32) / 255, 83 | utils3d.numpy.image_uv(width=width, height=height), 84 | mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=0.03, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), 85 | tri=True 86 | ) 87 | vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] 88 | 89 | tempdir = Path(tempfile.gettempdir(), 'moge') 90 | tempdir.mkdir(exist_ok=True) 91 | 92 | output_glb_path = Path(tempdir, f'{run_id}.glb') 93 | output_glb_path.parent.mkdir(exist_ok=True) 94 | trimesh.Trimesh( 95 | vertices=vertices * [-1, 1, -1], # No idea why Gradio 3D Viewer' default camera is flipped 96 | faces=faces, 97 | visual = trimesh.visual.texture.TextureVisuals( 98 | uv=vertex_uvs, 99 | material=trimesh.visual.material.PBRMaterial( 100 | baseColorTexture=Image.fromarray(image), 101 | metallicFactor=0.5, 102 | roughnessFactor=1.0 103 | ) 104 | ), 105 | process=False 106 | ).export(output_glb_path) 107 | 108 | output_ply_path = Path(tempdir, f'{run_id}.ply') 109 | output_ply_path.parent.mkdir(exist_ok=True) 110 | trimesh.Trimesh( 111 | vertices=vertices, 112 | faces=faces, 113 | vertex_colors=vertex_colors, 114 | process=False 115 | ).export(output_ply_path) 116 | 117 | colorized_depth = colorize_depth(depth) 118 | 119 | delete_later(output_glb_path, delay=300) 120 | delete_later(output_ply_path, delay=300) 121 | 122 | return ( 123 | colorized_depth, 124 | output_glb_path, 125 | output_ply_path.as_posix(), 126 | f'Horizontal FOV: {fov_x:.2f}, Vertical FOV: {fov_y:.2f}' 127 | ) 128 | 129 | gr.Interface( 130 | fn=run, 131 | inputs=[ 132 | gr.Image(type="numpy", image_mode="RGB"), 133 | gr.Checkbox(True, label="Remove edges"), 134 | ], 135 | outputs=[ 136 | gr.Image(type="numpy", label="Depth map (colorized)", format='png'), 137 | gr.Model3D(display_mode="solid", clear_color=[1.0, 1.0, 1.0, 1.0], label="3D Viewer"), 138 | gr.File(type="filepath", label="Download the model as .ply file"), 139 | gr.Textbox('--', label="FOV (Horizontal, Vertical)") 140 | ], 141 | title=None, 142 | description=f""" 143 | ## Turn a 2D image into a 3D point map with [MoGe](https://wangrc.site/MoGePage/) 144 | 145 | NOTE: 146 | * The maximum size is set to {max_size:d}px for efficiency purpose. Oversized images will be downsampled. 147 | * The color in the 3D viewer may look dark due to rendering of 3D viewer. You may download the 3D model as .glb or .ply file to view it in other 3D viewers. 148 | """, 149 | clear_btn=None, 150 | allow_flagging="never", 151 | theme=gr.themes.Soft() 152 | ).launch(share=share) 153 | 154 | 155 | if __name__ == '__main__': 156 | main() -------------------------------------------------------------------------------- /moge/scripts/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' 3 | from pathlib import Path 4 | import sys 5 | if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: 6 | sys.path.insert(0, _package_root) 7 | 8 | import click 9 | 10 | 11 | @click.group(help='MoGe command line interface.') 12 | def cli(): 13 | pass 14 | 15 | def main(): 16 | from moge.scripts import app, infer, infer_baseline, infer_panorama, eval_baseline, vis_data 17 | cli.add_command(app.main, name='app') 18 | cli.add_command(infer.main, name='infer') 19 | cli.add_command(infer_baseline.main, name='infer_baseline') 20 | cli.add_command(infer_panorama.main, name='infer_panorama') 21 | cli.add_command(eval_baseline.main, name='eval_baseline') 22 | cli.add_command(vis_data.main, name='vis_data') 23 | cli() 24 | 25 | 26 | if __name__ == '__main__': 27 | main() -------------------------------------------------------------------------------- /moge/scripts/eval_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: 5 | sys.path.insert(0, _package_root) 6 | import json 7 | from typing import * 8 | import importlib 9 | import importlib.util 10 | 11 | import click 12 | 13 | 14 | @click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Evaluation script.') 15 | @click.option('--baseline', 'baseline_code_path', type=click.Path(), required=True, help='Path to the baseline model python code.') 16 | @click.option('--config', 'config_path', type=click.Path(), default='configs/eval/all_benchmarks.json', help='Path to the evaluation configurations. ' 17 | 'Defaults to "configs/eval/all_benchmarks.json".') 18 | @click.option('--output', '-o', 'output_path', type=click.Path(), required=True, help='Path to the output json file.') 19 | @click.option('--oracle', 'oracle_mode', is_flag=True, help='Use oracle mode for evaluation, i.e., use the GT intrinsics input.') 20 | @click.option('--dump_pred', is_flag=True, help='Dump predition results.') 21 | @click.option('--dump_gt', is_flag=True, help='Dump ground truth.') 22 | @click.pass_context 23 | def main(ctx: click.Context, baseline_code_path: str, config_path: str, oracle_mode: bool, output_path: Union[str, Path], dump_pred: bool, dump_gt: bool): 24 | # Lazy import 25 | import cv2 26 | import numpy as np 27 | from tqdm import tqdm 28 | import torch 29 | import torch.nn.functional as F 30 | import utils3d 31 | 32 | from moge.test.baseline import MGEBaselineInterface 33 | from moge.test.dataloader import EvalDataLoaderPipeline 34 | from moge.test.metrics import compute_metrics 35 | from moge.utils.geometry_torch import intrinsics_to_fov 36 | from moge.utils.vis import colorize_depth, colorize_normal 37 | from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module 38 | 39 | # Load the baseline model 40 | module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem) 41 | baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline') 42 | baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False) 43 | 44 | # Load the evaluation configurations 45 | with open(config_path, 'r') as f: 46 | config = json.load(f) 47 | 48 | Path(output_path).parent.mkdir(parents=True, exist_ok=True) 49 | all_metrics = {} 50 | # Iterate over the dataset 51 | for benchmark_name, benchmark_config in tqdm(list(config.items()), desc='Benchmarks'): 52 | filenames, metrics_list = [], [] 53 | with ( 54 | EvalDataLoaderPipeline(**benchmark_config) as eval_data_pipe, 55 | tqdm(total=len(eval_data_pipe), desc=benchmark_name, leave=False) as pbar 56 | ): 57 | # Iterate over the samples in the dataset 58 | for i in range(len(eval_data_pipe)): 59 | sample = eval_data_pipe.get() 60 | sample = {k: v.to(baseline.device) if isinstance(v, torch.Tensor) else v for k, v in sample.items()} 61 | image = sample['image'] 62 | gt_intrinsics = sample['intrinsics'] 63 | 64 | # Inference 65 | torch.cuda.synchronize() 66 | with torch.inference_mode(), timeit('_inference_timer', verbose=False) as timer: 67 | if oracle_mode: 68 | pred = baseline.infer_for_evaluation(image, gt_intrinsics) 69 | else: 70 | pred = baseline.infer_for_evaluation(image) 71 | torch.cuda.synchronize() 72 | 73 | # Compute metrics 74 | metrics, misc = compute_metrics(pred, sample, vis=dump_pred or dump_gt) 75 | metrics['inference_time'] = timer.time 76 | metrics_list.append(metrics) 77 | 78 | # Dump results 79 | dump_path = Path(output_path.replace(".json", f"_dump"), f'{benchmark_name}', sample['filename'].replace('.zip', '')) 80 | if dump_pred: 81 | dump_path.joinpath('pred').mkdir(parents=True, exist_ok=True) 82 | cv2.imwrite(str(dump_path / 'pred' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) 83 | 84 | with Path(dump_path, 'pred', 'metrics.json').open('w') as f: 85 | json.dump(metrics, f, indent=4) 86 | 87 | if 'pred_points' in misc: 88 | points = misc['pred_points'].cpu().numpy() 89 | cv2.imwrite(str(dump_path / 'pred' / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) 90 | 91 | if 'pred_depth' in misc: 92 | depth = misc['pred_depth'].cpu().numpy() 93 | if 'mask' in pred: 94 | mask = pred['mask'].cpu().numpy() 95 | depth = np.where(mask, depth, np.inf) 96 | cv2.imwrite(str(dump_path / 'pred' / 'depth.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR)) 97 | 98 | if 'mask' in pred: 99 | mask = pred['mask'].cpu().numpy() 100 | cv2.imwrite(str(dump_path / 'pred' / 'mask.png'), (mask * 255).astype(np.uint8)) 101 | 102 | if 'normal' in pred: 103 | normal = pred['normal'].cpu().numpy() 104 | cv2.imwrite(str(dump_path / 'pred' / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR)) 105 | 106 | if 'intrinsics' in pred: 107 | intrinsics = pred['intrinsics'] 108 | fov_x, fov_y = intrinsics_to_fov(intrinsics) 109 | with open(dump_path / 'pred' / 'fov.json', 'w') as f: 110 | json.dump({ 111 | 'fov_x': np.rad2deg(fov_x.item()), 112 | 'fov_y': np.rad2deg(fov_y.item()), 113 | 'intrinsics': intrinsics.cpu().numpy().tolist(), 114 | }, f) 115 | 116 | if dump_gt: 117 | dump_path.joinpath('gt').mkdir(parents=True, exist_ok=True) 118 | cv2.imwrite(str(dump_path / 'gt' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) 119 | 120 | if 'points' in sample: 121 | points = sample['points'] 122 | cv2.imwrite(str(dump_path / 'gt' / 'points.exr'), cv2.cvtColor(points.cpu().numpy().astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) 123 | 124 | if 'depth' in sample: 125 | depth = sample['depth'] 126 | mask = sample['depth_mask'] 127 | cv2.imwrite(str(dump_path / 'gt' / 'depth.png'), cv2.cvtColor(colorize_depth(depth.cpu().numpy(), mask=mask.cpu().numpy()), cv2.COLOR_RGB2BGR)) 128 | 129 | if 'normal' in sample: 130 | normal = sample['normal'] 131 | cv2.imwrite(str(dump_path / 'gt' / 'normal.png'), cv2.cvtColor(colorize_normal(normal.cpu().numpy()), cv2.COLOR_RGB2BGR)) 132 | 133 | if 'depth_mask' in sample: 134 | mask = sample['depth_mask'] 135 | cv2.imwrite(str(dump_path / 'gt' /'mask.png'), (mask.cpu().numpy() * 255).astype(np.uint8)) 136 | 137 | if 'intrinsics' in sample: 138 | intrinsics = sample['intrinsics'] 139 | fov_x, fov_y = intrinsics_to_fov(intrinsics) 140 | with open(dump_path / 'gt' / 'info.json', 'w') as f: 141 | json.dump({ 142 | 'fov_x': np.rad2deg(fov_x.item()), 143 | 'fov_y': np.rad2deg(fov_y.item()), 144 | 'intrinsics': intrinsics.cpu().numpy().tolist(), 145 | }, f) 146 | 147 | # Save intermediate results 148 | if i % 100 == 0 or i == len(eval_data_pipe) - 1: 149 | Path(output_path).write_text( 150 | json.dumps({ 151 | **all_metrics, 152 | benchmark_name: key_average(metrics_list) 153 | }, indent=4) 154 | ) 155 | pbar.update(1) 156 | 157 | all_metrics[benchmark_name] = key_average(metrics_list) 158 | 159 | # Save final results 160 | all_metrics['mean'] = key_average(list(all_metrics.values())) 161 | Path(output_path).write_text(json.dumps(all_metrics, indent=4)) 162 | 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /moge/scripts/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' 3 | from pathlib import Path 4 | import sys 5 | if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: 6 | sys.path.insert(0, _package_root) 7 | from typing import * 8 | import itertools 9 | import json 10 | import warnings 11 | 12 | import cv2 13 | import numpy as np 14 | import torch 15 | from PIL import Image 16 | from tqdm import tqdm 17 | import trimesh 18 | import trimesh.visual 19 | import click 20 | 21 | from moge.model.v1 import MoGeModel 22 | from moge.utils.io import save_glb, save_ply 23 | from moge.utils.vis import colorize_depth, colorize_normal 24 | import utils3d 25 | 26 | 27 | @click.command(help='Inference script') 28 | @click.option('--input', '-i', 'input_path', type=click.Path(exists=True), help='Input image or folder path. "jpg" and "png" are supported.') 29 | @click.option('--fov_x', 'fov_x_', type=float, default=None, help='If camera parameters are known, set the horizontal field of view in degrees. Otherwise, MoGe will estimate it.') 30 | @click.option('--output', '-o', 'output_path', default='./output', type=click.Path(), help='Output folder path') 31 | @click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Defaults to "Ruicheng/moge-vitl"') 32 | @click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"') 33 | @click.option('--fp16', 'use_fp16', is_flag=True, help='Use fp16 precision for 2x faster inference.') 34 | @click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).') 35 | @click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level for inference. \ 36 | Higher value means more tokens and the finer details will be captured, but inference can be slower. \ 37 | Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. \ 38 | `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.') 39 | @click.option('--num_tokens', type=int, default=None, help='number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. \ 40 | `resolution_level` will be ignored if `num_tokens` is provided. Default: None') 41 | @click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Defaults to 0.03. Smaller value removes more edges. "inf" means no thresholding.') 42 | @click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).') 43 | @click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.') 44 | @click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.') 45 | @click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.') 46 | def main( 47 | input_path: str, 48 | fov_x_: float, 49 | output_path: str, 50 | pretrained_model_name_or_path: str, 51 | device_name: str, 52 | use_fp16: bool, 53 | resize_to: int, 54 | resolution_level: int, 55 | num_tokens: int, 56 | threshold: float, 57 | save_maps_: bool, 58 | save_glb_: bool, 59 | save_ply_: bool, 60 | show: bool, 61 | ): 62 | device = torch.device(device_name) 63 | 64 | include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] 65 | if Path(input_path).is_dir(): 66 | image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) 67 | else: 68 | image_paths = [Path(input_path)] 69 | 70 | if len(image_paths) == 0: 71 | raise FileNotFoundError(f'No image files found in {input_path}') 72 | 73 | model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() 74 | 75 | 76 | if not any([save_maps_, save_glb_, save_ply_]): 77 | warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.') 78 | save_maps_ = save_glb_ = save_ply_ = True 79 | 80 | for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)): 81 | image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) 82 | height, width = image.shape[:2] 83 | if resize_to is not None: 84 | height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) 85 | image = cv2.resize(image, (width, height), cv2.INTER_AREA) 86 | image_tensor = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1) 87 | 88 | # Inference 89 | output = model.infer(image_tensor, fov_x=fov_x_, resolution_level=resolution_level, num_tokens=num_tokens, use_fp16=use_fp16) 90 | points, depth, mask, intrinsics = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy(), output['intrinsics'].cpu().numpy() 91 | normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask) 92 | 93 | save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) 94 | save_path.mkdir(exist_ok=True, parents=True) 95 | 96 | # Save images / maps 97 | if save_maps_: 98 | cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 99 | cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR)) 100 | cv2.imwrite(str(save_path / 'depth.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) 101 | cv2.imwrite(str(save_path / 'mask.png'), (mask * 255).astype(np.uint8)) 102 | cv2.imwrite(str(save_path / 'points.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) 103 | fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics) 104 | with open(save_path / 'fov.json', 'w') as f: 105 | json.dump({ 106 | 'fov_x': round(float(np.rad2deg(fov_x)), 2), 107 | 'fov_y': round(float(np.rad2deg(fov_y)), 2), 108 | }, f) 109 | 110 | # Export mesh & visulization 111 | if save_glb_ or save_ply_ or show: 112 | faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( 113 | points, 114 | image.astype(np.float32) / 255, 115 | utils3d.numpy.image_uv(width=width, height=height), 116 | mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=threshold, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), 117 | tri=True 118 | ) 119 | # When exporting the model, follow the OpenGL coordinate conventions: 120 | # - world coordinate system: x right, y up, z backward. 121 | # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top. 122 | vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] 123 | 124 | if save_glb_: 125 | save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image) 126 | 127 | if save_ply_: 128 | save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) 129 | 130 | if show: 131 | trimesh.Trimesh( 132 | vertices=vertices, 133 | vertex_colors=vertex_colors, 134 | faces=faces, 135 | process=False 136 | ).show() 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /moge/scripts/infer_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' 3 | from pathlib import Path 4 | import sys 5 | if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: 6 | sys.path.insert(0, _package_root) 7 | import json 8 | from pathlib import Path 9 | from typing import * 10 | import itertools 11 | import warnings 12 | 13 | import click 14 | 15 | 16 | @click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Inference script for wrapped baselines methods') 17 | @click.option('--baseline', 'baseline_code_path', required=True, type=click.Path(), help='Path to the baseline model python code.') 18 | @click.option('--input', '-i', 'input_path', type=str, required=True, help='Input image or folder') 19 | @click.option('--output', '-o', 'output_path', type=str, default='./output', help='Output folder') 20 | @click.option('--size', 'image_size', type=int, default=None, help='Resize input image') 21 | @click.option('--skip', is_flag=True, help='Skip existing output') 22 | @click.option('--maps', 'save_maps_', is_flag=True, help='Save output point / depth maps') 23 | @click.option('--ply', 'save_ply_', is_flag=True, help='Save mesh in PLY format') 24 | @click.option('--glb', 'save_glb_', is_flag=True, help='Save mesh in GLB format') 25 | @click.option('--threshold', type=float, default=0.03, help='Depth edge detection threshold for saving mesh') 26 | @click.pass_context 27 | def main(ctx: click.Context, baseline_code_path: str, input_path: str, output_path: str, image_size: int, skip: bool, save_maps_, save_ply_: bool, save_glb_: bool, threshold: float): 28 | # Lazy import 29 | import cv2 30 | import numpy as np 31 | from tqdm import tqdm 32 | import torch 33 | import utils3d 34 | 35 | from moge.utils.io import save_ply, save_glb 36 | from moge.utils.geometry_numpy import intrinsics_to_fov_numpy 37 | from moge.utils.vis import colorize_depth, colorize_depth_affine, colorize_disparity 38 | from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module 39 | from moge.test.baseline import MGEBaselineInterface 40 | 41 | # Load the baseline model 42 | module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem) 43 | baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline') 44 | baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False) 45 | 46 | # Input images list 47 | include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] 48 | if Path(input_path).is_dir(): 49 | image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) 50 | else: 51 | image_paths = [Path(input_path)] 52 | 53 | if not any([save_maps_, save_glb_, save_ply_]): 54 | warnings.warn('No output format specified. Defaults to saving maps only. Please use "--maps", "--glb", or "--ply" to specify the output.') 55 | save_maps_ = True 56 | 57 | for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)): 58 | # Load one image at a time 59 | image_np = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) 60 | height, width = image_np.shape[:2] 61 | if image_size is not None and max(image_np.shape[:2]) > image_size: 62 | height, width = min(image_size, int(image_size * height / width)), min(image_size, int(image_size * width / height)) 63 | image_np = cv2.resize(image_np, (width, height), cv2.INTER_AREA) 64 | image = torch.from_numpy(image_np.astype(np.float32) / 255.0).permute(2, 0, 1).to(baseline.device) 65 | 66 | # Inference 67 | torch.cuda.synchronize() 68 | with torch.inference_mode(), (timer := timeit('Inference', verbose=False, average=True)): 69 | output = baseline.infer(image) 70 | torch.cuda.synchronize() 71 | 72 | inference_time = timer.average_time 73 | pbar.set_postfix({'average inference time': f'{inference_time:.3f}s'}) 74 | 75 | # Save the output 76 | save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) 77 | if skip and save_path.exists(): 78 | continue 79 | save_path.mkdir(parents=True, exist_ok=True) 80 | 81 | if save_maps_: 82 | cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) 83 | 84 | if 'mask' in output: 85 | mask = output['mask'].cpu().numpy() 86 | cv2.imwrite(str(save_path /'mask.png'), (mask * 255).astype(np.uint8)) 87 | 88 | for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']: 89 | if k in output: 90 | points = output[k].cpu().numpy() 91 | cv2.imwrite(str(save_path / f'{k}.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) 92 | 93 | for k in ['depth_metric', 'depth_scale_invariant', 'depth_affine_invariant', 'disparity_affine_invariant']: 94 | if k in output: 95 | depth = output[k].cpu().numpy() 96 | cv2.imwrite(str(save_path / f'{k}.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) 97 | if k in ['depth_metric', 'depth_scale_invariant']: 98 | depth_vis = colorize_depth(depth) 99 | elif k == 'depth_affine_invariant': 100 | depth_vis = colorize_depth_affine(depth) 101 | elif k == 'disparity_affine_invariant': 102 | depth_vis = colorize_disparity(depth) 103 | cv2.imwrite(str(save_path / f'{k}_vis.png'), cv2.cvtColor(depth_vis, cv2.COLOR_RGB2BGR)) 104 | 105 | if 'intrinsics' in output: 106 | intrinsics = output['intrinsics'].cpu().numpy() 107 | fov_x, fov_y = intrinsics_to_fov_numpy(intrinsics) 108 | with open(save_path / 'fov.json', 'w') as f: 109 | json.dump({ 110 | 'fov_x': float(np.rad2deg(fov_x)), 111 | 'fov_y': float(np.rad2deg(fov_y)), 112 | 'intrinsics': intrinsics.tolist() 113 | }, f, indent=4) 114 | 115 | # Export mesh & visulization 116 | if save_ply_ or save_glb_: 117 | assert any(k in output for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']), 'No point map found in output' 118 | points = next(output[k] for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant'] if k in output).cpu().numpy() 119 | mask = output['mask'] if 'mask' in output else np.ones_like(points[..., 0], dtype=bool) 120 | normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask) 121 | faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( 122 | points, 123 | image_np.astype(np.float32) / 255, 124 | utils3d.numpy.image_uv(width=width, height=height), 125 | mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=threshold, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), 126 | tri=True 127 | ) 128 | # When exporting the model, follow the OpenGL coordinate conventions: 129 | # - world coordinate system: x right, y up, z backward. 130 | # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top. 131 | vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] 132 | 133 | if save_glb_: 134 | save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image_np) 135 | 136 | if save_ply_: 137 | save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /moge/scripts/infer_panorama.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' 3 | from pathlib import Path 4 | import sys 5 | if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: 6 | sys.path.insert(0, _package_root) 7 | from typing import * 8 | import itertools 9 | import json 10 | import warnings 11 | 12 | import click 13 | 14 | 15 | @click.command(help='Inference script for panorama images') 16 | @click.option('--input', '-i', 'input_path', type=click.Path(exists=True), required=True, help='Input image or folder path. "jpg" and "png" are supported.') 17 | @click.option('--output', '-o', 'output_path', type=click.Path(), default='./output', help='Output folder path') 18 | @click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Defaults to "Ruicheng/moge-vitl"') 19 | @click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"') 20 | @click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).') 21 | @click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level of inference. The higher, the better but slower. Defaults to 9. Note that it is irrelevant to the output resolution.') 22 | @click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Defaults to 0.03. Smaller value removes more edges. "inf" means no thresholding.') 23 | @click.option('--batch_size', type=int, default=4, help='Batch size for inference. Defaults to 4.') 24 | @click.option('--splitted', 'save_splitted', is_flag=True, help='Whether to save the splitted images. Defaults to False.') 25 | @click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).') 26 | @click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.') 27 | @click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.') 28 | @click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.') 29 | def main( 30 | input_path: str, 31 | output_path: str, 32 | pretrained_model_name_or_path: str, 33 | device_name: str, 34 | resize_to: int, 35 | resolution_level: int, 36 | threshold: float, 37 | batch_size: int, 38 | save_splitted: bool, 39 | save_maps_: bool, 40 | save_glb_: bool, 41 | save_ply_: bool, 42 | show: bool, 43 | ): 44 | # Lazy import 45 | import cv2 46 | import numpy as np 47 | from numpy import ndarray 48 | import torch 49 | from PIL import Image 50 | from tqdm import tqdm, trange 51 | import trimesh 52 | import trimesh.visual 53 | from scipy.sparse import csr_array, hstack, vstack 54 | from scipy.ndimage import convolve 55 | from scipy.sparse.linalg import lsmr 56 | 57 | import utils3d 58 | from moge.model.v1 import MoGeModel 59 | from moge.utils.io import save_glb, save_ply 60 | from moge.utils.vis import colorize_depth 61 | from moge.utils.panorama import spherical_uv_to_directions, get_panorama_cameras, split_panorama_image, merge_panorama_depth 62 | 63 | 64 | device = torch.device(device_name) 65 | 66 | include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] 67 | if Path(input_path).is_dir(): 68 | image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) 69 | else: 70 | image_paths = [Path(input_path)] 71 | 72 | if len(image_paths) == 0: 73 | raise FileNotFoundError(f'No image files found in {input_path}') 74 | 75 | # Write outputs 76 | if not any([save_maps_, save_glb_, save_ply_]): 77 | warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.') 78 | save_maps_ = save_glb_ = save_ply_ = True 79 | 80 | model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() 81 | 82 | for image_path in (pbar := tqdm(image_paths, desc='Total images', disable=len(image_paths) <= 1)): 83 | image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) 84 | height, width = image.shape[:2] 85 | if resize_to is not None: 86 | height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) 87 | image = cv2.resize(image, (width, height), cv2.INTER_AREA) 88 | 89 | splitted_extrinsics, splitted_intriniscs = get_panorama_cameras() 90 | splitted_resolution = 512 91 | splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution) 92 | 93 | # Infer each view 94 | print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring') 95 | 96 | splitted_distance_maps, splitted_masks = [], [] 97 | for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False): 98 | image_tensor = torch.tensor(np.stack(splitted_images[i:i + batch_size]) / 255, dtype=torch.float32, device=device).permute(0, 3, 1, 2) 99 | fov_x, fov_y = np.rad2deg(utils3d.numpy.intrinsics_to_fov(np.array(splitted_intriniscs[i:i + batch_size]))) 100 | fov_x = torch.tensor(fov_x, dtype=torch.float32, device=device) 101 | output = model.infer(image_tensor, fov_x=fov_x, apply_mask=False) 102 | distance_map, mask = output['points'].norm(dim=-1).cpu().numpy(), output['mask'].cpu().numpy() 103 | splitted_distance_maps.extend(list(distance_map)) 104 | splitted_masks.extend(list(mask)) 105 | 106 | # Save splitted 107 | if save_splitted: 108 | splitted_save_path = Path(output_path, image_path.stem, 'splitted') 109 | splitted_save_path.mkdir(exist_ok=True, parents=True) 110 | for i in range(len(splitted_images)): 111 | cv2.imwrite(str(splitted_save_path / f'{i:02d}.jpg'), cv2.cvtColor(splitted_images[i], cv2.COLOR_RGB2BGR)) 112 | cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR)) 113 | 114 | # Merge 115 | print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging') 116 | 117 | merging_width, merging_height = min(1920, width), min(960, height) 118 | panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs) 119 | panorama_depth = panorama_depth.astype(np.float32) 120 | panorama_depth = cv2.resize(panorama_depth, (width, height), cv2.INTER_LINEAR) 121 | panorama_mask = cv2.resize(panorama_mask.astype(np.uint8), (width, height), cv2.INTER_NEAREST) > 0 122 | points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.numpy.image_uv(width=width, height=height)) 123 | 124 | # Write outputs 125 | print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring') 126 | save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) 127 | save_path.mkdir(exist_ok=True, parents=True) 128 | if save_maps_: 129 | cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 130 | cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(panorama_depth, mask=panorama_mask), cv2.COLOR_RGB2BGR)) 131 | cv2.imwrite(str(save_path / 'depth.exr'), panorama_depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) 132 | cv2.imwrite(str(save_path / 'points.exr'), points, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) 133 | cv2.imwrite(str(save_path /'mask.png'), (panorama_mask * 255).astype(np.uint8)) 134 | 135 | # Export mesh & visulization 136 | if save_glb_ or save_ply_ or show: 137 | normals, normals_mask = utils3d.numpy.points_to_normals(points, panorama_mask) 138 | faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( 139 | points, 140 | image.astype(np.float32) / 255, 141 | utils3d.numpy.image_uv(width=width, height=height), 142 | mask=panorama_mask & ~(utils3d.numpy.depth_edge(panorama_depth, rtol=threshold) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), 143 | tri=True 144 | ) 145 | 146 | if save_glb_: 147 | save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image) 148 | 149 | if save_ply_: 150 | save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) 151 | 152 | if show: 153 | trimesh.Trimesh( 154 | vertices=vertices, 155 | vertex_colors=vertex_colors, 156 | faces=faces, 157 | process=False 158 | ).show() 159 | 160 | 161 | if __name__ == '__main__': 162 | main() -------------------------------------------------------------------------------- /moge/scripts/vis_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' 3 | import sys 4 | from pathlib import Path 5 | if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: 6 | sys.path.insert(0, _package_root) 7 | 8 | import click 9 | 10 | 11 | @click.command() 12 | @click.argument('folder_or_path', type=click.Path(exists=True)) 13 | @click.option('--output', '-o', 'output_folder', type=click.Path(), help='Path to output folder') 14 | @click.option('--max_depth', '-m', type=float, default=float('inf'), help='max depth') 15 | @click.option('--fov', type=float, default=None, help='field of view in degrees') 16 | @click.option('--show', 'show', is_flag=True, help='show point cloud') 17 | @click.option('--depth', 'depth_filename', type=str, default='depth.png', help='depth image file name') 18 | @click.option('--ply', 'save_ply', is_flag=True, help='save point cloud as PLY file') 19 | @click.option('--depth_vis', 'save_depth_vis', is_flag=True, help='save depth image') 20 | @click.option('--inf', 'inf_mask', is_flag=True, help='use infinity mask') 21 | @click.option('--version', 'version', type=str, default='v3', help='version of rgbd data') 22 | def main( 23 | folder_or_path: str, 24 | output_folder: str, 25 | max_depth: float, 26 | fov: float, 27 | depth_filename: str, 28 | show: bool, 29 | save_ply: bool, 30 | save_depth_vis: bool, 31 | inf_mask: bool, 32 | version: str 33 | ): 34 | # Lazy import 35 | import cv2 36 | import numpy as np 37 | import utils3d 38 | from tqdm import tqdm 39 | import trimesh 40 | 41 | from moge.utils.io import read_image, read_depth, read_meta 42 | from moge.utils.vis import colorize_depth, colorize_normal 43 | 44 | filepaths = sorted(p.parent for p in Path(folder_or_path).rglob('meta.json')) 45 | 46 | for filepath in tqdm(filepaths): 47 | image = read_image(Path(filepath, 'image.jpg')) 48 | depth, unit = read_depth(Path(filepath, depth_filename)) 49 | meta = read_meta(Path(filepath,'meta.json')) 50 | depth_mask = np.isfinite(depth) 51 | depth_mask_inf = (depth == np.inf) 52 | intrinsics = np.array(meta['intrinsics']) 53 | 54 | extrinsics = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=float) # OpenGL's identity camera 55 | verts = utils3d.numpy.unproject_cv(utils3d.numpy.image_uv(*image.shape[:2]), depth, extrinsics=extrinsics, intrinsics=intrinsics) 56 | 57 | depth_mask_ply = depth_mask & (depth < depth[depth_mask].min() * max_depth) 58 | point_cloud = trimesh.PointCloud(verts[depth_mask_ply], image[depth_mask_ply] / 255) 59 | 60 | if show: 61 | point_cloud.show() 62 | 63 | if output_folder is None: 64 | output_path = filepath 65 | else: 66 | output_path = Path(output_folder, filepath.name) 67 | output_path.mkdir(exist_ok=True, parents=True) 68 | 69 | if inf_mask: 70 | depth = np.where(depth_mask_inf, np.inf, depth) 71 | depth_mask = depth_mask | depth_mask_inf 72 | 73 | if save_depth_vis: 74 | p = output_path.joinpath('depth_vis.png') 75 | cv2.imwrite(str(p), cv2.cvtColor(colorize_depth(depth, depth_mask), cv2.COLOR_RGB2BGR)) 76 | print(f"{p}") 77 | 78 | if save_ply: 79 | p = output_path.joinpath('pointcloud.ply') 80 | point_cloud.export(p) 81 | print(f"{p}") 82 | 83 | if __name__ == '__main__': 84 | main() -------------------------------------------------------------------------------- /moge/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/test/__init__.py -------------------------------------------------------------------------------- /moge/test/baseline.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import click 4 | import torch 5 | 6 | 7 | class MGEBaselineInterface: 8 | """ 9 | Abstract class for model wrapper to uniformize the interface of loading and inference across different models. 10 | """ 11 | device: torch.device 12 | 13 | @click.command() 14 | @staticmethod 15 | def load(*args, **kwargs) -> "MGEBaselineInterface": 16 | """ 17 | Customized static method to create an instance of the model wrapper from command line arguments. Decorated by `click.command()` 18 | """ 19 | raise NotImplementedError(f"{type(self).__name__} has not implemented the load method.") 20 | 21 | def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: 22 | """ 23 | ### Parameters 24 | `image`: [B, 3, H, W] or [3, H, W], RGB values in range [0, 1] 25 | `intrinsics`: [B, 3, 3] or [3, 3], camera intrinsics. Optional. 26 | 27 | ### Returns 28 | A dictionary containing: 29 | - `points_*`. point map output in OpenCV identity camera space. 30 | Supported suffixes: `metric`, `scale_invariant`, `affine_invariant`. 31 | - `depth_*`. depth map output 32 | Supported suffixes: `metric` (in meters), `scale_invariant`, `affine_invariant`. 33 | - `disparity_affine_invariant`. affine disparity map output 34 | """ 35 | raise NotImplementedError(f"{type(self).__name__} has not implemented the infer method.") 36 | 37 | def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: 38 | """ 39 | If the model has a special evaluation mode, override this method to provide the evaluation mode inference. 40 | 41 | By default, this method simply calls `infer()`. 42 | """ 43 | return self.infer(image, intrinsics) -------------------------------------------------------------------------------- /moge/test/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import * 3 | from pathlib import Path 4 | import math 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | import cv2 10 | import utils3d 11 | 12 | from ..utils import pipeline 13 | from ..utils.geometry_numpy import focal_to_fov_numpy, mask_aware_nearest_resize_numpy, norm3d 14 | from ..utils.io import * 15 | from ..utils.tools import timeit 16 | 17 | 18 | class EvalDataLoaderPipeline: 19 | 20 | def __init__( 21 | self, 22 | path: str, 23 | width: int, 24 | height: int, 25 | split: int = '.index.txt', 26 | drop_max_depth: float = 1000., 27 | num_load_workers: int = 4, 28 | num_process_workers: int = 8, 29 | include_segmentation: bool = False, 30 | include_normal: bool = False, 31 | depth_to_normal: bool = False, 32 | max_segments: int = 100, 33 | min_seg_area: int = 1000, 34 | depth_unit: str = None, 35 | has_sharp_boundary = False, 36 | subset: int = None, 37 | ): 38 | filenames = Path(path).joinpath(split).read_text(encoding='utf-8').splitlines() 39 | filenames = filenames[::subset] 40 | self.width = width 41 | self.height = height 42 | self.drop_max_depth = drop_max_depth 43 | self.path = Path(path) 44 | self.filenames = filenames 45 | self.include_segmentation = include_segmentation 46 | self.include_normal = include_normal 47 | self.max_segments = max_segments 48 | self.min_seg_area = min_seg_area 49 | self.depth_to_normal = depth_to_normal 50 | self.depth_unit = depth_unit 51 | self.has_sharp_boundary = has_sharp_boundary 52 | 53 | self.rng = np.random.default_rng(seed=0) 54 | 55 | self.pipeline = pipeline.Sequential([ 56 | self._generator, 57 | pipeline.Parallel([self._load_instance] * num_load_workers), 58 | pipeline.Parallel([self._process_instance] * num_process_workers), 59 | pipeline.Buffer(4) 60 | ]) 61 | 62 | def __len__(self): 63 | return math.ceil(len(self.filenames)) 64 | 65 | def _generator(self): 66 | for idx in range(len(self)): 67 | yield idx 68 | 69 | def _load_instance(self, idx): 70 | if idx >= len(self.filenames): 71 | return None 72 | 73 | path = self.path.joinpath(self.filenames[idx]) 74 | 75 | instance = { 76 | 'filename': self.filenames[idx], 77 | 'width': self.width, 78 | 'height': self.height, 79 | } 80 | instance['image'] = read_image(Path(path, 'image.jpg')) 81 | 82 | depth, _ = read_depth(Path(path, 'depth.png')) # ignore depth unit from depth file, use config instead 83 | instance.update({ 84 | 'depth': np.nan_to_num(depth, nan=1, posinf=1, neginf=1), 85 | 'depth_mask': np.isfinite(depth), 86 | 'depth_mask_inf': np.isinf(depth), 87 | }) 88 | 89 | if self.include_segmentation: 90 | segmentation_mask, segmentation_labels = read_segmentation(Path(path,'segmentation.png')) 91 | instance.update({ 92 | 'segmentation_mask': segmentation_mask, 93 | 'segmentation_labels': segmentation_labels, 94 | }) 95 | 96 | meta = read_meta(Path(path, 'meta.json')) 97 | instance['intrinsics'] = np.array(meta['intrinsics'], dtype=np.float32) 98 | 99 | return instance 100 | 101 | def _process_instance(self, instance: dict): 102 | if instance is None: 103 | return None 104 | 105 | image, depth, depth_mask, intrinsics = instance['image'], instance['depth'], instance['depth_mask'], instance['intrinsics'] 106 | segmentation_mask, segmentation_labels = instance.get('segmentation_mask', None), instance.get('segmentation_labels', None) 107 | 108 | raw_height, raw_width = image.shape[:2] 109 | raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1]) 110 | raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height 111 | tgt_width, tgt_height = instance['width'], instance['height'] 112 | tgt_aspect = tgt_width / tgt_height 113 | 114 | # set expected target view field 115 | tgt_horizontal = min(raw_horizontal, raw_vertical * tgt_aspect) 116 | tgt_vertical = tgt_horizontal / tgt_aspect 117 | 118 | # set target view direction 119 | cu, cv = 0.5, 0.5 120 | direction = utils3d.numpy.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0] 121 | R = utils3d.numpy.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32)) 122 | 123 | # restrict target view field within the raw view 124 | corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32) 125 | corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane 126 | corners = corners[:, :2] / corners[:, 2:3] 127 | 128 | warp_horizontal, warp_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1]) 129 | for i in range(4): 130 | intersection, _ = utils3d.numpy.ray_intersection( 131 | np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]), 132 | corners[i - 1], corners[i] - corners[i - 1], 133 | ) 134 | warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min()) 135 | tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical) 136 | 137 | # get target view intrinsics 138 | fx, fy = 1.0 / tgt_horizontal, 1.0 / tgt_vertical 139 | tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32) 140 | 141 | # do homogeneous transformation with the rotation and intrinsics 142 | # 4.1 The image and depth is resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling 143 | tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes) 144 | rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h) 145 | image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS)) 146 | 147 | depth, depth_mask = mask_aware_nearest_resize_numpy(depth, depth_mask, (rescaled_w, rescaled_h)) 148 | distance = norm3d(utils3d.numpy.depth_to_points(depth, intrinsics=intrinsics)) 149 | segmentation_mask = cv2.resize(segmentation_mask, (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) if segmentation_mask is not None else None 150 | 151 | # 4.2 calculate homography warping 152 | transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics) 153 | uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height) 154 | pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T 155 | uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12) 156 | pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32) 157 | 158 | tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) 159 | tgt_distance = cv2.remap(distance, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) 160 | tgt_ray_length = utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics) 161 | tgt_ray_length = (tgt_ray_length[:, :, 0] ** 2 + tgt_ray_length[:, :, 1] ** 2 + tgt_ray_length[:, :, 2] ** 2) ** 0.5 162 | tgt_depth = tgt_distance / (tgt_ray_length + 1e-12) 163 | tgt_depth_mask = cv2.remap(depth_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0 164 | tgt_segmentation_mask = cv2.remap(segmentation_mask, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) if segmentation_mask is not None else None 165 | 166 | # drop depth greater than drop_max_depth 167 | max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.drop_max_depth 168 | tgt_depth_mask &= tgt_depth <= max_depth 169 | tgt_depth = np.nan_to_num(tgt_depth, nan=0.0) 170 | 171 | if self.depth_unit is not None: 172 | tgt_depth *= self.depth_unit 173 | 174 | if not np.any(tgt_depth_mask): 175 | # always make sure that mask is not empty, otherwise the loss calculation will crash 176 | tgt_depth_mask = np.ones_like(tgt_depth_mask) 177 | tgt_depth = np.ones_like(tgt_depth) 178 | instance['label_type'] = 'invalid' 179 | 180 | tgt_pts = utils3d.numpy.unproject_cv(uv_tgt, tgt_depth, intrinsics=tgt_intrinsics) 181 | 182 | # Process segmentation labels 183 | if self.include_segmentation and segmentation_mask is not None: 184 | for k in ['undefined', 'unannotated', 'background', 'sky']: 185 | if k in segmentation_labels: 186 | del segmentation_labels[k] 187 | seg_id2count = dict(zip(*np.unique(tgt_segmentation_mask, return_counts=True))) 188 | sorted_labels = sorted(segmentation_labels.keys(), key=lambda x: seg_id2count.get(segmentation_labels[x], 0), reverse=True) 189 | segmentation_labels = {k: segmentation_labels[k] for k in sorted_labels[:self.max_segments] if seg_id2count.get(segmentation_labels[k], 0) >= self.min_seg_area} 190 | 191 | instance.update({ 192 | 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1), 193 | 'depth': torch.from_numpy(tgt_depth).float(), 194 | 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(), 195 | 'intrinsics': torch.from_numpy(tgt_intrinsics).float(), 196 | 'points': torch.from_numpy(tgt_pts).float(), 197 | 'segmentation_mask': torch.from_numpy(tgt_segmentation_mask).long() if tgt_segmentation_mask is not None else None, 198 | 'segmentation_labels': segmentation_labels, 199 | 'is_metric': self.depth_unit is not None, 200 | 'has_sharp_boundary': self.has_sharp_boundary, 201 | }) 202 | 203 | instance = {k: v for k, v in instance.items() if v is not None} 204 | 205 | return instance 206 | 207 | def start(self): 208 | self.pipeline.start() 209 | 210 | def stop(self): 211 | self.pipeline.stop() 212 | 213 | def __enter__(self): 214 | self.start() 215 | return self 216 | 217 | def __exit__(self, exc_type, exc_value, traceback): 218 | self.stop() 219 | 220 | def get(self): 221 | return self.pipeline.get() -------------------------------------------------------------------------------- /moge/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/train/__init__.py -------------------------------------------------------------------------------- /moge/train/utils.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import fnmatch 3 | 4 | import sympy 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def any_match(s: str, patterns: List[str]) -> bool: 10 | return any(fnmatch.fnmatch(s, pat) for pat in patterns) 11 | 12 | 13 | def build_optimizer(model: nn.Module, optimizer_config: Dict[str, Any]) -> torch.optim.Optimizer: 14 | named_param_groups = [ 15 | { 16 | k: p for k, p in model.named_parameters() if any_match(k, param_group_config['params']['include']) and not any_match(k, param_group_config['params'].get('exclude', [])) 17 | } for param_group_config in optimizer_config['params'] 18 | ] 19 | excluded_params = [k for k, p in model.named_parameters() if p.requires_grad and not any(k in named_params for named_params in named_param_groups)] 20 | assert len(excluded_params) == 0, f'The following parameters require grad but are excluded from the optimizer: {excluded_params}' 21 | optimizer_cls = getattr(torch.optim, optimizer_config['type']) 22 | optimizer = optimizer_cls([ 23 | { 24 | **param_group_config, 25 | 'params': list(params.values()), 26 | } for param_group_config, params in zip(optimizer_config['params'], named_param_groups) 27 | ]) 28 | return optimizer 29 | 30 | 31 | def parse_lr_lambda(s: str) -> Callable[[int], float]: 32 | epoch = sympy.symbols('epoch') 33 | lr_lambda = sympy.sympify(s) 34 | return sympy.lambdify(epoch, lr_lambda, 'math') 35 | 36 | 37 | def build_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]) -> torch.optim.lr_scheduler._LRScheduler: 38 | if scheduler_config['type'] == "SequentialLR": 39 | child_schedulers = [ 40 | build_lr_scheduler(optimizer, child_scheduler_config) 41 | for child_scheduler_config in scheduler_config['params']['schedulers'] 42 | ] 43 | return torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=child_schedulers, milestones=scheduler_config['params']['milestones']) 44 | elif scheduler_config['type'] == "LambdaLR": 45 | lr_lambda = scheduler_config['params']['lr_lambda'] 46 | if isinstance(lr_lambda, str): 47 | lr_lambda = parse_lr_lambda(lr_lambda) 48 | elif isinstance(lr_lambda, list): 49 | lr_lambda = [parse_lr_lambda(l) for l in lr_lambda] 50 | return torch.optim.lr_scheduler.LambdaLR( 51 | optimizer, 52 | lr_lambda=lr_lambda, 53 | ) 54 | else: 55 | scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_config['type']) 56 | scheduler = scheduler_cls(optimizer, **scheduler_config.get('params', {})) 57 | return scheduler -------------------------------------------------------------------------------- /moge/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MoGe/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/utils/__init__.py -------------------------------------------------------------------------------- /moge/utils/download.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import * 3 | import requests 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | __all__ = ["download_file", "download_bytes"] 9 | 10 | 11 | def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None: 12 | # Ensure headers is a dict if not provided 13 | headers = headers or {} 14 | 15 | # Initialize local variables 16 | file_path = Path(filepath) 17 | downloaded_bytes = 0 18 | 19 | # Check if we should resume the download 20 | if resume and file_path.exists(): 21 | downloaded_bytes = file_path.stat().st_size 22 | headers['Range'] = f"bytes={downloaded_bytes}-" 23 | 24 | # Make a GET request to fetch the file 25 | with requests.get(url, stream=True, headers=headers) as response: 26 | response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx 27 | 28 | # Calculate the total size to download 29 | total_size = downloaded_bytes + int(response.headers.get('content-length', 0)) 30 | 31 | # Display a progress bar while downloading 32 | with ( 33 | tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar, 34 | open(file_path, 'ab') as file, 35 | ): 36 | # Set the initial position of the progress bar 37 | pbar.update(downloaded_bytes) 38 | 39 | # Write the content to the file in chunks 40 | for chunk in response.iter_content(chunk_size=4096): 41 | file.write(chunk) 42 | pbar.update(len(chunk)) 43 | 44 | 45 | def download_bytes(url: str, headers: dict = None) -> bytes: 46 | # Ensure headers is a dict if not provided 47 | headers = headers or {} 48 | 49 | # Make a GET request to fetch the file 50 | with requests.get(url, stream=True, headers=headers) as response: 51 | response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx 52 | 53 | # Read the content of the response 54 | return response.content 55 | -------------------------------------------------------------------------------- /moge/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' 3 | from typing import IO 4 | import zipfile 5 | import json 6 | import io 7 | from typing import * 8 | from pathlib import Path 9 | import re 10 | from PIL import Image, PngImagePlugin 11 | 12 | import numpy as np 13 | import cv2 14 | 15 | from .tools import timeit 16 | 17 | 18 | def save_glb( 19 | save_path: Union[str, os.PathLike], 20 | vertices: np.ndarray, 21 | faces: np.ndarray, 22 | vertex_uvs: np.ndarray, 23 | texture: np.ndarray, 24 | ): 25 | import trimesh 26 | import trimesh.visual 27 | from PIL import Image 28 | 29 | trimesh.Trimesh( 30 | vertices=vertices, 31 | faces=faces, 32 | visual = trimesh.visual.texture.TextureVisuals( 33 | uv=vertex_uvs, 34 | material=trimesh.visual.material.PBRMaterial( 35 | baseColorTexture=Image.fromarray(texture), 36 | metallicFactor=0.5, 37 | roughnessFactor=1.0 38 | ) 39 | ), 40 | process=False 41 | ).export(save_path) 42 | 43 | 44 | def save_ply( 45 | save_path: Union[str, os.PathLike], 46 | vertices: np.ndarray, 47 | faces: np.ndarray, 48 | vertex_colors: np.ndarray, 49 | ): 50 | import trimesh 51 | import trimesh.visual 52 | from PIL import Image 53 | 54 | trimesh.Trimesh( 55 | vertices=vertices, 56 | faces=faces, 57 | vertex_colors=vertex_colors, 58 | process=False 59 | ).export(save_path) 60 | 61 | 62 | 63 | def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray: 64 | """ 65 | Read a image, return uint8 RGB array of shape (H, W, 3). 66 | """ 67 | if isinstance(path, (str, os.PathLike)): 68 | data = Path(path).read_bytes() 69 | else: 70 | data = path.read() 71 | image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 72 | return image 73 | 74 | 75 | def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95): 76 | """ 77 | Write a image, input uint8 RGB array of shape (H, W, 3). 78 | """ 79 | data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes() 80 | if isinstance(path, (str, os.PathLike)): 81 | Path(path).write_bytes(data) 82 | else: 83 | path.write(data) 84 | 85 | 86 | def read_depth(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, float]: 87 | """ 88 | Read a depth image, return float32 depth array of shape (H, W). 89 | """ 90 | if isinstance(path, (str, os.PathLike)): 91 | data = Path(path).read_bytes() 92 | else: 93 | data = path.read() 94 | pil_image = Image.open(io.BytesIO(data)) 95 | near = float(pil_image.info.get('near')) 96 | far = float(pil_image.info.get('far')) 97 | unit = float(pil_image.info.get('unit')) if 'unit' in pil_image.info else None 98 | depth = np.array(pil_image) 99 | mask_nan, mask_inf = depth == 0, depth == 65535 100 | depth = (depth.astype(np.float32) - 1) / 65533 101 | depth = near ** (1 - depth) * far ** depth 102 | depth[mask_nan] = np.nan 103 | depth[mask_inf] = np.inf 104 | return depth, unit 105 | 106 | 107 | def write_depth( 108 | path: Union[str, os.PathLike, IO], 109 | depth: np.ndarray, 110 | unit: float = None, 111 | max_range: float = 1e5, 112 | compression_level: int = 7, 113 | ): 114 | """ 115 | Encode and write a depth image as 16-bit PNG format. 116 | ### Parameters: 117 | - `path: Union[str, os.PathLike, IO]` 118 | The file path or file object to write to. 119 | - `depth: np.ndarray` 120 | The depth array, float32 array of shape (H, W). 121 | May contain `NaN` for invalid values and `Inf` for infinite values. 122 | - `unit: float = None` 123 | The unit of the depth values. 124 | 125 | Depth values are encoded as follows: 126 | - 0: unknown 127 | - 1 ~ 65534: depth values in logarithmic 128 | - 65535: infinity 129 | 130 | metadata is stored in the PNG file as text fields: 131 | - `near`: the minimum depth value 132 | - `far`: the maximum depth value 133 | - `unit`: the unit of the depth values (optional) 134 | """ 135 | mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth) 136 | 137 | depth = depth.astype(np.float32) 138 | mask_finite = depth 139 | near = max(depth[mask_values].min(), 1e-5) 140 | far = max(near * 1.1, min(depth[mask_values].max(), near * max_range)) 141 | depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534 142 | depth[mask_nan] = 0 143 | depth[mask_inf] = 65535 144 | 145 | pil_image = Image.fromarray(depth) 146 | pnginfo = PngImagePlugin.PngInfo() 147 | pnginfo.add_text('near', str(near)) 148 | pnginfo.add_text('far', str(far)) 149 | if unit is not None: 150 | pnginfo.add_text('unit', str(unit)) 151 | pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level) 152 | 153 | 154 | def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]: 155 | """ 156 | Read a segmentation mask 157 | ### Parameters: 158 | - `path: Union[str, os.PathLike, IO]` 159 | The file path or file object to read from. 160 | ### Returns: 161 | - `Tuple[np.ndarray, Dict[str, int]]` 162 | A tuple containing: 163 | - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W). 164 | - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}. 165 | """ 166 | if isinstance(path, (str, os.PathLike)): 167 | data = Path(path).read_bytes() 168 | else: 169 | data = path.read() 170 | pil_image = Image.open(io.BytesIO(data)) 171 | labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None 172 | mask = np.array(pil_image) 173 | return mask, labels 174 | 175 | 176 | def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7): 177 | """ 178 | Write a segmentation mask and label mapping, as PNG format. 179 | ### Parameters: 180 | - `path: Union[str, os.PathLike, IO]` 181 | The file path or file object to write to. 182 | - `mask: np.ndarray` 183 | The segmentation mask, uint8 or uint16 array of shape (H, W). 184 | - `labels: Dict[str, int] = None` 185 | The label mapping, a dictionary of {label_name: label_id}. 186 | - `compression_level: int = 7` 187 | The compression level for PNG compression. 188 | """ 189 | assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}" 190 | pil_image = Image.fromarray(mask) 191 | pnginfo = PngImagePlugin.PngInfo() 192 | if labels is not None: 193 | labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':')) 194 | pnginfo.add_text('labels', labels_json) 195 | pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level) 196 | 197 | 198 | 199 | def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray: 200 | """ 201 | Read a normal image, return float32 normal array of shape (H, W, 3). 202 | """ 203 | if isinstance(path, (str, os.PathLike)): 204 | data = Path(path).read_bytes() 205 | else: 206 | data = path.read() 207 | normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB) 208 | mask_nan = np.all(normal == 0, axis=-1) 209 | normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0] 210 | normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12) 211 | normal[mask_nan] = np.nan 212 | return normal 213 | 214 | 215 | def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray: 216 | """ 217 | Write a normal image, input float32 normal array of shape (H, W, 3). 218 | """ 219 | mask_nan = np.isnan(normal).any(axis=-1) 220 | normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16) 221 | normal[mask_nan] = 0 222 | data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes() 223 | if isinstance(path, (str, os.PathLike)): 224 | Path(path).write_bytes(data) 225 | else: 226 | path.write(data) 227 | 228 | 229 | def read_meta(path: Union[str, os.PathLike, IO]) -> Dict[str, Any]: 230 | return json.loads(Path(path).read_text()) 231 | 232 | def write_meta(path: Union[str, os.PathLike, IO], meta: Dict[str, Any]): 233 | Path(path).write_text(json.dumps(meta)) -------------------------------------------------------------------------------- /moge/utils/panorama.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' 3 | from pathlib import Path 4 | from typing import * 5 | import itertools 6 | import json 7 | import warnings 8 | 9 | import cv2 10 | import numpy as np 11 | from numpy import ndarray 12 | from tqdm import tqdm, trange 13 | from scipy.sparse import csr_array, hstack, vstack 14 | from scipy.ndimage import convolve 15 | from scipy.sparse.linalg import lsmr 16 | 17 | import utils3d 18 | 19 | 20 | def get_panorama_cameras(): 21 | vertices, _ = utils3d.numpy.icosahedron() 22 | intrinsics = utils3d.numpy.intrinsics_from_fov(fov_x=np.deg2rad(90), fov_y=np.deg2rad(90)) 23 | extrinsics = utils3d.numpy.extrinsics_look_at([0, 0, 0], vertices, [0, 0, 1]).astype(np.float32) 24 | return extrinsics, [intrinsics] * len(vertices) 25 | 26 | 27 | def spherical_uv_to_directions(uv: np.ndarray): 28 | theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi 29 | directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1) 30 | return directions 31 | 32 | 33 | def directions_to_spherical_uv(directions: np.ndarray): 34 | directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True) 35 | u = 1 - np.arctan2(directions[..., 1], directions[..., 0]) / (2 * np.pi) % 1.0 36 | v = np.arccos(directions[..., 2]) / np.pi 37 | return np.stack([u, v], axis=-1) 38 | 39 | 40 | def split_panorama_image(image: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray, resolution: int): 41 | height, width = image.shape[:2] 42 | uv = utils3d.numpy.image_uv(width=resolution, height=resolution) 43 | splitted_images = [] 44 | for i in range(len(extrinsics)): 45 | spherical_uv = directions_to_spherical_uv(utils3d.numpy.unproject_cv(uv, extrinsics=extrinsics[i], intrinsics=intrinsics[i])) 46 | pixels = utils3d.numpy.uv_to_pixel(spherical_uv, width=width, height=height).astype(np.float32) 47 | 48 | splitted_image = cv2.remap(image, pixels[..., 0], pixels[..., 1], interpolation=cv2.INTER_LINEAR) 49 | splitted_images.append(splitted_image) 50 | return splitted_images 51 | 52 | 53 | def poisson_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, ndarray]: 54 | grid_index = np.arange(height * width).reshape(height, width) 55 | grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode='wrap' if wrap_x else 'edge') 56 | grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode='wrap' if wrap_y else 'edge') 57 | 58 | data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(height * width, axis=0).reshape(-1) 59 | indices = np.stack([ 60 | grid_index[1:-1, 1:-1], 61 | grid_index[:-2, 1:-1], # up 62 | grid_index[2:, 1:-1], # down 63 | grid_index[1:-1, :-2], # left 64 | grid_index[1:-1, 2:] # right 65 | ], axis=-1).reshape(-1) 66 | indptr = np.arange(0, height * width * 5 + 1, 5) 67 | A = csr_array((data, indices, indptr), shape=(height * width, height * width)) 68 | 69 | return A 70 | 71 | 72 | def grad_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, np.ndarray]: 73 | grid_index = np.arange(width * height).reshape(height, width) 74 | if wrap_x: 75 | grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode='wrap') 76 | if wrap_y: 77 | grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode='wrap') 78 | 79 | data = np.concatenate([ 80 | np.concatenate([ 81 | np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j] 82 | -np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j-1] 83 | ], axis=1).reshape(-1), 84 | np.concatenate([ 85 | np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i,j] 86 | -np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i-1,j] 87 | ], axis=1).reshape(-1), 88 | ]) 89 | indices = np.concatenate([ 90 | np.concatenate([ 91 | grid_index[:, :-1].reshape(-1, 1), 92 | grid_index[:, 1:].reshape(-1, 1), 93 | ], axis=1).reshape(-1), 94 | np.concatenate([ 95 | grid_index[:-1, :].reshape(-1, 1), 96 | grid_index[1:, :].reshape(-1, 1), 97 | ], axis=1).reshape(-1), 98 | ]) 99 | indptr = np.arange(0, grid_index.shape[0] * (grid_index.shape[1] - 1) * 2 + (grid_index.shape[0] - 1) * grid_index.shape[1] * 2 + 1, 2) 100 | A = csr_array((data, indices, indptr), shape=(grid_index.shape[0] * (grid_index.shape[1] - 1) + (grid_index.shape[0] - 1) * grid_index.shape[1], height * width)) 101 | 102 | return A 103 | 104 | 105 | def merge_panorama_depth(width: int, height: int, distance_maps: List[np.ndarray], pred_masks: List[np.ndarray], extrinsics: List[np.ndarray], intrinsics: List[np.ndarray]): 106 | if max(width, height) > 256: 107 | panorama_depth_init, _ = merge_panorama_depth(width // 2, height // 2, distance_maps, pred_masks, extrinsics, intrinsics) 108 | panorama_depth_init = cv2.resize(panorama_depth_init, (width, height), cv2.INTER_LINEAR) 109 | else: 110 | panorama_depth_init = None 111 | 112 | uv = utils3d.numpy.image_uv(width=width, height=height) 113 | spherical_directions = spherical_uv_to_directions(uv) 114 | 115 | # Warp each view to the panorama 116 | panorama_log_distance_grad_maps, panorama_grad_masks = [], [] 117 | panorama_log_distance_laplacian_maps, panorama_laplacian_masks = [], [] 118 | panorama_pred_masks = [] 119 | for i in range(len(distance_maps)): 120 | projected_uv, projected_depth = utils3d.numpy.project_cv(spherical_directions, extrinsics=extrinsics[i], intrinsics=intrinsics[i]) 121 | projection_valid_mask = (projected_depth > 0) & (projected_uv > 0).all(axis=-1) & (projected_uv < 1).all(axis=-1) 122 | 123 | projected_pixels = utils3d.numpy.uv_to_pixel(np.clip(projected_uv, 0, 1), width=distance_maps[i].shape[1], height=distance_maps[i].shape[0]).astype(np.float32) 124 | 125 | log_splitted_distance = np.log(distance_maps[i]) 126 | panorama_log_distance_map = np.where(projection_valid_mask, cv2.remap(log_splitted_distance, projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE), 0) 127 | panorama_pred_mask = projection_valid_mask & (cv2.remap(pred_masks[i].astype(np.uint8), projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE) > 0) 128 | 129 | # calculate gradient map 130 | padded = np.pad(panorama_log_distance_map, ((0, 0), (0, 1)), mode='wrap') 131 | grad_x, grad_y = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :] 132 | 133 | padded = np.pad(panorama_pred_mask, ((0, 0), (0, 1)), mode='wrap') 134 | mask_x, mask_y = padded[:, :-1] & padded[:, 1:], padded[:-1, :] & padded[1:, :] 135 | 136 | panorama_log_distance_grad_maps.append((grad_x, grad_y)) 137 | panorama_grad_masks.append((mask_x, mask_y)) 138 | 139 | # calculate laplacian map 140 | padded = np.pad(panorama_log_distance_map, ((1, 1), (0, 0)), mode='edge') 141 | padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap') 142 | laplacian = convolve(padded, np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32))[1:-1, 1:-1] 143 | 144 | padded = np.pad(panorama_pred_mask, ((1, 1), (0, 0)), mode='edge') 145 | padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap') 146 | mask = convolve(padded.astype(np.uint8), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8))[1:-1, 1:-1] == 5 147 | 148 | panorama_log_distance_laplacian_maps.append(laplacian) 149 | panorama_laplacian_masks.append(mask) 150 | 151 | panorama_pred_masks.append(panorama_pred_mask) 152 | 153 | panorama_log_distance_grad_x = np.stack([grad_map[0] for grad_map in panorama_log_distance_grad_maps], axis=0) 154 | panorama_log_distance_grad_y = np.stack([grad_map[1] for grad_map in panorama_log_distance_grad_maps], axis=0) 155 | panorama_grad_mask_x = np.stack([mask_map[0] for mask_map in panorama_grad_masks], axis=0) 156 | panorama_grad_mask_y = np.stack([mask_map[1] for mask_map in panorama_grad_masks], axis=0) 157 | 158 | panorama_log_distance_grad_x = np.sum(panorama_log_distance_grad_x * panorama_grad_mask_x, axis=0) / np.sum(panorama_grad_mask_x, axis=0).clip(1e-3) 159 | panorama_log_distance_grad_y = np.sum(panorama_log_distance_grad_y * panorama_grad_mask_y, axis=0) / np.sum(panorama_grad_mask_y, axis=0).clip(1e-3) 160 | 161 | panorama_laplacian_maps = np.stack(panorama_log_distance_laplacian_maps, axis=0) 162 | panorama_laplacian_masks = np.stack(panorama_laplacian_masks, axis=0) 163 | panorama_laplacian_map = np.sum(panorama_laplacian_maps * panorama_laplacian_masks, axis=0) / np.sum(panorama_laplacian_masks, axis=0).clip(1e-3) 164 | 165 | grad_x_mask = np.any(panorama_grad_mask_x, axis=0).reshape(-1) 166 | grad_y_mask = np.any(panorama_grad_mask_y, axis=0).reshape(-1) 167 | grad_mask = np.concatenate([grad_x_mask, grad_y_mask]) 168 | laplacian_mask = np.any(panorama_laplacian_masks, axis=0).reshape(-1) 169 | 170 | # Solve overdetermined system 171 | A = vstack([ 172 | grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask], 173 | poisson_equation(width, height, wrap_x=True, wrap_y=False)[laplacian_mask], 174 | ]) 175 | b = np.concatenate([ 176 | panorama_log_distance_grad_x.reshape(-1)[grad_x_mask], 177 | panorama_log_distance_grad_y.reshape(-1)[grad_y_mask], 178 | panorama_laplacian_map.reshape(-1)[laplacian_mask] 179 | ]) 180 | x, *_ = lsmr( 181 | A, b, 182 | atol=1e-5, btol=1e-5, 183 | x0=np.log(panorama_depth_init).reshape(-1) if panorama_depth_init is not None else None, 184 | show=False, 185 | ) 186 | 187 | panorama_depth = np.exp(x).reshape(height, width).astype(np.float32) 188 | panorama_mask = np.any(panorama_pred_masks, axis=0) 189 | 190 | return panorama_depth, panorama_mask 191 | 192 | -------------------------------------------------------------------------------- /moge/utils/tools.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import time 3 | from pathlib import Path 4 | from numbers import Number 5 | from functools import wraps 6 | import warnings 7 | import math 8 | import json 9 | import os 10 | import importlib 11 | import importlib.util 12 | 13 | 14 | def catch_exception(fn): 15 | @wraps(fn) 16 | def wrapper(*args, **kwargs): 17 | try: 18 | return fn(*args, **kwargs) 19 | except Exception as e: 20 | import traceback 21 | print(f"Exception in {fn.__name__}", end='r') 22 | # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())}) 23 | traceback.print_exc(chain=False) 24 | time.sleep(0.1) 25 | return None 26 | return wrapper 27 | 28 | 29 | class CallbackOnException: 30 | def __init__(self, callback: Callable, exception: type): 31 | self.exception = exception 32 | self.callback = callback 33 | 34 | def __enter__(self): 35 | return self 36 | 37 | def __exit__(self, exc_type, exc_val, exc_tb): 38 | if isinstance(exc_val, self.exception): 39 | self.callback() 40 | return True 41 | return False 42 | 43 | def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]: 44 | for k, v in d.items(): 45 | if isinstance(v, dict): 46 | for sub_key in traverse_nested_dict_keys(v): 47 | yield (k, ) + sub_key 48 | else: 49 | yield (k, ) 50 | 51 | 52 | def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None): 53 | for k in keys: 54 | d = d.get(k, default) 55 | if d is None: 56 | break 57 | return d 58 | 59 | def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any): 60 | for k in keys[:-1]: 61 | d = d.setdefault(k, {}) 62 | d[keys[-1]] = value 63 | 64 | 65 | def key_average(list_of_dicts: list) -> Dict[str, Any]: 66 | """ 67 | Returns a dictionary with the average value of each key in the input list of dictionaries. 68 | """ 69 | _nested_dict_keys = set() 70 | for d in list_of_dicts: 71 | _nested_dict_keys.update(traverse_nested_dict_keys(d)) 72 | _nested_dict_keys = sorted(_nested_dict_keys) 73 | result = {} 74 | for k in _nested_dict_keys: 75 | values = [] 76 | for d in list_of_dicts: 77 | v = get_nested_dict(d, k) 78 | if v is not None and not math.isnan(v): 79 | values.append(v) 80 | avg = sum(values) / len(values) if values else float('nan') 81 | set_nested_dict(result, k, avg) 82 | return result 83 | 84 | 85 | def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]: 86 | """ 87 | Flattens a nested dictionary into a single-level dictionary, with keys as tuples. 88 | """ 89 | items = [] 90 | if parent_key is None: 91 | parent_key = () 92 | for k, v in d.items(): 93 | new_key = parent_key + (k, ) 94 | if isinstance(v, MutableMapping): 95 | items.extend(flatten_nested_dict(v, new_key).items()) 96 | else: 97 | items.append((new_key, v)) 98 | return dict(items) 99 | 100 | 101 | def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: 102 | """ 103 | Unflattens a single-level dictionary into a nested dictionary, with keys as tuples. 104 | """ 105 | result = {} 106 | for k, v in d.items(): 107 | sub_dict = result 108 | for k_ in k[:-1]: 109 | if k_ not in sub_dict: 110 | sub_dict[k_] = {} 111 | sub_dict = sub_dict[k_] 112 | sub_dict[k[-1]] = v 113 | return result 114 | 115 | 116 | def read_jsonl(file): 117 | import json 118 | with open(file, 'r') as f: 119 | data = f.readlines() 120 | return [json.loads(line) for line in data] 121 | 122 | 123 | def write_jsonl(data: List[dict], file): 124 | import json 125 | with open(file, 'w') as f: 126 | for item in data: 127 | f.write(json.dumps(item) + '\n') 128 | 129 | 130 | def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]): 131 | import pandas as pd 132 | data = [flatten_nested_dict(d) for d in data] 133 | df = pd.DataFrame(data) 134 | df = df.sort_index(axis=1) 135 | df.columns = pd.MultiIndex.from_tuples(df.columns) 136 | return df 137 | 138 | 139 | def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]): 140 | if isinstance(d, str): 141 | for old, new in mapping.items(): 142 | d = d.replace(old, new) 143 | elif isinstance(d, list): 144 | for i, item in enumerate(d): 145 | d[i] = recursive_replace(item, mapping) 146 | elif isinstance(d, dict): 147 | for k, v in d.items(): 148 | d[k] = recursive_replace(v, mapping) 149 | return d 150 | 151 | 152 | class timeit: 153 | _history: Dict[str, List['timeit']] = {} 154 | 155 | def __init__(self, name: str = None, verbose: bool = True, average: bool = False): 156 | self.name = name 157 | self.verbose = verbose 158 | self.start = None 159 | self.end = None 160 | self.average = average 161 | if average and name not in timeit._history: 162 | timeit._history[name] = [] 163 | 164 | def __call__(self, func: Callable): 165 | import inspect 166 | if inspect.iscoroutinefunction(func): 167 | async def wrapper(*args, **kwargs): 168 | with timeit(self.name or func.__qualname__): 169 | ret = await func(*args, **kwargs) 170 | return ret 171 | return wrapper 172 | else: 173 | def wrapper(*args, **kwargs): 174 | with timeit(self.name or func.__qualname__): 175 | ret = func(*args, **kwargs) 176 | return ret 177 | return wrapper 178 | 179 | def __enter__(self): 180 | self.start = time.time() 181 | return self 182 | 183 | @property 184 | def time(self) -> float: 185 | assert self.start is not None, "Time not yet started." 186 | assert self.end is not None, "Time not yet ended." 187 | return self.end - self.start 188 | 189 | @property 190 | def average_time(self) -> float: 191 | assert self.average, "Average time not available." 192 | return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name]) 193 | 194 | @property 195 | def history(self) -> List['timeit']: 196 | return timeit._history.get(self.name, []) 197 | 198 | def __exit__(self, exc_type, exc_val, exc_tb): 199 | self.end = time.time() 200 | if self.average: 201 | timeit._history[self.name].append(self) 202 | if self.verbose: 203 | if self.average: 204 | avg = self.average_time 205 | print(f"{self.name or 'It'} took {avg:.6f} seconds in average.") 206 | else: 207 | print(f"{self.name or 'It'} took {self.time:.6f} seconds.") 208 | 209 | 210 | def strip_common_prefix_suffix(strings: List[str]) -> List[str]: 211 | first = strings[0] 212 | 213 | for start in range(len(first)): 214 | if any(s[start] != strings[0][start] for s in strings): 215 | break 216 | 217 | for end in range(1, min(len(s) for s in strings)): 218 | if any(s[-end] != first[-end] for s in strings): 219 | break 220 | 221 | return [s[start:len(s) - end + 1] for s in strings] 222 | 223 | 224 | def multithead_execute(inputs: List[Any], num_workers: int, pbar = None): 225 | from concurrent.futures import ThreadPoolExecutor 226 | from contextlib import nullcontext 227 | from tqdm import tqdm 228 | 229 | if pbar is not None: 230 | pbar.total = len(inputs) if hasattr(inputs, '__len__') else None 231 | else: 232 | pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None) 233 | 234 | def decorator(fn: Callable): 235 | with ( 236 | ThreadPoolExecutor(max_workers=num_workers) as executor, 237 | pbar 238 | ): 239 | pbar.refresh() 240 | @catch_exception 241 | @suppress_traceback 242 | def _fn(input): 243 | ret = fn(input) 244 | pbar.update() 245 | return ret 246 | executor.map(_fn, inputs) 247 | executor.shutdown(wait=True) 248 | 249 | return decorator 250 | 251 | 252 | def suppress_traceback(fn): 253 | @wraps(fn) 254 | def wrapper(*args, **kwargs): 255 | try: 256 | return fn(*args, **kwargs) 257 | except Exception as e: 258 | e.__traceback__ = e.__traceback__.tb_next.tb_next 259 | raise 260 | return wrapper 261 | 262 | 263 | class no_warnings: 264 | def __init__(self, action: str = 'ignore', **kwargs): 265 | self.action = action 266 | self.filter_kwargs = kwargs 267 | 268 | def __call__(self, fn): 269 | @wraps(fn) 270 | def wrapper(*args, **kwargs): 271 | with warnings.catch_warnings(): 272 | warnings.simplefilter(self.action, **self.filter_kwargs) 273 | return fn(*args, **kwargs) 274 | return wrapper 275 | 276 | def __enter__(self): 277 | self.warnings_manager = warnings.catch_warnings() 278 | self.warnings_manager.__enter__() 279 | warnings.simplefilter(self.action, **self.filter_kwargs) 280 | 281 | def __exit__(self, exc_type, exc_val, exc_tb): 282 | self.warnings_manager.__exit__(exc_type, exc_val, exc_tb) 283 | 284 | 285 | def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str): 286 | spec = importlib.util.spec_from_file_location(module_name, file_path) 287 | module = importlib.util.module_from_spec(spec) 288 | spec.loader.exec_module(module) 289 | return module -------------------------------------------------------------------------------- /moge/utils/vis.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import numpy as np 4 | import matplotlib 5 | 6 | 7 | def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: 8 | if mask is None: 9 | depth = np.where(depth > 0, depth, np.nan) 10 | else: 11 | depth = np.where((depth > 0) & mask, depth, np.nan) 12 | disp = 1 / depth 13 | if normalize: 14 | min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99) 15 | disp = (disp - min_disp) / (max_disp - min_disp) 16 | colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0) 17 | colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) 18 | return colored 19 | 20 | 21 | def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray: 22 | if mask is not None: 23 | depth = np.where(mask, depth, np.nan) 24 | 25 | min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999) 26 | depth = (depth - min_depth) / (max_depth - min_depth) 27 | colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0) 28 | colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) 29 | return colored 30 | 31 | 32 | def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: 33 | if mask is not None: 34 | disparity = np.where(mask, disparity, np.nan) 35 | 36 | if normalize: 37 | min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999) 38 | disparity = (disparity - min_disp) / (max_disp - min_disp) 39 | colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0) 40 | colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) 41 | return colored 42 | 43 | 44 | def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray: 45 | colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20)[..., :3] 46 | colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) 47 | return colored 48 | 49 | 50 | def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray: 51 | if mask is not None: 52 | normal = np.where(mask[..., None], normal, 0) 53 | normal = normal * [0.5, -0.5, -0.5] + 0.5 54 | normal = (normal.clip(0, 1) * 255).astype(np.uint8) 55 | return normal 56 | 57 | 58 | def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None): 59 | vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map)) 60 | cmap = matplotlib.colormaps[cmap] 61 | colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3] 62 | if mask is not None: 63 | colorized_error_map = np.where(mask[..., None], colorized_error_map, 0) 64 | colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8)) 65 | return colorized_error_map 66 | -------------------------------------------------------------------------------- /moge/utils/webfile.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from typing import * 3 | 4 | __all__ = ["WebFile"] 5 | 6 | 7 | class WebFile: 8 | def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None): 9 | self.url = url 10 | self.session = session or requests.Session() 11 | self.session.headers.update(headers or {}) 12 | self._offset = 0 13 | self.size = size if size is not None else self._fetch_size() 14 | 15 | def _fetch_size(self): 16 | with self.session.get(self.url, stream=True) as response: 17 | response.raise_for_status() 18 | content_length = response.headers.get("Content-Length") 19 | if content_length is None: 20 | raise ValueError("Missing Content-Length in header") 21 | return int(content_length) 22 | 23 | def _fetch_data(self, offset: int, n: int) -> bytes: 24 | headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"} 25 | response = self.session.get(self.url, headers=headers) 26 | response.raise_for_status() 27 | return response.content 28 | 29 | def seekable(self) -> bool: 30 | return True 31 | 32 | def tell(self) -> int: 33 | return self._offset 34 | 35 | def available(self) -> int: 36 | return self.size - self._offset 37 | 38 | def seek(self, offset: int, whence: int = 0) -> None: 39 | if whence == 0: 40 | new_offset = offset 41 | elif whence == 1: 42 | new_offset = self._offset + offset 43 | elif whence == 2: 44 | new_offset = self.size + offset 45 | else: 46 | raise ValueError("Invalid value for whence") 47 | 48 | self._offset = max(0, min(new_offset, self.size)) 49 | 50 | def read(self, n: Optional[int] = None) -> bytes: 51 | if n is None or n < 0: 52 | n = self.available() 53 | else: 54 | n = min(n, self.available()) 55 | 56 | if n == 0: 57 | return b'' 58 | 59 | data = self._fetch_data(self._offset, n) 60 | self._offset += len(data) 61 | 62 | return data 63 | 64 | def close(self) -> None: 65 | pass 66 | 67 | def __enter__(self): 68 | return self 69 | 70 | def __exit__(self, exc_type, exc_value, traceback): 71 | pass 72 | 73 | -------------------------------------------------------------------------------- /moge/utils/webzipfile.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import io 3 | import os 4 | from zipfile import ( 5 | ZipInfo, BadZipFile, ZipFile, ZipExtFile, 6 | sizeFileHeader, structFileHeader, stringFileHeader, 7 | _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS, 8 | _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED 9 | ) 10 | import struct 11 | from requests import Session 12 | 13 | from .webfile import WebFile 14 | 15 | 16 | class _SharedWebFile(WebFile): 17 | def __init__(self, webfile: WebFile, pos: int): 18 | super().__init__(webfile.url, webfile.session, size=webfile.size) 19 | self.seek(pos) 20 | 21 | 22 | class WebZipFile(ZipFile): 23 | "Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads." 24 | def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None): 25 | """Open the ZIP file with mode read 'r', write 'w', exclusive create 'x', 26 | or append 'a'.""" 27 | webf = WebFile(url, session=session, headers=headers) 28 | super().__init__(webf, mode='r') 29 | 30 | def open(self, name, mode="r", pwd=None, *, force_zip64=False): 31 | """Return file-like object for 'name'. 32 | 33 | name is a string for the file name within the ZIP file, or a ZipInfo 34 | object. 35 | 36 | mode should be 'r' to read a file already in the ZIP file, or 'w' to 37 | write to a file newly added to the archive. 38 | 39 | pwd is the password to decrypt files (only used for reading). 40 | 41 | When writing, if the file size is not known in advance but may exceed 42 | 2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large 43 | files. If the size is known in advance, it is best to pass a ZipInfo 44 | instance for name, with zinfo.file_size set. 45 | """ 46 | if mode not in {"r", "w"}: 47 | raise ValueError('open() requires mode "r" or "w"') 48 | if pwd and (mode == "w"): 49 | raise ValueError("pwd is only supported for reading files") 50 | if not self.fp: 51 | raise ValueError( 52 | "Attempt to use ZIP archive that was already closed") 53 | 54 | assert mode == "r", "Only read mode is supported for now" 55 | 56 | # Make sure we have an info object 57 | if isinstance(name, ZipInfo): 58 | # 'name' is already an info object 59 | zinfo = name 60 | elif mode == 'w': 61 | zinfo = ZipInfo(name) 62 | zinfo.compress_type = self.compression 63 | zinfo._compresslevel = self.compresslevel 64 | else: 65 | # Get info object for name 66 | zinfo = self.getinfo(name) 67 | 68 | if mode == 'w': 69 | return self._open_to_write(zinfo, force_zip64=force_zip64) 70 | 71 | if self._writing: 72 | raise ValueError("Can't read from the ZIP file while there " 73 | "is an open writing handle on it. " 74 | "Close the writing handle before trying to read.") 75 | 76 | # Open for reading: 77 | self._fileRefCnt += 1 78 | zef_file = _SharedWebFile(self.fp, zinfo.header_offset) 79 | 80 | try: 81 | # Skip the file header: 82 | fheader = zef_file.read(sizeFileHeader) 83 | if len(fheader) != sizeFileHeader: 84 | raise BadZipFile("Truncated file header") 85 | fheader = struct.unpack(structFileHeader, fheader) 86 | if fheader[_FH_SIGNATURE] != stringFileHeader: 87 | raise BadZipFile("Bad magic number for file header") 88 | 89 | fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) 90 | if fheader[_FH_EXTRA_FIELD_LENGTH]: 91 | zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1) 92 | 93 | if zinfo.flag_bits & _MASK_COMPRESSED_PATCH: 94 | # Zip 2.7: compressed patched data 95 | raise NotImplementedError("compressed patched data (flag bit 5)") 96 | 97 | if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION: 98 | # strong encryption 99 | raise NotImplementedError("strong encryption (flag bit 6)") 100 | 101 | if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME: 102 | # UTF-8 filename 103 | fname_str = fname.decode("utf-8") 104 | else: 105 | fname_str = fname.decode(self.metadata_encoding or "cp437") 106 | 107 | if fname_str != zinfo.orig_filename: 108 | raise BadZipFile( 109 | 'File name in directory %r and header %r differ.' 110 | % (zinfo.orig_filename, fname)) 111 | 112 | # check for encrypted flag & handle password 113 | is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED 114 | if is_encrypted: 115 | if not pwd: 116 | pwd = self.pwd 117 | if pwd and not isinstance(pwd, bytes): 118 | raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__) 119 | if not pwd: 120 | raise RuntimeError("File %r is encrypted, password " 121 | "required for extraction" % name) 122 | else: 123 | pwd = None 124 | 125 | return ZipExtFile(zef_file, mode, zinfo, pwd, True) 126 | except: 127 | zef_file.close() 128 | raise -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "moge" 7 | version = "1.0.0" 8 | description = "MoGe: Unlocking Accurate Monocular Geometry Estimation for Open-Domain Images with Optimal Training Supervision" 9 | readme = "README.md" 10 | license = {text = "MIT"} 11 | dependencies = [ 12 | "click", 13 | "opencv-python", 14 | "scipy", 15 | "matplotlib", 16 | "trimesh", 17 | "pillow", 18 | "huggingface_hub", 19 | "numpy", 20 | "torch>=2.0.0", 21 | "torchvision", 22 | "gradio", 23 | "utils3d @ git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900" 24 | ] 25 | requires-python = ">=3.9" 26 | 27 | [project.urls] 28 | Homepage = "https://github.com/microsoft/MoGe" 29 | 30 | [tool.setuptools.packages.find] 31 | where = ["."] 32 | include = ["moge*"] 33 | 34 | [project.scripts] 35 | moge = "moge.scripts.cli:main" -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "include": [ 3 | "moge", 4 | "scripts", 5 | "baselines" 6 | ], 7 | "ignore": [ 8 | "**" 9 | ] 10 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # The versions are not specified since MoGe should be compatible with most versions of the packages. 2 | # If incompatibilities are found, consider upgrading to latest versions or installing the following recommended version of the package. 3 | torch # >= 2.0.0 4 | torchvision 5 | gradio # ==2.8.13 6 | click # ==8.1.7 7 | opencv-python # ==4.10.0.84 8 | scipy # ==1.14.1 9 | matplotlib # ==3.9.2 10 | trimesh # ==4.5.1 11 | pillow # ==10.4.0 12 | huggingface_hub # ==0.25.2 13 | git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900 14 | --------------------------------------------------------------------------------