├── LICENSE
├── README.md
├── callbacks
├── callback.py
├── callback_utils.py
├── phase.py
├── state_callback.py
├── tensorboard_callback.py
├── viewer_callback.py
└── wandb_callback.py
├── checkpoints
└── HOW_TO_DOWNLOAD.txt
├── configs
├── config_scalp_texture_conditional.json
└── strand_vae_train.toml
├── data_loader
├── dataloader.py
├── difflocks_bodydata
│ ├── scalp.ply
│ └── smplx_base.ply
└── mesh_utils.py
├── data_processing
├── create_chunked_strands.py
├── create_latents.py
├── create_scalp_textures.py
└── uncompress_data.py
├── dockerfile
├── Dockerfile
├── build.sh
├── nvidia-container-runtime-script.sh
├── nvidia-container-runtime-script_ubuntu24.sh
└── run.sh
├── download_checkpoints.sh
├── download_dataset.sh
├── download_validation.sh
├── extensions
├── cuda
│ └── push_pull_inpaint.cu
├── include
│ └── cuMat
│ │ ├── CMakeLists.txt
│ │ ├── Core
│ │ ├── Dense
│ │ ├── IterativeLinearSolvers
│ │ ├── Sparse
│ │ └── src
│ │ ├── Allocator.h
│ │ ├── BinaryOps.h
│ │ ├── BinaryOpsPlugin.inl
│ │ ├── CholeskyDecomposition.h
│ │ ├── ConjugateGradient.h
│ │ ├── Constants.h
│ │ ├── Context.h
│ │ ├── CublasApi.h
│ │ ├── CudaUtils.h
│ │ ├── CusolverApi.h
│ │ ├── CwiseOp.h
│ │ ├── DecompositionBase.h
│ │ ├── DenseLinAlgOps.h
│ │ ├── DenseLinAlgPlugin.inl
│ │ ├── DevicePointer.h
│ │ ├── DisableCompilerWarnings.h
│ │ ├── EigenInteropHelpers.h
│ │ ├── Errors.h
│ │ ├── ForwardDeclarations.h
│ │ ├── IO.h
│ │ ├── IterativeSolverBase.h
│ │ ├── Iterator.h
│ │ ├── LUDecomposition.h
│ │ ├── Logging.h
│ │ ├── Macros.h
│ │ ├── Matrix.h
│ │ ├── MatrixBase.h
│ │ ├── MatrixBlock.h
│ │ ├── MatrixBlockPluginLvalue.inl
│ │ ├── MatrixBlockPluginRvalue.inl
│ │ ├── MatrixNullaryOpsPlugin.inl
│ │ ├── NullaryOps.h
│ │ ├── NumTraits.h
│ │ ├── ProductOp.h
│ │ ├── Profiling.h
│ │ ├── ReductionAlgorithmSelection.h
│ │ ├── ReductionOps.h
│ │ ├── ReductionOpsPlugin.inl
│ │ ├── SimpleRandom.h
│ │ ├── SolverBase.h
│ │ ├── SparseEvaluation.h
│ │ ├── SparseExpressionOp.h
│ │ ├── SparseExpressionOpPlugin.inl
│ │ ├── SparseMatrix.h
│ │ ├── SparseMatrixBase.h
│ │ ├── SparseProductEvaluation.h
│ │ ├── TransposeOp.h
│ │ ├── UnaryOps.h
│ │ └── UnaryOpsPlugin.inl
└── push_pull_inpaint.cpp
├── imgs
├── dataset.png
└── teaser.png
├── inference
├── assets
│ ├── blender_vis_base_v26_with_shrinkwrap_full_base.blend
│ └── face_landmarker.task
├── img2hair.py
└── npz2blender.py
├── inference_difflocks.py
├── k_diffusion
├── __init__.py
├── config.py
├── layers.py
├── models
│ ├── __init__.py
│ ├── attention.py
│ ├── axial_rope.py
│ ├── flags.py
│ ├── flops.py
│ ├── image_transformer_v2_conditional.py
│ └── modules.py
├── sampling.py
└── utils.py
├── losses
├── loss.py
├── loss_utils.py
└── losses.py
├── models
├── rgb_to_material.py
└── strand_codec.py
├── modules
├── __init__.py
├── edm2_modules.py
└── networks.py
├── requirements.txt
├── samples
├── buzzcut_3.jpg
├── cooper_4.jpg
├── fiennes_7.jpg
├── freeman_2.png
├── harris_5.jpeg
├── hathaway_1.jpg
├── medium_11.png
└── uggams_3.png
├── schedulers
├── linearlr.py
├── multisteplr.py
├── pytorch_warmup
│ ├── __init__.py
│ ├── base.py
│ ├── radam.py
│ └── untuned.py
└── warmup.py
├── train_rgb2material.py
├── train_scalp_diffusion.py
├── train_strandsVAE.py
└── utils
├── create_strand_latent_weights.py
├── diffusion_utils.py
├── general_util.py
├── resize_right
├── interp_methods.py
└── resize_right.py
├── strand_util.py
└── vis_util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | 1. DiffLocks License (Non-Commercial Scientific Research Use Only)
2 |
3 |
4 | Software Copyright License for non-commercial scientific research purposes
5 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the DiffLocks model, data and software, (the "Model & Software"). By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
6 |
7 | Ownership / Licensees
8 | The Software and the associated materials has been developed at the
9 |
10 | Meshcapade.
11 |
12 | Any copyright or patent right is owned by and proprietary material of the
13 |
14 | Meshcapade
15 |
16 | hereinafter the “Licensor”.
17 |
18 | License Grant
19 | Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
20 |
21 | To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization;
22 | To use the Model & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
23 | Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Meshcapade’s prior written permission.
24 |
25 | The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
26 |
27 | No Distribution
28 | The Model & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
29 |
30 | Disclaimer of Representations and Warranties
31 | You expressly acknowledge and agree that the Model & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Model & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE MODEL & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Model & Software, (ii) that the use of the Model & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Model & Software will not cause any damage of any kind to you or a third party.
32 |
33 | Limitation of Liability
34 | Because this Model & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
35 | Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
36 | Patent claims generated through the usage of the Model & Software cannot be directed towards the copyright holders.
37 | The Model & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Model & Software and is not responsible for any problems such modifications cause.
38 |
39 | No Maintenance Services
40 | You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Model & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Model & Software at any time.
41 |
42 | Defects of the Model & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
43 |
44 | Publications using the Model & Software
45 | You acknowledge that the Model & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Model & Software.
46 |
47 | Citation:
48 |
49 |
50 | @inproceedings{difflocks2025,
51 | title = {DiffLocks: Generating 3D Hair from a Single Image using Diffusion Models},
52 | author = {Rosu, Radu Alexandru and Wu, Keyu and Feng, Yao and Zheng, Youyi and Black, Michael J.},
53 | booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)},
54 | year = {2025}
55 | }
56 | Commercial licensing opportunities
57 | For commercial uses of the Software, please send email to sales@meshcapade.com
58 |
59 | This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
60 |
61 |
62 |
63 | 2. Third-Party Components
64 |
65 | Portions of this code are derived from [k-diffusion](https://github.com/crowsonkb/k-diffusion), which is licensed under the MIT License.
66 | The original copyright and license are included below:
67 |
68 | Copyright (c) 2022 Katherine Crowson
69 |
70 | Permission is hereby granted, free of charge, to any person obtaining a copy
71 | of this software and associated documentation files (the "Software"), to deal
72 | in the Software without restriction, including without limitation the rights
73 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
74 | copies of the Software, and to permit persons to whom the Software is
75 | furnished to do so, subject to the following conditions:
76 |
77 | The above copyright notice and this permission notice shall be included in
78 | all copies or substantial portions of the Software.
79 |
80 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
81 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
82 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
83 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
84 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
85 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
86 | THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DiffLocks: Generating 3D Hair from a Single Image using Diffusion Models #
2 |
3 | [**Paper**](https://arxiv.org/abs/2505.06166) | [**Project Page**](https://radualexandru.github.io/difflocks/)
4 |
5 |
6 |
7 |
8 |
9 | This repository contains official inference and training code for DiffLocks, which creates strand-based realistic hairstyle from a single image. It also contains the DiffLocks dataset consisting of 40K 3D synthetic strand-based hair data generated in Blender.
10 |
11 | ## Requirements
12 |
13 | DiffLocks dependencies can be installed from the provided `requirements.txt` which can be installed in a virtual environment:
14 |
15 | $ python3 -m venv ./difflocks_env
16 | $ source ./difflocks_env/bin/activate
17 | $ pip install -r ./requirements.txt
18 |
19 | Afterwards we need to install custom CUDA kernels for the diffusion model:
20 | * [NATTEN](https://github.com/SHI-Labs/NATTEN/tree/main) for the sparse (neighborhood) attention used at low levels of the hierarchy.
21 | * [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) for global attention.
22 | Please double check that you install Natten for torch 2.5.0 (as per requierments.txt).
23 |
24 | Finally if you want to perform inference, you need to download the checkpoints for the trained models.
25 | The pretrained checkpoints can be downloaded by following [this section](#download-pretrained-checkpoints):
26 |
27 |
28 |
29 | ## Dataset
30 | 
31 | The DiffLocks dataset consists 40K hairstyle. Each sample includes 3D hair (~100K strands), corresponding rendered RGB image and metadata regarding the hair.
32 |
33 | DiffLocks can be downloaded using:
34 |
35 | ./download_dataset.sh
36 |
37 | After downloading, the dataset has to first be uncompressed:
38 |
39 | $ ./data_processing/uncompress_data.py --dataset_zipped_path --out_path
40 |
41 | After uncompressing we create a processed dataset:
42 |
43 | $ ./data_processing/create_chunked_strands.py --dataset_path
44 | $ ./data_processing/create_latents.py --dataset_path= --out_path
45 | $ ./data_processing/create_scalp_textures.py --dataset_path= --out_path --path_strand_vae_model ./checkpoints/strand_vae/strand_codec.pt
46 |
47 |
48 | ## Download pretrained checkpoints
49 | You can download pretrained checkpoints by running:
50 |
51 | ./download_checkpoints.sh
52 |
53 | ## Inference
54 | To run inference on an RGB and create 3D strands use:
55 |
56 | $ ./inference_difflocks.py \
57 | --img_path=./samples/medium_11.png \
58 | --out_path=./outputs_inference/
59 |
60 | You also have options to export a `.blend` file and an alembic file by specifying `--blender_path` and `--export-alembic` in the above script.
61 | Note that the blender path corresponds to the blender executable with version 4.1.1. It will likely not work with other versions.
62 |
63 |
64 | ## Train StrandVAE
65 | To train the strandVAE model:
66 |
67 | $ ./train_strandsVAE.py --dataset_path= --exp_info=
68 |
69 | it will start training and outputting tensorboard logs in `./tensorboard_logs`
70 |
71 |
72 | ## Train DiffLocks diffusion model
73 | To train the diffusion model:
74 |
75 | $ ./train_scalp_diffusion.py \
76 | --config ./configs/config_scalp_texture_conditional.json \
77 | --batch-size 4 \
78 | --grad-accum-steps 4 \
79 | --mixed-precision bf16 \
80 | --use-tensorboard \
81 | --save-checkpoints \
82 | --save-every 100000 \
83 | --compile \
84 | --dataset_path= \
85 | --dataset_processed_path=
86 | --name
87 |
88 | it will start training and outputting tensorboard logs in `./tensorboard_logs`.
89 | Start training on multiple GPUs by first running:
90 |
91 | $ accelerate config
92 |
93 | followed by pre-pending `accelerate launch` to the previous training script:
94 |
95 | $ accelerate launch ./train_scalp_diffusion.py \
96 | --config ./configs/config_scalp_texture_conditional.json \
97 | --batch-size 4 \
98 |
99 |
100 | You would probably to adjust the `batch-size` and `grad-accum-step` depending on the number of GPUs you have.
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
--------------------------------------------------------------------------------
/callbacks/callback.py:
--------------------------------------------------------------------------------
1 | #https://github.com/devforfu/pytorch_playground/blob/master/loop.ipynb
2 |
3 | import re
4 | import torch
5 |
6 | def to_snake_case(string):
7 | """Converts CamelCase string into snake_case."""
8 |
9 | s = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', string)
10 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s).lower()
11 |
12 | def classname(obj):
13 | return obj.__class__.__name__
14 |
15 |
16 | class Callback:
17 | """
18 | The base class inherited by callbacks.
19 |
20 | Provides a lot of hooks invoked on various stages of the training loop
21 | execution. The signature of functions is as broad as possible to allow
22 | flexibility and customization in descendant classes.
23 | """
24 | def training_started(self, **kwargs): pass
25 |
26 | def training_ended(self, **kwargs): pass
27 |
28 | def epoch_started(self, **kwargs): pass
29 |
30 | def phase_started(self, **kwargs): pass
31 |
32 | def phase_ended(self, **kwargs): pass
33 |
34 | def epoch_ended(self, **kwargs): pass
35 |
36 | def batch_started(self, **kwargs): pass
37 |
38 | def batch_ended(self, **kwargs): pass
39 |
40 | def before_forward_pass(self, **kwargs): pass
41 |
42 | def after_forward_pass(self, **kwargs): pass
43 |
44 | def before_backward_pass(self, **kwargs): pass
45 |
46 | def after_backward_pass(self, **kwargs): pass
47 |
48 |
49 | class CallbacksGroup(Callback):
50 | """
51 | Groups together several callbacks and delegates training loop
52 | notifications to the encapsulated objects.
53 | """
54 | def __init__(self, callbacks):
55 | self.callbacks = callbacks
56 | self.named_callbacks = {to_snake_case(classname(cb)): cb for cb in self.callbacks}
57 |
58 | def __getitem__(self, item):
59 | item = to_snake_case(item)
60 | if item in self.named_callbacks:
61 | return self.named_callbacks[item]
62 | raise KeyError(f'callback name is not found: {item}')
63 |
64 | def training_started(self, **kwargs): self.invoke('training_started', **kwargs)
65 |
66 | def training_ended(self, **kwargs): self.invoke('training_ended', **kwargs)
67 |
68 | def epoch_started(self, **kwargs): self.invoke('epoch_started', **kwargs)
69 |
70 | def phase_started(self, **kwargs): self.invoke('phase_started', **kwargs)
71 |
72 | def phase_ended(self, **kwargs): self.invoke('phase_ended', **kwargs)
73 |
74 | def epoch_ended(self, **kwargs): self.invoke('epoch_ended', **kwargs)
75 |
76 | def batch_started(self, **kwargs): self.invoke('batch_started', **kwargs)
77 |
78 | def batch_ended(self, **kwargs): self.invoke('batch_ended', **kwargs)
79 |
80 | def before_forward_pass(self, **kwargs): self.invoke('before_forward_pass', **kwargs)
81 |
82 | def after_forward_pass(self, **kwargs): self.invoke('after_forward_pass', **kwargs)
83 |
84 | def before_backward_pass(self, **kwargs): self.invoke('before_backward_pass', **kwargs)
85 |
86 | def after_backward_pass(self, **kwargs): self.invoke('after_backward_pass', **kwargs)
87 |
88 | def invoke(self, method, **kwargs):
89 | with torch.set_grad_enabled(False):
90 | for cb in self.callbacks:
91 | getattr(cb, method)(**kwargs)
--------------------------------------------------------------------------------
/callbacks/callback_utils.py:
--------------------------------------------------------------------------------
1 | # from permuto_sdf import TrainParams
2 | from callbacks.callback import *
3 | from callbacks.wandb_callback import *
4 | from callbacks.state_callback import *
5 | from callbacks.phase import *
6 |
7 |
8 | def create_callbacks(with_tensorboard, with_visualizer, experiment_name, viewer_config_path=None):
9 | cb_list = []
10 | if(with_tensorboard):
11 | from callbacks.tensorboard_callback import TensorboardCallback #we put it here in case we don't have tensorboard installed
12 | tensorboard_callback=TensorboardCallback(experiment_name)
13 | cb_list.append(tensorboard_callback)
14 | if(with_visualizer):
15 | from callbacks.viewer_callback import ViewerCallback #we put it here because we might not have the visualizer package installed
16 | viewer_callback=ViewerCallback(viewer_config_path, experiment_name)
17 | cb_list.append(viewer_callback)
18 | cb_list.append(StateCallback())
19 | cb = CallbacksGroup(cb_list)
20 |
21 | return cb
22 |
--------------------------------------------------------------------------------
/callbacks/phase.py:
--------------------------------------------------------------------------------
1 | #https://github.com/devforfu/pytorch_playground/blob/master/loop.ipynbA
2 |
3 | # from permuto_sdf_py.callbacks.scores import *
4 |
5 | class Phase:
6 | """
7 | Model training loop phase.
8 |
9 | Each model's training loop iteration could be separated into (at least) two
10 | phases: training and validation. The instances of this class track
11 | metrics and counters, related to the specific phase, and keep the reference
12 | to subset of data, used during phase.
13 | """
14 |
15 | def __init__(self, name, loader, grad):
16 | self.name = name
17 | self.loader = loader
18 | self.grad = grad
19 | self.iter_nr = 0
20 | self.epoch_nr = 0
21 | self.samples_processed_this_epoch = 0
22 | # self.scores= Scores()
23 | self.loss_acum_per_epoch=0.0
24 | self.loss_pos_acum_per_epoch=0.0
25 | self.loss_dir_acum_per_epoch=0.0
26 |
--------------------------------------------------------------------------------
/callbacks/state_callback.py:
--------------------------------------------------------------------------------
1 | from callbacks.callback import *
2 | import os
3 | import torch
4 |
5 |
6 | class StateCallback(Callback):
7 |
8 | def __init__(self):
9 | pass
10 |
11 | def after_forward_pass(self, phase, loss, loss_pos, loss_dir, loss_curv, **kwargs):
12 | phase.iter_nr+=1
13 | phase.samples_processed_this_epoch+=1
14 | phase.loss_acum_per_epoch+=loss
15 | phase.loss_pos_acum_per_epoch+=loss_pos
16 | phase.loss_dir_acum_per_epoch+=loss_dir
17 | phase.loss_curv_acum_per_epoch+=loss_curv
18 |
19 |
20 | def epoch_started(self, phase, **kwargs):
21 | phase.loss_acum_per_epoch=0.0
22 | phase.loss_pos_acum_per_epoch=0.0
23 | phase.loss_dir_acum_per_epoch=0.0
24 | phase.loss_curv_acum_per_epoch=0.0
25 |
26 | def epoch_ended(self, phase, **kwargs):
27 |
28 | phase.epoch_nr+=1
29 |
30 | def phase_started(self, phase, **kwargs):
31 | phase.samples_processed_this_epoch=0
32 |
33 | def phase_ended(self, phase, model, hyperparams, experiment_name, output_training_path, **kwargs):
34 |
35 | if (phase.epoch_nr%hyperparams.save_checkpoint_every_x_epoch==0) and hyperparams.save_checkpoint and phase.grad:
36 | model.save(output_training_path, experiment_name, hyperparams, phase.epoch_nr)
37 |
38 |
39 |
--------------------------------------------------------------------------------
/callbacks/tensorboard_callback.py:
--------------------------------------------------------------------------------
1 | from callbacks.callback import *
2 | from torch.utils.tensorboard import SummaryWriter
3 |
4 | class TensorboardCallback(Callback):
5 |
6 | def __init__(self, experiment_name):
7 | self.tensorboard_writer = SummaryWriter("tensorboard_logs/"+experiment_name)
8 | self.experiment_name=experiment_name
9 |
10 |
11 | def after_forward_pass(self, phase, loss=0, loss_pos=0, loss_dir=0, loss_curv=0, loss_kl=0, lr=0, z_deviation=None, z=None, z_no_eps=None, **kwargs):
12 |
13 | if phase.iter_nr%300==0 and phase.grad:
14 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss', loss.item(), phase.iter_nr)
15 | if loss_pos!=0:
16 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_pos', loss_pos.item(), phase.iter_nr)
17 | # if loss_l1!=0:
18 | # self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_l1', loss_l1.item(), phase.iter_nr)
19 | # if loss_l2!=0:
20 | # self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_l2', loss_l2.item(), phase.iter_nr)
21 | if loss_dir!=0:
22 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_dir', loss_dir.item(), phase.iter_nr)
23 | if loss_curv!=0:
24 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_curv', loss_curv.item(), phase.iter_nr)
25 | if loss_kl!=0:
26 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_kl', loss_kl.item(), phase.iter_nr)
27 |
28 | if lr!=0:
29 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/lr', lr, phase.iter_nr)
30 |
31 | if z_deviation is not None:
32 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/z_deviation', z_deviation.std(), phase.iter_nr)
33 |
34 | # if z is not None:
35 | # self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/z_max', z.max(), phase.iter_nr)
36 | if z_no_eps is not None:
37 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/z_no_eps_mean', z_no_eps.mean(), phase.iter_nr)
38 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/z_no_eps_std', z_no_eps.std(), phase.iter_nr)
39 |
40 |
41 | def epoch_ended(self, phase, **kwargs):
42 | avg_loss_pos=phase.loss_pos_acum_per_epoch/phase.samples_processed_this_epoch
43 | avg_loss_dir=phase.loss_dir_acum_per_epoch/phase.samples_processed_this_epoch
44 | avg_loss_curv=phase.loss_curv_acum_per_epoch/phase.samples_processed_this_epoch
45 |
46 | if phase.grad==False and phase.loss_pos_acum_per_epoch!=0:
47 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_pos_avg', avg_loss_pos.item(), phase.epoch_nr)
48 | if phase.grad==False and phase.loss_dir_acum_per_epoch!=0:
49 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_dir_avg', avg_loss_dir.item(), phase.epoch_nr)
50 | if phase.grad==False and phase.loss_curv_acum_per_epoch!=0:
51 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_curv_avg', avg_loss_curv.item(), phase.epoch_nr)
52 |
--------------------------------------------------------------------------------
/callbacks/viewer_callback.py:
--------------------------------------------------------------------------------
1 | from callbacks.callback import *
2 | import numpy as np
3 | from gloss import *
4 |
5 | class ViewerCallback(Callback):
6 |
7 | def __init__(self, viewer_config_path, experiment_name):
8 | gloss_setup_logger(log_level=LogLevel.Info)
9 | self.viewer=Viewer(viewer_config_path)
10 | self.scene=self.viewer.get_scene()
11 |
12 | self.first_time=True
13 |
14 | self.visualize_every_x_iters=100
15 | self.visualize_strand_at_idx=None
16 | # self.visualize_strand_at_idx=0
17 |
18 | def after_forward_pass(self, phase, gt_cloud=None, pred_cloud=None, **kwargs):
19 |
20 | # if phase.iter_nr==1000:
21 | # exit(1)
22 |
23 | if phase.iter_nr%self.visualize_every_x_iters==0:
24 | if self.visualize_strand_at_idx is not None:
25 | gt_cloud=gt_cloud[self.visualize_strand_at_idx,:,:]
26 | gt_ent = self.scene.get_or_spawn_renderable("gt_ent")
27 | gt_ent.insert(Verts(gt_cloud.reshape(-1,3).cpu().numpy()))
28 | gt_ent.insert(Colors(gt_cloud.reshape(-1,3).cpu().numpy())) #just to avoid verts being different than Colors when we switch nr_verts for the same entity
29 | if self.first_time:
30 | gt_ent.insert(VisPoints(show_points=True, \
31 | point_size=3.0, \
32 | color_type=PointColorType.Solid))
33 |
34 |
35 | #pred
36 | if self.visualize_strand_at_idx is not None:
37 | pred_cloud=pred_cloud[self.visualize_strand_at_idx,:,:]
38 | pred_ent = self.scene.get_or_spawn_renderable("pred_ent")
39 | pred_ent.insert(Verts(pred_cloud.reshape(-1,3).cpu().numpy()))
40 | pred_ent.insert(Colors(pred_cloud.reshape(-1,3).cpu().numpy())) #just to avoid verts being different than Colors when we switch nr_verts for the same entity
41 | if self.first_time:
42 | pred_ent.insert(VisPoints(show_points=True, \
43 | point_size=3.0, \
44 | point_color=[0.1, 0.2, 0.8, 1.0],
45 | color_type=PointColorType.Solid))
46 |
47 | #render
48 | self.viewer.start_frame()
49 | self.viewer.update()
50 |
51 |
52 | self.first_time=False
53 |
54 |
55 |
56 |
57 |
58 | # viewer=Viewer()
59 | # scene=viewer.get_scene()
60 |
61 | # mesh = scene.get_or_spawn_renderable("test")
62 | # mesh.insert(Verts(
63 | # np.array([
64 | # [0,0,0],
65 | # [0,1,0],
66 | # [1,0,0],
67 | # [1.5,1.5,-1]
68 | # ], dtype = "float32") ))
69 | # mesh.insert(Colors(
70 | # np.array([
71 | # [1,0,0],
72 | # [0,0,1],
73 | # [1,1,0],
74 | # [1,0,1]
75 | # ], dtype = "float32")))
76 |
77 | # mesh.insert(VisPoints(show_points=True, \
78 | # show_points_indices=True, \
79 | # point_size=10.0, \
80 | # color_type=PointColorType.PerVert))
81 |
82 | # while True:
83 | # viewer.start_frame()
84 | # viewer.update()
85 |
86 |
87 |
--------------------------------------------------------------------------------
/callbacks/wandb_callback.py:
--------------------------------------------------------------------------------
1 | from callbacks.callback import *
2 | import wandb
3 | import hjson
4 |
5 | class WandBCallback(Callback):
6 |
7 | def __init__(self, experiment_name, config_path, entity):
8 | self.experiment_name=experiment_name
9 | # loading the config file like this and giving it to wandb stores them on the website
10 | with open(config_path, 'r') as j:
11 | cfg = hjson.loads(j.read())
12 | # Before this init can be run, you have to use wandb login in the console you are starting the script from (https://docs.wandb.ai/ref/cli/wandb-login, https://docs.wandb.ai/ref/python/init)
13 | # entity= your username
14 | wandb.init(project=experiment_name, entity=entity,config = cfg)
15 |
16 |
17 | def after_forward_pass(self, phase, loss, loss_kl, **kwargs):
18 |
19 | # / act as seperators. If you would like to log train and test separately you would log test loss in test/loss
20 | wandb.log({'train/loss': loss}, step=phase.iter_nr)
21 | if loss_kl!=0:
22 | wandb.log({'train/loss_kl': loss_kl}, step=phase.iter_nr)
23 |
24 |
25 | def epoch_ended(self, phase, **kwargs):
26 | pass
--------------------------------------------------------------------------------
/checkpoints/HOW_TO_DOWNLOAD.txt:
--------------------------------------------------------------------------------
1 | If you want to download the checkpoints please use the download_checkpoints.sh script from the root of the repo.
2 | It will download all necessary checkpoints and unzip them in this folder
--------------------------------------------------------------------------------
/configs/config_scalp_texture_conditional.json:
--------------------------------------------------------------------------------
1 | {
2 |
3 | "model": {
4 | "type": "image_transformer_v2_conditional",
5 | "cross_cond": 1,
6 | "condition_dropout_rate": 0.1,
7 | "input_channels": 65,
8 | "input_size": [256, 256],
9 | "patch_size": [2, 2],
10 | "depths": [2, 2, 2],
11 | "widths": [1024, 2048, 3072],
12 | "d_ffs": [1024, 2048, 3072],
13 | "self_attns": [
14 | {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
15 | {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
16 | {"type": "global", "d_head": 64}
17 | ],
18 | "parametrization": "v",
19 | "loss_config": "karras",
20 | "loss_weighting": "snr",
21 | "dropout_rate": [0.0, 0.0, 0.1],
22 | "mapping_width": 768,
23 | "mapping_d_ff": 1024,
24 | "mapping_dropout_rate": 0.0,
25 | "augment_prob": 0.0,
26 | "sigma_data": 0.3,
27 | "sigma_min": 1e-2,
28 | "sigma_max": 160,
29 | "sigma_sample_density": {
30 | "type": "cosine-interpolated",
31 | "noise_d_low": 64
32 | },
33 |
34 | "rgb_condition":{
35 | "global_condition_shape": [1, 1024],
36 | "local_condition_shapes": [
37 | {"shape": [1, 1024, 55, 55]},
38 | {"shape": [1, 1024, 55, 55]},
39 | {"shape": [1, 1024, 55, 55]}
40 | ],
41 | "cross_condition_dim": [512,512,512],
42 | "self_attn": [false,true,true]
43 | },
44 |
45 | "loss_weight_per_channel":[
46 | 0.05580984428524971, 0.030566953122615814, 0.025709286332130432, 0.05716482549905777, 0.023201454430818558, 0.08562534302473068, 0.017659854143857956, 0.07521561533212662, 0.0483829528093338, 0.017475251108407974, 0.021301859989762306, 0.006980733945965767, 0.06842825561761856, 0.16614384949207306, 0.028791628777980804, 0.07338058203458786, 0.08613209426403046, 0.01631935127079487, 0.015131217427551746, 0.06258893758058548, 0.058801501989364624, 0.05222005024552345, 0.03300714120268822, 0.028619959950447083, 0.066630519926548, 0.10768849402666092, 0.07185778766870499, 0.08246062695980072, 0.08738923817873001, 0.5705997943878174, 0.059733662754297256, 0.1032959371805191, 0.008284210227429867, 0.028203167021274567, 0.015574352815747261, 0.031722474843263626, 0.11103591322898865, 1.0, 0.33545804023742676, 0.02750393934547901, 0.037128742784261703, 0.04466118663549423, 0.0203529242426157, 0.018273167312145233, 0.140573188662529, 0.08498218655586243, 0.05887111648917198, 0.022760389372706413, 0.1844673603773117, 0.052569273859262466, 0.12437944859266281, 0.011494286358356476, 0.3124508857727051, 0.0704188272356987, 0.21793445944786072, 0.04560207203030586, 0.14500388503074646, 0.01105482131242752, 0.028649816289544106, 0.015449784696102142, 0.06804649531841278, 0.01901610754430294, 0.036636028438806534, 0.04171859472990036,
47 |
48 | 1.0
49 | ]
50 | },
51 | "optimizer": {
52 | "lr": 5e-4,
53 | "betas": [0.9, 0.95],
54 | "eps": 1e-8,
55 | "weight_decay": 1e-3
56 | },
57 | "lr_sched": {
58 | "type": "constant",
59 | "warmup": 0.0
60 | },
61 | "ema_sched": {
62 | "type": "inverse",
63 | "power": 0.75,
64 | "max_value": 0.9999
65 | }
66 | }
67 |
68 |
--------------------------------------------------------------------------------
/configs/strand_vae_train.toml:
--------------------------------------------------------------------------------
1 | [core]
2 | floor_type= "grid"
3 | floor_origin = [0,0,0]
4 | floor_scale = 30.0
5 | auto_create_logger=false
6 | log_level = "off"
7 |
8 | [render]
9 | msaa_nr_samples = 4
10 |
--------------------------------------------------------------------------------
/data_loader/difflocks_bodydata/scalp.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/data_loader/difflocks_bodydata/scalp.ply
--------------------------------------------------------------------------------
/data_loader/difflocks_bodydata/smplx_base.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/data_loader/difflocks_bodydata/smplx_base.ply
--------------------------------------------------------------------------------
/data_processing/create_chunked_strands.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | #from the full_strands.npz it creates chunked versions of it they can be more easily loaded by the trainer for strand vae
4 |
5 | # ./create_chunked_strands.py --dataset_path
6 |
7 |
8 | import sys
9 | import os
10 | import argparse
11 | import numpy as np
12 | from tqdm import tqdm
13 | from torch.utils.data import DataLoader
14 | import math
15 | sys.path.append(
16 | os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
17 | from data_loader.dataloader import DiffLocksDataset
18 |
19 |
20 | def main():
21 |
22 | #argparse
23 | parser = argparse.ArgumentParser(description='Create latents')
24 | parser.add_argument('--dataset_path', required=True, help='Path to the hair_synth dataset')
25 | args = parser.parse_args()
26 |
27 |
28 | difflocks_dataset = DiffLocksDataset(args.dataset_path,
29 | check_validity=False,
30 | load_full_strands=True
31 | )
32 | loader = DataLoader(difflocks_dataset, batch_size=1, num_workers=8, shuffle=False, pin_memory=True, persistent_workers=True)
33 |
34 |
35 |
36 | progress_bar = tqdm(range(0, len(difflocks_dataset)), desc="Training progress")
37 |
38 | for batch in loader:
39 | progress_bar.update()
40 |
41 | # print("batch",batch)
42 |
43 | path_hairstyle=batch["path"][0]
44 | print("path", path_hairstyle)
45 |
46 | positions=batch["full_strands"]["positions"].squeeze(0).cpu().numpy()
47 | root_uv=batch["full_strands"]["root_uv"].squeeze(0).cpu().numpy()
48 | root_normal=batch["full_strands"]["root_normal"].squeeze(0).cpu().numpy()
49 | # print("positions", positions.shape)
50 | # print("root_uv", root_uv.shape)
51 | # print("root_normal", root_normal.shape)
52 |
53 |
54 | select_nr_random_strands=10000
55 | nr_strands_per_chunk_list=[1000,100]
56 |
57 | #select random strands because the original 100K is too much
58 | if select_nr_random_strands:
59 | nr_strands_left = positions.shape[0]
60 | per_curve_keep_random = np.random.choice(nr_strands_left, select_nr_random_strands, replace=False)
61 | positions=positions[per_curve_keep_random,:,:]
62 | root_uv=root_uv[per_curve_keep_random,:]
63 | root_normal=root_normal[per_curve_keep_random,:]
64 |
65 |
66 | #break the strands into chunks if needed and write those too
67 | nr_strands_total = positions.shape[0]
68 | if nr_strands_per_chunk_list:
69 | for nr_strands_cur_chunk in nr_strands_per_chunk_list:
70 | nr_chunks = math.ceil(nr_strands_total/nr_strands_cur_chunk)
71 |
72 | positions_chunked = np.array_split(positions, nr_chunks,axis=0)
73 | root_uv_chunked = np.array_split(root_uv, nr_chunks,axis=0)
74 | root_normal_chunked = np.array_split(root_normal, nr_chunks,axis=0)
75 |
76 | #make path for this chunked data
77 | chunked_data_path=os.path.join(path_hairstyle,"full_strands_chunked","nr_strands_"+str(nr_strands_cur_chunk))
78 | os.makedirs(chunked_data_path, exist_ok=True)
79 |
80 | #write each chunk
81 | for idx_chunk in range(nr_chunks):
82 | npz_path = os.path.join(chunked_data_path,str(idx_chunk)+".npz")
83 | np.savez(npz_path, positions=positions_chunked[idx_chunk],\
84 | root_uv=root_uv_chunked[idx_chunk],\
85 | root_normal=root_normal_chunked[idx_chunk])
86 |
87 |
88 | return
89 |
90 |
91 | if __name__ == '__main__':
92 | main()
93 |
94 |
--------------------------------------------------------------------------------
/data_processing/create_latents.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | #creates latent representations of all the rgb images in the dataset, this will be useful when using them to condition the diffusion model
4 |
5 | #./create_latents.py --subsample_factor 1 --dataset_path= --out_path
6 |
7 |
8 | import sys
9 | import os
10 | import argparse
11 | import torch
12 | import torchvision
13 | sys.path.append(
14 | os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
15 | from models.strand_codec import StrandCodec
16 | from torch.utils.data import DataLoader
17 | import utils.resize_right.resize_right as resize_right
18 | import utils.resize_right.interp_methods as interp_methods
19 | from tqdm import tqdm
20 | import torchvision.transforms as T
21 | from data_loader.dataloader import DiffLocksDataset
22 |
23 |
24 |
25 | torch.set_grad_enabled(False)
26 |
27 |
28 | def horizontally_flip(batch):
29 |
30 | rgb_img=batch["rgb_img"]
31 | rgb_img=torchvision.transforms.functional.hflip(rgb_img)
32 | batch["rgb_img"]=rgb_img
33 |
34 | return batch
35 |
36 |
37 | def generate_latents_dinov2(args, batch, preprocessor, model, output_latents_path):
38 | #encode img
39 | rgb_img=batch["rgb_img"].cuda()
40 | rgb_img=rgb_img[:,0:3,:,:]
41 |
42 |
43 | rgb_input = preprocessor(rgb_img).to("cuda")
44 | ret = model.forward_features(rgb_input)
45 | patch_tok = ret["x_norm_patchtokens"].clone()
46 | cls_tok = ret["x_norm_clstoken"].clone()
47 |
48 | # print("outputs",outputs)
49 |
50 | #only makes sense to write the last layer becuase it';s dinov2 and the other ones are not coarser representations
51 | cls_token=cls_tok
52 | # print("cls", cls_token.shape)
53 | patch_embeddings = patch_tok
54 | #reshape to [Batch_size, h, w, embedding]
55 | batch_size, num_patches, hidden_size = patch_embeddings.shape
56 | h = w = int(num_patches ** 0.5) # Assuming the number of patches is a perfect square (e.g., 14x14)
57 | patch_embeddings_reshaped = patch_embeddings.reshape(batch_size, h, w, hidden_size)
58 | patch_embeddings_reshaped=patch_embeddings_reshaped.permute(0,3,1,2).contiguous() #Make it bchw
59 | #write last layer
60 | out_path_final_latent=os.path.join(output_latents_path, "final_latent.pt")
61 | torch.save(patch_embeddings_reshaped, out_path_final_latent)
62 | #write cls token which is like an embedding for the whole image
63 | out_path_cls_token=os.path.join(output_latents_path, "cls_token.pt")
64 | #writing cls token of size
65 | # print("cls_token",cls_token.shape)
66 | torch.save(cls_token, out_path_cls_token)
67 |
68 | # feat_pca = img_2_pca(patch_embeddings_reshaped)
69 | # torchvision.utils.save_image(feat_pca.squeeze(0), os.path.join(output_latents_path, "final_latent.png"))
70 |
71 |
72 |
73 | #write a file to signify that we are done with this folder
74 | #start with x so that rsync copies it last if we copy to local
75 | open( os.path.join(output_latents_path, "x_done.txt"), 'a').close()
76 |
77 |
78 | def main():
79 |
80 | #argparse
81 | parser = argparse.ArgumentParser(description='Create latents')
82 | parser.add_argument('--dataset_path', required=True, help='Path to the hair_synth dataset')
83 | parser.add_argument('--out_path', required=True, type=str, help='Where to output the processed hair_synth dataset')
84 | parser.add_argument('--subsample_factor', default=1, type=int, help='Subsample factor for the RGB img')
85 | parser.add_argument('--skip_validity_check', dest='check_validity', action='store_false', help='Wether to check for the validity of each hairstyle we read from the dataset. Some older dataset versions might need this turned to false')
86 | args = parser.parse_args()
87 |
88 |
89 | #v2 from torch
90 | image_size = int(768/(2**(args.subsample_factor-1)))
91 | print("Selected dino with img size", image_size)
92 | #going to the nearest multiple of 14 because 14 is the patch size
93 | if image_size==768:
94 | image_size=770
95 | else:
96 | print("I haven't implemented the other ones yet")
97 | latents_preprocessor = T.Compose([
98 | T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC),
99 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
100 | ])
101 | latents_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
102 | latents_model.cuda()
103 |
104 | latents_model.eval()
105 |
106 |
107 |
108 |
109 | print("args.check_validity",args.check_validity)
110 |
111 | difflocks_dataset = DiffLocksDataset(args.dataset_path,
112 | check_validity=args.check_validity,
113 | load_rgb_imgs=True,
114 | processed_difflocks_path = args.out_path,
115 | subsample_factor=args.subsample_factor,
116 | )
117 | loader = DataLoader(difflocks_dataset, batch_size=1, num_workers=8, shuffle=False, pin_memory=True, persistent_workers=True)
118 |
119 |
120 |
121 | progress_bar = tqdm(range(0, len(difflocks_dataset)), desc="Training progress")
122 |
123 | for batch in loader:
124 | progress_bar.update()
125 |
126 | #make the output path
127 | output_latents_path=os.path.join(args.out_path, "processed_hairstyles", batch["file"][0], "latents_"+"dinov2"+"_subsample_"+str(args.subsample_factor))
128 | os.makedirs(output_latents_path, exist_ok=True)
129 | #check if we already created this one
130 | if not os.path.isfile( os.path.join(output_latents_path,"x_done.txt")):
131 | # if True:
132 | #if it doesn't exist or we can't load it we create it
133 | generate_latents_dinov2(args, batch, latents_preprocessor, latents_model, output_latents_path)
134 |
135 |
136 | # #generate also a flipped texture, the reason being that just flipping the rgb does not result in a flipped latents neceserily so we have to horizontally flip the data in the batch then encode a new flipped latent
137 | #make the output path
138 | output_latents_path=os.path.join(args.out_path, "processed_hairstyles", batch["file"][0], "latents_flipped_"+"dinov2"+"_subsample_"+str(args.subsample_factor))
139 | os.makedirs(output_latents_path, exist_ok=True)
140 | #check if we already created this one
141 | if not os.path.isfile( os.path.join(output_latents_path,"x_done.txt")):
142 | # if True:
143 | batch=horizontally_flip(batch)
144 | #if it doesn't exist or we can't load it we create it
145 | generate_latents_dinov2(args, batch, latents_preprocessor, latents_model, output_latents_path)
146 |
147 |
148 |
149 | #finished training
150 | return
151 |
152 |
153 | if __name__ == '__main__':
154 | main()
155 |
156 |
--------------------------------------------------------------------------------
/data_processing/uncompress_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | #given the difflocks dataset downloaded from mpi uncompress everything
4 |
5 | # ./uncompress_data.py --dataset_zipped_path --out_path
6 |
7 |
8 |
9 | import sys
10 | import os
11 | import argparse
12 | import subprocess
13 | from os import listdir
14 | from os.path import isfile, join
15 |
16 |
17 | def main():
18 |
19 | #argparse
20 | parser = argparse.ArgumentParser(description='Uncompress dataset')
21 | parser.add_argument('--dataset_zipped_path', required=True, help='Path to the difflocks folder containing all zip files')
22 | parser.add_argument('--out_path', required=True, type=str, help='Where to output the difflocks dataset')
23 | args = parser.parse_args()
24 |
25 |
26 | in_path=args.dataset_zipped_path
27 | onlyfiles = [f for f in listdir(in_path) if isfile(join(in_path, f))]
28 | for file_name in onlyfiles:
29 | filepath=os.path.join(in_path,file_name)
30 | cmd=["7zz","x", "-y", filepath, "-o"+args.out_path]
31 | # print("cmd",cmd)
32 | subprocess.run(cmd, capture_output=False)
33 | print("filename", file_name)
34 |
35 |
36 | if __name__ == '__main__':
37 | main()
38 |
39 |
--------------------------------------------------------------------------------
/dockerfile/build.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Check args
4 | if [ "$#" -ne 0 ]; then
5 | echo "usage: ./build.sh"
6 | return 1
7 | fi
8 |
9 |
10 | # Build the docker image
11 | docker build\
12 | --build-arg user=$USER\
13 | --build-arg uid=$UID\
14 | --build-arg home=$HOME\
15 | --build-arg workspace=$SCRIPTPATH\
16 | --build-arg shell=$SHELL\
17 | -t difflocks_dock \
18 | --progress=plain \
19 | -f Dockerfile .
20 |
21 |
--------------------------------------------------------------------------------
/dockerfile/nvidia-container-runtime-script.sh:
--------------------------------------------------------------------------------
1 | curl -s -L https://nvidia.github.io/nvidia-container-runtime/gpgkey | \
2 | sudo apt-key add -
3 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
4 | curl -s -L https://nvidia.github.io/nvidia-container-runtime/$distribution/nvidia-container-runtime.list | \
5 | sudo tee /etc/apt/sources.list.d/nvidia-container-runtime.list
6 | sudo apt-get update
7 | sudo apt-get install -y nvidia-container-toolkit
8 | sudo systemctl restart docker
--------------------------------------------------------------------------------
/dockerfile/nvidia-container-runtime-script_ubuntu24.sh:
--------------------------------------------------------------------------------
1 | curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
2 | && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
3 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
4 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
5 | sudo apt-get update
6 | sudo apt-get install -y nvidia-container-toolkit
7 | sudo systemctl restart docker
--------------------------------------------------------------------------------
/dockerfile/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Check args
4 | if [ "$#" -ne 0 ]; then
5 | echo "usage: ./run.sh"
6 | exit 1
7 | fi
8 |
9 | # Get this script's path
10 | pushd `dirname $0` > /dev/null
11 | SCRIPTPATH=`pwd`
12 | popd > /dev/null
13 |
14 | set -e
15 |
16 | # --volume $SSH_AUTH_SOCK:/ssh-agent \
17 | # --env SSH_AUTH_SOCK=/ssh-agent \
18 |
19 |
20 | # for more info see: https://medium.com/@benjamin.botto/opengl-and-cuda-applications-in-docker-af0eece000f1
21 | # for more info see: https://gist.github.com/RafaelPalomar/f594933bb5c07184408c480184c2afb4
22 | # Run the container with shared X11
23 | docker run\
24 | --rm \
25 | --shm-size 12G\
26 | --gpus all\
27 | --net host\
28 | --privileged\
29 | -e SHELL\
30 | -e DISPLAY\
31 | -e DOCKER=1\
32 | -v /dev:/dev\
33 | --volume=/run/user/${USER_UID}/pulse:/run/user/1000/pulse \
34 | -e PULSE_SERVER=unix:${XDG_RUNTIME_DIR}/pulse/native \
35 | -v ${XDG_RUNTIME_DIR}/pulse/native:${XDG_RUNTIME_DIR}/pulse/native \
36 | -v ~/.config/pulse/cookie:/root/.config/pulse/cookie \
37 | --group-add $(getent group audio | cut -d: -f3) \
38 | --volume="/tmp/.X11-unix:/tmp/.X11-unix:rw" \
39 | --volume="/etc/group:/etc/group:ro" \
40 | --volume="/etc/passwd:/etc/passwd:ro" \
41 | --volume="/etc/shadow:/etc/shadow:ro" \
42 | -v "$HOME:$HOME:rw"\
43 | --name difflocks_dock\
44 | -it difflocks_dock:latest
45 |
46 |
--------------------------------------------------------------------------------
/download_checkpoints.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
3 |
4 |
5 | # username and password input
6 | echo -e "\nYou need to register at https://difflocks.is.tue.mpg.de/"
7 | read -p "Username (DIFFLOCKS):" username
8 | read -p "Password (DIFFLOCKS):" password
9 | username=$(urle $username)
10 | password=$(urle $password)
11 |
12 | echo -e "\nDownloading files..."
13 |
14 |
15 | #checkpoints
16 | wget --post-data "username=$username&password=$password" \
17 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=difflocks_checkpoints.zip" \
18 | -O "difflocks_checkpoints.zip" --no-check-certificate --continue
19 | # Check if wget failed (non-zero exit code)
20 | if [ $? -ne 0 ]; then
21 | echo "❌ Error downloading body data. Exiting script."
22 | exit 1
23 | fi
24 |
25 | #unzip
26 | unzip difflocks_checkpoints.zip
27 |
28 |
29 |
--------------------------------------------------------------------------------
/download_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
3 |
4 | # Get the download path from the first argument
5 | TARGET_PATH="$1"
6 |
7 | # Check if the path was given
8 | if [ -z "$TARGET_PATH" ]; then
9 | echo "Usage: $0 "
10 | exit 1
11 | fi
12 |
13 | # Make sure the directory exists
14 | mkdir -p "$TARGET_PATH"
15 |
16 | # username and password input
17 | echo -e "\nYou need to register at https://difflocks.is.tue.mpg.de/"
18 | read -p "Username (DIFFLOCKS):" username
19 | read -p "Password (DIFFLOCKS):" password
20 | username=$(urle $username)
21 | password=$(urle $password)
22 |
23 | echo -e "\nDownloading files..."
24 |
25 |
26 | #BODY DATA
27 | filename='difflocks_dataset_body_data.7z'
28 | wget --post-data "username=$username&password=$password" \
29 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=$filename" \
30 | -O "$TARGET_PATH/$filename" --no-check-certificate --continue
31 | # Check if wget failed (non-zero exit code)
32 | if [ $? -ne 0 ]; then
33 | echo "❌ Error downloading body data. Exiting script."
34 | exit 1
35 | fi
36 |
37 | #IMGS
38 | for i in $(seq 0 20); do
39 | filename="difflocks_dataset_imgs_chunk_${i}.7z"
40 | # Run wget and capture status
41 | wget --post-data "username=$username&password=$password" \
42 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=${filename}" \
43 | -O "$TARGET_PATH/$filename" --no-check-certificate --continue
44 | # Check if wget failed (non-zero exit code)
45 | if [ $? -ne 0 ]; then
46 | echo "❌ Error downloading $filename. Exiting script."
47 | exit 1
48 | fi
49 | done
50 |
51 | #IMGS v2
52 | for i in $(seq 0 20); do
53 | filename="difflocks_dataset_imgs_v2_chunk_${i}.7z"
54 | # Run wget and capture status
55 | wget --post-data "username=$username&password=$password" \
56 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=${filename}" \
57 | -O "$TARGET_PATH/$filename" --no-check-certificate --continue
58 | # Check if wget failed (non-zero exit code)
59 | if [ $? -ne 0 ]; then
60 | echo "❌ Error downloading $filename. Exiting script."
61 | exit 1
62 | fi
63 | done
64 |
65 | #HAIRSTYLES
66 | for i in $(seq 0 4455); do
67 | filename="difflocks_dataset_hairstyles_chunk_${i}.7z"
68 | # Run wget and capture status
69 | wget --post-data "username=$username&password=$password" \
70 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=${filename}" \
71 | -O "$TARGET_PATH/$filename" --no-check-certificate --continue
72 | # Check if wget failed (non-zero exit code)
73 | if [ $? -ne 0 ]; then
74 | echo "❌ Error downloading $filename. Exiting script."
75 | exit 1
76 | fi
77 | done
78 |
79 |
80 |
--------------------------------------------------------------------------------
/download_validation.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
3 |
4 |
5 | # username and password input
6 | echo -e "\nYou need to register at https://difflocks.is.tue.mpg.de/"
7 | read -p "Username (DIFFLOCKS):" username
8 | read -p "Password (DIFFLOCKS):" password
9 | username=$(urle $username)
10 | password=$(urle $password)
11 |
12 | echo -e "\nDownloading files..."
13 |
14 |
15 | #checkpoints
16 | wget --post-data "username=$username&password=$password" \
17 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=difflocks_validation.7z" \
18 | -O "difflocks_validation.7z" --no-check-certificate --continue
19 | # Check if wget failed (non-zero exit code)
20 | if [ $? -ne 0 ]; then
21 | echo "❌ Error downloading body data. Exiting script."
22 | exit 1
23 | fi
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Header-only library
2 |
3 | # Define 'cuMat_headers' variable to list all the header files
4 | # Todo: maybe replace by GLOB
5 | set(CUMAT_HEADERS
6 | src/Macros.h
7 | src/ForwardDeclarations.h
8 | src/Profiling.h
9 | src/Logging.h
10 | src/Errors.h
11 | src/Constants.h
12 | src/Context.h
13 | src/Allocator.h
14 | src/NumTraits.h
15 | src/DevicePointer.h
16 | src/EigenInteropHelpers.h
17 | src/MatrixBase.h
18 | src/Matrix.h
19 | src/IO.h
20 | src/CwiseOp.h
21 | src/MatrixBlock.h
22 | src/MatrixBlockPluginLvalue.inl
23 | src/MatrixBlockPluginRvalue.inl
24 | src/NullaryOps.h
25 | src/MatrixNullaryOpsPlugin.inl
26 | src/UnaryOps.h
27 | src/UnaryOpsPlugin.inl
28 | src/CudaUtils.h
29 | src/TransposeOp.h
30 | src/BinaryOps.h
31 | src/BinaryOpsPlugin.inl
32 | src/ReductionOps.h
33 | src/ReductionOpsPlugin.inl
34 | src/ReductionAlgorithmSelection.h
35 | src/Iterator.h
36 | src/CublasApi.h
37 | src/SimpleRandom.h
38 | src/ProductOp.h
39 | Core
40 |
41 | src/CusolverApi.h
42 | src/SolverBase.h
43 | src/DecompositionBase.h
44 | src/LUDecomposition.h
45 | src/CholeskyDecomposition.h
46 | src/DenseLinAlgOps.h
47 | src/DenseLinAlgPlugin.inl
48 | Dense
49 |
50 | src/SparseMatrixBase.h
51 | src/SparseMatrix.h
52 | src/SparseEvaluation.h
53 | src/SparseProductEvaluation.h
54 | src/SparseExpressionOp.h
55 | src/SparseExpressionOpPlugin.inl
56 | Sparse
57 |
58 | src/IterativeSolverBase.h
59 | src/ConjugateGradient.h
60 | IterativeLinearSolvers
61 | )
62 |
63 | add_library(cuMat INTERFACE) # 'moduleA' is an INTERFACE pseudo target
64 |
65 | #
66 | # From here, the target 'moduleA' can be customised
67 | #
68 | target_include_directories(cuMat INTERFACE ${CMAKE_SOURCE_DIR}) # Transitively forwarded
69 | target_include_directories(cuMat INTERFACE ${CUDA_INCLUDE_DIRS})
70 | #install(TARGETS cuMat ...)
71 |
72 | #
73 | # HACK: have the files showing in the IDE, under the name 'moduleA_ide'
74 | #
75 | option(CUMAT_SOURCE_LIBRARY "show the source headers as a library in the IDE" OFF)
76 | if(CUMAT_SOURCE_LIBRARY)
77 | cuda_add_library(cuMat_ide ${CUMAT_HEADERS})
78 | target_include_directories(cuMat_ide INTERFACE ${CMAKE_SOURCE_DIR}) # Transitively forwarded
79 | target_include_directories(cuMat_ide INTERFACE ${CUDA_INCLUDE_DIRS})
80 | endif(CUMAT_SOURCE_LIBRARY)
81 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/Core:
--------------------------------------------------------------------------------
1 |
2 |
3 | /*
4 | Specifies the core modules.
5 | This enables all BLAS-related modules
6 | */
7 |
8 | //first thing: tell the compiler to ignore some stuid warnings
9 | #include "src/DisableCompilerWarnings.h"
10 |
11 | #include "src/Macros.h"
12 | #include "src/ForwardDeclarations.h"
13 | #include "src/NumTraits.h"
14 |
15 | #include "src/Logging.h"
16 | #include "src/Errors.h"
17 | #include "src/Context.h"
18 |
19 | #include "src/MatrixBase.h"
20 | #include "src/Matrix.h"
21 | #include "src/MatrixBlock.h"
22 | #include "src/EigenInteropHelpers.h"
23 | #include "src/IO.h"
24 |
25 | #include "src/Iterator.h"
26 | #include "src/NullaryOps.h"
27 | #include "src/UnaryOps.h"
28 | #include "src/TransposeOp.h"
29 | #include "src/BinaryOps.h"
30 | #include "src/ReductionOps.h"
31 | #include "src/ProductOp.h"
32 |
33 | #include "src/SimpleRandom.h"
34 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/Dense:
--------------------------------------------------------------------------------
1 |
2 | /*
3 | * Specifies dense linear algebra routines (LAPACK)
4 | * This needs cuSolver.
5 | */
6 |
7 | // Core is always needed
8 | #include "Core"
9 |
10 | #include "src/CusolverApi.h"
11 | #include "src/LUDecomposition.h"
12 | #include "src/CholeskyDecomposition.h"
13 | #include "src/DenseLinAlgOps.h"
14 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/IterativeLinearSolvers:
--------------------------------------------------------------------------------
1 | /*
2 | * Specifies iterative linear solver routines.
3 | */
4 |
5 | // Core is always needed
6 | #include "Core"
7 |
8 | #include "src/IterativeSolverBase.h"
9 | #include "src/ConjugateGradient.h"
10 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/Sparse:
--------------------------------------------------------------------------------
1 | /*
2 | * Specifies sparse linear algebra routines
3 | * This needs cuSparse.
4 | */
5 |
6 | // Core is always needed
7 | #include "Core"
8 |
9 | // sparse classes
10 | #include "src/SparseMatrixBase.h"
11 | #include "src/SparseMatrix.h"
12 | #include "src/SparseExpressionOp.h"
13 | #include "src/SparseEvaluation.h"
14 | #include "src/SparseProductEvaluation.h"
15 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/Allocator.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_ALLOCATOR_H__
2 | #define __CUMAT_ALLOCATOR_H__
3 |
4 | #include "Macros.h"
5 |
6 | //TODO: add interface to support custom allocators
7 | //Up to now, the allocator from CUB is hard-coded
8 |
9 | CUMAT_NAMESPACE_BEGIN
10 |
11 | class AllocatorBase
12 | {
13 |
14 | };
15 |
16 | CUMAT_NAMESPACE_END
17 |
18 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/BinaryOpsPlugin.inl:
--------------------------------------------------------------------------------
1 | //Included inside MatrixBase, define the accessors
2 |
3 | #define BINARY_OP_ACCESSOR(Name) \
4 | template \
5 | BinaryOp<_Derived, _Right, functor::BinaryMathFunctor_ ## Name > Name (const MatrixBase<_Right>& rhs) const { \
6 | CUMAT_ERROR_IF_NO_NVCC(Name) \
7 | return BinaryOp<_Derived, _Right, functor::BinaryMathFunctor_ ## Name >(derived(), rhs.derived()); \
8 | } \
9 | template::value, \
10 | BinaryOp<_Derived, HostScalar, functor::BinaryMathFunctor_ ## Name > >::type > \
11 | T Name(const _Right& rhs) const { \
12 | CUMAT_ERROR_IF_NO_NVCC(Name) \
13 | return BinaryOp<_Derived, HostScalar, functor::BinaryMathFunctor_ ## Name >(derived(), HostScalar(rhs)); \
14 | }
15 | #define BINARY_OP_ACCESSOR_INV(Name) \
16 | template \
17 | BinaryOp<_Left, _Derived, functor::BinaryMathFunctor_ ## Name > Name ## Inv(const MatrixBase<_Left>& lhs) const { \
18 | CUMAT_ERROR_IF_NO_NVCC(Name) \
19 | return BinaryOp<_Left, _Derived, functor::BinaryMathFunctor_ ## Name >(lhs.derived(), derived()); \
20 | } \
21 | template::value, \
22 | BinaryOp, _Derived, functor::BinaryMathFunctor_ ## Name > >::type > \
23 | T Name ## Inv(const _Left& lhs) const { \
24 | CUMAT_ERROR_IF_NO_NVCC(Name) \
25 | return BinaryOp, _Derived, functor::BinaryMathFunctor_ ## Name >(HostScalar(lhs), derived()); \
26 | }
27 |
28 | /**
29 | * \brief computes the component-wise multiplation (this*rhs)
30 | */
31 | BINARY_OP_ACCESSOR(cwiseMul)
32 |
33 | /**
34 | * \brief computes the component-wise dot-product.
35 | * The dot product of every individual element is computed,
36 | * for regular matrices/vectors with scalar entries,
37 | * this is exactly equal to the component-wise multiplication (\ref cwiseMul).
38 | * However, one can also use Matrix with submatrices/vectors as entries
39 | * and then this operation might have the real dot-product
40 | * if the respective functor \ref functor::BinaryMathFunctor_cwiseMDot is specialized.
41 | */
42 | BINARY_OP_ACCESSOR(cwiseDot)
43 |
44 | /**
45 | * \brief computes the component-wise division (this/rhs)
46 | */
47 | BINARY_OP_ACCESSOR(cwiseDiv)
48 |
49 | /**
50 | * \brief computes the inverted component-wise division (rhs/this)
51 | */
52 | BINARY_OP_ACCESSOR_INV(cwiseDiv)
53 |
54 | /**
55 | * \brief computes the component-wise exponent (this^rhs)
56 | */
57 | BINARY_OP_ACCESSOR(cwisePow)
58 |
59 | /**
60 | * \brief computes the inverted component-wise exponent (rhs^this)
61 | */
62 | BINARY_OP_ACCESSOR_INV(cwisePow)
63 |
64 | /**
65 | * \brief computes the component-wise binary AND (this & rhs).
66 | * Only available for integer matrices.
67 | */
68 | BINARY_OP_ACCESSOR(cwiseBinaryAnd)
69 |
70 | /**
71 | * \brief computes the component-wise binary OR (this | rhs).
72 | * Only available for integer matrices.
73 | */
74 | BINARY_OP_ACCESSOR(cwiseBinaryOr)
75 |
76 | /**
77 | * \brief computes the component-wise binary XOR (this ^ rhs).
78 | * Only available for integer matrices.
79 | */
80 | BINARY_OP_ACCESSOR(cwiseBinaryXor)
81 |
82 | /**
83 | * \brief computes the component-wise logical AND (this && rhs).
84 | * Only available for boolean matrices
85 | */
86 | BINARY_OP_ACCESSOR(cwiseLogicalAnd)
87 |
88 | /**
89 | * \brief computes the component-wise logical OR (this || rhs).
90 | * Only available for boolean matrices
91 | */
92 | BINARY_OP_ACCESSOR(cwiseLogicalOr)
93 |
94 | /**
95 | * \brief computes the component-wise logical XOR (this ^ rhs).
96 | * Only available for boolean matrices
97 | */
98 | BINARY_OP_ACCESSOR(cwiseLogicalXor)
99 |
100 | #undef BINARY_OP_ACCESSOR
101 | #undef BINARY_OP_ACCESSOR_INV
102 |
103 | /**
104 | * \brief Custom binary expression.
105 | * The binary functor must support look as follow:
106 | * \code
107 | * struct MyFunctor
108 | * {
109 | * typedef OutputType ReturnType;
110 | * __device__ CUMAT_STRONG_INLINE ReturnType operator()(const LeftType& x, const RightType& y, Index row, Index col, Index batch) const
111 | * {
112 | * return ...
113 | * }
114 | * };
115 | * \endcode
116 | * where \c LeftType is the type of this matrix expression,
117 | * \c RightType is the type of the rhs matrix,
118 | * and \c OutputType is the output type.
119 | *
120 | * \param rhs the matrix expression on the right hand side
121 | * \param functor the functor to apply component-wise
122 | * \return an expression of a component-wise binary expression with a custom functor applied per component.
123 | */
124 | template
125 | UnaryOp<_Derived, Functor> binaryExpr(const Right& rhs, const Functor& functor = Functor()) const
126 | {
127 | CUMAT_ERROR_IF_NO_NVCC(binaryExpr)
128 | return BinaryOp<_Derived, Right, Functor>(derived(), rhs.derived(), functor);
129 | }
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/Constants.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_CONSTANTS_H__
2 | #define __CUMAT_CONSTANTS_H__
3 |
4 |
5 | #include "Macros.h"
6 |
7 | CUMAT_NAMESPACE_BEGIN
8 |
9 | /**
10 | * \brief This value means that a positive quantity (e.g., a size) is not known at compile-time,
11 | * and that instead the value is stored in some runtime variable.
12 | */
13 | const int Dynamic = -1;
14 |
15 | /**
16 | * \brief The bit flags for the matrix expressions.
17 | */
18 | enum Flags
19 | {
20 | /**
21 | * \brief The storage is column major (the default).
22 | */
23 | ColumnMajor = 0x00,
24 | /**
25 | * \brief The storage is row major.
26 | */
27 | RowMajor = 0x01,
28 |
29 | };
30 | #define CUMAT_IS_COLUMN_MAJOR(flags) (((flags) & CUMAT_NAMESPACE Flags::RowMajor)==0)
31 | #define CUMAT_IS_ROW_MAJOR(flags) ((flags) & CUMAT_NAMESPACE Flags::RowMajor)
32 |
33 | /**
34 | * \brief Flags that specify how the data in a MatrixBase-expression can be accessed.
35 | */
36 | enum AccessFlags
37 | {
38 | /**
39 | * \brief component-wise read is available.
40 | * The following method must be provided:
41 | * \code
42 | * __device__ const Scalar& coeff(Index row, Index col, Index batch, Index index) const;
43 | * \endcode
44 | * The parameter \c index is the same as the linear index from the writing procedure and is used
45 | * by optimized routines only if the user explicitly enables them.
46 | * (This is only supported by SparseMatrix yet)
47 | * If some operation can't pass the linear index to the expressions, -1 might be used instead.
48 | */
49 | ReadCwise = 0x01,
50 | /**
51 | * \brief direct read is available, the underlying memory is directly adressable.
52 | * The following method must be provided:
53 | * \code
54 | * __host__ __device__ const _Scalar* data() const;
55 | * __host__ bool isExclusiveUse() const;
56 | * \endcode
57 | */
58 | ReadDirect = 0x02,
59 | /**
60 | * \brief Component-wise read is available.
61 | * To allow the implementation to specify the access order, the following methods have to be provided:
62 | * \code
63 | * __host__ Index size() const;
64 | * __device__ void index(Index index, Index& row, Index& col, Index& batch) const;
65 | * __device__ void setRawCoeff(Index index, const Scalar& newValue);
66 | * \endcode
67 | */
68 | WriteCwise = 0x10,
69 | /**
70 | * \brief Direct write is available, the underlying memory can be directly written to.
71 | * The following method must be provided:
72 | * \code
73 | * __host__ __device__ _Scalar* data()
74 | * \endcode
75 | */
76 | WriteDirect = 0x20,
77 | /**
78 | * \brief This extends \c WriteCwise and allows inplace modifications (compound operators) by additionally providing the function
79 | * \code __device__ const Scalar& getRawCoeff(Index index) const; \endcode .
80 | * To enable compound assignment with this as target type, either RWCwise or RWCwiseRef (or both) have to be defined.
81 | */
82 | RWCwise = 0x40,
83 | /**
84 | * \brief This extends \c WriteCwise and allows inplace modifications (compound operators) by additionally providing the function
85 | * \code __device__ Scalar& rawCoeff(Index index); \endcode for read-write access to that entry.
86 | * To enable compound assignment with this as target type, either RWCwise or RWCwiseRef (or both) have to be defined.
87 | */
88 | RWCwiseRef = 0x80,
89 | };
90 |
91 | /**
92 | * \brief The axis over which reductions are performed.
93 | */
94 | enum Axis
95 | {
96 | NoAxis = 0,
97 | Row = 1,
98 | Column = 2,
99 | Batch = 4,
100 | All = Row | Column | Batch
101 | };
102 |
103 | /**
104 | * \brief Tags for the different reduction algorithms
105 | */
106 | namespace ReductionAlg
107 | {
108 | /**
109 | * \brief reduction with cub::DeviceSegmentedReduce
110 | */
111 | struct Segmented {};
112 | /**
113 | * \brief Thread reduction. Each thread reduces a batch.
114 | */
115 | struct Thread {};
116 | /**
117 | * \brief Warp reduction. Each warp reduces a batch.
118 | */
119 | struct Warp {};
120 | /**
121 | * \brief Block reduction. Each block reduces a batch.
122 | * \tparam N the block size
123 | */
124 | template
125 | struct Block {};
126 | /**
127 | * \brief Device reduction.
128 | * Each reduction per batch is computed with a separate call to cub::DeviceReduce,
129 | * parallelized over N cuda streams.
130 | * \tparam N the number of parallel streams
131 | */
132 | template
133 | struct Device {};
134 | /**
135 | * \brief Automatic algorithm selection.
136 | * Chooses the algorithm during runtime based on the matrix sizes.
137 | */
138 | struct Auto {};
139 | }
140 |
141 | /**
142 | * \brief Specifies the assignment mode in \c Assignment::assign() .
143 | * This is the difference between regular assignment (operator==, \c AssignmentMode::ASSIGN)
144 | * and inplace modifications like operator+= (\c AssignmentMode::ADD).
145 | *
146 | * Note that not all assignment modes have to be supported for all scalar types
147 | * and all right hand sides.
148 | * For example:
149 | * - MUL (=*) and DIV (=\) are only supported for scalar right hand sides (broadcasting)
150 | * to avoid the ambiguity if component-wise or matrix operations are meant
151 | * - MOD (%=), AND (&=), OR (|=) are only supported for integer types
152 | */
153 | enum class AssignmentMode
154 | {
155 | ASSIGN,
156 | ADD,
157 | SUB,
158 | MUL,
159 | DIV,
160 | MOD,
161 | AND,
162 | OR,
163 | };
164 |
165 | /**
166 | * \brief Flags for the SparseMatrix class.
167 | */
168 | enum SparseFlags
169 | {
170 | /**
171 | * \brief Matrix stored in the Compressed Sparse Column format.
172 | */
173 | CSC = 1,
174 | /**
175 | * \brief Matrix stored in the Compressed Sparse Row format.
176 | */
177 | CSR = 2,
178 | /**
179 | * \brief Column-major ELLPACK format.
180 | * This format is optimized for matrices with uniform nnz per row.
181 | */
182 | ELLPACK = 3,
183 | };
184 |
185 | CUMAT_NAMESPACE_END
186 |
187 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/CudaUtils.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_CUDA_UTILS_H__
2 | #define __CUMAT_CUDA_UTILS_H__
3 |
4 | #include "Macros.h"
5 |
6 | // SOME CUDA UTILITIES
7 |
8 | CUMAT_NAMESPACE_BEGIN
9 |
10 | namespace cuda
11 | {
12 | /**
13 | * \brief Loads the given scalar value from the specified address,
14 | * possible cached.
15 | * \tparam T the type of scalar
16 | * \param ptr the pointer to the scalar
17 | * \return the value at that adress
18 | */
19 | template
20 | __device__ CUMAT_STRONG_INLINE T load(const T* ptr)
21 | {
22 | //#if __CUDA_ARCH__ >= 350
23 | // return __ldg(ptr);
24 | //#else
25 | return *ptr;
26 | //#endif
27 | }
28 | #if __CUDA_ARCH__ >= 350
29 | #define LOAD(T) \
30 | template<> \
31 | __device__ CUMAT_STRONG_INLINE T load(const T* ptr) \
32 | { \
33 | return __ldg(ptr); \
34 | }
35 | #else
36 | #define LOAD(T)
37 | #endif
38 |
39 | LOAD(char);
40 | LOAD(short);
41 | LOAD(int);
42 | LOAD(long);
43 | LOAD(long long);
44 | LOAD(unsigned char);
45 | LOAD(unsigned short);
46 | LOAD(unsigned int);
47 | LOAD(unsigned long);
48 | LOAD(unsigned long long);
49 | LOAD(int2);
50 | LOAD(int4);
51 | LOAD(uint2);
52 | LOAD(uint4);
53 | LOAD(float);
54 | LOAD(float2);
55 | LOAD(float4);
56 | LOAD(double);
57 | LOAD(double2);
58 |
59 | #undef LOAD
60 |
61 | }
62 |
63 | CUMAT_NAMESPACE_END
64 |
65 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/CusolverApi.h:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/extensions/include/cuMat/src/CusolverApi.h
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/CwiseOp.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_CWISE_OP_H__
2 | #define __CUMAT_CWISE_OP_H__
3 |
4 | #include "Macros.h"
5 | #include "ForwardDeclarations.h"
6 | #include "MatrixBase.h"
7 | #include "Context.h"
8 | #include "Logging.h"
9 | #include "Profiling.h"
10 |
11 | CUMAT_NAMESPACE_BEGIN
12 |
13 | namespace internal {
14 | //compound assignment
15 | template
16 | struct CwiseAssignmentHandler
17 | {
18 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index)
19 | {
20 | matrix.setRawCoeff(index, value);
21 | }
22 | };
23 |
24 | #define CUMAT_DECLARE_ASSIGNMENT_HANDLER(mode, op1, op2) \
25 | template \
26 | struct CwiseAssignmentHandler \
27 | { \
28 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index, std::integral_constant) /*reference access*/ \
29 | { \
30 | matrix.rawCoeff(index) op1 value; \
31 | } \
32 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index, std::integral_constant) /*read-write access*/ \
33 | { \
34 | matrix.setRawCoeff(index) = matrix.getRawCoeff(index) op2 value; \
35 | } \
36 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index, std::integral_constant) /*both is possible, use reference*/ \
37 | { \
38 | matrix.rawCoeff(index) op1 value; \
39 | } \
40 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index) \
41 | { \
42 | assign(matrix, value, index, std::integral_constant::AccessFlags & (AccessFlags::RWCwiseRef | AccessFlags::RWCwise)>()); \
43 | } \
44 | };
45 |
46 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(ADD, +=, +)
47 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(SUB, -=, -)
48 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(MUL, *=, *)
49 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(DIV, /=, /)
50 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(MOD, %=, %)
51 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(AND, &=, &)
52 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(OR, |=, |)
53 |
54 | //partial specialization for normal assignment
55 | template
56 | struct CwiseAssignmentHandler
57 | {
58 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index)
59 | {
60 | matrix.setRawCoeff(index, value);
61 | }
62 | };
63 |
64 | namespace kernels
65 | {
66 | template
67 | __global__ void CwiseEvaluationKernel(dim3 virtual_size, const T expr, M matrix)
68 | {
69 | //By using a 1D-loop over the linear index,
70 | //the target matrix can determine the order of rows, columns and batches.
71 | //E.g. by storage order (row major / column major)
72 | //Later, this may come in hand if sparse matrices or diagonal matrices are allowed
73 | //that only evaluate certain elements.
74 | CUMAT_KERNEL_1D_LOOP(index, virtual_size)
75 |
76 | Index i, j, k;
77 | matrix.index(index, i, j, k);
78 |
79 | //there seems to be a bug in CUDA if the result of expr.coeff is directly passed to setRawCoeff.
80 | //By saving it in a local variable, this is prevented
81 | auto val = expr.coeff(i, j, k, index);
82 | internal::CwiseAssignmentHandler::assign(matrix, val, index);
83 |
84 | CUMAT_KERNEL_1D_LOOP_END
85 | }
86 | }
87 | }
88 |
89 | /**
90 | * \brief Base class of all component-wise expressions.
91 | * It defines the evaluation logic.
92 | *
93 | * A component-wise expression can be evaluated to any object that
94 | * - inherits MatrixBase
95 | * - defines a __host__ Index size() const method that returns the number of entries
96 | * - defines a __device__ void index(Index index, Index& row, Index& col, Index& batch) const
97 | * method to convert from raw index (from 0 to size()-1) to row, column and batch index
98 | * - defines a __Device__ void setRawCoeff(Index index, const Scalar& newValue) method
99 | * that is used to write the results back.
100 | *
101 | * Currently, the following classes support this interface and can therefore be used
102 | * as the left-hand-side of a component-wise expression:
103 | * - Matrix
104 | * - MatrixBlock
105 | *
106 | * \tparam _Derived the type of the derived expression
107 | */
108 | template
109 | class CwiseOp : public MatrixBase<_Derived>
110 | {
111 | public:
112 | typedef _Derived Type;
113 | typedef MatrixBase<_Derived> Base;
114 | CUMAT_PUBLIC_API
115 | using Base::rows;
116 | using Base::cols;
117 | using Base::batches;
118 | using Base::size;
119 |
120 | __device__ CUMAT_STRONG_INLINE const Scalar& coeff(Index row, Index col, Index batch, Index index) const
121 | {
122 | return derived().coeff(row, col, batch, index);
123 | }
124 |
125 | };
126 |
127 | namespace internal
128 | {
129 | //General assignment for everything that fullfills CwiseSrcTag into DenseDstTag (cwise dense evaluation)
130 | template
131 | struct Assignment<_Dst, _Src, _Mode, DenseDstTag, CwiseSrcTag>
132 | {
133 | static void assign(_Dst& dst, const _Src& src)
134 | {
135 | #if CUMAT_NVCC==1
136 | typedef typename _Dst::Type DstActual;
137 | typedef typename _Src::Type SrcActual;
138 | CUMAT_PROFILING_INC(EvalCwise);
139 | CUMAT_PROFILING_INC(EvalAny);
140 | if (dst.size() == 0) return;
141 | CUMAT_ASSERT(src.rows() == dst.rows());
142 | CUMAT_ASSERT(src.cols() == dst.cols());
143 | CUMAT_ASSERT(src.batches() == dst.batches());
144 |
145 | CUMAT_LOG_DEBUG("Evaluate component wise expression " << typeid(src.derived()).name()
146 | << "\n rows=" << src.rows() << ", cols=" << src.cols() << ", batches=" << src.batches());
147 |
148 | //here is now the real logic
149 | Context& ctx = Context::current();
150 | KernelLaunchConfig cfg = ctx.createLaunchConfig1D(static_cast(dst.size()), kernels::CwiseEvaluationKernel);
151 | kernels::CwiseEvaluationKernel <<>> (cfg.virtual_size, src.derived(), dst.derived());
152 | CUMAT_CHECK_ERROR();
153 | CUMAT_LOG_DEBUG("Evaluation done");
154 | #else
155 | CUMAT_ERROR_IF_NO_NVCC(general_component_wise_evaluation)
156 | #endif
157 | }
158 | };
159 | }
160 |
161 | CUMAT_NAMESPACE_END
162 |
163 | #endif
164 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/DecompositionBase.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_DECOMPOSITION_BASE_H__
2 | #define __CUMAT_DECOMPOSITION_BASE_H__
3 |
4 | #include "Macros.h"
5 | #include "SolverBase.h"
6 |
7 | CUMAT_NAMESPACE_BEGIN
8 |
9 | template
10 | class DecompositionBase : public SolverBase<_DecompositionImpl>
11 | {
12 | public:
13 | using Base = SolverBase<_DecompositionImpl>;
14 | using typename Base::Scalar;
15 | using Base::Rows;
16 | using Base::Columns;
17 | using Base::Batches;
18 | using Base::impl;
19 |
20 | typedef SolveOp<_DecompositionImpl, NullaryOp > > InverseResultType;
21 | /**
22 | * \brief Computes the inverse of the input matrix.
23 | * \return The inverse matrix
24 | */
25 | InverseResultType inverse() const
26 | {
27 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(Rows > 0 && Columns > 0, Rows == Columns),
28 | "Static count of rows and columns must be equal (square matrix)");
29 | CUMAT_ASSERT(impl().rows() == impl().cols());
30 |
31 | return impl().solve(NullaryOp >(
32 | impl().rows(), impl().cols(), impl().batches(), functor::IdentityFunctor()));
33 | }
34 |
35 | typedef Matrix DeterminantMatrix;
36 |
37 | /**
38 | * \brief Computes the determinant of this matrix
39 | * \return The determinant
40 | */
41 | DeterminantMatrix determinant() const
42 | {
43 | return impl().determinant();
44 | }
45 | /**
46 | * \brief Computes the log-determinant of this matrix.
47 | * \return The log-determinant
48 | */
49 | DeterminantMatrix logDeterminant() const
50 | {
51 | return impl().logDeterminant();
52 | }
53 | };
54 |
55 | CUMAT_NAMESPACE_END
56 |
57 | #endif
58 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/DenseLinAlgPlugin.inl:
--------------------------------------------------------------------------------
1 | //Included inside MatrixBase, define the accessors
2 |
3 | /**
4 | * \brief Computes and returns the LU decomposition with pivoting of this matrix.
5 | * The resulting decomposition can then be used to compute the determinant of the matrix,
6 | * invert the matrix and solve multiple linear equation systems.
7 | */
8 | LUDecomposition<_Derived> decompositionLU() const
9 | {
10 | return LUDecomposition<_Derived>(derived());
11 | }
12 |
13 | /**
14 | * \brief Computes and returns the Cholesky decompositionof this matrix.
15 | * The matrix must be Hermetian and positive definite.
16 | * The resulting decomposition can then be used to compute the determinant of the matrix,
17 | * invert the matrix and solve multiple linear equation systems.
18 | */
19 | CholeskyDecomposition<_Derived> decompositionCholesky() const
20 | {
21 | return CholeskyDecomposition<_Derived>(derived());
22 | }
23 |
24 | /**
25 | * \brief Computes the determinant of this matrix.
26 | * \return the determinant of this matrix
27 | */
28 | DeterminantOp<_Derived> determinant() const
29 | {
30 | return DeterminantOp<_Derived>(derived());
31 | }
32 |
33 | /**
34 | * \brief Computes the log-determinant of this matrix.
35 | * This is only supported for hermitian positive definite matrices, because no sign is computed.
36 | * A negative determinant would return in an complex logarithm which requires to return
37 | * a complex result for real matrices. This is not desired.
38 | * \return the log-determinant of this matrix
39 | */
40 | Matrix::Scalar, 1, 1, internal::traits<_Derived>::BatchesAtCompileTime, ColumnMajor> logDeterminant() const
41 | {
42 | //TODO: implement direct methods for matrices up to 4x4.
43 | return decompositionLU().logDeterminant();
44 | }
45 |
46 | /**
47 | * \brief Computes the determinant of this matrix.
48 | * For matrices of up to 4x4, an explicit formula is used. For larger matrices, this method falls back to a Cholesky Decomposition.
49 | * \return the inverse of this matrix
50 | */
51 | InverseOp<_Derived> inverse() const
52 | {
53 | return InverseOp<_Derived>(derived());
54 | }
55 |
56 | /**
57 | * \brief Computation of matrix inverse and determinant in one kernel call.
58 | *
59 | * This is only for fixed-size square matrices of size up to 4x4.
60 | *
61 | * \param inverse Reference to the matrix in which to store the inverse.
62 | * \param determinant Reference to the variable in which to store the determinant.
63 | *
64 | * \see inverse(), determinant()
65 | */
66 | template
67 | void computeInverseAndDet(InverseType& inverseOut, DetType& detOut) const
68 | {
69 | CUMAT_STATIC_ASSERT(Rows >= 1 && Rows <= 4, "This matrix must be a compile-time 1x1, 2x2, 3x3 or 4x4 matrix");
70 | CUMAT_STATIC_ASSERT(Columns >= 1 && Columns <= 4, "This matrix must be a compile-time 1x1, 2x2, 3x3 or 4x4 matrix");
71 | CUMAT_STATIC_ASSERT(Rows == Columns, "This matrix must be symmetric");
72 | CUMAT_STATIC_ASSERT(Rows >= 1 && internal::traits::RowsAtCompileTime, "The output matrix must have the same compile-time size as this matrix");
73 | CUMAT_STATIC_ASSERT(Columns >= 1 && internal::traits::ColsAtCompileTime, "The output matrix must have the same compile-time size as this matrix");
74 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(Batches > 0 && internal::traits::BatchesAtCompileTime > 0, Batches == internal::traits::BatchesAtCompileTime),
75 | "This matrix and the output matrix must have the same batch size");
76 | CUMAT_ASSERT_DIMENSION(batches() == inverseOut.batches());
77 |
78 | CUMAT_STATIC_ASSERT(internal::traits::RowsAtCompileTime == 1, "The determinant output must be a (batched) scalar, i.e. compile-time 1x1 matrix");
79 | CUMAT_STATIC_ASSERT(internal::traits::ColsAtCompileTime == 1, "The determinant output must be a (batched) scalar, i.e. compile-time 1x1 matrix");
80 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(Batches > 0 && internal::traits::BatchesAtCompileTime > 0, Batches == internal::traits::BatchesAtCompileTime),
81 | "This matrix and the determinant matrix must have the same batch size");
82 | CUMAT_ASSERT_DIMENSION(batches() == detOut.batches());
83 |
84 | CUMAT_NAMESPACE ComputeInverseWithDet<_Derived, Rows, InverseType, DetType>::run(derived(), inverseOut, detOut);
85 | }
86 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/DevicePointer.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_DEVICE_POINTER_H__
2 | #define __CUMAT_DEVICE_POINTER_H__
3 |
4 | #include
5 |
6 | #include "Macros.h"
7 | #include "Context.h"
8 |
9 | CUMAT_NAMESPACE_BEGIN
10 |
11 | template
12 | class DevicePointer
13 | {
14 | private:
15 | T* pointer_;
16 | int* counter_;
17 | CUMAT_NAMESPACE Context* context_;
18 | friend class DevicePointer::type>;
19 |
20 | __host__ __device__
21 | void release()
22 | {
23 | #ifndef __CUDA_ARCH__
24 | //no decrement of the counter in CUDA-code, counter is in host-memory
25 | assert(CUMAT_IMPLIES(counter_, (*counter_) > 0) && "Attempt to calling release() twice");
26 | if ((counter_) && (--(*counter_) == 0))
27 | {
28 | delete counter_;
29 |
30 | //DEBUG
31 | if (&Context::current() != context_)
32 | {
33 | CUMAT_LOG_WARNING(
34 | "Freeing memory with a different context than the current context.\n"
35 | "This will likely crash with an invalid-resource-handle error due to different Cub-Allocators");
36 | }
37 |
38 | context_->freeDevice(pointer_);
39 | }
40 | #endif
41 | }
42 |
43 | public:
44 | DevicePointer(size_t size, CUMAT_NAMESPACE Context& ctx)
45 | : pointer_(nullptr)
46 | , counter_(nullptr)
47 | {
48 | context_ = &ctx;
49 | pointer_ = static_cast(context_->mallocDevice(size * sizeof(T)));
50 | try {
51 | counter_ = new int(1);
52 | }
53 | catch (...)
54 | {
55 | context_->freeDevice(const_cast::type*>(pointer_));
56 | throw;
57 | }
58 | }
59 | DevicePointer(size_t size)
60 | : DevicePointer(size, CUMAT_NAMESPACE Context::current())
61 | {}
62 |
63 | __host__ __device__
64 | DevicePointer()
65 | : pointer_(nullptr)
66 | , counter_(nullptr)
67 | , context_(nullptr)
68 | {}
69 |
70 | __host__ __device__
71 | DevicePointer(const DevicePointer& rhs)
72 | : pointer_(rhs.pointer_)
73 | , counter_(rhs.counter_)
74 | , context_(rhs.context_)
75 | {
76 | #ifndef __CUDA_ARCH__
77 | //no increment of the counter in CUDA-code, counter is in host-memory
78 | if (counter_) {
79 | ++(*counter_);
80 | }
81 | #endif
82 | }
83 |
84 | __host__ __device__
85 | DevicePointer(DevicePointer&& rhs) noexcept
86 | : pointer_(std::move(rhs.pointer_))
87 | , counter_(std::move(rhs.counter_))
88 | , context_(std::move(rhs.context_))
89 | {
90 | rhs.pointer_ = nullptr;
91 | rhs.counter_ = nullptr;
92 | rhs.context_ = nullptr;
93 | }
94 |
95 | __host__ __device__
96 | DevicePointer& operator=(const DevicePointer& rhs)
97 | {
98 | release();
99 | pointer_ = rhs.pointer_;
100 | counter_ = rhs.counter_;
101 | context_ = rhs.context_;
102 | #ifndef __CUDA_ARCH__
103 | //no increment of the counter in CUDA-code, counter is in host-memory
104 | if (counter_) {
105 | ++(*counter_);
106 | }
107 | #endif
108 | return *this;
109 | }
110 |
111 | __host__ __device__
112 | DevicePointer& operator=(DevicePointer&& rhs) noexcept
113 | {
114 | release();
115 | pointer_ = std::move(rhs.pointer_);
116 | counter_ = std::move(rhs.counter_);
117 | context_ = std::move(rhs.context_);
118 | rhs.pointer_ = nullptr;
119 | rhs.counter_ = nullptr;
120 | rhs.context_ = nullptr;
121 | return *this;
122 | }
123 |
124 | __host__ __device__
125 | void swap(DevicePointer& rhs) throw()
126 | {
127 | std::swap(pointer_, rhs.pointer_);
128 | std::swap(counter_, rhs.counter_);
129 | std::swap(context_, rhs.context_);
130 | }
131 |
132 | __host__ __device__
133 | ~DevicePointer()
134 | {
135 | release();
136 | }
137 |
138 | __host__ __device__ T* pointer() { return pointer_; }
139 | __host__ __device__ const T* pointer() const { return pointer_; }
140 |
141 | /**
142 | * \brief Returns the current value of the reference counter.
143 | * This can be used to determine if this memory is used uniquely
144 | * by an object.
145 | * \return the current number of references
146 | */
147 | size_t getCounter() const { return counter_ ? *counter_ : 0; }
148 | };
149 |
150 | CUMAT_NAMESPACE_END
151 |
152 | #endif
153 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/DisableCompilerWarnings.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_DISABLE_COMPILER_WARNINGS_H__
2 | #define __CUMAT_DISABLE_COMPILER_WARNINGS_H__
3 |
4 |
5 | #if defined __GNUC__
6 |
7 | //This is a stupid warning in GCC that pops up everytime storage qualifies or so are compared
8 | #pragma GCC diagnostic ignored "-Wenum-compare"
9 |
10 | #endif
11 |
12 |
13 | #endif
14 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/EigenInteropHelpers.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_EIGEN_INTEROP_HELPERS_H__
2 | #define __CUMAT_EIGEN_INTEROP_HELPERS_H__
3 |
4 |
5 | #include "Macros.h"
6 | #include "Constants.h"
7 |
8 | #if CUMAT_EIGEN_SUPPORT==1
9 | #include
10 | #include
11 |
12 | CUMAT_NAMESPACE_BEGIN
13 |
14 | //forward-declare Matrix
15 | template
16 | class Matrix;
17 |
18 | /**
19 | * Namespace for eigen interop.
20 | */
21 | namespace eigen
22 | {
23 | //There are now a lot of redundant qualifiers, but I want to make
24 | //it clear if we are in the Eigen world or in cuMat world.
25 |
26 | //Flag conversion
27 |
28 | template
29 | struct StorageCuMatToEigen {};
30 | template<>
31 | struct StorageCuMatToEigen<::cuMat::Flags::ColumnMajor>
32 | {
33 | enum { value = ::Eigen::StorageOptions::ColMajor };
34 | };
35 | template<>
36 | struct StorageCuMatToEigen<::cuMat::Flags::RowMajor>
37 | {
38 | enum { value = ::Eigen::StorageOptions::RowMajor };
39 | };
40 |
41 | template
42 | struct StorageEigenToCuMat {};
43 | template<>
44 | struct StorageEigenToCuMat<::Eigen::StorageOptions::RowMajor>
45 | {
46 | enum { value = ::cuMat::Flags::RowMajor };
47 | };
48 | template<>
49 | struct StorageEigenToCuMat<::Eigen::StorageOptions::ColMajor>
50 | {
51 | enum { value = ::cuMat::Flags::ColumnMajor };
52 | };
53 |
54 | //Size conversion (for dynamic tag)
55 |
56 | template
57 | struct SizeCuMatToEigen
58 | {
59 | enum { size = _Size };
60 | };
61 | template<>
62 | struct SizeCuMatToEigen<::cuMat::Dynamic>
63 | {
64 | enum {size = ::Eigen::Dynamic};
65 | };
66 |
67 | template
68 | struct SizeEigenToCuMat
69 | {
70 | enum {size = _Size};
71 | };
72 | template<>
73 | struct SizeEigenToCuMat<::Eigen::Dynamic>
74 | {
75 | enum{size = ::cuMat::Dynamic};
76 | };
77 |
78 | //Type conversion
79 | template
80 | struct TypeCuMatToEigen
81 | {
82 | typedef T type;
83 | };
84 | template<>
85 | struct TypeCuMatToEigen
86 | {
87 | typedef std::complex type;
88 | };
89 | template<>
90 | struct TypeCuMatToEigen
91 | {
92 | typedef std::complex type;
93 | };
94 |
95 | template
96 | struct TypeEigenToCuMat
97 | {
98 | typedef T type;
99 | };
100 | template<>
101 | struct TypeEigenToCuMat >
102 | {
103 | typedef cfloat type;
104 | };
105 | template<>
106 | struct TypeEigenToCuMat >
107 | {
108 | typedef cdouble type;
109 | };
110 |
111 | //Matrix type conversion
112 |
113 | template
114 | struct MatrixCuMatToEigen
115 | {
116 | using type = ::Eigen::Matrix<
117 | //typename TypeCuMatToEigen::type,
118 | typename _CuMatMatrixType::Scalar,
119 | SizeCuMatToEigen<_CuMatMatrixType::Rows>::size,
120 | SizeCuMatToEigen<_CuMatMatrixType::Columns>::size,
121 | //Eigen requires specific storage types for vector sizes
122 | ((_CuMatMatrixType::Rows==1) ? ::Eigen::StorageOptions::RowMajor
123 | : (_CuMatMatrixType::Columns==1) ? ::Eigen::StorageOptions::ColMajor
124 | : StorageCuMatToEigen<_CuMatMatrixType::Flags>::value)
125 | | ::Eigen::DontAlign //otherwise, toEigen() will produce strange errors because we access the native data pointer
126 | >;
127 | };
128 | template
129 | struct MatrixEigenToCuMat
130 | {
131 | using type = ::cuMat::Matrix<
132 | //typename TypeEigenToCuMat::type,
133 | typename _EigenMatrixType::Scalar,
134 | SizeEigenToCuMat<_EigenMatrixType::RowsAtCompileTime>::size,
135 | SizeEigenToCuMat<_EigenMatrixType::ColsAtCompileTime>::size,
136 | 1, //batch size of 1
137 | StorageEigenToCuMat<_EigenMatrixType::Options>::value
138 | >;
139 | };
140 |
141 | }
142 |
143 | CUMAT_NAMESPACE_END
144 |
145 | //tell Eigen how to handle cfloat and cdouble
146 | namespace Eigen
147 | {
148 | template<> struct NumTraits : NumTraits> {};
149 | template<> struct NumTraits : NumTraits> {};
150 | }
151 |
152 |
153 | #endif
154 |
155 | #endif
156 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/Errors.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_ERRORS_H__
2 | #define __CUMAT_ERRORS_H__
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | #include "Macros.h"
13 | #include "Logging.h"
14 |
15 | CUMAT_NAMESPACE_BEGIN
16 |
17 | class cuda_error : public std::exception
18 | {
19 | private:
20 | std::string message_;
21 | public:
22 | cuda_error(std::string message)
23 | : message_(message)
24 | {}
25 |
26 | const char* what() const throw() override
27 | {
28 | return message_.c_str();
29 | }
30 | };
31 |
32 | namespace internal {
33 |
34 | class ErrorHelpers
35 | {
36 | public:
37 | static std::string vformat(const char *fmt, va_list ap)
38 | {
39 | // Allocate a buffer on the stack that's big enough for us almost
40 | // all the time. Be prepared to allocate dynamically if it doesn't fit.
41 | size_t size = 1024;
42 | char stackbuf[1024];
43 | std::vector dynamicbuf;
44 | char *buf = &stackbuf[0];
45 | va_list ap_copy;
46 |
47 | while (1) {
48 | // Try to vsnprintf into our buffer.
49 | va_copy(ap_copy, ap);
50 | int needed = vsnprintf(buf, size, fmt, ap);
51 | va_end(ap_copy);
52 |
53 | // NB. C99 (which modern Linux and OS X follow) says vsnprintf
54 | // failure returns the length it would have needed. But older
55 | // glibc and current Windows return -1 for failure, i.e., not
56 | // telling us how much was needed.
57 |
58 | if (needed <= (int)size && needed >= 0) {
59 | // It fit fine so we're done.
60 | return std::string(buf, (size_t)needed);
61 | }
62 |
63 | // vsnprintf reported that it wanted to write more characters
64 | // than we allotted. So try again using a dynamic buffer. This
65 | // doesn't happen very often if we chose our initial size well.
66 | size = (needed > 0) ? (needed + 1) : (size * 2);
67 | dynamicbuf.resize(size);
68 | buf = &dynamicbuf[0];
69 | }
70 | }
71 | //Taken from https://stackoverflow.com/a/69911/4053176
72 |
73 | static std::string format(const char *fmt, ...)
74 | {
75 | va_list ap;
76 | va_start(ap, fmt);
77 | std::string buf = vformat(fmt, ap);
78 | va_end(ap);
79 | return buf;
80 | }
81 | //Taken from https://stackoverflow.com/a/69911/4053176
82 |
83 | // Taken from https://codeyarns.com/2011/03/02/how-to-do-error-checking-in-cuda/
84 | // and adopted
85 |
86 | private:
87 | static bool evalError(cudaError err, const char* file, const int line)
88 | {
89 | if (cudaErrorCudartUnloading == err) {
90 | std::string msg = format("cudaCheckError() failed at %s:%i : %s\nThis error can happen in multi-threaded applications during shut-down and is ignored.\n",
91 | file, line, cudaGetErrorString(err));
92 | CUMAT_LOG_SEVERE(msg);
93 | return false;
94 | }
95 | else if (cudaSuccess != err) {
96 | std::string msg = format("cudaSafeCall() failed at %s:%i : %s\n",
97 | file, line, cudaGetErrorString(err));
98 | CUMAT_LOG_SEVERE(msg);
99 | throw cuda_error(msg);
100 | }
101 | return true;
102 | }
103 | static bool evalErrorNoThrow(cudaError err, const char* file, const int line)
104 | {
105 | if (cudaSuccess != err) {
106 | std::string msg = format("cudaSafeCall() failed at %s:%i : %s\n",
107 | file, line, cudaGetErrorString(err));
108 | CUMAT_LOG_SEVERE(msg);
109 | return false;
110 | }
111 | return true;
112 | }
113 | public:
114 | static void cudaSafeCall(cudaError err, const char *file, const int line)
115 | {
116 | if (!evalError(err, file, line)) return;
117 | #if CUMAT_VERBOSE_ERROR_CHECKING==1
118 | //insert a device-sync
119 | err = cudaDeviceSynchronize();
120 | evalError(err, file, line);
121 | #endif
122 | }
123 |
124 | static bool cudaSafeCallNoThrow(cudaError err, const char* file, const int line)
125 | {
126 | if (!evalErrorNoThrow(err, file, line)) return false;
127 | #if CUMAT_VERBOSE_ERROR_CHECKING==1
128 | //insert a device-sync
129 | err = cudaDeviceSynchronize();
130 | if (!evalError(err, file, line)) return false;
131 | #endif
132 | return true;
133 | }
134 |
135 | static void cudaCheckError(const char *file, const int line)
136 | {
137 | cudaError err = cudaGetLastError();
138 | if (!evalError(err, file, line)) return;
139 |
140 | #if CUMAT_VERBOSE_ERROR_CHECKING==1
141 | // More careful checking. However, this will affect performance.
142 | err = cudaDeviceSynchronize();
143 | evalError(err, file, line);
144 | #endif
145 | }
146 | };
147 |
148 | /**
149 | * \brief Tests if the cuda library call wrapped inside the bracets was executed successfully, aka returned cudaSuccess.
150 | * Throws an cuMat::cuda_error if unsuccessfull
151 | * \param err the error code
152 | */
153 | #define CUMAT_SAFE_CALL( err ) CUMAT_NAMESPACE internal::ErrorHelpers::cudaSafeCall( err, __FILE__, __LINE__ )
154 |
155 | /**
156 | * \brief Tests if the cuda library call wrapped inside the bracets was executed successfully, aka returned cudaSuccess
157 | * Returns false iff unsuccessfull
158 | * \param err the error code
159 | */
160 | #define CUMAT_SAFE_CALL_NO_THROW( err ) CUMAT_NAMESPACE internal::ErrorHelpers::cudaSafeCallNoThrow( err, __FILE__, __LINE__ )
161 | /**
162 | * \brief Issue this after kernel launches to check for errors in the kernel.
163 | */
164 | #define CUMAT_CHECK_ERROR() CUMAT_NAMESPACE internal::ErrorHelpers::cudaCheckError( __FILE__, __LINE__ )
165 |
166 | //TODO: find a better place in some Utility header
167 |
168 | /**
169 | * \brief Numeric type conversion with overflow check.
170 | * If CUMAT_ENABLE_HOST_ASSERTIONS==1, this method
171 | * throws an std::runtime_error if the conversion results in
172 | * an overflow.
173 | *
174 | * If CUMAT_ENABLE_HOST_ASSERTIONS is not defined (default in release mode),
175 | * this method simply becomes static_cast.
176 | *
177 | * Source: The C++ Programming Language 4th Edition by Bjarne Stroustrup
178 | * https://stackoverflow.com/a/30114062/1786598
179 | *
180 | * \tparam Target the target type
181 | * \tparam Source the source type
182 | * \param v the source value
183 | * \return the casted target value
184 | */
185 | template
186 | CUMAT_STRONG_INLINE Target narrow_cast(Source v)
187 | {
188 | #if CUMAT_ENABLE_HOST_ASSERTIONS==1
189 | auto r = static_cast(v); // convert the value to the target type
190 | if (static_cast(r) != v)
191 | throw std::runtime_error("narrow_cast<>() failed");
192 | return r;
193 | #else
194 | return static_cast(v);
195 | #endif
196 | }
197 |
198 | }
199 | CUMAT_NAMESPACE_END
200 |
201 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/Logging.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_LOGGING_H__
2 | #define __CUMAT_LOGGING_H__
3 |
4 | #include "Macros.h"
5 |
6 | /**
7 | * Defines the logging macros.
8 | * All logging messages are started with a call to
9 | * e.g. CUMAT_LOG_INFO(message)
10 | *
11 | * You can define the CUMAT_LOG and the related logging levels
12 | * to point to your own logging implementation.
13 | * If you don't overwrite these, a very trivial logger is used
14 | * that simply prints to std::cout.
15 | *
16 | * This is achieved by globally defining CUMAT_LOGGING_PLUGIN
17 | * that includes a file that then defines all the logging macros:
18 | * CUMAT_LOG_DEBUG, CUMAT_LOG_INFO, CUMAT_LOG_WARNING, CUMAT_LOG_SEVERE
19 | */
20 |
21 | #ifndef CUMAT_LOG
22 |
23 | #include
24 | #include
25 | #include
26 |
27 | #ifdef CUMAT_LOGGING_PLUGIN
28 | #include CUMAT_LOGGING_PLUGIN
29 | #endif
30 |
31 | CUMAT_NAMESPACE_BEGIN
32 |
33 |
34 | class DummyLogger
35 | {
36 | private:
37 | std::ios_base::fmtflags flags_;
38 | bool enabled_;
39 |
40 | public:
41 | DummyLogger(const std::string& level)
42 | : enabled_(level != "[debug]") //to disable the many, many logs during kernel evaluations in the test suites
43 | {
44 | flags_ = std::cout.flags();
45 | if (enabled_) std::cout << level << " ";
46 | }
47 | ~DummyLogger()
48 | {
49 | if (enabled_) {
50 | std::cout << std::endl;
51 | std::cout.flags(flags_);
52 | }
53 | }
54 |
55 | template
56 | DummyLogger& operator<<(const T& t) {
57 | if (enabled_) std::cout << t;
58 | return *this;
59 | }
60 | };
61 |
62 | #ifndef CUMAT_LOG_DEBUG
63 | /**
64 | * Logs the message as a debug message
65 | */
66 | #define CUMAT_LOG_DEBUG(...) DummyLogger("[debug]") << __VA_ARGS__
67 | #endif
68 |
69 | #ifndef CUMAT_LOG_INFO
70 | /**
71 | * Logs the message as a information message
72 | */
73 | #define CUMAT_LOG_INFO(...) DummyLogger("[info]") << __VA_ARGS__
74 | #endif
75 | #ifndef CUMAT_LOG_WARNING
76 | /**
77 | * Logs the message as a warning message
78 | */
79 | #define CUMAT_LOG_WARNING(...) DummyLogger("[warning]") << __VA_ARGS__
80 | #endif
81 | #ifndef CUMAT_LOG_SEVERE
82 | /**
83 | * Logs the message as a severe error message
84 | */
85 | #define CUMAT_LOG_SEVERE(...) DummyLogger("[SEVERE]") << __VA_ARGS__
86 | #endif
87 |
88 | CUMAT_NAMESPACE_END
89 |
90 | #endif
91 |
92 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/Macros.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_MACROS_H__
2 | #define __CUMAT_MACROS_H__
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | /*
10 | * This file contains global macros and type definitions used everywhere
11 | */
12 |
13 | #ifndef CUMAT_NAMESPACE
14 | /**
15 | * \brief The namespace of the library
16 | */
17 | #define CUMAT_NAMESPACE ::cuMat::
18 | #endif
19 |
20 | #ifndef CUMAT_NAMESPACE_BEGIN
21 | /**
22 | * \brief Defines the namespace in which everything of cuMat lives in
23 | */
24 | #define CUMAT_NAMESPACE_BEGIN namespace cuMat {
25 | #endif
26 |
27 | #ifndef CUMAT_NAMESPACE_END
28 | /**
29 | * \brief Closes the namespace opened with CUMAT_NAMESPACE_BEGIN
30 | */
31 | #define CUMAT_NAMESPACE_END }
32 | #endif
33 |
34 | #ifndef CUMAT_FUNCTION_NAMESPACE_BEGIN
35 | /**
36 | * \brief Defines the namespace in which overloaed math functions are defined.
37 | * Examples are 'sin(x)' and 'pow(a,b)'.
38 | * This defaults to "cuMat::functions", but can be changed if needed.
39 | */
40 | #define CUMAT_FUNCTION_NAMESPACE_BEGIN namespace cuMat { namespace functions {
41 | #endif
42 |
43 | #ifndef CUMAT_FUNCTION_NAMESPACE_END
44 | /**
45 | * \brief Closes the namespace openeded with CUMAT_FUNCTION_NAMESPACE_BEGIN
46 | */
47 | #define CUMAT_FUNCTION_NAMESPACE_END }}
48 | #endif
49 |
50 | #ifndef CUMAT_FUNCTION_NAMESPACE
51 | /**
52 | * \brief The namespace in which overloaded math functions are defined.
53 | */
54 | #define CUMAT_FUNCTION_NAMESPACE ::cuMat::functions::
55 | #endif
56 |
57 | #ifdef _MSC_VER
58 | //running under MS Visual Studio -> no thread_local
59 | #define CUMAT_THREAD_LOCAL __declspec( thread )
60 | #else
61 | //C++11 compatible
62 | #define CUMAT_THREAD_LOCAL thread_local
63 | #endif
64 |
65 | #ifndef CUMAT_EIGEN_SUPPORT
66 | /**
67 | * \brief Set CUMAT_EIGEN_INTEROP to 1 to enable Eigen interop.
68 | * This enables the methods to convert between Eigen matrices and cuMat matrices.
69 | * Default: 0
70 | */
71 | #define CUMAT_EIGEN_SUPPORT 0
72 | #endif
73 |
74 | #ifdef __CUDACC__
75 | /**
76 | * If the current source file is compiled with the NVCC as a CUDA source file,
77 | * this macro is set to one, else to zero.
78 | */
79 | #define CUMAT_NVCC 1
80 | #else
81 | #define CUMAT_NVCC 0
82 | #endif
83 |
84 | #if CUMAT_NVCC==0
85 | #define CUMAT_ERROR_IF_NO_NVCC(name) {THIS_FUNCTION_REQUIRES_THE_FILE_TO_BE_COMPILED_WITH_NVCC name;}
86 | #else
87 | #define CUMAT_ERROR_IF_NO_NVCC(name)
88 | #endif
89 |
90 | /**
91 | * \brief Define this macro in a class that should not be copyable or assignable
92 | * \param TypeName the name of the class
93 | */
94 | #define CUMAT_DISALLOW_COPY_AND_ASSIGN(TypeName)\
95 | TypeName(const TypeName&) = delete; \
96 | void operator=(const TypeName&) = delete
97 |
98 | #define CUMAT_STR_DETAIL(x) #x
99 | #define CUMAT_STR(x) CUMAT_STR_DETAIL(x)
100 |
101 |
102 | /*
103 | * \brief enable verbose error checking after each kernel launch.
104 | * This implies a synchronization point after kernel.
105 | *
106 | * By default, this is only enabled in a debug build, but it can be
107 | * manually activated by defining CUMAT_VERBOSE_ERROR_CHECKING=1
108 | */
109 | #ifndef CUMAT_VERBOSE_ERROR_CHECKING
110 | #if defined(_DEBUG) || (!defined(NDEBUG) && !defined(_NDEBUG))
111 | #define CUMAT_VERBOSE_ERROR_CHECKING 1
112 | #else
113 | #define CUMAT_VERBOSE_ERROR_CHECKING 0
114 | #endif
115 | #endif
116 |
117 | /*
118 | * \brief enable host assertions.
119 | * The assertion macros (CUMAT_ASSERT* without CUMAT_ASSERT_CUDA)
120 | * will throw an exception if triggered.
121 | *
122 | * By default, this is only enabled in a debug build, but it can be
123 | * manually activated by defining CUMAT_ENABLE_HOST_ASSERTIONS=1
124 | */
125 | #ifndef CUMAT_ENABLE_HOST_ASSERTIONS
126 | #if defined(_DEBUG) || (!defined(NDEBUG) && !defined(_NDEBUG))
127 | #define CUMAT_ENABLE_HOST_ASSERTIONS 1
128 | #else
129 | #define CUMAT_ENABLE_HOST_ASSERTIONS 0
130 | #endif
131 | #endif
132 |
133 | /*
134 | * \brief enable device assertions.
135 | * The assertion macros CUMAT_ASSERT_CUDA
136 | * will throw an exception if triggered.
137 | *
138 | * By default, this is only enabled in a debug build, but it can be
139 | * manually activated by defining CUMAT_ENABLE_DEVICE_ASSERTIONS=1
140 | */
141 | #ifndef CUMAT_ENABLE_DEVICE_ASSERTIONS
142 | #if defined(_DEBUG) || (!defined(NDEBUG) && !defined(_NDEBUG))
143 | #define CUMAT_ENABLE_DEVICE_ASSERTIONS 1
144 | #else
145 | #define CUMAT_ENABLE_DEVICE_ASSERTIONS 0
146 | #endif
147 | #endif
148 |
149 | #if CUMAT_ENABLE_HOST_ASSERTIONS==1
150 |
151 | /**
152 | * \brief Runtime assertion, uses assert()
153 | * Only use for something that should never happen
154 | * \param x the expression that must be true
155 | */
156 | #define CUMAT_ASSERT(x) assert(x)
157 |
158 | #define CUMAT_ASSERT_ARGUMENT(x) \
159 | if (!(x)) throw std::invalid_argument(__FILE__ ":" CUMAT_STR(__LINE__) ": Invalid argument: " #x);
160 | #define CUMAT_ASSERT_BOUNDS(x) \
161 | if (!(x)) throw std::out_of_range(__FILE__ ":" CUMAT_STR(__LINE__) "Out of bounds: " #x);
162 | #define CUMAT_ASSERT_ERROR(x) \
163 | if (!(x)) throw std::runtime_error(__FILE__ ":" CUMAT_STR(__LINE__) "Runtime Error: " #x);
164 | #define CUMAT_ASSERT_DIMENSION(x) \
165 | if (!(x)) throw std::invalid_argument(__FILE__ ":" CUMAT_STR(__LINE__) ": Invalid dimensions: " #x);
166 |
167 | #else
168 |
169 | #define CUMAT_ASSERT(x)
170 | #define CUMAT_ASSERT_ARGUMENT(x)
171 | #define CUMAT_ASSERT_BOUNDS(x)
172 | #define CUMAT_ASSERT_ERROR(x)
173 | #define CUMAT_ASSERT_DIMENSION(x)
174 |
175 | #endif
176 |
177 | #if CUMAT_ENABLE_DEVICE_ASSERTIONS==1
178 | /**
179 | * \brief Assertions in device code (if supported)
180 | * \param x the expression that must be true
181 | */
182 | #define CUMAT_ASSERT_CUDA(x) assert(x)
183 | #else
184 | #define CUMAT_ASSERT_CUDA(x)
185 | #endif
186 |
187 | #define CUMAT_THROW_INVALID_ARGUMENT(msg) \
188 | throw std::invalid_argument(__FILE__ ":" CUMAT_STR(__LINE__) ": " msg);
189 | #define CUMAT_THROW_RUNTIME_ERROR(msg) \
190 | throw std::runtime_error(__FILE__ ":" CUMAT_STR(__LINE__) ": " msg);
191 | #define CUMAT_THROW_OUT_OF_RANGE(msg) \
192 | throw std::out_of_range(__FILE__ ":" CUMAT_STR(__LINE__) ": " msg);
193 |
194 | /**
195 | * \brief A static assertion
196 | * \param exp the compile-time boolean expression that must be true
197 | * \param msg an error message if exp is false
198 | */
199 | #define CUMAT_STATIC_ASSERT(exp, msg) static_assert(exp, msg)
200 |
201 | #define CUMAT_STRONG_INLINE __inline__
202 |
203 |
204 | /**
205 | * \brief Returns the integer division x/y rounded up.
206 | * Taken from https://stackoverflow.com/a/2745086/4053176
207 | */
208 | #define CUMAT_DIV_UP(x, y) (((x) + (y) - 1) / (y))
209 |
210 | /**
211 | * \brief Computes the logical implication (a -> b)
212 | */
213 | #define CUMAT_IMPLIES(a,b) (!(a) || (b))
214 |
215 |
216 | #define CUMAT_PUBLIC_API_NO_METHODS \
217 | enum \
218 | { \
219 | Flags = internal::traits::Flags, \
220 | Rows = internal::traits::RowsAtCompileTime, \
221 | Columns = internal::traits::ColsAtCompileTime, \
222 | Batches = internal::traits::BatchesAtCompileTime \
223 | }; \
224 | using Scalar = typename internal::traits::Scalar; \
225 | using SrcTag = typename internal::traits::SrcTag; \
226 | using DstTag = typename internal::traits::DstTag;
227 | /**
228 | * \brief Defines the basic API of each cumat class.
229 | * The typedefs and enums that have to be exposed.
230 | * But before that, you have to define the current class in \c Type
231 | * and the base class in \c Base.
232 | */
233 | #define CUMAT_PUBLIC_API \
234 | CUMAT_PUBLIC_API_NO_METHODS \
235 | using Base::derived; \
236 | using Base::eval_t;
237 |
238 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/MatrixBase.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_MATRIX_BASE_H__
2 | #define __CUMAT_MATRIX_BASE_H__
3 |
4 | #include
5 |
6 | #include "Macros.h"
7 | #include "ForwardDeclarations.h"
8 | #include "Constants.h"
9 |
10 | CUMAT_NAMESPACE_BEGIN
11 |
12 | /**
13 | * \brief The base class of all matrix types and matrix expressions.
14 | * \tparam _Derived
15 | */
16 | template
17 | class MatrixBase
18 | {
19 | public:
20 | typedef _Derived Type;
21 | typedef MatrixBase<_Derived> Base;
22 | CUMAT_PUBLIC_API_NO_METHODS
23 |
24 | /**
25 | * \returns a reference to the _Derived object
26 | */
27 | __host__ __device__ CUMAT_STRONG_INLINE _Derived& derived() { return *static_cast<_Derived*>(this); }
28 |
29 | /**
30 | * \returns a const reference to the _Derived object
31 | */
32 | __host__ __device__ CUMAT_STRONG_INLINE const _Derived& derived() const { return *static_cast(this); }
33 |
34 | /**
35 | * \brief Returns the number of rows of this matrix.
36 | * \returns the number of rows.
37 | */
38 | __host__ __device__ CUMAT_STRONG_INLINE Index rows() const { return derived().rows(); }
39 | /**
40 | * \brief Returns the number of columns of this matrix.
41 | * \returns the number of columns.
42 | */
43 | __host__ __device__ CUMAT_STRONG_INLINE Index cols() const { return derived().cols(); }
44 | /**
45 | * \brief Returns the number of batches of this matrix.
46 | * \returns the number of batches.
47 | */
48 | __host__ __device__ CUMAT_STRONG_INLINE Index batches() const { return derived().batches(); }
49 | /**
50 | * \brief Returns the total number of entries in this matrix.
51 | * This value is computed as \code rows()*cols()*batches()* \endcode
52 | * \return the total number of entries
53 | */
54 | __host__ __device__ CUMAT_STRONG_INLINE Index size() const { return rows()*cols()*batches(); }
55 |
56 | // EVALUATION
57 |
58 | typedef Matrix<
59 | typename internal::traits<_Derived>::Scalar,
60 | internal::traits<_Derived>::RowsAtCompileTime,
61 | internal::traits<_Derived>::ColsAtCompileTime,
62 | internal::traits<_Derived>::BatchesAtCompileTime,
63 | internal::traits<_Derived>::Flags
64 | > eval_t;
65 |
66 | /**
67 | * \brief Evaluates this into a matrix.
68 | * This evaluates any expression template. If this is already a matrix, it is returned unchanged.
69 | * \return the evaluated matrix
70 | */
71 | eval_t eval() const
72 | {
73 | return eval_t(derived());
74 | }
75 |
76 | /**
77 | * \brief Conversion: Matrix of size 1-1-1 (scalar) in device memory to the host memory scalar.
78 | *
79 | * This is expecially usefull to directly use the results of full reductions in host code.
80 | *
81 | * \tparam T
82 | */
83 | explicit operator Scalar () const
84 | {
85 | CUMAT_STATIC_ASSERT(
86 | internal::traits<_Derived>::RowsAtCompileTime == 1 &&
87 | internal::traits<_Derived>::ColsAtCompileTime == 1 &&
88 | internal::traits<_Derived>::BatchesAtCompileTime == 1,
89 | "Conversion only possible for compile-time scalars");
90 | eval_t m = eval();
91 | Scalar v;
92 | m.copyToHost(&v);
93 | return v;
94 | }
95 |
96 |
97 | // CWISE EXPRESSIONS
98 | #include "MatrixBlockPluginRvalue.inl"
99 | #include "UnaryOpsPlugin.inl"
100 | #include "BinaryOpsPlugin.inl"
101 | #include "ReductionOpsPlugin.inl"
102 | #include "DenseLinAlgPlugin.inl"
103 | #include "SparseExpressionOpPlugin.inl"
104 | };
105 |
106 |
107 |
108 | template
109 | struct MatrixReadWrapper
110 | {
111 | private:
112 | enum
113 | {
114 | //the existing access flags
115 | flags = internal::traits<_Derived>::AccessFlags,
116 | //boolean if the access is sufficient
117 | sufficient = (flags & _AccessFlags)
118 | };
119 | public:
120 | /**
121 | * \brief The wrapped type: either the type itself, if the access is sufficient,
122 | * or the evaluated type if not.
123 | */
124 | using type = typename std::conditional::type;
125 |
126 | /*
127 | template>::type>
128 | static type wrap(const T& m)
129 | {
130 | return m.derived();
131 | }
132 | template>::type>
133 | static type wrap(const T& m)
134 | {
135 | return m.derived().eval();
136 | }
137 | */
138 |
139 | private:
140 | MatrixReadWrapper(){} //not constructible
141 | };
142 |
143 | CUMAT_NAMESPACE_END
144 |
145 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/NullaryOps.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_NULLARY_OPS_H__
2 | #define __CUMAT_NULLARY_OPS_H__
3 |
4 | #include "Macros.h"
5 | #include "ForwardDeclarations.h"
6 | #include "CwiseOp.h"
7 |
8 | CUMAT_NAMESPACE_BEGIN
9 |
10 | namespace internal {
11 | template
12 | struct traits >
13 | {
14 | typedef _Scalar Scalar;
15 | enum
16 | {
17 | Flags = _Flags,
18 | RowsAtCompileTime = _Rows,
19 | ColsAtCompileTime = _Columns,
20 | BatchesAtCompileTime = _Batches,
21 | AccessFlags = ReadCwise
22 | };
23 | typedef CwiseSrcTag SrcTag;
24 | typedef DeletedDstTag DstTag;
25 | };
26 | }
27 |
28 | /**
29 | * \brief A generic nullary operator.
30 | * This is used as a leaf in the expression template tree.
31 | * The nullary functor can be any class or structure that supports
32 | * the following method:
33 | * \code
34 | * __device__ const _Scalar& operator()(Index row, Index col, Index batch);
35 | * \endcode
36 | * \tparam _Scalar
37 | * \tparam _NullaryFunctor
38 | */
39 | template
40 | class NullaryOp : public CwiseOp >
41 | {
42 | public:
43 | typedef CwiseOp > Base;
44 | typedef NullaryOp<_Scalar, _Rows, _Columns, _Batches, _Flags, _NullaryFunctor> Type;
45 | CUMAT_PUBLIC_API
46 |
47 | protected:
48 | const Index rows_;
49 | const Index cols_;
50 | const Index batches_;
51 | const _NullaryFunctor functor_;
52 |
53 | public:
54 | NullaryOp(Index rows, Index cols, Index batches, const _NullaryFunctor& functor)
55 | : rows_(rows), cols_(cols), batches_(batches), functor_(functor)
56 | {}
57 | explicit NullaryOp(const _NullaryFunctor& functor)
58 | : rows_(_Rows), cols_(_Columns), batches_(_Batches), functor_(functor)
59 | {
60 | CUMAT_STATIC_ASSERT(_Rows > 0, "The one-parameter constructor is only available for fully fixed-size instances");
61 | CUMAT_STATIC_ASSERT(_Columns > 0, "The one-parameter constructor is only available for fully fixed-size instances");
62 | CUMAT_STATIC_ASSERT(_Batches > 0, "The one-parameter constructor is only available for fully fixed-size instances");
63 | }
64 |
65 | __host__ __device__ CUMAT_STRONG_INLINE Index rows() const { return rows_; }
66 | __host__ __device__ CUMAT_STRONG_INLINE Index cols() const { return cols_; }
67 | __host__ __device__ CUMAT_STRONG_INLINE Index batches() const { return batches_; }
68 |
69 | __device__ CUMAT_STRONG_INLINE Scalar coeff(Index row, Index col, Index batch, Index linear) const
70 | {
71 | return functor_(row, col, batch);
72 | }
73 | };
74 |
75 | namespace functor
76 | {
77 | template
78 | struct ConstantFunctor
79 | {
80 | private:
81 | const _Scalar value_;
82 | public:
83 | ConstantFunctor(_Scalar value)
84 | : value_(value)
85 | {}
86 |
87 | __device__ CUMAT_STRONG_INLINE const _Scalar& operator()(Index row, Index col, Index batch) const
88 | {
89 | return value_;
90 | }
91 | };
92 |
93 | template
94 | struct IdentityFunctor
95 | {
96 | public:
97 | __device__ CUMAT_STRONG_INLINE _Scalar operator()(Index row, Index col, Index batch) const
98 | {
99 | return (row==col) ? _Scalar(1) : _Scalar(0);
100 | }
101 | };
102 | }
103 |
104 | template
105 | HostScalar<_Scalar> make_host_scalar(const _Scalar& value)
106 | {
107 | return HostScalar<_Scalar>(1, 1, 1, functor::ConstantFunctor<_Scalar>(value));
108 | }
109 |
110 | CUMAT_NAMESPACE_END
111 |
112 | #endif
113 |
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/NumTraits.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_NUM_TRAITS_H__
2 | #define __CUMAT_NUM_TRAITS_H__
3 |
4 | #include
5 |
6 | #include "Macros.h"
7 | #include "ForwardDeclarations.h"
8 | #include
9 | #include
10 |
11 | CUMAT_NAMESPACE_BEGIN
12 |
13 | namespace internal
14 | {
15 | /**
16 | * \brief General implementation of NumTraits
17 | * \tparam T
18 | */
19 | template
20 | struct NumTraits
21 | {
22 | /**
23 | * \brief The type itself
24 | */
25 | typedef T Type;
26 | /**
27 | * \brief For complex types: the corresponding real type; equals to Type for non-complex types
28 | */
29 | typedef T RealType;
30 | /**
31 | * \brief For compound types (blocked types): the corresponding element type (e.g. cfloat3->cfloat, double4->double);
32 | * equals to Type for non-blocked types
33 | */
34 | typedef T ElementalType;
35 | enum
36 | {
37 | /**
38 | * \brief Equals one if cuBlas supports this type
39 | * \see CublasApi
40 | */
41 | IsCudaNumeric = 0,
42 | /**
43 | * \brief Equal to true if this type is a complex type, and hence Type!=RealType
44 | */
45 | IsComplex = false
46 | };
47 | /**
48 | * \brief The default epsilon for approximate comparisons
49 | */
50 | static constexpr CUMAT_STRONG_INLINE ElementalType epsilon() {return std::numeric_limits::epsilon();}
51 | };
52 |
53 | template <>
54 | struct NumTraits
55 | {
56 | typedef float Type;
57 | typedef float RealType;
58 | typedef float ElementalType;
59 | enum
60 | {
61 | IsCudaNumeric = 1,
62 | IsComplex = false,
63 | };
64 | static constexpr CUMAT_STRONG_INLINE RealType epsilon() {return std::numeric_limits::epsilon();}
65 | };
66 | template <>
67 | struct NumTraits
68 | {
69 | typedef double Type;
70 | typedef double RealType;
71 | typedef double ElementalType;
72 | enum
73 | {
74 | IsCudaNumeric = 1,
75 | IsComplex = false,
76 | };
77 | static constexpr CUMAT_STRONG_INLINE RealType epsilon() {return std::numeric_limits::epsilon();}
78 | };
79 |
80 | template <>
81 | struct NumTraits
82 | {
83 | typedef cfloat Type;
84 | typedef float RealType;
85 | typedef cfloat ElementalType;
86 | enum
87 | {
88 | IsCudaNumeric = 1,
89 | IsComplex = true,
90 | };
91 | static constexpr CUMAT_STRONG_INLINE RealType epsilon() {return std::numeric_limits::epsilon();}
92 | };
93 |
94 | template <>
95 | struct NumTraits
96 | {
97 | typedef cdouble Type;
98 | typedef double RealType;
99 | typedef cdouble ElementalType;
100 | enum
101 | {
102 | IsCudaNumeric = 1,
103 | IsComplex = true,
104 | };
105 | static constexpr CUMAT_STRONG_INLINE RealType epsilon() {return std::numeric_limits::epsilon();}
106 | };
107 |
108 | template
109 | struct isPrimitive : std::is_arithmetic {};
110 | template <> struct isPrimitive : std::integral_constant {};
111 | template <> struct isPrimitive : std::integral_constant {};
112 |
113 | /**
114 | * \brief Can the type T be used for broadcasting when the scalar type of the other matrix is S?
115 | */
116 | template
117 | struct canBroadcast : std::integral_constant::value && CUMAT_NAMESPACE internal::isPrimitive::value) \
119 | || std::is_same::type, typename std::remove_cv::type>::value
120 | > {};
121 |
122 | template
123 | struct NumOps //special functions for numbers
124 | {
125 | static __host__ __device__ CUMAT_STRONG_INLINE T conj(const T& v) {return v;}
126 | };
127 | template<>
128 | struct NumOps
129 | {
130 | static __host__ __device__ CUMAT_STRONG_INLINE cfloat conj(const cfloat& v) {return ::thrust::conj(v);}
131 | };
132 | template<>
133 | struct NumOps
134 | {
135 | static __host__ __device__ CUMAT_STRONG_INLINE cdouble conj(const cdouble& v) {return ::thrust::conj(v);}
136 | };
137 | }
138 |
139 | CUMAT_NAMESPACE_END
140 |
141 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/Profiling.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_PROFILING_H__
2 | #define __CUMAT_PROFILING_H__
3 |
4 | #include "Macros.h"
5 |
6 | #ifndef CUMAT_PROFILING
7 | /**
8 | * \brief Define this macro as '1' to enable profiling.
9 | * For the created statistics, see the class Profiling.
10 | */
11 | #define CUMAT_PROFILING 0
12 | #endif
13 |
14 | CUMAT_NAMESPACE_BEGIN
15 |
16 |
17 |
18 | /**
19 | * \brief This class contains the counters used to profile the library.
20 | * See Counter for which statistics are available.
21 | * For easy access, several macros are available with the prefix CUMAT_PROFILING_ .
22 | */
23 | class Profiling
24 | {
25 | public:
26 | enum Counter
27 | {
28 | DeviceMemAlloc,
29 | DeviceMemFree,
30 | HostMemAlloc,
31 | HostMemFree,
32 |
33 | MemcpyDeviceToDevice,
34 | MemcpyHostToHost,
35 | MemcpyDeviceToHost,
36 | MemcpyHostToDevice,
37 |
38 | /**
39 | * \brief Any evaluation has happend
40 | */
41 | EvalAny,
42 | /**
43 | * \brief Component-wise evaluation
44 | */
45 | EvalCwise,
46 | /**
47 | * \brief Special transposition operation
48 | */
49 | EvalTranspose,
50 | /**
51 | * \brief Reduction operation with CUB
52 | */
53 | EvalReduction,
54 | /**
55 | * \brief Matrix-Matrix multiplication with cuBLAS
56 | */
57 | EvalMatmul,
58 | /**
59 | * \brief Sparse component-wise evaluation
60 | */
61 | EvalCwiseSparse,
62 | /**
63 | * \brief Sparse matrix-matrix multiplication
64 | */
65 | EvalMatmulSparse,
66 |
67 | _NumCounter_
68 | };
69 | private:
70 | size_t counters_[_NumCounter_];
71 |
72 | public:
73 | void resetAll()
74 | {
75 | for (size_t i = 0; i < _NumCounter_; ++i) counters_[i] = 0;
76 | }
77 | void reset(Counter counter)
78 | {
79 | counters_[counter] = 0;
80 | }
81 | void inc(Counter counter)
82 | {
83 | counters_[counter]++;
84 | }
85 | size_t get(Counter counter)
86 | {
87 | return counters_[counter];
88 | }
89 | size_t getReset(Counter counter)
90 | {
91 | size_t v = counters_[counter];
92 | counters_[counter] = 0;
93 | return v;
94 | }
95 |
96 | private:
97 | Profiling()
98 | {
99 | resetAll();
100 | }
101 | CUMAT_DISALLOW_COPY_AND_ASSIGN(Profiling);
102 |
103 | public:
104 | static Profiling& instance()
105 | {
106 | static Profiling p;
107 | return p;
108 | }
109 | };
110 |
111 | #ifdef CUMAT_PARSED_BY_DOXYGEN
112 |
113 | /**
114 | * \brief Increments the counter 'counter' (an element of Profiling::Counter).
115 | */
116 | #define CUMAT_PROFILING_INC(counter)
117 |
118 | /**
119 | * \brief Gets and resets the counter 'counter' (an element of Profiling::Counter).
120 | * If profiling is disabled (CUMAT_PROFILING is not defined or unequal to 1), the result
121 | * is not defined.
122 | */
123 | #define CUMAT_PROFILING_GET(counter)
124 |
125 | /**
126 | * \brief Resets all counters
127 | */
128 | #define CUMAT_PROFILING_RESET()
129 |
130 | #else
131 |
132 | #if CUMAT_PROFILING==1
133 | //Profiling enabled
134 | #define CUMAT_PROFILING_INC(counter) \
135 | CUMAT_NAMESPACE Profiling::instance().inc(CUMAT_NAMESPACE Profiling::Counter::counter)
136 | #define CUMAT_PROFILING_GET(counter) \
137 | CUMAT_NAMESPACE Profiling::instance().getReset(CUMAT_NAMESPACE Profiling::Counter::counter)
138 | #define CUMAT_PROFILING_RESET() \
139 | CUMAT_NAMESPACE Profiling::instance().resetAll()
140 | #else
141 | //Profiling disabled
142 | #define CUMAT_PROFILING_INC(counter) ((void)0)
143 | #define CUMAT_PROFILING_GET(counter) ((void)0)
144 | #define CUMAT_PROFILING_RESET() ((void)0)
145 | #endif
146 |
147 | #endif
148 |
149 |
150 |
151 | CUMAT_NAMESPACE_END
152 |
153 |
154 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/ReductionAlgorithmSelection.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_REDUCTION_ALGORITHM_SELECTION_H__
2 | #define __CUMAT_REDUCTION_ALGORITHM_SELECTION_H__
3 |
4 | #include
5 | #include
6 |
7 | #include "Macros.h"
8 | #include "ForwardDeclarations.h"
9 |
10 | CUMAT_NAMESPACE_BEGIN
11 |
12 | namespace internal
13 | {
14 | /**
15 | * \brief Algorithms possible during dynamic selection.
16 | * These are a subset of the tags from namespace ReductionAlg.
17 | */
18 | enum class ReductionAlgorithm
19 | {
20 | Segmented,
21 | Thread,
22 | Warp,
23 | Block256,
24 | Device1,
25 | Device2,
26 | Device4
27 | };
28 |
29 | /**
30 | * \brief Selects the best reduction algorithm dynamically given
31 | * the reduction axis (inner, middle, outer).
32 | * Given a matrix in ColumnMajor order, these axis
33 | * correspond to the following: inner=Row, middle=Column, outer=Batch.
34 | *
35 | * The timings from which those selections were determined were evaluated
36 | * on a Nvidia RTX 2070.
37 | * For a different architecture, you might need to tweak the conditions
38 | * in the source code.
39 | */
40 | struct ReductionAlgorithmSelection
41 | {
42 | private:
43 | typedef std::tuple condition;
44 | static constexpr int MAX_CONDITION = 3;
45 | typedef std::array conditions;
46 | typedef std::tuple choice;
47 |
48 | template
49 | static ReductionAlgorithm select(
50 | const choice(&conditions)[N], ReductionAlgorithm def,
51 | Index numBatches, Index batchSize)
52 | {
53 | const double nb = std::log2(numBatches);
54 | const double bs = std::log2(batchSize);
55 | for (const choice& c : conditions)
56 | {
57 | bool success = true;
58 | for (int i = 0; i < std::get<1>(c); ++i) {
59 | const auto& cond = std::get<2>(c)[i];
60 | if (std::get<0>(cond)*nb + std::get<1>(cond)*bs < std::get<2>(cond))
61 | success = false;
62 | }
63 | if (success)
64 | return std::get<0>(c);
65 | }
66 | return def;
67 | }
68 |
69 | public:
70 | static ReductionAlgorithm inner(Index numBatches, Index batchSize)
71 | {
72 | //adopt to new architectures
73 | static const choice CONDITONS[] = {
74 | choice{ReductionAlgorithm::Device1, 2, {condition{1.2, 1.0, 19.5}, condition{-1,0,-2.5}}},
75 | choice{ReductionAlgorithm::Device2, 3, {condition{0.42857142857142855,1,17.821428571428573}, condition{-1,0,-4.25}, condition{1,0,2.5}}},
76 | choice{ReductionAlgorithm::Device4, 3, {condition{0,1,16.25}, condition{-1,0,-5.5}, condition{1,0,4.25}}},
77 | choice{ReductionAlgorithm::Block256,2, {condition{-1.6, 1, 8}, condition{-1,0,-5}}},
78 | choice{ReductionAlgorithm::Thread, 2, {condition{0.475, -1, 2.01875}, condition{0, -1, -4.75}}}
79 | };
80 | static const ReductionAlgorithm DEFAULT = ReductionAlgorithm::Warp;
81 | return select(CONDITONS, DEFAULT, numBatches, batchSize);
82 | }
83 |
84 | static ReductionAlgorithm middle(Index numBatches, Index batchSize)
85 | {
86 | //adopt to new architectures
87 | static const choice CONDITONS[] = {
88 | choice{ReductionAlgorithm::Device1, 2, {condition{1.5,1,19.5}, condition{-1,0,-2.5}}},
89 | choice{ReductionAlgorithm::Device2, 3, {condition{0,1,15.5}, condition{1,0,2.5}, condition{-1,0,-4}}},
90 | choice{ReductionAlgorithm::Device4, 3, {condition{0,1,15.75}, condition{1,0,4}, condition{-1,0,-5.75}}},
91 | choice{ReductionAlgorithm::Block256,2, {condition{0,1,9}, condition{-1,0,-2.5}}},
92 | choice{ReductionAlgorithm::Warp, 2, {condition{0,1,4}, condition{-1,0,-11.75}}}
93 | };
94 | static const ReductionAlgorithm DEFAULT = ReductionAlgorithm::Thread;
95 | return select(CONDITONS, DEFAULT, numBatches, batchSize);
96 | }
97 |
98 | static ReductionAlgorithm outer(Index numBatches, Index batchSize)
99 | {
100 | //adopt to new architectures
101 | static const choice CONDITONS[] = {
102 | choice{ReductionAlgorithm::Device1, 2, {condition{-1,0,-2}, condition{1.875,1,19}}},
103 | choice{ReductionAlgorithm::Device4, 3, {condition{1,0,2}, condition{-1,0,-4.25}, condition{10, 9, 184.25}}},
104 | choice{ReductionAlgorithm::Device2, 3, {condition{1,0,2}, condition{-1,0,-4.25}, condition{-0.22222, 1, 14.085555}}},
105 | choice{ReductionAlgorithm::Segmented, 3, {condition{1,0,4}, condition{0,1,11.5}, condition{-1,0,-8.5}}},
106 | choice{ReductionAlgorithm::Block256,2, {condition{0,1,8}, condition{-1,0,-2}}},
107 | choice{ReductionAlgorithm::Warp, 2, {condition{0,1,2.75}, condition{-1,0,-11.75}}}
108 | };
109 | static const ReductionAlgorithm DEFAULT = ReductionAlgorithm::Thread;
110 | return select(CONDITONS, DEFAULT, numBatches, batchSize);
111 | }
112 | };
113 | }
114 |
115 | CUMAT_NAMESPACE_END
116 |
117 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/SolverBase.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_SOLVER_BASE_H__
2 | #define __CUMAT_SOLVER_BASE_H__
3 |
4 | #include "Macros.h"
5 | #include "ForwardDeclarations.h"
6 |
7 | CUMAT_NAMESPACE_BEGIN
8 |
9 | template
10 | class SolverBase
11 | {
12 | protected:
13 | CUMAT_STRONG_INLINE _DecompositionImpl& impl() { return *static_cast<_DecompositionImpl*>(this); }
14 | CUMAT_STRONG_INLINE const _DecompositionImpl& impl() const { return *static_cast(this); }
15 |
16 | public:
17 | using Scalar = typename internal::traits<_DecompositionImpl>::Scalar;
18 | using MatrixType = typename internal::traits<_DecompositionImpl>::MatrixType;
19 | enum
20 | {
21 | Flags = internal::traits::Flags,
22 | Rows = internal::traits::RowsAtCompileTime,
23 | Columns = internal::traits::ColsAtCompileTime,
24 | Batches = internal::traits::BatchesAtCompileTime,
25 | };
26 |
27 | /**
28 | * \brief Solves the system of linear equations
29 | * \tparam _RHS the type of the right hand side
30 | * \param rhs the right hand side matrix
31 | * \return The operation that computes the solution of the linear system
32 | */
33 | template
34 | SolveOp<_DecompositionImpl, _RHS> solve(const MatrixBase<_RHS>& rhs) const
35 | {
36 | return SolveOp<_DecompositionImpl, _RHS>(impl(), rhs.derived());
37 | }
38 | };
39 |
40 | namespace internal
41 | {
42 | struct SolveSrcTag {};
43 | template
44 | struct traits >
45 | {
46 | //using Scalar = typename internal::traits<_Child>::Scalar;
47 | using Scalar = typename internal::traits<_RHS>::Scalar;
48 | enum
49 | {
50 | Flags = ColumnMajor,
51 | RowsAtCompileTime = internal::traits<_RHS>::RowsAtCompileTime,
52 | ColsAtCompileTime = internal::traits<_RHS>::ColsAtCompileTime,
53 | BatchesAtCompileTime = internal::traits<_RHS>::BatchesAtCompileTime,
54 | AccessFlags = 0
55 | };
56 | typedef SolveSrcTag SrcTag;
57 | typedef DeletedDstTag DstTag;
58 | };
59 | }
60 |
61 | /**
62 | * \brief General solver operation.
63 | * Delegates to _solve_impl of the solver implementation.
64 | * \tparam _Solver
65 | * \tparam _RHS
66 | */
67 | template
68 | class SolveOp : public MatrixBase>
69 | {
70 | public:
71 | typedef MatrixBase> Base;
72 | typedef SolveOp<_Solver, _RHS> Type;
73 | using Scalar = typename internal::traits<_RHS>::Scalar;
74 | enum
75 | {
76 | Flags = ColumnMajor,
77 | Rows = internal::traits<_RHS>::RowsAtCompileTime,
78 | Columns = internal::traits<_RHS>::ColsAtCompileTime,
79 | Batches = internal::traits<_RHS>::BatchesAtCompileTime
80 | };
81 |
82 | private:
83 | const _Solver& decomposition_;
84 | const _RHS rhs_;
85 |
86 | public:
87 | SolveOp(const _Solver& decomposition, const MatrixBase<_RHS>& rhs)
88 | : decomposition_(decomposition)
89 | , rhs_(rhs.derived())
90 | {
91 | CUMAT_STATIC_ASSERT((std::is_same<
92 | typename internal::NumTraits::ElementalType,
93 | typename internal::NumTraits::Scalar>::ElementalType>::value),
94 | "Datatype of left- and right hand side must match");
95 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(_Solver::Batches > 1 && _RHS::Batches > 0, _Solver::Batches == _RHS::Batches),
96 | "Static count of batches must match"); //note: _Solver::Batches>1 to allow broadcasting
97 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(_Solver::Rows > 0 && _Solver::Columns > 0, _Solver::Rows == _Solver::Columns),
98 | "Static count of rows and columns must be equal (square matrix)");
99 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(_Solver::Rows > 0 && _RHS::Rows > 0, _Solver::Rows == _RHS::Rows),
100 | "Left and right hand side are not compatible");
101 |
102 | CUMAT_ASSERT(CUMAT_IMPLIES(_Solver::Batches!=1, decomposition.batches() == rhs.batches()) && "batch size of the matrix and the right hand side does not match");
103 | CUMAT_ASSERT(decomposition.rows() == decomposition.cols() && "matrix must be square"); //TODO: relax for Least-Squares problems
104 | CUMAT_ASSERT(decomposition.cols() == rhs.rows() && "matrix size does not match right-hand-side");
105 | }
106 |
107 | __host__ __device__ CUMAT_STRONG_INLINE Index rows() const { return rhs_.rows(); }
108 | __host__ __device__ CUMAT_STRONG_INLINE Index cols() const { return rhs_.cols(); }
109 | __host__ __device__ CUMAT_STRONG_INLINE Index batches() const { return rhs_.batches(); }
110 |
111 | const _Solver& getDecomposition() const { return decomposition_; }
112 | const _RHS& getRhs() const { return rhs_; }
113 | };
114 |
115 | namespace internal
116 | {
117 | //Assignment for decompositions that call SolveOp::evalTo
118 | template
119 | struct Assignment<_Dst, _Src, _Mode, DenseDstTag, SolveSrcTag>
120 | {
121 | static void assign(_Dst& dst, const _Src& src)
122 | {
123 | static_assert(_Mode == AssignmentMode::ASSIGN, "Decompositions only support AssignmentMode::ASSIGN (operator=)");
124 | src.derived().getDecomposition()._solve_impl(src.derived().getRhs(), dst.derived());
125 | }
126 | };
127 | }
128 |
129 | CUMAT_NAMESPACE_END
130 |
131 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/SparseEvaluation.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_SPARSE_EVALUATION__
2 | #define __CUMAT_SPARSE_EVALUATION__
3 |
4 | #include "Macros.h"
5 | #include "CwiseOp.h"
6 | #include "SparseMatrix.h"
7 |
8 | CUMAT_NAMESPACE_BEGIN
9 |
10 | namespace internal
11 | {
12 | namespace kernels
13 | {
14 | template
15 | __global__ void CwiseCSREvaluationKernel(dim3 virtual_size, const T expr, M matrix)
16 | {
17 | const int* JA = matrix.getSparsityPattern().JA.data();
18 | const int* IA = matrix.getSparsityPattern().IA.data();
19 | Index batchStride = matrix.getSparsityPattern().nnz;
20 | //TODO: Profiling, what is the best way to loop over the batches?
21 | CUMAT_KERNEL_2D_LOOP(outer, batch, virtual_size)
22 | int start = JA[outer];
23 | int end = JA[outer + 1];
24 | for (int i = start; i < end; ++i)
25 | {
26 | int inner = IA[i];
27 | Index row = outer;
28 | Index col = inner;
29 | Index idx = i + batch * batchStride;
30 | auto val = expr.coeff(row, col, batch, idx);
31 | internal::CwiseAssignmentHandler::assign(matrix, val, idx);
32 | }
33 | CUMAT_KERNEL_2D_LOOP_END
34 | }
35 | template
36 | __global__ void CwiseCSCEvaluationKernel(dim3 virtual_size, const T expr, M matrix)
37 | {
38 | const int* JA = matrix.getSparsityPattern().JA.data();
39 | const int* IA = matrix.getSparsityPattern().IA.data();
40 | Index batchStride = matrix.getSparsityPattern().nnz;
41 | //TODO: Profiling, what is the best way to loop over the batches?
42 | CUMAT_KERNEL_2D_LOOP(outer, batch, virtual_size)
43 | int start = JA[outer];
44 | int end = JA[outer + 1];
45 | for (int i = start; i < end; ++i)
46 | {
47 | int inner = IA[i];
48 | Index row = inner;
49 | Index col = outer;
50 | Index idx = i + batch * batchStride;
51 | auto val = expr.coeff(row, col, batch, idx);
52 | internal::CwiseAssignmentHandler::assign(matrix, val, idx);
53 | }
54 | CUMAT_KERNEL_2D_LOOP_END
55 | }
56 | template
57 | __global__ void CwiseELLPACKEvaluationKernel(dim3 virtual_size, const T expr, M matrix)
58 | {
59 | const SparsityPattern::IndexMatrix indices = matrix.getSparsityPattern().indices;
60 | Index nnzPerRow = matrix.getSparsityPattern().nnzPerRow;
61 | Index rows = matrix.getSparsityPattern().rows;
62 | Index batchStride = rows * nnzPerRow;
63 | //TODO: Profiling, what is the best way to loop over the batches?
64 | CUMAT_KERNEL_2D_LOOP(row, batch, virtual_size)
65 | for (int i = 0; i < nnzPerRow; ++i)
66 | {
67 | Index col = indices.coeff(row, i, 0, -1);
68 | if (col < 0) continue; //TODO: test if it is faster to continue reading (and set col=0) and discard before assignment
69 | Index idx = row + i * rows + batch * batchStride;
70 | auto val = expr.coeff(row, col, batch, idx);
71 | internal::CwiseAssignmentHandler::assign(matrix, val, idx);
72 | }
73 | CUMAT_KERNEL_2D_LOOP_END
74 | }
75 | }
76 | }
77 |
78 | namespace internal {
79 |
80 | #if CUMAT_NVCC==1
81 | //General assignment for everything that fulfills CwiseSrcTag into SparseDstTag (cwise sparse evaluation)
82 | //The source expression is only evaluated at the non-zero entries of the target SparseMatrix
83 | template
84 | struct Assignment<_Dst, _Src, _Mode, SparseDstTag, CwiseSrcTag>
85 | {
86 | private:
87 | static void assign(_Dst& dst, const _Src& src, std::integral_constant)
88 | {
89 | //here is now the real logic
90 | Context& ctx = Context::current();
91 | KernelLaunchConfig cfg = ctx.createLaunchConfig2D(static_cast(dst.derived().outerSize()), static_cast(dst.derived().batches()),
92 | kernels::CwiseCSREvaluationKernel);
93 | kernels::CwiseCSREvaluationKernel
94 | <<< cfg.block_count, cfg.thread_per_block, 0, ctx.stream() >>>
95 | (cfg.virtual_size, src.derived(), dst.derived());
96 | CUMAT_CHECK_ERROR();
97 | }
98 | static void assign(_Dst& dst, const _Src& src, std::integral_constant)
99 | {
100 | //here is now the real logic
101 | Context& ctx = Context::current();
102 | KernelLaunchConfig cfg = ctx.createLaunchConfig2D(static_cast(dst.derived().outerSize()), static_cast(dst.derived().batches()),
103 | kernels::CwiseCSCEvaluationKernel);
104 | kernels::CwiseCSCEvaluationKernel
105 | <<< cfg.block_count, cfg.thread_per_block, 0, ctx.stream() >>>
106 | (cfg.virtual_size, src.derived(), dst.derived());
107 | CUMAT_CHECK_ERROR();
108 | }
109 | static void assign(_Dst& dst, const _Src& src, std::integral_constant)
110 | {
111 | //here is now the real logic
112 | Context& ctx = Context::current();
113 | KernelLaunchConfig cfg = ctx.createLaunchConfig2D(static_cast(dst.derived().outerSize()), static_cast(dst.derived().batches()),
114 | kernels::CwiseELLPACKEvaluationKernel);
115 | kernels::CwiseELLPACKEvaluationKernel
116 | <<< cfg.block_count, cfg.thread_per_block, 0, ctx.stream() >>>
117 | (cfg.virtual_size, src.derived(), dst.derived());
118 | CUMAT_CHECK_ERROR();
119 | }
120 |
121 | public:
122 | static void assign(_Dst& dst, const _Src& src)
123 | {
124 | typedef typename _Dst::Type DstActual;
125 | typedef typename _Src::Type SrcActual;
126 | CUMAT_PROFILING_INC(EvalCwiseSparse);
127 | CUMAT_PROFILING_INC(EvalAny);
128 | if (dst.size() == 0) return;
129 | CUMAT_ASSERT(src.rows() == dst.rows());
130 | CUMAT_ASSERT(src.cols() == dst.cols());
131 | CUMAT_ASSERT(src.batches() == dst.batches());
132 |
133 | CUMAT_LOG_DEBUG("Evaluate component wise sparse expression " << typeid(src.derived()).name()
134 | << "\n rows=" << src.rows() << ", cols=" << src.cols() << ", batches=" << src.batches());
135 | assign(dst, src, std::integral_constant());
136 | CUMAT_LOG_DEBUG("Evaluation done");
137 | }
138 | };
139 | #endif
140 | }
141 |
142 | CUMAT_NAMESPACE_END
143 |
144 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/SparseExpressionOp.h:
--------------------------------------------------------------------------------
1 | #ifndef __CUMAT_SPARSE_EXPRESSION_OP__
2 | #define __CUMAT_SPARSE_EXPRESSION_OP__
3 |
4 | #include "Macros.h"
5 | #include "ForwardDeclarations.h"
6 | #include "MatrixBase.h"
7 | #include "SparseMatrixBase.h"
8 |
9 | CUMAT_NAMESPACE_BEGIN
10 |
11 | namespace internal
12 | {
13 | template
14 | struct traits >
15 | {
16 | using Scalar = typename internal::traits<_Child>::Scalar;
17 | enum
18 | {
19 | SFlags = _SparseFlags,
20 | Flags = internal::traits<_Child>::Flags,
21 | RowsAtCompileTime = internal::traits<_Child>::RowsAtCompileTime,
22 | ColsAtCompileTime = internal::traits<_Child>::ColsAtCompileTime,
23 | BatchesAtCompileTime = internal::traits<_Child>::BatchesAtCompileTime,
24 | AccessFlags = ReadCwise
25 | };
26 | typedef CwiseSrcTag SrcTag;
27 | typedef DeletedDstTag DstTag;
28 | };
29 | }
30 |
31 | template
32 | class SparseExpressionOp : public SparseMatrixBase>
33 | {
34 | public:
35 | typedef SparseExpressionOp<_Child, _SparseFlags> Type;
36 | typedef SparseMatrixBase Base;
37 | CUMAT_PUBLIC_API
38 | enum
39 | {
40 | SFlags = _SparseFlags
41 | };
42 |
43 | using typename Base::StorageIndex;
44 | using Base::rows;
45 | using Base::cols;
46 | using Base::batches;
47 | using Base::nnz;
48 | using Base::size;
49 |
50 | protected:
51 | typedef typename MatrixReadWrapper<_Child, AccessFlags::ReadCwise>::type child_wrapped_t;
52 | const child_wrapped_t child_;
53 |
54 | public:
55 | SparseExpressionOp(const MatrixBase<_Child>& child, const SparsityPattern<_SparseFlags>& sparsityPattern)
56 | : Base(sparsityPattern, child.batches())
57 | , child_(child.derived())
58 | {
59 | CUMAT_ASSERT_DIMENSION(child.rows() == sparsityPattern.rows);
60 | CUMAT_ASSERT_DIMENSION(child.cols() == sparsityPattern.cols);
61 | }
62 |
63 | __device__ CUMAT_STRONG_INLINE Scalar coeff(Index row, Index col, Index batch, Index index) const
64 | {
65 | return child_.derived().coeff(row, col, batch, index), row, col, batch;
66 | }
67 |
68 | __device__ CUMAT_STRONG_INLINE Scalar getSparseCoeff(Index row, Index col, Index batch, Index index) const
69 | {
70 | return child_.derived().coeff(row, col, batch, index);
71 | }
72 | };
73 |
74 | CUMAT_NAMESPACE_END
75 |
76 |
77 | #endif
--------------------------------------------------------------------------------
/extensions/include/cuMat/src/SparseExpressionOpPlugin.inl:
--------------------------------------------------------------------------------
1 | //Included in MatrixBase
2 |
3 | /**
4 | * \brief Views this matrix expression as a sparse matrix.
5 | * This enforces the specified sparsity pattern and the coefficients
6 | * of this matrix expression are then only evaluated at these positions.
7 | *
8 | * For now, the only use case is the sparse matrix-vector product.
9 | * For example:
10 | * \code
11 | * SparseMatrix<...> m1, m2; //all initialized with the same sparsity pattern
12 | * VectorXf v1, v2 = ...;
13 | * v1 = (m1 + m2) * v2;
14 | * \endcode
15 | * In this form, this would trigger a dense matrix vector multiplication, which is
16 | * completely unfeasable. This is because the the addition expression
17 | * does not know anything about sparse matrices and the product operation
18 | * then only sees an addition expression on the left hand side. Thus,
19 | * because of lacking knowledge, it has to trigger a dense evaluation.
20 | *
21 | * Improvement:
22 | * \code
23 | * v1 = (m1 + m2).sparseView(m1.getSparsityPattern()) * v2;
24 | * \endcode
25 | * with Format being either CSR or CSC.
26 | * This enforces the sparsity pattern of m1 onto the addition expression.
27 | * Thus the (immediately following!) product expression sees this sparse expression
28 | * and can trigger a sparse matrix-vector multiplication.
29 | * But the sparse matrices \c m1 and \c m2 now have to search for their coefficients.
30 | * This is because they don't know that the parent operations (add+multiply) will call
31 | * their coefficients in order of their own entries. In other words, that the \c linear
32 | * index parameter in \ref coeff(Index row, Index col, Index batch, Index linear) matches
33 | * the linear index in their data array.
34 | * (This is a valid assumption if you take transpose and block operations that change the
35 | * access pattern into considerations. Further, this allows e.g. \c m2 to have a different
36 | * sparsity pattern from m1, but only the entries that are included in both are used.)
37 | *
38 | * To overcome the above problem, one has to make one last adjustion:
39 | * \code
40 | * v1 = (m1.direct() + m2.direct()).sparseView(m1.getSparsityPattern()) * v2;
41 | * \endcode
42 | * \ref SparseMatrix::direct() tells the matrix that the linear index in
43 | * \ref coeff(Index row, Index col, Index batch, Index linear) matches the linear index
44 | * in the data array and thus can be used directly. This discards and checks that
45 | * the row, column and batch index actually match. So use this with care
46 | * if you know that access pattern is not changed in the operation.
47 | * (This holds true for all non-broadcasting component wise expressions)
48 | *
49 | * \param pattern the enforced sparsity pattern
50 | * \tparam _SparseFlags the sparse format: CSC or CSR
51 | */
52 | template
53 | SparseExpressionOp
54 | sparseView(const SparsityPattern<_SparseFlags>& pattern)
55 | {
56 | return SparseExpressionOp(derived(), pattern);
57 | }
--------------------------------------------------------------------------------
/extensions/push_pull_inpaint.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | // #include "cuda/SplatSliceGPU.cuh"
6 |
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_DIM(x, d) TORCH_CHECK((x.dim() == (d)), #x " must be a tensor with ", d, " dimensions, but has shape ", x.sizes())
11 | #define CHECK_SIZE(x, d, s) TORCH_CHECK((x.size(d) == (s)), #x " must have ", s, " entries at dimension ", d, ", but has ", x.size(d), " entries")
12 |
13 | #define MAX_CHANNELS 128
14 |
15 | // CUDA forward declarations
16 | std::tuple
17 | push_pull_inpaint_recursion_cuda(
18 | const torch::Tensor& mask,
19 | const torch::Tensor& data);
20 |
21 |
22 |
23 |
24 | torch::Tensor push_pull_inpaint(
25 | const torch::Tensor& mask,
26 | const torch::Tensor& data)
27 | {
28 | //check input
29 | CHECK_CUDA(mask);
30 | CHECK_CONTIGUOUS(mask);
31 | CHECK_DIM(mask, 3);
32 | int64_t B = mask.size(0);
33 | int64_t H = mask.size(1);
34 | int64_t W = mask.size(2);
35 |
36 | CHECK_CUDA(data);
37 | CHECK_CONTIGUOUS(data);
38 | CHECK_DIM(data, 4);
39 | CHECK_SIZE(data, 0, B);
40 | int64_t C = data.size(1);
41 | CHECK_SIZE(data, 2, H);
42 | CHECK_SIZE(data, 3, W);
43 | TORCH_CHECK(C < MAX_CHANNELS, "Inpainting::fastInpaint only supports up to 128 channels, but got " + std::to_string(C));
44 |
45 | //inpaint recursivly
46 | return std::get<1>(push_pull_inpaint_recursion_cuda(mask, data));
47 | }
48 |
49 |
50 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
51 | m.def("push_pull_inpaint", &push_pull_inpaint);
52 | }
--------------------------------------------------------------------------------
/imgs/dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/imgs/dataset.png
--------------------------------------------------------------------------------
/imgs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/imgs/teaser.png
--------------------------------------------------------------------------------
/inference/assets/blender_vis_base_v26_with_shrinkwrap_full_base.blend:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/inference/assets/blender_vis_base_v26_with_shrinkwrap_full_base.blend
--------------------------------------------------------------------------------
/inference/assets/face_landmarker.task:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/inference/assets/face_landmarker.task
--------------------------------------------------------------------------------
/inference_difflocks.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # ./inference_difflocks.py \
4 | # --img_path=./samples/medium_11.png \
5 | # --out_path=./outputs_inference/
6 |
7 | from inference.img2hair import DiffLocksInference
8 | import subprocess
9 | import os
10 | import argparse
11 |
12 | def run():
13 |
14 | #argparse
15 | parser = argparse.ArgumentParser(description='Get the weights of each dimensions after training a strand VAE')
16 | parser.add_argument('--strand_checkpoint_path', default="./checkpoints/strand_vae/strand_codec.pt", type=str, help='Path to the strandVAE checkpoint')
17 | parser.add_argument('--difflocks_checkpoint_path', default="./checkpoints/difflocks_diffusion/scalp_v9_40k_06730000.pth", type=str, help='Path to the difflocks checkpoint')
18 | parser.add_argument('--difflocks_config_path', default="./configs/config_scalp_texture_conditional.json", type=str, help='Path to the difflocks config')
19 | parser.add_argument('--rgb2mat_checkpoint_path', default="./checkpoints/rgb2material/rgb2material.pt", type=str, help='Path to the rgb2material checkpoint')
20 | parser.add_argument('--blender_path', type=str, default="", help='Path to the blender executable')
21 | parser.add_argument('--blender_nr_threads', default=8, type=int, help='Number of threads for blender to use')
22 | parser.add_argument('--blender_strands_subsample', default=1.0, type=float, help='Amount of subsample of the strands(1.0=full strands, 0.5=half strands)')
23 | parser.add_argument('--blender_vertex_subsample', default=1.0, type=float, help='Amount of subsample of the vertices(1.0=all vertex, 0.5=half number of vertices per strand)')
24 | parser.add_argument('--alembic_resolution', default=7, type=int, help='Resolution of the exported alembic')
25 | parser.add_argument('--export_alembic', action='store_true', help='weather to export alembic or not')
26 | parser.add_argument('--do_shrinkwrap', action='store_true', help='applies a shrinkwrap modifier in blender that pushes the strands away from the scalp so they dont pass through the head')
27 | parser.add_argument('--img_path', type=str, required=True, help='Path to the image to do inference on')
28 | parser.add_argument('--out_path', type=str, required=True, help='Path to the image to do inference on')
29 | args = parser.parse_args()
30 |
31 | print("args is", args)
32 |
33 | out_path="./outputs_inference/"
34 |
35 | difflocks= DiffLocksInference(args.strand_checkpoint_path, args.difflocks_config_path, args.difflocks_checkpoint_path, args.rgb2mat_checkpoint_path)
36 |
37 |
38 | #run----
39 | # img_path="./samples/medium_11.png"
40 | strand_points_world, hair_material_dict=difflocks.file2hair(args.img_path, args.out_path)
41 | print("hair_material_dict",hair_material_dict)
42 |
43 |
44 | #create blender file and optionally an alembic file
45 | if args.blender_path!="":
46 | cmd=[args.blender_path, "-t", str(args.blender_nr_threads), "--background", "--python", "./inference/npz2blender.py", "--", "--input_npz", os.path.join(out_path,"difflocks_output_strands.npz"), "--out_path", args.out_path, "--strands_subsample", str(args.blender_strands_subsample), "--vertex_subsample", str(args.blender_vertex_subsample), "--alembic_resolution", str(args.alembic_resolution) ]
47 | if args.do_shrinkwrap:
48 | cmd.append("--shrinkwrap")
49 | if args.export_alembic:
50 | cmd.append("--export_alembic")
51 | subprocess.run(cmd, capture_output=False)
52 |
53 | print("Finished writing to ", args.out_path)
54 |
55 | if __name__ == '__main__':
56 |
57 | run()
58 |
--------------------------------------------------------------------------------
/k_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from . import config, layers, models, sampling, utils
2 | from .layers import Denoiser
3 |
--------------------------------------------------------------------------------
/k_diffusion/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import flops
2 | from .flags import checkpointing, get_checkpointing
3 | # from .image_v1 import ImageDenoiserModelV1
4 | # from .image_transformer_v2 import ImageTransformerDenoiserModelV2
5 | from .image_transformer_v2_conditional import ImageTransformerDenoiserModelV2Conditional
6 |
7 |
--------------------------------------------------------------------------------
/k_diffusion/models/attention.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | from inspect import isfunction
3 | import math
4 | from k_diffusion.models.modules import AxialRoPE, apply_rotary_emb_
5 | from k_diffusion.models.modules import AdaRMSNorm, FeedForwardBlock, LinearGEGLU, RMSNorm, apply_wd, use_flash_2, zero_init
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn, einsum
9 | from einops import rearrange, repeat
10 | from typing import Optional, Any
11 |
12 | import flash_attn
13 |
14 |
15 |
16 | try:
17 | import xformers
18 | import xformers.ops
19 |
20 | XFORMERS_IS_AVAILBLE = True
21 | except:
22 | XFORMERS_IS_AVAILBLE = False
23 |
24 | # CrossAttn precision handling
25 | import os
26 |
27 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
28 |
29 |
30 | def zero_module(module):
31 | """
32 | Zero out the parameters of a module and return it.
33 | """
34 | for p in module.parameters():
35 | p.detach().zero_()
36 | return module
37 |
38 |
39 | def scale_for_cosine_sim(q, k, scale, eps):
40 | dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
41 | sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
42 | sum_sq_k = torch.sum(k.to(dtype)**2, dim=-1, keepdim=True)
43 | sqrt_scale = torch.sqrt(scale.to(dtype))
44 | scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps)
45 | scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps)
46 | return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)
47 |
48 | def scale_for_cosine_sim_qkv(qkv, scale, eps):
49 | q, k, v = qkv.unbind(2)
50 | q, k = scale_for_cosine_sim(q, k, scale[:, None], eps)
51 | return torch.stack((q, k, v), dim=2)
52 |
53 | def scale_for_cosine_sim_single(q, scale, eps):
54 | dtype = reduce(torch.promote_types, (q.dtype, scale.dtype, torch.float32))
55 | sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
56 | sqrt_scale = torch.sqrt(scale.to(dtype))
57 | scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps)
58 | return q * scale_q.to(q.dtype)
59 |
60 | class SpatialTransformerSimpleV2(nn.Module):
61 | """
62 | Transformer block for image-like data.
63 | First, project the input (aka embedding)
64 | and reshape to b, t, d.
65 | Then apply standard transformer action.
66 | Finally, reshape to image
67 | NEW: use_linear for more efficiency instead of the 1x1 convs
68 | """
69 |
70 | def __init__(self, in_channels, n_heads, d_head,
71 | global_cond_dim,
72 | do_self_attention=True,
73 | dropout=0.,
74 | context_dim=None,
75 | ):
76 | super().__init__()
77 |
78 |
79 | self.in_channels = in_channels
80 | inner_dim = n_heads * d_head
81 | self.n_heads = n_heads
82 | self.d_head = d_head
83 | self.do_self_attention=do_self_attention
84 |
85 |
86 | self.x_in_norm = AdaRMSNorm(in_channels, global_cond_dim)
87 |
88 | #x to qkv
89 | if self.do_self_attention:
90 | self.x_qkv_proj = apply_wd(torch.nn.Linear(in_channels, inner_dim * 3, bias=False))
91 | else:
92 | self.x_q_proj = apply_wd(torch.nn.Linear(in_channels, inner_dim, bias=False))
93 | self.x_scale = nn.Parameter(torch.full([self.n_heads], 10.0))
94 |
95 | self.x_pos_emb = AxialRoPE(d_head // 2, self.n_heads)
96 |
97 |
98 | #context to kv
99 | self.cond_kv_proj = apply_wd(torch.nn.Linear(context_dim, inner_dim * 2, bias=False))
100 | self.cond_scale = nn.Parameter(torch.full([self.n_heads], 10.0))
101 | self.cond_pos_emb = AxialRoPE(d_head // 2, self.n_heads)
102 |
103 | self.ff = FeedForwardBlock(in_channels, d_ff=int(in_channels*2), cond_features=global_cond_dim, dropout=dropout)
104 |
105 | self.dropout = nn.Dropout(dropout)
106 | self.proj_out = apply_wd(zero_module(nn.Linear(in_channels, inner_dim)))
107 |
108 |
109 | def forward(self, x, pos, global_cond, context=None, context_pos=None):
110 | b, c, h, w = x.shape
111 | x_in = x
112 | x = rearrange(x, 'b c h w -> b h w c')
113 | context = rearrange(context, 'b c h w -> b h w c')
114 | x = self.x_in_norm(x, global_cond)
115 |
116 | if self.do_self_attention:
117 | #x to qkv
118 | x_qkv = self.x_qkv_proj(x)
119 | pos = rearrange(pos, "... h w e -> ... (h w) e").to(x_qkv.dtype)
120 | x_theta = self.x_pos_emb(pos)
121 | if use_flash_2(x_qkv):
122 | x_qkv = rearrange(x_qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head)
123 | x_qkv = scale_for_cosine_sim_qkv(x_qkv, self.x_scale, 1e-6)
124 | x_theta = torch.stack((x_theta, x_theta, torch.zeros_like(x_theta)), dim=-3)
125 | x_qkv = apply_rotary_emb_(x_qkv, x_theta)
126 | x_q, x_k, x_v = x_qkv.chunk(3,dim=-3)
127 | else:
128 | print("we couldnt run flash2, maybe it's not installed or the input si not bfloat16")
129 | exit(1)
130 | else:
131 | #x to q
132 | x_q = self.x_q_proj(x)
133 | pos = rearrange(pos, "... h w e -> ... (h w) e").to(x_q.dtype)
134 | x_theta = self.x_pos_emb(pos)
135 | if use_flash_2(x_q):
136 | x_q = rearrange(x_q, "n h w (nh e) -> n (h w) nh e", e=self.d_head)
137 | x_q = scale_for_cosine_sim_single(x_q, self.x_scale[:, None], 1e-6)
138 | x_q=x_q.unsqueeze(2) #n (h w) 1 nh e
139 | x_theta=x_theta.unsqueeze(1)
140 | x_q = apply_rotary_emb_(x_q, x_theta)
141 | else:
142 | print("we couldnt run flash2, maybe it's not installed or the input si not bfloat16")
143 | exit(1)
144 |
145 |
146 | #context to kv
147 | cond_kv = self.cond_kv_proj(context)
148 | # print("cond_kv init",cond_kv.shape)
149 | context_pos = rearrange(context_pos, "... h w e -> ... (h w) e").to(cond_kv.dtype)
150 | cond_theta = self.cond_pos_emb(context_pos)
151 | if use_flash_2(cond_kv):
152 | cond_kv = rearrange(cond_kv, "n h w (t nh e) -> n (h w) t nh e", t=2, e=self.d_head)
153 | cond_k, cond_v = cond_kv.unbind(2) # makes each n (h w) nh e
154 | cond_k = scale_for_cosine_sim_single(cond_k, self.cond_scale[:, None], 1e-6)
155 | cond_k=cond_k.unsqueeze(2) #n (h w) 1 nh e
156 | cond_theta=cond_theta.unsqueeze(1)
157 | cond_k = apply_rotary_emb_(cond_k, cond_theta)
158 | cond_k=cond_k.squeeze(2)
159 | else:
160 | print("we couldnt run flash2, maybe it's not installed or the input si not bfloat16")
161 | exit(1)
162 |
163 | #doing self attention by concating K and V between X and cond
164 | if self.do_self_attention:
165 | k = torch.cat([x_k, cond_k.unsqueeze(2)], dim=1)
166 | v = torch.cat([x_v, cond_v.unsqueeze(2)], dim=1)
167 | else:
168 | # print("not doing self attention")
169 | k=cond_k.unsqueeze(2)
170 | v=cond_v.unsqueeze(2)
171 | q=x_q
172 |
173 |
174 | #rearange a bit
175 | q=q.squeeze(2)
176 | kv=torch.cat([k,v],2)
177 | # print("final q before giving to flash",q.shape)
178 | # print("final kv before giving to flash",kv.shape)
179 |
180 | x = flash_attn.flash_attn_kvpacked_func(q, kv, softmax_scale=1.0)
181 |
182 | x = rearrange(x, 'b (h w) nh e -> b (h w) (nh e)', nh=self.n_heads, e=self.d_head, h=h, w=w)
183 |
184 | #last ff
185 | x = self.dropout(x)
186 | x = self.proj_out(x)
187 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w, c=c)
188 |
189 | x= x + x_in
190 |
191 | #attention part finished------
192 |
193 | #linear feed forward
194 | x = rearrange(x, 'b c h w -> b h w c', h=h, w=w, c=c)
195 |
196 | # print("x before ff is ", x.shape)
197 | x = self.ff(x, global_cond)
198 |
199 |
200 | x = rearrange(x, 'b h w c -> b c h w', h=h, w=w, c=c)
201 |
202 | return x
203 |
204 |
205 |
--------------------------------------------------------------------------------
/k_diffusion/models/axial_rope.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch._dynamo
5 | from torch import nn
6 |
7 | from . import flags
8 |
9 | if flags.get_use_compile():
10 | torch._dynamo.config.suppress_errors = True
11 |
12 |
13 | def rotate_half(x):
14 | x1, x2 = x[..., 0::2], x[..., 1::2]
15 | x = torch.stack((-x2, x1), dim=-1)
16 | *shape, d, r = x.shape
17 | return x.view(*shape, d * r)
18 |
19 |
20 | @flags.compile_wrap
21 | def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
22 | freqs = freqs.to(t)
23 | rot_dim = freqs.shape[-1]
24 | end_index = start_index + rot_dim
25 | assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
26 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
27 | t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
28 | return torch.cat((t_left, t, t_right), dim=-1)
29 |
30 |
31 | def centers(start, stop, num, dtype=None, device=None):
32 | edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
33 | return (edges[:-1] + edges[1:]) / 2
34 |
35 |
36 | def make_grid(h_pos, w_pos):
37 | grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1)
38 | h, w, d = grid.shape
39 | return grid.view(h * w, d)
40 |
41 |
42 | def bounding_box(h, w, pixel_aspect_ratio=1.0):
43 | # Adjusted dimensions
44 | w_adj = w
45 | h_adj = h * pixel_aspect_ratio
46 |
47 | # Adjusted aspect ratio
48 | ar_adj = w_adj / h_adj
49 |
50 | # Determine bounding box based on the adjusted aspect ratio
51 | y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0
52 | if ar_adj > 1:
53 | y_min, y_max = -1 / ar_adj, 1 / ar_adj
54 | elif ar_adj < 1:
55 | x_min, x_max = -ar_adj, ar_adj
56 |
57 | return y_min, y_max, x_min, x_max
58 |
59 |
60 | def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None):
61 | y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio)
62 | if align_corners:
63 | h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device)
64 | w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device)
65 | else:
66 | h_pos = centers(y_min, y_max, h, dtype=dtype, device=device)
67 | w_pos = centers(x_min, x_max, w, dtype=dtype, device=device)
68 | return make_grid(h_pos, w_pos)
69 |
70 |
71 | def freqs_pixel(max_freq=10.0):
72 | def init(shape):
73 | freqs = torch.linspace(1.0, max_freq / 2, shape[-1]) * math.pi
74 | return freqs.log().expand(shape)
75 | return init
76 |
77 |
78 | def freqs_pixel_log(max_freq=10.0):
79 | def init(shape):
80 | log_min = math.log(math.pi)
81 | log_max = math.log(max_freq * math.pi / 2)
82 | return torch.linspace(log_min, log_max, shape[-1]).expand(shape)
83 | return init
84 |
85 |
86 | class AxialRoPE(nn.Module):
87 | def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)):
88 | super().__init__()
89 | self.n_heads = n_heads
90 | self.start_index = start_index
91 | log_freqs = freqs_init((n_heads, dim // 4))
92 | self.freqs_h = nn.Parameter(log_freqs.clone())
93 | self.freqs_w = nn.Parameter(log_freqs.clone())
94 |
95 | def extra_repr(self):
96 | dim = (self.freqs_h.shape[-1] + self.freqs_w.shape[-1]) * 2
97 | return f"dim={dim}, n_heads={self.n_heads}, start_index={self.start_index}"
98 |
99 | def get_freqs(self, pos):
100 | if pos.shape[-1] != 2:
101 | raise ValueError("input shape must be (..., 2)")
102 | freqs_h = pos[..., None, None, 0] * self.freqs_h.exp()
103 | freqs_w = pos[..., None, None, 1] * self.freqs_w.exp()
104 | freqs = torch.cat((freqs_h, freqs_w), dim=-1).repeat_interleave(2, dim=-1)
105 | return freqs.transpose(-2, -3)
106 |
107 | def forward(self, x, pos):
108 | freqs = self.get_freqs(pos)
109 | return apply_rotary_emb(freqs, x, self.start_index)
110 |
--------------------------------------------------------------------------------
/k_diffusion/models/flags.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from functools import update_wrapper
3 | import os
4 | import threading
5 |
6 | import torch
7 |
8 |
9 | def get_use_compile():
10 | return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1"
11 |
12 |
13 | def get_use_flash_attention_2():
14 | return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1"
15 |
16 |
17 | state = threading.local()
18 | state.checkpointing = False
19 |
20 |
21 | @contextmanager
22 | def checkpointing(enable=True):
23 | try:
24 | old_checkpointing, state.checkpointing = state.checkpointing, enable
25 | yield
26 | finally:
27 | state.checkpointing = old_checkpointing
28 |
29 |
30 | def get_checkpointing():
31 | return getattr(state, "checkpointing", False)
32 |
33 |
34 | class compile_wrap:
35 | def __init__(self, function, *args, **kwargs):
36 | self.function = function
37 | self.args = args
38 | self.kwargs = kwargs
39 | self._compiled_function = None
40 | update_wrapper(self, function)
41 |
42 | @property
43 | def compiled_function(self):
44 | if self._compiled_function is not None:
45 | return self._compiled_function
46 | if get_use_compile():
47 | try:
48 | self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs)
49 | except RuntimeError:
50 | self._compiled_function = self.function
51 | else:
52 | self._compiled_function = self.function
53 | return self._compiled_function
54 |
55 | def __call__(self, *args, **kwargs):
56 | return self.compiled_function(*args, **kwargs)
57 |
--------------------------------------------------------------------------------
/k_diffusion/models/flops.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import math
3 | import threading
4 |
5 |
6 | state = threading.local()
7 | state.flop_counter = None
8 |
9 |
10 | @contextmanager
11 | def flop_counter(enable=True):
12 | try:
13 | old_flop_counter = state.flop_counter
14 | state.flop_counter = FlopCounter() if enable else None
15 | yield state.flop_counter
16 | finally:
17 | state.flop_counter = old_flop_counter
18 |
19 |
20 | class FlopCounter:
21 | def __init__(self):
22 | self.ops = []
23 |
24 | def op(self, op, *args, **kwargs):
25 | self.ops.append((op, args, kwargs))
26 |
27 | @property
28 | def flops(self):
29 | flops = 0
30 | for op, args, kwargs in self.ops:
31 | flops += op(*args, **kwargs)
32 | return flops
33 |
34 |
35 | def op(op, *args, **kwargs):
36 | if getattr(state, "flop_counter", None):
37 | state.flop_counter.op(op, *args, **kwargs)
38 |
39 |
40 | def op_linear(x, weight):
41 | return math.prod(x) * weight[0]
42 |
43 |
44 | def op_attention(q, k, v):
45 | *b, s_q, d_q = q
46 | *b, s_k, d_k = k
47 | *b, s_v, d_v = v
48 | return math.prod(b) * s_q * s_k * (d_q + d_v)
49 |
50 |
51 | def op_natten(q, k, v, kernel_size):
52 | *q_rest, d_q = q
53 | *_, d_v = v
54 | return math.prod(q_rest) * (d_q + d_v) * kernel_size**2
55 |
--------------------------------------------------------------------------------
/losses/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from losses.losses import compute_loss_dir_l1, compute_loss_curv_l1, compute_loss_l1, compute_loss_l2, compute_loss_kl
3 | import numpy as np
4 |
5 | class StrandVAELoss(torch.nn.Module):
6 | def __init__(
7 | self,
8 | eps: float = 1e-8,
9 | ):
10 | super().__init__()
11 |
12 | self.eps = eps
13 |
14 | def forward(self, phase, gt_dict, pred_dict, latent_dict, hyperparams):
15 | gt_strands= gt_dict["strand_positions"]
16 | pred_hair_strands= pred_dict["strand_positions"]
17 |
18 | # w_kl=map_range_val(phase.epoch_nr, hyperparams.loss_kl_anneal_epoch_nr_start, hyperparams.loss_kl_anneal_epoch_nr_finish, 0.0, hyperparams.loss_kl_weight)
19 |
20 |
21 | loss_dict = {}
22 |
23 | #we always predict the xyz loss just because it's fast
24 | # loss_l2 = compute_loss_l2(gt_strands, pred_hair_strands)
25 | loss_pos = compute_loss_l1(gt_strands, pred_hair_strands)
26 | loss_dir = compute_loss_dir_l1(gt_strands, pred_hair_strands)
27 | loss_curv = compute_loss_curv_l1(gt_strands, pred_hair_strands)
28 | loss_kl = 0.0
29 | if "z_logstd" in latent_dict:
30 | loss_kl = compute_loss_kl(latent_dict["z_mean"], latent_dict["z_logstd"])
31 | loss = loss_pos*hyperparams.loss_pos_weight + loss_dir*hyperparams.loss_dir_weight + loss_curv*hyperparams.loss_curv_weight + loss_kl*hyperparams.loss_kl_weight
32 | # loss = loss_pos*hyperparams.loss_pos_weight + loss_dir*hyperparams.loss_dir_weight + loss_curv*hyperparams.loss_curv_weight + loss_kl*w_kl
33 | loss_dict['loss'] = loss
34 | loss_dict['loss_pos'] = loss_pos
35 | loss_dict['loss_dir'] = loss_dir
36 | loss_dict['loss_curv'] = loss_curv
37 | loss_dict['loss_kl'] = loss_kl
38 |
39 |
40 | return loss_dict
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/losses/loss_utils.py:
--------------------------------------------------------------------------------
1 | import scipy.signal
2 | import torch
3 |
4 | def apply_reduction(losses, reduction="none"):
5 | """Apply reduction to collection of losses."""
6 | if reduction == "mean":
7 | losses = losses.mean()
8 | elif reduction == "sum":
9 | losses = losses.sum()
10 | return losses
11 |
12 |
--------------------------------------------------------------------------------
/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch
3 | import numpy as np
4 | from typing import List, Any
5 | from losses.loss_utils import apply_reduction
6 | from utils.general_util import get_window
7 | from utils.strand_util import compute_dirs, compute_curv
8 |
9 |
10 | def compute_loss_l2(gt_hair_strands, pred_hair_strands):
11 | loss_l2 = ((pred_hair_strands - gt_hair_strands) ** 2).mean()
12 |
13 | return loss_l2
14 |
15 | def compute_loss_l1(gt_hair_strands, pred_hair_strands):
16 | loss_l1 = torch.nn.functional.l1_loss(pred_hair_strands, gt_hair_strands).mean()
17 |
18 | return loss_l1
19 |
20 | def compute_loss_dir_dot(gt_hair_strands, pred_hair_strands):
21 | nr_verts_per_strand=256
22 |
23 | pred_points = pred_hair_strands.view(-1, nr_verts_per_strand, 3)
24 | gt_hair_strands=gt_hair_strands.view(-1, nr_verts_per_strand, 3)
25 |
26 | # get also a loss for the direciton, we need to compute the direction
27 | pred_deltas = compute_dirs(pred_points)
28 | pred_deltas = pred_deltas.view(-1, 3)
29 | pred_deltas = torch.nn.functional.normalize(pred_deltas, dim=-1)
30 |
31 | gt_dir = compute_dirs(gt_hair_strands)
32 | gt_dir = gt_dir.view(-1, 3)
33 | gt_dir = torch.nn.functional.normalize(gt_dir, dim=-1)
34 | # loss_dir = self.cosine_embed_loss(pred_deltas, gt_dir, torch.ones(gt_dir.shape[0]).cuda())
35 |
36 | dot = torch.sum(pred_deltas * gt_dir, dim=-1)
37 |
38 | loss_dir = (1.0 - dot).mean()
39 |
40 | return loss_dir
41 |
42 | def compute_loss_dir_l1(gt_hair_strands, pred_hair_strands):
43 | nr_verts_per_strand=256
44 |
45 | pred_points = pred_hair_strands.view(-1, nr_verts_per_strand, 3)
46 | gt_hair_strands=gt_hair_strands.view(-1, nr_verts_per_strand, 3)
47 |
48 | # get also a loss for the direciton, we need to compute the direction
49 | pred_deltas = compute_dirs(pred_points)
50 | pred_deltas = pred_deltas.view(-1, 3)
51 | pred_deltas = pred_deltas*nr_verts_per_strand #Just because the deltas are very tiny and I want them in a nicer range for the loss
52 |
53 | gt_dir = compute_dirs(gt_hair_strands)
54 | gt_dir = gt_dir.view(-1, 3)
55 | gt_dir = gt_dir*nr_verts_per_strand #Just because the deltas are very tiny and I want them in a nicer range for the loss
56 | # loss_dir = self.cosine_embed_loss(pred_deltas, gt_dir, torch.ones(gt_dir.shape[0]).cuda())
57 |
58 | loss_l1 = torch.nn.functional.l1_loss(pred_deltas, gt_dir).mean()
59 | return loss_l1
60 |
61 | def compute_loss_curv_l1(gt_hair_strands, pred_hair_strands):
62 | nr_verts_per_strand=256
63 |
64 | pred_points = pred_hair_strands.view(-1, nr_verts_per_strand, 3)
65 | gt_hair_strands=gt_hair_strands.view(-1, nr_verts_per_strand, 3)
66 |
67 | # get also a loss for the direciton, we need to compute the direction
68 | pred_deltas = compute_dirs(pred_points)
69 | pred_curvs = compute_curv(pred_deltas)
70 | pred_curvs = pred_curvs.view(-1, 3)
71 | pred_curvs = pred_curvs*nr_verts_per_strand #Just because the deltas are very tiny and I want them in a nicer range for the loss
72 |
73 | gt_dir = compute_dirs(gt_hair_strands)
74 | gt_curvs = compute_curv(gt_dir)
75 | gt_curvs = gt_curvs.view(-1, 3)
76 | gt_curvs = gt_curvs*nr_verts_per_strand #Just because the deltas are very tiny and I want them in a nicer range for the loss
77 | # loss_dir = self.cosine_embed_loss(pred_deltas, gt_dir, torch.ones(gt_dir.shape[0]).cuda())
78 |
79 | loss_l1 = torch.nn.functional.l1_loss(pred_curvs, gt_curvs).mean()
80 | return loss_l1
81 |
82 | def compute_loss_kl(mean, logstd):
83 | #get input data
84 | kl_loss = 0
85 |
86 | #kl loss
87 |
88 | kl_shape = kl( mean, logstd)
89 | # free bits from IAF-VAE. so that if the KL drops below a certan value, then we stop reducing the KL
90 | kl_shape = torch.clamp(kl_shape, min=0.25)
91 |
92 | kl_loss = kl_shape.mean()
93 |
94 | return kl_loss
95 |
96 | def kl(mean, logstd):
97 | kl = (-0.5 - logstd + 0.5 * mean ** 2 + 0.5 * torch.exp(2 * logstd))
98 | return kl
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/models/rgb_to_material.py:
--------------------------------------------------------------------------------
1 | from modules.networks import kaiming_init
2 | import torch
3 | from torch import nn
4 | import os
5 | import json
6 | import numpy as np
7 |
8 |
9 | class RGB2MaterialModel(nn.Module):
10 |
11 | def __init__(self, input_dim, out_dim, hidden_dim):
12 | super().__init__()
13 |
14 | self.out_dim=out_dim
15 |
16 |
17 |
18 | #attempt 2
19 | self.dino2conf=nn.Sequential(
20 | nn.Conv2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=1, padding=0, bias=True),
21 | nn.SiLU(),
22 | nn.Conv2d(in_channels=hidden_dim, out_channels=1, kernel_size=1, padding=0, bias=True),
23 | nn.Sigmoid()
24 | )
25 | self.dino2mat=nn.Sequential(
26 | nn.Conv2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=1, padding=0, bias=True),
27 | nn.SiLU(),
28 | nn.Conv2d(in_channels=hidden_dim, out_channels=out_dim, kernel_size=1, padding=0, bias=True),
29 | nn.Sigmoid()
30 | )
31 |
32 |
33 |
34 |
35 | self.apply(lambda x: kaiming_init(x, False, nonlinearity="silu"))
36 |
37 | def save(self, root_folder, experiment_name, hyperparams, iter_nr, info=None):
38 | name=str(iter_nr)
39 | if info is not None:
40 | name+="_"+info
41 | models_path = os.path.join(root_folder, experiment_name, name, "models")
42 | if not os.path.exists(models_path):
43 | os.makedirs(models_path, exist_ok=True)
44 | torch.save(self.state_dict(), os.path.join(models_path, "rgb2material.pt"))
45 |
46 | hyperparams_params_path=os.path.join(models_path, "hyperparams.json")
47 | with open(hyperparams_params_path, 'w', encoding='utf-8') as f:
48 | json.dump(vars(hyperparams), f, ensure_ascii=False, indent=4)
49 |
50 |
51 | def forward(self, batch_dict):
52 |
53 | x=batch_dict["dinov2_latents"] #BCHW (1,1024,55,55)
54 |
55 |
56 | #attempt 2, each patch predicts a confidence and a material, then we average all the materials across all patches, weighted by confidence
57 | conf=self.dino2conf(x)
58 | mat=self.dino2mat(x)
59 |
60 |
61 | #average the mat across the pixels
62 | avg_mat = (mat*conf).sum((2,3)) / (conf.sum((2,3)) +1e-6) #sum across all H and W dimensions
63 | x=avg_mat
64 |
65 |
66 | #split the material in parameters, at least the ones that are meaningfull and actually have a loss applied to them
67 | melanin=x[:,3]
68 | redness=x[:,4]
69 | root_darkness_start=x[:,8]
70 | root_darkness_end=x[:,9]
71 | root_darkness_strength=x[:,10]
72 |
73 |
74 |
75 | pred_dict={}
76 | pred_dict["material"]=x
77 | pred_dict["melanin"]=melanin
78 | pred_dict["redness"]=redness
79 | pred_dict["root_darkness_start"]=root_darkness_start
80 | pred_dict["root_darkness_end"]=root_darkness_end
81 | pred_dict["root_darkness_strength"]=root_darkness_strength
82 |
83 |
84 | return pred_dict
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/modules/__init__.py
--------------------------------------------------------------------------------
/modules/edm2_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class MPFourier(torch.nn.Module):
6 | def __init__(self, num_channels, bandwidth=1):
7 | super().__init__()
8 | self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
9 | self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
10 |
11 | def forward(self, x):
12 | y = x.to(torch.float32)
13 | y = y.ger(self.freqs.to(torch.float32))
14 | y = y + self.phases.to(torch.float32)
15 | y = y.cos() * np.sqrt(2)
16 | return y.to(x.dtype)
--------------------------------------------------------------------------------
/modules/networks.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | import torch.nn.functional as F
5 | import math
6 |
7 |
8 | class LinearDummy(torch.nn.Module):
9 | def __init__(self, in_channels, out_channels, bias=True):
10 | super().__init__()
11 | self.out_channels = out_channels
12 | self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels))
13 | self.bias=None
14 | if bias:
15 | self.bias = torch.nn.Parameter(torch.zeros(out_channels))
16 |
17 | def forward(self, x, gain=1):
18 | return F.linear(x, self.weight, self.bias)
19 |
20 |
21 | class BlockSiren(torch.nn.Module):
22 | def __init__(self, in_channels, out_channels, use_bias=True, is_first_layer=False, scale_init=1.0):
23 | super(BlockSiren, self).__init__()
24 | self.use_bias = use_bias
25 | # self.activ = activ
26 | self.is_first_layer = is_first_layer
27 | self.scale_init = scale_init
28 | # self.freq_scaling = scale_init
29 |
30 | self.conv = torch.nn.Linear(in_channels, out_channels, bias=self.use_bias).cuda()
31 |
32 | with torch.no_grad():
33 |
34 | #following the official implementation from https://github.com/vsitzmann/siren/blob/master/explore_siren.ipynb
35 | if self.is_first_layer:
36 | self.conv.weight.uniform_(-1 / in_channels,
37 | 1 / in_channels)
38 | else:
39 | self.conv.weight.uniform_(-np.sqrt(6 / in_channels) / self.scale_init,
40 | np.sqrt(6 / in_channels) / self.scale_init)
41 |
42 | self.conv.bias.data.zero_()
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 |
47 | x = self.scale_init * x
48 |
49 | x = torch.sin(x)
50 |
51 | return x
52 |
53 |
54 | class LinearWN_v2(torch.nn.Linear):
55 | def __init__(self, in_features, out_features, bias=True):
56 | super(LinearWN_v2, self).__init__(in_features, out_features, bias)
57 | self.g = torch.nn.Parameter(torch.ones(out_features))
58 |
59 |
60 | self.in_features=in_features
61 | self.out_features=out_features
62 |
63 | def forward(self, input):
64 | w= torch._weight_norm(self.weight, self.g, 0)
65 | out=F.linear(input, w, self.bias)
66 | return out
67 |
68 |
69 |
70 | class Conv1dWN_v2(torch.nn.Conv1d):
71 | def __init__(
72 | self,
73 | in_channels,
74 | out_channels,
75 | kernel_size,
76 | stride=1,
77 | padding=0,
78 | padding_mode="zeros",
79 | dilation=1,
80 | groups=1,
81 | ):
82 | super(Conv1dWN_v2, self).__init__(
83 | in_channels,
84 | out_channels,
85 | kernel_size,
86 | stride,
87 | padding,
88 | dilation,
89 | groups,
90 | True,
91 | )
92 | self.g = torch.nn.Parameter(torch.ones(out_channels))
93 |
94 | self.padding_amount=padding
95 | self.padding_mode=padding_mode
96 |
97 | def forward(self, x):
98 | w= torch._weight_norm(self.weight, self.g, 0)
99 |
100 |
101 |
102 | if self.padding_mode != 'zeros':
103 | x= F.pad(x, (self.padding_amount,self.padding_amount), mode=self.padding_mode)
104 |
105 | x= F.conv1d(
106 | x,
107 | w,
108 | bias=self.bias,
109 | stride=self.stride,
110 | padding=0,
111 | dilation=self.dilation,
112 | groups=self.groups,
113 | )
114 | # print("x after conv", x.shape)
115 | return x
116 | else:
117 |
118 | return F.conv1d(
119 | x,
120 | w,
121 | bias=self.bias,
122 | stride=self.stride,
123 | padding=self.padding,
124 | dilation=self.dilation,
125 | groups=self.groups,
126 | )
127 |
128 |
129 | def kaiming_init(m, is_linear, nonlinearity="silu"):
130 | # gain = math.sqrt(2.0 / (1.0 + alpha**2))
131 |
132 | # gain=np.sqrt(10.5)
133 | if nonlinearity=="silu":
134 | gain=np.sqrt(2.3) #works fine with silu
135 | elif nonlinearity=="relu":
136 | gain=np.sqrt(2) #works fine with silu
137 | # gain=np.sqrt(2.15)
138 | # gain=np.sqrt(0.92) #for mpsilu
139 | scale=1.0
140 |
141 | if is_linear:
142 | gain = 1
143 |
144 | # print("effective scale", gain*scale)
145 |
146 | # print("m is ",m)
147 | # help(m)
148 |
149 | if isinstance(m, torch.nn.Conv2d):
150 | ksize = m.kernel_size[0] * m.kernel_size[1]
151 | n1 = m.in_channels
152 | n2 = m.out_channels
153 |
154 | # std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize))
155 | std = gain * math.sqrt(n1 * ksize)
156 | elif isinstance(m, torch.nn.ConvTranspose2d):
157 | ksize = m.kernel_size[0] * m.kernel_size[1] // 4
158 | n1 = m.in_channels
159 | n2 = m.out_channels
160 |
161 | # std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize))
162 | std = gain * math.sqrt(n1 * ksize)
163 | elif isinstance(m, torch.nn.Conv1d):
164 | ksize = m.kernel_size[0]
165 | n1 = m.in_channels
166 | n2 = m.out_channels
167 | # std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize))
168 | std = gain * math.sqrt(n1 * ksize)
169 | elif isinstance(m, torch.nn.Linear):
170 | n1 = m.in_features
171 | n2 = m.out_features
172 |
173 | # std = gain * math.sqrt(2.0 / (n1 + n2))
174 | std = gain * math.sqrt(n1)
175 | # elif isinstance(m, SMConv1d_v2):
176 | # print("SMConv1d_v2")
177 | # exit()
178 | else:
179 | return
180 |
181 | # help(m)
182 |
183 | # print("m is ", m)
184 |
185 | # print("std is", std)
186 |
187 | # m.weight.data.normal_(0, std)
188 | # m.weight.data.uniform_(-std * scale, std * scale)
189 | fan = torch.nn.init._calculate_correct_fan(m.weight, "fan_in")
190 | std = gain / math.sqrt(fan)
191 | # print("std is", std)
192 | with torch.no_grad():
193 | # m.weight.normal_(0, std)
194 | m.weight.data.uniform_(-std * math.sqrt(3.0) * scale, std * math.sqrt(3.0) * scale)
195 | if m.bias is not None:
196 | m.bias.data.zero_()
197 |
198 | if isinstance(m, torch.nn.ConvTranspose2d):
199 | # hardcoded for stride=2 for now
200 | m.weight.data[:, :, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2]
201 | m.weight.data[:, :, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2]
202 | m.weight.data[:, :, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2]
203 |
204 | # if isinstance(m, Conv2dWNUB) or isinstance(m, ConvTranspose2dWNUB) or isinstance(m, LinearWN):
205 | # print("m is ", m)
206 | if (
207 | # isinstance(m, torch.Conv2dWNUB)
208 | # isinstance(m, torch.Conv2dWN)
209 | # or isinstance(m, torch.ConvTranspose2dWN)
210 | # or isinstance(m, torch.ConvTranspose2dWNUB)
211 | isinstance(m, LinearWN_v2)
212 | or isinstance(m, Conv1dWN_v2)
213 | ):
214 | # print("selected m is ", m)
215 | # help(m)
216 | # norm = np.sqrt(torch.sum(m.weight.data[:] ** 2))
217 | dims = list(range(1, len(m.weight.shape)))
218 | norm = torch.norm(m.weight, 2, dim=dims, keepdim=True)
219 | # print("weight is ", m.weight.shape)
220 | # print("norm",norm.shape)
221 | # print("m.g.data",m.g.data.shape)
222 | # norm = torch.norm(m.weight, 2, dim=0, keepdim=True)
223 | m.g.data = norm
224 |
225 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.31.0
2 | torch==2.5.0
3 | torchvision==0.20.0
4 | opencv-python==4.9.0.80
5 | mediapipe==0.10.10
6 | trimesh==4.0.4
7 | jsonmerge==1.9.2
8 | dctorch==0.1.2
9 | einops==0.8.1
10 | torchdiffeq==0.2.5
11 | torchsde==0.2.6
12 | wheel==0.41.2
13 | setuptools==59.5.0
14 | libigl==2.5.1
15 |
16 |
17 |
--------------------------------------------------------------------------------
/samples/buzzcut_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/buzzcut_3.jpg
--------------------------------------------------------------------------------
/samples/cooper_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/cooper_4.jpg
--------------------------------------------------------------------------------
/samples/fiennes_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/fiennes_7.jpg
--------------------------------------------------------------------------------
/samples/freeman_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/freeman_2.png
--------------------------------------------------------------------------------
/samples/harris_5.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/harris_5.jpeg
--------------------------------------------------------------------------------
/samples/hathaway_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/hathaway_1.jpg
--------------------------------------------------------------------------------
/samples/medium_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/medium_11.png
--------------------------------------------------------------------------------
/samples/uggams_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/uggams_3.png
--------------------------------------------------------------------------------
/schedulers/linearlr.py:
--------------------------------------------------------------------------------
1 | #from pytorch 1.10
2 |
3 | import types
4 | import math
5 | from torch import inf
6 | from functools import wraps
7 | import warnings
8 | import weakref
9 | from collections import Counter
10 | from bisect import bisect_right
11 |
12 | # from .optimizer import Optimizer
13 |
14 | import torch
15 | from torch.optim import Optimizer
16 |
17 | class LinearLR(torch.optim.lr_scheduler._LRScheduler):
18 | """Decays the learning rate of each parameter group by linearly changing small
19 | multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
20 | Notice that such decay can happen simultaneously with other changes to the learning rate
21 | from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
22 |
23 | Args:
24 | optimizer (Optimizer): Wrapped optimizer.
25 | start_factor (float): The number we multiply learning rate in the first epoch.
26 | The multiplication factor changes towards end_factor in the following epochs.
27 | Default: 1./3.
28 | end_factor (float): The number we multiply learning rate at the end of linear changing
29 | process. Default: 1.0.
30 | total_iters (int): The number of iterations that multiplicative factor reaches to 1.
31 | Default: 5.
32 | last_epoch (int): The index of the last epoch. Default: -1.
33 | verbose (bool): If ``True``, prints a message to stdout for
34 | each update. Default: ``False``.
35 |
36 | Example:
37 | >>> # Assuming optimizer uses lr = 0.05 for all groups
38 | >>> # lr = 0.025 if epoch == 0
39 | >>> # lr = 0.03125 if epoch == 1
40 | >>> # lr = 0.0375 if epoch == 2
41 | >>> # lr = 0.04375 if epoch == 3
42 | >>> # lr = 0.005 if epoch >= 4
43 | >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
44 | >>> for epoch in range(100):
45 | >>> train(...)
46 | >>> validate(...)
47 | >>> scheduler.step()
48 | """
49 |
50 | def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
51 | verbose=False):
52 | if start_factor > 1.0 or start_factor < 0:
53 | raise ValueError('Starting multiplicative factor expected to be between 0 and 1.')
54 |
55 | if end_factor > 1.0 or end_factor < 0:
56 | raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
57 |
58 | self.start_factor = start_factor
59 | self.end_factor = end_factor
60 | self.total_iters = total_iters
61 | super(LinearLR, self).__init__(optimizer, last_epoch, verbose)
62 |
63 | def get_lr(self):
64 | if not self._get_lr_called_within_step:
65 | warnings.warn("To get the last learning rate computed by the scheduler, "
66 | "please use `get_last_lr()`.", UserWarning)
67 |
68 | if self.last_epoch == 0:
69 | return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
70 |
71 | if (self.last_epoch > self.total_iters):
72 | return [group['lr'] for group in self.optimizer.param_groups]
73 |
74 | return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
75 | (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
76 | for group in self.optimizer.param_groups]
77 |
78 | def _get_closed_form_lr(self):
79 | return [base_lr * (self.start_factor +
80 | (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
81 | for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/schedulers/multisteplr.py:
--------------------------------------------------------------------------------
1 | #from pytorch 1.10
2 |
3 | import types
4 | import math
5 | from torch import inf
6 | from functools import wraps
7 | import warnings
8 | import weakref
9 | from collections import Counter
10 | from bisect import bisect_right
11 |
12 | # from .optimizer import Optimizer
13 |
14 | import torch
15 | from torch.optim import Optimizer
16 |
17 |
18 | class MultiStepLR(torch.optim.lr_scheduler._LRScheduler):
19 | """Decays the learning rate of each parameter group by gamma once the
20 | number of epoch reaches one of the milestones. Notice that such decay can
21 | happen simultaneously with other changes to the learning rate from outside
22 | this scheduler. When last_epoch=-1, sets initial lr as lr.
23 |
24 | Args:
25 | optimizer (Optimizer): Wrapped optimizer.
26 | milestones (list): List of epoch indices. Must be increasing.
27 | gamma (float): Multiplicative factor of learning rate decay.
28 | Default: 0.1.
29 | last_epoch (int): The index of last epoch. Default: -1.
30 | verbose (bool): If ``True``, prints a message to stdout for
31 | each update. Default: ``False``.
32 |
33 | Example:
34 | >>> # Assuming optimizer uses lr = 0.05 for all groups
35 | >>> # lr = 0.05 if epoch < 30
36 | >>> # lr = 0.005 if 30 <= epoch < 80
37 | >>> # lr = 0.0005 if epoch >= 80
38 | >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
39 | >>> for epoch in range(100):
40 | >>> train(...)
41 | >>> validate(...)
42 | >>> scheduler.step()
43 | """
44 |
45 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
46 | self.milestones = Counter(milestones)
47 | self.gamma = gamma
48 | super(MultiStepLR, self).__init__(optimizer, last_epoch, verbose)
49 |
50 | def get_lr(self):
51 | if not self._get_lr_called_within_step:
52 | warnings.warn("To get the last learning rate computed by the scheduler, "
53 | "please use `get_last_lr()`.", UserWarning)
54 |
55 | if self.last_epoch not in self.milestones:
56 | return [group['lr'] for group in self.optimizer.param_groups]
57 | return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
58 | for group in self.optimizer.param_groups]
59 |
60 | def _get_closed_form_lr(self):
61 | milestones = list(sorted(self.milestones.elements()))
62 | return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
63 | for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/schedulers/pytorch_warmup/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseWarmup, LinearWarmup, ExponentialWarmup
2 | from .untuned import UntunedLinearWarmup, UntunedExponentialWarmup
3 | from .radam import RAdamWarmup, rho_fn, rho_inf_fn, get_offset
4 |
5 | __version__ = "0.2.0.dev0"
6 |
7 | __all__ = [
8 | 'BaseWarmup',
9 | 'LinearWarmup',
10 | 'ExponentialWarmup',
11 | 'UntunedLinearWarmup',
12 | 'UntunedExponentialWarmup',
13 | 'RAdamWarmup',
14 | 'rho_fn',
15 | 'rho_inf_fn',
16 | 'get_offset',
17 | ]
--------------------------------------------------------------------------------
/schedulers/pytorch_warmup/base.py:
--------------------------------------------------------------------------------
1 | ##from https://github.com/Tony-Y/pytorch_warmup/blob/master/pytorch_warmup/base.py
2 |
3 | import math
4 | from contextlib import contextmanager
5 | from torch.optim import Optimizer
6 |
7 |
8 | def _check_optimizer(optimizer):
9 | if not isinstance(optimizer, Optimizer):
10 | raise TypeError('{} ({}) is not an Optimizer.'.format(
11 | optimizer, type(optimizer).__name__))
12 |
13 |
14 | class BaseWarmup(object):
15 | """Base class for all warmup schedules
16 |
17 | Arguments:
18 | optimizer (Optimizer): an instance of a subclass of Optimizer
19 | warmup_params (list): warmup paramters
20 | last_step (int): The index of last step. (Default: -1)
21 | """
22 |
23 | def __init__(self, optimizer, warmup_params, last_step=-1):
24 | self.optimizer = optimizer
25 | self.warmup_params = warmup_params
26 | self.last_step = last_step
27 | self.lrs = [group['lr'] for group in self.optimizer.param_groups]
28 | self.dampen()
29 |
30 | def state_dict(self):
31 | """Returns the state of the warmup scheduler as a :class:`dict`.
32 |
33 | It contains an entry for every variable in self.__dict__ which
34 | is not the optimizer.
35 | """
36 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
37 |
38 | def load_state_dict(self, state_dict):
39 | """Loads the warmup scheduler's state.
40 |
41 | Arguments:
42 | state_dict (dict): warmup scheduler state. Should be an object returned
43 | from a call to :meth:`state_dict`.
44 | """
45 | self.__dict__.update(state_dict)
46 |
47 | def dampen(self, step=None):
48 | """Dampen the learning rates.
49 |
50 | Arguments:
51 | step (int): The index of current step. (Default: None)
52 | """
53 | if step is None:
54 | step = self.last_step + 1
55 | self.last_step = step
56 |
57 | for group, params in zip(self.optimizer.param_groups, self.warmup_params):
58 | omega = self.warmup_factor(step, **params)
59 | group['lr'] *= omega
60 |
61 | @contextmanager
62 | def dampening(self):
63 | for group, lr in zip(self.optimizer.param_groups, self.lrs):
64 | group['lr'] = lr
65 | yield
66 | self.lrs = [group['lr'] for group in self.optimizer.param_groups]
67 | self.dampen()
68 |
69 | def warmup_factor(self, step, **params):
70 | raise NotImplementedError
71 |
72 |
73 | def get_warmup_params(warmup_period, group_count):
74 | if isinstance(warmup_period, list):
75 | if len(warmup_period) != group_count:
76 | raise ValueError(
77 | 'The size of warmup_period ({}) does not match the size of param_groups ({}).'.format(
78 | len(warmup_period), group_count))
79 | for x in warmup_period:
80 | if not isinstance(x, int):
81 | raise TypeError(
82 | 'An element in warmup_period, {}, is not an int.'.format(
83 | type(x).__name__))
84 | if x <= 0:
85 | raise ValueError(
86 | 'An element in warmup_period must be a positive integer, but is {}.'.format(x))
87 | warmup_params = [dict(warmup_period=x) for x in warmup_period]
88 | elif isinstance(warmup_period, int):
89 | if warmup_period <= 0:
90 | raise ValueError(
91 | 'warmup_period must be a positive integer, but is {}.'.format(warmup_period))
92 | warmup_params = [dict(warmup_period=warmup_period)
93 | for _ in range(group_count)]
94 | else:
95 | raise TypeError('{} ({}) is not a list nor an int.'.format(
96 | warmup_period, type(warmup_period).__name__))
97 | return warmup_params
98 |
99 |
100 | class LinearWarmup(BaseWarmup):
101 | """Linear warmup schedule.
102 |
103 | Arguments:
104 | optimizer (Optimizer): an instance of a subclass of Optimizer
105 | warmup_period (int or list): Warmup period
106 | last_step (int): The index of last step. (Default: -1)
107 | """
108 |
109 | def __init__(self, optimizer, warmup_period, last_step=-1):
110 | _check_optimizer(optimizer)
111 | group_count = len(optimizer.param_groups)
112 | warmup_params = get_warmup_params(warmup_period, group_count)
113 | super(LinearWarmup, self).__init__(optimizer, warmup_params, last_step)
114 |
115 | def warmup_factor(self, step, warmup_period):
116 | return min(1.0, (step+1) / warmup_period)
117 |
118 |
119 | class ExponentialWarmup(BaseWarmup):
120 | """Exponential warmup schedule.
121 |
122 | Arguments:
123 | optimizer (Optimizer): an instance of a subclass of Optimizer
124 | warmup_period (int or list): Effective warmup period
125 | last_step (int): The index of last step. (Default: -1)
126 | """
127 |
128 | def __init__(self, optimizer, warmup_period, last_step=-1):
129 | _check_optimizer(optimizer)
130 | group_count = len(optimizer.param_groups)
131 | warmup_params = get_warmup_params(warmup_period, group_count)
132 | super(ExponentialWarmup, self).__init__(optimizer, warmup_params, last_step)
133 |
134 | def warmup_factor(self, step, warmup_period):
135 | return 1.0 - math.exp(-(step+1) / warmup_period)
--------------------------------------------------------------------------------
/schedulers/pytorch_warmup/radam.py:
--------------------------------------------------------------------------------
1 | ###from https://github.com/Tony-Y/pytorch_warmup/blob/master/pytorch_warmup/radam.py
2 |
3 | import math
4 | from .base import BaseWarmup, _check_optimizer
5 |
6 |
7 | def rho_inf_fn(beta2):
8 | return 2.0 / (1 - beta2) - 1
9 |
10 |
11 | def rho_fn(t, beta2, rho_inf):
12 | b2t = beta2 ** t
13 | rho_t = rho_inf - 2 * t * b2t / (1 - b2t)
14 | return rho_t
15 |
16 |
17 | def get_offset(beta2, rho_inf):
18 | if not beta2 > 0.6:
19 | raise ValueError('beta2 ({}) must be greater than 0.6'.format(beta2))
20 | offset = 1
21 | while True:
22 | if rho_fn(offset, beta2, rho_inf) > 4:
23 | return offset
24 | offset += 1
25 |
26 |
27 | class RAdamWarmup(BaseWarmup):
28 | """RAdam warmup schedule.
29 |
30 | This warmup scheme is described in
31 | `On the adequacy of untuned warmup for adaptive optimization
32 | `_.
33 |
34 | Arguments:
35 | optimizer (Optimizer): an Adam optimizer
36 | last_step (int): The index of last step. (Default: -1)
37 | """
38 |
39 | def __init__(self, optimizer, last_step=-1):
40 | _check_optimizer(optimizer)
41 | warmup_params = [
42 | dict(
43 | beta2=x['betas'][1],
44 | rho_inf=rho_inf_fn(x['betas'][1]),
45 | )
46 | for x in optimizer.param_groups
47 | ]
48 | for x in warmup_params:
49 | x['offset'] = get_offset(**x)
50 | super(RAdamWarmup, self).__init__(optimizer, warmup_params, last_step)
51 |
52 | def warmup_factor(self, step, beta2, rho_inf, offset):
53 | rho = rho_fn(step+offset, beta2, rho_inf)
54 | numerator = (rho - 4) * (rho - 2) * rho_inf
55 | denominator = (rho_inf - 4) * (rho_inf - 2) * rho
56 | return math.sqrt(numerator/denominator)
--------------------------------------------------------------------------------
/schedulers/pytorch_warmup/untuned.py:
--------------------------------------------------------------------------------
1 | from .base import LinearWarmup, ExponentialWarmup, _check_optimizer
2 |
3 |
4 | class UntunedLinearWarmup(LinearWarmup):
5 | """Untuned linear warmup schedule for Adam.
6 |
7 | This warmup scheme is described in
8 | `On the adequacy of untuned warmup for adaptive optimization
9 | `_.
10 |
11 | Arguments:
12 | optimizer (Optimizer): an Adam optimizer
13 | last_step (int): The index of last step. (Default: -1)
14 | """
15 |
16 | def __init__(self, optimizer, last_step=-1):
17 | _check_optimizer(optimizer)
18 |
19 | def warmup_period_fn(beta2):
20 | return int(2.0 / (1.0-beta2))
21 | warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups]
22 | super(UntunedLinearWarmup, self).__init__(optimizer, warmup_period, last_step)
23 |
24 |
25 | class UntunedExponentialWarmup(ExponentialWarmup):
26 | """Untuned exponetial warmup schedule for Adam.
27 |
28 | This warmup scheme is described in
29 | `On the adequacy of untuned warmup for adaptive optimization
30 | `_.
31 |
32 | Arguments:
33 | optimizer (Optimizer): an Adam optimizer
34 | last_step (int): The index of last step. (Default: -1)
35 | """
36 |
37 | def __init__(self, optimizer, last_step=-1):
38 | _check_optimizer(optimizer)
39 |
40 | def warmup_period_fn(beta2):
41 | return int(1.0 / (1.0-beta2))
42 | warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups]
43 | super(UntunedExponentialWarmup, self).__init__(optimizer, warmup_period, last_step)
--------------------------------------------------------------------------------
/schedulers/warmup.py:
--------------------------------------------------------------------------------
1 | #from https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py
2 |
3 | from torch.optim.lr_scheduler import _LRScheduler
4 | from torch.optim.lr_scheduler import ReduceLROnPlateau
5 |
6 |
7 | class GradualWarmupScheduler(_LRScheduler):
8 | """ Gradually warm-up(increasing) learning rate in optimizer.
9 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
10 | Args:
11 | optimizer (Optimizer): Wrapped optimizer.
12 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
13 | total_epoch: target learning rate is reached at total_epoch, gradually
14 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
15 | """
16 |
17 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
18 | self.multiplier = multiplier
19 | if self.multiplier < 1.:
20 | raise ValueError('multiplier should be greater thant or equal to 1.')
21 | self.total_epoch = total_epoch
22 | self.after_scheduler = after_scheduler
23 | self.finished = False
24 | super(GradualWarmupScheduler, self).__init__(optimizer)
25 |
26 | def get_lr(self):
27 | if self.last_epoch > self.total_epoch:
28 | if self.after_scheduler:
29 | if not self.finished:
30 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
31 | self.finished = True
32 | return self.after_scheduler.get_last_lr()
33 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
34 |
35 | if self.multiplier == 1.0:
36 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
37 | else:
38 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
39 |
40 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
41 | if epoch is None:
42 | epoch = self.last_epoch + 1
43 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
44 | if self.last_epoch <= self.total_epoch:
45 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
46 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
47 | param_group['lr'] = lr
48 | else:
49 | if epoch is None:
50 | self.after_scheduler.step(metrics, None)
51 | else:
52 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
53 |
54 | def step(self, epoch=None, metrics=None):
55 | if type(self.after_scheduler) != ReduceLROnPlateau:
56 | if self.finished and self.after_scheduler:
57 | if epoch is None:
58 | self.after_scheduler.step(None)
59 | else:
60 | self.after_scheduler.step(epoch - self.total_epoch)
61 | self._last_lr = self.after_scheduler.get_last_lr()
62 | else:
63 | return super(GradualWarmupScheduler, self).step(epoch)
64 | else:
65 | self.step_ReduceLROnPlateau(metrics, epoch)
--------------------------------------------------------------------------------
/utils/create_strand_latent_weights.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | #when training the diffusion model we apply a L2 loss on all the channels of the scalp texture. However some of the channels are essentially noise.
4 | #This is due to the fact that some dimensions of the strand vae encode very little information and don't signficantly modify the strand shape.
5 | #here we check what is the change in position, direction and curvature when changing each of the dimensions of the latent and we use this delta change as a weight for our diffusion model to downweight certain channels
6 |
7 |
8 |
9 | #python3 ./create_strand_latent_weights.py --checkpoint_path
10 |
11 |
12 |
13 |
14 | import argparse
15 | import torch
16 | from models.strand_codec import StrandCodec
17 | from utils.strand_util import compute_dirs
18 | from losses.loss import StrandVAELoss
19 | import json
20 | from data_loader.dataloader import DiffLocksDataset
21 | from data_loader.mesh_utils import world_to_tbn_space
22 |
23 |
24 |
25 |
26 | torch.set_grad_enabled(False)
27 |
28 |
29 | #transforms the data to a local space, put it on cuda device and reshapes it the way we expect it to be
30 | def prepare_gt_batch(batch):
31 | gt_dict = {}
32 |
33 | tbn=batch['full_strands']["tbn"].cuda()
34 | positions=batch['full_strands']["positions"].cuda()
35 | root_normal=batch['full_strands']["root_normal"].cuda()
36 |
37 | #get it on local space
38 | gt_strand_positions, gt_root_normals = world_to_tbn_space(tbn,
39 | positions,
40 | root_normal)
41 | gt_strand_positions=gt_strand_positions.cuda()
42 |
43 | #reshape it to be nr_strands, nr_points, dim
44 | gt_strand_positions=gt_strand_positions.reshape(-1,256,3)
45 |
46 | gt_dirs=compute_dirs(gt_strand_positions, append_last_dir=False) #nr_strands,256-1,3
47 |
48 |
49 | gt_dict["strand_positions"]=gt_strand_positions
50 | gt_dict["strand_directions"]=gt_dirs
51 |
52 |
53 | return gt_dict
54 |
55 |
56 | class HyperParamsStrandVAE:
57 | def __init__(self):
58 | #####Output
59 | #####decode dir######
60 | self.decode_type="dir"
61 | self.scale_init=30.0
62 | self.nr_verts_per_strand=256
63 | self.nr_values_to_decode=255
64 | self.dim_per_value_decoded=3
65 |
66 |
67 | ###LOSS######
68 | #these are the same values that were used to train the strand vae
69 | self.loss_pos_weight=0.5 ##FOR LAMBDA
70 | self.loss_dir_weight=1.0
71 | self.loss_curv_weight=20.0
72 | self.loss_kl_weight=0.0
73 |
74 |
75 |
76 | def main():
77 |
78 | #argparse
79 | parser = argparse.ArgumentParser(description='Get the weights of each dimensions after training a strand VAE')
80 | parser.add_argument('--checkpoint_path', required=True, help='Path to the strandVAE checkpoint')
81 | args = parser.parse_args()
82 |
83 | path_strand_vae_model=args.checkpoint_path
84 |
85 |
86 | hyperparams=HyperParamsStrandVAE()
87 |
88 |
89 | normalization_dict=DiffLocksDataset.get_normalization_data()
90 |
91 | model = StrandCodec(do_vae=False,
92 | decode_type="dir",
93 | scale_init=30.0,
94 | nr_verts_per_strand=256, nr_values_to_decode=255,
95 | dim_per_value_decoded=3).cuda()
96 | model.load_state_dict(torch.load(path_strand_vae_model))
97 | model = torch.compile(model)
98 |
99 |
100 |
101 | #latent of dimension 64 and get GT which is the mean strand
102 | latent=torch.zeros(1,64).cuda()
103 | pred_dict = model.decoder(latent, None, normalization_dict)
104 | pred_points=pred_dict["strand_positions"]
105 | gt_strand=pred_points
106 | print("gt_strand",gt_strand.shape)
107 |
108 |
109 | #make loss function
110 | loss_computer= StrandVAELoss()
111 |
112 |
113 | #for each dimension change it by 0.5 and check the error towards the mean strand (GT)
114 | loss_per_dim=[]
115 | for i in range(64):
116 | latent=torch.zeros(1,64).cuda()
117 | latent[:,i]=0.8
118 | pred_dict = model.decoder(latent, None, normalization_dict)
119 | pred_points=pred_dict["strand_positions"]
120 |
121 | #make dicts
122 | gt_dict={"strand_positions": gt_strand}
123 | pred_dict={"strand_positions": pred_points}
124 | latent_dict={}
125 |
126 | #loss
127 | loss_dict = loss_computer(None, gt_dict, pred_dict, latent_dict, hyperparams)
128 | loss=loss_dict["loss"]
129 | loss_per_dim.append(loss)
130 |
131 | # print("loss", loss)
132 |
133 | #normalize losses
134 | loss_normalization=max(loss_per_dim)
135 | weight_per_dim = [(x/loss_normalization).item() for x in loss_per_dim]
136 |
137 |
138 |
139 | #print them in order
140 | for i in range(64):
141 | # print("i", i, " w: ", weight_per_dim[i].item())
142 | print(weight_per_dim[i])
143 |
144 | with open("loss_weight_strand_latent.json", "w") as final:
145 | json.dump(weight_per_dim, final)
146 |
147 |
148 | #finished
149 | return
150 |
151 |
152 | if __name__ == '__main__':
153 | main()
154 |
155 |
--------------------------------------------------------------------------------
/utils/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import k_diffusion as K
3 |
4 | @torch.no_grad()
5 | def sample_images(nr_images, model_ema, model_config, nr_iters=100, extra_args={}, callback=None):
6 | model_ema.eval()
7 | sigma_min = model_config['sigma_min']
8 | sigma_max = model_config['sigma_max']
9 | size = model_config['input_size']
10 | n_per_proc = nr_images
11 | x = torch.randn([1, n_per_proc, model_config['input_channels'], size[0], size[1]]).cuda()
12 | x = x[0] * sigma_max
13 | model_fn = model_ema
14 | sigmas = K.sampling.get_sigmas_karras(nr_iters, sigma_min, sigma_max, rho=7., device="cuda")
15 | x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=False, callback=callback)
16 | return x_0
17 |
18 | @torch.no_grad()
19 | #samples using classifier free guidance and only enables the cfg_val when the sigma is within the interval.
20 | #idead from this paper: https://arxiv.org/pdf/2404.07724
21 | def sample_images_cfg(nr_images, cfg_val, cfg_interval, model_ema, model_config, nr_iters=100, extra_args={}, callback=None):
22 | model_ema.eval()
23 | sigma_min = model_config['sigma_min']
24 | sigma_max = model_config['sigma_max']
25 | size = model_config['input_size']
26 | n_per_proc = nr_images
27 | x = torch.randn([1, n_per_proc, model_config['input_channels'], size[0], size[1]]).cuda()
28 | x = x[0] * sigma_max
29 | model_fn = model_ema
30 | sigmas = K.sampling.get_sigmas_karras(nr_iters, sigma_min, sigma_max, rho=7., device="cuda")
31 | x_0 = K.sampling.sample_dpmpp_2m_sde_cfg(model_fn, x, sigmas, cfg_val, cfg_interval, extra_args=extra_args, eta=0.0, solver_type='heun', disable=False, callback=callback)
32 | return x_0
--------------------------------------------------------------------------------
/utils/resize_right/interp_methods.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 |
3 | try:
4 | import torch
5 | except ImportError:
6 | torch = None
7 |
8 | try:
9 | import numpy
10 | except ImportError:
11 | numpy = None
12 |
13 | if numpy is None and torch is None:
14 | raise ImportError("Must have either Numpy or PyTorch but both not found")
15 |
16 |
17 | def set_framework_dependencies(x):
18 | if type(x) is numpy.ndarray:
19 | to_dtype = lambda a: a
20 | fw = numpy
21 | else:
22 | to_dtype = lambda a: a.to(x.dtype)
23 | fw = torch
24 | eps = fw.finfo(fw.float32).eps
25 | return fw, to_dtype, eps
26 |
27 |
28 | def support_sz(sz):
29 | def wrapper(f):
30 | f.support_sz = sz
31 | return f
32 | return wrapper
33 |
34 |
35 | @support_sz(4)
36 | def cubic(x):
37 | fw, to_dtype, eps = set_framework_dependencies(x)
38 | absx = fw.abs(x)
39 | absx2 = absx ** 2
40 | absx3 = absx ** 3
41 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
42 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
43 | to_dtype((1. < absx) & (absx <= 2.)))
44 |
45 |
46 | @support_sz(4)
47 | def lanczos2(x):
48 | fw, to_dtype, eps = set_framework_dependencies(x)
49 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
50 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))
51 |
52 |
53 | @support_sz(6)
54 | def lanczos3(x):
55 | fw, to_dtype, eps = set_framework_dependencies(x)
56 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
57 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))
58 |
59 |
60 | @support_sz(2)
61 | def linear(x):
62 | fw, to_dtype, eps = set_framework_dependencies(x)
63 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *
64 | to_dtype((0 <= x) & (x <= 1)))
65 |
66 |
67 | @support_sz(1)
68 | def box(x):
69 | fw, to_dtype, eps = set_framework_dependencies(x)
70 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
71 |
--------------------------------------------------------------------------------
/utils/vis_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 |
4 |
5 | class PCA(Function):
6 | @staticmethod
7 | def forward(ctx, sv): #sv corresponds to the slices values, it has dimensions nr_positions x val_full_dim
8 |
9 | # http://agnesmustar.com/2017/11/01/principal-component-analysis-pca-implemented-pytorch/
10 |
11 |
12 | X=sv.detach().cpu()#we switch to cpu of memory issues when doing svd on really big imaes
13 | k=3
14 | # print("x is ", X.shape)
15 | X_mean = torch.mean(X,0)
16 | # print("x_mean is ", X_mean.shape)
17 | X = X - X_mean.expand_as(X)
18 |
19 | # U,S,V = torch.svd(torch.t(X))
20 | U,S,V = torch.pca_lowrank( torch.t(X) )
21 | C = torch.mm(X,U[:,:k])
22 | # print("C has shape ", C.shape)
23 | # print("C min and max is ", C.min(), " ", C.max() )
24 | C-=C.min()
25 | C/=C.max()
26 | # print("after normalization C min and max is ", C.min(), " ", C.max() )
27 |
28 | return C
29 |
30 |
31 | #img supposed to be N,C,H,W
32 | def img_2_pca(img):
33 | assert img.shape[0]==1
34 | img_c=img.shape[1]
35 | img_h=img.shape[2]
36 | img_w=img.shape[3]
37 | vals = img.permute(0,2,3,1) #N,H,W,C
38 | c=PCA.apply(vals.view(-1,img_c))
39 | c=c.view(1, img_h, img_w, 3).permute(0,3,1,2) #N,H,W,C to N,C,H,W
40 | return c
41 |
42 |
--------------------------------------------------------------------------------