├── .github └── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── feature_request.md ├── .gitignore ├── LICENSE ├── activation.py ├── assets ├── advanced.md └── update_logs.md ├── config ├── anya.csv ├── car.csv └── corgi.csv ├── data ├── anya_back.webp ├── anya_back_depth.png ├── anya_back_normal.png ├── anya_back_rgba.png ├── anya_front.jpg ├── anya_front.png ├── anya_front_depth.png ├── anya_front_normal.png ├── anya_front_rgba.png ├── baby_phoenix_on_ice.png ├── beach_house_1.png ├── beach_house_2.png ├── bollywood_actress.png ├── cactus.png ├── cactus_depth.png ├── cactus_normal.png ├── cactus_rgba.png ├── cake.png ├── cake_depth.png ├── cake_normal.png ├── cake_rgba.png ├── car_back.jpg ├── car_front.jpg ├── car_left.jpg ├── car_right.jpg ├── catstatue.png ├── catstatue_depth.png ├── catstatue_normal.png ├── catstatue_rgba.png ├── church_ruins.png ├── corgi_puppy_sitting_looking_up.jpg ├── firekeeper.jpg ├── firekeeper_depth.png ├── firekeeper_normal.png ├── firekeeper_rgba.png ├── futuristic_car.png ├── hamburger.png ├── hamburger_depth.png ├── hamburger_normal.png ├── hamburger_rgba.png ├── mona_lisa.png ├── teddy.png ├── teddy_depth.png ├── teddy_normal.png └── teddy_rgba.png ├── docker ├── Dockerfile └── README.md ├── dpt.py ├── encoding.py ├── evaluation ├── Prompt.py ├── mesh_to_video.py ├── r_precision.py └── readme.md ├── freqencoder ├── __init__.py ├── backend.py ├── freq.py ├── setup.py └── src │ ├── bindings.cpp │ ├── freqencoder.cu │ └── freqencoder.h ├── gridencoder ├── __init__.py ├── backend.py ├── grid.py ├── setup.py └── src │ ├── bindings.cpp │ ├── gridencoder.cu │ └── gridencoder.h ├── guidance ├── clip_utils.py ├── if_utils.py ├── perpneg_utils.py ├── sd_utils.py └── zero123_utils.py ├── ldm ├── extras.py ├── guidance.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── evaluate │ │ ├── adm_evaluator.py │ │ ├── evaluate_perceptualsim.py │ │ ├── frechet_video_distance.py │ │ ├── ssim.py │ │ └── torch_frechet_video_distance.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py ├── thirdp │ └── psp │ │ ├── helpers.py │ │ ├── id_loss.py │ │ └── model_irse.py └── util.py ├── main.py ├── meshutils.py ├── nerf ├── gui.py ├── network.py ├── network_grid.py ├── network_grid_taichi.py ├── network_grid_tcnn.py ├── provider.py ├── renderer.py └── utils.py ├── optimizer.py ├── preprocess_image.py ├── pretrained └── zero123 │ └── sd-objaverse-finetune-c_concat-256.yaml ├── raymarching ├── __init__.py ├── backend.py ├── raymarching.py ├── setup.py └── src │ ├── bindings.cpp │ ├── raymarching.cu │ └── raymarching.h ├── readme.md ├── requirements.txt ├── scripts ├── install_ext.sh ├── res64.args ├── run.sh ├── run2.sh ├── run3.sh ├── run4.sh ├── run5.sh ├── run6.sh ├── run_if.sh ├── run_if2.sh ├── run_if2_perpneg.sh ├── run_image.sh ├── run_image_anya.sh ├── run_image_hard_examples.sh ├── run_image_procedure.sh ├── run_image_text.sh └── run_images.sh ├── shencoder ├── __init__.py ├── backend.py ├── setup.py ├── sphere_harmonics.py └── src │ ├── bindings.cpp │ ├── shencoder.cu │ └── shencoder.h ├── taichi_modules ├── __init__.py ├── hash_encoder.py ├── intersection.py ├── ray_march.py ├── utils.py ├── volume_render_test.py └── volume_train.py └── tets ├── 128_tets.npz ├── 32_tets.npz ├── 64_tets.npz ├── README.md └── generate_tets.py /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File a bug report 3 | title: "" 4 | labels: ["bug"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Before filing a bug report, [search for an existing issue](https://github.com/ashawkey/stable-dreamfusion/issues). 10 | 11 | Also, ensure you are running the latest version. 12 | - type: textarea 13 | id: description 14 | attributes: 15 | label: Description 16 | description: Provide a clear and concise description of what the bug is. 17 | placeholder: Description 18 | validations: 19 | required: true 20 | - type: textarea 21 | id: steps 22 | attributes: 23 | label: Steps to Reproduce 24 | description: List the steps needed to reproduce the issue. 25 | placeholder: | 26 | 1. Go to '...' 27 | 2. Click on '...' 28 | validations: 29 | required: true 30 | - type: textarea 31 | id: expected-behavior 32 | attributes: 33 | label: Expected Behavior 34 | description: Describe what you expected to happen. 35 | placeholder: | 36 | The 'action' would do 'some amazing thing'. 37 | validations: 38 | required: true 39 | - type: textarea 40 | id: environment 41 | attributes: 42 | label: Environment 43 | description: Describe your environment. 44 | placeholder: | 45 | Ubuntu 22.04, PyTorch 1.13, CUDA 11.6 46 | validations: 47 | required: true 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | *.egg-info/ 4 | *.so 5 | venv_*/ 6 | 7 | tmp* 8 | # data/ 9 | ldm/data/ 10 | data2 11 | scripts2 12 | trial*/ 13 | .vs/ 14 | 15 | TOKEN 16 | *.ckpt 17 | 18 | densegridencoder 19 | tets/256_tets.npz 20 | 21 | .vscode/launch.json 22 | 23 | data2 24 | data/car* 25 | data/chair* 26 | data/warrior* 27 | data/wd* 28 | data/space* 29 | data/corgi* 30 | data/turtle* 31 | 32 | # Only keep the original image, not the automatically-generated depth, normals, rgba 33 | data/baby_phoenix_on_ice_* 34 | data/bollywood_actress_* 35 | data/beach_house_1_* 36 | data/beach_house_2_* 37 | data/mona_lisa_* 38 | data/futuristic_car_* 39 | data/church_ruins_* 40 | 41 | -------------------------------------------------------------------------------- /activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.cuda.amp import custom_bwd, custom_fwd 4 | 5 | class _trunc_exp(Function): 6 | @staticmethod 7 | @custom_fwd(cast_inputs=torch.float) 8 | def forward(ctx, x): 9 | ctx.save_for_backward(x) 10 | return torch.exp(x) 11 | 12 | @staticmethod 13 | @custom_bwd 14 | def backward(ctx, g): 15 | x = ctx.saved_tensors[0] 16 | return g * torch.exp(x.clamp(max=15)) 17 | 18 | trunc_exp = _trunc_exp.apply 19 | 20 | def biased_softplus(x, bias=0): 21 | return torch.nn.functional.softplus(x - bias) -------------------------------------------------------------------------------- /assets/advanced.md: -------------------------------------------------------------------------------- 1 | 2 | # Code organization & Advanced tips 3 | 4 | This is a simple description of the most important implementation details. 5 | If you are interested in improving this repo, this might be a starting point. 6 | Any contribution would be greatly appreciated! 7 | 8 | * The SDS loss is located at `./guidance/sd_utils.py > StableDiffusion > train_step`: 9 | ```python 10 | ## 1. we need to interpolate the NeRF rendering to 512x512, to feed it to SD's VAE. 11 | pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) 12 | ## 2. image (512x512) --- VAE --> latents (64x64), this is SD's difference from Imagen. 13 | latents = self.encode_imgs(pred_rgb_512) 14 | ... # timestep sampling, noise adding and UNet noise predicting 15 | ## 3. the SDS loss 16 | w = (1 - self.alphas[t]) 17 | grad = w * (noise_pred - noise) 18 | # since UNet part is ignored and cannot simply audodiff, we have two ways to set the grad: 19 | # 3.1. call backward and set the grad now (need to retain graph since we will call a second backward for the other losses later) 20 | latents.backward(gradient=grad, retain_graph=True) 21 | return 0 # dummy loss 22 | 23 | # 3.2. use a custom function to set a hook in backward, so we only call backward once (credits to @elliottzheng) 24 | class SpecifyGradient(torch.autograd.Function): 25 | @staticmethod 26 | @custom_fwd 27 | def forward(ctx, input_tensor, gt_grad): 28 | ctx.save_for_backward(gt_grad) 29 | # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward. 30 | return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype) 31 | 32 | @staticmethod 33 | @custom_bwd 34 | def backward(ctx, grad_scale): 35 | gt_grad, = ctx.saved_tensors 36 | gt_grad = gt_grad * grad_scale 37 | return gt_grad, None 38 | 39 | loss = SpecifyGradient.apply(latents, grad) 40 | return loss # functional loss 41 | 42 | # 3.3. reparameterization (credits to @Xallt) 43 | # d(loss)/d(latents) = grad, since grad is already detached, it's this simple. 44 | loss = (grad * latents).sum() 45 | return loss 46 | 47 | # 3.4. reparameterization (credits to threestudio) 48 | # this is the same as 3.3, but the loss value only reflects the magnitude of grad, which is more informative. 49 | targets = (latents - grad).detach() 50 | loss = 0.5 * F.mse_loss(latents, targets, reduction='sum') 51 | return loss 52 | ``` 53 | * Other regularizations are in `./nerf/utils.py > Trainer > train_step`. 54 | * The generation seems quite sensitive to regularizations on weights_sum (alphas for each ray). The original opacity loss tends to make NeRF disappear (zero density everywhere), so we use an entropy loss to replace it for now (encourages alpha to be either 0 or 1). 55 | * NeRF Rendering core function: `./nerf/renderer.py > NeRFRenderer > run & run_cuda`. 56 | * Shading & normal evaluation: `./nerf/network*.py > NeRFNetwork > forward`. 57 | * light direction: current implementation use a plane light source, instead of a point light source. 58 | * View-dependent prompting: `./nerf/provider.py > get_view_direction`. 59 | * use `--angle_overhead, --angle_front` to set the border. 60 | * Network backbone (`./nerf/network*.py`) can be chosen by the `--backbone` option. 61 | * Spatial density bias (density blob): `./nerf/network*.py > NeRFNetwork > density_blob`. 62 | 63 | 64 | # Debugging 65 | 66 | `debugpy-run` is a convenient way to remotely debug this project. Simply replace a command like this one: 67 | 68 | ```bash 69 | python main.py --text "a hamburger" --workspace trial -O --vram_O 70 | ``` 71 | 72 | ... with: 73 | 74 | ```bash 75 | debugpy-run main.py -- --text "a hamburger" --workspace trial -O --vram_O 76 | ``` 77 | 78 | For more details: https://github.com/bulletmark/debugpy-run 79 | 80 | # Axes and directions of polar, azimuth, etc. in NeRF and Zero123 81 | 82 | <img width="1119" alt="NeRF_Zero123" src="https://github.com/ashawkey/stable-dreamfusion/assets/22424247/a0f432ff-2d08-45a4-a390-bda64f5cbc94"> 83 | 84 | This code refers to theta for polar, phi for azimuth. 85 | 86 | -------------------------------------------------------------------------------- /assets/update_logs.md: -------------------------------------------------------------------------------- 1 | ### 2023.4.19 2 | * Fix depth supervision, migrate depth estimation model to omnidata. 3 | * Add normal supervision (also by omnidata). 4 | 5 | https://user-images.githubusercontent.com/25863658/232403294-b77409bf-ddc7-4bb8-af32-ee0cc123825a.mp4 6 | 7 | ### 2023.4.7 8 | Improvement on mesh quality & DMTet finetuning support. 9 | 10 | https://user-images.githubusercontent.com/25863658/230535363-298c960e-bf9c-4906-8b96-cd60edcb24dd.mp4 11 | 12 | ### 2023.3.30 13 | * adopt ideas from [Fantasia3D](https://fantasia3d.github.io/) to concatenate normal and mask as the latent code in a warm up stage, which shows faster convergence of shape. 14 | 15 | https://user-images.githubusercontent.com/25863658/230535373-6ee28f16-bb21-4ec4-bc86-d46597361a04.mp4 16 | 17 | ### 2023.1.30 18 | * Use an MLP to predict the surface normals as in Magic3D to avoid finite difference / second order gradient, generation quality is greatly improved. 19 | * More efficient two-pass raymarching in training inspired by nerfacc. 20 | 21 | https://user-images.githubusercontent.com/25863658/215996308-9fd959f5-b5c7-4a8e-a241-0fe63ec86a4a.mp4 22 | 23 | ### 2022.12.3 24 | * Support Stable-diffusion 2.0 base. 25 | 26 | ### 2022.11.15 27 | * Add the vanilla backbone that is pure-pytorch. 28 | 29 | ### 2022.10.9 30 | * The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled. 31 | * Enable shading by default (--latent_iter_ratio 1000). 32 | 33 | ### 2022.10.5 34 | * Basic reproduction finished. 35 | * Non --cuda_ray, --tcnn are not working, need to fix. 36 | * Shading is not working, disabled in utils.py for now. Surface normals are bad. 37 | * Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry... 38 | 39 | https://user-images.githubusercontent.com/25863658/194241493-f3e68f78-aefe-479e-a4a8-001424a61b37.mp4 40 | -------------------------------------------------------------------------------- /config/anya.csv: -------------------------------------------------------------------------------- 1 | zero123_weight, radius, polar, azimuth, image 2 | 1, 3, 90, 0, data/anya_front_rgba.png 3 | 1, 3, 90, 180, data/anya_back_rgba.png -------------------------------------------------------------------------------- /config/car.csv: -------------------------------------------------------------------------------- 1 | zero123_weight, radius, polar, azimuth, image 2 | 4, 3.2, 90, 0, data/car_left_rgba.png 3 | 1, 3, 90, 90, data/car_front_rgba.png 4 | 4, 3.2, 90, 180, data/car_right_rgba.png 5 | 1, 3, 90, -90, data/car_back_rgba.png -------------------------------------------------------------------------------- /config/corgi.csv: -------------------------------------------------------------------------------- 1 | zero123_weight, radius, polar, azimuth, image 2 | 1, 3.2, 90, 0, data/corgi_puppy_sitting_looking_up_rgba.png -------------------------------------------------------------------------------- /data/anya_back.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_back.webp -------------------------------------------------------------------------------- /data/anya_back_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_back_depth.png -------------------------------------------------------------------------------- /data/anya_back_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_back_normal.png -------------------------------------------------------------------------------- /data/anya_back_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_back_rgba.png -------------------------------------------------------------------------------- /data/anya_front.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front.jpg -------------------------------------------------------------------------------- /data/anya_front.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front.png -------------------------------------------------------------------------------- /data/anya_front_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front_depth.png -------------------------------------------------------------------------------- /data/anya_front_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front_normal.png -------------------------------------------------------------------------------- /data/anya_front_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/anya_front_rgba.png -------------------------------------------------------------------------------- /data/baby_phoenix_on_ice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/baby_phoenix_on_ice.png -------------------------------------------------------------------------------- /data/beach_house_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/beach_house_1.png -------------------------------------------------------------------------------- /data/beach_house_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/beach_house_2.png -------------------------------------------------------------------------------- /data/bollywood_actress.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/bollywood_actress.png -------------------------------------------------------------------------------- /data/cactus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cactus.png -------------------------------------------------------------------------------- /data/cactus_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cactus_depth.png -------------------------------------------------------------------------------- /data/cactus_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cactus_normal.png -------------------------------------------------------------------------------- /data/cactus_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cactus_rgba.png -------------------------------------------------------------------------------- /data/cake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cake.png -------------------------------------------------------------------------------- /data/cake_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cake_depth.png -------------------------------------------------------------------------------- /data/cake_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cake_normal.png -------------------------------------------------------------------------------- /data/cake_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/cake_rgba.png -------------------------------------------------------------------------------- /data/car_back.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/car_back.jpg -------------------------------------------------------------------------------- /data/car_front.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/car_front.jpg -------------------------------------------------------------------------------- /data/car_left.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/car_left.jpg -------------------------------------------------------------------------------- /data/car_right.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/car_right.jpg -------------------------------------------------------------------------------- /data/catstatue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/catstatue.png -------------------------------------------------------------------------------- /data/catstatue_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/catstatue_depth.png -------------------------------------------------------------------------------- /data/catstatue_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/catstatue_normal.png -------------------------------------------------------------------------------- /data/catstatue_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/catstatue_rgba.png -------------------------------------------------------------------------------- /data/church_ruins.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/church_ruins.png -------------------------------------------------------------------------------- /data/corgi_puppy_sitting_looking_up.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/corgi_puppy_sitting_looking_up.jpg -------------------------------------------------------------------------------- /data/firekeeper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/firekeeper.jpg -------------------------------------------------------------------------------- /data/firekeeper_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/firekeeper_depth.png -------------------------------------------------------------------------------- /data/firekeeper_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/firekeeper_normal.png -------------------------------------------------------------------------------- /data/firekeeper_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/firekeeper_rgba.png -------------------------------------------------------------------------------- /data/futuristic_car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/futuristic_car.png -------------------------------------------------------------------------------- /data/hamburger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/hamburger.png -------------------------------------------------------------------------------- /data/hamburger_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/hamburger_depth.png -------------------------------------------------------------------------------- /data/hamburger_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/hamburger_normal.png -------------------------------------------------------------------------------- /data/hamburger_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/hamburger_rgba.png -------------------------------------------------------------------------------- /data/mona_lisa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/mona_lisa.png -------------------------------------------------------------------------------- /data/teddy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/teddy.png -------------------------------------------------------------------------------- /data/teddy_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/teddy_depth.png -------------------------------------------------------------------------------- /data/teddy_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/teddy_normal.png -------------------------------------------------------------------------------- /data/teddy_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/data/teddy_rgba.png -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 2 | 3 | # Remove any third-party apt sources to avoid issues with expiring keys. 4 | RUN rm -f /etc/apt/sources.list.d/*.list 5 | 6 | RUN apt-get update 7 | 8 | RUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata 9 | 10 | # Install some basic utilities 11 | RUN apt-get install -y \ 12 | curl \ 13 | ca-certificates \ 14 | sudo \ 15 | git \ 16 | bzip2 \ 17 | libx11-6 \ 18 | python3 \ 19 | python3-pip \ 20 | libglfw3-dev \ 21 | libgles2-mesa-dev \ 22 | libglib2.0-0 \ 23 | && rm -rf /var/lib/apt/lists/* 24 | 25 | 26 | # Create a working directory 27 | RUN mkdir /app 28 | WORKDIR /app 29 | 30 | RUN cd /app 31 | RUN git clone https://github.com/ashawkey/stable-dreamfusion.git 32 | 33 | 34 | RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 35 | 36 | WORKDIR /app/stable-dreamfusion 37 | 38 | RUN pip3 install -r requirements.txt 39 | RUN pip3 install git+https://github.com/NVlabs/nvdiffrast/ 40 | 41 | # Needs nvidia runtime, if you have "No CUDA runtime is found" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer 42 | RUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 43 | 44 | RUN pip3 install git+https://github.com/openai/CLIP.git 45 | RUN bash scripts/install_ext.sh 46 | 47 | 48 | 49 | 50 | 51 | # Set the default command to python3 52 | #CMD ["python3"] 53 | 54 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ### Docker installation 2 | 3 | ## Build image 4 | To build the docker image on your own machine, which may take 15-30 mins: 5 | ``` 6 | docker build -t stable-dreamfusion:latest . 7 | ``` 8 | 9 | If you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker. 10 | ``` 11 | sudo apt-get install nvidia-container-runtime 12 | ``` 13 | Then edit `/etc/docker/daemon.json` and add the default-runtime: 14 | ``` 15 | { 16 | "runtimes": { 17 | "nvidia": { 18 | "path": "nvidia-container-runtime", 19 | "runtimeArgs": [] 20 | } 21 | }, 22 | "default-runtime": "nvidia" 23 | } 24 | ``` 25 | And restart docker: 26 | ``` 27 | sudo systemctl restart docker 28 | ``` 29 | Now you can build tiny-cuda-nn inside docker. 30 | 31 | ## Download image 32 | To download the image (~6GB) instead: 33 | ``` 34 | docker pull supercabb/stable-dreamfusion:3080_0.0.1 35 | docker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion 36 | ``` 37 | 38 | ## Use image 39 | 40 | You can launch an interactive shell inside the container: 41 | 42 | ``` 43 | docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash 44 | ``` 45 | From this shell, all the code in the repo should work. 46 | 47 | To run any single command `<command...>` inside the docker container: 48 | ``` 49 | docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "<command...>" 50 | ``` 51 | To train: 52 | ``` 53 | export TOKEN="#HUGGING FACE ACCESS TOKEN#" 54 | docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "echo ${TOKEN} > TOKEN \ 55 | && python3 main.py --text \"a hamburger\" --workspace trial -O" 56 | 57 | ``` 58 | Run test without gui: 59 | ``` 60 | export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#" 61 | docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \ 62 | -v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \ 63 | main.py --workspace trial -O --test" 64 | ``` 65 | Run test with gui: 66 | ``` 67 | export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#" 68 | xhost + 69 | docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \ 70 | -v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \ 71 | main.py --workspace trial -O --test --gui" 72 | xhost - 73 | ``` 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FreqEncoder_torch(nn.Module): 6 | def __init__(self, input_dim, max_freq_log2, N_freqs, 7 | log_sampling=True, include_input=True, 8 | periodic_fns=(torch.sin, torch.cos)): 9 | 10 | super().__init__() 11 | 12 | self.input_dim = input_dim 13 | self.include_input = include_input 14 | self.periodic_fns = periodic_fns 15 | self.N_freqs = N_freqs 16 | 17 | self.output_dim = 0 18 | if self.include_input: 19 | self.output_dim += self.input_dim 20 | 21 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) 22 | 23 | if log_sampling: 24 | self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs) 25 | else: 26 | self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs) 27 | 28 | self.freq_bands = self.freq_bands.numpy().tolist() 29 | 30 | def forward(self, input, max_level=None, **kwargs): 31 | 32 | if max_level is None: 33 | max_level = self.N_freqs 34 | else: 35 | max_level = int(max_level * self.N_freqs) 36 | 37 | out = [] 38 | if self.include_input: 39 | out.append(input) 40 | 41 | for i in range(max_level): 42 | freq = self.freq_bands[i] 43 | for p_fn in self.periodic_fns: 44 | out.append(p_fn(input * freq)) 45 | 46 | # append 0 47 | if self.N_freqs - max_level > 0: 48 | out.append(torch.zeros(*input.shape[:-1], (self.N_freqs - max_level) * 2 * input.shape[-1], device=input.device, dtype=input.dtype)) 49 | 50 | out = torch.cat(out, dim=-1) 51 | 52 | return out 53 | 54 | def get_encoder(encoding, input_dim=3, 55 | multires=6, 56 | degree=4, 57 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear', 58 | **kwargs): 59 | 60 | if encoding == 'None': 61 | return lambda x, **kwargs: x, input_dim 62 | 63 | elif encoding == 'frequency_torch': 64 | encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) 65 | 66 | elif encoding == 'frequency': # CUDA implementation, faster than torch. 67 | from freqencoder import FreqEncoder 68 | encoder = FreqEncoder(input_dim=input_dim, degree=multires) 69 | 70 | elif encoding == 'sphere_harmonics': 71 | from shencoder import SHEncoder 72 | encoder = SHEncoder(input_dim=input_dim, degree=degree) 73 | 74 | elif encoding == 'hashgrid': 75 | from gridencoder import GridEncoder 76 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation) 77 | 78 | elif encoding == 'tiledgrid': 79 | from gridencoder import GridEncoder 80 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation) 81 | 82 | elif encoding == 'hashgrid_taichi': 83 | from taichi_modules.hash_encoder import HashEncoderTaichi 84 | encoder = HashEncoderTaichi(batch_size=4096) #TODO: hard encoded batch size 85 | 86 | else: 87 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') 88 | 89 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /evaluation/Prompt.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification 3 | from transformers import pipeline 4 | import argparse 5 | import sys 6 | import warnings 7 | warnings.filterwarnings("ignore", category=UserWarning) 8 | 9 | 10 | #python Prompt.py --text "a dog is in front of a rabbit" --model vlt5 11 | 12 | 13 | if __name__ == '__main__': 14 | 15 | # Mimic the calling part of the main, using 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--text', default="", type=str, help="text prompt") 18 | #parser.add_argument('--workspace', default="trial", type=str, help="workspace") 19 | parser.add_argument('--model', default='vlt5', type=str, help="model choices - vlt5, bert, XLNet") 20 | 21 | opt = parser.parse_args() 22 | 23 | if opt.model == "vlt5": 24 | tokenizer = AutoTokenizer.from_pretrained("Voicelab/vlt5-base-keywords") 25 | model = AutoModelForSeq2SeqLM.from_pretrained("Voicelab/vlt5-base-keywords") 26 | 27 | task_prefix = "Keywords: " 28 | inputs = [ 29 | opt.text 30 | ] 31 | 32 | for sample in inputs: 33 | input_sequences = [task_prefix + sample] 34 | input_ids = tokenizer( 35 | input_sequences, return_tensors="pt", truncation=True 36 | ).input_ids 37 | output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4) 38 | output_text = tokenizer.decode(output[0], skip_special_tokens=True) 39 | #print(sample, "\n --->", output_text) 40 | 41 | elif opt.model == "bert": 42 | tokenizer = AutoTokenizer.from_pretrained("yanekyuk/bert-uncased-keyword-extractor") 43 | model = AutoModelForTokenClassification.from_pretrained("yanekyuk/bert-uncased-keyword-extractor") 44 | 45 | text = opt.text 46 | input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt") 47 | 48 | # Classify tokens 49 | outputs = model(input_ids) 50 | predictions = outputs.logits.detach().numpy()[0] 51 | labels = predictions.argmax(axis=1) 52 | labels = labels[1:-1] 53 | 54 | print(labels) 55 | tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) 56 | tokens = tokens[1:-1] 57 | output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0] 58 | output_text = tokenizer.convert_tokens_to_string(output_tokens) 59 | 60 | #print(output_text) 61 | 62 | 63 | elif opt.model == "XLNet": 64 | tokenizer = AutoTokenizer.from_pretrained("jasminejwebb/KeywordIdentifier") 65 | model = AutoModelForTokenClassification.from_pretrained("jasminejwebb/KeywordIdentifier") 66 | 67 | text = opt.text 68 | input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt") 69 | 70 | # Classify tokens 71 | outputs = model(input_ids) 72 | predictions = outputs.logits.detach().numpy()[0] 73 | labels = predictions.argmax(axis=1) 74 | labels = labels[1:-1] 75 | 76 | print(labels) 77 | tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) 78 | tokens = tokens[1:-1] 79 | output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0] 80 | output_text = tokenizer.convert_tokens_to_string(output_tokens) 81 | 82 | #print(output_text) 83 | 84 | wrapped_text = textwrap.fill(output_text, width=50) 85 | 86 | 87 | print('+' + '-'*52 + '+') 88 | for line in wrapped_text.split('\n'): 89 | print('| {} |'.format(line.ljust(50))) 90 | print('+' + '-'*52 + '+') 91 | #print(result) 92 | -------------------------------------------------------------------------------- /evaluation/mesh_to_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import trimesh 4 | import argparse 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | import pyvista as pv 8 | 9 | def render_video(anim_mesh): 10 | center = anim_mesh.center_mass 11 | plotter = pv.Plotter(off_screen=True) 12 | plotter.add_mesh(anim_mesh) 13 | 14 | radius = 10 15 | n_frames = 360 16 | angle_step = 2 * np.pi / n_frames 17 | for i in tqdm(range(n_frames)): 18 | camera_pos = [center[0] + radius * np.cos(i*angle_step),center[1] + radius *np.sin(i*angle_step),center[2]] 19 | plotter.camera_position = (camera_pos, center, (0, 0, 1)) 20 | plotter.show(screenshot=f'frame_{i}.png', auto_close=False) 21 | plotter.close() 22 | os.system('ffmpeg -r 30 -f image2 -s 1920x1080 -i "result/frame_%d.png" -vcodec libx264 -crf 25 -pix_fmt yuv420p result/output.mp4') 23 | 24 | 25 | 26 | def generate_mesh(obj1,obj2,transform_vector): 27 | 28 | # Read 2 objects 29 | filename1 = obj1 # Central Object 30 | filename2 = obj2 # Surrounding Object 31 | mesh1 = trimesh.load_mesh(filename1) 32 | mesh2 = trimesh.load_mesh(filename2) 33 | 34 | extents1 = mesh1.extents 35 | extents2 = mesh1.extents 36 | 37 | radius1 = sum(extents1) / 3.0 38 | radius2 = sum(extents2) / 3.0 39 | 40 | center1 = mesh1.center_mass 41 | center2 = mesh2.center_mass 42 | 43 | # Move 44 | T1 = -center1 45 | new =[] 46 | for i in transform_vector: 47 | try: 48 | new.append(float(i))*radius1 49 | except: 50 | pass 51 | transform_vector = new 52 | print(T1, transform_vector, radius1) 53 | T2 = -center2 + transform_vector 54 | 55 | # Transform 56 | mesh1.apply_translation(T1) 57 | mesh2.apply_translation(T2) 58 | 59 | # merge mesh 60 | merged_mesh = trimesh.util.concatenate((mesh1, mesh2)) 61 | 62 | # save mesh 63 | merged_mesh.export('merged_mesh.obj') 64 | print("----> merge mesh done") 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser(description='Generate rotating mesh animation.') 68 | parser.add_argument('--center_obj', type=str, help='Input OBJ1 file.') 69 | parser.add_argument('--surround_obj', type=str, help='Input OBJ2 file.') 70 | parser.add_argument('--transform_vector', help='Transform_vector.') 71 | parser.add_argument('--output_file', type=str, default="result/Demo.mp4", help='Output MP4 file.') 72 | parser.add_argument('--num_frames', type=int, default=100, help='Number of frames to render.') 73 | args = parser.parse_args() 74 | 75 | #mesh = obj.Obj("wr.obj") 76 | generate_mesh(args.center_obj,args.surround_obj,args.transform_vector) 77 | 78 | input_file = Path("merged_mesh.obj") 79 | output_file = Path(args.output_file) 80 | 81 | out_dir = output_file.parent.joinpath('frames') 82 | out_dir.mkdir(parents=True, exist_ok=True) 83 | 84 | anim_mesh = trimesh.load_mesh(str(input_file)) 85 | 86 | render_video(anim_mesh) 87 | 88 | -------------------------------------------------------------------------------- /evaluation/r_precision.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer, util 2 | from PIL import Image 3 | import argparse 4 | import sys 5 | 6 | 7 | if __name__ == '__main__': 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--text', default="", type=str, help="text prompt") 11 | parser.add_argument('--workspace', default="trial", type=str, help="text prompt") 12 | parser.add_argument('--latest', default='ep0001', type=str, help="which epoch result you want to use for image path") 13 | parser.add_argument('--mode', default='rgb', type=str, help="mode of result, color(rgb) or textureless()") 14 | parser.add_argument('--clip', default="clip-ViT-B-32", type=str, help="CLIP model to encode the img and prompt") 15 | 16 | opt = parser.parse_args() 17 | 18 | #Load CLIP model 19 | model = SentenceTransformer(f'{opt.clip}') 20 | 21 | #Encode an image: 22 | img_emb = model.encode(Image.open(f'../results/{opt.workspace}/validation/df_{opt.latest}_0005_{opt.mode}.png')) 23 | 24 | #Encode text descriptions 25 | text_emb = model.encode([f'{opt.text}']) 26 | 27 | #Compute cosine similarities 28 | cos_scores = util.cos_sim(img_emb, text_emb) 29 | print("The final CLIP R-Precision is:", cos_scores[0][0].cpu().numpy()) 30 | 31 | -------------------------------------------------------------------------------- /evaluation/readme.md: -------------------------------------------------------------------------------- 1 | ### Improvement: 2 | 3 | - Usage 4 | 5 | - r_precision.py <br> 6 | For prompt seperation <br> 7 | --text is for the prompt following the author of stable dream fusion <br> 8 | --workspace is the workspace folder which will be created for every prompt fed into stable dreamfusion <br> 9 | --latest is which ckpt is used. Stable dream fusion record every epoch data. Normally is ep0100 unless the training is not finished or we further extend the training <br> 10 | --mode has choices of rgb and depth which is correspondent to color and texture result as original paper Figure 5: Qualitative comparison with baselines. <br> 11 | --clip has choices of clip-ViT-B-32, CLIP B/16, CLIP L/14, same as original paper <br> 12 | 13 | ```bash 14 | python Prompt.py --text "matte painting of a castle made of cheesecake surrounded by a moat made of ice cream" --workspace ../castle --latest ep0100 --mode rgb --clip clip-ViT-B-32 15 | ``` 16 | 17 | - Prompt.py (model name case sensitive) <br> 18 | For prompt seperation <br> <br> 19 | --text is for the prompt following the author of stable dream fusion <br> 20 | --model is for choose the pretrain models <br> 21 | 22 | ```bash 23 | python Prompt.py --text "a dog is in front of a rabbit" --model vlt5 24 | python Prompt.py --text "a dog is in front of a rabbit" --model bert 25 | python Prompt.py --text "a dog is in front of a rabbit" --model XLNet 26 | ``` 27 | 28 | 29 | - mesh_to_video.py <br> 30 | --center_obj IS THE CENTER OBJECT <br> 31 | --surround_obj IS THE SURROUNDING OBJECT SUBJECT TO CHANGE <br> 32 | --transform_vector THE X Y Z 3d vector for transform <br> 33 | 34 | ```bash 35 | python mesh_to_video.py --center_obj 'mesh_whiterabbit/mesh.obj' --surround_obj 'mesh_snake/mesh.obj' --transform_vector [1,0,0] 36 | ``` 37 | -------------------------------------------------------------------------------- /freqencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .freq import FreqEncoder -------------------------------------------------------------------------------- /freqencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | '-use_fast_math' 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | _backend = load(name='_freqencoder', 34 | extra_cflags=c_flags, 35 | extra_cuda_cflags=nvcc_flags, 36 | sources=[os.path.join(_src_path, 'src', f) for f in [ 37 | 'freqencoder.cu', 38 | 'bindings.cpp', 39 | ]], 40 | ) 41 | 42 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /freqencoder/freq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _freqencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | 15 | class _freq_encoder(Function): 16 | @staticmethod 17 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 18 | def forward(ctx, inputs, degree, output_dim): 19 | # inputs: [B, input_dim], float 20 | # RETURN: [B, F], float 21 | 22 | if not inputs.is_cuda: inputs = inputs.cuda() 23 | inputs = inputs.contiguous() 24 | 25 | B, input_dim = inputs.shape # batch size, coord dim 26 | 27 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 28 | 29 | _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 30 | 31 | ctx.save_for_backward(inputs, outputs) 32 | ctx.dims = [B, input_dim, degree, output_dim] 33 | 34 | return outputs 35 | 36 | @staticmethod 37 | #@once_differentiable 38 | @custom_bwd 39 | def backward(ctx, grad): 40 | # grad: [B, C * C] 41 | 42 | grad = grad.contiguous() 43 | inputs, outputs = ctx.saved_tensors 44 | B, input_dim, degree, output_dim = ctx.dims 45 | 46 | grad_inputs = torch.zeros_like(inputs) 47 | _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 48 | 49 | return grad_inputs, None, None 50 | 51 | 52 | freq_encode = _freq_encoder.apply 53 | 54 | 55 | class FreqEncoder(nn.Module): 56 | def __init__(self, input_dim=3, degree=4): 57 | super().__init__() 58 | 59 | self.input_dim = input_dim 60 | self.degree = degree 61 | self.output_dim = input_dim + input_dim * 2 * degree 62 | 63 | def __repr__(self): 64 | return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" 65 | 66 | def forward(self, inputs, **kwargs): 67 | # inputs: [..., input_dim] 68 | # return: [..., ] 69 | 70 | prefix_shape = list(inputs.shape[:-1]) 71 | inputs = inputs.reshape(-1, self.input_dim) 72 | 73 | outputs = freq_encode(inputs, self.degree, self.output_dim) 74 | 75 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 76 | 77 | return outputs -------------------------------------------------------------------------------- /freqencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | '-use_fast_math' 11 | ] 12 | 13 | if os.name == "posix": 14 | c_flags = ['-O3', '-std=c++14'] 15 | elif os.name == "nt": 16 | c_flags = ['/O2', '/std:c++17'] 17 | 18 | # find cl.exe 19 | def find_cl_path(): 20 | import glob 21 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: 22 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 23 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) 24 | if paths: 25 | return paths[0] 26 | 27 | # If cl.exe is not on path, try to find it. 28 | if os.system("where cl.exe >nul 2>nul") != 0: 29 | cl_path = find_cl_path() 30 | if cl_path is None: 31 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 32 | os.environ["PATH"] += ";" + cl_path 33 | 34 | setup( 35 | name='freqencoder', # package name, import this to use python API 36 | ext_modules=[ 37 | CUDAExtension( 38 | name='_freqencoder', # extension name, import this to use CUDA API 39 | sources=[os.path.join(_src_path, 'src', f) for f in [ 40 | 'freqencoder.cu', 41 | 'bindings.cpp', 42 | ]], 43 | extra_compile_args={ 44 | 'cxx': c_flags, 45 | 'nvcc': nvcc_flags, 46 | } 47 | ), 48 | ], 49 | cmdclass={ 50 | 'build_ext': BuildExtension, 51 | } 52 | ) -------------------------------------------------------------------------------- /freqencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | 3 | #include "freqencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); 7 | m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.cu: -------------------------------------------------------------------------------- 1 | #include <stdint.h> 2 | 3 | #include <cuda.h> 4 | #include <cuda_fp16.h> 5 | #include <cuda_runtime.h> 6 | 7 | #include <ATen/cuda/CUDAContext.h> 8 | #include <torch/torch.h> 9 | 10 | #include <algorithm> 11 | #include <stdexcept> 12 | 13 | #include <cstdio> 14 | 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 18 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 19 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") 20 | 21 | inline constexpr __device__ float PI() { return 3.141592653589793f; } 22 | 23 | template <typename T> 24 | __host__ __device__ T div_round_up(T val, T divisor) { 25 | return (val + divisor - 1) / divisor; 26 | } 27 | 28 | // inputs: [B, D] 29 | // outputs: [B, C], C = D + D * deg * 2 30 | __global__ void kernel_freq( 31 | const float * __restrict__ inputs, 32 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 33 | float * outputs 34 | ) { 35 | // parallel on per-element 36 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 37 | if (t >= B * C) return; 38 | 39 | // get index 40 | const uint32_t b = t / C; 41 | const uint32_t c = t - b * C; // t % C; 42 | 43 | // locate 44 | inputs += b * D; 45 | outputs += t; 46 | 47 | // write self 48 | if (c < D) { 49 | outputs[0] = inputs[c]; 50 | // write freq 51 | } else { 52 | const uint32_t col = c / D - 1; 53 | const uint32_t d = c % D; 54 | const uint32_t freq = col / 2; 55 | const float phase_shift = (col % 2) * (PI() / 2); 56 | outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); 57 | } 58 | } 59 | 60 | // grad: [B, C], C = D + D * deg * 2 61 | // outputs: [B, C] 62 | // grad_inputs: [B, D] 63 | __global__ void kernel_freq_backward( 64 | const float * __restrict__ grad, 65 | const float * __restrict__ outputs, 66 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 67 | float * grad_inputs 68 | ) { 69 | // parallel on per-element 70 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 71 | if (t >= B * D) return; 72 | 73 | const uint32_t b = t / D; 74 | const uint32_t d = t - b * D; // t % D; 75 | 76 | // locate 77 | grad += b * C; 78 | outputs += b * C; 79 | grad_inputs += t; 80 | 81 | // register 82 | float result = grad[d]; 83 | grad += D; 84 | outputs += D; 85 | 86 | for (uint32_t f = 0; f < deg; f++) { 87 | result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); 88 | grad += 2 * D; 89 | outputs += 2 * D; 90 | } 91 | 92 | // write 93 | grad_inputs[0] = result; 94 | } 95 | 96 | 97 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { 98 | CHECK_CUDA(inputs); 99 | CHECK_CUDA(outputs); 100 | 101 | CHECK_CONTIGUOUS(inputs); 102 | CHECK_CONTIGUOUS(outputs); 103 | 104 | CHECK_IS_FLOATING(inputs); 105 | CHECK_IS_FLOATING(outputs); 106 | 107 | static constexpr uint32_t N_THREADS = 128; 108 | 109 | kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>()); 110 | } 111 | 112 | 113 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { 114 | CHECK_CUDA(grad); 115 | CHECK_CUDA(outputs); 116 | CHECK_CUDA(grad_inputs); 117 | 118 | CHECK_CONTIGUOUS(grad); 119 | CHECK_CONTIGUOUS(outputs); 120 | CHECK_CONTIGUOUS(grad_inputs); 121 | 122 | CHECK_IS_FLOATING(grad); 123 | CHECK_IS_FLOATING(outputs); 124 | CHECK_IS_FLOATING(grad_inputs); 125 | 126 | static constexpr uint32_t N_THREADS = 128; 127 | 128 | kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>()); 129 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include <stdint.h> 4 | #include <torch/torch.h> 5 | 6 | // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 7 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); 8 | 9 | // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 10 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) 22 | if paths: 23 | return paths[0] 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | setup( 34 | name='gridencoder', # package name, import this to use python API 35 | ext_modules=[ 36 | CUDAExtension( 37 | name='_gridencoder', # extension name, import this to use CUDA API 38 | sources=[os.path.join(_src_path, 'src', f) for f in [ 39 | 'gridencoder.cu', 40 | 'bindings.cpp', 41 | ]], 42 | extra_compile_args={ 43 | 'cxx': c_flags, 44 | 'nvcc': nvcc_flags, 45 | } 46 | ), 47 | ], 48 | cmdclass={ 49 | 'build_ext': BuildExtension, 50 | } 51 | ) -------------------------------------------------------------------------------- /gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); 9 | m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)"); 10 | } -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include <stdint.h> 5 | #include <torch/torch.h> 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 14 | 15 | void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); 16 | void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L); 17 | 18 | #endif -------------------------------------------------------------------------------- /guidance/clip_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torchvision.transforms as T 5 | import torchvision.transforms.functional as TF 6 | 7 | import clip 8 | 9 | class CLIP(nn.Module): 10 | def __init__(self, device, **kwargs): 11 | super().__init__() 12 | 13 | self.device = device 14 | self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False) 15 | 16 | self.aug = T.Compose([ 17 | T.Resize((224, 224)), 18 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 19 | ]) 20 | 21 | def get_text_embeds(self, prompt, **kwargs): 22 | 23 | text = clip.tokenize(prompt).to(self.device) 24 | text_z = self.clip_model.encode_text(text) 25 | text_z = text_z / text_z.norm(dim=-1, keepdim=True) 26 | 27 | return text_z 28 | 29 | def get_img_embeds(self, image, **kwargs): 30 | 31 | image_z = self.clip_model.encode_image(self.aug(image)) 32 | image_z = image_z / image_z.norm(dim=-1, keepdim=True) 33 | 34 | return image_z 35 | 36 | 37 | def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs): 38 | """ 39 | Args: 40 | grad_scale: scalar or 1-tensor of size [B], i.e. 1 grad_scale per batch item. 41 | """ 42 | # TODO: resize the image from NeRF-rendered resolution (e.g. 128x128) to what CLIP expects (512x512), to prevent Pytorch warning about `antialias=None`. 43 | image_z = self.clip_model.encode_image(self.aug(pred_rgb)) 44 | image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features 45 | 46 | loss = 0 47 | if 'image' in clip_z: 48 | loss -= ((image_z * clip_z['image']).sum(-1) * grad_scale).mean() 49 | 50 | if 'text' in clip_z: 51 | loss -= ((image_z * clip_z['text']).sum(-1) * grad_scale).mean() 52 | 53 | return loss 54 | 55 | -------------------------------------------------------------------------------- /guidance/perpneg_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm 4 | def get_perpendicular_component(x, y): 5 | assert x.shape == y.shape 6 | return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y 7 | 8 | 9 | def batch_get_perpendicular_component(x, y): 10 | assert x.shape == y.shape 11 | result = [] 12 | for i in range(x.shape[0]): 13 | result.append(get_perpendicular_component(x[i], y[i])) 14 | return torch.stack(result) 15 | 16 | 17 | def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size): 18 | """ 19 | Notes: 20 | - weights: an array with the weights for combining the noise predictions 21 | - delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir 22 | """ 23 | delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64] 24 | weights = weights.split(batch_size, dim=0) # K x [B] 25 | # print(f"{weights[0].shape = } {weights = }") 26 | 27 | assert torch.all(weights[0] == 1.0) 28 | 29 | main_positive = delta_noise_preds[0] # [B, 4, 64, 64] 30 | 31 | accumulated_output = torch.zeros_like(main_positive) 32 | for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1): 33 | # print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n") 34 | 35 | idx_non_zero = torch.abs(weights[i]) > 1e-4 36 | 37 | # print(f"{idx_non_zero.shape = }, {idx_non_zero = }") 38 | # print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }") 39 | # print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }") 40 | # print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }") 41 | if sum(idx_non_zero) == 0: 42 | continue 43 | accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero]) 44 | 45 | assert accumulated_output.shape == main_positive.shape, f"{accumulated_output.shape = }, {main_positive.shape = }" 46 | 47 | 48 | return accumulated_output + main_positive -------------------------------------------------------------------------------- /ldm/extras.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from omegaconf import OmegaConf 3 | import torch 4 | from ldm.util import instantiate_from_config 5 | import logging 6 | from contextlib import contextmanager 7 | 8 | from contextlib import contextmanager 9 | import logging 10 | 11 | @contextmanager 12 | def all_logging_disabled(highest_level=logging.CRITICAL): 13 | """ 14 | A context manager that will prevent any logging messages 15 | triggered during the body from being processed. 16 | 17 | :param highest_level: the maximum logging level in use. 18 | This would only need to be changed if a custom level greater than CRITICAL 19 | is defined. 20 | 21 | https://gist.github.com/simon-weber/7853144 22 | """ 23 | # two kind-of hacks here: 24 | # * can't get the highest logging level in effect => delegate to the user 25 | # * can't get the current module-level override => use an undocumented 26 | # (but non-private!) interface 27 | 28 | previous_level = logging.root.manager.disable 29 | 30 | logging.disable(highest_level) 31 | 32 | try: 33 | yield 34 | finally: 35 | logging.disable(previous_level) 36 | 37 | def load_training_dir(train_dir, device, epoch="last"): 38 | """Load a checkpoint and config from training directory""" 39 | train_dir = Path(train_dir) 40 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) 41 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" 42 | config = list(train_dir.rglob(f"*-project.yaml")) 43 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}" 44 | if len(config) > 1: 45 | print(f"found {len(config)} matching config files") 46 | config = sorted(config)[-1] 47 | print(f"selecting {config}") 48 | else: 49 | config = config[0] 50 | 51 | 52 | config = OmegaConf.load(config) 53 | return load_model_from_config(config, ckpt[0], device) 54 | 55 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 56 | """Loads a model from config and a ckpt 57 | if config is a path will use omegaconf to load 58 | """ 59 | if isinstance(config, (str, Path)): 60 | config = OmegaConf.load(config) 61 | 62 | with all_logging_disabled(): 63 | print(f"Loading model from {ckpt}") 64 | pl_sd = torch.load(ckpt, map_location="cpu") 65 | global_step = pl_sd["global_step"] 66 | sd = pl_sd["state_dict"] 67 | model = instantiate_from_config(config.model) 68 | m, u = model.load_state_dict(sd, strict=False) 69 | if len(m) > 0 and verbose: 70 | print("missing keys:") 71 | print(m) 72 | if len(u) > 0 and verbose: 73 | print("unexpected keys:") 74 | model.to(device) 75 | model.eval() 76 | model.cond_stage_model.device = device 77 | return model -------------------------------------------------------------------------------- /ldm/guidance.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from scipy import interpolate 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from IPython.display import clear_output 7 | import abc 8 | 9 | 10 | class GuideModel(torch.nn.Module, abc.ABC): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | @abc.abstractmethod 15 | def preprocess(self, x_img): 16 | pass 17 | 18 | @abc.abstractmethod 19 | def compute_loss(self, inp): 20 | pass 21 | 22 | 23 | class Guider(torch.nn.Module): 24 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False): 25 | """Apply classifier guidance 26 | 27 | Specify a guidance scale as either a scalar 28 | Or a schedule as a list of tuples t = 0->1 and scale, e.g. 29 | [(0, 10), (0.5, 20), (1, 50)] 30 | """ 31 | super().__init__() 32 | self.sampler = sampler 33 | self.index = 0 34 | self.show = verbose 35 | self.guide_model = guide_model 36 | self.history = [] 37 | 38 | if isinstance(scale, (Tuple, List)): 39 | times = np.array([x[0] for x in scale]) 40 | values = np.array([x[1] for x in scale]) 41 | self.scale_schedule = {"times": times, "values": values} 42 | else: 43 | self.scale_schedule = float(scale) 44 | 45 | self.ddim_timesteps = sampler.ddim_timesteps 46 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps 47 | 48 | 49 | def get_scales(self): 50 | if isinstance(self.scale_schedule, float): 51 | return len(self.ddim_timesteps)*[self.scale_schedule] 52 | 53 | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) 54 | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps 55 | return interpolater(fractional_steps) 56 | 57 | def modify_score(self, model, e_t, x, t, c): 58 | 59 | # TODO look up index by t 60 | scale = self.get_scales()[self.index] 61 | 62 | if (scale == 0): 63 | return e_t 64 | 65 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) 66 | with torch.enable_grad(): 67 | x_in = x.detach().requires_grad_(True) 68 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) 69 | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) 70 | 71 | inp = self.guide_model.preprocess(x_img) 72 | loss = self.guide_model.compute_loss(inp) 73 | grads = torch.autograd.grad(loss.sum(), x_in)[0] 74 | correction = grads * scale 75 | 76 | if self.show: 77 | clear_output(wait=True) 78 | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) 79 | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) 80 | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) 81 | plt.axis('off') 82 | plt.show() 83 | plt.imshow(correction[0][0].detach().cpu()) 84 | plt.axis('off') 85 | plt.show() 86 | 87 | 88 | e_t_mod = e_t - sqrt_1ma*correction 89 | if self.show: 90 | fig, axs = plt.subplots(1, 3) 91 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) 92 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) 93 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) 94 | plt.show() 95 | self.index += 1 96 | return e_t_mod -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def renorm_thresholding(x0, value): 15 | # renorm 16 | pred_max = x0.max() 17 | pred_min = x0.min() 18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 20 | 21 | s = torch.quantile( 22 | rearrange(pred_x0, 'b ... -> b (...)').abs(), 23 | value, 24 | dim=-1 25 | ) 26 | s.clamp_(min=1.0) 27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) 28 | 29 | # clip by threshold 30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max 31 | 32 | # temporary hack: numpy on cpu 33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() 34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device) 35 | 36 | # re.renorm 37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range 39 | return pred_x0 40 | 41 | 42 | def norm_thresholding(x0, value): 43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 44 | return x0 * (value / s) 45 | 46 | 47 | def spatial_norm_thresholding(x0, value): 48 | # b c h w 49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 50 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/evaluate/frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Minimal Reference implementation for the Frechet Video Distance (FVD). 18 | 19 | FVD is a metric for the quality of video generation models. It is inspired by 20 | the FID (Frechet Inception Distance) used for images, but uses a different 21 | embedding to be better suitable for videos. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | 29 | import six 30 | import tensorflow.compat.v1 as tf 31 | import tensorflow_gan as tfgan 32 | import tensorflow_hub as hub 33 | 34 | 35 | def preprocess(videos, target_resolution): 36 | """Runs some preprocessing on the videos for I3D model. 37 | 38 | Args: 39 | videos: <T>[batch_size, num_frames, height, width, depth] The videos to be 40 | preprocessed. We don't care about the specific dtype of the videos, it can 41 | be anything that tf.image.resize_bilinear accepts. Values are expected to 42 | be in the range 0-255. 43 | target_resolution: (width, height): target video resolution 44 | 45 | Returns: 46 | videos: <float32>[batch_size, num_frames, height, width, depth] 47 | """ 48 | videos_shape = list(videos.shape) 49 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 50 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) 51 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 52 | output_videos = tf.reshape(resized_videos, target_shape) 53 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 54 | return scaled_videos 55 | 56 | 57 | def _is_in_graph(tensor_name): 58 | """Checks whether a given tensor does exists in the graph.""" 59 | try: 60 | tf.get_default_graph().get_tensor_by_name(tensor_name) 61 | except KeyError: 62 | return False 63 | return True 64 | 65 | 66 | def create_id3_embedding(videos,warmup=False,batch_size=16): 67 | """Embeds the given videos using the Inflated 3D Convolution ne twork. 68 | 69 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the 70 | first call. 71 | 72 | Args: 73 | videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3]. 74 | Expected range is [-1, 1]. 75 | 76 | Returns: 77 | embedding: <float32>[batch_size, embedding_size]. embedding_size depends 78 | on the model used. 79 | 80 | Raises: 81 | ValueError: when a provided embedding_layer is not supported. 82 | """ 83 | 84 | # batch_size = 16 85 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" 86 | 87 | 88 | # Making sure that we import the graph separately for 89 | # each different input video tensor. 90 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( 91 | videos.name).replace(":", "_") 92 | 93 | 94 | 95 | assert_ops = [ 96 | tf.Assert( 97 | tf.reduce_max(videos) <= 1.001, 98 | ["max value in frame is > 1", videos]), 99 | tf.Assert( 100 | tf.reduce_min(videos) >= -1.001, 101 | ["min value in frame is < -1", videos]), 102 | tf.assert_equal( 103 | tf.shape(videos)[0], 104 | batch_size, ["invalid frame batch size: ", 105 | tf.shape(videos)], 106 | summarize=6), 107 | ] 108 | with tf.control_dependencies(assert_ops): 109 | videos = tf.identity(videos) 110 | 111 | module_scope = "%s_apply_default/" % module_name 112 | 113 | # To check whether the module has already been loaded into the graph, we look 114 | # for a given tensor name. If this tensor name exists, we assume the function 115 | # has been called before and the graph was imported. Otherwise we import it. 116 | # Note: in theory, the tensor could exist, but have wrong shapes. 117 | # This will happen if create_id3_embedding is called with a frames_placehoder 118 | # of wrong size/batch size, because even though that will throw a tf.Assert 119 | # on graph-execution time, it will insert the tensor (with wrong shape) into 120 | # the graph. This is why we need the following assert. 121 | if warmup: 122 | video_batch_size = int(videos.shape[0]) 123 | assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}" 124 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 125 | if not _is_in_graph(tensor_name): 126 | i3d_model = hub.Module(module_spec, name=module_name) 127 | i3d_model(videos) 128 | 129 | # gets the kinetics-i3d-400-logits layer 130 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 131 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) 132 | return tensor 133 | 134 | 135 | def calculate_fvd(real_activations, 136 | generated_activations): 137 | """Returns a list of ops that compute metrics as funcs of activations. 138 | 139 | Args: 140 | real_activations: <float32>[num_samples, embedding_size] 141 | generated_activations: <float32>[num_samples, embedding_size] 142 | 143 | Returns: 144 | A scalar that contains the requested FVD. 145 | """ 146 | return tfgan.eval.frechet_classifier_distance_from_activations( 147 | real_activations, generated_activations) 148 | -------------------------------------------------------------------------------- /ldm/modules/evaluate/ssim.py: -------------------------------------------------------------------------------- 1 | # MIT Licence 2 | 3 | # Methods to predict the SSIM, taken from 4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 5 | 6 | from math import exp 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor( 14 | [ 15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) 16 | for x in range(window_size) 17 | ] 18 | ) 19 | return gauss / gauss.sum() 20 | 21 | 22 | def create_window(window_size, channel): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = Variable( 26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 27 | ) 28 | return window 29 | 30 | 31 | def _ssim( 32 | img1, img2, window, window_size, channel, mask=None, size_average=True 33 | ): 34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1 * mu2 40 | 41 | sigma1_sq = ( 42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) 43 | - mu1_sq 44 | ) 45 | sigma2_sq = ( 46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) 47 | - mu2_sq 48 | ) 49 | sigma12 = ( 50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 51 | - mu1_mu2 52 | ) 53 | 54 | C1 = (0.01) ** 2 55 | C2 = (0.03) ** 2 56 | 57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 59 | ) 60 | 61 | if not (mask is None): 62 | b = mask.size(0) 63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask 64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( 65 | dim=1 66 | ).clamp(min=1) 67 | return ssim_map 68 | 69 | import pdb 70 | 71 | pdb.set_trace 72 | 73 | if size_average: 74 | return ssim_map.mean() 75 | else: 76 | return ssim_map.mean(1).mean(1).mean(1) 77 | 78 | 79 | class SSIM(torch.nn.Module): 80 | def __init__(self, window_size=11, size_average=True): 81 | super(SSIM, self).__init__() 82 | self.window_size = window_size 83 | self.size_average = size_average 84 | self.channel = 1 85 | self.window = create_window(window_size, self.channel) 86 | 87 | def forward(self, img1, img2, mask=None): 88 | (_, channel, _, _) = img1.size() 89 | 90 | if ( 91 | channel == self.channel 92 | and self.window.data.type() == img1.data.type() 93 | ): 94 | window = self.window 95 | else: 96 | window = create_window(self.window_size, channel) 97 | 98 | if img1.is_cuda: 99 | window = window.cuda(img1.get_device()) 100 | window = window.type_as(img1) 101 | 102 | self.window = window 103 | self.channel = channel 104 | 105 | return _ssim( 106 | img1, 107 | img2, 108 | window, 109 | self.window_size, 110 | channel, 111 | mask, 112 | self.size_average, 113 | ) 114 | 115 | 116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True): 117 | (_, channel, _, _) = img1.size() 118 | window = create_window(window_size, channel) 119 | 120 | if img1.is_cuda: 121 | window = window.cuda(img1.get_device()) 122 | window = window.type_as(img1) 123 | 124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average) 125 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/helpers.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from collections import namedtuple 4 | import torch 5 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 6 | 7 | """ 8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 9 | """ 10 | 11 | 12 | class Flatten(Module): 13 | def forward(self, input): 14 | return input.view(input.size(0), -1) 15 | 16 | 17 | def l2_norm(input, axis=1): 18 | norm = torch.norm(input, 2, axis, True) 19 | output = torch.div(input, norm) 20 | return output 21 | 22 | 23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 24 | """ A named tuple describing a ResNet block. """ 25 | 26 | 27 | def get_block(in_channel, depth, num_units, stride=2): 28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 29 | 30 | 31 | def get_blocks(num_layers): 32 | if num_layers == 50: 33 | blocks = [ 34 | get_block(in_channel=64, depth=64, num_units=3), 35 | get_block(in_channel=64, depth=128, num_units=4), 36 | get_block(in_channel=128, depth=256, num_units=14), 37 | get_block(in_channel=256, depth=512, num_units=3) 38 | ] 39 | elif num_layers == 100: 40 | blocks = [ 41 | get_block(in_channel=64, depth=64, num_units=3), 42 | get_block(in_channel=64, depth=128, num_units=13), 43 | get_block(in_channel=128, depth=256, num_units=30), 44 | get_block(in_channel=256, depth=512, num_units=3) 45 | ] 46 | elif num_layers == 152: 47 | blocks = [ 48 | get_block(in_channel=64, depth=64, num_units=3), 49 | get_block(in_channel=64, depth=128, num_units=8), 50 | get_block(in_channel=128, depth=256, num_units=36), 51 | get_block(in_channel=256, depth=512, num_units=3) 52 | ] 53 | else: 54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 55 | return blocks 56 | 57 | 58 | class SEModule(Module): 59 | def __init__(self, channels, reduction): 60 | super(SEModule, self).__init__() 61 | self.avg_pool = AdaptiveAvgPool2d(1) 62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 63 | self.relu = ReLU(inplace=True) 64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 65 | self.sigmoid = Sigmoid() 66 | 67 | def forward(self, x): 68 | module_input = x 69 | x = self.avg_pool(x) 70 | x = self.fc1(x) 71 | x = self.relu(x) 72 | x = self.fc2(x) 73 | x = self.sigmoid(x) 74 | return module_input * x 75 | 76 | 77 | class bottleneck_IR(Module): 78 | def __init__(self, in_channel, depth, stride): 79 | super(bottleneck_IR, self).__init__() 80 | if in_channel == depth: 81 | self.shortcut_layer = MaxPool2d(1, stride) 82 | else: 83 | self.shortcut_layer = Sequential( 84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 85 | BatchNorm2d(depth) 86 | ) 87 | self.res_layer = Sequential( 88 | BatchNorm2d(in_channel), 89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 91 | ) 92 | 93 | def forward(self, x): 94 | shortcut = self.shortcut_layer(x) 95 | res = self.res_layer(x) 96 | return res + shortcut 97 | 98 | 99 | class bottleneck_IR_SE(Module): 100 | def __init__(self, in_channel, depth, stride): 101 | super(bottleneck_IR_SE, self).__init__() 102 | if in_channel == depth: 103 | self.shortcut_layer = MaxPool2d(1, stride) 104 | else: 105 | self.shortcut_layer = Sequential( 106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 107 | BatchNorm2d(depth) 108 | ) 109 | self.res_layer = Sequential( 110 | BatchNorm2d(in_channel), 111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 112 | PReLU(depth), 113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 114 | BatchNorm2d(depth), 115 | SEModule(depth, 16) 116 | ) 117 | 118 | def forward(self, x): 119 | shortcut = self.shortcut_layer(x) 120 | res = self.res_layer(x) 121 | return res + shortcut -------------------------------------------------------------------------------- /ldm/thirdp/psp/id_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | import torch 3 | from torch import nn 4 | from ldm.thirdp.psp.model_irse import Backbone 5 | 6 | 7 | class IDFeatures(nn.Module): 8 | def __init__(self, model_path): 9 | super(IDFeatures, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | 16 | def forward(self, x, crop=False): 17 | # Not sure of the image range here 18 | if crop: 19 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area") 20 | x = x[:, :, 35:223, 32:220] 21 | x = self.face_pool(x) 22 | x_feats = self.facenet(x) 23 | return x_feats 24 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/model_irse.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 4 | from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 5 | 6 | """ 7 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Backbone(Module): 12 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 13 | super(Backbone, self).__init__() 14 | assert input_size in [112, 224], "input_size should be 112 or 224" 15 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 16 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 17 | blocks = get_blocks(num_layers) 18 | if mode == 'ir': 19 | unit_module = bottleneck_IR 20 | elif mode == 'ir_se': 21 | unit_module = bottleneck_IR_SE 22 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 23 | BatchNorm2d(64), 24 | PReLU(64)) 25 | if input_size == 112: 26 | self.output_layer = Sequential(BatchNorm2d(512), 27 | Dropout(drop_ratio), 28 | Flatten(), 29 | Linear(512 * 7 * 7, 512), 30 | BatchNorm1d(512, affine=affine)) 31 | else: 32 | self.output_layer = Sequential(BatchNorm2d(512), 33 | Dropout(drop_ratio), 34 | Flatten(), 35 | Linear(512 * 14 * 14, 512), 36 | BatchNorm1d(512, affine=affine)) 37 | 38 | modules = [] 39 | for block in blocks: 40 | for bottleneck in block: 41 | modules.append(unit_module(bottleneck.in_channel, 42 | bottleneck.depth, 43 | bottleneck.stride)) 44 | self.body = Sequential(*modules) 45 | 46 | def forward(self, x): 47 | x = self.input_layer(x) 48 | x = self.body(x) 49 | x = self.output_layer(x) 50 | return l2_norm(x) 51 | 52 | 53 | def IR_50(input_size): 54 | """Constructs a ir-50 model.""" 55 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 56 | return model 57 | 58 | 59 | def IR_101(input_size): 60 | """Constructs a ir-101 model.""" 61 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 62 | return model 63 | 64 | 65 | def IR_152(input_size): 66 | """Constructs a ir-152 model.""" 67 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 68 | return model 69 | 70 | 71 | def IR_SE_50(input_size): 72 | """Constructs a ir_se-50 model.""" 73 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 74 | return model 75 | 76 | 77 | def IR_SE_101(input_size): 78 | """Constructs a ir_se-101 model.""" 79 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 80 | return model 81 | 82 | 83 | def IR_SE_152(input_size): 84 | """Constructs a ir_se-152 model.""" 85 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 86 | return model -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torchvision 4 | import torch 5 | from torch import optim 6 | import numpy as np 7 | 8 | from inspect import isfunction 9 | from PIL import Image, ImageDraw, ImageFont 10 | 11 | import os 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | from PIL import Image 15 | import torch 16 | import time 17 | import cv2 18 | 19 | import PIL 20 | 21 | def pil_rectangle_crop(im): 22 | width, height = im.size # Get dimensions 23 | 24 | if width <= height: 25 | left = 0 26 | right = width 27 | top = (height - width)/2 28 | bottom = (height + width)/2 29 | else: 30 | 31 | top = 0 32 | bottom = height 33 | left = (width - height) / 2 34 | bottom = (width + height) / 2 35 | 36 | # Crop the center of the image 37 | im = im.crop((left, top, right, bottom)) 38 | return im 39 | 40 | 41 | def log_txt_as_img(wh, xc, size=10): 42 | # wh a tuple of (width, height) 43 | # xc a list of captions to plot 44 | b = len(xc) 45 | txts = list() 46 | for bi in range(b): 47 | txt = Image.new("RGB", wh, color="white") 48 | draw = ImageDraw.Draw(txt) 49 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 50 | nc = int(40 * (wh[0] / 256)) 51 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 52 | 53 | try: 54 | draw.text((0, 0), lines, fill="black", font=font) 55 | except UnicodeEncodeError: 56 | print("Cant encode string for logging. Skipping.") 57 | 58 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 59 | txts.append(txt) 60 | txts = np.stack(txts) 61 | txts = torch.tensor(txts) 62 | return txts 63 | 64 | 65 | def ismap(x): 66 | if not isinstance(x, torch.Tensor): 67 | return False 68 | return (len(x.shape) == 4) and (x.shape[1] > 3) 69 | 70 | 71 | def isimage(x): 72 | if not isinstance(x,torch.Tensor): 73 | return False 74 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 75 | 76 | 77 | def exists(x): 78 | return x is not None 79 | 80 | 81 | def default(val, d): 82 | if exists(val): 83 | return val 84 | return d() if isfunction(d) else d 85 | 86 | 87 | def mean_flat(tensor): 88 | """ 89 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 90 | Take the mean over all non-batch dimensions. 91 | """ 92 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 93 | 94 | 95 | def count_params(model, verbose=False): 96 | total_params = sum(p.numel() for p in model.parameters()) 97 | if verbose: 98 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 99 | return total_params 100 | 101 | 102 | def instantiate_from_config(config): 103 | if not "target" in config: 104 | if config == '__is_first_stage__': 105 | return None 106 | elif config == "__is_unconditional__": 107 | return None 108 | raise KeyError("Expected key `target` to instantiate.") 109 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 110 | 111 | 112 | def get_obj_from_str(string, reload=False): 113 | module, cls = string.rsplit(".", 1) 114 | if reload: 115 | module_imp = importlib.import_module(module) 116 | importlib.reload(module_imp) 117 | return getattr(importlib.import_module(module, package=None), cls) 118 | 119 | 120 | class AdamWwithEMAandWings(optim.Optimizer): 121 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 122 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 123 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 124 | ema_power=1., param_names=()): 125 | """AdamW that saves EMA versions of the parameters.""" 126 | if not 0.0 <= lr: 127 | raise ValueError("Invalid learning rate: {}".format(lr)) 128 | if not 0.0 <= eps: 129 | raise ValueError("Invalid epsilon value: {}".format(eps)) 130 | if not 0.0 <= betas[0] < 1.0: 131 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 132 | if not 0.0 <= betas[1] < 1.0: 133 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 134 | if not 0.0 <= weight_decay: 135 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 136 | if not 0.0 <= ema_decay <= 1.0: 137 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 138 | defaults = dict(lr=lr, betas=betas, eps=eps, 139 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 140 | ema_power=ema_power, param_names=param_names) 141 | super().__init__(params, defaults) 142 | 143 | def __setstate__(self, state): 144 | super().__setstate__(state) 145 | for group in self.param_groups: 146 | group.setdefault('amsgrad', False) 147 | 148 | @torch.no_grad() 149 | def step(self, closure=None): 150 | """Performs a single optimization step. 151 | Args: 152 | closure (callable, optional): A closure that reevaluates the model 153 | and returns the loss. 154 | """ 155 | loss = None 156 | if closure is not None: 157 | with torch.enable_grad(): 158 | loss = closure() 159 | 160 | for group in self.param_groups: 161 | params_with_grad = [] 162 | grads = [] 163 | exp_avgs = [] 164 | exp_avg_sqs = [] 165 | ema_params_with_grad = [] 166 | state_sums = [] 167 | max_exp_avg_sqs = [] 168 | state_steps = [] 169 | amsgrad = group['amsgrad'] 170 | beta1, beta2 = group['betas'] 171 | ema_decay = group['ema_decay'] 172 | ema_power = group['ema_power'] 173 | 174 | for p in group['params']: 175 | if p.grad is None: 176 | continue 177 | params_with_grad.append(p) 178 | if p.grad.is_sparse: 179 | raise RuntimeError('AdamW does not support sparse gradients') 180 | grads.append(p.grad) 181 | 182 | state = self.state[p] 183 | 184 | # State initialization 185 | if len(state) == 0: 186 | state['step'] = 0 187 | # Exponential moving average of gradient values 188 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 189 | # Exponential moving average of squared gradient values 190 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 191 | if amsgrad: 192 | # Maintains max of all exp. moving avg. of sq. grad. values 193 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 194 | # Exponential moving average of parameter values 195 | state['param_exp_avg'] = p.detach().float().clone() 196 | 197 | exp_avgs.append(state['exp_avg']) 198 | exp_avg_sqs.append(state['exp_avg_sq']) 199 | ema_params_with_grad.append(state['param_exp_avg']) 200 | 201 | if amsgrad: 202 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 203 | 204 | # update the steps for each param group update 205 | state['step'] += 1 206 | # record the step after step update 207 | state_steps.append(state['step']) 208 | 209 | optim._functional.adamw(params_with_grad, 210 | grads, 211 | exp_avgs, 212 | exp_avg_sqs, 213 | max_exp_avg_sqs, 214 | state_steps, 215 | amsgrad=amsgrad, 216 | beta1=beta1, 217 | beta2=beta2, 218 | lr=group['lr'], 219 | weight_decay=group['weight_decay'], 220 | eps=group['eps'], 221 | maximize=False) 222 | 223 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 224 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 225 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 226 | 227 | return loss -------------------------------------------------------------------------------- /meshutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pymeshlab as pml 3 | 4 | def poisson_mesh_reconstruction(points, normals=None): 5 | # points/normals: [N, 3] np.ndarray 6 | 7 | import open3d as o3d 8 | 9 | pcd = o3d.geometry.PointCloud() 10 | pcd.points = o3d.utility.Vector3dVector(points) 11 | 12 | # outlier removal 13 | pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10) 14 | 15 | # normals 16 | if normals is None: 17 | pcd.estimate_normals() 18 | else: 19 | pcd.normals = o3d.utility.Vector3dVector(normals[ind]) 20 | 21 | # visualize 22 | o3d.visualization.draw_geometries([pcd], point_show_normal=False) 23 | 24 | mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9) 25 | vertices_to_remove = densities < np.quantile(densities, 0.1) 26 | mesh.remove_vertices_by_mask(vertices_to_remove) 27 | 28 | # visualize 29 | o3d.visualization.draw_geometries([mesh]) 30 | 31 | vertices = np.asarray(mesh.vertices) 32 | triangles = np.asarray(mesh.triangles) 33 | 34 | print(f'[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}') 35 | 36 | return vertices, triangles 37 | 38 | 39 | def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True): 40 | # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect. 41 | 42 | _ori_vert_shape = verts.shape 43 | _ori_face_shape = faces.shape 44 | 45 | if backend == 'pyfqmr': 46 | import pyfqmr 47 | solver = pyfqmr.Simplify() 48 | solver.setMesh(verts, faces) 49 | solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False) 50 | verts, faces, normals = solver.getMesh() 51 | else: 52 | 53 | m = pml.Mesh(verts, faces) 54 | ms = pml.MeshSet() 55 | ms.add_mesh(m, 'mesh') # will copy! 56 | 57 | # filters 58 | # ms.meshing_decimation_clustering(threshold=pml.Percentage(1)) 59 | ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement) 60 | 61 | if remesh: 62 | # ms.apply_coord_taubin_smoothing() 63 | ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1)) 64 | 65 | # extract mesh 66 | m = ms.current_mesh() 67 | verts = m.vertex_matrix() 68 | faces = m.face_matrix() 69 | 70 | print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') 71 | 72 | return verts, faces 73 | 74 | 75 | def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01): 76 | # verts: [N, 3] 77 | # faces: [N, 3] 78 | 79 | _ori_vert_shape = verts.shape 80 | _ori_face_shape = faces.shape 81 | 82 | m = pml.Mesh(verts, faces) 83 | ms = pml.MeshSet() 84 | ms.add_mesh(m, 'mesh') # will copy! 85 | 86 | # filters 87 | ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces 88 | 89 | if v_pct > 0: 90 | ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal 91 | 92 | ms.meshing_remove_duplicate_faces() # faces defined by the same verts 93 | ms.meshing_remove_null_faces() # faces with area == 0 94 | 95 | if min_d > 0: 96 | ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d)) 97 | 98 | if min_f > 0: 99 | ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f) 100 | 101 | if repair: 102 | # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True) 103 | ms.meshing_repair_non_manifold_edges(method=0) 104 | ms.meshing_repair_non_manifold_vertices(vertdispratio=0) 105 | 106 | if remesh: 107 | # ms.apply_coord_taubin_smoothing() 108 | ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size)) 109 | 110 | # extract mesh 111 | m = ms.current_mesh() 112 | verts = m.vertex_matrix() 113 | faces = m.face_matrix() 114 | 115 | print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') 116 | 117 | return verts, faces -------------------------------------------------------------------------------- /nerf/network_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from activation import trunc_exp, biased_softplus 6 | from .renderer import NeRFRenderer 7 | 8 | import numpy as np 9 | from encoding import get_encoder 10 | 11 | from .utils import safe_normalize 12 | 13 | class MLP(nn.Module): 14 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): 15 | super().__init__() 16 | self.dim_in = dim_in 17 | self.dim_out = dim_out 18 | self.dim_hidden = dim_hidden 19 | self.num_layers = num_layers 20 | 21 | net = [] 22 | for l in range(num_layers): 23 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) 24 | 25 | self.net = nn.ModuleList(net) 26 | 27 | def forward(self, x): 28 | for l in range(self.num_layers): 29 | x = self.net[l](x) 30 | if l != self.num_layers - 1: 31 | x = F.relu(x, inplace=True) 32 | return x 33 | 34 | 35 | class NeRFNetwork(NeRFRenderer): 36 | def __init__(self, 37 | opt, 38 | num_layers=3, 39 | hidden_dim=64, 40 | num_layers_bg=2, 41 | hidden_dim_bg=32, 42 | ): 43 | 44 | super().__init__(opt) 45 | 46 | self.num_layers = num_layers 47 | self.hidden_dim = hidden_dim 48 | 49 | self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep') 50 | 51 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) 52 | # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True) 53 | 54 | self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus 55 | 56 | # background network 57 | if self.opt.bg_radius > 0: 58 | self.num_layers_bg = num_layers_bg 59 | self.hidden_dim_bg = hidden_dim_bg 60 | 61 | # use a very simple network to avoid it learning the prompt... 62 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6) 63 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) 64 | 65 | else: 66 | self.bg_net = None 67 | 68 | def common_forward(self, x): 69 | 70 | # sigma 71 | enc = self.encoder(x, bound=self.bound, max_level=self.max_level) 72 | 73 | h = self.sigma_net(enc) 74 | 75 | sigma = self.density_activation(h[..., 0] + self.density_blob(x)) 76 | albedo = torch.sigmoid(h[..., 1:]) 77 | 78 | return sigma, albedo 79 | 80 | # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 81 | def finite_difference_normal(self, x, epsilon=1e-2): 82 | # x: [N, 3] 83 | dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 84 | dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 85 | dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 86 | dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 87 | dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) 88 | dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) 89 | 90 | normal = torch.stack([ 91 | 0.5 * (dx_pos - dx_neg) / epsilon, 92 | 0.5 * (dy_pos - dy_neg) / epsilon, 93 | 0.5 * (dz_pos - dz_neg) / epsilon 94 | ], dim=-1) 95 | 96 | return -normal 97 | 98 | def normal(self, x): 99 | normal = self.finite_difference_normal(x) 100 | normal = safe_normalize(normal) 101 | normal = torch.nan_to_num(normal) 102 | return normal 103 | 104 | def forward(self, x, d, l=None, ratio=1, shading='albedo'): 105 | # x: [N, 3], in [-bound, bound] 106 | # d: [N, 3], view direction, nomalized in [-1, 1] 107 | # l: [3], plane light direction, nomalized in [-1, 1] 108 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) 109 | 110 | sigma, albedo = self.common_forward(x) 111 | 112 | if shading == 'albedo': 113 | normal = None 114 | color = albedo 115 | 116 | else: # lambertian shading 117 | 118 | # normal = self.normal_net(enc) 119 | normal = self.normal(x) 120 | 121 | lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,] 122 | 123 | if shading == 'textureless': 124 | color = lambertian.unsqueeze(-1).repeat(1, 3) 125 | elif shading == 'normal': 126 | color = (normal + 1) / 2 127 | else: # 'lambertian' 128 | color = albedo * lambertian.unsqueeze(-1) 129 | 130 | return sigma, color, normal 131 | 132 | 133 | def density(self, x): 134 | # x: [N, 3], in [-bound, bound] 135 | 136 | sigma, albedo = self.common_forward(x) 137 | 138 | return { 139 | 'sigma': sigma, 140 | 'albedo': albedo, 141 | } 142 | 143 | 144 | def background(self, d): 145 | 146 | h = self.encoder_bg(d) # [N, C] 147 | 148 | h = self.bg_net(h) 149 | 150 | # sigmoid activation for rgb 151 | rgbs = torch.sigmoid(h) 152 | 153 | return rgbs 154 | 155 | # optimizer utils 156 | def get_params(self, lr): 157 | 158 | params = [ 159 | {'params': self.encoder.parameters(), 'lr': lr * 10}, 160 | {'params': self.sigma_net.parameters(), 'lr': lr}, 161 | # {'params': self.normal_net.parameters(), 'lr': lr}, 162 | ] 163 | 164 | if self.opt.bg_radius > 0: 165 | # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) 166 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 167 | 168 | if self.opt.dmtet and not self.opt.lock_geo: 169 | params.append({'params': self.sdf, 'lr': lr}) 170 | params.append({'params': self.deform, 'lr': lr}) 171 | 172 | return params -------------------------------------------------------------------------------- /nerf/network_grid_taichi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from activation import trunc_exp 6 | from .renderer import NeRFRenderer 7 | 8 | import numpy as np 9 | from encoding import get_encoder 10 | 11 | from .utils import safe_normalize 12 | 13 | class MLP(nn.Module): 14 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): 15 | super().__init__() 16 | self.dim_in = dim_in 17 | self.dim_out = dim_out 18 | self.dim_hidden = dim_hidden 19 | self.num_layers = num_layers 20 | 21 | net = [] 22 | for l in range(num_layers): 23 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) 24 | 25 | self.net = nn.ModuleList(net) 26 | 27 | def forward(self, x): 28 | for l in range(self.num_layers): 29 | x = self.net[l](x) 30 | if l != self.num_layers - 1: 31 | x = F.relu(x, inplace=True) 32 | return x 33 | 34 | 35 | class NeRFNetwork(NeRFRenderer): 36 | def __init__(self, 37 | opt, 38 | num_layers=2, 39 | hidden_dim=32, 40 | num_layers_bg=2, 41 | hidden_dim_bg=16, 42 | ): 43 | 44 | super().__init__(opt) 45 | 46 | self.num_layers = num_layers 47 | self.hidden_dim = hidden_dim 48 | 49 | self.encoder, self.in_dim = get_encoder('hashgrid_taichi', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep') 50 | 51 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) 52 | # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True) 53 | 54 | self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else F.softplus 55 | 56 | # background network 57 | if self.opt.bg_radius > 0: 58 | self.num_layers_bg = num_layers_bg 59 | self.hidden_dim_bg = hidden_dim_bg 60 | # use a very simple network to avoid it learning the prompt... 61 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency_torch', input_dim=3, multires=4) # TODO: freq encoder can be replaced by a Taichi implementation 62 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) 63 | 64 | else: 65 | self.bg_net = None 66 | 67 | def common_forward(self, x): 68 | 69 | # sigma 70 | enc = self.encoder(x, bound=self.bound) 71 | 72 | h = self.sigma_net(enc) 73 | 74 | sigma = self.density_activation(h[..., 0] + self.density_blob(x)) 75 | albedo = torch.sigmoid(h[..., 1:]) 76 | 77 | return sigma, albedo 78 | 79 | # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 80 | def finite_difference_normal(self, x, epsilon=1e-2): 81 | # x: [N, 3] 82 | dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 83 | dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 84 | dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 85 | dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 86 | dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) 87 | dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) 88 | 89 | normal = torch.stack([ 90 | 0.5 * (dx_pos - dx_neg) / epsilon, 91 | 0.5 * (dy_pos - dy_neg) / epsilon, 92 | 0.5 * (dz_pos - dz_neg) / epsilon 93 | ], dim=-1) 94 | 95 | return -normal 96 | 97 | def normal(self, x): 98 | normal = self.finite_difference_normal(x) 99 | normal = safe_normalize(normal) 100 | normal = torch.nan_to_num(normal) 101 | return normal 102 | 103 | def forward(self, x, d, l=None, ratio=1, shading='albedo'): 104 | # x: [N, 3], in [-bound, bound] 105 | # d: [N, 3], view direction, nomalized in [-1, 1] 106 | # l: [3], plane light direction, nomalized in [-1, 1] 107 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) 108 | 109 | sigma, albedo = self.common_forward(x) 110 | 111 | if shading == 'albedo': 112 | normal = None 113 | color = albedo 114 | 115 | else: # lambertian shading 116 | # normal = self.normal_net(enc) 117 | normal = self.normal(x) 118 | 119 | lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,] 120 | 121 | if shading == 'textureless': 122 | color = lambertian.unsqueeze(-1).repeat(1, 3) 123 | elif shading == 'normal': 124 | color = (normal + 1) / 2 125 | else: # 'lambertian' 126 | color = albedo * lambertian.unsqueeze(-1) 127 | 128 | return sigma, color, normal 129 | 130 | 131 | def density(self, x): 132 | # x: [N, 3], in [-bound, bound] 133 | 134 | sigma, albedo = self.common_forward(x) 135 | 136 | return { 137 | 'sigma': sigma, 138 | 'albedo': albedo, 139 | } 140 | 141 | 142 | def background(self, d): 143 | 144 | h = self.encoder_bg(d) # [N, C] 145 | 146 | h = self.bg_net(h) 147 | 148 | # sigmoid activation for rgb 149 | rgbs = torch.sigmoid(h) 150 | 151 | return rgbs 152 | 153 | # optimizer utils 154 | def get_params(self, lr): 155 | 156 | params = [ 157 | {'params': self.encoder.parameters(), 'lr': lr * 10}, 158 | {'params': self.sigma_net.parameters(), 'lr': lr}, 159 | # {'params': self.normal_net.parameters(), 'lr': lr}, 160 | ] 161 | 162 | if self.opt.bg_radius > 0: 163 | # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) 164 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 165 | 166 | if self.opt.dmtet and not self.opt.lock_geo: 167 | params.append({'params': self.sdf, 'lr': lr}) 168 | params.append({'params': self.deform, 'lr': lr}) 169 | 170 | return params -------------------------------------------------------------------------------- /nerf/network_grid_tcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from activation import trunc_exp, biased_softplus 6 | from .renderer import NeRFRenderer 7 | 8 | import numpy as np 9 | from encoding import get_encoder 10 | 11 | from .utils import safe_normalize 12 | 13 | import tinycudann as tcnn 14 | 15 | class MLP(nn.Module): 16 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): 17 | super().__init__() 18 | self.dim_in = dim_in 19 | self.dim_out = dim_out 20 | self.dim_hidden = dim_hidden 21 | self.num_layers = num_layers 22 | 23 | net = [] 24 | for l in range(num_layers): 25 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) 26 | 27 | self.net = nn.ModuleList(net) 28 | 29 | def forward(self, x): 30 | for l in range(self.num_layers): 31 | x = self.net[l](x) 32 | if l != self.num_layers - 1: 33 | x = F.relu(x, inplace=True) 34 | return x 35 | 36 | 37 | class NeRFNetwork(NeRFRenderer): 38 | def __init__(self, 39 | opt, 40 | num_layers=3, 41 | hidden_dim=64, 42 | num_layers_bg=2, 43 | hidden_dim_bg=32, 44 | ): 45 | 46 | super().__init__(opt) 47 | 48 | self.num_layers = num_layers 49 | self.hidden_dim = hidden_dim 50 | 51 | self.encoder = tcnn.Encoding( 52 | n_input_dims=3, 53 | encoding_config={ 54 | "otype": "HashGrid", 55 | "n_levels": 16, 56 | "n_features_per_level": 2, 57 | "log2_hashmap_size": 19, 58 | "base_resolution": 16, 59 | "interpolation": "Smoothstep", 60 | "per_level_scale": np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)), 61 | }, 62 | dtype=torch.float32, # ENHANCE: default float16 seems unstable... 63 | ) 64 | self.in_dim = self.encoder.n_output_dims 65 | # use torch MLP, as tcnn MLP doesn't impl second-order derivative 66 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) 67 | 68 | self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus 69 | 70 | # background network 71 | if self.opt.bg_radius > 0: 72 | self.num_layers_bg = num_layers_bg 73 | self.hidden_dim_bg = hidden_dim_bg 74 | 75 | # use a very simple network to avoid it learning the prompt... 76 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6) 77 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) 78 | 79 | else: 80 | self.bg_net = None 81 | 82 | def common_forward(self, x): 83 | 84 | # sigma 85 | enc = self.encoder((x + self.bound) / (2 * self.bound)).float() 86 | h = self.sigma_net(enc) 87 | 88 | sigma = self.density_activation(h[..., 0] + self.density_blob(x)) 89 | albedo = torch.sigmoid(h[..., 1:]) 90 | 91 | return sigma, albedo 92 | 93 | def normal(self, x): 94 | 95 | with torch.enable_grad(): 96 | with torch.cuda.amp.autocast(enabled=False): 97 | x.requires_grad_(True) 98 | sigma, albedo = self.common_forward(x) 99 | # query gradient 100 | normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] 101 | 102 | # normal = self.finite_difference_normal(x) 103 | normal = safe_normalize(normal) 104 | normal = torch.nan_to_num(normal) 105 | 106 | return normal 107 | 108 | def forward(self, x, d, l=None, ratio=1, shading='albedo'): 109 | # x: [N, 3], in [-bound, bound] 110 | # d: [N, 3], view direction, nomalized in [-1, 1] 111 | # l: [3], plane light direction, nomalized in [-1, 1] 112 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) 113 | 114 | 115 | if shading == 'albedo': 116 | sigma, albedo = self.common_forward(x) 117 | normal = None 118 | color = albedo 119 | 120 | else: # lambertian shading 121 | with torch.enable_grad(): 122 | with torch.cuda.amp.autocast(enabled=False): 123 | x.requires_grad_(True) 124 | sigma, albedo = self.common_forward(x) 125 | normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] 126 | normal = safe_normalize(normal) 127 | normal = torch.nan_to_num(normal) 128 | 129 | lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,] 130 | 131 | if shading == 'textureless': 132 | color = lambertian.unsqueeze(-1).repeat(1, 3) 133 | elif shading == 'normal': 134 | color = (normal + 1) / 2 135 | else: # 'lambertian' 136 | color = albedo * lambertian.unsqueeze(-1) 137 | 138 | return sigma, color, normal 139 | 140 | 141 | def density(self, x): 142 | # x: [N, 3], in [-bound, bound] 143 | 144 | sigma, albedo = self.common_forward(x) 145 | 146 | return { 147 | 'sigma': sigma, 148 | 'albedo': albedo, 149 | } 150 | 151 | 152 | def background(self, d): 153 | 154 | h = self.encoder_bg(d) # [N, C] 155 | 156 | h = self.bg_net(h) 157 | 158 | # sigmoid activation for rgb 159 | rgbs = torch.sigmoid(h) 160 | 161 | return rgbs 162 | 163 | # optimizer utils 164 | def get_params(self, lr): 165 | 166 | params = [ 167 | {'params': self.encoder.parameters(), 'lr': lr * 10}, 168 | {'params': self.sigma_net.parameters(), 'lr': lr}, 169 | ] 170 | 171 | if self.opt.bg_radius > 0: 172 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 173 | 174 | if self.opt.dmtet and not self.opt.lock_geo: 175 | params.append({'params': self.sdf, 'lr': lr}) 176 | params.append({'params': self.deform, 'lr': lr}) 177 | 178 | return params -------------------------------------------------------------------------------- /preprocess_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from PIL import Image 13 | 14 | class BackgroundRemoval(): 15 | def __init__(self, device='cuda'): 16 | 17 | from carvekit.api.high import HiInterface 18 | self.interface = HiInterface( 19 | object_type="object", # Can be "object" or "hairs-like". 20 | batch_size_seg=5, 21 | batch_size_matting=1, 22 | device=device, 23 | seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net 24 | matting_mask_size=2048, 25 | trimap_prob_threshold=231, 26 | trimap_dilation=30, 27 | trimap_erosion_iters=5, 28 | fp16=True, 29 | ) 30 | 31 | @torch.no_grad() 32 | def __call__(self, image): 33 | # image: [H, W, 3] array in [0, 255]. 34 | image = Image.fromarray(image) 35 | 36 | image = self.interface([image])[0] 37 | image = np.array(image) 38 | 39 | return image 40 | 41 | class BLIP2(): 42 | def __init__(self, device='cuda'): 43 | self.device = device 44 | from transformers import AutoProcessor, Blip2ForConditionalGeneration 45 | self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") 46 | self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device) 47 | 48 | @torch.no_grad() 49 | def __call__(self, image): 50 | image = Image.fromarray(image) 51 | inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16) 52 | 53 | generated_ids = self.model.generate(**inputs, max_new_tokens=20) 54 | generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() 55 | 56 | return generated_text 57 | 58 | 59 | class DPT(): 60 | def __init__(self, task='depth', device='cuda'): 61 | 62 | self.task = task 63 | self.device = device 64 | 65 | from dpt import DPTDepthModel 66 | 67 | if task == 'depth': 68 | path = 'pretrained/omnidata/omnidata_dpt_depth_v2.ckpt' 69 | self.model = DPTDepthModel(backbone='vitb_rn50_384') 70 | self.aug = transforms.Compose([ 71 | transforms.Resize((384, 384)), 72 | transforms.ToTensor(), 73 | transforms.Normalize(mean=0.5, std=0.5) 74 | ]) 75 | 76 | else: # normal 77 | path = 'pretrained/omnidata/omnidata_dpt_normal_v2.ckpt' 78 | self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) 79 | self.aug = transforms.Compose([ 80 | transforms.Resize((384, 384)), 81 | transforms.ToTensor() 82 | ]) 83 | 84 | # load model 85 | checkpoint = torch.load(path, map_location='cpu') 86 | if 'state_dict' in checkpoint: 87 | state_dict = {} 88 | for k, v in checkpoint['state_dict'].items(): 89 | state_dict[k[6:]] = v 90 | else: 91 | state_dict = checkpoint 92 | self.model.load_state_dict(state_dict) 93 | self.model.eval().to(device) 94 | 95 | 96 | @torch.no_grad() 97 | def __call__(self, image): 98 | # image: np.ndarray, uint8, [H, W, 3] 99 | H, W = image.shape[:2] 100 | image = Image.fromarray(image) 101 | 102 | image = self.aug(image).unsqueeze(0).to(self.device) 103 | 104 | if self.task == 'depth': 105 | depth = self.model(image).clamp(0, 1) 106 | depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False) 107 | depth = depth.squeeze(1).cpu().numpy() 108 | return depth 109 | else: 110 | normal = self.model(image).clamp(0, 1) 111 | normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False) 112 | normal = normal.cpu().numpy() 113 | return normal 114 | 115 | 116 | 117 | if __name__ == '__main__': 118 | 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)") 121 | parser.add_argument('--size', default=256, type=int, help="output resolution") 122 | parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio") 123 | parser.add_argument('--recenter', type=bool, default=True, help="recenter, potentially not helpful for multiview zero123") 124 | parser.add_argument('--dont_recenter', dest='recenter', action='store_false') 125 | opt = parser.parse_args() 126 | 127 | out_dir = os.path.dirname(opt.path) 128 | out_rgba = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_rgba.png') 129 | out_depth = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_depth.png') 130 | out_normal = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_normal.png') 131 | out_caption = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_caption.txt') 132 | 133 | # load image 134 | print(f'[INFO] loading image...') 135 | image = cv2.imread(opt.path, cv2.IMREAD_UNCHANGED) 136 | if image.shape[-1] == 4: 137 | image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) 138 | else: 139 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 140 | 141 | # carve background 142 | print(f'[INFO] background removal...') 143 | carved_image = BackgroundRemoval()(image) # [H, W, 4] 144 | mask = carved_image[..., -1] > 0 145 | 146 | # predict depth 147 | print(f'[INFO] depth estimation...') 148 | dpt_depth_model = DPT(task='depth') 149 | depth = dpt_depth_model(image)[0] 150 | depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9) 151 | depth[~mask] = 0 152 | depth = (depth * 255).astype(np.uint8) 153 | del dpt_depth_model 154 | 155 | # predict normal 156 | print(f'[INFO] normal estimation...') 157 | dpt_normal_model = DPT(task='normal') 158 | normal = dpt_normal_model(image)[0] 159 | normal = (normal * 255).astype(np.uint8).transpose(1, 2, 0) 160 | normal[~mask] = 0 161 | del dpt_normal_model 162 | 163 | # recenter 164 | if opt.recenter: 165 | print(f'[INFO] recenter...') 166 | final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8) 167 | final_depth = np.zeros((opt.size, opt.size), dtype=np.uint8) 168 | final_normal = np.zeros((opt.size, opt.size, 3), dtype=np.uint8) 169 | 170 | coords = np.nonzero(mask) 171 | x_min, x_max = coords[0].min(), coords[0].max() 172 | y_min, y_max = coords[1].min(), coords[1].max() 173 | h = x_max - x_min 174 | w = y_max - y_min 175 | desired_size = int(opt.size * (1 - opt.border_ratio)) 176 | scale = desired_size / max(h, w) 177 | h2 = int(h * scale) 178 | w2 = int(w * scale) 179 | x2_min = (opt.size - h2) // 2 180 | x2_max = x2_min + h2 181 | y2_min = (opt.size - w2) // 2 182 | y2_max = y2_min + w2 183 | final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) 184 | final_depth[x2_min:x2_max, y2_min:y2_max] = cv2.resize(depth[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) 185 | final_normal[x2_min:x2_max, y2_min:y2_max] = cv2.resize(normal[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) 186 | 187 | else: 188 | final_rgba = carved_image 189 | final_depth = depth 190 | final_normal = normal 191 | 192 | # write output 193 | cv2.imwrite(out_rgba, cv2.cvtColor(final_rgba, cv2.COLOR_RGBA2BGRA)) 194 | cv2.imwrite(out_depth, final_depth) 195 | cv2.imwrite(out_normal, final_normal) 196 | 197 | # predict caption (it's too slow... use your brain instead) 198 | # print(f'[INFO] captioning...') 199 | # blip2 = BLIP2() 200 | # caption = blip2(image) 201 | # with open(out_caption, 'w') as f: 202 | # f.write(caption) 203 | 204 | -------------------------------------------------------------------------------- /pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "image_cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | 19 | scheduler_config: # 10000 warmup steps 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: [ 100 ] 23 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 24 | f_start: [ 1.e-6 ] 25 | f_max: [ 1. ] 26 | f_min: [ 1. ] 27 | 28 | unet_config: 29 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 30 | params: 31 | image_size: 32 # unused 32 | in_channels: 8 33 | out_channels: 4 34 | model_channels: 320 35 | attention_resolutions: [ 4, 2, 1 ] 36 | num_res_blocks: 2 37 | channel_mult: [ 1, 2, 4, 4 ] 38 | num_heads: 8 39 | use_spatial_transformer: True 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: True 43 | legacy: False 44 | 45 | first_stage_config: 46 | target: ldm.models.autoencoder.AutoencoderKL 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder 70 | 71 | 72 | # data: 73 | # target: ldm.data.simple.ObjaverseDataModuleFromConfig 74 | # params: 75 | # root_dir: 'views_whole_sphere' 76 | # batch_size: 192 77 | # num_workers: 16 78 | # total_view: 4 79 | # train: 80 | # validation: False 81 | # image_transforms: 82 | # size: 256 83 | 84 | # validation: 85 | # validation: True 86 | # image_transforms: 87 | # size: 256 88 | 89 | 90 | # lightning: 91 | # find_unused_parameters: false 92 | # metrics_over_trainsteps_checkpoint: True 93 | # modelcheckpoint: 94 | # params: 95 | # every_n_train_steps: 5000 96 | # callbacks: 97 | # image_logger: 98 | # target: main.ImageLogger 99 | # params: 100 | # batch_frequency: 500 101 | # max_images: 32 102 | # increase_log_steps: False 103 | # log_first_step: True 104 | # log_images_kwargs: 105 | # use_ema_scope: False 106 | # inpaint: False 107 | # plot_progressive_rows: False 108 | # plot_diffusion_rows: False 109 | # N: 32 110 | # unconditional_scale: 3.0 111 | # unconditional_label: [""] 112 | 113 | # trainer: 114 | # benchmark: True 115 | # val_check_interval: 5000000 # really sorry 116 | # num_sanity_val_steps: 0 117 | # accumulate_grad_batches: 1 118 | -------------------------------------------------------------------------------- /raymarching/__init__.py: -------------------------------------------------------------------------------- 1 | from .raymarching import * -------------------------------------------------------------------------------- /raymarching/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | _backend = load(name='_raymarching', 33 | extra_cflags=c_flags, 34 | extra_cuda_cflags=nvcc_flags, 35 | sources=[os.path.join(_src_path, 'src', f) for f in [ 36 | 'raymarching.cu', 37 | 'bindings.cpp', 38 | ]], 39 | ) 40 | 41 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /raymarching/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | ''' 34 | Usage: 35 | 36 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) 37 | 38 | python setup.py install # build extensions and install (copy) to PATH. 39 | pip install . # ditto but better (e.g., dependency & metadata handling) 40 | 41 | python setup.py develop # build extensions and install (symbolic) to PATH. 42 | pip install -e . # ditto but better (e.g., dependency & metadata handling) 43 | 44 | ''' 45 | setup( 46 | name='raymarching', # package name, import this to use python API 47 | ext_modules=[ 48 | CUDAExtension( 49 | name='_raymarching', # extension name, import this to use CUDA API 50 | sources=[os.path.join(_src_path, 'src', f) for f in [ 51 | 'raymarching.cu', 52 | 'bindings.cpp', 53 | ]], 54 | extra_compile_args={ 55 | 'cxx': c_flags, 56 | 'nvcc': nvcc_flags, 57 | } 58 | ), 59 | ], 60 | cmdclass={ 61 | 'build_ext': BuildExtension, 62 | } 63 | ) -------------------------------------------------------------------------------- /raymarching/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | 3 | #include "raymarching.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // utils 7 | m.def("flatten_rays", &flatten_rays, "flatten_rays (CUDA)"); 8 | m.def("packbits", &packbits, "packbits (CUDA)"); 9 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); 10 | m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); 11 | m.def("morton3D", &morton3D, "morton3D (CUDA)"); 12 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); 13 | // train 14 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); 15 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); 16 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); 17 | // infer 18 | m.def("march_rays", &march_rays, "march rays (CUDA)"); 19 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); 20 | } -------------------------------------------------------------------------------- /raymarching/src/raymarching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <stdint.h> 4 | #include <torch/torch.h> 5 | 6 | 7 | void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); 8 | void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); 9 | void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); 10 | void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); 11 | void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); 12 | void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res); 13 | 14 | void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional<at::Tensor> xyzs, at::optional<at::Tensor> dirs, at::optional<at::Tensor> ts, at::Tensor rays, at::Tensor counter, at::Tensor noises); 15 | void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 16 | void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs); 17 | 18 | void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises); 19 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | rich 3 | ninja 4 | numpy 5 | pandas 6 | scipy 7 | scikit-learn 8 | matplotlib 9 | opencv-python 10 | imageio 11 | imageio-ffmpeg 12 | 13 | torch 14 | torch-ema 15 | einops 16 | tensorboard 17 | tensorboardX 18 | 19 | # for gui 20 | dearpygui 21 | 22 | # for grid_tcnn 23 | # git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 24 | 25 | # for stable-diffusion 26 | huggingface_hub 27 | diffusers >= 0.9.0 28 | accelerate 29 | transformers 30 | 31 | # for dmtet and mesh export 32 | xatlas 33 | trimesh 34 | PyMCubes 35 | pymeshlab 36 | git+https://github.com/NVlabs/nvdiffrast/ 37 | 38 | # for zero123 39 | carvekit-colab 40 | omegaconf 41 | pytorch-lightning 42 | taming-transformers-rom1504 43 | kornia 44 | git+https://github.com/openai/CLIP.git 45 | 46 | # for omnidata 47 | gdown 48 | 49 | # for dpt 50 | timm 51 | 52 | # for remote debugging 53 | debugpy-run 54 | 55 | # for deepfloyd if 56 | sentencepiece 57 | -------------------------------------------------------------------------------- /scripts/install_ext.sh: -------------------------------------------------------------------------------- 1 | pip install ./raymarching 2 | pip install ./shencoder 3 | pip install ./freqencoder 4 | pip install ./gridencoder -------------------------------------------------------------------------------- /scripts/res64.args: -------------------------------------------------------------------------------- 1 | -O --vram_O --w 64 --h 64 -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a delicious hamburger" --workspace trial_hamburger --iters 5000 3 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a delicious hamburger" --workspace trial2_hamburger --dmtet --iters 5000 --init_with trial_hamburger/checkpoints/df.pth 4 | 5 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a highly detailed stone bust of Theodoros Kolokotronis" --workspace trial_stonehead --iters 5000 6 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a highly detailed stone bust of Theodoros Kolokotronis" --workspace trial2_stonehead --dmtet --iters 5000 --init_with trial_stonehead/checkpoints/df.pth 7 | 8 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "an astronaut, full body" --workspace trial_astronaut --iters 5000 9 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "an astronaut, full body" --workspace trial2_astronaut --dmtet --iters 5000 --init_with trial_astronaut/checkpoints/df.pth 10 | 11 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a squirrel-octopus hybrid" --workspace trial_squrrel_octopus --iters 5000 12 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a squirrel-octopus hybrid" --workspace trial2_squrrel_octopus --dmtet --iters 5000 --init_with trial_squrrel_octopus/checkpoints/df.pth 13 | 14 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a baby bunny sitting on top of a stack of pancakes" --workspace trial_rabbit_pancake --iters 5000 15 | CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a metal bunny sitting on top of a stack of chocolate cookies" --workspace trial2_rabbit_pancake --dmtet --iters 5000 --init_with trial_rabbit_pancake/checkpoints/df.pth -------------------------------------------------------------------------------- /scripts/run2.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a DSLR photo of a shiba inu playing golf wearing tartan golf clothes and hat" --workspace trial_shiba --iters 10000 4 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a DSLR photo of a shiba inu playing golf wearing tartan golf clothes and hat" --workspace trial2_shiba --dmtet --iters 5000 --init_with trial_shiba/checkpoints/df.pth 5 | 6 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a banana peeling itself" --workspace trial_banana --iters 10000 7 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a banana peeling itself" --workspace trial2_banana --dmtet --iters 5000 --init_with trial_banana/checkpoints/df.pth 8 | 9 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a capybara wearing a top hat, low poly" --workspace trial_capybara --iters 10000 10 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a capybara wearing a top hat, low poly" --workspace trial2_capybara --dmtet --iters 5000 --init_with trial_capybara/checkpoints/df.pth -------------------------------------------------------------------------------- /scripts/run3.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "ironman, full body" --workspace trial_ironman --iters 10000 4 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "ironman, full body" --workspace trial2_ironman --dmtet --iters 5000 --init_with trial_ironman/checkpoints/df.pth 5 | 6 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of an ice cream sundae" --workspace trial_icecream --iters 10000 7 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of an ice cream sundae" --workspace trial2_icecream --dmtet --iters 5000 --init_with trial_icecream/checkpoints/df.pth 8 | 9 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of a kingfisher bird" --workspace trial_bird --iters 10000 10 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of a kingfisher bird" --workspace trial2_bird --dmtet --iters 5000 --init_with trial_bird/checkpoints/df.pth 11 | 12 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a car made of sushi" --workspace trial_sushi --iters 10000 13 | CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a car made of sushi" --workspace trial2_sushi --dmtet --iters 5000 --init_with trial_sushi/checkpoints/df.pth 14 | -------------------------------------------------------------------------------- /scripts/run4.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a rabbit, animated movie character, high detail 3d model" --workspace trial_rabbit2 --iters 10000 4 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a rabbit, animated movie character, high detail 3d model" --workspace trial2_rabbit2 --dmtet --iters 5000 --init_with trial_rabbit2/checkpoints/df.pth 5 | 6 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a corgi dog, highly detailed 3d model" --workspace trial_corgi --iters 10000 7 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a corgi dog, highly detailed 3d model" --workspace trial2_corgi --dmtet --iters 5000 --init_with trial_corgi/checkpoints/df.pth 8 | 9 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text " a small saguaro cactus planted in a clay pot" --workspace trial_cactus --iters 10000 10 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text " a small saguaro cactus planted in a clay pot" --workspace trial2_cactus --dmtet --iters 5000 --init_with trial_cactus/checkpoints/df.pth 11 | 12 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "the leaning tower of Pisa" --workspace trial_pisa --iters 10000 13 | CUDA_VISIBLE_DEVICES=5 python main.py -O --text "the leaning tower of Pisa" --workspace trial2_pisa --dmtet --iters 5000 --init_with trial_pisa/checkpoints/df.pth -------------------------------------------------------------------------------- /scripts/run5.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Perched blue jay bird" --workspace trial_jay --iters 10000 4 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Perched blue jay bird" --workspace trial2_jay --dmtet --iters 5000 --init_with trial_jay/checkpoints/df.pth 5 | 6 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "angel statue wings out" --workspace trial_angle --iters 10000 7 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "angel statue wings out" --workspace trial2_angle --dmtet --iters 5000 --init_with trial_angle/checkpoints/df.pth 8 | 9 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "devil statue" --workspace trial_devil --iters 10000 10 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "devil statue" --workspace trial2_devil --dmtet --iters 5000 --init_with trial_devil/checkpoints/df.pth 11 | 12 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Einstein statue" --workspace trial_einstein --iters 10000 13 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Einstein statue" --workspace trial2_einstein --dmtet --iters 5000 --init_with trial_einstein/checkpoints/df.pth 14 | -------------------------------------------------------------------------------- /scripts/run6.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a baby bunny sitting on top of a stack of pancakes" --workspace trial_rabbit_pancake --iters 5000 3 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a metal bunny sitting on top of a stack of chocolate cookies" --workspace trial2_rabbit_pancake --dmtet --iters 5000 --init_with trial_rabbit_pancake/checkpoints/df.pth 4 | 5 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial_jay --iters 5000 6 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial2_jay --dmtet --iters 5000 --init_with trial_jay/checkpoints/df.pth 7 | 8 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial_fox --iters 5000 9 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial2_fox --dmtet --iters 5000 --init_with trial_fox/checkpoints/df.pth 10 | 11 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial_peacock --iters 5000 12 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial2_peacock --dmtet --iters 5000 --init_with trial_peacock/checkpoints/df.pth 13 | 14 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a flower made out of metal" --workspace trial_metal_flower --iters 5000 15 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a flower made out of metal" --workspace trial2_metal_flower --dmtet --iters 5000 --init_with trial_metal_flower/checkpoints/df.pth 16 | 17 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial_chicken --iters 5000 18 | CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial2_chicken --dmtet --iters 5000 --init_with trial_chicken/checkpoints/df.pth -------------------------------------------------------------------------------- /scripts/run_if.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a baby bunny sitting on top of a stack of pancakes" --workspace trial_if_rabbit_pancake --iters 5000 --IF 3 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a metal bunny sitting on top of a stack of chocolate cookies" --workspace trial_if2_rabbit_pancake --dmtet --iters 5000 --init_with trial_if_rabbit_pancake/checkpoints/df.pth 4 | 5 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial_if_jay --iters 5000 --IF 6 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial_if2_jay --dmtet --iters 5000 --init_with trial_if_jay/checkpoints/df.pth 7 | 8 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial_if_fox --iters 5000 --IF 9 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial_if2_fox --dmtet --iters 5000 --init_with trial_if_fox/checkpoints/df.pth 10 | 11 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial_if_peacock --iters 5000 --IF 12 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial_if2_peacock --dmtet --iters 5000 --init_with trial_if_peacock/checkpoints/df.pth 13 | 14 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a flower made out of metal" --workspace trial_if_metal_flower --iters 5000 --IF 15 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a flower made out of metal" --workspace trial_if2_metal_flower --dmtet --iters 5000 --init_with trial_if_metal_flower/checkpoints/df.pth 16 | 17 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial_if_chicken --iters 5000 --IF 18 | CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial_if2_chicken --dmtet --iters 5000 --init_with trial_if_chicken/checkpoints/df.pth -------------------------------------------------------------------------------- /scripts/run_if2.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a corgi taking a selfie" --workspace trial_if_corgi --iters 5000 --IF 3 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a corgi taking a selfie" --workspace trial_if2_corgi --dmtet --iters 5000 --init_with trial_if_corgi/checkpoints/df.pth 4 | 5 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a ghost eating a hamburger" --workspace trial_if_ghost --iters 5000 --IF 6 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a ghost eating a hamburger" --workspace trial_if2_ghost --dmtet --iters 5000 --init_with trial_if_ghost/checkpoints/df.pth 7 | 8 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of an origami motorcycle" --workspace trial_if_motor --iters 5000 --IF 9 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of an origami motorcycle" --workspace trial_if2_motor --dmtet --iters 5000 --init_with trial_if_motor/checkpoints/df.pth 10 | 11 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a Space Shuttle" --workspace trial_if_spaceshuttle --iters 5000 --IF 12 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a Space Shuttle" --workspace trial_if2_spaceshuttle --dmtet --iters 5000 --init_with trial_if_spaceshuttle/checkpoints/df.pth 13 | 14 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a palm tree, low poly 3d model" --workspace trial_if_palm --iters 5000 --IF 15 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a palm tree, low poly 3d model" --workspace trial_if2_palm --dmtet --iters 5000 --init_with trial_if_palm/checkpoints/df.pth 16 | 17 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a zoomed out DSLR photo of a marble bust of a cat, a real mouse is sitting on its head" --workspace trial_if_cat_mouse --iters 5000 --IF 18 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a zoomed out DSLR photo of a marble bust of a cat, a real mouse is sitting on its head" --workspace trial_if2_cat_mouse --dmtet --iters 5000 --init_with trial_if_cat_mouse/checkpoints/df.pth -------------------------------------------------------------------------------- /scripts/run_if2_perpneg.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # To avoid the Janus problem caused by the diffusion model's front view bias, utilize the Perp-Neg algorithm. To maximize its benefits, 3 | # increase the absolute value of "negative_w" for improved Janus problem mitigation. If you encounter flat faces or divergence, consider 4 | # reducing the absolute value of "negative_w". The value of "negative_w" should vary for each prompt due to the diffusion model's varying 5 | # bias towards generating front views for different objects. Vary the weights within the range of 0 to -4. 6 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a lion bust" --workspace trial_perpneg_if_lion --iters 5000 --IF --batch_size 1 --perpneg 7 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a marble lion head" --workspace trial_perpneg_if2_lion_p --dmtet --iters 5000 --perpneg --init_with trial_perpneg_if_lion/checkpoints/df.pth 8 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a marble lion head" --workspace trial_perpneg_if2_lion_nop --dmtet --iters 5000 --init_with trial_perpneg_if_lion/checkpoints/df.pth 9 | 10 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a tiger cub" --workspace trial_perpneg_if_tiger --iters 5000 --IF --batch_size 1 --perpneg 11 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "tiger" --workspace trial_perpneg_if2_tiger_p --dmtet --iters 5000 --perpneg --init_with trial_perpneg_if_tiger/checkpoints/df.pth 12 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "tiger" --workspace trial_perpneg_if2_tiger_nop --dmtet --iters 5000 --init_with trial_perpneg_if_tiger/checkpoints/df.pth 13 | 14 | # larger absolute value of negative_w is used for the following command because the defult negative weight of -2 is not enough to make the diffusion model to produce the views as desired 15 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a shiba dog wearing sunglasses" --workspace trial_perpneg_if_shiba --iters 5000 --IF --batch_size 1 --perpneg --negative_w -3.0 16 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "shiba wearing sunglasses" --workspace trial_perpneg_if2_shiba_p --dmtet --iters 5000 --perpneg --negative_w -3.0 --init_with trial_perpneg_if_shiba/checkpoints/df.pth 17 | CUDA_VISIBLE_DEVICES=3 python main.py -O --text "shiba wearing sunglasses" --workspace trial_perpneg_if2_shiba_nop --dmtet --iters 5000 --init_with trial_perpneg_if_shiba/checkpoints/df.pth 18 | 19 | -------------------------------------------------------------------------------- /scripts/run_image.sh: -------------------------------------------------------------------------------- 1 | # zero123 backend (single object, images like 3d model rendering) 2 | 3 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/teddy_rgba.png --workspace trial_image_teddy --iters 5000 4 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/teddy_rgba.png --workspace trial2_image_teddy --iters 5000 --dmtet --init_with trial_image_teddy/checkpoints/df.pth 5 | 6 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/catstatue_rgba.png --workspace trial_image_catstatue --iters 5000 7 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/catstatue_rgba.png --workspace trial2_image_catstatue --iters 5000 --dmtet --init_with trial_image_catstatue/checkpoints/df.pth 8 | 9 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/firekeeper_rgba.png --workspace trial_image_firekeeper --iters 5000 10 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/firekeeper_rgba.png --workspace trial2_image_firekeeper --iters 5000 --dmtet --init_with trial_image_firekeeper/checkpoints/df.pth 11 | 12 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/hamburger_rgba.png --workspace trial_image_hamburger --iters 5000 13 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/hamburger_rgba.png --workspace trial2_image_hamburger --iters 5000 --dmtet --init_with trial_image_hamburger/checkpoints/df.pth 14 | 15 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/corgi_rgba.png --workspace trial_image_corgi --iters 5000 16 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/corgi_rgba.png --workspace trial2_image_corgi --iters 5000 --dmtet --init_with trial_image_corgi/checkpoints/df.pth 17 | 18 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cactus_rgba.png --workspace trial_image_cactus --iters 5000 19 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cactus_rgba.png --workspace trial2_image_cactus --iters 5000 --dmtet --init_with trial_image_cactus/checkpoints/df.pth 20 | 21 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cake_rgba.png --workspace trial_image_cake --iters 5000 22 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cake_rgba.png --workspace trial2_image_cake --iters 5000 --dmtet --init_with trial_image_cake/checkpoints/df.pth 23 | 24 | # CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/warrior_rgba.png --workspace trial_image_warrior --iters 5000 25 | # CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/warrior_rgba.png --workspace trial2_image_warrior --iters 5000 --dmtet --init_with trial_image_warrior/checkpoints/df.pth -------------------------------------------------------------------------------- /scripts/run_image_anya.sh: -------------------------------------------------------------------------------- 1 | # Phase 1 - barely fits in A100 40GB. 2 | # Conclusion: results in concave-ish face, no neck, excess hair in the back 3 | CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage \ 4 | --iters 10000 --save_guidance --save_guidance_interval 10 --ckpt scratch --batch_size 2 --test_interval 2 \ 5 | --h 128 --w 128 --zero123_grad_scale None 6 | 7 | # Phase 2 - barely fits in A100 40GB. 8 | # 20X smaller lambda_3d_normal_smooth, --known_view_interval 2, 3X LR 9 | # Much higher jitter to increase disparity (and eliminate some of the flatness)... not too high either (to avoid cropping the face) 10 | CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage_B_GPU2_reproduction1_GPU2 \ 11 | --text "A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" \ 12 | --iters 12500 --ckpt trial_anya_1_refimage/checkpoints/df_ep0100.pth --save_guidance --save_guidance_interval 1 \ 13 | --h 256 --w 256 --albedo_iter_ratio 0.0 --t_range 0.2 0.6 --batch_size 4 --radius_range 2.2 2.6 --test_interval 2 \ 14 | --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.1 --jitter_target 0.1 --jitter_up 0.05 \ 15 | --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --progressive_view --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 1 \ 16 | --exp_start_iter 10000 --exp_end_iter 12500 17 | 18 | # Phase 3 - increase resolution to 512 19 | # Disable textureless since they can cause catastrophic divergence 20 | # Since radius range is inconsistent, increase it, and reduce the jitter to avoid excessively cropped renders. 21 | # Learning rate may be set too high, since `--batch_size 1`. 22 | CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage_B_GPU2_reproduction1_GPU2_refinedGPU2 \ 23 | --text "A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" \ 24 | --iters 25000 --ckpt trial_anya_1_refimage_B_GPU2_reproduction1_GPU2/checkpoints/df_ep0125.pth --save_guidance --save_guidance_interval 1 \ 25 | --h 512 --w 512 --albedo_iter_ratio 0.0 --t_range 0.0 0.5 --batch_size 1 --radius_range 3.2 3.6 --test_interval 2 \ 26 | --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.015 --jitter_target 0.015 --jitter_up 0.05 \ 27 | --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 0.5 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \ 28 | --exp_start_iter 12500 --exp_end_iter 25000 29 | 30 | # Generate 6 views 31 | CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --ckpt trial_anya_1_refimage_B_GPU2_reproduction1_GPU2_refinedGPU2/checkpoints/df_ep0250.pth --six_views 32 | 33 | # Phase 4 - untested, need to adjust 34 | # CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage --iters 5000 --dmtet --init_with trial_anya_1_refimage/checkpoints/df.pth 35 | 36 | -------------------------------------------------------------------------------- /scripts/run_image_hard_examples.sh: -------------------------------------------------------------------------------- 1 | bash scripts/run_image_procedure.sh 0 30 90 anya_front "A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" 2 | bash scripts/run_image_procedure.sh 1 30 70 baby_phoenix_on_ice "A DSLR 3D photo of an adorable baby phoenix made in Swarowski crystal highly detailed intricate concept art 8K ( unreal engine 5 trending on Artstation )" 3 | bash scripts/run_image_procedure.sh 2 30 90 bollywood_actress "A DSLR 3D photo of a beautiful bollywood indian actress, pretty eyes, full body shot composition, sunny outdoor, seen from far away ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" 4 | bash scripts/run_image_procedure.sh 3 30 40 beach_house_1 "A DSLR 3D photo of a very beautiful small house on a beach ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" 5 | bash scripts/run_image_procedure.sh 4 30 60 beach_house_2 "A DSLR 3D photo of a very beautiful high-tech small house with solar panels and wildflowers on a beach ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" 6 | bash scripts/run_image_procedure.sh 5 30 90 mona_lisa "A DSLR 3D photo of a beautiful young woman dressed like Mona Lisa ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" 7 | bash scripts/run_image_procedure.sh 6 30 80 futuristic_car "A DSLR 3D photo of a crazily futuristic electric car ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" 8 | # the church ruins probably require a wider field of view... e.g. 90 degrees, maybe even more... so may not work with Zero123 etc. 9 | bash scripts/run_image_procedure.sh 7 30 90 church_ruins "A DSLR 3D photo of the remains of an isolated old church ruin covered in ivy ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" 10 | 11 | # young woman dressed like mona lisa -------------------------------------------------------------------------------- /scripts/run_image_procedure.sh: -------------------------------------------------------------------------------- 1 | # Perform a 2D-to-3D reconstruction, similar to the Anya case study: https://github.com/ashawkey/stable-dreamfusion/issues/263 2 | # Args: 3 | # bash scripts/run_image_procedure.sh GPU_ID guidance_interval image_name "prompt" 4 | # e.g.: 5 | # bash scripts/run_image_procedure 1 30 baby_phoenix_on_ice "An adorable baby phoenix made in Swarowski crystal highly detailed intricated concept art 8K" 6 | GPU_ID=$1 7 | GUIDANCE_INTERVAL=$2 8 | DEFAULT_POLAR=$3 9 | PREFIX=$4 10 | PROMPT=$5 11 | EPOCHS1=100 12 | EPOCHS2=200 13 | EPOCHS3=300 14 | IMAGE=data/$PREFIX.png 15 | IMAGE_RGBA=data/${PREFIX}_rgba.png 16 | WS_PH1=trial_$PREFIX-ph1 17 | WS_PH2=trial_$PREFIX-ph2 18 | WS_PH3=trial_$PREFIX-ph3 19 | CKPT1=$WS_PH1/checkpoints/df_ep0${EPOCHS1}.pth 20 | CKPT2=$WS_PH2/checkpoints/df_ep0${EPOCHS2}.pth 21 | CKPT3=$WS_PH3/checkpoints/df_ep0${EPOCHS3}.pth 22 | 23 | # Can uncomment to clear up trial folders. Be careful - mistakes could erase important work! 24 | # rm -r $WS_PH1 $WS_PH2 $WS_PH3 25 | 26 | # Preprocess 27 | if [ ! -f $IMAGE_RGBA ] 28 | then 29 | python preprocess_image.py $IMAGE 30 | fi 31 | 32 | if [ ! -f $CKPT1 ] 33 | then 34 | # Phase 1 - zero123-guidance 35 | # WARNING: claforte: constantly runs out of VRAM with resolution of 128x128 and batch_size 2... no longer able to reproduce Anya result because of this... 36 | # I added these to try to reduce mem usage, but this might degrade the quality... `--lambda_depth 0 --lambda_3d_normal_smooth 0` 37 | # Remove: --ckpt scratch 38 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH1 --default_polar $DEFAULT_POLAR \ 39 | --iters ${EPOCHS1}00 --save_guidance --save_guidance_interval $GUIDANCE_INTERVAL --batch_size 1 --test_interval 2 \ 40 | --h 96 --w 96 --zero123_grad_scale None --lambda_3d_normal_smooth 0 --dont_override_stuff \ 41 | --fovy_range 20 20 --guidance_scale 5 42 | fi 43 | 44 | GUIDANCE_INTERVAL=7 45 | if [ ! -f $CKPT2 ] 46 | then 47 | # Phase 2 - SD-guidance at 256x256 48 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH2 \ 49 | --text "${PROMPT}" --default_polar $DEFAULT_POLAR \ 50 | --iters ${EPOCHS2}00 --ckpt $CKPT1 --save_guidance --save_guidance_interval 7 \ 51 | --h 128 --w 128 --albedo_iter_ratio 0.0 --t_range 0.2 0.6 --batch_size 4 --radius_range 2.2 2.6 --test_interval 2 \ 52 | --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.1 --jitter_target 0.1 --jitter_up 0.05 \ 53 | --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --progressive_view --progressive_view_init_ratio 0.05 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 1 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \ 54 | --exp_start_iter ${EPOCHS1}00 --exp_end_iter ${EPOCHS2}00 55 | fi 56 | 57 | if [ ! -f $CKPT3 ] 58 | then 59 | # # Phase 3 - increase resolution to 512 60 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH3 \ 61 | --text "${PROMPT}" --default_polar $DEFAULT_POLAR \ 62 | --iters ${EPOCHS3}00 --ckpt $CKPT2 --save_guidance --save_guidance_interval 7 \ 63 | --h 512 --w 512 --albedo_iter_ratio 0.0 --t_range 0.0 0.5 --batch_size 1 --radius_range 3.2 3.6 --test_interval 2 \ 64 | --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.015 --jitter_target 0.015 --jitter_up 0.05 \ 65 | --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 0.5 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \ 66 | --exp_start_iter ${EPOCHS2}00 --exp_end_iter ${EPOCHS3}00 67 | fi 68 | 69 | # Generate 6 views 70 | CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --ckpt $CKPT3 --six_views 71 | 72 | -------------------------------------------------------------------------------- /scripts/run_image_text.sh: -------------------------------------------------------------------------------- 1 | # sd backend (realistic images) 2 | 3 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/teddy_rgba.png --text "a brown teddy bear sitting on a ground" --workspace trial_imagetext_teddy --iters 5000 4 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/teddy_rgba.png --text "a brown teddy bear sitting on a ground" --workspace trial2_imagetext_teddy --iters 10000 --dmtet --init_with trial_imagetext_teddy/checkpoints/df.pth 5 | 6 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/corgi_rgba.png --text "a corgi running" --workspace trial_imagetext_corgi --iters 5000 7 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/corgi_rgba.png --text "a corgi running" --workspace trial2_imagetext_corgi --iters 10000 --dmtet --init_with trial_imagetext_corgi/checkpoints/df.pth 8 | 9 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/hamburger_rgba.png --text "a DSLR photo of a delicious hamburger" --workspace trial_imagetext_hamburger --iters 5000 10 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/hamburger_rgba.png --text "a DSLR photo of a delicious hamburger" --workspace trial2_imagetext_hamburger --iters 10000 --dmtet --init_with trial_imagetext_hamburger/checkpoints/df.pth 11 | 12 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/cactus_rgba.png --text "a potted cactus plant" --workspace trial_imagetext_cactus --iters 5000 13 | CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/cactus_rgba.png --text "a potted cactus plant" --workspace trial2_imagetext_cactus --iters 10000 --dmtet --init_with trial_imagetext_cactus/checkpoints/df.pth 14 | -------------------------------------------------------------------------------- /scripts/run_images.sh: -------------------------------------------------------------------------------- 1 | # zero123 backend (single object, images like 3d model rendering) 2 | 3 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/corgi.csv --workspace trial_images_corgi --iters 5000 4 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/corgi.csv --workspace trial2_images_corgi --iters 10000 --dmtet --init_with trial_images_corgi/checkpoints/df.pth 5 | 6 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/car.csv --workspace trial_images_car --iters 5000 7 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/car.csv --workspace trial2_images_car --iters 10000 --dmtet --init_with trial_images_car/checkpoints/df.pth 8 | 9 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/anya.csv --workspace trial_images_anya --iters 5000 10 | CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/anya.csv --workspace trial2_images_anya --iters 10000 --dmtet --init_with trial_images_anya/checkpoints/df.pth -------------------------------------------------------------------------------- /shencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphere_harmonics import SHEncoder -------------------------------------------------------------------------------- /shencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | _backend = load(name='_sh_encoder', 33 | extra_cflags=c_flags, 34 | extra_cuda_cflags=nvcc_flags, 35 | sources=[os.path.join(_src_path, 'src', f) for f in [ 36 | 'shencoder.cu', 37 | 'bindings.cpp', 38 | ]], 39 | ) 40 | 41 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /shencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | setup( 34 | name='shencoder', # package name, import this to use python API 35 | ext_modules=[ 36 | CUDAExtension( 37 | name='_shencoder', # extension name, import this to use CUDA API 38 | sources=[os.path.join(_src_path, 'src', f) for f in [ 39 | 'shencoder.cu', 40 | 'bindings.cpp', 41 | ]], 42 | extra_compile_args={ 43 | 'cxx': c_flags, 44 | 'nvcc': nvcc_flags, 45 | } 46 | ), 47 | ], 48 | cmdclass={ 49 | 'build_ext': BuildExtension, 50 | } 51 | ) -------------------------------------------------------------------------------- /shencoder/sphere_harmonics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _shencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | class _sh_encoder(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 17 | def forward(ctx, inputs, degree, calc_grad_inputs=False): 18 | # inputs: [B, input_dim], float in [-1, 1] 19 | # RETURN: [B, F], float 20 | 21 | inputs = inputs.contiguous() 22 | B, input_dim = inputs.shape # batch size, coord dim 23 | output_dim = degree ** 2 24 | 25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 26 | 27 | if calc_grad_inputs: 28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) 29 | else: 30 | dy_dx = None 31 | 32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) 33 | 34 | ctx.save_for_backward(inputs, dy_dx) 35 | ctx.dims = [B, input_dim, degree] 36 | 37 | return outputs 38 | 39 | @staticmethod 40 | #@once_differentiable 41 | @custom_bwd 42 | def backward(ctx, grad): 43 | # grad: [B, C * C] 44 | 45 | inputs, dy_dx = ctx.saved_tensors 46 | 47 | if dy_dx is not None: 48 | grad = grad.contiguous() 49 | B, input_dim, degree = ctx.dims 50 | grad_inputs = torch.zeros_like(inputs) 51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) 52 | return grad_inputs, None, None 53 | else: 54 | return None, None, None 55 | 56 | 57 | 58 | sh_encode = _sh_encoder.apply 59 | 60 | 61 | class SHEncoder(nn.Module): 62 | def __init__(self, input_dim=3, degree=4): 63 | super().__init__() 64 | 65 | self.input_dim = input_dim # coord dims, must be 3 66 | self.degree = degree # 0 ~ 4 67 | self.output_dim = degree ** 2 68 | 69 | assert self.input_dim == 3, "SH encoder only support input dim == 3" 70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" 71 | 72 | def __repr__(self): 73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" 74 | 75 | def forward(self, inputs, size=1): 76 | # inputs: [..., input_dim], normalized real world positions in [-size, size] 77 | # return: [..., degree^2] 78 | 79 | inputs = inputs / size # [-1, 1] 80 | 81 | prefix_shape = list(inputs.shape[:-1]) 82 | inputs = inputs.reshape(-1, self.input_dim) 83 | 84 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad) 85 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 86 | 87 | return outputs -------------------------------------------------------------------------------- /shencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | 3 | #include "shencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); 7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /shencoder/src/shencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include <stdint.h> 4 | #include <torch/torch.h> 5 | 6 | // inputs: [B, D], float, in [-1, 1] 7 | // outputs: [B, F], float 8 | 9 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx); 10 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /taichi_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .ray_march import RayMarcherTaichi, raymarching_test 2 | from .volume_train import VolumeRendererTaichi 3 | from .intersection import RayAABBIntersector 4 | from .volume_render_test import composite_test 5 | from .utils import packbits -------------------------------------------------------------------------------- /taichi_modules/intersection.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import torch 3 | from taichi.math import vec3 4 | from torch.cuda.amp import custom_fwd 5 | 6 | from .utils import NEAR_DISTANCE 7 | 8 | 9 | @ti.kernel 10 | def simple_ray_aabb_intersec_taichi_forward( 11 | hits_t: ti.types.ndarray(ndim=2), 12 | rays_o: ti.types.ndarray(ndim=2), 13 | rays_d: ti.types.ndarray(ndim=2), 14 | centers: ti.types.ndarray(ndim=2), 15 | half_sizes: ti.types.ndarray(ndim=2)): 16 | 17 | for r in ti.ndrange(hits_t.shape[0]): 18 | ray_o = vec3([rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]]) 19 | ray_d = vec3([rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]]) 20 | inv_d = 1.0 / ray_d 21 | 22 | center = vec3([centers[0, 0], centers[0, 1], centers[0, 2]]) 23 | half_size = vec3( 24 | [half_sizes[0, 0], half_sizes[0, 1], half_sizes[0, 1]]) 25 | 26 | t_min = (center - half_size - ray_o) * inv_d 27 | t_max = (center + half_size - ray_o) * inv_d 28 | 29 | _t1 = ti.min(t_min, t_max) 30 | _t2 = ti.max(t_min, t_max) 31 | t1 = _t1.max() 32 | t2 = _t2.min() 33 | 34 | if t2 > 0.0: 35 | hits_t[r, 0, 0] = ti.max(t1, NEAR_DISTANCE) 36 | hits_t[r, 0, 1] = t2 37 | 38 | 39 | class RayAABBIntersector(torch.autograd.Function): 40 | """ 41 | Computes the intersections of rays and axis-aligned voxels. 42 | 43 | Inputs: 44 | rays_o: (N_rays, 3) ray origins 45 | rays_d: (N_rays, 3) ray directions 46 | centers: (N_voxels, 3) voxel centers 47 | half_sizes: (N_voxels, 3) voxel half sizes 48 | max_hits: maximum number of intersected voxels to keep for one ray 49 | (for a cubic scene, this is at most 3*N_voxels^(1/3)-2) 50 | 51 | Outputs: 52 | hits_cnt: (N_rays) number of hits for each ray 53 | (followings are from near to far) 54 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit) 55 | hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit) 56 | """ 57 | 58 | @staticmethod 59 | @custom_fwd(cast_inputs=torch.float32) 60 | def forward(ctx, rays_o, rays_d, center, half_size, max_hits): 61 | hits_t = (torch.zeros( 62 | rays_o.size(0), 1, 2, device=rays_o.device, dtype=torch.float32) - 63 | 1).contiguous() 64 | 65 | simple_ray_aabb_intersec_taichi_forward(hits_t, rays_o, rays_d, center, 66 | half_size) 67 | 68 | return None, hits_t, None 69 | -------------------------------------------------------------------------------- /taichi_modules/utils.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import torch 3 | from taichi.math import uvec3 4 | 5 | taichi_block_size = 128 6 | 7 | data_type = ti.f32 8 | torch_type = torch.float32 9 | 10 | MAX_SAMPLES = 1024 11 | NEAR_DISTANCE = 0.01 12 | SQRT3 = 1.7320508075688772 13 | SQRT3_MAX_SAMPLES = SQRT3 / 1024 14 | SQRT3_2 = 1.7320508075688772 * 2 15 | 16 | 17 | @ti.func 18 | def scalbn(x, exponent): 19 | return x * ti.math.pow(2, exponent) 20 | 21 | 22 | @ti.func 23 | def calc_dt(t, exp_step_factor, grid_size, scale): 24 | return ti.math.clamp(t * exp_step_factor, SQRT3_MAX_SAMPLES, 25 | SQRT3_2 * scale / grid_size) 26 | 27 | 28 | @ti.func 29 | def frexp_bit(x): 30 | exponent = 0 31 | if x != 0.0: 32 | # frac = ti.abs(x) 33 | bits = ti.bit_cast(x, ti.u32) 34 | exponent = ti.i32((bits & ti.u32(0x7f800000)) >> 23) - 127 35 | # exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127 36 | bits &= ti.u32(0x7fffff) 37 | bits |= ti.u32(0x3f800000) 38 | frac = ti.bit_cast(bits, ti.f32) 39 | if frac < 0.5: 40 | exponent -= 1 41 | elif frac > 1.0: 42 | exponent += 1 43 | return exponent 44 | 45 | 46 | @ti.func 47 | def mip_from_pos(xyz, cascades): 48 | mx = ti.abs(xyz).max() 49 | # _, exponent = _frexp(mx) 50 | exponent = frexp_bit(ti.f32(mx)) + 1 51 | # frac, exponent = ti.frexp(ti.f32(mx)) 52 | return ti.min(cascades - 1, ti.max(0, exponent)) 53 | 54 | 55 | @ti.func 56 | def mip_from_dt(dt, grid_size, cascades): 57 | # _, exponent = _frexp(dt*grid_size) 58 | exponent = frexp_bit(ti.f32(dt * grid_size)) 59 | # frac, exponent = ti.frexp(ti.f32(dt*grid_size)) 60 | return ti.min(cascades - 1, ti.max(0, exponent)) 61 | 62 | 63 | @ti.func 64 | def __expand_bits(v): 65 | v = (v * ti.uint32(0x00010001)) & ti.uint32(0xFF0000FF) 66 | v = (v * ti.uint32(0x00000101)) & ti.uint32(0x0F00F00F) 67 | v = (v * ti.uint32(0x00000011)) & ti.uint32(0xC30C30C3) 68 | v = (v * ti.uint32(0x00000005)) & ti.uint32(0x49249249) 69 | return v 70 | 71 | 72 | @ti.func 73 | def __morton3D(xyz): 74 | xyz = __expand_bits(xyz) 75 | return xyz[0] | (xyz[1] << 1) | (xyz[2] << 2) 76 | 77 | 78 | @ti.func 79 | def __morton3D_invert(x): 80 | x = x & (0x49249249) 81 | x = (x | (x >> 2)) & ti.uint32(0xc30c30c3) 82 | x = (x | (x >> 4)) & ti.uint32(0x0f00f00f) 83 | x = (x | (x >> 8)) & ti.uint32(0xff0000ff) 84 | x = (x | (x >> 16)) & ti.uint32(0x0000ffff) 85 | return ti.int32(x) 86 | 87 | 88 | @ti.kernel 89 | def morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1), 90 | coords: ti.types.ndarray(ndim=2)): 91 | for i in indices: 92 | ind = ti.uint32(indices[i]) 93 | coords[i, 0] = __morton3D_invert(ind >> 0) 94 | coords[i, 1] = __morton3D_invert(ind >> 1) 95 | coords[i, 2] = __morton3D_invert(ind >> 2) 96 | 97 | 98 | def morton3D_invert(indices): 99 | coords = torch.zeros(indices.size(0), 100 | 3, 101 | device=indices.device, 102 | dtype=torch.int32) 103 | morton3D_invert_kernel(indices.contiguous(), coords) 104 | ti.sync() 105 | return coords 106 | 107 | 108 | @ti.kernel 109 | def morton3D_kernel(xyzs: ti.types.ndarray(ndim=2), 110 | indices: ti.types.ndarray(ndim=1)): 111 | for s in indices: 112 | xyz = uvec3([xyzs[s, 0], xyzs[s, 1], xyzs[s, 2]]) 113 | indices[s] = ti.cast(__morton3D(xyz), ti.int32) 114 | 115 | 116 | def morton3D(coords1): 117 | indices = torch.zeros(coords1.size(0), 118 | device=coords1.device, 119 | dtype=torch.int32) 120 | morton3D_kernel(coords1.contiguous(), indices) 121 | ti.sync() 122 | return indices 123 | 124 | 125 | @ti.kernel 126 | def packbits(density_grid: ti.types.ndarray(ndim=1), 127 | density_threshold: float, 128 | density_bitfield: ti.types.ndarray(ndim=1)): 129 | 130 | for n in density_bitfield: 131 | bits = ti.uint8(0) 132 | 133 | for i in ti.static(range(8)): 134 | bits |= (ti.uint8(1) << i) if ( 135 | density_grid[8 * n + i] > density_threshold) else ti.uint8(0) 136 | 137 | density_bitfield[n] = bits 138 | 139 | 140 | @ti.kernel 141 | def torch2ti(field: ti.template(), data: ti.types.ndarray()): 142 | for I in ti.grouped(data): 143 | field[I] = data[I] 144 | 145 | 146 | @ti.kernel 147 | def ti2torch(field: ti.template(), data: ti.types.ndarray()): 148 | for I in ti.grouped(data): 149 | data[I] = field[I] 150 | 151 | 152 | @ti.kernel 153 | def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()): 154 | for I in ti.grouped(grad): 155 | grad[I] = field.grad[I] 156 | 157 | 158 | @ti.kernel 159 | def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()): 160 | for I in ti.grouped(grad): 161 | field.grad[I] = grad[I] 162 | 163 | 164 | @ti.kernel 165 | def torch2ti_vec(field: ti.template(), data: ti.types.ndarray()): 166 | for I in range(data.shape[0] // 2): 167 | field[I] = ti.Vector([data[I * 2], data[I * 2 + 1]]) 168 | 169 | 170 | @ti.kernel 171 | def ti2torch_vec(field: ti.template(), data: ti.types.ndarray()): 172 | for i, j in ti.ndrange(data.shape[0], data.shape[1] // 2): 173 | data[i, j * 2] = field[i, j][0] 174 | data[i, j * 2 + 1] = field[i, j][1] 175 | 176 | 177 | @ti.kernel 178 | def ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()): 179 | for I in range(grad.shape[0] // 2): 180 | grad[I * 2] = field.grad[I][0] 181 | grad[I * 2 + 1] = field.grad[I][1] 182 | 183 | 184 | @ti.kernel 185 | def torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()): 186 | for i, j in ti.ndrange(grad.shape[0], grad.shape[1] // 2): 187 | field.grad[i, j][0] = grad[i, j * 2] 188 | field.grad[i, j][1] = grad[i, j * 2 + 1] 189 | 190 | 191 | def extract_model_state_dict(ckpt_path, 192 | model_name='model', 193 | prefixes_to_ignore=[]): 194 | checkpoint = torch.load(ckpt_path, map_location='cpu') 195 | checkpoint_ = {} 196 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint 197 | checkpoint = checkpoint['state_dict'] 198 | for k, v in checkpoint.items(): 199 | if not k.startswith(model_name): 200 | continue 201 | k = k[len(model_name) + 1:] 202 | for prefix in prefixes_to_ignore: 203 | if k.startswith(prefix): 204 | break 205 | else: 206 | checkpoint_[k] = v 207 | return checkpoint_ 208 | 209 | 210 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): 211 | if not ckpt_path: 212 | return 213 | model_dict = model.state_dict() 214 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, 215 | prefixes_to_ignore) 216 | model_dict.update(checkpoint_) 217 | model.load_state_dict(model_dict) 218 | 219 | def depth2img(depth): 220 | depth = (depth - depth.min()) / (depth.max() - depth.min()) 221 | depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8), 222 | cv2.COLORMAP_TURBO) 223 | 224 | return depth_img -------------------------------------------------------------------------------- /taichi_modules/volume_render_test.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | 3 | 4 | @ti.kernel 5 | def composite_test( 6 | sigmas: ti.types.ndarray(ndim=2), rgbs: ti.types.ndarray(ndim=3), 7 | deltas: ti.types.ndarray(ndim=2), ts: ti.types.ndarray(ndim=2), 8 | hits_t: ti.types.ndarray(ndim=2), 9 | alive_indices: ti.types.ndarray(ndim=1), T_threshold: float, 10 | N_eff_samples: ti.types.ndarray(ndim=1), 11 | opacity: ti.types.ndarray(ndim=1), 12 | depth: ti.types.ndarray(ndim=1), rgb: ti.types.ndarray(ndim=2)): 13 | 14 | for n in alive_indices: 15 | samples = N_eff_samples[n] 16 | if samples == 0: 17 | alive_indices[n] = -1 18 | else: 19 | r = alive_indices[n] 20 | 21 | T = 1 - opacity[r] 22 | 23 | rgb_temp_0 = 0.0 24 | rgb_temp_1 = 0.0 25 | rgb_temp_2 = 0.0 26 | depth_temp = 0.0 27 | opacity_temp = 0.0 28 | 29 | for s in range(samples): 30 | a = 1.0 - ti.exp(-sigmas[n, s] * deltas[n, s]) 31 | w = a * T 32 | 33 | rgb_temp_0 += w * rgbs[n, s, 0] 34 | rgb_temp_1 += w * rgbs[n, s, 1] 35 | rgb_temp_2 += w * rgbs[n, s, 2] 36 | depth[r] += w * ts[n, s] 37 | opacity[r] += w 38 | T *= 1.0 - a 39 | 40 | if T <= T_threshold: 41 | alive_indices[n] = -1 42 | break 43 | 44 | rgb[r, 0] += rgb_temp_0 45 | rgb[r, 1] += rgb_temp_1 46 | rgb[r, 2] += rgb_temp_2 47 | depth[r] += depth_temp 48 | opacity[r] += opacity_temp 49 | -------------------------------------------------------------------------------- /tets/128_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/tets/128_tets.npz -------------------------------------------------------------------------------- /tets/32_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/tets/32_tets.npz -------------------------------------------------------------------------------- /tets/64_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/stable-dreamfusion/5550b91862a3af7842bb04875b7f1211e5095a63/tets/64_tets.npz -------------------------------------------------------------------------------- /tets/README.md: -------------------------------------------------------------------------------- 1 | Place the tet grid files in this folder. 2 | We provide a few example grids. See the main README.md for a download link. 3 | 4 | You can also generate your own grids using https://github.com/crawforddoran/quartet 5 | Please see the `generate_tets.py` script for an example. 6 | 7 | -------------------------------------------------------------------------------- /tets/generate_tets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import os 11 | import numpy as np 12 | 13 | 14 | ''' 15 | This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, 16 | to generate a tet grid 17 | 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet` 18 | 2) Run the function below to generate a file `cube_32_tet.tet` 19 | ''' 20 | 21 | def generate_tetrahedron_grid_file(res=32, root='..'): 22 | frac = 1.0 / res 23 | command = 'cd %s/quartet; ' % (root) + \ 24 | './quartet_release meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res) 25 | os.system(command) 26 | 27 | 28 | ''' 29 | This code segment shows how to convert from a quartet .tet file to compressed npz file 30 | ''' 31 | def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets.npz'): 32 | 33 | file1 = open(quartetfile, 'r') 34 | header = file1.readline() 35 | numvertices = int(header.split(" ")[1]) 36 | numtets = int(header.split(" ")[2]) 37 | print(numvertices, numtets) 38 | 39 | # load vertices 40 | vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices) 41 | vertices = vertices - 0.5 42 | print(vertices.shape, vertices.min(), vertices.max()) 43 | 44 | # load indices 45 | indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets) 46 | print(indices.shape) 47 | 48 | np.savez_compressed(npzfile, vertices=vertices, indices=indices) 49 | 50 | if __name__ == '__main__': 51 | import argparse 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('--res', type=int, default=32) 54 | parser.add_argument('--root', type=str, default='..') 55 | args = parser.parse_args() 56 | 57 | generate_tetrahedron_grid_file(res=args.res, root=args.root) 58 | convert_from_quartet_to_npz(quartetfile=os.path.join(args.root, 'quartet', 'meshes', f'cube_{args.res}.000000_tet.tet'), npzfile=os.path.join('./tets', f'{args.res}_tets.npz')) --------------------------------------------------------------------------------