├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ └── feature_request.md
├── .gitignore
├── LICENSE
├── activation.py
├── assets
├── advanced.md
└── update_logs.md
├── config
├── anya.csv
├── car.csv
└── corgi.csv
├── data
├── anya_back.webp
├── anya_back_depth.png
├── anya_back_normal.png
├── anya_back_rgba.png
├── anya_front.jpg
├── anya_front.png
├── anya_front_depth.png
├── anya_front_normal.png
├── anya_front_rgba.png
├── baby_phoenix_on_ice.png
├── beach_house_1.png
├── beach_house_2.png
├── bollywood_actress.png
├── cactus.png
├── cactus_depth.png
├── cactus_normal.png
├── cactus_rgba.png
├── cake.png
├── cake_depth.png
├── cake_normal.png
├── cake_rgba.png
├── car_back.jpg
├── car_front.jpg
├── car_left.jpg
├── car_right.jpg
├── catstatue.png
├── catstatue_depth.png
├── catstatue_normal.png
├── catstatue_rgba.png
├── church_ruins.png
├── corgi_puppy_sitting_looking_up.jpg
├── firekeeper.jpg
├── firekeeper_depth.png
├── firekeeper_normal.png
├── firekeeper_rgba.png
├── futuristic_car.png
├── hamburger.png
├── hamburger_depth.png
├── hamburger_normal.png
├── hamburger_rgba.png
├── mona_lisa.png
├── teddy.png
├── teddy_depth.png
├── teddy_normal.png
└── teddy_rgba.png
├── docker
├── Dockerfile
└── README.md
├── dpt.py
├── encoding.py
├── evaluation
├── Prompt.py
├── mesh_to_video.py
├── r_precision.py
└── readme.md
├── freqencoder
├── __init__.py
├── backend.py
├── freq.py
├── setup.py
└── src
│ ├── bindings.cpp
│ ├── freqencoder.cu
│ └── freqencoder.h
├── gridencoder
├── __init__.py
├── backend.py
├── grid.py
├── setup.py
└── src
│ ├── bindings.cpp
│ ├── gridencoder.cu
│ └── gridencoder.h
├── guidance
├── clip_utils.py
├── if_utils.py
├── perpneg_utils.py
├── sd_utils.py
└── zero123_utils.py
├── ldm
├── extras.py
├── guidance.py
├── lr_scheduler.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── plms.py
│ │ └── sampling_util.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── evaluate
│ │ ├── adm_evaluator.py
│ │ ├── evaluate_perceptualsim.py
│ │ ├── frechet_video_distance.py
│ │ ├── ssim.py
│ │ └── torch_frechet_video_distance.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
├── thirdp
│ └── psp
│ │ ├── helpers.py
│ │ ├── id_loss.py
│ │ └── model_irse.py
└── util.py
├── main.py
├── meshutils.py
├── nerf
├── gui.py
├── network.py
├── network_grid.py
├── network_grid_taichi.py
├── network_grid_tcnn.py
├── provider.py
├── renderer.py
└── utils.py
├── optimizer.py
├── preprocess_image.py
├── pretrained
└── zero123
│ └── sd-objaverse-finetune-c_concat-256.yaml
├── raymarching
├── __init__.py
├── backend.py
├── raymarching.py
├── setup.py
└── src
│ ├── bindings.cpp
│ ├── raymarching.cu
│ └── raymarching.h
├── readme.md
├── requirements.txt
├── scripts
├── install_ext.sh
├── res64.args
├── run.sh
├── run2.sh
├── run3.sh
├── run4.sh
├── run5.sh
├── run6.sh
├── run_if.sh
├── run_if2.sh
├── run_if2_perpneg.sh
├── run_image.sh
├── run_image_anya.sh
├── run_image_hard_examples.sh
├── run_image_procedure.sh
├── run_image_text.sh
└── run_images.sh
├── shencoder
├── __init__.py
├── backend.py
├── setup.py
├── sphere_harmonics.py
└── src
│ ├── bindings.cpp
│ ├── shencoder.cu
│ └── shencoder.h
├── taichi_modules
├── __init__.py
├── hash_encoder.py
├── intersection.py
├── ray_march.py
├── utils.py
├── volume_render_test.py
└── volume_train.py
└── tets
├── 128_tets.npz
├── 32_tets.npz
├── 64_tets.npz
├── README.md
└── generate_tets.py
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: File a bug report
3 | title: "
"
4 | labels: ["bug"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Before filing a bug report, [search for an existing issue](https://github.com/ashawkey/stable-dreamfusion/issues).
10 |
11 | Also, ensure you are running the latest version.
12 | - type: textarea
13 | id: description
14 | attributes:
15 | label: Description
16 | description: Provide a clear and concise description of what the bug is.
17 | placeholder: Description
18 | validations:
19 | required: true
20 | - type: textarea
21 | id: steps
22 | attributes:
23 | label: Steps to Reproduce
24 | description: List the steps needed to reproduce the issue.
25 | placeholder: |
26 | 1. Go to '...'
27 | 2. Click on '...'
28 | validations:
29 | required: true
30 | - type: textarea
31 | id: expected-behavior
32 | attributes:
33 | label: Expected Behavior
34 | description: Describe what you expected to happen.
35 | placeholder: |
36 | The 'action' would do 'some amazing thing'.
37 | validations:
38 | required: true
39 | - type: textarea
40 | id: environment
41 | attributes:
42 | label: Environment
43 | description: Describe your environment.
44 | placeholder: |
45 | Ubuntu 22.04, PyTorch 1.13, CUDA 11.6
46 | validations:
47 | required: true
48 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | build/
3 | *.egg-info/
4 | *.so
5 | venv_*/
6 |
7 | tmp*
8 | # data/
9 | ldm/data/
10 | data2
11 | scripts2
12 | trial*/
13 | .vs/
14 |
15 | TOKEN
16 | *.ckpt
17 |
18 | densegridencoder
19 | tets/256_tets.npz
20 |
21 | .vscode/launch.json
22 |
23 | data2
24 | data/car*
25 | data/chair*
26 | data/warrior*
27 | data/wd*
28 | data/space*
29 | data/corgi*
30 | data/turtle*
31 |
32 | # Only keep the original image, not the automatically-generated depth, normals, rgba
33 | data/baby_phoenix_on_ice_*
34 | data/bollywood_actress_*
35 | data/beach_house_1_*
36 | data/beach_house_2_*
37 | data/mona_lisa_*
38 | data/futuristic_car_*
39 | data/church_ruins_*
40 |
41 |
--------------------------------------------------------------------------------
/activation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch.cuda.amp import custom_bwd, custom_fwd
4 |
5 | class _trunc_exp(Function):
6 | @staticmethod
7 | @custom_fwd(cast_inputs=torch.float)
8 | def forward(ctx, x):
9 | ctx.save_for_backward(x)
10 | return torch.exp(x)
11 |
12 | @staticmethod
13 | @custom_bwd
14 | def backward(ctx, g):
15 | x = ctx.saved_tensors[0]
16 | return g * torch.exp(x.clamp(max=15))
17 |
18 | trunc_exp = _trunc_exp.apply
19 |
20 | def biased_softplus(x, bias=0):
21 | return torch.nn.functional.softplus(x - bias)
--------------------------------------------------------------------------------
/assets/advanced.md:
--------------------------------------------------------------------------------
1 |
2 | # Code organization & Advanced tips
3 |
4 | This is a simple description of the most important implementation details.
5 | If you are interested in improving this repo, this might be a starting point.
6 | Any contribution would be greatly appreciated!
7 |
8 | * The SDS loss is located at `./guidance/sd_utils.py > StableDiffusion > train_step`:
9 | ```python
10 | ## 1. we need to interpolate the NeRF rendering to 512x512, to feed it to SD's VAE.
11 | pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
12 | ## 2. image (512x512) --- VAE --> latents (64x64), this is SD's difference from Imagen.
13 | latents = self.encode_imgs(pred_rgb_512)
14 | ... # timestep sampling, noise adding and UNet noise predicting
15 | ## 3. the SDS loss
16 | w = (1 - self.alphas[t])
17 | grad = w * (noise_pred - noise)
18 | # since UNet part is ignored and cannot simply audodiff, we have two ways to set the grad:
19 | # 3.1. call backward and set the grad now (need to retain graph since we will call a second backward for the other losses later)
20 | latents.backward(gradient=grad, retain_graph=True)
21 | return 0 # dummy loss
22 |
23 | # 3.2. use a custom function to set a hook in backward, so we only call backward once (credits to @elliottzheng)
24 | class SpecifyGradient(torch.autograd.Function):
25 | @staticmethod
26 | @custom_fwd
27 | def forward(ctx, input_tensor, gt_grad):
28 | ctx.save_for_backward(gt_grad)
29 | # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
30 | return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
31 |
32 | @staticmethod
33 | @custom_bwd
34 | def backward(ctx, grad_scale):
35 | gt_grad, = ctx.saved_tensors
36 | gt_grad = gt_grad * grad_scale
37 | return gt_grad, None
38 |
39 | loss = SpecifyGradient.apply(latents, grad)
40 | return loss # functional loss
41 |
42 | # 3.3. reparameterization (credits to @Xallt)
43 | # d(loss)/d(latents) = grad, since grad is already detached, it's this simple.
44 | loss = (grad * latents).sum()
45 | return loss
46 |
47 | # 3.4. reparameterization (credits to threestudio)
48 | # this is the same as 3.3, but the loss value only reflects the magnitude of grad, which is more informative.
49 | targets = (latents - grad).detach()
50 | loss = 0.5 * F.mse_loss(latents, targets, reduction='sum')
51 | return loss
52 | ```
53 | * Other regularizations are in `./nerf/utils.py > Trainer > train_step`.
54 | * The generation seems quite sensitive to regularizations on weights_sum (alphas for each ray). The original opacity loss tends to make NeRF disappear (zero density everywhere), so we use an entropy loss to replace it for now (encourages alpha to be either 0 or 1).
55 | * NeRF Rendering core function: `./nerf/renderer.py > NeRFRenderer > run & run_cuda`.
56 | * Shading & normal evaluation: `./nerf/network*.py > NeRFNetwork > forward`.
57 | * light direction: current implementation use a plane light source, instead of a point light source.
58 | * View-dependent prompting: `./nerf/provider.py > get_view_direction`.
59 | * use `--angle_overhead, --angle_front` to set the border.
60 | * Network backbone (`./nerf/network*.py`) can be chosen by the `--backbone` option.
61 | * Spatial density bias (density blob): `./nerf/network*.py > NeRFNetwork > density_blob`.
62 |
63 |
64 | # Debugging
65 |
66 | `debugpy-run` is a convenient way to remotely debug this project. Simply replace a command like this one:
67 |
68 | ```bash
69 | python main.py --text "a hamburger" --workspace trial -O --vram_O
70 | ```
71 |
72 | ... with:
73 |
74 | ```bash
75 | debugpy-run main.py -- --text "a hamburger" --workspace trial -O --vram_O
76 | ```
77 |
78 | For more details: https://github.com/bulletmark/debugpy-run
79 |
80 | # Axes and directions of polar, azimuth, etc. in NeRF and Zero123
81 |
82 |
83 |
84 | This code refers to theta for polar, phi for azimuth.
85 |
86 |
--------------------------------------------------------------------------------
/assets/update_logs.md:
--------------------------------------------------------------------------------
1 | ### 2023.4.19
2 | * Fix depth supervision, migrate depth estimation model to omnidata.
3 | * Add normal supervision (also by omnidata).
4 |
5 | https://user-images.githubusercontent.com/25863658/232403294-b77409bf-ddc7-4bb8-af32-ee0cc123825a.mp4
6 |
7 | ### 2023.4.7
8 | Improvement on mesh quality & DMTet finetuning support.
9 |
10 | https://user-images.githubusercontent.com/25863658/230535363-298c960e-bf9c-4906-8b96-cd60edcb24dd.mp4
11 |
12 | ### 2023.3.30
13 | * adopt ideas from [Fantasia3D](https://fantasia3d.github.io/) to concatenate normal and mask as the latent code in a warm up stage, which shows faster convergence of shape.
14 |
15 | https://user-images.githubusercontent.com/25863658/230535373-6ee28f16-bb21-4ec4-bc86-d46597361a04.mp4
16 |
17 | ### 2023.1.30
18 | * Use an MLP to predict the surface normals as in Magic3D to avoid finite difference / second order gradient, generation quality is greatly improved.
19 | * More efficient two-pass raymarching in training inspired by nerfacc.
20 |
21 | https://user-images.githubusercontent.com/25863658/215996308-9fd959f5-b5c7-4a8e-a241-0fe63ec86a4a.mp4
22 |
23 | ### 2022.12.3
24 | * Support Stable-diffusion 2.0 base.
25 |
26 | ### 2022.11.15
27 | * Add the vanilla backbone that is pure-pytorch.
28 |
29 | ### 2022.10.9
30 | * The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled.
31 | * Enable shading by default (--latent_iter_ratio 1000).
32 |
33 | ### 2022.10.5
34 | * Basic reproduction finished.
35 | * Non --cuda_ray, --tcnn are not working, need to fix.
36 | * Shading is not working, disabled in utils.py for now. Surface normals are bad.
37 | * Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...
38 |
39 | https://user-images.githubusercontent.com/25863658/194241493-f3e68f78-aefe-479e-a4a8-001424a61b37.mp4
40 |
--------------------------------------------------------------------------------
/config/anya.csv:
--------------------------------------------------------------------------------
1 | zero123_weight, radius, polar, azimuth, image
2 | 1, 3, 90, 0, data/anya_front_rgba.png
3 | 1, 3, 90, 180, data/anya_back_rgba.png
--------------------------------------------------------------------------------
/config/car.csv:
--------------------------------------------------------------------------------
1 | zero123_weight, radius, polar, azimuth, image
2 | 4, 3.2, 90, 0, data/car_left_rgba.png
3 | 1, 3, 90, 90, data/car_front_rgba.png
4 | 4, 3.2, 90, 180, data/car_right_rgba.png
5 | 1, 3, 90, -90, data/car_back_rgba.png
--------------------------------------------------------------------------------
/config/corgi.csv:
--------------------------------------------------------------------------------
1 | zero123_weight, radius, polar, azimuth, image
2 | 1, 3.2, 90, 0, data/corgi_puppy_sitting_looking_up_rgba.png
--------------------------------------------------------------------------------
/data/anya_back.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_back.webp
--------------------------------------------------------------------------------
/data/anya_back_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_back_depth.png
--------------------------------------------------------------------------------
/data/anya_back_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_back_normal.png
--------------------------------------------------------------------------------
/data/anya_back_rgba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_back_rgba.png
--------------------------------------------------------------------------------
/data/anya_front.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front.jpg
--------------------------------------------------------------------------------
/data/anya_front.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front.png
--------------------------------------------------------------------------------
/data/anya_front_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front_depth.png
--------------------------------------------------------------------------------
/data/anya_front_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front_normal.png
--------------------------------------------------------------------------------
/data/anya_front_rgba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front_rgba.png
--------------------------------------------------------------------------------
/data/baby_phoenix_on_ice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/baby_phoenix_on_ice.png
--------------------------------------------------------------------------------
/data/beach_house_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/beach_house_1.png
--------------------------------------------------------------------------------
/data/beach_house_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/beach_house_2.png
--------------------------------------------------------------------------------
/data/bollywood_actress.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/bollywood_actress.png
--------------------------------------------------------------------------------
/data/cactus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cactus.png
--------------------------------------------------------------------------------
/data/cactus_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cactus_depth.png
--------------------------------------------------------------------------------
/data/cactus_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cactus_normal.png
--------------------------------------------------------------------------------
/data/cactus_rgba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cactus_rgba.png
--------------------------------------------------------------------------------
/data/cake.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cake.png
--------------------------------------------------------------------------------
/data/cake_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cake_depth.png
--------------------------------------------------------------------------------
/data/cake_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cake_normal.png
--------------------------------------------------------------------------------
/data/cake_rgba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cake_rgba.png
--------------------------------------------------------------------------------
/data/car_back.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/car_back.jpg
--------------------------------------------------------------------------------
/data/car_front.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/car_front.jpg
--------------------------------------------------------------------------------
/data/car_left.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/car_left.jpg
--------------------------------------------------------------------------------
/data/car_right.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/car_right.jpg
--------------------------------------------------------------------------------
/data/catstatue.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/catstatue.png
--------------------------------------------------------------------------------
/data/catstatue_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/catstatue_depth.png
--------------------------------------------------------------------------------
/data/catstatue_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/catstatue_normal.png
--------------------------------------------------------------------------------
/data/catstatue_rgba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/catstatue_rgba.png
--------------------------------------------------------------------------------
/data/church_ruins.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/church_ruins.png
--------------------------------------------------------------------------------
/data/corgi_puppy_sitting_looking_up.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/corgi_puppy_sitting_looking_up.jpg
--------------------------------------------------------------------------------
/data/firekeeper.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/firekeeper.jpg
--------------------------------------------------------------------------------
/data/firekeeper_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/firekeeper_depth.png
--------------------------------------------------------------------------------
/data/firekeeper_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/firekeeper_normal.png
--------------------------------------------------------------------------------
/data/firekeeper_rgba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/firekeeper_rgba.png
--------------------------------------------------------------------------------
/data/futuristic_car.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/futuristic_car.png
--------------------------------------------------------------------------------
/data/hamburger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/hamburger.png
--------------------------------------------------------------------------------
/data/hamburger_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/hamburger_depth.png
--------------------------------------------------------------------------------
/data/hamburger_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/hamburger_normal.png
--------------------------------------------------------------------------------
/data/hamburger_rgba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/hamburger_rgba.png
--------------------------------------------------------------------------------
/data/mona_lisa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/mona_lisa.png
--------------------------------------------------------------------------------
/data/teddy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/teddy.png
--------------------------------------------------------------------------------
/data/teddy_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/teddy_depth.png
--------------------------------------------------------------------------------
/data/teddy_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/teddy_normal.png
--------------------------------------------------------------------------------
/data/teddy_rgba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/teddy_rgba.png
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
2 |
3 | # Remove any third-party apt sources to avoid issues with expiring keys.
4 | RUN rm -f /etc/apt/sources.list.d/*.list
5 |
6 | RUN apt-get update
7 |
8 | RUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata
9 |
10 | # Install some basic utilities
11 | RUN apt-get install -y \
12 | curl \
13 | ca-certificates \
14 | sudo \
15 | git \
16 | bzip2 \
17 | libx11-6 \
18 | python3 \
19 | python3-pip \
20 | libglfw3-dev \
21 | libgles2-mesa-dev \
22 | libglib2.0-0 \
23 | && rm -rf /var/lib/apt/lists/*
24 |
25 |
26 | # Create a working directory
27 | RUN mkdir /app
28 | WORKDIR /app
29 |
30 | RUN cd /app
31 | RUN git clone https://github.com/ashawkey/stable-dreamfusion.git
32 |
33 |
34 | RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
35 |
36 | WORKDIR /app/stable-dreamfusion
37 |
38 | RUN pip3 install -r requirements.txt
39 | RUN pip3 install git+https://github.com/NVlabs/nvdiffrast/
40 |
41 | # Needs nvidia runtime, if you have "No CUDA runtime is found" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer
42 | RUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
43 |
44 | RUN pip3 install git+https://github.com/openai/CLIP.git
45 | RUN bash scripts/install_ext.sh
46 |
47 |
48 |
49 |
50 |
51 | # Set the default command to python3
52 | #CMD ["python3"]
53 |
54 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | ### Docker installation
2 |
3 | ## Build image
4 | To build the docker image on your own machine, which may take 15-30 mins:
5 | ```
6 | docker build -t stable-dreamfusion:latest .
7 | ```
8 |
9 | If you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker.
10 | ```
11 | sudo apt-get install nvidia-container-runtime
12 | ```
13 | Then edit `/etc/docker/daemon.json` and add the default-runtime:
14 | ```
15 | {
16 | "runtimes": {
17 | "nvidia": {
18 | "path": "nvidia-container-runtime",
19 | "runtimeArgs": []
20 | }
21 | },
22 | "default-runtime": "nvidia"
23 | }
24 | ```
25 | And restart docker:
26 | ```
27 | sudo systemctl restart docker
28 | ```
29 | Now you can build tiny-cuda-nn inside docker.
30 |
31 | ## Download image
32 | To download the image (~6GB) instead:
33 | ```
34 | docker pull supercabb/stable-dreamfusion:3080_0.0.1
35 | docker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion
36 | ```
37 |
38 | ## Use image
39 |
40 | You can launch an interactive shell inside the container:
41 |
42 | ```
43 | docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash
44 | ```
45 | From this shell, all the code in the repo should work.
46 |
47 | To run any single command `` inside the docker container:
48 | ```
49 | docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c ""
50 | ```
51 | To train:
52 | ```
53 | export TOKEN="#HUGGING FACE ACCESS TOKEN#"
54 | docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "echo ${TOKEN} > TOKEN \
55 | && python3 main.py --text \"a hamburger\" --workspace trial -O"
56 |
57 | ```
58 | Run test without gui:
59 | ```
60 | export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
61 | docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
62 | -v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
63 | main.py --workspace trial -O --test"
64 | ```
65 | Run test with gui:
66 | ```
67 | export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
68 | xhost +
69 | docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
70 | -v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
71 | main.py --workspace trial -O --test --gui"
72 | xhost -
73 | ```
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FreqEncoder_torch(nn.Module):
6 | def __init__(self, input_dim, max_freq_log2, N_freqs,
7 | log_sampling=True, include_input=True,
8 | periodic_fns=(torch.sin, torch.cos)):
9 |
10 | super().__init__()
11 |
12 | self.input_dim = input_dim
13 | self.include_input = include_input
14 | self.periodic_fns = periodic_fns
15 | self.N_freqs = N_freqs
16 |
17 | self.output_dim = 0
18 | if self.include_input:
19 | self.output_dim += self.input_dim
20 |
21 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
22 |
23 | if log_sampling:
24 | self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)
25 | else:
26 | self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)
27 |
28 | self.freq_bands = self.freq_bands.numpy().tolist()
29 |
30 | def forward(self, input, max_level=None, **kwargs):
31 |
32 | if max_level is None:
33 | max_level = self.N_freqs
34 | else:
35 | max_level = int(max_level * self.N_freqs)
36 |
37 | out = []
38 | if self.include_input:
39 | out.append(input)
40 |
41 | for i in range(max_level):
42 | freq = self.freq_bands[i]
43 | for p_fn in self.periodic_fns:
44 | out.append(p_fn(input * freq))
45 |
46 | # append 0
47 | if self.N_freqs - max_level > 0:
48 | out.append(torch.zeros(*input.shape[:-1], (self.N_freqs - max_level) * 2 * input.shape[-1], device=input.device, dtype=input.dtype))
49 |
50 | out = torch.cat(out, dim=-1)
51 |
52 | return out
53 |
54 | def get_encoder(encoding, input_dim=3,
55 | multires=6,
56 | degree=4,
57 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear',
58 | **kwargs):
59 |
60 | if encoding == 'None':
61 | return lambda x, **kwargs: x, input_dim
62 |
63 | elif encoding == 'frequency_torch':
64 | encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
65 |
66 | elif encoding == 'frequency': # CUDA implementation, faster than torch.
67 | from freqencoder import FreqEncoder
68 | encoder = FreqEncoder(input_dim=input_dim, degree=multires)
69 |
70 | elif encoding == 'sphere_harmonics':
71 | from shencoder import SHEncoder
72 | encoder = SHEncoder(input_dim=input_dim, degree=degree)
73 |
74 | elif encoding == 'hashgrid':
75 | from gridencoder import GridEncoder
76 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation)
77 |
78 | elif encoding == 'tiledgrid':
79 | from gridencoder import GridEncoder
80 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation)
81 |
82 | elif encoding == 'hashgrid_taichi':
83 | from taichi_modules.hash_encoder import HashEncoderTaichi
84 | encoder = HashEncoderTaichi(batch_size=4096) #TODO: hard encoded batch size
85 |
86 | else:
87 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
88 |
89 | return encoder, encoder.output_dim
--------------------------------------------------------------------------------
/evaluation/Prompt.py:
--------------------------------------------------------------------------------
1 | import textwrap
2 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification
3 | from transformers import pipeline
4 | import argparse
5 | import sys
6 | import warnings
7 | warnings.filterwarnings("ignore", category=UserWarning)
8 |
9 |
10 | #python Prompt.py --text "a dog is in front of a rabbit" --model vlt5
11 |
12 |
13 | if __name__ == '__main__':
14 |
15 | # Mimic the calling part of the main, using
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--text', default="", type=str, help="text prompt")
18 | #parser.add_argument('--workspace', default="trial", type=str, help="workspace")
19 | parser.add_argument('--model', default='vlt5', type=str, help="model choices - vlt5, bert, XLNet")
20 |
21 | opt = parser.parse_args()
22 |
23 | if opt.model == "vlt5":
24 | tokenizer = AutoTokenizer.from_pretrained("Voicelab/vlt5-base-keywords")
25 | model = AutoModelForSeq2SeqLM.from_pretrained("Voicelab/vlt5-base-keywords")
26 |
27 | task_prefix = "Keywords: "
28 | inputs = [
29 | opt.text
30 | ]
31 |
32 | for sample in inputs:
33 | input_sequences = [task_prefix + sample]
34 | input_ids = tokenizer(
35 | input_sequences, return_tensors="pt", truncation=True
36 | ).input_ids
37 | output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4)
38 | output_text = tokenizer.decode(output[0], skip_special_tokens=True)
39 | #print(sample, "\n --->", output_text)
40 |
41 | elif opt.model == "bert":
42 | tokenizer = AutoTokenizer.from_pretrained("yanekyuk/bert-uncased-keyword-extractor")
43 | model = AutoModelForTokenClassification.from_pretrained("yanekyuk/bert-uncased-keyword-extractor")
44 |
45 | text = opt.text
46 | input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
47 |
48 | # Classify tokens
49 | outputs = model(input_ids)
50 | predictions = outputs.logits.detach().numpy()[0]
51 | labels = predictions.argmax(axis=1)
52 | labels = labels[1:-1]
53 |
54 | print(labels)
55 | tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
56 | tokens = tokens[1:-1]
57 | output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0]
58 | output_text = tokenizer.convert_tokens_to_string(output_tokens)
59 |
60 | #print(output_text)
61 |
62 |
63 | elif opt.model == "XLNet":
64 | tokenizer = AutoTokenizer.from_pretrained("jasminejwebb/KeywordIdentifier")
65 | model = AutoModelForTokenClassification.from_pretrained("jasminejwebb/KeywordIdentifier")
66 |
67 | text = opt.text
68 | input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
69 |
70 | # Classify tokens
71 | outputs = model(input_ids)
72 | predictions = outputs.logits.detach().numpy()[0]
73 | labels = predictions.argmax(axis=1)
74 | labels = labels[1:-1]
75 |
76 | print(labels)
77 | tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
78 | tokens = tokens[1:-1]
79 | output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0]
80 | output_text = tokenizer.convert_tokens_to_string(output_tokens)
81 |
82 | #print(output_text)
83 |
84 | wrapped_text = textwrap.fill(output_text, width=50)
85 |
86 |
87 | print('+' + '-'*52 + '+')
88 | for line in wrapped_text.split('\n'):
89 | print('| {} |'.format(line.ljust(50)))
90 | print('+' + '-'*52 + '+')
91 | #print(result)
92 |
--------------------------------------------------------------------------------
/evaluation/mesh_to_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import trimesh
4 | import argparse
5 | from pathlib import Path
6 | from tqdm import tqdm
7 | import pyvista as pv
8 |
9 | def render_video(anim_mesh):
10 | center = anim_mesh.center_mass
11 | plotter = pv.Plotter(off_screen=True)
12 | plotter.add_mesh(anim_mesh)
13 |
14 | radius = 10
15 | n_frames = 360
16 | angle_step = 2 * np.pi / n_frames
17 | for i in tqdm(range(n_frames)):
18 | camera_pos = [center[0] + radius * np.cos(i*angle_step),center[1] + radius *np.sin(i*angle_step),center[2]]
19 | plotter.camera_position = (camera_pos, center, (0, 0, 1))
20 | plotter.show(screenshot=f'frame_{i}.png', auto_close=False)
21 | plotter.close()
22 | os.system('ffmpeg -r 30 -f image2 -s 1920x1080 -i "result/frame_%d.png" -vcodec libx264 -crf 25 -pix_fmt yuv420p result/output.mp4')
23 |
24 |
25 |
26 | def generate_mesh(obj1,obj2,transform_vector):
27 |
28 | # Read 2 objects
29 | filename1 = obj1 # Central Object
30 | filename2 = obj2 # Surrounding Object
31 | mesh1 = trimesh.load_mesh(filename1)
32 | mesh2 = trimesh.load_mesh(filename2)
33 |
34 | extents1 = mesh1.extents
35 | extents2 = mesh1.extents
36 |
37 | radius1 = sum(extents1) / 3.0
38 | radius2 = sum(extents2) / 3.0
39 |
40 | center1 = mesh1.center_mass
41 | center2 = mesh2.center_mass
42 |
43 | # Move
44 | T1 = -center1
45 | new =[]
46 | for i in transform_vector:
47 | try:
48 | new.append(float(i))*radius1
49 | except:
50 | pass
51 | transform_vector = new
52 | print(T1, transform_vector, radius1)
53 | T2 = -center2 + transform_vector
54 |
55 | # Transform
56 | mesh1.apply_translation(T1)
57 | mesh2.apply_translation(T2)
58 |
59 | # merge mesh
60 | merged_mesh = trimesh.util.concatenate((mesh1, mesh2))
61 |
62 | # save mesh
63 | merged_mesh.export('merged_mesh.obj')
64 | print("----> merge mesh done")
65 |
66 | if __name__ == '__main__':
67 | parser = argparse.ArgumentParser(description='Generate rotating mesh animation.')
68 | parser.add_argument('--center_obj', type=str, help='Input OBJ1 file.')
69 | parser.add_argument('--surround_obj', type=str, help='Input OBJ2 file.')
70 | parser.add_argument('--transform_vector', help='Transform_vector.')
71 | parser.add_argument('--output_file', type=str, default="result/Demo.mp4", help='Output MP4 file.')
72 | parser.add_argument('--num_frames', type=int, default=100, help='Number of frames to render.')
73 | args = parser.parse_args()
74 |
75 | #mesh = obj.Obj("wr.obj")
76 | generate_mesh(args.center_obj,args.surround_obj,args.transform_vector)
77 |
78 | input_file = Path("merged_mesh.obj")
79 | output_file = Path(args.output_file)
80 |
81 | out_dir = output_file.parent.joinpath('frames')
82 | out_dir.mkdir(parents=True, exist_ok=True)
83 |
84 | anim_mesh = trimesh.load_mesh(str(input_file))
85 |
86 | render_video(anim_mesh)
87 |
88 |
--------------------------------------------------------------------------------
/evaluation/r_precision.py:
--------------------------------------------------------------------------------
1 | from sentence_transformers import SentenceTransformer, util
2 | from PIL import Image
3 | import argparse
4 | import sys
5 |
6 |
7 | if __name__ == '__main__':
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--text', default="", type=str, help="text prompt")
11 | parser.add_argument('--workspace', default="trial", type=str, help="text prompt")
12 | parser.add_argument('--latest', default='ep0001', type=str, help="which epoch result you want to use for image path")
13 | parser.add_argument('--mode', default='rgb', type=str, help="mode of result, color(rgb) or textureless()")
14 | parser.add_argument('--clip', default="clip-ViT-B-32", type=str, help="CLIP model to encode the img and prompt")
15 |
16 | opt = parser.parse_args()
17 |
18 | #Load CLIP model
19 | model = SentenceTransformer(f'{opt.clip}')
20 |
21 | #Encode an image:
22 | img_emb = model.encode(Image.open(f'../results/{opt.workspace}/validation/df_{opt.latest}_0005_{opt.mode}.png'))
23 |
24 | #Encode text descriptions
25 | text_emb = model.encode([f'{opt.text}'])
26 |
27 | #Compute cosine similarities
28 | cos_scores = util.cos_sim(img_emb, text_emb)
29 | print("The final CLIP R-Precision is:", cos_scores[0][0].cpu().numpy())
30 |
31 |
--------------------------------------------------------------------------------
/evaluation/readme.md:
--------------------------------------------------------------------------------
1 | ### Improvement:
2 |
3 | - Usage
4 |
5 | - r_precision.py
6 | For prompt seperation
7 | --text is for the prompt following the author of stable dream fusion
8 | --workspace is the workspace folder which will be created for every prompt fed into stable dreamfusion
9 | --latest is which ckpt is used. Stable dream fusion record every epoch data. Normally is ep0100 unless the training is not finished or we further extend the training
10 | --mode has choices of rgb and depth which is correspondent to color and texture result as original paper Figure 5: Qualitative comparison with baselines.
11 | --clip has choices of clip-ViT-B-32, CLIP B/16, CLIP L/14, same as original paper
12 |
13 | ```bash
14 | python Prompt.py --text "matte painting of a castle made of cheesecake surrounded by a moat made of ice cream" --workspace ../castle --latest ep0100 --mode rgb --clip clip-ViT-B-32
15 | ```
16 |
17 | - Prompt.py (model name case sensitive)
18 | For prompt seperation
19 | --text is for the prompt following the author of stable dream fusion
20 | --model is for choose the pretrain models
21 |
22 | ```bash
23 | python Prompt.py --text "a dog is in front of a rabbit" --model vlt5
24 | python Prompt.py --text "a dog is in front of a rabbit" --model bert
25 | python Prompt.py --text "a dog is in front of a rabbit" --model XLNet
26 | ```
27 |
28 |
29 | - mesh_to_video.py
30 | --center_obj IS THE CENTER OBJECT
31 | --surround_obj IS THE SURROUNDING OBJECT SUBJECT TO CHANGE
32 | --transform_vector THE X Y Z 3d vector for transform
33 |
34 | ```bash
35 | python mesh_to_video.py --center_obj 'mesh_whiterabbit/mesh.obj' --surround_obj 'mesh_snake/mesh.obj' --transform_vector [1,0,0]
36 | ```
37 |
--------------------------------------------------------------------------------
/freqencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .freq import FreqEncoder
--------------------------------------------------------------------------------
/freqencoder/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | '-use_fast_math'
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
22 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
23 | if paths:
24 | return paths[0]
25 |
26 | # If cl.exe is not on path, try to find it.
27 | if os.system("where cl.exe >nul 2>nul") != 0:
28 | cl_path = find_cl_path()
29 | if cl_path is None:
30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
31 | os.environ["PATH"] += ";" + cl_path
32 |
33 | _backend = load(name='_freqencoder',
34 | extra_cflags=c_flags,
35 | extra_cuda_cflags=nvcc_flags,
36 | sources=[os.path.join(_src_path, 'src', f) for f in [
37 | 'freqencoder.cu',
38 | 'bindings.cpp',
39 | ]],
40 | )
41 |
42 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/freqencoder/freq.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Function
6 | from torch.autograd.function import once_differentiable
7 | from torch.cuda.amp import custom_bwd, custom_fwd
8 |
9 | try:
10 | import _freqencoder as _backend
11 | except ImportError:
12 | from .backend import _backend
13 |
14 |
15 | class _freq_encoder(Function):
16 | @staticmethod
17 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
18 | def forward(ctx, inputs, degree, output_dim):
19 | # inputs: [B, input_dim], float
20 | # RETURN: [B, F], float
21 |
22 | if not inputs.is_cuda: inputs = inputs.cuda()
23 | inputs = inputs.contiguous()
24 |
25 | B, input_dim = inputs.shape # batch size, coord dim
26 |
27 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
28 |
29 | _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
30 |
31 | ctx.save_for_backward(inputs, outputs)
32 | ctx.dims = [B, input_dim, degree, output_dim]
33 |
34 | return outputs
35 |
36 | @staticmethod
37 | #@once_differentiable
38 | @custom_bwd
39 | def backward(ctx, grad):
40 | # grad: [B, C * C]
41 |
42 | grad = grad.contiguous()
43 | inputs, outputs = ctx.saved_tensors
44 | B, input_dim, degree, output_dim = ctx.dims
45 |
46 | grad_inputs = torch.zeros_like(inputs)
47 | _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
48 |
49 | return grad_inputs, None, None
50 |
51 |
52 | freq_encode = _freq_encoder.apply
53 |
54 |
55 | class FreqEncoder(nn.Module):
56 | def __init__(self, input_dim=3, degree=4):
57 | super().__init__()
58 |
59 | self.input_dim = input_dim
60 | self.degree = degree
61 | self.output_dim = input_dim + input_dim * 2 * degree
62 |
63 | def __repr__(self):
64 | return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
65 |
66 | def forward(self, inputs, **kwargs):
67 | # inputs: [..., input_dim]
68 | # return: [..., ]
69 |
70 | prefix_shape = list(inputs.shape[:-1])
71 | inputs = inputs.reshape(-1, self.input_dim)
72 |
73 | outputs = freq_encode(inputs, self.degree, self.output_dim)
74 |
75 | outputs = outputs.reshape(prefix_shape + [self.output_dim])
76 |
77 | return outputs
--------------------------------------------------------------------------------
/freqencoder/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | '-use_fast_math'
11 | ]
12 |
13 | if os.name == "posix":
14 | c_flags = ['-O3', '-std=c++14']
15 | elif os.name == "nt":
16 | c_flags = ['/O2', '/std:c++17']
17 |
18 | # find cl.exe
19 | def find_cl_path():
20 | import glob
21 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
22 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
23 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
24 | if paths:
25 | return paths[0]
26 |
27 | # If cl.exe is not on path, try to find it.
28 | if os.system("where cl.exe >nul 2>nul") != 0:
29 | cl_path = find_cl_path()
30 | if cl_path is None:
31 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
32 | os.environ["PATH"] += ";" + cl_path
33 |
34 | setup(
35 | name='freqencoder', # package name, import this to use python API
36 | ext_modules=[
37 | CUDAExtension(
38 | name='_freqencoder', # extension name, import this to use CUDA API
39 | sources=[os.path.join(_src_path, 'src', f) for f in [
40 | 'freqencoder.cu',
41 | 'bindings.cpp',
42 | ]],
43 | extra_compile_args={
44 | 'cxx': c_flags,
45 | 'nvcc': nvcc_flags,
46 | }
47 | ),
48 | ],
49 | cmdclass={
50 | 'build_ext': BuildExtension,
51 | }
52 | )
--------------------------------------------------------------------------------
/freqencoder/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "freqencoder.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
7 | m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
8 | }
--------------------------------------------------------------------------------
/freqencoder/src/freqencoder.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 |
10 | #include
11 | #include
12 |
13 | #include
14 |
15 |
16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
18 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
19 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
20 |
21 | inline constexpr __device__ float PI() { return 3.141592653589793f; }
22 |
23 | template
24 | __host__ __device__ T div_round_up(T val, T divisor) {
25 | return (val + divisor - 1) / divisor;
26 | }
27 |
28 | // inputs: [B, D]
29 | // outputs: [B, C], C = D + D * deg * 2
30 | __global__ void kernel_freq(
31 | const float * __restrict__ inputs,
32 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
33 | float * outputs
34 | ) {
35 | // parallel on per-element
36 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
37 | if (t >= B * C) return;
38 |
39 | // get index
40 | const uint32_t b = t / C;
41 | const uint32_t c = t - b * C; // t % C;
42 |
43 | // locate
44 | inputs += b * D;
45 | outputs += t;
46 |
47 | // write self
48 | if (c < D) {
49 | outputs[0] = inputs[c];
50 | // write freq
51 | } else {
52 | const uint32_t col = c / D - 1;
53 | const uint32_t d = c % D;
54 | const uint32_t freq = col / 2;
55 | const float phase_shift = (col % 2) * (PI() / 2);
56 | outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
57 | }
58 | }
59 |
60 | // grad: [B, C], C = D + D * deg * 2
61 | // outputs: [B, C]
62 | // grad_inputs: [B, D]
63 | __global__ void kernel_freq_backward(
64 | const float * __restrict__ grad,
65 | const float * __restrict__ outputs,
66 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
67 | float * grad_inputs
68 | ) {
69 | // parallel on per-element
70 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
71 | if (t >= B * D) return;
72 |
73 | const uint32_t b = t / D;
74 | const uint32_t d = t - b * D; // t % D;
75 |
76 | // locate
77 | grad += b * C;
78 | outputs += b * C;
79 | grad_inputs += t;
80 |
81 | // register
82 | float result = grad[d];
83 | grad += D;
84 | outputs += D;
85 |
86 | for (uint32_t f = 0; f < deg; f++) {
87 | result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
88 | grad += 2 * D;
89 | outputs += 2 * D;
90 | }
91 |
92 | // write
93 | grad_inputs[0] = result;
94 | }
95 |
96 |
97 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
98 | CHECK_CUDA(inputs);
99 | CHECK_CUDA(outputs);
100 |
101 | CHECK_CONTIGUOUS(inputs);
102 | CHECK_CONTIGUOUS(outputs);
103 |
104 | CHECK_IS_FLOATING(inputs);
105 | CHECK_IS_FLOATING(outputs);
106 |
107 | static constexpr uint32_t N_THREADS = 128;
108 |
109 | kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr());
110 | }
111 |
112 |
113 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
114 | CHECK_CUDA(grad);
115 | CHECK_CUDA(outputs);
116 | CHECK_CUDA(grad_inputs);
117 |
118 | CHECK_CONTIGUOUS(grad);
119 | CHECK_CONTIGUOUS(outputs);
120 | CHECK_CONTIGUOUS(grad_inputs);
121 |
122 | CHECK_IS_FLOATING(grad);
123 | CHECK_IS_FLOATING(outputs);
124 | CHECK_IS_FLOATING(grad_inputs);
125 |
126 | static constexpr uint32_t N_THREADS = 128;
127 |
128 | kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr());
129 | }
--------------------------------------------------------------------------------
/freqencoder/src/freqencoder.h:
--------------------------------------------------------------------------------
1 | # pragma once
2 |
3 | #include
4 | #include
5 |
6 | // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
7 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
8 |
9 | // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
10 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
--------------------------------------------------------------------------------
/gridencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .grid import GridEncoder
--------------------------------------------------------------------------------
/gridencoder/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | ]
10 |
11 | if os.name == "posix":
12 | c_flags = ['-O3', '-std=c++14']
13 | elif os.name == "nt":
14 | c_flags = ['/O2', '/std:c++17']
15 |
16 | # find cl.exe
17 | def find_cl_path():
18 | import glob
19 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
22 | if paths:
23 | return paths[0]
24 | # If cl.exe is not on path, try to find it.
25 | if os.system("where cl.exe >nul 2>nul") != 0:
26 | cl_path = find_cl_path()
27 | if cl_path is None:
28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29 | os.environ["PATH"] += ";" + cl_path
30 |
31 | _backend = load(name='_grid_encoder',
32 | extra_cflags=c_flags,
33 | extra_cuda_cflags=nvcc_flags,
34 | sources=[os.path.join(_src_path, 'src', f) for f in [
35 | 'gridencoder.cu',
36 | 'bindings.cpp',
37 | ]],
38 | )
39 |
40 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/gridencoder/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
22 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
23 | if paths:
24 | return paths[0]
25 |
26 | # If cl.exe is not on path, try to find it.
27 | if os.system("where cl.exe >nul 2>nul") != 0:
28 | cl_path = find_cl_path()
29 | if cl_path is None:
30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
31 | os.environ["PATH"] += ";" + cl_path
32 |
33 | setup(
34 | name='gridencoder', # package name, import this to use python API
35 | ext_modules=[
36 | CUDAExtension(
37 | name='_gridencoder', # extension name, import this to use CUDA API
38 | sources=[os.path.join(_src_path, 'src', f) for f in [
39 | 'gridencoder.cu',
40 | 'bindings.cpp',
41 | ]],
42 | extra_compile_args={
43 | 'cxx': c_flags,
44 | 'nvcc': nvcc_flags,
45 | }
46 | ),
47 | ],
48 | cmdclass={
49 | 'build_ext': BuildExtension,
50 | }
51 | )
--------------------------------------------------------------------------------
/gridencoder/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "gridencoder.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
9 | m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)");
10 | }
--------------------------------------------------------------------------------
/gridencoder/src/gridencoder.h:
--------------------------------------------------------------------------------
1 | #ifndef _HASH_ENCODE_H
2 | #define _HASH_ENCODE_H
3 |
4 | #include
5 | #include
6 |
7 | // inputs: [B, D], float, in [0, 1]
8 | // embeddings: [sO, C], float
9 | // offsets: [L + 1], uint32_t
10 | // outputs: [B, L * C], float
11 | // H: base resolution
12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
14 |
15 | void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
16 | void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L);
17 |
18 | #endif
--------------------------------------------------------------------------------
/guidance/clip_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import torchvision.transforms as T
5 | import torchvision.transforms.functional as TF
6 |
7 | import clip
8 |
9 | class CLIP(nn.Module):
10 | def __init__(self, device, **kwargs):
11 | super().__init__()
12 |
13 | self.device = device
14 | self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
15 |
16 | self.aug = T.Compose([
17 | T.Resize((224, 224)),
18 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
19 | ])
20 |
21 | def get_text_embeds(self, prompt, **kwargs):
22 |
23 | text = clip.tokenize(prompt).to(self.device)
24 | text_z = self.clip_model.encode_text(text)
25 | text_z = text_z / text_z.norm(dim=-1, keepdim=True)
26 |
27 | return text_z
28 |
29 | def get_img_embeds(self, image, **kwargs):
30 |
31 | image_z = self.clip_model.encode_image(self.aug(image))
32 | image_z = image_z / image_z.norm(dim=-1, keepdim=True)
33 |
34 | return image_z
35 |
36 |
37 | def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs):
38 | """
39 | Args:
40 | grad_scale: scalar or 1-tensor of size [B], i.e. 1 grad_scale per batch item.
41 | """
42 | # TODO: resize the image from NeRF-rendered resolution (e.g. 128x128) to what CLIP expects (512x512), to prevent Pytorch warning about `antialias=None`.
43 | image_z = self.clip_model.encode_image(self.aug(pred_rgb))
44 | image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
45 |
46 | loss = 0
47 | if 'image' in clip_z:
48 | loss -= ((image_z * clip_z['image']).sum(-1) * grad_scale).mean()
49 |
50 | if 'text' in clip_z:
51 | loss -= ((image_z * clip_z['text']).sum(-1) * grad_scale).mean()
52 |
53 | return loss
54 |
55 |
--------------------------------------------------------------------------------
/guidance/perpneg_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm
4 | def get_perpendicular_component(x, y):
5 | assert x.shape == y.shape
6 | return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y
7 |
8 |
9 | def batch_get_perpendicular_component(x, y):
10 | assert x.shape == y.shape
11 | result = []
12 | for i in range(x.shape[0]):
13 | result.append(get_perpendicular_component(x[i], y[i]))
14 | return torch.stack(result)
15 |
16 |
17 | def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size):
18 | """
19 | Notes:
20 | - weights: an array with the weights for combining the noise predictions
21 | - delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir
22 | """
23 | delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64]
24 | weights = weights.split(batch_size, dim=0) # K x [B]
25 | # print(f"{weights[0].shape = } {weights = }")
26 |
27 | assert torch.all(weights[0] == 1.0)
28 |
29 | main_positive = delta_noise_preds[0] # [B, 4, 64, 64]
30 |
31 | accumulated_output = torch.zeros_like(main_positive)
32 | for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1):
33 | # print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n")
34 |
35 | idx_non_zero = torch.abs(weights[i]) > 1e-4
36 |
37 | # print(f"{idx_non_zero.shape = }, {idx_non_zero = }")
38 | # print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }")
39 | # print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }")
40 | # print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }")
41 | if sum(idx_non_zero) == 0:
42 | continue
43 | accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero])
44 |
45 | assert accumulated_output.shape == main_positive.shape, f"{accumulated_output.shape = }, {main_positive.shape = }"
46 |
47 |
48 | return accumulated_output + main_positive
--------------------------------------------------------------------------------
/ldm/extras.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from omegaconf import OmegaConf
3 | import torch
4 | from ldm.util import instantiate_from_config
5 | import logging
6 | from contextlib import contextmanager
7 |
8 | from contextlib import contextmanager
9 | import logging
10 |
11 | @contextmanager
12 | def all_logging_disabled(highest_level=logging.CRITICAL):
13 | """
14 | A context manager that will prevent any logging messages
15 | triggered during the body from being processed.
16 |
17 | :param highest_level: the maximum logging level in use.
18 | This would only need to be changed if a custom level greater than CRITICAL
19 | is defined.
20 |
21 | https://gist.github.com/simon-weber/7853144
22 | """
23 | # two kind-of hacks here:
24 | # * can't get the highest logging level in effect => delegate to the user
25 | # * can't get the current module-level override => use an undocumented
26 | # (but non-private!) interface
27 |
28 | previous_level = logging.root.manager.disable
29 |
30 | logging.disable(highest_level)
31 |
32 | try:
33 | yield
34 | finally:
35 | logging.disable(previous_level)
36 |
37 | def load_training_dir(train_dir, device, epoch="last"):
38 | """Load a checkpoint and config from training directory"""
39 | train_dir = Path(train_dir)
40 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
41 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
42 | config = list(train_dir.rglob(f"*-project.yaml"))
43 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
44 | if len(config) > 1:
45 | print(f"found {len(config)} matching config files")
46 | config = sorted(config)[-1]
47 | print(f"selecting {config}")
48 | else:
49 | config = config[0]
50 |
51 |
52 | config = OmegaConf.load(config)
53 | return load_model_from_config(config, ckpt[0], device)
54 |
55 | def load_model_from_config(config, ckpt, device="cpu", verbose=False):
56 | """Loads a model from config and a ckpt
57 | if config is a path will use omegaconf to load
58 | """
59 | if isinstance(config, (str, Path)):
60 | config = OmegaConf.load(config)
61 |
62 | with all_logging_disabled():
63 | print(f"Loading model from {ckpt}")
64 | pl_sd = torch.load(ckpt, map_location="cpu")
65 | global_step = pl_sd["global_step"]
66 | sd = pl_sd["state_dict"]
67 | model = instantiate_from_config(config.model)
68 | m, u = model.load_state_dict(sd, strict=False)
69 | if len(m) > 0 and verbose:
70 | print("missing keys:")
71 | print(m)
72 | if len(u) > 0 and verbose:
73 | print("unexpected keys:")
74 | model.to(device)
75 | model.eval()
76 | model.cond_stage_model.device = device
77 | return model
--------------------------------------------------------------------------------
/ldm/guidance.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | from scipy import interpolate
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | from IPython.display import clear_output
7 | import abc
8 |
9 |
10 | class GuideModel(torch.nn.Module, abc.ABC):
11 | def __init__(self) -> None:
12 | super().__init__()
13 |
14 | @abc.abstractmethod
15 | def preprocess(self, x_img):
16 | pass
17 |
18 | @abc.abstractmethod
19 | def compute_loss(self, inp):
20 | pass
21 |
22 |
23 | class Guider(torch.nn.Module):
24 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
25 | """Apply classifier guidance
26 |
27 | Specify a guidance scale as either a scalar
28 | Or a schedule as a list of tuples t = 0->1 and scale, e.g.
29 | [(0, 10), (0.5, 20), (1, 50)]
30 | """
31 | super().__init__()
32 | self.sampler = sampler
33 | self.index = 0
34 | self.show = verbose
35 | self.guide_model = guide_model
36 | self.history = []
37 |
38 | if isinstance(scale, (Tuple, List)):
39 | times = np.array([x[0] for x in scale])
40 | values = np.array([x[1] for x in scale])
41 | self.scale_schedule = {"times": times, "values": values}
42 | else:
43 | self.scale_schedule = float(scale)
44 |
45 | self.ddim_timesteps = sampler.ddim_timesteps
46 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
47 |
48 |
49 | def get_scales(self):
50 | if isinstance(self.scale_schedule, float):
51 | return len(self.ddim_timesteps)*[self.scale_schedule]
52 |
53 | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
54 | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
55 | return interpolater(fractional_steps)
56 |
57 | def modify_score(self, model, e_t, x, t, c):
58 |
59 | # TODO look up index by t
60 | scale = self.get_scales()[self.index]
61 |
62 | if (scale == 0):
63 | return e_t
64 |
65 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
66 | with torch.enable_grad():
67 | x_in = x.detach().requires_grad_(True)
68 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
69 | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
70 |
71 | inp = self.guide_model.preprocess(x_img)
72 | loss = self.guide_model.compute_loss(inp)
73 | grads = torch.autograd.grad(loss.sum(), x_in)[0]
74 | correction = grads * scale
75 |
76 | if self.show:
77 | clear_output(wait=True)
78 | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
79 | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
80 | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
81 | plt.axis('off')
82 | plt.show()
83 | plt.imshow(correction[0][0].detach().cpu())
84 | plt.axis('off')
85 | plt.show()
86 |
87 |
88 | e_t_mod = e_t - sqrt_1ma*correction
89 | if self.show:
90 | fig, axs = plt.subplots(1, 3)
91 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
92 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
93 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
94 | plt.show()
95 | self.index += 1
96 | return e_t_mod
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def renorm_thresholding(x0, value):
15 | # renorm
16 | pred_max = x0.max()
17 | pred_min = x0.min()
18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
20 |
21 | s = torch.quantile(
22 | rearrange(pred_x0, 'b ... -> b (...)').abs(),
23 | value,
24 | dim=-1
25 | )
26 | s.clamp_(min=1.0)
27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
28 |
29 | # clip by threshold
30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
31 |
32 | # temporary hack: numpy on cpu
33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device)
35 |
36 | # re.renorm
37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
39 | return pred_x0
40 |
41 |
42 | def norm_thresholding(x0, value):
43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
44 | return x0 * (value / s)
45 |
46 |
47 | def spatial_norm_thresholding(x0, value):
48 | # b c h w
49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
50 | return x0 * (value / s)
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/evaluate/frechet_video_distance.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python2, python3
17 | """Minimal Reference implementation for the Frechet Video Distance (FVD).
18 |
19 | FVD is a metric for the quality of video generation models. It is inspired by
20 | the FID (Frechet Inception Distance) used for images, but uses a different
21 | embedding to be better suitable for videos.
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 |
29 | import six
30 | import tensorflow.compat.v1 as tf
31 | import tensorflow_gan as tfgan
32 | import tensorflow_hub as hub
33 |
34 |
35 | def preprocess(videos, target_resolution):
36 | """Runs some preprocessing on the videos for I3D model.
37 |
38 | Args:
39 | videos: [batch_size, num_frames, height, width, depth] The videos to be
40 | preprocessed. We don't care about the specific dtype of the videos, it can
41 | be anything that tf.image.resize_bilinear accepts. Values are expected to
42 | be in the range 0-255.
43 | target_resolution: (width, height): target video resolution
44 |
45 | Returns:
46 | videos: [batch_size, num_frames, height, width, depth]
47 | """
48 | videos_shape = list(videos.shape)
49 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
50 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
51 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
52 | output_videos = tf.reshape(resized_videos, target_shape)
53 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1
54 | return scaled_videos
55 |
56 |
57 | def _is_in_graph(tensor_name):
58 | """Checks whether a given tensor does exists in the graph."""
59 | try:
60 | tf.get_default_graph().get_tensor_by_name(tensor_name)
61 | except KeyError:
62 | return False
63 | return True
64 |
65 |
66 | def create_id3_embedding(videos,warmup=False,batch_size=16):
67 | """Embeds the given videos using the Inflated 3D Convolution ne twork.
68 |
69 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the
70 | first call.
71 |
72 | Args:
73 | videos: [batch_size, num_frames, height=224, width=224, depth=3].
74 | Expected range is [-1, 1].
75 |
76 | Returns:
77 | embedding: [batch_size, embedding_size]. embedding_size depends
78 | on the model used.
79 |
80 | Raises:
81 | ValueError: when a provided embedding_layer is not supported.
82 | """
83 |
84 | # batch_size = 16
85 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
86 |
87 |
88 | # Making sure that we import the graph separately for
89 | # each different input video tensor.
90 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
91 | videos.name).replace(":", "_")
92 |
93 |
94 |
95 | assert_ops = [
96 | tf.Assert(
97 | tf.reduce_max(videos) <= 1.001,
98 | ["max value in frame is > 1", videos]),
99 | tf.Assert(
100 | tf.reduce_min(videos) >= -1.001,
101 | ["min value in frame is < -1", videos]),
102 | tf.assert_equal(
103 | tf.shape(videos)[0],
104 | batch_size, ["invalid frame batch size: ",
105 | tf.shape(videos)],
106 | summarize=6),
107 | ]
108 | with tf.control_dependencies(assert_ops):
109 | videos = tf.identity(videos)
110 |
111 | module_scope = "%s_apply_default/" % module_name
112 |
113 | # To check whether the module has already been loaded into the graph, we look
114 | # for a given tensor name. If this tensor name exists, we assume the function
115 | # has been called before and the graph was imported. Otherwise we import it.
116 | # Note: in theory, the tensor could exist, but have wrong shapes.
117 | # This will happen if create_id3_embedding is called with a frames_placehoder
118 | # of wrong size/batch size, because even though that will throw a tf.Assert
119 | # on graph-execution time, it will insert the tensor (with wrong shape) into
120 | # the graph. This is why we need the following assert.
121 | if warmup:
122 | video_batch_size = int(videos.shape[0])
123 | assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}"
124 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
125 | if not _is_in_graph(tensor_name):
126 | i3d_model = hub.Module(module_spec, name=module_name)
127 | i3d_model(videos)
128 |
129 | # gets the kinetics-i3d-400-logits layer
130 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
131 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
132 | return tensor
133 |
134 |
135 | def calculate_fvd(real_activations,
136 | generated_activations):
137 | """Returns a list of ops that compute metrics as funcs of activations.
138 |
139 | Args:
140 | real_activations: [num_samples, embedding_size]
141 | generated_activations: [num_samples, embedding_size]
142 |
143 | Returns:
144 | A scalar that contains the requested FVD.
145 | """
146 | return tfgan.eval.frechet_classifier_distance_from_activations(
147 | real_activations, generated_activations)
148 |
--------------------------------------------------------------------------------
/ldm/modules/evaluate/ssim.py:
--------------------------------------------------------------------------------
1 | # MIT Licence
2 |
3 | # Methods to predict the SSIM, taken from
4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
5 |
6 | from math import exp
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 |
12 | def gaussian(window_size, sigma):
13 | gauss = torch.Tensor(
14 | [
15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))
16 | for x in range(window_size)
17 | ]
18 | )
19 | return gauss / gauss.sum()
20 |
21 |
22 | def create_window(window_size, channel):
23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
25 | window = Variable(
26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous()
27 | )
28 | return window
29 |
30 |
31 | def _ssim(
32 | img1, img2, window, window_size, channel, mask=None, size_average=True
33 | ):
34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
36 |
37 | mu1_sq = mu1.pow(2)
38 | mu2_sq = mu2.pow(2)
39 | mu1_mu2 = mu1 * mu2
40 |
41 | sigma1_sq = (
42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
43 | - mu1_sq
44 | )
45 | sigma2_sq = (
46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
47 | - mu2_sq
48 | )
49 | sigma12 = (
50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
51 | - mu1_mu2
52 | )
53 |
54 | C1 = (0.01) ** 2
55 | C2 = (0.03) ** 2
56 |
57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
59 | )
60 |
61 | if not (mask is None):
62 | b = mask.size(0)
63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(
65 | dim=1
66 | ).clamp(min=1)
67 | return ssim_map
68 |
69 | import pdb
70 |
71 | pdb.set_trace
72 |
73 | if size_average:
74 | return ssim_map.mean()
75 | else:
76 | return ssim_map.mean(1).mean(1).mean(1)
77 |
78 |
79 | class SSIM(torch.nn.Module):
80 | def __init__(self, window_size=11, size_average=True):
81 | super(SSIM, self).__init__()
82 | self.window_size = window_size
83 | self.size_average = size_average
84 | self.channel = 1
85 | self.window = create_window(window_size, self.channel)
86 |
87 | def forward(self, img1, img2, mask=None):
88 | (_, channel, _, _) = img1.size()
89 |
90 | if (
91 | channel == self.channel
92 | and self.window.data.type() == img1.data.type()
93 | ):
94 | window = self.window
95 | else:
96 | window = create_window(self.window_size, channel)
97 |
98 | if img1.is_cuda:
99 | window = window.cuda(img1.get_device())
100 | window = window.type_as(img1)
101 |
102 | self.window = window
103 | self.channel = channel
104 |
105 | return _ssim(
106 | img1,
107 | img2,
108 | window,
109 | self.window_size,
110 | channel,
111 | mask,
112 | self.size_average,
113 | )
114 |
115 |
116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True):
117 | (_, channel, _, _) = img1.size()
118 | window = create_window(window_size, channel)
119 |
120 | if img1.is_cuda:
121 | window = window.cuda(img1.get_device())
122 | window = window.type_as(img1)
123 |
124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average)
125 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/ldm/thirdp/psp/helpers.py:
--------------------------------------------------------------------------------
1 | # https://github.com/eladrich/pixel2style2pixel
2 |
3 | from collections import namedtuple
4 | import torch
5 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
6 |
7 | """
8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
9 | """
10 |
11 |
12 | class Flatten(Module):
13 | def forward(self, input):
14 | return input.view(input.size(0), -1)
15 |
16 |
17 | def l2_norm(input, axis=1):
18 | norm = torch.norm(input, 2, axis, True)
19 | output = torch.div(input, norm)
20 | return output
21 |
22 |
23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
24 | """ A named tuple describing a ResNet block. """
25 |
26 |
27 | def get_block(in_channel, depth, num_units, stride=2):
28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
29 |
30 |
31 | def get_blocks(num_layers):
32 | if num_layers == 50:
33 | blocks = [
34 | get_block(in_channel=64, depth=64, num_units=3),
35 | get_block(in_channel=64, depth=128, num_units=4),
36 | get_block(in_channel=128, depth=256, num_units=14),
37 | get_block(in_channel=256, depth=512, num_units=3)
38 | ]
39 | elif num_layers == 100:
40 | blocks = [
41 | get_block(in_channel=64, depth=64, num_units=3),
42 | get_block(in_channel=64, depth=128, num_units=13),
43 | get_block(in_channel=128, depth=256, num_units=30),
44 | get_block(in_channel=256, depth=512, num_units=3)
45 | ]
46 | elif num_layers == 152:
47 | blocks = [
48 | get_block(in_channel=64, depth=64, num_units=3),
49 | get_block(in_channel=64, depth=128, num_units=8),
50 | get_block(in_channel=128, depth=256, num_units=36),
51 | get_block(in_channel=256, depth=512, num_units=3)
52 | ]
53 | else:
54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
55 | return blocks
56 |
57 |
58 | class SEModule(Module):
59 | def __init__(self, channels, reduction):
60 | super(SEModule, self).__init__()
61 | self.avg_pool = AdaptiveAvgPool2d(1)
62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
63 | self.relu = ReLU(inplace=True)
64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
65 | self.sigmoid = Sigmoid()
66 |
67 | def forward(self, x):
68 | module_input = x
69 | x = self.avg_pool(x)
70 | x = self.fc1(x)
71 | x = self.relu(x)
72 | x = self.fc2(x)
73 | x = self.sigmoid(x)
74 | return module_input * x
75 |
76 |
77 | class bottleneck_IR(Module):
78 | def __init__(self, in_channel, depth, stride):
79 | super(bottleneck_IR, self).__init__()
80 | if in_channel == depth:
81 | self.shortcut_layer = MaxPool2d(1, stride)
82 | else:
83 | self.shortcut_layer = Sequential(
84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
85 | BatchNorm2d(depth)
86 | )
87 | self.res_layer = Sequential(
88 | BatchNorm2d(in_channel),
89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
91 | )
92 |
93 | def forward(self, x):
94 | shortcut = self.shortcut_layer(x)
95 | res = self.res_layer(x)
96 | return res + shortcut
97 |
98 |
99 | class bottleneck_IR_SE(Module):
100 | def __init__(self, in_channel, depth, stride):
101 | super(bottleneck_IR_SE, self).__init__()
102 | if in_channel == depth:
103 | self.shortcut_layer = MaxPool2d(1, stride)
104 | else:
105 | self.shortcut_layer = Sequential(
106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
107 | BatchNorm2d(depth)
108 | )
109 | self.res_layer = Sequential(
110 | BatchNorm2d(in_channel),
111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
112 | PReLU(depth),
113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
114 | BatchNorm2d(depth),
115 | SEModule(depth, 16)
116 | )
117 |
118 | def forward(self, x):
119 | shortcut = self.shortcut_layer(x)
120 | res = self.res_layer(x)
121 | return res + shortcut
--------------------------------------------------------------------------------
/ldm/thirdp/psp/id_loss.py:
--------------------------------------------------------------------------------
1 | # https://github.com/eladrich/pixel2style2pixel
2 | import torch
3 | from torch import nn
4 | from ldm.thirdp.psp.model_irse import Backbone
5 |
6 |
7 | class IDFeatures(nn.Module):
8 | def __init__(self, model_path):
9 | super(IDFeatures, self).__init__()
10 | print('Loading ResNet ArcFace')
11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12 | self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
14 | self.facenet.eval()
15 |
16 | def forward(self, x, crop=False):
17 | # Not sure of the image range here
18 | if crop:
19 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
20 | x = x[:, :, 35:223, 32:220]
21 | x = self.face_pool(x)
22 | x_feats = self.facenet(x)
23 | return x_feats
24 |
--------------------------------------------------------------------------------
/ldm/thirdp/psp/model_irse.py:
--------------------------------------------------------------------------------
1 | # https://github.com/eladrich/pixel2style2pixel
2 |
3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
4 | from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
5 |
6 | """
7 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
8 | """
9 |
10 |
11 | class Backbone(Module):
12 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
13 | super(Backbone, self).__init__()
14 | assert input_size in [112, 224], "input_size should be 112 or 224"
15 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
16 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
17 | blocks = get_blocks(num_layers)
18 | if mode == 'ir':
19 | unit_module = bottleneck_IR
20 | elif mode == 'ir_se':
21 | unit_module = bottleneck_IR_SE
22 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
23 | BatchNorm2d(64),
24 | PReLU(64))
25 | if input_size == 112:
26 | self.output_layer = Sequential(BatchNorm2d(512),
27 | Dropout(drop_ratio),
28 | Flatten(),
29 | Linear(512 * 7 * 7, 512),
30 | BatchNorm1d(512, affine=affine))
31 | else:
32 | self.output_layer = Sequential(BatchNorm2d(512),
33 | Dropout(drop_ratio),
34 | Flatten(),
35 | Linear(512 * 14 * 14, 512),
36 | BatchNorm1d(512, affine=affine))
37 |
38 | modules = []
39 | for block in blocks:
40 | for bottleneck in block:
41 | modules.append(unit_module(bottleneck.in_channel,
42 | bottleneck.depth,
43 | bottleneck.stride))
44 | self.body = Sequential(*modules)
45 |
46 | def forward(self, x):
47 | x = self.input_layer(x)
48 | x = self.body(x)
49 | x = self.output_layer(x)
50 | return l2_norm(x)
51 |
52 |
53 | def IR_50(input_size):
54 | """Constructs a ir-50 model."""
55 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
56 | return model
57 |
58 |
59 | def IR_101(input_size):
60 | """Constructs a ir-101 model."""
61 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
62 | return model
63 |
64 |
65 | def IR_152(input_size):
66 | """Constructs a ir-152 model."""
67 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
68 | return model
69 |
70 |
71 | def IR_SE_50(input_size):
72 | """Constructs a ir_se-50 model."""
73 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
74 | return model
75 |
76 |
77 | def IR_SE_101(input_size):
78 | """Constructs a ir_se-101 model."""
79 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
80 | return model
81 |
82 |
83 | def IR_SE_152(input_size):
84 | """Constructs a ir_se-152 model."""
85 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
86 | return model
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torchvision
4 | import torch
5 | from torch import optim
6 | import numpy as np
7 |
8 | from inspect import isfunction
9 | from PIL import Image, ImageDraw, ImageFont
10 |
11 | import os
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 | from PIL import Image
15 | import torch
16 | import time
17 | import cv2
18 |
19 | import PIL
20 |
21 | def pil_rectangle_crop(im):
22 | width, height = im.size # Get dimensions
23 |
24 | if width <= height:
25 | left = 0
26 | right = width
27 | top = (height - width)/2
28 | bottom = (height + width)/2
29 | else:
30 |
31 | top = 0
32 | bottom = height
33 | left = (width - height) / 2
34 | bottom = (width + height) / 2
35 |
36 | # Crop the center of the image
37 | im = im.crop((left, top, right, bottom))
38 | return im
39 |
40 |
41 | def log_txt_as_img(wh, xc, size=10):
42 | # wh a tuple of (width, height)
43 | # xc a list of captions to plot
44 | b = len(xc)
45 | txts = list()
46 | for bi in range(b):
47 | txt = Image.new("RGB", wh, color="white")
48 | draw = ImageDraw.Draw(txt)
49 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
50 | nc = int(40 * (wh[0] / 256))
51 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
52 |
53 | try:
54 | draw.text((0, 0), lines, fill="black", font=font)
55 | except UnicodeEncodeError:
56 | print("Cant encode string for logging. Skipping.")
57 |
58 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
59 | txts.append(txt)
60 | txts = np.stack(txts)
61 | txts = torch.tensor(txts)
62 | return txts
63 |
64 |
65 | def ismap(x):
66 | if not isinstance(x, torch.Tensor):
67 | return False
68 | return (len(x.shape) == 4) and (x.shape[1] > 3)
69 |
70 |
71 | def isimage(x):
72 | if not isinstance(x,torch.Tensor):
73 | return False
74 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
75 |
76 |
77 | def exists(x):
78 | return x is not None
79 |
80 |
81 | def default(val, d):
82 | if exists(val):
83 | return val
84 | return d() if isfunction(d) else d
85 |
86 |
87 | def mean_flat(tensor):
88 | """
89 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
90 | Take the mean over all non-batch dimensions.
91 | """
92 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
93 |
94 |
95 | def count_params(model, verbose=False):
96 | total_params = sum(p.numel() for p in model.parameters())
97 | if verbose:
98 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
99 | return total_params
100 |
101 |
102 | def instantiate_from_config(config):
103 | if not "target" in config:
104 | if config == '__is_first_stage__':
105 | return None
106 | elif config == "__is_unconditional__":
107 | return None
108 | raise KeyError("Expected key `target` to instantiate.")
109 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
110 |
111 |
112 | def get_obj_from_str(string, reload=False):
113 | module, cls = string.rsplit(".", 1)
114 | if reload:
115 | module_imp = importlib.import_module(module)
116 | importlib.reload(module_imp)
117 | return getattr(importlib.import_module(module, package=None), cls)
118 |
119 |
120 | class AdamWwithEMAandWings(optim.Optimizer):
121 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
122 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
123 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
124 | ema_power=1., param_names=()):
125 | """AdamW that saves EMA versions of the parameters."""
126 | if not 0.0 <= lr:
127 | raise ValueError("Invalid learning rate: {}".format(lr))
128 | if not 0.0 <= eps:
129 | raise ValueError("Invalid epsilon value: {}".format(eps))
130 | if not 0.0 <= betas[0] < 1.0:
131 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
132 | if not 0.0 <= betas[1] < 1.0:
133 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
134 | if not 0.0 <= weight_decay:
135 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
136 | if not 0.0 <= ema_decay <= 1.0:
137 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
138 | defaults = dict(lr=lr, betas=betas, eps=eps,
139 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
140 | ema_power=ema_power, param_names=param_names)
141 | super().__init__(params, defaults)
142 |
143 | def __setstate__(self, state):
144 | super().__setstate__(state)
145 | for group in self.param_groups:
146 | group.setdefault('amsgrad', False)
147 |
148 | @torch.no_grad()
149 | def step(self, closure=None):
150 | """Performs a single optimization step.
151 | Args:
152 | closure (callable, optional): A closure that reevaluates the model
153 | and returns the loss.
154 | """
155 | loss = None
156 | if closure is not None:
157 | with torch.enable_grad():
158 | loss = closure()
159 |
160 | for group in self.param_groups:
161 | params_with_grad = []
162 | grads = []
163 | exp_avgs = []
164 | exp_avg_sqs = []
165 | ema_params_with_grad = []
166 | state_sums = []
167 | max_exp_avg_sqs = []
168 | state_steps = []
169 | amsgrad = group['amsgrad']
170 | beta1, beta2 = group['betas']
171 | ema_decay = group['ema_decay']
172 | ema_power = group['ema_power']
173 |
174 | for p in group['params']:
175 | if p.grad is None:
176 | continue
177 | params_with_grad.append(p)
178 | if p.grad.is_sparse:
179 | raise RuntimeError('AdamW does not support sparse gradients')
180 | grads.append(p.grad)
181 |
182 | state = self.state[p]
183 |
184 | # State initialization
185 | if len(state) == 0:
186 | state['step'] = 0
187 | # Exponential moving average of gradient values
188 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
189 | # Exponential moving average of squared gradient values
190 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
191 | if amsgrad:
192 | # Maintains max of all exp. moving avg. of sq. grad. values
193 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
194 | # Exponential moving average of parameter values
195 | state['param_exp_avg'] = p.detach().float().clone()
196 |
197 | exp_avgs.append(state['exp_avg'])
198 | exp_avg_sqs.append(state['exp_avg_sq'])
199 | ema_params_with_grad.append(state['param_exp_avg'])
200 |
201 | if amsgrad:
202 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
203 |
204 | # update the steps for each param group update
205 | state['step'] += 1
206 | # record the step after step update
207 | state_steps.append(state['step'])
208 |
209 | optim._functional.adamw(params_with_grad,
210 | grads,
211 | exp_avgs,
212 | exp_avg_sqs,
213 | max_exp_avg_sqs,
214 | state_steps,
215 | amsgrad=amsgrad,
216 | beta1=beta1,
217 | beta2=beta2,
218 | lr=group['lr'],
219 | weight_decay=group['weight_decay'],
220 | eps=group['eps'],
221 | maximize=False)
222 |
223 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
224 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
225 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
226 |
227 | return loss
--------------------------------------------------------------------------------
/meshutils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pymeshlab as pml
3 |
4 | def poisson_mesh_reconstruction(points, normals=None):
5 | # points/normals: [N, 3] np.ndarray
6 |
7 | import open3d as o3d
8 |
9 | pcd = o3d.geometry.PointCloud()
10 | pcd.points = o3d.utility.Vector3dVector(points)
11 |
12 | # outlier removal
13 | pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)
14 |
15 | # normals
16 | if normals is None:
17 | pcd.estimate_normals()
18 | else:
19 | pcd.normals = o3d.utility.Vector3dVector(normals[ind])
20 |
21 | # visualize
22 | o3d.visualization.draw_geometries([pcd], point_show_normal=False)
23 |
24 | mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9)
25 | vertices_to_remove = densities < np.quantile(densities, 0.1)
26 | mesh.remove_vertices_by_mask(vertices_to_remove)
27 |
28 | # visualize
29 | o3d.visualization.draw_geometries([mesh])
30 |
31 | vertices = np.asarray(mesh.vertices)
32 | triangles = np.asarray(mesh.triangles)
33 |
34 | print(f'[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}')
35 |
36 | return vertices, triangles
37 |
38 |
39 | def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True):
40 | # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
41 |
42 | _ori_vert_shape = verts.shape
43 | _ori_face_shape = faces.shape
44 |
45 | if backend == 'pyfqmr':
46 | import pyfqmr
47 | solver = pyfqmr.Simplify()
48 | solver.setMesh(verts, faces)
49 | solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
50 | verts, faces, normals = solver.getMesh()
51 | else:
52 |
53 | m = pml.Mesh(verts, faces)
54 | ms = pml.MeshSet()
55 | ms.add_mesh(m, 'mesh') # will copy!
56 |
57 | # filters
58 | # ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
59 | ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement)
60 |
61 | if remesh:
62 | # ms.apply_coord_taubin_smoothing()
63 | ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1))
64 |
65 | # extract mesh
66 | m = ms.current_mesh()
67 | verts = m.vertex_matrix()
68 | faces = m.face_matrix()
69 |
70 | print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
71 |
72 | return verts, faces
73 |
74 |
75 | def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01):
76 | # verts: [N, 3]
77 | # faces: [N, 3]
78 |
79 | _ori_vert_shape = verts.shape
80 | _ori_face_shape = faces.shape
81 |
82 | m = pml.Mesh(verts, faces)
83 | ms = pml.MeshSet()
84 | ms.add_mesh(m, 'mesh') # will copy!
85 |
86 | # filters
87 | ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
88 |
89 | if v_pct > 0:
90 | ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal
91 |
92 | ms.meshing_remove_duplicate_faces() # faces defined by the same verts
93 | ms.meshing_remove_null_faces() # faces with area == 0
94 |
95 | if min_d > 0:
96 | ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d))
97 |
98 | if min_f > 0:
99 | ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
100 |
101 | if repair:
102 | # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
103 | ms.meshing_repair_non_manifold_edges(method=0)
104 | ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
105 |
106 | if remesh:
107 | # ms.apply_coord_taubin_smoothing()
108 | ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size))
109 |
110 | # extract mesh
111 | m = ms.current_mesh()
112 | verts = m.vertex_matrix()
113 | faces = m.face_matrix()
114 |
115 | print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
116 |
117 | return verts, faces
--------------------------------------------------------------------------------
/nerf/network_grid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from activation import trunc_exp, biased_softplus
6 | from .renderer import NeRFRenderer
7 |
8 | import numpy as np
9 | from encoding import get_encoder
10 |
11 | from .utils import safe_normalize
12 |
13 | class MLP(nn.Module):
14 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
15 | super().__init__()
16 | self.dim_in = dim_in
17 | self.dim_out = dim_out
18 | self.dim_hidden = dim_hidden
19 | self.num_layers = num_layers
20 |
21 | net = []
22 | for l in range(num_layers):
23 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
24 |
25 | self.net = nn.ModuleList(net)
26 |
27 | def forward(self, x):
28 | for l in range(self.num_layers):
29 | x = self.net[l](x)
30 | if l != self.num_layers - 1:
31 | x = F.relu(x, inplace=True)
32 | return x
33 |
34 |
35 | class NeRFNetwork(NeRFRenderer):
36 | def __init__(self,
37 | opt,
38 | num_layers=3,
39 | hidden_dim=64,
40 | num_layers_bg=2,
41 | hidden_dim_bg=32,
42 | ):
43 |
44 | super().__init__(opt)
45 |
46 | self.num_layers = num_layers
47 | self.hidden_dim = hidden_dim
48 |
49 | self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
50 |
51 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
52 | # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
53 |
54 | self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus
55 |
56 | # background network
57 | if self.opt.bg_radius > 0:
58 | self.num_layers_bg = num_layers_bg
59 | self.hidden_dim_bg = hidden_dim_bg
60 |
61 | # use a very simple network to avoid it learning the prompt...
62 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
63 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
64 |
65 | else:
66 | self.bg_net = None
67 |
68 | def common_forward(self, x):
69 |
70 | # sigma
71 | enc = self.encoder(x, bound=self.bound, max_level=self.max_level)
72 |
73 | h = self.sigma_net(enc)
74 |
75 | sigma = self.density_activation(h[..., 0] + self.density_blob(x))
76 | albedo = torch.sigmoid(h[..., 1:])
77 |
78 | return sigma, albedo
79 |
80 | # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
81 | def finite_difference_normal(self, x, epsilon=1e-2):
82 | # x: [N, 3]
83 | dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
84 | dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
85 | dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
86 | dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
87 | dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
88 | dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
89 |
90 | normal = torch.stack([
91 | 0.5 * (dx_pos - dx_neg) / epsilon,
92 | 0.5 * (dy_pos - dy_neg) / epsilon,
93 | 0.5 * (dz_pos - dz_neg) / epsilon
94 | ], dim=-1)
95 |
96 | return -normal
97 |
98 | def normal(self, x):
99 | normal = self.finite_difference_normal(x)
100 | normal = safe_normalize(normal)
101 | normal = torch.nan_to_num(normal)
102 | return normal
103 |
104 | def forward(self, x, d, l=None, ratio=1, shading='albedo'):
105 | # x: [N, 3], in [-bound, bound]
106 | # d: [N, 3], view direction, nomalized in [-1, 1]
107 | # l: [3], plane light direction, nomalized in [-1, 1]
108 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
109 |
110 | sigma, albedo = self.common_forward(x)
111 |
112 | if shading == 'albedo':
113 | normal = None
114 | color = albedo
115 |
116 | else: # lambertian shading
117 |
118 | # normal = self.normal_net(enc)
119 | normal = self.normal(x)
120 |
121 | lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
122 |
123 | if shading == 'textureless':
124 | color = lambertian.unsqueeze(-1).repeat(1, 3)
125 | elif shading == 'normal':
126 | color = (normal + 1) / 2
127 | else: # 'lambertian'
128 | color = albedo * lambertian.unsqueeze(-1)
129 |
130 | return sigma, color, normal
131 |
132 |
133 | def density(self, x):
134 | # x: [N, 3], in [-bound, bound]
135 |
136 | sigma, albedo = self.common_forward(x)
137 |
138 | return {
139 | 'sigma': sigma,
140 | 'albedo': albedo,
141 | }
142 |
143 |
144 | def background(self, d):
145 |
146 | h = self.encoder_bg(d) # [N, C]
147 |
148 | h = self.bg_net(h)
149 |
150 | # sigmoid activation for rgb
151 | rgbs = torch.sigmoid(h)
152 |
153 | return rgbs
154 |
155 | # optimizer utils
156 | def get_params(self, lr):
157 |
158 | params = [
159 | {'params': self.encoder.parameters(), 'lr': lr * 10},
160 | {'params': self.sigma_net.parameters(), 'lr': lr},
161 | # {'params': self.normal_net.parameters(), 'lr': lr},
162 | ]
163 |
164 | if self.opt.bg_radius > 0:
165 | # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
166 | params.append({'params': self.bg_net.parameters(), 'lr': lr})
167 |
168 | if self.opt.dmtet and not self.opt.lock_geo:
169 | params.append({'params': self.sdf, 'lr': lr})
170 | params.append({'params': self.deform, 'lr': lr})
171 |
172 | return params
--------------------------------------------------------------------------------
/nerf/network_grid_taichi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from activation import trunc_exp
6 | from .renderer import NeRFRenderer
7 |
8 | import numpy as np
9 | from encoding import get_encoder
10 |
11 | from .utils import safe_normalize
12 |
13 | class MLP(nn.Module):
14 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
15 | super().__init__()
16 | self.dim_in = dim_in
17 | self.dim_out = dim_out
18 | self.dim_hidden = dim_hidden
19 | self.num_layers = num_layers
20 |
21 | net = []
22 | for l in range(num_layers):
23 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
24 |
25 | self.net = nn.ModuleList(net)
26 |
27 | def forward(self, x):
28 | for l in range(self.num_layers):
29 | x = self.net[l](x)
30 | if l != self.num_layers - 1:
31 | x = F.relu(x, inplace=True)
32 | return x
33 |
34 |
35 | class NeRFNetwork(NeRFRenderer):
36 | def __init__(self,
37 | opt,
38 | num_layers=2,
39 | hidden_dim=32,
40 | num_layers_bg=2,
41 | hidden_dim_bg=16,
42 | ):
43 |
44 | super().__init__(opt)
45 |
46 | self.num_layers = num_layers
47 | self.hidden_dim = hidden_dim
48 |
49 | self.encoder, self.in_dim = get_encoder('hashgrid_taichi', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
50 |
51 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
52 | # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
53 |
54 | self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else F.softplus
55 |
56 | # background network
57 | if self.opt.bg_radius > 0:
58 | self.num_layers_bg = num_layers_bg
59 | self.hidden_dim_bg = hidden_dim_bg
60 | # use a very simple network to avoid it learning the prompt...
61 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency_torch', input_dim=3, multires=4) # TODO: freq encoder can be replaced by a Taichi implementation
62 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
63 |
64 | else:
65 | self.bg_net = None
66 |
67 | def common_forward(self, x):
68 |
69 | # sigma
70 | enc = self.encoder(x, bound=self.bound)
71 |
72 | h = self.sigma_net(enc)
73 |
74 | sigma = self.density_activation(h[..., 0] + self.density_blob(x))
75 | albedo = torch.sigmoid(h[..., 1:])
76 |
77 | return sigma, albedo
78 |
79 | # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
80 | def finite_difference_normal(self, x, epsilon=1e-2):
81 | # x: [N, 3]
82 | dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
83 | dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
84 | dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
85 | dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
86 | dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
87 | dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
88 |
89 | normal = torch.stack([
90 | 0.5 * (dx_pos - dx_neg) / epsilon,
91 | 0.5 * (dy_pos - dy_neg) / epsilon,
92 | 0.5 * (dz_pos - dz_neg) / epsilon
93 | ], dim=-1)
94 |
95 | return -normal
96 |
97 | def normal(self, x):
98 | normal = self.finite_difference_normal(x)
99 | normal = safe_normalize(normal)
100 | normal = torch.nan_to_num(normal)
101 | return normal
102 |
103 | def forward(self, x, d, l=None, ratio=1, shading='albedo'):
104 | # x: [N, 3], in [-bound, bound]
105 | # d: [N, 3], view direction, nomalized in [-1, 1]
106 | # l: [3], plane light direction, nomalized in [-1, 1]
107 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
108 |
109 | sigma, albedo = self.common_forward(x)
110 |
111 | if shading == 'albedo':
112 | normal = None
113 | color = albedo
114 |
115 | else: # lambertian shading
116 | # normal = self.normal_net(enc)
117 | normal = self.normal(x)
118 |
119 | lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
120 |
121 | if shading == 'textureless':
122 | color = lambertian.unsqueeze(-1).repeat(1, 3)
123 | elif shading == 'normal':
124 | color = (normal + 1) / 2
125 | else: # 'lambertian'
126 | color = albedo * lambertian.unsqueeze(-1)
127 |
128 | return sigma, color, normal
129 |
130 |
131 | def density(self, x):
132 | # x: [N, 3], in [-bound, bound]
133 |
134 | sigma, albedo = self.common_forward(x)
135 |
136 | return {
137 | 'sigma': sigma,
138 | 'albedo': albedo,
139 | }
140 |
141 |
142 | def background(self, d):
143 |
144 | h = self.encoder_bg(d) # [N, C]
145 |
146 | h = self.bg_net(h)
147 |
148 | # sigmoid activation for rgb
149 | rgbs = torch.sigmoid(h)
150 |
151 | return rgbs
152 |
153 | # optimizer utils
154 | def get_params(self, lr):
155 |
156 | params = [
157 | {'params': self.encoder.parameters(), 'lr': lr * 10},
158 | {'params': self.sigma_net.parameters(), 'lr': lr},
159 | # {'params': self.normal_net.parameters(), 'lr': lr},
160 | ]
161 |
162 | if self.opt.bg_radius > 0:
163 | # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
164 | params.append({'params': self.bg_net.parameters(), 'lr': lr})
165 |
166 | if self.opt.dmtet and not self.opt.lock_geo:
167 | params.append({'params': self.sdf, 'lr': lr})
168 | params.append({'params': self.deform, 'lr': lr})
169 |
170 | return params
--------------------------------------------------------------------------------
/nerf/network_grid_tcnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from activation import trunc_exp, biased_softplus
6 | from .renderer import NeRFRenderer
7 |
8 | import numpy as np
9 | from encoding import get_encoder
10 |
11 | from .utils import safe_normalize
12 |
13 | import tinycudann as tcnn
14 |
15 | class MLP(nn.Module):
16 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
17 | super().__init__()
18 | self.dim_in = dim_in
19 | self.dim_out = dim_out
20 | self.dim_hidden = dim_hidden
21 | self.num_layers = num_layers
22 |
23 | net = []
24 | for l in range(num_layers):
25 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
26 |
27 | self.net = nn.ModuleList(net)
28 |
29 | def forward(self, x):
30 | for l in range(self.num_layers):
31 | x = self.net[l](x)
32 | if l != self.num_layers - 1:
33 | x = F.relu(x, inplace=True)
34 | return x
35 |
36 |
37 | class NeRFNetwork(NeRFRenderer):
38 | def __init__(self,
39 | opt,
40 | num_layers=3,
41 | hidden_dim=64,
42 | num_layers_bg=2,
43 | hidden_dim_bg=32,
44 | ):
45 |
46 | super().__init__(opt)
47 |
48 | self.num_layers = num_layers
49 | self.hidden_dim = hidden_dim
50 |
51 | self.encoder = tcnn.Encoding(
52 | n_input_dims=3,
53 | encoding_config={
54 | "otype": "HashGrid",
55 | "n_levels": 16,
56 | "n_features_per_level": 2,
57 | "log2_hashmap_size": 19,
58 | "base_resolution": 16,
59 | "interpolation": "Smoothstep",
60 | "per_level_scale": np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)),
61 | },
62 | dtype=torch.float32, # ENHANCE: default float16 seems unstable...
63 | )
64 | self.in_dim = self.encoder.n_output_dims
65 | # use torch MLP, as tcnn MLP doesn't impl second-order derivative
66 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
67 |
68 | self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus
69 |
70 | # background network
71 | if self.opt.bg_radius > 0:
72 | self.num_layers_bg = num_layers_bg
73 | self.hidden_dim_bg = hidden_dim_bg
74 |
75 | # use a very simple network to avoid it learning the prompt...
76 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
77 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
78 |
79 | else:
80 | self.bg_net = None
81 |
82 | def common_forward(self, x):
83 |
84 | # sigma
85 | enc = self.encoder((x + self.bound) / (2 * self.bound)).float()
86 | h = self.sigma_net(enc)
87 |
88 | sigma = self.density_activation(h[..., 0] + self.density_blob(x))
89 | albedo = torch.sigmoid(h[..., 1:])
90 |
91 | return sigma, albedo
92 |
93 | def normal(self, x):
94 |
95 | with torch.enable_grad():
96 | with torch.cuda.amp.autocast(enabled=False):
97 | x.requires_grad_(True)
98 | sigma, albedo = self.common_forward(x)
99 | # query gradient
100 | normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
101 |
102 | # normal = self.finite_difference_normal(x)
103 | normal = safe_normalize(normal)
104 | normal = torch.nan_to_num(normal)
105 |
106 | return normal
107 |
108 | def forward(self, x, d, l=None, ratio=1, shading='albedo'):
109 | # x: [N, 3], in [-bound, bound]
110 | # d: [N, 3], view direction, nomalized in [-1, 1]
111 | # l: [3], plane light direction, nomalized in [-1, 1]
112 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
113 |
114 |
115 | if shading == 'albedo':
116 | sigma, albedo = self.common_forward(x)
117 | normal = None
118 | color = albedo
119 |
120 | else: # lambertian shading
121 | with torch.enable_grad():
122 | with torch.cuda.amp.autocast(enabled=False):
123 | x.requires_grad_(True)
124 | sigma, albedo = self.common_forward(x)
125 | normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
126 | normal = safe_normalize(normal)
127 | normal = torch.nan_to_num(normal)
128 |
129 | lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
130 |
131 | if shading == 'textureless':
132 | color = lambertian.unsqueeze(-1).repeat(1, 3)
133 | elif shading == 'normal':
134 | color = (normal + 1) / 2
135 | else: # 'lambertian'
136 | color = albedo * lambertian.unsqueeze(-1)
137 |
138 | return sigma, color, normal
139 |
140 |
141 | def density(self, x):
142 | # x: [N, 3], in [-bound, bound]
143 |
144 | sigma, albedo = self.common_forward(x)
145 |
146 | return {
147 | 'sigma': sigma,
148 | 'albedo': albedo,
149 | }
150 |
151 |
152 | def background(self, d):
153 |
154 | h = self.encoder_bg(d) # [N, C]
155 |
156 | h = self.bg_net(h)
157 |
158 | # sigmoid activation for rgb
159 | rgbs = torch.sigmoid(h)
160 |
161 | return rgbs
162 |
163 | # optimizer utils
164 | def get_params(self, lr):
165 |
166 | params = [
167 | {'params': self.encoder.parameters(), 'lr': lr * 10},
168 | {'params': self.sigma_net.parameters(), 'lr': lr},
169 | ]
170 |
171 | if self.opt.bg_radius > 0:
172 | params.append({'params': self.bg_net.parameters(), 'lr': lr})
173 |
174 | if self.opt.dmtet and not self.opt.lock_geo:
175 | params.append({'params': self.sdf, 'lr': lr})
176 | params.append({'params': self.deform, 'lr': lr})
177 |
178 | return params
--------------------------------------------------------------------------------
/preprocess_image.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torchvision import transforms
12 | from PIL import Image
13 |
14 | class BackgroundRemoval():
15 | def __init__(self, device='cuda'):
16 |
17 | from carvekit.api.high import HiInterface
18 | self.interface = HiInterface(
19 | object_type="object", # Can be "object" or "hairs-like".
20 | batch_size_seg=5,
21 | batch_size_matting=1,
22 | device=device,
23 | seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
24 | matting_mask_size=2048,
25 | trimap_prob_threshold=231,
26 | trimap_dilation=30,
27 | trimap_erosion_iters=5,
28 | fp16=True,
29 | )
30 |
31 | @torch.no_grad()
32 | def __call__(self, image):
33 | # image: [H, W, 3] array in [0, 255].
34 | image = Image.fromarray(image)
35 |
36 | image = self.interface([image])[0]
37 | image = np.array(image)
38 |
39 | return image
40 |
41 | class BLIP2():
42 | def __init__(self, device='cuda'):
43 | self.device = device
44 | from transformers import AutoProcessor, Blip2ForConditionalGeneration
45 | self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
46 | self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
47 |
48 | @torch.no_grad()
49 | def __call__(self, image):
50 | image = Image.fromarray(image)
51 | inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16)
52 |
53 | generated_ids = self.model.generate(**inputs, max_new_tokens=20)
54 | generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
55 |
56 | return generated_text
57 |
58 |
59 | class DPT():
60 | def __init__(self, task='depth', device='cuda'):
61 |
62 | self.task = task
63 | self.device = device
64 |
65 | from dpt import DPTDepthModel
66 |
67 | if task == 'depth':
68 | path = 'pretrained/omnidata/omnidata_dpt_depth_v2.ckpt'
69 | self.model = DPTDepthModel(backbone='vitb_rn50_384')
70 | self.aug = transforms.Compose([
71 | transforms.Resize((384, 384)),
72 | transforms.ToTensor(),
73 | transforms.Normalize(mean=0.5, std=0.5)
74 | ])
75 |
76 | else: # normal
77 | path = 'pretrained/omnidata/omnidata_dpt_normal_v2.ckpt'
78 | self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3)
79 | self.aug = transforms.Compose([
80 | transforms.Resize((384, 384)),
81 | transforms.ToTensor()
82 | ])
83 |
84 | # load model
85 | checkpoint = torch.load(path, map_location='cpu')
86 | if 'state_dict' in checkpoint:
87 | state_dict = {}
88 | for k, v in checkpoint['state_dict'].items():
89 | state_dict[k[6:]] = v
90 | else:
91 | state_dict = checkpoint
92 | self.model.load_state_dict(state_dict)
93 | self.model.eval().to(device)
94 |
95 |
96 | @torch.no_grad()
97 | def __call__(self, image):
98 | # image: np.ndarray, uint8, [H, W, 3]
99 | H, W = image.shape[:2]
100 | image = Image.fromarray(image)
101 |
102 | image = self.aug(image).unsqueeze(0).to(self.device)
103 |
104 | if self.task == 'depth':
105 | depth = self.model(image).clamp(0, 1)
106 | depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False)
107 | depth = depth.squeeze(1).cpu().numpy()
108 | return depth
109 | else:
110 | normal = self.model(image).clamp(0, 1)
111 | normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False)
112 | normal = normal.cpu().numpy()
113 | return normal
114 |
115 |
116 |
117 | if __name__ == '__main__':
118 |
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)")
121 | parser.add_argument('--size', default=256, type=int, help="output resolution")
122 | parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio")
123 | parser.add_argument('--recenter', type=bool, default=True, help="recenter, potentially not helpful for multiview zero123")
124 | parser.add_argument('--dont_recenter', dest='recenter', action='store_false')
125 | opt = parser.parse_args()
126 |
127 | out_dir = os.path.dirname(opt.path)
128 | out_rgba = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_rgba.png')
129 | out_depth = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_depth.png')
130 | out_normal = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_normal.png')
131 | out_caption = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_caption.txt')
132 |
133 | # load image
134 | print(f'[INFO] loading image...')
135 | image = cv2.imread(opt.path, cv2.IMREAD_UNCHANGED)
136 | if image.shape[-1] == 4:
137 | image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
138 | else:
139 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
140 |
141 | # carve background
142 | print(f'[INFO] background removal...')
143 | carved_image = BackgroundRemoval()(image) # [H, W, 4]
144 | mask = carved_image[..., -1] > 0
145 |
146 | # predict depth
147 | print(f'[INFO] depth estimation...')
148 | dpt_depth_model = DPT(task='depth')
149 | depth = dpt_depth_model(image)[0]
150 | depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9)
151 | depth[~mask] = 0
152 | depth = (depth * 255).astype(np.uint8)
153 | del dpt_depth_model
154 |
155 | # predict normal
156 | print(f'[INFO] normal estimation...')
157 | dpt_normal_model = DPT(task='normal')
158 | normal = dpt_normal_model(image)[0]
159 | normal = (normal * 255).astype(np.uint8).transpose(1, 2, 0)
160 | normal[~mask] = 0
161 | del dpt_normal_model
162 |
163 | # recenter
164 | if opt.recenter:
165 | print(f'[INFO] recenter...')
166 | final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8)
167 | final_depth = np.zeros((opt.size, opt.size), dtype=np.uint8)
168 | final_normal = np.zeros((opt.size, opt.size, 3), dtype=np.uint8)
169 |
170 | coords = np.nonzero(mask)
171 | x_min, x_max = coords[0].min(), coords[0].max()
172 | y_min, y_max = coords[1].min(), coords[1].max()
173 | h = x_max - x_min
174 | w = y_max - y_min
175 | desired_size = int(opt.size * (1 - opt.border_ratio))
176 | scale = desired_size / max(h, w)
177 | h2 = int(h * scale)
178 | w2 = int(w * scale)
179 | x2_min = (opt.size - h2) // 2
180 | x2_max = x2_min + h2
181 | y2_min = (opt.size - w2) // 2
182 | y2_max = y2_min + w2
183 | final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
184 | final_depth[x2_min:x2_max, y2_min:y2_max] = cv2.resize(depth[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
185 | final_normal[x2_min:x2_max, y2_min:y2_max] = cv2.resize(normal[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
186 |
187 | else:
188 | final_rgba = carved_image
189 | final_depth = depth
190 | final_normal = normal
191 |
192 | # write output
193 | cv2.imwrite(out_rgba, cv2.cvtColor(final_rgba, cv2.COLOR_RGBA2BGRA))
194 | cv2.imwrite(out_depth, final_depth)
195 | cv2.imwrite(out_normal, final_normal)
196 |
197 | # predict caption (it's too slow... use your brain instead)
198 | # print(f'[INFO] captioning...')
199 | # blip2 = BLIP2()
200 | # caption = blip2(image)
201 | # with open(out_caption, 'w') as f:
202 | # f.write(caption)
203 |
204 |
--------------------------------------------------------------------------------
/pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "image_target"
11 | cond_stage_key: "image_cond"
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 |
19 | scheduler_config: # 10000 warmup steps
20 | target: ldm.lr_scheduler.LambdaLinearScheduler
21 | params:
22 | warm_up_steps: [ 100 ]
23 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
24 | f_start: [ 1.e-6 ]
25 | f_max: [ 1. ]
26 | f_min: [ 1. ]
27 |
28 | unet_config:
29 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
30 | params:
31 | image_size: 32 # unused
32 | in_channels: 8
33 | out_channels: 4
34 | model_channels: 320
35 | attention_resolutions: [ 4, 2, 1 ]
36 | num_res_blocks: 2
37 | channel_mult: [ 1, 2, 4, 4 ]
38 | num_heads: 8
39 | use_spatial_transformer: True
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: True
43 | legacy: False
44 |
45 | first_stage_config:
46 | target: ldm.models.autoencoder.AutoencoderKL
47 | params:
48 | embed_dim: 4
49 | monitor: val/rec_loss
50 | ddconfig:
51 | double_z: true
52 | z_channels: 4
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult:
58 | - 1
59 | - 2
60 | - 4
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: []
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 |
68 | cond_stage_config:
69 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
70 |
71 |
72 | # data:
73 | # target: ldm.data.simple.ObjaverseDataModuleFromConfig
74 | # params:
75 | # root_dir: 'views_whole_sphere'
76 | # batch_size: 192
77 | # num_workers: 16
78 | # total_view: 4
79 | # train:
80 | # validation: False
81 | # image_transforms:
82 | # size: 256
83 |
84 | # validation:
85 | # validation: True
86 | # image_transforms:
87 | # size: 256
88 |
89 |
90 | # lightning:
91 | # find_unused_parameters: false
92 | # metrics_over_trainsteps_checkpoint: True
93 | # modelcheckpoint:
94 | # params:
95 | # every_n_train_steps: 5000
96 | # callbacks:
97 | # image_logger:
98 | # target: main.ImageLogger
99 | # params:
100 | # batch_frequency: 500
101 | # max_images: 32
102 | # increase_log_steps: False
103 | # log_first_step: True
104 | # log_images_kwargs:
105 | # use_ema_scope: False
106 | # inpaint: False
107 | # plot_progressive_rows: False
108 | # plot_diffusion_rows: False
109 | # N: 32
110 | # unconditional_scale: 3.0
111 | # unconditional_label: [""]
112 |
113 | # trainer:
114 | # benchmark: True
115 | # val_check_interval: 5000000 # really sorry
116 | # num_sanity_val_steps: 0
117 | # accumulate_grad_batches: 1
118 |
--------------------------------------------------------------------------------
/raymarching/__init__.py:
--------------------------------------------------------------------------------
1 | from .raymarching import *
--------------------------------------------------------------------------------
/raymarching/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | ]
10 |
11 | if os.name == "posix":
12 | c_flags = ['-O3', '-std=c++14']
13 | elif os.name == "nt":
14 | c_flags = ['/O2', '/std:c++17']
15 |
16 | # find cl.exe
17 | def find_cl_path():
18 | import glob
19 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
22 | if paths:
23 | return paths[0]
24 |
25 | # If cl.exe is not on path, try to find it.
26 | if os.system("where cl.exe >nul 2>nul") != 0:
27 | cl_path = find_cl_path()
28 | if cl_path is None:
29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30 | os.environ["PATH"] += ";" + cl_path
31 |
32 | _backend = load(name='_raymarching',
33 | extra_cflags=c_flags,
34 | extra_cuda_cflags=nvcc_flags,
35 | sources=[os.path.join(_src_path, 'src', f) for f in [
36 | 'raymarching.cu',
37 | 'bindings.cpp',
38 | ]],
39 | )
40 |
41 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/raymarching/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
22 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
23 | if paths:
24 | return paths[0]
25 |
26 | # If cl.exe is not on path, try to find it.
27 | if os.system("where cl.exe >nul 2>nul") != 0:
28 | cl_path = find_cl_path()
29 | if cl_path is None:
30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
31 | os.environ["PATH"] += ";" + cl_path
32 |
33 | '''
34 | Usage:
35 |
36 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
37 |
38 | python setup.py install # build extensions and install (copy) to PATH.
39 | pip install . # ditto but better (e.g., dependency & metadata handling)
40 |
41 | python setup.py develop # build extensions and install (symbolic) to PATH.
42 | pip install -e . # ditto but better (e.g., dependency & metadata handling)
43 |
44 | '''
45 | setup(
46 | name='raymarching', # package name, import this to use python API
47 | ext_modules=[
48 | CUDAExtension(
49 | name='_raymarching', # extension name, import this to use CUDA API
50 | sources=[os.path.join(_src_path, 'src', f) for f in [
51 | 'raymarching.cu',
52 | 'bindings.cpp',
53 | ]],
54 | extra_compile_args={
55 | 'cxx': c_flags,
56 | 'nvcc': nvcc_flags,
57 | }
58 | ),
59 | ],
60 | cmdclass={
61 | 'build_ext': BuildExtension,
62 | }
63 | )
--------------------------------------------------------------------------------
/raymarching/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "raymarching.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | // utils
7 | m.def("flatten_rays", &flatten_rays, "flatten_rays (CUDA)");
8 | m.def("packbits", &packbits, "packbits (CUDA)");
9 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
10 | m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
11 | m.def("morton3D", &morton3D, "morton3D (CUDA)");
12 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
13 | // train
14 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
15 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
16 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
17 | // infer
18 | m.def("march_rays", &march_rays, "march rays (CUDA)");
19 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
20 | }
--------------------------------------------------------------------------------
/raymarching/src/raymarching.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 |
7 | void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
8 | void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
9 | void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
10 | void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
11 | void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
12 | void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res);
13 |
14 | void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional xyzs, at::optional dirs, at::optional ts, at::Tensor rays, at::Tensor counter, at::Tensor noises);
15 | void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
16 | void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
17 |
18 | void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises);
19 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | rich
3 | ninja
4 | numpy
5 | pandas
6 | scipy
7 | scikit-learn
8 | matplotlib
9 | opencv-python
10 | imageio
11 | imageio-ffmpeg
12 |
13 | torch
14 | torch-ema
15 | einops
16 | tensorboard
17 | tensorboardX
18 |
19 | # for gui
20 | dearpygui
21 |
22 | # for grid_tcnn
23 | # git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
24 |
25 | # for stable-diffusion
26 | huggingface_hub
27 | diffusers >= 0.9.0
28 | accelerate
29 | transformers
30 |
31 | # for dmtet and mesh export
32 | xatlas
33 | trimesh
34 | PyMCubes
35 | pymeshlab
36 | git+https://github.com/NVlabs/nvdiffrast/
37 |
38 | # for zero123
39 | carvekit-colab
40 | omegaconf
41 | pytorch-lightning
42 | taming-transformers-rom1504
43 | kornia
44 | git+https://github.com/openai/CLIP.git
45 |
46 | # for omnidata
47 | gdown
48 |
49 | # for dpt
50 | timm
51 |
52 | # for remote debugging
53 | debugpy-run
54 |
55 | # for deepfloyd if
56 | sentencepiece
57 |
--------------------------------------------------------------------------------
/scripts/install_ext.sh:
--------------------------------------------------------------------------------
1 | pip install ./raymarching
2 | pip install ./shencoder
3 | pip install ./freqencoder
4 | pip install ./gridencoder
--------------------------------------------------------------------------------
/scripts/res64.args:
--------------------------------------------------------------------------------
1 | -O --vram_O --w 64 --h 64
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a delicious hamburger" --workspace trial_hamburger --iters 5000
3 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a delicious hamburger" --workspace trial2_hamburger --dmtet --iters 5000 --init_with trial_hamburger/checkpoints/df.pth
4 |
5 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a highly detailed stone bust of Theodoros Kolokotronis" --workspace trial_stonehead --iters 5000
6 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a highly detailed stone bust of Theodoros Kolokotronis" --workspace trial2_stonehead --dmtet --iters 5000 --init_with trial_stonehead/checkpoints/df.pth
7 |
8 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "an astronaut, full body" --workspace trial_astronaut --iters 5000
9 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "an astronaut, full body" --workspace trial2_astronaut --dmtet --iters 5000 --init_with trial_astronaut/checkpoints/df.pth
10 |
11 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a squirrel-octopus hybrid" --workspace trial_squrrel_octopus --iters 5000
12 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a squirrel-octopus hybrid" --workspace trial2_squrrel_octopus --dmtet --iters 5000 --init_with trial_squrrel_octopus/checkpoints/df.pth
13 |
14 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a baby bunny sitting on top of a stack of pancakes" --workspace trial_rabbit_pancake --iters 5000
15 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a metal bunny sitting on top of a stack of chocolate cookies" --workspace trial2_rabbit_pancake --dmtet --iters 5000 --init_with trial_rabbit_pancake/checkpoints/df.pth
--------------------------------------------------------------------------------
/scripts/run2.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a DSLR photo of a shiba inu playing golf wearing tartan golf clothes and hat" --workspace trial_shiba --iters 10000
4 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a DSLR photo of a shiba inu playing golf wearing tartan golf clothes and hat" --workspace trial2_shiba --dmtet --iters 5000 --init_with trial_shiba/checkpoints/df.pth
5 |
6 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a banana peeling itself" --workspace trial_banana --iters 10000
7 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a banana peeling itself" --workspace trial2_banana --dmtet --iters 5000 --init_with trial_banana/checkpoints/df.pth
8 |
9 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a capybara wearing a top hat, low poly" --workspace trial_capybara --iters 10000
10 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a capybara wearing a top hat, low poly" --workspace trial2_capybara --dmtet --iters 5000 --init_with trial_capybara/checkpoints/df.pth
--------------------------------------------------------------------------------
/scripts/run3.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "ironman, full body" --workspace trial_ironman --iters 10000
4 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "ironman, full body" --workspace trial2_ironman --dmtet --iters 5000 --init_with trial_ironman/checkpoints/df.pth
5 |
6 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of an ice cream sundae" --workspace trial_icecream --iters 10000
7 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of an ice cream sundae" --workspace trial2_icecream --dmtet --iters 5000 --init_with trial_icecream/checkpoints/df.pth
8 |
9 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of a kingfisher bird" --workspace trial_bird --iters 10000
10 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of a kingfisher bird" --workspace trial2_bird --dmtet --iters 5000 --init_with trial_bird/checkpoints/df.pth
11 |
12 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a car made of sushi" --workspace trial_sushi --iters 10000
13 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a car made of sushi" --workspace trial2_sushi --dmtet --iters 5000 --init_with trial_sushi/checkpoints/df.pth
14 |
--------------------------------------------------------------------------------
/scripts/run4.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a rabbit, animated movie character, high detail 3d model" --workspace trial_rabbit2 --iters 10000
4 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a rabbit, animated movie character, high detail 3d model" --workspace trial2_rabbit2 --dmtet --iters 5000 --init_with trial_rabbit2/checkpoints/df.pth
5 |
6 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a corgi dog, highly detailed 3d model" --workspace trial_corgi --iters 10000
7 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a corgi dog, highly detailed 3d model" --workspace trial2_corgi --dmtet --iters 5000 --init_with trial_corgi/checkpoints/df.pth
8 |
9 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text " a small saguaro cactus planted in a clay pot" --workspace trial_cactus --iters 10000
10 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text " a small saguaro cactus planted in a clay pot" --workspace trial2_cactus --dmtet --iters 5000 --init_with trial_cactus/checkpoints/df.pth
11 |
12 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "the leaning tower of Pisa" --workspace trial_pisa --iters 10000
13 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "the leaning tower of Pisa" --workspace trial2_pisa --dmtet --iters 5000 --init_with trial_pisa/checkpoints/df.pth
--------------------------------------------------------------------------------
/scripts/run5.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Perched blue jay bird" --workspace trial_jay --iters 10000
4 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Perched blue jay bird" --workspace trial2_jay --dmtet --iters 5000 --init_with trial_jay/checkpoints/df.pth
5 |
6 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "angel statue wings out" --workspace trial_angle --iters 10000
7 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "angel statue wings out" --workspace trial2_angle --dmtet --iters 5000 --init_with trial_angle/checkpoints/df.pth
8 |
9 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "devil statue" --workspace trial_devil --iters 10000
10 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "devil statue" --workspace trial2_devil --dmtet --iters 5000 --init_with trial_devil/checkpoints/df.pth
11 |
12 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Einstein statue" --workspace trial_einstein --iters 10000
13 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Einstein statue" --workspace trial2_einstein --dmtet --iters 5000 --init_with trial_einstein/checkpoints/df.pth
14 |
--------------------------------------------------------------------------------
/scripts/run6.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a baby bunny sitting on top of a stack of pancakes" --workspace trial_rabbit_pancake --iters 5000
3 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a metal bunny sitting on top of a stack of chocolate cookies" --workspace trial2_rabbit_pancake --dmtet --iters 5000 --init_with trial_rabbit_pancake/checkpoints/df.pth
4 |
5 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial_jay --iters 5000
6 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial2_jay --dmtet --iters 5000 --init_with trial_jay/checkpoints/df.pth
7 |
8 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial_fox --iters 5000
9 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial2_fox --dmtet --iters 5000 --init_with trial_fox/checkpoints/df.pth
10 |
11 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial_peacock --iters 5000
12 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial2_peacock --dmtet --iters 5000 --init_with trial_peacock/checkpoints/df.pth
13 |
14 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a flower made out of metal" --workspace trial_metal_flower --iters 5000
15 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a flower made out of metal" --workspace trial2_metal_flower --dmtet --iters 5000 --init_with trial_metal_flower/checkpoints/df.pth
16 |
17 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial_chicken --iters 5000
18 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial2_chicken --dmtet --iters 5000 --init_with trial_chicken/checkpoints/df.pth
--------------------------------------------------------------------------------
/scripts/run_if.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a baby bunny sitting on top of a stack of pancakes" --workspace trial_if_rabbit_pancake --iters 5000 --IF
3 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a metal bunny sitting on top of a stack of chocolate cookies" --workspace trial_if2_rabbit_pancake --dmtet --iters 5000 --init_with trial_if_rabbit_pancake/checkpoints/df.pth
4 |
5 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial_if_jay --iters 5000 --IF
6 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial_if2_jay --dmtet --iters 5000 --init_with trial_if_jay/checkpoints/df.pth
7 |
8 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial_if_fox --iters 5000 --IF
9 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial_if2_fox --dmtet --iters 5000 --init_with trial_if_fox/checkpoints/df.pth
10 |
11 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial_if_peacock --iters 5000 --IF
12 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial_if2_peacock --dmtet --iters 5000 --init_with trial_if_peacock/checkpoints/df.pth
13 |
14 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a flower made out of metal" --workspace trial_if_metal_flower --iters 5000 --IF
15 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a flower made out of metal" --workspace trial_if2_metal_flower --dmtet --iters 5000 --init_with trial_if_metal_flower/checkpoints/df.pth
16 |
17 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial_if_chicken --iters 5000 --IF
18 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial_if2_chicken --dmtet --iters 5000 --init_with trial_if_chicken/checkpoints/df.pth
--------------------------------------------------------------------------------
/scripts/run_if2.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a corgi taking a selfie" --workspace trial_if_corgi --iters 5000 --IF
3 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a corgi taking a selfie" --workspace trial_if2_corgi --dmtet --iters 5000 --init_with trial_if_corgi/checkpoints/df.pth
4 |
5 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a ghost eating a hamburger" --workspace trial_if_ghost --iters 5000 --IF
6 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a ghost eating a hamburger" --workspace trial_if2_ghost --dmtet --iters 5000 --init_with trial_if_ghost/checkpoints/df.pth
7 |
8 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of an origami motorcycle" --workspace trial_if_motor --iters 5000 --IF
9 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of an origami motorcycle" --workspace trial_if2_motor --dmtet --iters 5000 --init_with trial_if_motor/checkpoints/df.pth
10 |
11 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a Space Shuttle" --workspace trial_if_spaceshuttle --iters 5000 --IF
12 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a Space Shuttle" --workspace trial_if2_spaceshuttle --dmtet --iters 5000 --init_with trial_if_spaceshuttle/checkpoints/df.pth
13 |
14 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a palm tree, low poly 3d model" --workspace trial_if_palm --iters 5000 --IF
15 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a palm tree, low poly 3d model" --workspace trial_if2_palm --dmtet --iters 5000 --init_with trial_if_palm/checkpoints/df.pth
16 |
17 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a zoomed out DSLR photo of a marble bust of a cat, a real mouse is sitting on its head" --workspace trial_if_cat_mouse --iters 5000 --IF
18 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a zoomed out DSLR photo of a marble bust of a cat, a real mouse is sitting on its head" --workspace trial_if2_cat_mouse --dmtet --iters 5000 --init_with trial_if_cat_mouse/checkpoints/df.pth
--------------------------------------------------------------------------------
/scripts/run_if2_perpneg.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # To avoid the Janus problem caused by the diffusion model's front view bias, utilize the Perp-Neg algorithm. To maximize its benefits,
3 | # increase the absolute value of "negative_w" for improved Janus problem mitigation. If you encounter flat faces or divergence, consider
4 | # reducing the absolute value of "negative_w". The value of "negative_w" should vary for each prompt due to the diffusion model's varying
5 | # bias towards generating front views for different objects. Vary the weights within the range of 0 to -4.
6 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a lion bust" --workspace trial_perpneg_if_lion --iters 5000 --IF --batch_size 1 --perpneg
7 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a marble lion head" --workspace trial_perpneg_if2_lion_p --dmtet --iters 5000 --perpneg --init_with trial_perpneg_if_lion/checkpoints/df.pth
8 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a marble lion head" --workspace trial_perpneg_if2_lion_nop --dmtet --iters 5000 --init_with trial_perpneg_if_lion/checkpoints/df.pth
9 |
10 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a tiger cub" --workspace trial_perpneg_if_tiger --iters 5000 --IF --batch_size 1 --perpneg
11 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "tiger" --workspace trial_perpneg_if2_tiger_p --dmtet --iters 5000 --perpneg --init_with trial_perpneg_if_tiger/checkpoints/df.pth
12 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "tiger" --workspace trial_perpneg_if2_tiger_nop --dmtet --iters 5000 --init_with trial_perpneg_if_tiger/checkpoints/df.pth
13 |
14 | # larger absolute value of negative_w is used for the following command because the defult negative weight of -2 is not enough to make the diffusion model to produce the views as desired
15 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a shiba dog wearing sunglasses" --workspace trial_perpneg_if_shiba --iters 5000 --IF --batch_size 1 --perpneg --negative_w -3.0
16 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "shiba wearing sunglasses" --workspace trial_perpneg_if2_shiba_p --dmtet --iters 5000 --perpneg --negative_w -3.0 --init_with trial_perpneg_if_shiba/checkpoints/df.pth
17 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "shiba wearing sunglasses" --workspace trial_perpneg_if2_shiba_nop --dmtet --iters 5000 --init_with trial_perpneg_if_shiba/checkpoints/df.pth
18 |
19 |
--------------------------------------------------------------------------------
/scripts/run_image.sh:
--------------------------------------------------------------------------------
1 | # zero123 backend (single object, images like 3d model rendering)
2 |
3 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/teddy_rgba.png --workspace trial_image_teddy --iters 5000
4 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/teddy_rgba.png --workspace trial2_image_teddy --iters 5000 --dmtet --init_with trial_image_teddy/checkpoints/df.pth
5 |
6 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/catstatue_rgba.png --workspace trial_image_catstatue --iters 5000
7 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/catstatue_rgba.png --workspace trial2_image_catstatue --iters 5000 --dmtet --init_with trial_image_catstatue/checkpoints/df.pth
8 |
9 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/firekeeper_rgba.png --workspace trial_image_firekeeper --iters 5000
10 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/firekeeper_rgba.png --workspace trial2_image_firekeeper --iters 5000 --dmtet --init_with trial_image_firekeeper/checkpoints/df.pth
11 |
12 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/hamburger_rgba.png --workspace trial_image_hamburger --iters 5000
13 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/hamburger_rgba.png --workspace trial2_image_hamburger --iters 5000 --dmtet --init_with trial_image_hamburger/checkpoints/df.pth
14 |
15 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/corgi_rgba.png --workspace trial_image_corgi --iters 5000
16 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/corgi_rgba.png --workspace trial2_image_corgi --iters 5000 --dmtet --init_with trial_image_corgi/checkpoints/df.pth
17 |
18 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cactus_rgba.png --workspace trial_image_cactus --iters 5000
19 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cactus_rgba.png --workspace trial2_image_cactus --iters 5000 --dmtet --init_with trial_image_cactus/checkpoints/df.pth
20 |
21 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cake_rgba.png --workspace trial_image_cake --iters 5000
22 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cake_rgba.png --workspace trial2_image_cake --iters 5000 --dmtet --init_with trial_image_cake/checkpoints/df.pth
23 |
24 | # CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/warrior_rgba.png --workspace trial_image_warrior --iters 5000
25 | # CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/warrior_rgba.png --workspace trial2_image_warrior --iters 5000 --dmtet --init_with trial_image_warrior/checkpoints/df.pth
--------------------------------------------------------------------------------
/scripts/run_image_anya.sh:
--------------------------------------------------------------------------------
1 | # Phase 1 - barely fits in A100 40GB.
2 | # Conclusion: results in concave-ish face, no neck, excess hair in the back
3 | CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage \
4 | --iters 10000 --save_guidance --save_guidance_interval 10 --ckpt scratch --batch_size 2 --test_interval 2 \
5 | --h 128 --w 128 --zero123_grad_scale None
6 |
7 | # Phase 2 - barely fits in A100 40GB.
8 | # 20X smaller lambda_3d_normal_smooth, --known_view_interval 2, 3X LR
9 | # Much higher jitter to increase disparity (and eliminate some of the flatness)... not too high either (to avoid cropping the face)
10 | CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage_B_GPU2_reproduction1_GPU2 \
11 | --text "A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" \
12 | --iters 12500 --ckpt trial_anya_1_refimage/checkpoints/df_ep0100.pth --save_guidance --save_guidance_interval 1 \
13 | --h 256 --w 256 --albedo_iter_ratio 0.0 --t_range 0.2 0.6 --batch_size 4 --radius_range 2.2 2.6 --test_interval 2 \
14 | --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.1 --jitter_target 0.1 --jitter_up 0.05 \
15 | --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --progressive_view --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 1 \
16 | --exp_start_iter 10000 --exp_end_iter 12500
17 |
18 | # Phase 3 - increase resolution to 512
19 | # Disable textureless since they can cause catastrophic divergence
20 | # Since radius range is inconsistent, increase it, and reduce the jitter to avoid excessively cropped renders.
21 | # Learning rate may be set too high, since `--batch_size 1`.
22 | CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage_B_GPU2_reproduction1_GPU2_refinedGPU2 \
23 | --text "A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" \
24 | --iters 25000 --ckpt trial_anya_1_refimage_B_GPU2_reproduction1_GPU2/checkpoints/df_ep0125.pth --save_guidance --save_guidance_interval 1 \
25 | --h 512 --w 512 --albedo_iter_ratio 0.0 --t_range 0.0 0.5 --batch_size 1 --radius_range 3.2 3.6 --test_interval 2 \
26 | --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.015 --jitter_target 0.015 --jitter_up 0.05 \
27 | --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 0.5 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \
28 | --exp_start_iter 12500 --exp_end_iter 25000
29 |
30 | # Generate 6 views
31 | CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --ckpt trial_anya_1_refimage_B_GPU2_reproduction1_GPU2_refinedGPU2/checkpoints/df_ep0250.pth --six_views
32 |
33 | # Phase 4 - untested, need to adjust
34 | # CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage --iters 5000 --dmtet --init_with trial_anya_1_refimage/checkpoints/df.pth
35 |
36 |
--------------------------------------------------------------------------------
/scripts/run_image_hard_examples.sh:
--------------------------------------------------------------------------------
1 | bash scripts/run_image_procedure.sh 0 30 90 anya_front "A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )"
2 | bash scripts/run_image_procedure.sh 1 30 70 baby_phoenix_on_ice "A DSLR 3D photo of an adorable baby phoenix made in Swarowski crystal highly detailed intricate concept art 8K ( unreal engine 5 trending on Artstation )"
3 | bash scripts/run_image_procedure.sh 2 30 90 bollywood_actress "A DSLR 3D photo of a beautiful bollywood indian actress, pretty eyes, full body shot composition, sunny outdoor, seen from far away ( highly detailed intricate 8K unreal engine 5 trending on Artstation )"
4 | bash scripts/run_image_procedure.sh 3 30 40 beach_house_1 "A DSLR 3D photo of a very beautiful small house on a beach ( highly detailed intricate 8K unreal engine 5 trending on Artstation )"
5 | bash scripts/run_image_procedure.sh 4 30 60 beach_house_2 "A DSLR 3D photo of a very beautiful high-tech small house with solar panels and wildflowers on a beach ( highly detailed intricate 8K unreal engine 5 trending on Artstation )"
6 | bash scripts/run_image_procedure.sh 5 30 90 mona_lisa "A DSLR 3D photo of a beautiful young woman dressed like Mona Lisa ( highly detailed intricate 8K unreal engine 5 trending on Artstation )"
7 | bash scripts/run_image_procedure.sh 6 30 80 futuristic_car "A DSLR 3D photo of a crazily futuristic electric car ( highly detailed intricate 8K unreal engine 5 trending on Artstation )"
8 | # the church ruins probably require a wider field of view... e.g. 90 degrees, maybe even more... so may not work with Zero123 etc.
9 | bash scripts/run_image_procedure.sh 7 30 90 church_ruins "A DSLR 3D photo of the remains of an isolated old church ruin covered in ivy ( highly detailed intricate 8K unreal engine 5 trending on Artstation )"
10 |
11 | # young woman dressed like mona lisa
--------------------------------------------------------------------------------
/scripts/run_image_procedure.sh:
--------------------------------------------------------------------------------
1 | # Perform a 2D-to-3D reconstruction, similar to the Anya case study: https://github.com/ashawkey/stable-dreamfusion/issues/263
2 | # Args:
3 | # bash scripts/run_image_procedure.sh GPU_ID guidance_interval image_name "prompt"
4 | # e.g.:
5 | # bash scripts/run_image_procedure 1 30 baby_phoenix_on_ice "An adorable baby phoenix made in Swarowski crystal highly detailed intricated concept art 8K"
6 | GPU_ID=$1
7 | GUIDANCE_INTERVAL=$2
8 | DEFAULT_POLAR=$3
9 | PREFIX=$4
10 | PROMPT=$5
11 | EPOCHS1=100
12 | EPOCHS2=200
13 | EPOCHS3=300
14 | IMAGE=data/$PREFIX.png
15 | IMAGE_RGBA=data/${PREFIX}_rgba.png
16 | WS_PH1=trial_$PREFIX-ph1
17 | WS_PH2=trial_$PREFIX-ph2
18 | WS_PH3=trial_$PREFIX-ph3
19 | CKPT1=$WS_PH1/checkpoints/df_ep0${EPOCHS1}.pth
20 | CKPT2=$WS_PH2/checkpoints/df_ep0${EPOCHS2}.pth
21 | CKPT3=$WS_PH3/checkpoints/df_ep0${EPOCHS3}.pth
22 |
23 | # Can uncomment to clear up trial folders. Be careful - mistakes could erase important work!
24 | # rm -r $WS_PH1 $WS_PH2 $WS_PH3
25 |
26 | # Preprocess
27 | if [ ! -f $IMAGE_RGBA ]
28 | then
29 | python preprocess_image.py $IMAGE
30 | fi
31 |
32 | if [ ! -f $CKPT1 ]
33 | then
34 | # Phase 1 - zero123-guidance
35 | # WARNING: claforte: constantly runs out of VRAM with resolution of 128x128 and batch_size 2... no longer able to reproduce Anya result because of this...
36 | # I added these to try to reduce mem usage, but this might degrade the quality... `--lambda_depth 0 --lambda_3d_normal_smooth 0`
37 | # Remove: --ckpt scratch
38 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH1 --default_polar $DEFAULT_POLAR \
39 | --iters ${EPOCHS1}00 --save_guidance --save_guidance_interval $GUIDANCE_INTERVAL --batch_size 1 --test_interval 2 \
40 | --h 96 --w 96 --zero123_grad_scale None --lambda_3d_normal_smooth 0 --dont_override_stuff \
41 | --fovy_range 20 20 --guidance_scale 5
42 | fi
43 |
44 | GUIDANCE_INTERVAL=7
45 | if [ ! -f $CKPT2 ]
46 | then
47 | # Phase 2 - SD-guidance at 256x256
48 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH2 \
49 | --text "${PROMPT}" --default_polar $DEFAULT_POLAR \
50 | --iters ${EPOCHS2}00 --ckpt $CKPT1 --save_guidance --save_guidance_interval 7 \
51 | --h 128 --w 128 --albedo_iter_ratio 0.0 --t_range 0.2 0.6 --batch_size 4 --radius_range 2.2 2.6 --test_interval 2 \
52 | --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.1 --jitter_target 0.1 --jitter_up 0.05 \
53 | --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --progressive_view --progressive_view_init_ratio 0.05 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 1 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \
54 | --exp_start_iter ${EPOCHS1}00 --exp_end_iter ${EPOCHS2}00
55 | fi
56 |
57 | if [ ! -f $CKPT3 ]
58 | then
59 | # # Phase 3 - increase resolution to 512
60 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH3 \
61 | --text "${PROMPT}" --default_polar $DEFAULT_POLAR \
62 | --iters ${EPOCHS3}00 --ckpt $CKPT2 --save_guidance --save_guidance_interval 7 \
63 | --h 512 --w 512 --albedo_iter_ratio 0.0 --t_range 0.0 0.5 --batch_size 1 --radius_range 3.2 3.6 --test_interval 2 \
64 | --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.015 --jitter_target 0.015 --jitter_up 0.05 \
65 | --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 0.5 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \
66 | --exp_start_iter ${EPOCHS2}00 --exp_end_iter ${EPOCHS3}00
67 | fi
68 |
69 | # Generate 6 views
70 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --ckpt $CKPT3 --six_views
71 |
72 |
--------------------------------------------------------------------------------
/scripts/run_image_text.sh:
--------------------------------------------------------------------------------
1 | # sd backend (realistic images)
2 |
3 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/teddy_rgba.png --text "a brown teddy bear sitting on a ground" --workspace trial_imagetext_teddy --iters 5000
4 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/teddy_rgba.png --text "a brown teddy bear sitting on a ground" --workspace trial2_imagetext_teddy --iters 10000 --dmtet --init_with trial_imagetext_teddy/checkpoints/df.pth
5 |
6 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/corgi_rgba.png --text "a corgi running" --workspace trial_imagetext_corgi --iters 5000
7 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/corgi_rgba.png --text "a corgi running" --workspace trial2_imagetext_corgi --iters 10000 --dmtet --init_with trial_imagetext_corgi/checkpoints/df.pth
8 |
9 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/hamburger_rgba.png --text "a DSLR photo of a delicious hamburger" --workspace trial_imagetext_hamburger --iters 5000
10 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/hamburger_rgba.png --text "a DSLR photo of a delicious hamburger" --workspace trial2_imagetext_hamburger --iters 10000 --dmtet --init_with trial_imagetext_hamburger/checkpoints/df.pth
11 |
12 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/cactus_rgba.png --text "a potted cactus plant" --workspace trial_imagetext_cactus --iters 5000
13 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/cactus_rgba.png --text "a potted cactus plant" --workspace trial2_imagetext_cactus --iters 10000 --dmtet --init_with trial_imagetext_cactus/checkpoints/df.pth
14 |
--------------------------------------------------------------------------------
/scripts/run_images.sh:
--------------------------------------------------------------------------------
1 | # zero123 backend (single object, images like 3d model rendering)
2 |
3 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/corgi.csv --workspace trial_images_corgi --iters 5000
4 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/corgi.csv --workspace trial2_images_corgi --iters 10000 --dmtet --init_with trial_images_corgi/checkpoints/df.pth
5 |
6 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/car.csv --workspace trial_images_car --iters 5000
7 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/car.csv --workspace trial2_images_car --iters 10000 --dmtet --init_with trial_images_car/checkpoints/df.pth
8 |
9 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/anya.csv --workspace trial_images_anya --iters 5000
10 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/anya.csv --workspace trial2_images_anya --iters 10000 --dmtet --init_with trial_images_anya/checkpoints/df.pth
--------------------------------------------------------------------------------
/shencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .sphere_harmonics import SHEncoder
--------------------------------------------------------------------------------
/shencoder/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | ]
10 |
11 | if os.name == "posix":
12 | c_flags = ['-O3', '-std=c++14']
13 | elif os.name == "nt":
14 | c_flags = ['/O2', '/std:c++17']
15 |
16 | # find cl.exe
17 | def find_cl_path():
18 | import glob
19 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
22 | if paths:
23 | return paths[0]
24 |
25 | # If cl.exe is not on path, try to find it.
26 | if os.system("where cl.exe >nul 2>nul") != 0:
27 | cl_path = find_cl_path()
28 | if cl_path is None:
29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30 | os.environ["PATH"] += ";" + cl_path
31 |
32 | _backend = load(name='_sh_encoder',
33 | extra_cflags=c_flags,
34 | extra_cuda_cflags=nvcc_flags,
35 | sources=[os.path.join(_src_path, 'src', f) for f in [
36 | 'shencoder.cu',
37 | 'bindings.cpp',
38 | ]],
39 | )
40 |
41 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/shencoder/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
22 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
23 | if paths:
24 | return paths[0]
25 |
26 | # If cl.exe is not on path, try to find it.
27 | if os.system("where cl.exe >nul 2>nul") != 0:
28 | cl_path = find_cl_path()
29 | if cl_path is None:
30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
31 | os.environ["PATH"] += ";" + cl_path
32 |
33 | setup(
34 | name='shencoder', # package name, import this to use python API
35 | ext_modules=[
36 | CUDAExtension(
37 | name='_shencoder', # extension name, import this to use CUDA API
38 | sources=[os.path.join(_src_path, 'src', f) for f in [
39 | 'shencoder.cu',
40 | 'bindings.cpp',
41 | ]],
42 | extra_compile_args={
43 | 'cxx': c_flags,
44 | 'nvcc': nvcc_flags,
45 | }
46 | ),
47 | ],
48 | cmdclass={
49 | 'build_ext': BuildExtension,
50 | }
51 | )
--------------------------------------------------------------------------------
/shencoder/sphere_harmonics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Function
6 | from torch.autograd.function import once_differentiable
7 | from torch.cuda.amp import custom_bwd, custom_fwd
8 |
9 | try:
10 | import _shencoder as _backend
11 | except ImportError:
12 | from .backend import _backend
13 |
14 | class _sh_encoder(Function):
15 | @staticmethod
16 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
17 | def forward(ctx, inputs, degree, calc_grad_inputs=False):
18 | # inputs: [B, input_dim], float in [-1, 1]
19 | # RETURN: [B, F], float
20 |
21 | inputs = inputs.contiguous()
22 | B, input_dim = inputs.shape # batch size, coord dim
23 | output_dim = degree ** 2
24 |
25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
26 |
27 | if calc_grad_inputs:
28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
29 | else:
30 | dy_dx = None
31 |
32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
33 |
34 | ctx.save_for_backward(inputs, dy_dx)
35 | ctx.dims = [B, input_dim, degree]
36 |
37 | return outputs
38 |
39 | @staticmethod
40 | #@once_differentiable
41 | @custom_bwd
42 | def backward(ctx, grad):
43 | # grad: [B, C * C]
44 |
45 | inputs, dy_dx = ctx.saved_tensors
46 |
47 | if dy_dx is not None:
48 | grad = grad.contiguous()
49 | B, input_dim, degree = ctx.dims
50 | grad_inputs = torch.zeros_like(inputs)
51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
52 | return grad_inputs, None, None
53 | else:
54 | return None, None, None
55 |
56 |
57 |
58 | sh_encode = _sh_encoder.apply
59 |
60 |
61 | class SHEncoder(nn.Module):
62 | def __init__(self, input_dim=3, degree=4):
63 | super().__init__()
64 |
65 | self.input_dim = input_dim # coord dims, must be 3
66 | self.degree = degree # 0 ~ 4
67 | self.output_dim = degree ** 2
68 |
69 | assert self.input_dim == 3, "SH encoder only support input dim == 3"
70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
71 |
72 | def __repr__(self):
73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
74 |
75 | def forward(self, inputs, size=1):
76 | # inputs: [..., input_dim], normalized real world positions in [-size, size]
77 | # return: [..., degree^2]
78 |
79 | inputs = inputs / size # [-1, 1]
80 |
81 | prefix_shape = list(inputs.shape[:-1])
82 | inputs = inputs.reshape(-1, self.input_dim)
83 |
84 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
85 | outputs = outputs.reshape(prefix_shape + [self.output_dim])
86 |
87 | return outputs
--------------------------------------------------------------------------------
/shencoder/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "shencoder.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)");
7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)");
8 | }
--------------------------------------------------------------------------------
/shencoder/src/shencoder.h:
--------------------------------------------------------------------------------
1 | # pragma once
2 |
3 | #include
4 | #include
5 |
6 | // inputs: [B, D], float, in [-1, 1]
7 | // outputs: [B, F], float
8 |
9 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx);
10 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs);
--------------------------------------------------------------------------------
/taichi_modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .ray_march import RayMarcherTaichi, raymarching_test
2 | from .volume_train import VolumeRendererTaichi
3 | from .intersection import RayAABBIntersector
4 | from .volume_render_test import composite_test
5 | from .utils import packbits
--------------------------------------------------------------------------------
/taichi_modules/intersection.py:
--------------------------------------------------------------------------------
1 | import taichi as ti
2 | import torch
3 | from taichi.math import vec3
4 | from torch.cuda.amp import custom_fwd
5 |
6 | from .utils import NEAR_DISTANCE
7 |
8 |
9 | @ti.kernel
10 | def simple_ray_aabb_intersec_taichi_forward(
11 | hits_t: ti.types.ndarray(ndim=2),
12 | rays_o: ti.types.ndarray(ndim=2),
13 | rays_d: ti.types.ndarray(ndim=2),
14 | centers: ti.types.ndarray(ndim=2),
15 | half_sizes: ti.types.ndarray(ndim=2)):
16 |
17 | for r in ti.ndrange(hits_t.shape[0]):
18 | ray_o = vec3([rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]])
19 | ray_d = vec3([rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]])
20 | inv_d = 1.0 / ray_d
21 |
22 | center = vec3([centers[0, 0], centers[0, 1], centers[0, 2]])
23 | half_size = vec3(
24 | [half_sizes[0, 0], half_sizes[0, 1], half_sizes[0, 1]])
25 |
26 | t_min = (center - half_size - ray_o) * inv_d
27 | t_max = (center + half_size - ray_o) * inv_d
28 |
29 | _t1 = ti.min(t_min, t_max)
30 | _t2 = ti.max(t_min, t_max)
31 | t1 = _t1.max()
32 | t2 = _t2.min()
33 |
34 | if t2 > 0.0:
35 | hits_t[r, 0, 0] = ti.max(t1, NEAR_DISTANCE)
36 | hits_t[r, 0, 1] = t2
37 |
38 |
39 | class RayAABBIntersector(torch.autograd.Function):
40 | """
41 | Computes the intersections of rays and axis-aligned voxels.
42 |
43 | Inputs:
44 | rays_o: (N_rays, 3) ray origins
45 | rays_d: (N_rays, 3) ray directions
46 | centers: (N_voxels, 3) voxel centers
47 | half_sizes: (N_voxels, 3) voxel half sizes
48 | max_hits: maximum number of intersected voxels to keep for one ray
49 | (for a cubic scene, this is at most 3*N_voxels^(1/3)-2)
50 |
51 | Outputs:
52 | hits_cnt: (N_rays) number of hits for each ray
53 | (followings are from near to far)
54 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit)
55 | hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit)
56 | """
57 |
58 | @staticmethod
59 | @custom_fwd(cast_inputs=torch.float32)
60 | def forward(ctx, rays_o, rays_d, center, half_size, max_hits):
61 | hits_t = (torch.zeros(
62 | rays_o.size(0), 1, 2, device=rays_o.device, dtype=torch.float32) -
63 | 1).contiguous()
64 |
65 | simple_ray_aabb_intersec_taichi_forward(hits_t, rays_o, rays_d, center,
66 | half_size)
67 |
68 | return None, hits_t, None
69 |
--------------------------------------------------------------------------------
/taichi_modules/utils.py:
--------------------------------------------------------------------------------
1 | import taichi as ti
2 | import torch
3 | from taichi.math import uvec3
4 |
5 | taichi_block_size = 128
6 |
7 | data_type = ti.f32
8 | torch_type = torch.float32
9 |
10 | MAX_SAMPLES = 1024
11 | NEAR_DISTANCE = 0.01
12 | SQRT3 = 1.7320508075688772
13 | SQRT3_MAX_SAMPLES = SQRT3 / 1024
14 | SQRT3_2 = 1.7320508075688772 * 2
15 |
16 |
17 | @ti.func
18 | def scalbn(x, exponent):
19 | return x * ti.math.pow(2, exponent)
20 |
21 |
22 | @ti.func
23 | def calc_dt(t, exp_step_factor, grid_size, scale):
24 | return ti.math.clamp(t * exp_step_factor, SQRT3_MAX_SAMPLES,
25 | SQRT3_2 * scale / grid_size)
26 |
27 |
28 | @ti.func
29 | def frexp_bit(x):
30 | exponent = 0
31 | if x != 0.0:
32 | # frac = ti.abs(x)
33 | bits = ti.bit_cast(x, ti.u32)
34 | exponent = ti.i32((bits & ti.u32(0x7f800000)) >> 23) - 127
35 | # exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127
36 | bits &= ti.u32(0x7fffff)
37 | bits |= ti.u32(0x3f800000)
38 | frac = ti.bit_cast(bits, ti.f32)
39 | if frac < 0.5:
40 | exponent -= 1
41 | elif frac > 1.0:
42 | exponent += 1
43 | return exponent
44 |
45 |
46 | @ti.func
47 | def mip_from_pos(xyz, cascades):
48 | mx = ti.abs(xyz).max()
49 | # _, exponent = _frexp(mx)
50 | exponent = frexp_bit(ti.f32(mx)) + 1
51 | # frac, exponent = ti.frexp(ti.f32(mx))
52 | return ti.min(cascades - 1, ti.max(0, exponent))
53 |
54 |
55 | @ti.func
56 | def mip_from_dt(dt, grid_size, cascades):
57 | # _, exponent = _frexp(dt*grid_size)
58 | exponent = frexp_bit(ti.f32(dt * grid_size))
59 | # frac, exponent = ti.frexp(ti.f32(dt*grid_size))
60 | return ti.min(cascades - 1, ti.max(0, exponent))
61 |
62 |
63 | @ti.func
64 | def __expand_bits(v):
65 | v = (v * ti.uint32(0x00010001)) & ti.uint32(0xFF0000FF)
66 | v = (v * ti.uint32(0x00000101)) & ti.uint32(0x0F00F00F)
67 | v = (v * ti.uint32(0x00000011)) & ti.uint32(0xC30C30C3)
68 | v = (v * ti.uint32(0x00000005)) & ti.uint32(0x49249249)
69 | return v
70 |
71 |
72 | @ti.func
73 | def __morton3D(xyz):
74 | xyz = __expand_bits(xyz)
75 | return xyz[0] | (xyz[1] << 1) | (xyz[2] << 2)
76 |
77 |
78 | @ti.func
79 | def __morton3D_invert(x):
80 | x = x & (0x49249249)
81 | x = (x | (x >> 2)) & ti.uint32(0xc30c30c3)
82 | x = (x | (x >> 4)) & ti.uint32(0x0f00f00f)
83 | x = (x | (x >> 8)) & ti.uint32(0xff0000ff)
84 | x = (x | (x >> 16)) & ti.uint32(0x0000ffff)
85 | return ti.int32(x)
86 |
87 |
88 | @ti.kernel
89 | def morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1),
90 | coords: ti.types.ndarray(ndim=2)):
91 | for i in indices:
92 | ind = ti.uint32(indices[i])
93 | coords[i, 0] = __morton3D_invert(ind >> 0)
94 | coords[i, 1] = __morton3D_invert(ind >> 1)
95 | coords[i, 2] = __morton3D_invert(ind >> 2)
96 |
97 |
98 | def morton3D_invert(indices):
99 | coords = torch.zeros(indices.size(0),
100 | 3,
101 | device=indices.device,
102 | dtype=torch.int32)
103 | morton3D_invert_kernel(indices.contiguous(), coords)
104 | ti.sync()
105 | return coords
106 |
107 |
108 | @ti.kernel
109 | def morton3D_kernel(xyzs: ti.types.ndarray(ndim=2),
110 | indices: ti.types.ndarray(ndim=1)):
111 | for s in indices:
112 | xyz = uvec3([xyzs[s, 0], xyzs[s, 1], xyzs[s, 2]])
113 | indices[s] = ti.cast(__morton3D(xyz), ti.int32)
114 |
115 |
116 | def morton3D(coords1):
117 | indices = torch.zeros(coords1.size(0),
118 | device=coords1.device,
119 | dtype=torch.int32)
120 | morton3D_kernel(coords1.contiguous(), indices)
121 | ti.sync()
122 | return indices
123 |
124 |
125 | @ti.kernel
126 | def packbits(density_grid: ti.types.ndarray(ndim=1),
127 | density_threshold: float,
128 | density_bitfield: ti.types.ndarray(ndim=1)):
129 |
130 | for n in density_bitfield:
131 | bits = ti.uint8(0)
132 |
133 | for i in ti.static(range(8)):
134 | bits |= (ti.uint8(1) << i) if (
135 | density_grid[8 * n + i] > density_threshold) else ti.uint8(0)
136 |
137 | density_bitfield[n] = bits
138 |
139 |
140 | @ti.kernel
141 | def torch2ti(field: ti.template(), data: ti.types.ndarray()):
142 | for I in ti.grouped(data):
143 | field[I] = data[I]
144 |
145 |
146 | @ti.kernel
147 | def ti2torch(field: ti.template(), data: ti.types.ndarray()):
148 | for I in ti.grouped(data):
149 | data[I] = field[I]
150 |
151 |
152 | @ti.kernel
153 | def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):
154 | for I in ti.grouped(grad):
155 | grad[I] = field.grad[I]
156 |
157 |
158 | @ti.kernel
159 | def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):
160 | for I in ti.grouped(grad):
161 | field.grad[I] = grad[I]
162 |
163 |
164 | @ti.kernel
165 | def torch2ti_vec(field: ti.template(), data: ti.types.ndarray()):
166 | for I in range(data.shape[0] // 2):
167 | field[I] = ti.Vector([data[I * 2], data[I * 2 + 1]])
168 |
169 |
170 | @ti.kernel
171 | def ti2torch_vec(field: ti.template(), data: ti.types.ndarray()):
172 | for i, j in ti.ndrange(data.shape[0], data.shape[1] // 2):
173 | data[i, j * 2] = field[i, j][0]
174 | data[i, j * 2 + 1] = field[i, j][1]
175 |
176 |
177 | @ti.kernel
178 | def ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
179 | for I in range(grad.shape[0] // 2):
180 | grad[I * 2] = field.grad[I][0]
181 | grad[I * 2 + 1] = field.grad[I][1]
182 |
183 |
184 | @ti.kernel
185 | def torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
186 | for i, j in ti.ndrange(grad.shape[0], grad.shape[1] // 2):
187 | field.grad[i, j][0] = grad[i, j * 2]
188 | field.grad[i, j][1] = grad[i, j * 2 + 1]
189 |
190 |
191 | def extract_model_state_dict(ckpt_path,
192 | model_name='model',
193 | prefixes_to_ignore=[]):
194 | checkpoint = torch.load(ckpt_path, map_location='cpu')
195 | checkpoint_ = {}
196 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint
197 | checkpoint = checkpoint['state_dict']
198 | for k, v in checkpoint.items():
199 | if not k.startswith(model_name):
200 | continue
201 | k = k[len(model_name) + 1:]
202 | for prefix in prefixes_to_ignore:
203 | if k.startswith(prefix):
204 | break
205 | else:
206 | checkpoint_[k] = v
207 | return checkpoint_
208 |
209 |
210 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
211 | if not ckpt_path:
212 | return
213 | model_dict = model.state_dict()
214 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name,
215 | prefixes_to_ignore)
216 | model_dict.update(checkpoint_)
217 | model.load_state_dict(model_dict)
218 |
219 | def depth2img(depth):
220 | depth = (depth - depth.min()) / (depth.max() - depth.min())
221 | depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8),
222 | cv2.COLORMAP_TURBO)
223 |
224 | return depth_img
--------------------------------------------------------------------------------
/taichi_modules/volume_render_test.py:
--------------------------------------------------------------------------------
1 | import taichi as ti
2 |
3 |
4 | @ti.kernel
5 | def composite_test(
6 | sigmas: ti.types.ndarray(ndim=2), rgbs: ti.types.ndarray(ndim=3),
7 | deltas: ti.types.ndarray(ndim=2), ts: ti.types.ndarray(ndim=2),
8 | hits_t: ti.types.ndarray(ndim=2),
9 | alive_indices: ti.types.ndarray(ndim=1), T_threshold: float,
10 | N_eff_samples: ti.types.ndarray(ndim=1),
11 | opacity: ti.types.ndarray(ndim=1),
12 | depth: ti.types.ndarray(ndim=1), rgb: ti.types.ndarray(ndim=2)):
13 |
14 | for n in alive_indices:
15 | samples = N_eff_samples[n]
16 | if samples == 0:
17 | alive_indices[n] = -1
18 | else:
19 | r = alive_indices[n]
20 |
21 | T = 1 - opacity[r]
22 |
23 | rgb_temp_0 = 0.0
24 | rgb_temp_1 = 0.0
25 | rgb_temp_2 = 0.0
26 | depth_temp = 0.0
27 | opacity_temp = 0.0
28 |
29 | for s in range(samples):
30 | a = 1.0 - ti.exp(-sigmas[n, s] * deltas[n, s])
31 | w = a * T
32 |
33 | rgb_temp_0 += w * rgbs[n, s, 0]
34 | rgb_temp_1 += w * rgbs[n, s, 1]
35 | rgb_temp_2 += w * rgbs[n, s, 2]
36 | depth[r] += w * ts[n, s]
37 | opacity[r] += w
38 | T *= 1.0 - a
39 |
40 | if T <= T_threshold:
41 | alive_indices[n] = -1
42 | break
43 |
44 | rgb[r, 0] += rgb_temp_0
45 | rgb[r, 1] += rgb_temp_1
46 | rgb[r, 2] += rgb_temp_2
47 | depth[r] += depth_temp
48 | opacity[r] += opacity_temp
49 |
--------------------------------------------------------------------------------
/tets/128_tets.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/tets/128_tets.npz
--------------------------------------------------------------------------------
/tets/32_tets.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/tets/32_tets.npz
--------------------------------------------------------------------------------
/tets/64_tets.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/tets/64_tets.npz
--------------------------------------------------------------------------------
/tets/README.md:
--------------------------------------------------------------------------------
1 | Place the tet grid files in this folder.
2 | We provide a few example grids. See the main README.md for a download link.
3 |
4 | You can also generate your own grids using https://github.com/crawforddoran/quartet
5 | Please see the `generate_tets.py` script for an example.
6 |
7 |
--------------------------------------------------------------------------------
/tets/generate_tets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import os
11 | import numpy as np
12 |
13 |
14 | '''
15 | This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet,
16 | to generate a tet grid
17 | 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`
18 | 2) Run the function below to generate a file `cube_32_tet.tet`
19 | '''
20 |
21 | def generate_tetrahedron_grid_file(res=32, root='..'):
22 | frac = 1.0 / res
23 | command = 'cd %s/quartet; ' % (root) + \
24 | './quartet_release meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res)
25 | os.system(command)
26 |
27 |
28 | '''
29 | This code segment shows how to convert from a quartet .tet file to compressed npz file
30 | '''
31 | def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets.npz'):
32 |
33 | file1 = open(quartetfile, 'r')
34 | header = file1.readline()
35 | numvertices = int(header.split(" ")[1])
36 | numtets = int(header.split(" ")[2])
37 | print(numvertices, numtets)
38 |
39 | # load vertices
40 | vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)
41 | vertices = vertices - 0.5
42 | print(vertices.shape, vertices.min(), vertices.max())
43 |
44 | # load indices
45 | indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets)
46 | print(indices.shape)
47 |
48 | np.savez_compressed(npzfile, vertices=vertices, indices=indices)
49 |
50 | if __name__ == '__main__':
51 | import argparse
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument('--res', type=int, default=32)
54 | parser.add_argument('--root', type=str, default='..')
55 | args = parser.parse_args()
56 |
57 | generate_tetrahedron_grid_file(res=args.res, root=args.root)
58 | convert_from_quartet_to_npz(quartetfile=os.path.join(args.root, 'quartet', 'meshes', f'cube_{args.res}.000000_tet.tet'), npzfile=os.path.join('./tets', f'{args.res}_tets.npz'))
--------------------------------------------------------------------------------