├── LICENSE ├── README.md ├── callbacks ├── callback.py ├── callback_utils.py ├── phase.py ├── state_callback.py ├── tensorboard_callback.py ├── viewer_callback.py └── wandb_callback.py ├── checkpoints └── HOW_TO_DOWNLOAD.txt ├── configs ├── config_scalp_texture_conditional.json └── strand_vae_train.toml ├── data_loader ├── dataloader.py ├── difflocks_bodydata │ ├── scalp.ply │ └── smplx_base.ply └── mesh_utils.py ├── data_processing ├── create_chunked_strands.py ├── create_latents.py ├── create_scalp_textures.py └── uncompress_data.py ├── dockerfile ├── Dockerfile ├── build.sh ├── nvidia-container-runtime-script.sh ├── nvidia-container-runtime-script_ubuntu24.sh └── run.sh ├── download_checkpoints.sh ├── download_dataset.sh ├── download_validation.sh ├── extensions ├── cuda │ └── push_pull_inpaint.cu ├── include │ └── cuMat │ │ ├── CMakeLists.txt │ │ ├── Core │ │ ├── Dense │ │ ├── IterativeLinearSolvers │ │ ├── Sparse │ │ └── src │ │ ├── Allocator.h │ │ ├── BinaryOps.h │ │ ├── BinaryOpsPlugin.inl │ │ ├── CholeskyDecomposition.h │ │ ├── ConjugateGradient.h │ │ ├── Constants.h │ │ ├── Context.h │ │ ├── CublasApi.h │ │ ├── CudaUtils.h │ │ ├── CusolverApi.h │ │ ├── CwiseOp.h │ │ ├── DecompositionBase.h │ │ ├── DenseLinAlgOps.h │ │ ├── DenseLinAlgPlugin.inl │ │ ├── DevicePointer.h │ │ ├── DisableCompilerWarnings.h │ │ ├── EigenInteropHelpers.h │ │ ├── Errors.h │ │ ├── ForwardDeclarations.h │ │ ├── IO.h │ │ ├── IterativeSolverBase.h │ │ ├── Iterator.h │ │ ├── LUDecomposition.h │ │ ├── Logging.h │ │ ├── Macros.h │ │ ├── Matrix.h │ │ ├── MatrixBase.h │ │ ├── MatrixBlock.h │ │ ├── MatrixBlockPluginLvalue.inl │ │ ├── MatrixBlockPluginRvalue.inl │ │ ├── MatrixNullaryOpsPlugin.inl │ │ ├── NullaryOps.h │ │ ├── NumTraits.h │ │ ├── ProductOp.h │ │ ├── Profiling.h │ │ ├── ReductionAlgorithmSelection.h │ │ ├── ReductionOps.h │ │ ├── ReductionOpsPlugin.inl │ │ ├── SimpleRandom.h │ │ ├── SolverBase.h │ │ ├── SparseEvaluation.h │ │ ├── SparseExpressionOp.h │ │ ├── SparseExpressionOpPlugin.inl │ │ ├── SparseMatrix.h │ │ ├── SparseMatrixBase.h │ │ ├── SparseProductEvaluation.h │ │ ├── TransposeOp.h │ │ ├── UnaryOps.h │ │ └── UnaryOpsPlugin.inl └── push_pull_inpaint.cpp ├── imgs ├── dataset.png └── teaser.png ├── inference ├── assets │ ├── blender_vis_base_v26_with_shrinkwrap_full_base.blend │ └── face_landmarker.task ├── img2hair.py └── npz2blender.py ├── inference_difflocks.py ├── k_diffusion ├── __init__.py ├── config.py ├── layers.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── axial_rope.py │ ├── flags.py │ ├── flops.py │ ├── image_transformer_v2_conditional.py │ └── modules.py ├── sampling.py └── utils.py ├── losses ├── loss.py ├── loss_utils.py └── losses.py ├── models ├── rgb_to_material.py └── strand_codec.py ├── modules ├── __init__.py ├── edm2_modules.py └── networks.py ├── requirements.txt ├── samples ├── buzzcut_3.jpg ├── cooper_4.jpg ├── fiennes_7.jpg ├── freeman_2.png ├── harris_5.jpeg ├── hathaway_1.jpg ├── medium_11.png └── uggams_3.png ├── schedulers ├── linearlr.py ├── multisteplr.py ├── pytorch_warmup │ ├── __init__.py │ ├── base.py │ ├── radam.py │ └── untuned.py └── warmup.py ├── train_rgb2material.py ├── train_scalp_diffusion.py ├── train_strandsVAE.py └── utils ├── create_strand_latent_weights.py ├── diffusion_utils.py ├── general_util.py ├── resize_right ├── interp_methods.py └── resize_right.py ├── strand_util.py └── vis_util.py /LICENSE: -------------------------------------------------------------------------------- 1 | 1. DiffLocks License (Non-Commercial Scientific Research Use Only) 2 | 3 | 4 | Software Copyright License for non-commercial scientific research purposes 5 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the DiffLocks model, data and software, (the "Model & Software"). By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License 6 | 7 | Ownership / Licensees 8 | The Software and the associated materials has been developed at the 9 | 10 | Meshcapade. 11 | 12 | Any copyright or patent right is owned by and proprietary material of the 13 | 14 | Meshcapade 15 | 16 | hereinafter the “Licensor”. 17 | 18 | License Grant 19 | Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right: 20 | 21 | To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization; 22 | To use the Model & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects; 23 | Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Meshcapade’s prior written permission. 24 | 25 | The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it. 26 | 27 | No Distribution 28 | The Model & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only. 29 | 30 | Disclaimer of Representations and Warranties 31 | You expressly acknowledge and agree that the Model & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Model & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE MODEL & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Model & Software, (ii) that the use of the Model & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Model & Software will not cause any damage of any kind to you or a third party. 32 | 33 | Limitation of Liability 34 | Because this Model & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage. 35 | Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded. 36 | Patent claims generated through the usage of the Model & Software cannot be directed towards the copyright holders. 37 | The Model & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Model & Software and is not responsible for any problems such modifications cause. 38 | 39 | No Maintenance Services 40 | You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Model & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Model & Software at any time. 41 | 42 | Defects of the Model & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication. 43 | 44 | Publications using the Model & Software 45 | You acknowledge that the Model & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Model & Software. 46 | 47 | Citation: 48 | 49 | 50 | @inproceedings{difflocks2025, 51 | title = {DiffLocks: Generating 3D Hair from a Single Image using Diffusion Models}, 52 | author = {Rosu, Radu Alexandru and Wu, Keyu and Feng, Yao and Zheng, Youyi and Black, Michael J.}, 53 | booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, 54 | year = {2025} 55 | } 56 | Commercial licensing opportunities 57 | For commercial uses of the Software, please send email to sales@meshcapade.com 58 | 59 | This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention. 60 | 61 | 62 | 63 | 2. Third-Party Components 64 | 65 | Portions of this code are derived from [k-diffusion](https://github.com/crowsonkb/k-diffusion), which is licensed under the MIT License. 66 | The original copyright and license are included below: 67 | 68 | Copyright (c) 2022 Katherine Crowson 69 | 70 | Permission is hereby granted, free of charge, to any person obtaining a copy 71 | of this software and associated documentation files (the "Software"), to deal 72 | in the Software without restriction, including without limitation the rights 73 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 74 | copies of the Software, and to permit persons to whom the Software is 75 | furnished to do so, subject to the following conditions: 76 | 77 | The above copyright notice and this permission notice shall be included in 78 | all copies or substantial portions of the Software. 79 | 80 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 81 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 82 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 83 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 84 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 85 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 86 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffLocks: Generating 3D Hair from a Single Image using Diffusion Models # 2 | 3 | [**Paper**](https://arxiv.org/abs/2505.06166) | [**Project Page**](https://radualexandru.github.io/difflocks/) 4 | 5 |

6 | 7 |

8 | 9 | This repository contains official inference and training code for DiffLocks, which creates strand-based realistic hairstyle from a single image. It also contains the DiffLocks dataset consisting of 40K 3D synthetic strand-based hair data generated in Blender. 10 | 11 | ## Requirements 12 | 13 | DiffLocks dependencies can be installed from the provided `requirements.txt` which can be installed in a virtual environment: 14 | 15 | $ python3 -m venv ./difflocks_env 16 | $ source ./difflocks_env/bin/activate 17 | $ pip install -r ./requirements.txt 18 | 19 | Afterwards we need to install custom CUDA kernels for the diffusion model: 20 | * [NATTEN](https://github.com/SHI-Labs/NATTEN/tree/main) for the sparse (neighborhood) attention used at low levels of the hierarchy. 21 | * [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) for global attention. 22 | Please double check that you install Natten for torch 2.5.0 (as per requierments.txt). 23 | 24 | Finally if you want to perform inference, you need to download the checkpoints for the trained models. 25 | The pretrained checkpoints can be downloaded by following [this section](#download-pretrained-checkpoints): 26 | 27 | 28 | 29 | ## Dataset 30 | ![data](/imgs/dataset.png "dataset") 31 | The DiffLocks dataset consists 40K hairstyle. Each sample includes 3D hair (~100K strands), corresponding rendered RGB image and metadata regarding the hair. 32 | 33 | DiffLocks can be downloaded using: 34 | 35 | ./download_dataset.sh 36 | 37 | After downloading, the dataset has to first be uncompressed: 38 | 39 | $ ./data_processing/uncompress_data.py --dataset_zipped_path --out_path 40 | 41 | After uncompressing we create a processed dataset: 42 | 43 | $ ./data_processing/create_chunked_strands.py --dataset_path 44 | $ ./data_processing/create_latents.py --dataset_path= --out_path 45 | $ ./data_processing/create_scalp_textures.py --dataset_path= --out_path --path_strand_vae_model ./checkpoints/strand_vae/strand_codec.pt 46 | 47 | 48 | ## Download pretrained checkpoints 49 | You can download pretrained checkpoints by running: 50 | 51 | ./download_checkpoints.sh 52 | 53 | ## Inference 54 | To run inference on an RGB and create 3D strands use: 55 | 56 | $ ./inference_difflocks.py \ 57 | --img_path=./samples/medium_11.png \ 58 | --out_path=./outputs_inference/ 59 | 60 | You also have options to export a `.blend` file and an alembic file by specifying `--blender_path` and `--export-alembic` in the above script. 61 | Note that the blender path corresponds to the blender executable with version 4.1.1. It will likely not work with other versions. 62 | 63 | 64 | ## Train StrandVAE 65 | To train the strandVAE model: 66 | 67 | $ ./train_strandsVAE.py --dataset_path= --exp_info= 68 | 69 | it will start training and outputting tensorboard logs in `./tensorboard_logs` 70 | 71 | 72 | ## Train DiffLocks diffusion model 73 | To train the diffusion model: 74 | 75 | $ ./train_scalp_diffusion.py \ 76 | --config ./configs/config_scalp_texture_conditional.json \ 77 | --batch-size 4 \ 78 | --grad-accum-steps 4 \ 79 | --mixed-precision bf16 \ 80 | --use-tensorboard \ 81 | --save-checkpoints \ 82 | --save-every 100000 \ 83 | --compile \ 84 | --dataset_path= \ 85 | --dataset_processed_path= 86 | --name 87 | 88 | it will start training and outputting tensorboard logs in `./tensorboard_logs`. 89 | Start training on multiple GPUs by first running: 90 | 91 | $ accelerate config 92 | 93 | followed by pre-pending `accelerate launch` to the previous training script: 94 | 95 | $ accelerate launch ./train_scalp_diffusion.py \ 96 | --config ./configs/config_scalp_texture_conditional.json \ 97 | --batch-size 4 \ 98 | 99 | 100 | You would probably to adjust the `batch-size` and `grad-accum-step` depending on the number of GPUs you have. 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /callbacks/callback.py: -------------------------------------------------------------------------------- 1 | #https://github.com/devforfu/pytorch_playground/blob/master/loop.ipynb 2 | 3 | import re 4 | import torch 5 | 6 | def to_snake_case(string): 7 | """Converts CamelCase string into snake_case.""" 8 | 9 | s = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', string) 10 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s).lower() 11 | 12 | def classname(obj): 13 | return obj.__class__.__name__ 14 | 15 | 16 | class Callback: 17 | """ 18 | The base class inherited by callbacks. 19 | 20 | Provides a lot of hooks invoked on various stages of the training loop 21 | execution. The signature of functions is as broad as possible to allow 22 | flexibility and customization in descendant classes. 23 | """ 24 | def training_started(self, **kwargs): pass 25 | 26 | def training_ended(self, **kwargs): pass 27 | 28 | def epoch_started(self, **kwargs): pass 29 | 30 | def phase_started(self, **kwargs): pass 31 | 32 | def phase_ended(self, **kwargs): pass 33 | 34 | def epoch_ended(self, **kwargs): pass 35 | 36 | def batch_started(self, **kwargs): pass 37 | 38 | def batch_ended(self, **kwargs): pass 39 | 40 | def before_forward_pass(self, **kwargs): pass 41 | 42 | def after_forward_pass(self, **kwargs): pass 43 | 44 | def before_backward_pass(self, **kwargs): pass 45 | 46 | def after_backward_pass(self, **kwargs): pass 47 | 48 | 49 | class CallbacksGroup(Callback): 50 | """ 51 | Groups together several callbacks and delegates training loop 52 | notifications to the encapsulated objects. 53 | """ 54 | def __init__(self, callbacks): 55 | self.callbacks = callbacks 56 | self.named_callbacks = {to_snake_case(classname(cb)): cb for cb in self.callbacks} 57 | 58 | def __getitem__(self, item): 59 | item = to_snake_case(item) 60 | if item in self.named_callbacks: 61 | return self.named_callbacks[item] 62 | raise KeyError(f'callback name is not found: {item}') 63 | 64 | def training_started(self, **kwargs): self.invoke('training_started', **kwargs) 65 | 66 | def training_ended(self, **kwargs): self.invoke('training_ended', **kwargs) 67 | 68 | def epoch_started(self, **kwargs): self.invoke('epoch_started', **kwargs) 69 | 70 | def phase_started(self, **kwargs): self.invoke('phase_started', **kwargs) 71 | 72 | def phase_ended(self, **kwargs): self.invoke('phase_ended', **kwargs) 73 | 74 | def epoch_ended(self, **kwargs): self.invoke('epoch_ended', **kwargs) 75 | 76 | def batch_started(self, **kwargs): self.invoke('batch_started', **kwargs) 77 | 78 | def batch_ended(self, **kwargs): self.invoke('batch_ended', **kwargs) 79 | 80 | def before_forward_pass(self, **kwargs): self.invoke('before_forward_pass', **kwargs) 81 | 82 | def after_forward_pass(self, **kwargs): self.invoke('after_forward_pass', **kwargs) 83 | 84 | def before_backward_pass(self, **kwargs): self.invoke('before_backward_pass', **kwargs) 85 | 86 | def after_backward_pass(self, **kwargs): self.invoke('after_backward_pass', **kwargs) 87 | 88 | def invoke(self, method, **kwargs): 89 | with torch.set_grad_enabled(False): 90 | for cb in self.callbacks: 91 | getattr(cb, method)(**kwargs) -------------------------------------------------------------------------------- /callbacks/callback_utils.py: -------------------------------------------------------------------------------- 1 | # from permuto_sdf import TrainParams 2 | from callbacks.callback import * 3 | from callbacks.wandb_callback import * 4 | from callbacks.state_callback import * 5 | from callbacks.phase import * 6 | 7 | 8 | def create_callbacks(with_tensorboard, with_visualizer, experiment_name, viewer_config_path=None): 9 | cb_list = [] 10 | if(with_tensorboard): 11 | from callbacks.tensorboard_callback import TensorboardCallback #we put it here in case we don't have tensorboard installed 12 | tensorboard_callback=TensorboardCallback(experiment_name) 13 | cb_list.append(tensorboard_callback) 14 | if(with_visualizer): 15 | from callbacks.viewer_callback import ViewerCallback #we put it here because we might not have the visualizer package installed 16 | viewer_callback=ViewerCallback(viewer_config_path, experiment_name) 17 | cb_list.append(viewer_callback) 18 | cb_list.append(StateCallback()) 19 | cb = CallbacksGroup(cb_list) 20 | 21 | return cb 22 | -------------------------------------------------------------------------------- /callbacks/phase.py: -------------------------------------------------------------------------------- 1 | #https://github.com/devforfu/pytorch_playground/blob/master/loop.ipynbA 2 | 3 | # from permuto_sdf_py.callbacks.scores import * 4 | 5 | class Phase: 6 | """ 7 | Model training loop phase. 8 | 9 | Each model's training loop iteration could be separated into (at least) two 10 | phases: training and validation. The instances of this class track 11 | metrics and counters, related to the specific phase, and keep the reference 12 | to subset of data, used during phase. 13 | """ 14 | 15 | def __init__(self, name, loader, grad): 16 | self.name = name 17 | self.loader = loader 18 | self.grad = grad 19 | self.iter_nr = 0 20 | self.epoch_nr = 0 21 | self.samples_processed_this_epoch = 0 22 | # self.scores= Scores() 23 | self.loss_acum_per_epoch=0.0 24 | self.loss_pos_acum_per_epoch=0.0 25 | self.loss_dir_acum_per_epoch=0.0 26 | -------------------------------------------------------------------------------- /callbacks/state_callback.py: -------------------------------------------------------------------------------- 1 | from callbacks.callback import * 2 | import os 3 | import torch 4 | 5 | 6 | class StateCallback(Callback): 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def after_forward_pass(self, phase, loss, loss_pos, loss_dir, loss_curv, **kwargs): 12 | phase.iter_nr+=1 13 | phase.samples_processed_this_epoch+=1 14 | phase.loss_acum_per_epoch+=loss 15 | phase.loss_pos_acum_per_epoch+=loss_pos 16 | phase.loss_dir_acum_per_epoch+=loss_dir 17 | phase.loss_curv_acum_per_epoch+=loss_curv 18 | 19 | 20 | def epoch_started(self, phase, **kwargs): 21 | phase.loss_acum_per_epoch=0.0 22 | phase.loss_pos_acum_per_epoch=0.0 23 | phase.loss_dir_acum_per_epoch=0.0 24 | phase.loss_curv_acum_per_epoch=0.0 25 | 26 | def epoch_ended(self, phase, **kwargs): 27 | 28 | phase.epoch_nr+=1 29 | 30 | def phase_started(self, phase, **kwargs): 31 | phase.samples_processed_this_epoch=0 32 | 33 | def phase_ended(self, phase, model, hyperparams, experiment_name, output_training_path, **kwargs): 34 | 35 | if (phase.epoch_nr%hyperparams.save_checkpoint_every_x_epoch==0) and hyperparams.save_checkpoint and phase.grad: 36 | model.save(output_training_path, experiment_name, hyperparams, phase.epoch_nr) 37 | 38 | 39 | -------------------------------------------------------------------------------- /callbacks/tensorboard_callback.py: -------------------------------------------------------------------------------- 1 | from callbacks.callback import * 2 | from torch.utils.tensorboard import SummaryWriter 3 | 4 | class TensorboardCallback(Callback): 5 | 6 | def __init__(self, experiment_name): 7 | self.tensorboard_writer = SummaryWriter("tensorboard_logs/"+experiment_name) 8 | self.experiment_name=experiment_name 9 | 10 | 11 | def after_forward_pass(self, phase, loss=0, loss_pos=0, loss_dir=0, loss_curv=0, loss_kl=0, lr=0, z_deviation=None, z=None, z_no_eps=None, **kwargs): 12 | 13 | if phase.iter_nr%300==0 and phase.grad: 14 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss', loss.item(), phase.iter_nr) 15 | if loss_pos!=0: 16 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_pos', loss_pos.item(), phase.iter_nr) 17 | # if loss_l1!=0: 18 | # self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_l1', loss_l1.item(), phase.iter_nr) 19 | # if loss_l2!=0: 20 | # self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_l2', loss_l2.item(), phase.iter_nr) 21 | if loss_dir!=0: 22 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_dir', loss_dir.item(), phase.iter_nr) 23 | if loss_curv!=0: 24 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_curv', loss_curv.item(), phase.iter_nr) 25 | if loss_kl!=0: 26 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_kl', loss_kl.item(), phase.iter_nr) 27 | 28 | if lr!=0: 29 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/lr', lr, phase.iter_nr) 30 | 31 | if z_deviation is not None: 32 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/z_deviation', z_deviation.std(), phase.iter_nr) 33 | 34 | # if z is not None: 35 | # self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/z_max', z.max(), phase.iter_nr) 36 | if z_no_eps is not None: 37 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/z_no_eps_mean', z_no_eps.mean(), phase.iter_nr) 38 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/z_no_eps_std', z_no_eps.std(), phase.iter_nr) 39 | 40 | 41 | def epoch_ended(self, phase, **kwargs): 42 | avg_loss_pos=phase.loss_pos_acum_per_epoch/phase.samples_processed_this_epoch 43 | avg_loss_dir=phase.loss_dir_acum_per_epoch/phase.samples_processed_this_epoch 44 | avg_loss_curv=phase.loss_curv_acum_per_epoch/phase.samples_processed_this_epoch 45 | 46 | if phase.grad==False and phase.loss_pos_acum_per_epoch!=0: 47 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_pos_avg', avg_loss_pos.item(), phase.epoch_nr) 48 | if phase.grad==False and phase.loss_dir_acum_per_epoch!=0: 49 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_dir_avg', avg_loss_dir.item(), phase.epoch_nr) 50 | if phase.grad==False and phase.loss_curv_acum_per_epoch!=0: 51 | self.tensorboard_writer.add_scalar('hair_forge/' + phase.name + '/loss_curv_avg', avg_loss_curv.item(), phase.epoch_nr) 52 | -------------------------------------------------------------------------------- /callbacks/viewer_callback.py: -------------------------------------------------------------------------------- 1 | from callbacks.callback import * 2 | import numpy as np 3 | from gloss import * 4 | 5 | class ViewerCallback(Callback): 6 | 7 | def __init__(self, viewer_config_path, experiment_name): 8 | gloss_setup_logger(log_level=LogLevel.Info) 9 | self.viewer=Viewer(viewer_config_path) 10 | self.scene=self.viewer.get_scene() 11 | 12 | self.first_time=True 13 | 14 | self.visualize_every_x_iters=100 15 | self.visualize_strand_at_idx=None 16 | # self.visualize_strand_at_idx=0 17 | 18 | def after_forward_pass(self, phase, gt_cloud=None, pred_cloud=None, **kwargs): 19 | 20 | # if phase.iter_nr==1000: 21 | # exit(1) 22 | 23 | if phase.iter_nr%self.visualize_every_x_iters==0: 24 | if self.visualize_strand_at_idx is not None: 25 | gt_cloud=gt_cloud[self.visualize_strand_at_idx,:,:] 26 | gt_ent = self.scene.get_or_spawn_renderable("gt_ent") 27 | gt_ent.insert(Verts(gt_cloud.reshape(-1,3).cpu().numpy())) 28 | gt_ent.insert(Colors(gt_cloud.reshape(-1,3).cpu().numpy())) #just to avoid verts being different than Colors when we switch nr_verts for the same entity 29 | if self.first_time: 30 | gt_ent.insert(VisPoints(show_points=True, \ 31 | point_size=3.0, \ 32 | color_type=PointColorType.Solid)) 33 | 34 | 35 | #pred 36 | if self.visualize_strand_at_idx is not None: 37 | pred_cloud=pred_cloud[self.visualize_strand_at_idx,:,:] 38 | pred_ent = self.scene.get_or_spawn_renderable("pred_ent") 39 | pred_ent.insert(Verts(pred_cloud.reshape(-1,3).cpu().numpy())) 40 | pred_ent.insert(Colors(pred_cloud.reshape(-1,3).cpu().numpy())) #just to avoid verts being different than Colors when we switch nr_verts for the same entity 41 | if self.first_time: 42 | pred_ent.insert(VisPoints(show_points=True, \ 43 | point_size=3.0, \ 44 | point_color=[0.1, 0.2, 0.8, 1.0], 45 | color_type=PointColorType.Solid)) 46 | 47 | #render 48 | self.viewer.start_frame() 49 | self.viewer.update() 50 | 51 | 52 | self.first_time=False 53 | 54 | 55 | 56 | 57 | 58 | # viewer=Viewer() 59 | # scene=viewer.get_scene() 60 | 61 | # mesh = scene.get_or_spawn_renderable("test") 62 | # mesh.insert(Verts( 63 | # np.array([ 64 | # [0,0,0], 65 | # [0,1,0], 66 | # [1,0,0], 67 | # [1.5,1.5,-1] 68 | # ], dtype = "float32") )) 69 | # mesh.insert(Colors( 70 | # np.array([ 71 | # [1,0,0], 72 | # [0,0,1], 73 | # [1,1,0], 74 | # [1,0,1] 75 | # ], dtype = "float32"))) 76 | 77 | # mesh.insert(VisPoints(show_points=True, \ 78 | # show_points_indices=True, \ 79 | # point_size=10.0, \ 80 | # color_type=PointColorType.PerVert)) 81 | 82 | # while True: 83 | # viewer.start_frame() 84 | # viewer.update() 85 | 86 | 87 | -------------------------------------------------------------------------------- /callbacks/wandb_callback.py: -------------------------------------------------------------------------------- 1 | from callbacks.callback import * 2 | import wandb 3 | import hjson 4 | 5 | class WandBCallback(Callback): 6 | 7 | def __init__(self, experiment_name, config_path, entity): 8 | self.experiment_name=experiment_name 9 | # loading the config file like this and giving it to wandb stores them on the website 10 | with open(config_path, 'r') as j: 11 | cfg = hjson.loads(j.read()) 12 | # Before this init can be run, you have to use wandb login in the console you are starting the script from (https://docs.wandb.ai/ref/cli/wandb-login, https://docs.wandb.ai/ref/python/init) 13 | # entity= your username 14 | wandb.init(project=experiment_name, entity=entity,config = cfg) 15 | 16 | 17 | def after_forward_pass(self, phase, loss, loss_kl, **kwargs): 18 | 19 | # / act as seperators. If you would like to log train and test separately you would log test loss in test/loss 20 | wandb.log({'train/loss': loss}, step=phase.iter_nr) 21 | if loss_kl!=0: 22 | wandb.log({'train/loss_kl': loss_kl}, step=phase.iter_nr) 23 | 24 | 25 | def epoch_ended(self, phase, **kwargs): 26 | pass -------------------------------------------------------------------------------- /checkpoints/HOW_TO_DOWNLOAD.txt: -------------------------------------------------------------------------------- 1 | If you want to download the checkpoints please use the download_checkpoints.sh script from the root of the repo. 2 | It will download all necessary checkpoints and unzip them in this folder -------------------------------------------------------------------------------- /configs/config_scalp_texture_conditional.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "model": { 4 | "type": "image_transformer_v2_conditional", 5 | "cross_cond": 1, 6 | "condition_dropout_rate": 0.1, 7 | "input_channels": 65, 8 | "input_size": [256, 256], 9 | "patch_size": [2, 2], 10 | "depths": [2, 2, 2], 11 | "widths": [1024, 2048, 3072], 12 | "d_ffs": [1024, 2048, 3072], 13 | "self_attns": [ 14 | {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, 15 | {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, 16 | {"type": "global", "d_head": 64} 17 | ], 18 | "parametrization": "v", 19 | "loss_config": "karras", 20 | "loss_weighting": "snr", 21 | "dropout_rate": [0.0, 0.0, 0.1], 22 | "mapping_width": 768, 23 | "mapping_d_ff": 1024, 24 | "mapping_dropout_rate": 0.0, 25 | "augment_prob": 0.0, 26 | "sigma_data": 0.3, 27 | "sigma_min": 1e-2, 28 | "sigma_max": 160, 29 | "sigma_sample_density": { 30 | "type": "cosine-interpolated", 31 | "noise_d_low": 64 32 | }, 33 | 34 | "rgb_condition":{ 35 | "global_condition_shape": [1, 1024], 36 | "local_condition_shapes": [ 37 | {"shape": [1, 1024, 55, 55]}, 38 | {"shape": [1, 1024, 55, 55]}, 39 | {"shape": [1, 1024, 55, 55]} 40 | ], 41 | "cross_condition_dim": [512,512,512], 42 | "self_attn": [false,true,true] 43 | }, 44 | 45 | "loss_weight_per_channel":[ 46 | 0.05580984428524971, 0.030566953122615814, 0.025709286332130432, 0.05716482549905777, 0.023201454430818558, 0.08562534302473068, 0.017659854143857956, 0.07521561533212662, 0.0483829528093338, 0.017475251108407974, 0.021301859989762306, 0.006980733945965767, 0.06842825561761856, 0.16614384949207306, 0.028791628777980804, 0.07338058203458786, 0.08613209426403046, 0.01631935127079487, 0.015131217427551746, 0.06258893758058548, 0.058801501989364624, 0.05222005024552345, 0.03300714120268822, 0.028619959950447083, 0.066630519926548, 0.10768849402666092, 0.07185778766870499, 0.08246062695980072, 0.08738923817873001, 0.5705997943878174, 0.059733662754297256, 0.1032959371805191, 0.008284210227429867, 0.028203167021274567, 0.015574352815747261, 0.031722474843263626, 0.11103591322898865, 1.0, 0.33545804023742676, 0.02750393934547901, 0.037128742784261703, 0.04466118663549423, 0.0203529242426157, 0.018273167312145233, 0.140573188662529, 0.08498218655586243, 0.05887111648917198, 0.022760389372706413, 0.1844673603773117, 0.052569273859262466, 0.12437944859266281, 0.011494286358356476, 0.3124508857727051, 0.0704188272356987, 0.21793445944786072, 0.04560207203030586, 0.14500388503074646, 0.01105482131242752, 0.028649816289544106, 0.015449784696102142, 0.06804649531841278, 0.01901610754430294, 0.036636028438806534, 0.04171859472990036, 47 | 48 | 1.0 49 | ] 50 | }, 51 | "optimizer": { 52 | "lr": 5e-4, 53 | "betas": [0.9, 0.95], 54 | "eps": 1e-8, 55 | "weight_decay": 1e-3 56 | }, 57 | "lr_sched": { 58 | "type": "constant", 59 | "warmup": 0.0 60 | }, 61 | "ema_sched": { 62 | "type": "inverse", 63 | "power": 0.75, 64 | "max_value": 0.9999 65 | } 66 | } 67 | 68 | -------------------------------------------------------------------------------- /configs/strand_vae_train.toml: -------------------------------------------------------------------------------- 1 | [core] 2 | floor_type= "grid" 3 | floor_origin = [0,0,0] 4 | floor_scale = 30.0 5 | auto_create_logger=false 6 | log_level = "off" 7 | 8 | [render] 9 | msaa_nr_samples = 4 10 | -------------------------------------------------------------------------------- /data_loader/difflocks_bodydata/scalp.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/data_loader/difflocks_bodydata/scalp.ply -------------------------------------------------------------------------------- /data_loader/difflocks_bodydata/smplx_base.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/data_loader/difflocks_bodydata/smplx_base.ply -------------------------------------------------------------------------------- /data_processing/create_chunked_strands.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | #from the full_strands.npz it creates chunked versions of it they can be more easily loaded by the trainer for strand vae 4 | 5 | # ./create_chunked_strands.py --dataset_path 6 | 7 | 8 | import sys 9 | import os 10 | import argparse 11 | import numpy as np 12 | from tqdm import tqdm 13 | from torch.utils.data import DataLoader 14 | import math 15 | sys.path.append( 16 | os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 17 | from data_loader.dataloader import DiffLocksDataset 18 | 19 | 20 | def main(): 21 | 22 | #argparse 23 | parser = argparse.ArgumentParser(description='Create latents') 24 | parser.add_argument('--dataset_path', required=True, help='Path to the hair_synth dataset') 25 | args = parser.parse_args() 26 | 27 | 28 | difflocks_dataset = DiffLocksDataset(args.dataset_path, 29 | check_validity=False, 30 | load_full_strands=True 31 | ) 32 | loader = DataLoader(difflocks_dataset, batch_size=1, num_workers=8, shuffle=False, pin_memory=True, persistent_workers=True) 33 | 34 | 35 | 36 | progress_bar = tqdm(range(0, len(difflocks_dataset)), desc="Training progress") 37 | 38 | for batch in loader: 39 | progress_bar.update() 40 | 41 | # print("batch",batch) 42 | 43 | path_hairstyle=batch["path"][0] 44 | print("path", path_hairstyle) 45 | 46 | positions=batch["full_strands"]["positions"].squeeze(0).cpu().numpy() 47 | root_uv=batch["full_strands"]["root_uv"].squeeze(0).cpu().numpy() 48 | root_normal=batch["full_strands"]["root_normal"].squeeze(0).cpu().numpy() 49 | # print("positions", positions.shape) 50 | # print("root_uv", root_uv.shape) 51 | # print("root_normal", root_normal.shape) 52 | 53 | 54 | select_nr_random_strands=10000 55 | nr_strands_per_chunk_list=[1000,100] 56 | 57 | #select random strands because the original 100K is too much 58 | if select_nr_random_strands: 59 | nr_strands_left = positions.shape[0] 60 | per_curve_keep_random = np.random.choice(nr_strands_left, select_nr_random_strands, replace=False) 61 | positions=positions[per_curve_keep_random,:,:] 62 | root_uv=root_uv[per_curve_keep_random,:] 63 | root_normal=root_normal[per_curve_keep_random,:] 64 | 65 | 66 | #break the strands into chunks if needed and write those too 67 | nr_strands_total = positions.shape[0] 68 | if nr_strands_per_chunk_list: 69 | for nr_strands_cur_chunk in nr_strands_per_chunk_list: 70 | nr_chunks = math.ceil(nr_strands_total/nr_strands_cur_chunk) 71 | 72 | positions_chunked = np.array_split(positions, nr_chunks,axis=0) 73 | root_uv_chunked = np.array_split(root_uv, nr_chunks,axis=0) 74 | root_normal_chunked = np.array_split(root_normal, nr_chunks,axis=0) 75 | 76 | #make path for this chunked data 77 | chunked_data_path=os.path.join(path_hairstyle,"full_strands_chunked","nr_strands_"+str(nr_strands_cur_chunk)) 78 | os.makedirs(chunked_data_path, exist_ok=True) 79 | 80 | #write each chunk 81 | for idx_chunk in range(nr_chunks): 82 | npz_path = os.path.join(chunked_data_path,str(idx_chunk)+".npz") 83 | np.savez(npz_path, positions=positions_chunked[idx_chunk],\ 84 | root_uv=root_uv_chunked[idx_chunk],\ 85 | root_normal=root_normal_chunked[idx_chunk]) 86 | 87 | 88 | return 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | 94 | -------------------------------------------------------------------------------- /data_processing/create_latents.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | #creates latent representations of all the rgb images in the dataset, this will be useful when using them to condition the diffusion model 4 | 5 | #./create_latents.py --subsample_factor 1 --dataset_path= --out_path 6 | 7 | 8 | import sys 9 | import os 10 | import argparse 11 | import torch 12 | import torchvision 13 | sys.path.append( 14 | os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 15 | from models.strand_codec import StrandCodec 16 | from torch.utils.data import DataLoader 17 | import utils.resize_right.resize_right as resize_right 18 | import utils.resize_right.interp_methods as interp_methods 19 | from tqdm import tqdm 20 | import torchvision.transforms as T 21 | from data_loader.dataloader import DiffLocksDataset 22 | 23 | 24 | 25 | torch.set_grad_enabled(False) 26 | 27 | 28 | def horizontally_flip(batch): 29 | 30 | rgb_img=batch["rgb_img"] 31 | rgb_img=torchvision.transforms.functional.hflip(rgb_img) 32 | batch["rgb_img"]=rgb_img 33 | 34 | return batch 35 | 36 | 37 | def generate_latents_dinov2(args, batch, preprocessor, model, output_latents_path): 38 | #encode img 39 | rgb_img=batch["rgb_img"].cuda() 40 | rgb_img=rgb_img[:,0:3,:,:] 41 | 42 | 43 | rgb_input = preprocessor(rgb_img).to("cuda") 44 | ret = model.forward_features(rgb_input) 45 | patch_tok = ret["x_norm_patchtokens"].clone() 46 | cls_tok = ret["x_norm_clstoken"].clone() 47 | 48 | # print("outputs",outputs) 49 | 50 | #only makes sense to write the last layer becuase it';s dinov2 and the other ones are not coarser representations 51 | cls_token=cls_tok 52 | # print("cls", cls_token.shape) 53 | patch_embeddings = patch_tok 54 | #reshape to [Batch_size, h, w, embedding] 55 | batch_size, num_patches, hidden_size = patch_embeddings.shape 56 | h = w = int(num_patches ** 0.5) # Assuming the number of patches is a perfect square (e.g., 14x14) 57 | patch_embeddings_reshaped = patch_embeddings.reshape(batch_size, h, w, hidden_size) 58 | patch_embeddings_reshaped=patch_embeddings_reshaped.permute(0,3,1,2).contiguous() #Make it bchw 59 | #write last layer 60 | out_path_final_latent=os.path.join(output_latents_path, "final_latent.pt") 61 | torch.save(patch_embeddings_reshaped, out_path_final_latent) 62 | #write cls token which is like an embedding for the whole image 63 | out_path_cls_token=os.path.join(output_latents_path, "cls_token.pt") 64 | #writing cls token of size 65 | # print("cls_token",cls_token.shape) 66 | torch.save(cls_token, out_path_cls_token) 67 | 68 | # feat_pca = img_2_pca(patch_embeddings_reshaped) 69 | # torchvision.utils.save_image(feat_pca.squeeze(0), os.path.join(output_latents_path, "final_latent.png")) 70 | 71 | 72 | 73 | #write a file to signify that we are done with this folder 74 | #start with x so that rsync copies it last if we copy to local 75 | open( os.path.join(output_latents_path, "x_done.txt"), 'a').close() 76 | 77 | 78 | def main(): 79 | 80 | #argparse 81 | parser = argparse.ArgumentParser(description='Create latents') 82 | parser.add_argument('--dataset_path', required=True, help='Path to the hair_synth dataset') 83 | parser.add_argument('--out_path', required=True, type=str, help='Where to output the processed hair_synth dataset') 84 | parser.add_argument('--subsample_factor', default=1, type=int, help='Subsample factor for the RGB img') 85 | parser.add_argument('--skip_validity_check', dest='check_validity', action='store_false', help='Wether to check for the validity of each hairstyle we read from the dataset. Some older dataset versions might need this turned to false') 86 | args = parser.parse_args() 87 | 88 | 89 | #v2 from torch 90 | image_size = int(768/(2**(args.subsample_factor-1))) 91 | print("Selected dino with img size", image_size) 92 | #going to the nearest multiple of 14 because 14 is the patch size 93 | if image_size==768: 94 | image_size=770 95 | else: 96 | print("I haven't implemented the other ones yet") 97 | latents_preprocessor = T.Compose([ 98 | T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC), 99 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 100 | ]) 101 | latents_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg') 102 | latents_model.cuda() 103 | 104 | latents_model.eval() 105 | 106 | 107 | 108 | 109 | print("args.check_validity",args.check_validity) 110 | 111 | difflocks_dataset = DiffLocksDataset(args.dataset_path, 112 | check_validity=args.check_validity, 113 | load_rgb_imgs=True, 114 | processed_difflocks_path = args.out_path, 115 | subsample_factor=args.subsample_factor, 116 | ) 117 | loader = DataLoader(difflocks_dataset, batch_size=1, num_workers=8, shuffle=False, pin_memory=True, persistent_workers=True) 118 | 119 | 120 | 121 | progress_bar = tqdm(range(0, len(difflocks_dataset)), desc="Training progress") 122 | 123 | for batch in loader: 124 | progress_bar.update() 125 | 126 | #make the output path 127 | output_latents_path=os.path.join(args.out_path, "processed_hairstyles", batch["file"][0], "latents_"+"dinov2"+"_subsample_"+str(args.subsample_factor)) 128 | os.makedirs(output_latents_path, exist_ok=True) 129 | #check if we already created this one 130 | if not os.path.isfile( os.path.join(output_latents_path,"x_done.txt")): 131 | # if True: 132 | #if it doesn't exist or we can't load it we create it 133 | generate_latents_dinov2(args, batch, latents_preprocessor, latents_model, output_latents_path) 134 | 135 | 136 | # #generate also a flipped texture, the reason being that just flipping the rgb does not result in a flipped latents neceserily so we have to horizontally flip the data in the batch then encode a new flipped latent 137 | #make the output path 138 | output_latents_path=os.path.join(args.out_path, "processed_hairstyles", batch["file"][0], "latents_flipped_"+"dinov2"+"_subsample_"+str(args.subsample_factor)) 139 | os.makedirs(output_latents_path, exist_ok=True) 140 | #check if we already created this one 141 | if not os.path.isfile( os.path.join(output_latents_path,"x_done.txt")): 142 | # if True: 143 | batch=horizontally_flip(batch) 144 | #if it doesn't exist or we can't load it we create it 145 | generate_latents_dinov2(args, batch, latents_preprocessor, latents_model, output_latents_path) 146 | 147 | 148 | 149 | #finished training 150 | return 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | 156 | -------------------------------------------------------------------------------- /data_processing/uncompress_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | #given the difflocks dataset downloaded from mpi uncompress everything 4 | 5 | # ./uncompress_data.py --dataset_zipped_path --out_path 6 | 7 | 8 | 9 | import sys 10 | import os 11 | import argparse 12 | import subprocess 13 | from os import listdir 14 | from os.path import isfile, join 15 | 16 | 17 | def main(): 18 | 19 | #argparse 20 | parser = argparse.ArgumentParser(description='Uncompress dataset') 21 | parser.add_argument('--dataset_zipped_path', required=True, help='Path to the difflocks folder containing all zip files') 22 | parser.add_argument('--out_path', required=True, type=str, help='Where to output the difflocks dataset') 23 | args = parser.parse_args() 24 | 25 | 26 | in_path=args.dataset_zipped_path 27 | onlyfiles = [f for f in listdir(in_path) if isfile(join(in_path, f))] 28 | for file_name in onlyfiles: 29 | filepath=os.path.join(in_path,file_name) 30 | cmd=["7zz","x", "-y", filepath, "-o"+args.out_path] 31 | # print("cmd",cmd) 32 | subprocess.run(cmd, capture_output=False) 33 | print("filename", file_name) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | 39 | -------------------------------------------------------------------------------- /dockerfile/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Check args 4 | if [ "$#" -ne 0 ]; then 5 | echo "usage: ./build.sh" 6 | return 1 7 | fi 8 | 9 | 10 | # Build the docker image 11 | docker build\ 12 | --build-arg user=$USER\ 13 | --build-arg uid=$UID\ 14 | --build-arg home=$HOME\ 15 | --build-arg workspace=$SCRIPTPATH\ 16 | --build-arg shell=$SHELL\ 17 | -t difflocks_dock \ 18 | --progress=plain \ 19 | -f Dockerfile . 20 | 21 | -------------------------------------------------------------------------------- /dockerfile/nvidia-container-runtime-script.sh: -------------------------------------------------------------------------------- 1 | curl -s -L https://nvidia.github.io/nvidia-container-runtime/gpgkey | \ 2 | sudo apt-key add - 3 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 4 | curl -s -L https://nvidia.github.io/nvidia-container-runtime/$distribution/nvidia-container-runtime.list | \ 5 | sudo tee /etc/apt/sources.list.d/nvidia-container-runtime.list 6 | sudo apt-get update 7 | sudo apt-get install -y nvidia-container-toolkit 8 | sudo systemctl restart docker -------------------------------------------------------------------------------- /dockerfile/nvidia-container-runtime-script_ubuntu24.sh: -------------------------------------------------------------------------------- 1 | curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ 2 | && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ 3 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ 4 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list 5 | sudo apt-get update 6 | sudo apt-get install -y nvidia-container-toolkit 7 | sudo systemctl restart docker -------------------------------------------------------------------------------- /dockerfile/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Check args 4 | if [ "$#" -ne 0 ]; then 5 | echo "usage: ./run.sh" 6 | exit 1 7 | fi 8 | 9 | # Get this script's path 10 | pushd `dirname $0` > /dev/null 11 | SCRIPTPATH=`pwd` 12 | popd > /dev/null 13 | 14 | set -e 15 | 16 | # --volume $SSH_AUTH_SOCK:/ssh-agent \ 17 | # --env SSH_AUTH_SOCK=/ssh-agent \ 18 | 19 | 20 | # for more info see: https://medium.com/@benjamin.botto/opengl-and-cuda-applications-in-docker-af0eece000f1 21 | # for more info see: https://gist.github.com/RafaelPalomar/f594933bb5c07184408c480184c2afb4 22 | # Run the container with shared X11 23 | docker run\ 24 | --rm \ 25 | --shm-size 12G\ 26 | --gpus all\ 27 | --net host\ 28 | --privileged\ 29 | -e SHELL\ 30 | -e DISPLAY\ 31 | -e DOCKER=1\ 32 | -v /dev:/dev\ 33 | --volume=/run/user/${USER_UID}/pulse:/run/user/1000/pulse \ 34 | -e PULSE_SERVER=unix:${XDG_RUNTIME_DIR}/pulse/native \ 35 | -v ${XDG_RUNTIME_DIR}/pulse/native:${XDG_RUNTIME_DIR}/pulse/native \ 36 | -v ~/.config/pulse/cookie:/root/.config/pulse/cookie \ 37 | --group-add $(getent group audio | cut -d: -f3) \ 38 | --volume="/tmp/.X11-unix:/tmp/.X11-unix:rw" \ 39 | --volume="/etc/group:/etc/group:ro" \ 40 | --volume="/etc/passwd:/etc/passwd:ro" \ 41 | --volume="/etc/shadow:/etc/shadow:ro" \ 42 | -v "$HOME:$HOME:rw"\ 43 | --name difflocks_dock\ 44 | -it difflocks_dock:latest 45 | 46 | -------------------------------------------------------------------------------- /download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } 3 | 4 | 5 | # username and password input 6 | echo -e "\nYou need to register at https://difflocks.is.tue.mpg.de/" 7 | read -p "Username (DIFFLOCKS):" username 8 | read -p "Password (DIFFLOCKS):" password 9 | username=$(urle $username) 10 | password=$(urle $password) 11 | 12 | echo -e "\nDownloading files..." 13 | 14 | 15 | #checkpoints 16 | wget --post-data "username=$username&password=$password" \ 17 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=difflocks_checkpoints.zip" \ 18 | -O "difflocks_checkpoints.zip" --no-check-certificate --continue 19 | # Check if wget failed (non-zero exit code) 20 | if [ $? -ne 0 ]; then 21 | echo "❌ Error downloading body data. Exiting script." 22 | exit 1 23 | fi 24 | 25 | #unzip 26 | unzip difflocks_checkpoints.zip 27 | 28 | 29 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } 3 | 4 | # Get the download path from the first argument 5 | TARGET_PATH="$1" 6 | 7 | # Check if the path was given 8 | if [ -z "$TARGET_PATH" ]; then 9 | echo "Usage: $0 " 10 | exit 1 11 | fi 12 | 13 | # Make sure the directory exists 14 | mkdir -p "$TARGET_PATH" 15 | 16 | # username and password input 17 | echo -e "\nYou need to register at https://difflocks.is.tue.mpg.de/" 18 | read -p "Username (DIFFLOCKS):" username 19 | read -p "Password (DIFFLOCKS):" password 20 | username=$(urle $username) 21 | password=$(urle $password) 22 | 23 | echo -e "\nDownloading files..." 24 | 25 | 26 | #BODY DATA 27 | filename='difflocks_dataset_body_data.7z' 28 | wget --post-data "username=$username&password=$password" \ 29 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=$filename" \ 30 | -O "$TARGET_PATH/$filename" --no-check-certificate --continue 31 | # Check if wget failed (non-zero exit code) 32 | if [ $? -ne 0 ]; then 33 | echo "❌ Error downloading body data. Exiting script." 34 | exit 1 35 | fi 36 | 37 | #IMGS 38 | for i in $(seq 0 20); do 39 | filename="difflocks_dataset_imgs_chunk_${i}.7z" 40 | # Run wget and capture status 41 | wget --post-data "username=$username&password=$password" \ 42 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=${filename}" \ 43 | -O "$TARGET_PATH/$filename" --no-check-certificate --continue 44 | # Check if wget failed (non-zero exit code) 45 | if [ $? -ne 0 ]; then 46 | echo "❌ Error downloading $filename. Exiting script." 47 | exit 1 48 | fi 49 | done 50 | 51 | #IMGS v2 52 | for i in $(seq 0 20); do 53 | filename="difflocks_dataset_imgs_v2_chunk_${i}.7z" 54 | # Run wget and capture status 55 | wget --post-data "username=$username&password=$password" \ 56 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=${filename}" \ 57 | -O "$TARGET_PATH/$filename" --no-check-certificate --continue 58 | # Check if wget failed (non-zero exit code) 59 | if [ $? -ne 0 ]; then 60 | echo "❌ Error downloading $filename. Exiting script." 61 | exit 1 62 | fi 63 | done 64 | 65 | #HAIRSTYLES 66 | for i in $(seq 0 4455); do 67 | filename="difflocks_dataset_hairstyles_chunk_${i}.7z" 68 | # Run wget and capture status 69 | wget --post-data "username=$username&password=$password" \ 70 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=${filename}" \ 71 | -O "$TARGET_PATH/$filename" --no-check-certificate --continue 72 | # Check if wget failed (non-zero exit code) 73 | if [ $? -ne 0 ]; then 74 | echo "❌ Error downloading $filename. Exiting script." 75 | exit 1 76 | fi 77 | done 78 | 79 | 80 | -------------------------------------------------------------------------------- /download_validation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } 3 | 4 | 5 | # username and password input 6 | echo -e "\nYou need to register at https://difflocks.is.tue.mpg.de/" 7 | read -p "Username (DIFFLOCKS):" username 8 | read -p "Password (DIFFLOCKS):" password 9 | username=$(urle $username) 10 | password=$(urle $password) 11 | 12 | echo -e "\nDownloading files..." 13 | 14 | 15 | #checkpoints 16 | wget --post-data "username=$username&password=$password" \ 17 | "https://download.is.tue.mpg.de/download.php?domain=difflocks&sfile=difflocks_validation.7z" \ 18 | -O "difflocks_validation.7z" --no-check-certificate --continue 19 | # Check if wget failed (non-zero exit code) 20 | if [ $? -ne 0 ]; then 21 | echo "❌ Error downloading body data. Exiting script." 22 | exit 1 23 | fi 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /extensions/include/cuMat/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Header-only library 2 | 3 | # Define 'cuMat_headers' variable to list all the header files 4 | # Todo: maybe replace by GLOB 5 | set(CUMAT_HEADERS 6 | src/Macros.h 7 | src/ForwardDeclarations.h 8 | src/Profiling.h 9 | src/Logging.h 10 | src/Errors.h 11 | src/Constants.h 12 | src/Context.h 13 | src/Allocator.h 14 | src/NumTraits.h 15 | src/DevicePointer.h 16 | src/EigenInteropHelpers.h 17 | src/MatrixBase.h 18 | src/Matrix.h 19 | src/IO.h 20 | src/CwiseOp.h 21 | src/MatrixBlock.h 22 | src/MatrixBlockPluginLvalue.inl 23 | src/MatrixBlockPluginRvalue.inl 24 | src/NullaryOps.h 25 | src/MatrixNullaryOpsPlugin.inl 26 | src/UnaryOps.h 27 | src/UnaryOpsPlugin.inl 28 | src/CudaUtils.h 29 | src/TransposeOp.h 30 | src/BinaryOps.h 31 | src/BinaryOpsPlugin.inl 32 | src/ReductionOps.h 33 | src/ReductionOpsPlugin.inl 34 | src/ReductionAlgorithmSelection.h 35 | src/Iterator.h 36 | src/CublasApi.h 37 | src/SimpleRandom.h 38 | src/ProductOp.h 39 | Core 40 | 41 | src/CusolverApi.h 42 | src/SolverBase.h 43 | src/DecompositionBase.h 44 | src/LUDecomposition.h 45 | src/CholeskyDecomposition.h 46 | src/DenseLinAlgOps.h 47 | src/DenseLinAlgPlugin.inl 48 | Dense 49 | 50 | src/SparseMatrixBase.h 51 | src/SparseMatrix.h 52 | src/SparseEvaluation.h 53 | src/SparseProductEvaluation.h 54 | src/SparseExpressionOp.h 55 | src/SparseExpressionOpPlugin.inl 56 | Sparse 57 | 58 | src/IterativeSolverBase.h 59 | src/ConjugateGradient.h 60 | IterativeLinearSolvers 61 | ) 62 | 63 | add_library(cuMat INTERFACE) # 'moduleA' is an INTERFACE pseudo target 64 | 65 | # 66 | # From here, the target 'moduleA' can be customised 67 | # 68 | target_include_directories(cuMat INTERFACE ${CMAKE_SOURCE_DIR}) # Transitively forwarded 69 | target_include_directories(cuMat INTERFACE ${CUDA_INCLUDE_DIRS}) 70 | #install(TARGETS cuMat ...) 71 | 72 | # 73 | # HACK: have the files showing in the IDE, under the name 'moduleA_ide' 74 | # 75 | option(CUMAT_SOURCE_LIBRARY "show the source headers as a library in the IDE" OFF) 76 | if(CUMAT_SOURCE_LIBRARY) 77 | cuda_add_library(cuMat_ide ${CUMAT_HEADERS}) 78 | target_include_directories(cuMat_ide INTERFACE ${CMAKE_SOURCE_DIR}) # Transitively forwarded 79 | target_include_directories(cuMat_ide INTERFACE ${CUDA_INCLUDE_DIRS}) 80 | endif(CUMAT_SOURCE_LIBRARY) 81 | -------------------------------------------------------------------------------- /extensions/include/cuMat/Core: -------------------------------------------------------------------------------- 1 | 2 | 3 | /* 4 | Specifies the core modules. 5 | This enables all BLAS-related modules 6 | */ 7 | 8 | //first thing: tell the compiler to ignore some stuid warnings 9 | #include "src/DisableCompilerWarnings.h" 10 | 11 | #include "src/Macros.h" 12 | #include "src/ForwardDeclarations.h" 13 | #include "src/NumTraits.h" 14 | 15 | #include "src/Logging.h" 16 | #include "src/Errors.h" 17 | #include "src/Context.h" 18 | 19 | #include "src/MatrixBase.h" 20 | #include "src/Matrix.h" 21 | #include "src/MatrixBlock.h" 22 | #include "src/EigenInteropHelpers.h" 23 | #include "src/IO.h" 24 | 25 | #include "src/Iterator.h" 26 | #include "src/NullaryOps.h" 27 | #include "src/UnaryOps.h" 28 | #include "src/TransposeOp.h" 29 | #include "src/BinaryOps.h" 30 | #include "src/ReductionOps.h" 31 | #include "src/ProductOp.h" 32 | 33 | #include "src/SimpleRandom.h" 34 | -------------------------------------------------------------------------------- /extensions/include/cuMat/Dense: -------------------------------------------------------------------------------- 1 | 2 | /* 3 | * Specifies dense linear algebra routines (LAPACK) 4 | * This needs cuSolver. 5 | */ 6 | 7 | // Core is always needed 8 | #include "Core" 9 | 10 | #include "src/CusolverApi.h" 11 | #include "src/LUDecomposition.h" 12 | #include "src/CholeskyDecomposition.h" 13 | #include "src/DenseLinAlgOps.h" 14 | -------------------------------------------------------------------------------- /extensions/include/cuMat/IterativeLinearSolvers: -------------------------------------------------------------------------------- 1 | /* 2 | * Specifies iterative linear solver routines. 3 | */ 4 | 5 | // Core is always needed 6 | #include "Core" 7 | 8 | #include "src/IterativeSolverBase.h" 9 | #include "src/ConjugateGradient.h" 10 | -------------------------------------------------------------------------------- /extensions/include/cuMat/Sparse: -------------------------------------------------------------------------------- 1 | /* 2 | * Specifies sparse linear algebra routines 3 | * This needs cuSparse. 4 | */ 5 | 6 | // Core is always needed 7 | #include "Core" 8 | 9 | // sparse classes 10 | #include "src/SparseMatrixBase.h" 11 | #include "src/SparseMatrix.h" 12 | #include "src/SparseExpressionOp.h" 13 | #include "src/SparseEvaluation.h" 14 | #include "src/SparseProductEvaluation.h" 15 | -------------------------------------------------------------------------------- /extensions/include/cuMat/src/Allocator.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_ALLOCATOR_H__ 2 | #define __CUMAT_ALLOCATOR_H__ 3 | 4 | #include "Macros.h" 5 | 6 | //TODO: add interface to support custom allocators 7 | //Up to now, the allocator from CUB is hard-coded 8 | 9 | CUMAT_NAMESPACE_BEGIN 10 | 11 | class AllocatorBase 12 | { 13 | 14 | }; 15 | 16 | CUMAT_NAMESPACE_END 17 | 18 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/BinaryOpsPlugin.inl: -------------------------------------------------------------------------------- 1 | //Included inside MatrixBase, define the accessors 2 | 3 | #define BINARY_OP_ACCESSOR(Name) \ 4 | template \ 5 | BinaryOp<_Derived, _Right, functor::BinaryMathFunctor_ ## Name > Name (const MatrixBase<_Right>& rhs) const { \ 6 | CUMAT_ERROR_IF_NO_NVCC(Name) \ 7 | return BinaryOp<_Derived, _Right, functor::BinaryMathFunctor_ ## Name >(derived(), rhs.derived()); \ 8 | } \ 9 | template::value, \ 10 | BinaryOp<_Derived, HostScalar, functor::BinaryMathFunctor_ ## Name > >::type > \ 11 | T Name(const _Right& rhs) const { \ 12 | CUMAT_ERROR_IF_NO_NVCC(Name) \ 13 | return BinaryOp<_Derived, HostScalar, functor::BinaryMathFunctor_ ## Name >(derived(), HostScalar(rhs)); \ 14 | } 15 | #define BINARY_OP_ACCESSOR_INV(Name) \ 16 | template \ 17 | BinaryOp<_Left, _Derived, functor::BinaryMathFunctor_ ## Name > Name ## Inv(const MatrixBase<_Left>& lhs) const { \ 18 | CUMAT_ERROR_IF_NO_NVCC(Name) \ 19 | return BinaryOp<_Left, _Derived, functor::BinaryMathFunctor_ ## Name >(lhs.derived(), derived()); \ 20 | } \ 21 | template::value, \ 22 | BinaryOp, _Derived, functor::BinaryMathFunctor_ ## Name > >::type > \ 23 | T Name ## Inv(const _Left& lhs) const { \ 24 | CUMAT_ERROR_IF_NO_NVCC(Name) \ 25 | return BinaryOp, _Derived, functor::BinaryMathFunctor_ ## Name >(HostScalar(lhs), derived()); \ 26 | } 27 | 28 | /** 29 | * \brief computes the component-wise multiplation (this*rhs) 30 | */ 31 | BINARY_OP_ACCESSOR(cwiseMul) 32 | 33 | /** 34 | * \brief computes the component-wise dot-product. 35 | * The dot product of every individual element is computed, 36 | * for regular matrices/vectors with scalar entries, 37 | * this is exactly equal to the component-wise multiplication (\ref cwiseMul). 38 | * However, one can also use Matrix with submatrices/vectors as entries 39 | * and then this operation might have the real dot-product 40 | * if the respective functor \ref functor::BinaryMathFunctor_cwiseMDot is specialized. 41 | */ 42 | BINARY_OP_ACCESSOR(cwiseDot) 43 | 44 | /** 45 | * \brief computes the component-wise division (this/rhs) 46 | */ 47 | BINARY_OP_ACCESSOR(cwiseDiv) 48 | 49 | /** 50 | * \brief computes the inverted component-wise division (rhs/this) 51 | */ 52 | BINARY_OP_ACCESSOR_INV(cwiseDiv) 53 | 54 | /** 55 | * \brief computes the component-wise exponent (this^rhs) 56 | */ 57 | BINARY_OP_ACCESSOR(cwisePow) 58 | 59 | /** 60 | * \brief computes the inverted component-wise exponent (rhs^this) 61 | */ 62 | BINARY_OP_ACCESSOR_INV(cwisePow) 63 | 64 | /** 65 | * \brief computes the component-wise binary AND (this & rhs). 66 | * Only available for integer matrices. 67 | */ 68 | BINARY_OP_ACCESSOR(cwiseBinaryAnd) 69 | 70 | /** 71 | * \brief computes the component-wise binary OR (this | rhs). 72 | * Only available for integer matrices. 73 | */ 74 | BINARY_OP_ACCESSOR(cwiseBinaryOr) 75 | 76 | /** 77 | * \brief computes the component-wise binary XOR (this ^ rhs). 78 | * Only available for integer matrices. 79 | */ 80 | BINARY_OP_ACCESSOR(cwiseBinaryXor) 81 | 82 | /** 83 | * \brief computes the component-wise logical AND (this && rhs). 84 | * Only available for boolean matrices 85 | */ 86 | BINARY_OP_ACCESSOR(cwiseLogicalAnd) 87 | 88 | /** 89 | * \brief computes the component-wise logical OR (this || rhs). 90 | * Only available for boolean matrices 91 | */ 92 | BINARY_OP_ACCESSOR(cwiseLogicalOr) 93 | 94 | /** 95 | * \brief computes the component-wise logical XOR (this ^ rhs). 96 | * Only available for boolean matrices 97 | */ 98 | BINARY_OP_ACCESSOR(cwiseLogicalXor) 99 | 100 | #undef BINARY_OP_ACCESSOR 101 | #undef BINARY_OP_ACCESSOR_INV 102 | 103 | /** 104 | * \brief Custom binary expression. 105 | * The binary functor must support look as follow: 106 | * \code 107 | * struct MyFunctor 108 | * { 109 | * typedef OutputType ReturnType; 110 | * __device__ CUMAT_STRONG_INLINE ReturnType operator()(const LeftType& x, const RightType& y, Index row, Index col, Index batch) const 111 | * { 112 | * return ... 113 | * } 114 | * }; 115 | * \endcode 116 | * where \c LeftType is the type of this matrix expression, 117 | * \c RightType is the type of the rhs matrix, 118 | * and \c OutputType is the output type. 119 | * 120 | * \param rhs the matrix expression on the right hand side 121 | * \param functor the functor to apply component-wise 122 | * \return an expression of a component-wise binary expression with a custom functor applied per component. 123 | */ 124 | template 125 | UnaryOp<_Derived, Functor> binaryExpr(const Right& rhs, const Functor& functor = Functor()) const 126 | { 127 | CUMAT_ERROR_IF_NO_NVCC(binaryExpr) 128 | return BinaryOp<_Derived, Right, Functor>(derived(), rhs.derived(), functor); 129 | } -------------------------------------------------------------------------------- /extensions/include/cuMat/src/Constants.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_CONSTANTS_H__ 2 | #define __CUMAT_CONSTANTS_H__ 3 | 4 | 5 | #include "Macros.h" 6 | 7 | CUMAT_NAMESPACE_BEGIN 8 | 9 | /** 10 | * \brief This value means that a positive quantity (e.g., a size) is not known at compile-time, 11 | * and that instead the value is stored in some runtime variable. 12 | */ 13 | const int Dynamic = -1; 14 | 15 | /** 16 | * \brief The bit flags for the matrix expressions. 17 | */ 18 | enum Flags 19 | { 20 | /** 21 | * \brief The storage is column major (the default). 22 | */ 23 | ColumnMajor = 0x00, 24 | /** 25 | * \brief The storage is row major. 26 | */ 27 | RowMajor = 0x01, 28 | 29 | }; 30 | #define CUMAT_IS_COLUMN_MAJOR(flags) (((flags) & CUMAT_NAMESPACE Flags::RowMajor)==0) 31 | #define CUMAT_IS_ROW_MAJOR(flags) ((flags) & CUMAT_NAMESPACE Flags::RowMajor) 32 | 33 | /** 34 | * \brief Flags that specify how the data in a MatrixBase-expression can be accessed. 35 | */ 36 | enum AccessFlags 37 | { 38 | /** 39 | * \brief component-wise read is available. 40 | * The following method must be provided: 41 | * \code 42 | * __device__ const Scalar& coeff(Index row, Index col, Index batch, Index index) const; 43 | * \endcode 44 | * The parameter \c index is the same as the linear index from the writing procedure and is used 45 | * by optimized routines only if the user explicitly enables them. 46 | * (This is only supported by SparseMatrix yet) 47 | * If some operation can't pass the linear index to the expressions, -1 might be used instead. 48 | */ 49 | ReadCwise = 0x01, 50 | /** 51 | * \brief direct read is available, the underlying memory is directly adressable. 52 | * The following method must be provided: 53 | * \code 54 | * __host__ __device__ const _Scalar* data() const; 55 | * __host__ bool isExclusiveUse() const; 56 | * \endcode 57 | */ 58 | ReadDirect = 0x02, 59 | /** 60 | * \brief Component-wise read is available. 61 | * To allow the implementation to specify the access order, the following methods have to be provided: 62 | * \code 63 | * __host__ Index size() const; 64 | * __device__ void index(Index index, Index& row, Index& col, Index& batch) const; 65 | * __device__ void setRawCoeff(Index index, const Scalar& newValue); 66 | * \endcode 67 | */ 68 | WriteCwise = 0x10, 69 | /** 70 | * \brief Direct write is available, the underlying memory can be directly written to. 71 | * The following method must be provided: 72 | * \code 73 | * __host__ __device__ _Scalar* data() 74 | * \endcode 75 | */ 76 | WriteDirect = 0x20, 77 | /** 78 | * \brief This extends \c WriteCwise and allows inplace modifications (compound operators) by additionally providing the function 79 | * \code __device__ const Scalar& getRawCoeff(Index index) const; \endcode . 80 | * To enable compound assignment with this as target type, either RWCwise or RWCwiseRef (or both) have to be defined. 81 | */ 82 | RWCwise = 0x40, 83 | /** 84 | * \brief This extends \c WriteCwise and allows inplace modifications (compound operators) by additionally providing the function 85 | * \code __device__ Scalar& rawCoeff(Index index); \endcode for read-write access to that entry. 86 | * To enable compound assignment with this as target type, either RWCwise or RWCwiseRef (or both) have to be defined. 87 | */ 88 | RWCwiseRef = 0x80, 89 | }; 90 | 91 | /** 92 | * \brief The axis over which reductions are performed. 93 | */ 94 | enum Axis 95 | { 96 | NoAxis = 0, 97 | Row = 1, 98 | Column = 2, 99 | Batch = 4, 100 | All = Row | Column | Batch 101 | }; 102 | 103 | /** 104 | * \brief Tags for the different reduction algorithms 105 | */ 106 | namespace ReductionAlg 107 | { 108 | /** 109 | * \brief reduction with cub::DeviceSegmentedReduce 110 | */ 111 | struct Segmented {}; 112 | /** 113 | * \brief Thread reduction. Each thread reduces a batch. 114 | */ 115 | struct Thread {}; 116 | /** 117 | * \brief Warp reduction. Each warp reduces a batch. 118 | */ 119 | struct Warp {}; 120 | /** 121 | * \brief Block reduction. Each block reduces a batch. 122 | * \tparam N the block size 123 | */ 124 | template 125 | struct Block {}; 126 | /** 127 | * \brief Device reduction. 128 | * Each reduction per batch is computed with a separate call to cub::DeviceReduce, 129 | * parallelized over N cuda streams. 130 | * \tparam N the number of parallel streams 131 | */ 132 | template 133 | struct Device {}; 134 | /** 135 | * \brief Automatic algorithm selection. 136 | * Chooses the algorithm during runtime based on the matrix sizes. 137 | */ 138 | struct Auto {}; 139 | } 140 | 141 | /** 142 | * \brief Specifies the assignment mode in \c Assignment::assign() . 143 | * This is the difference between regular assignment (operator==, \c AssignmentMode::ASSIGN) 144 | * and inplace modifications like operator+= (\c AssignmentMode::ADD). 145 | * 146 | * Note that not all assignment modes have to be supported for all scalar types 147 | * and all right hand sides. 148 | * For example: 149 | * - MUL (=*) and DIV (=\) are only supported for scalar right hand sides (broadcasting) 150 | * to avoid the ambiguity if component-wise or matrix operations are meant 151 | * - MOD (%=), AND (&=), OR (|=) are only supported for integer types 152 | */ 153 | enum class AssignmentMode 154 | { 155 | ASSIGN, 156 | ADD, 157 | SUB, 158 | MUL, 159 | DIV, 160 | MOD, 161 | AND, 162 | OR, 163 | }; 164 | 165 | /** 166 | * \brief Flags for the SparseMatrix class. 167 | */ 168 | enum SparseFlags 169 | { 170 | /** 171 | * \brief Matrix stored in the Compressed Sparse Column format. 172 | */ 173 | CSC = 1, 174 | /** 175 | * \brief Matrix stored in the Compressed Sparse Row format. 176 | */ 177 | CSR = 2, 178 | /** 179 | * \brief Column-major ELLPACK format. 180 | * This format is optimized for matrices with uniform nnz per row. 181 | */ 182 | ELLPACK = 3, 183 | }; 184 | 185 | CUMAT_NAMESPACE_END 186 | 187 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/CudaUtils.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_CUDA_UTILS_H__ 2 | #define __CUMAT_CUDA_UTILS_H__ 3 | 4 | #include "Macros.h" 5 | 6 | // SOME CUDA UTILITIES 7 | 8 | CUMAT_NAMESPACE_BEGIN 9 | 10 | namespace cuda 11 | { 12 | /** 13 | * \brief Loads the given scalar value from the specified address, 14 | * possible cached. 15 | * \tparam T the type of scalar 16 | * \param ptr the pointer to the scalar 17 | * \return the value at that adress 18 | */ 19 | template 20 | __device__ CUMAT_STRONG_INLINE T load(const T* ptr) 21 | { 22 | //#if __CUDA_ARCH__ >= 350 23 | // return __ldg(ptr); 24 | //#else 25 | return *ptr; 26 | //#endif 27 | } 28 | #if __CUDA_ARCH__ >= 350 29 | #define LOAD(T) \ 30 | template<> \ 31 | __device__ CUMAT_STRONG_INLINE T load(const T* ptr) \ 32 | { \ 33 | return __ldg(ptr); \ 34 | } 35 | #else 36 | #define LOAD(T) 37 | #endif 38 | 39 | LOAD(char); 40 | LOAD(short); 41 | LOAD(int); 42 | LOAD(long); 43 | LOAD(long long); 44 | LOAD(unsigned char); 45 | LOAD(unsigned short); 46 | LOAD(unsigned int); 47 | LOAD(unsigned long); 48 | LOAD(unsigned long long); 49 | LOAD(int2); 50 | LOAD(int4); 51 | LOAD(uint2); 52 | LOAD(uint4); 53 | LOAD(float); 54 | LOAD(float2); 55 | LOAD(float4); 56 | LOAD(double); 57 | LOAD(double2); 58 | 59 | #undef LOAD 60 | 61 | } 62 | 63 | CUMAT_NAMESPACE_END 64 | 65 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/CusolverApi.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/extensions/include/cuMat/src/CusolverApi.h -------------------------------------------------------------------------------- /extensions/include/cuMat/src/CwiseOp.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_CWISE_OP_H__ 2 | #define __CUMAT_CWISE_OP_H__ 3 | 4 | #include "Macros.h" 5 | #include "ForwardDeclarations.h" 6 | #include "MatrixBase.h" 7 | #include "Context.h" 8 | #include "Logging.h" 9 | #include "Profiling.h" 10 | 11 | CUMAT_NAMESPACE_BEGIN 12 | 13 | namespace internal { 14 | //compound assignment 15 | template 16 | struct CwiseAssignmentHandler 17 | { 18 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index) 19 | { 20 | matrix.setRawCoeff(index, value); 21 | } 22 | }; 23 | 24 | #define CUMAT_DECLARE_ASSIGNMENT_HANDLER(mode, op1, op2) \ 25 | template \ 26 | struct CwiseAssignmentHandler \ 27 | { \ 28 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index, std::integral_constant) /*reference access*/ \ 29 | { \ 30 | matrix.rawCoeff(index) op1 value; \ 31 | } \ 32 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index, std::integral_constant) /*read-write access*/ \ 33 | { \ 34 | matrix.setRawCoeff(index) = matrix.getRawCoeff(index) op2 value; \ 35 | } \ 36 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index, std::integral_constant) /*both is possible, use reference*/ \ 37 | { \ 38 | matrix.rawCoeff(index) op1 value; \ 39 | } \ 40 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index) \ 41 | { \ 42 | assign(matrix, value, index, std::integral_constant::AccessFlags & (AccessFlags::RWCwiseRef | AccessFlags::RWCwise)>()); \ 43 | } \ 44 | }; 45 | 46 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(ADD, +=, +) 47 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(SUB, -=, -) 48 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(MUL, *=, *) 49 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(DIV, /=, /) 50 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(MOD, %=, %) 51 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(AND, &=, &) 52 | CUMAT_DECLARE_ASSIGNMENT_HANDLER(OR, |=, |) 53 | 54 | //partial specialization for normal assignment 55 | template 56 | struct CwiseAssignmentHandler 57 | { 58 | static __device__ CUMAT_STRONG_INLINE void assign(M& matrix, V value, Index index) 59 | { 60 | matrix.setRawCoeff(index, value); 61 | } 62 | }; 63 | 64 | namespace kernels 65 | { 66 | template 67 | __global__ void CwiseEvaluationKernel(dim3 virtual_size, const T expr, M matrix) 68 | { 69 | //By using a 1D-loop over the linear index, 70 | //the target matrix can determine the order of rows, columns and batches. 71 | //E.g. by storage order (row major / column major) 72 | //Later, this may come in hand if sparse matrices or diagonal matrices are allowed 73 | //that only evaluate certain elements. 74 | CUMAT_KERNEL_1D_LOOP(index, virtual_size) 75 | 76 | Index i, j, k; 77 | matrix.index(index, i, j, k); 78 | 79 | //there seems to be a bug in CUDA if the result of expr.coeff is directly passed to setRawCoeff. 80 | //By saving it in a local variable, this is prevented 81 | auto val = expr.coeff(i, j, k, index); 82 | internal::CwiseAssignmentHandler::assign(matrix, val, index); 83 | 84 | CUMAT_KERNEL_1D_LOOP_END 85 | } 86 | } 87 | } 88 | 89 | /** 90 | * \brief Base class of all component-wise expressions. 91 | * It defines the evaluation logic. 92 | * 93 | * A component-wise expression can be evaluated to any object that 94 | * - inherits MatrixBase 95 | * - defines a __host__ Index size() const method that returns the number of entries 96 | * - defines a __device__ void index(Index index, Index& row, Index& col, Index& batch) const 97 | * method to convert from raw index (from 0 to size()-1) to row, column and batch index 98 | * - defines a __Device__ void setRawCoeff(Index index, const Scalar& newValue) method 99 | * that is used to write the results back. 100 | * 101 | * Currently, the following classes support this interface and can therefore be used 102 | * as the left-hand-side of a component-wise expression: 103 | * - Matrix 104 | * - MatrixBlock 105 | * 106 | * \tparam _Derived the type of the derived expression 107 | */ 108 | template 109 | class CwiseOp : public MatrixBase<_Derived> 110 | { 111 | public: 112 | typedef _Derived Type; 113 | typedef MatrixBase<_Derived> Base; 114 | CUMAT_PUBLIC_API 115 | using Base::rows; 116 | using Base::cols; 117 | using Base::batches; 118 | using Base::size; 119 | 120 | __device__ CUMAT_STRONG_INLINE const Scalar& coeff(Index row, Index col, Index batch, Index index) const 121 | { 122 | return derived().coeff(row, col, batch, index); 123 | } 124 | 125 | }; 126 | 127 | namespace internal 128 | { 129 | //General assignment for everything that fullfills CwiseSrcTag into DenseDstTag (cwise dense evaluation) 130 | template 131 | struct Assignment<_Dst, _Src, _Mode, DenseDstTag, CwiseSrcTag> 132 | { 133 | static void assign(_Dst& dst, const _Src& src) 134 | { 135 | #if CUMAT_NVCC==1 136 | typedef typename _Dst::Type DstActual; 137 | typedef typename _Src::Type SrcActual; 138 | CUMAT_PROFILING_INC(EvalCwise); 139 | CUMAT_PROFILING_INC(EvalAny); 140 | if (dst.size() == 0) return; 141 | CUMAT_ASSERT(src.rows() == dst.rows()); 142 | CUMAT_ASSERT(src.cols() == dst.cols()); 143 | CUMAT_ASSERT(src.batches() == dst.batches()); 144 | 145 | CUMAT_LOG_DEBUG("Evaluate component wise expression " << typeid(src.derived()).name() 146 | << "\n rows=" << src.rows() << ", cols=" << src.cols() << ", batches=" << src.batches()); 147 | 148 | //here is now the real logic 149 | Context& ctx = Context::current(); 150 | KernelLaunchConfig cfg = ctx.createLaunchConfig1D(static_cast(dst.size()), kernels::CwiseEvaluationKernel); 151 | kernels::CwiseEvaluationKernel <<>> (cfg.virtual_size, src.derived(), dst.derived()); 152 | CUMAT_CHECK_ERROR(); 153 | CUMAT_LOG_DEBUG("Evaluation done"); 154 | #else 155 | CUMAT_ERROR_IF_NO_NVCC(general_component_wise_evaluation) 156 | #endif 157 | } 158 | }; 159 | } 160 | 161 | CUMAT_NAMESPACE_END 162 | 163 | #endif 164 | -------------------------------------------------------------------------------- /extensions/include/cuMat/src/DecompositionBase.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_DECOMPOSITION_BASE_H__ 2 | #define __CUMAT_DECOMPOSITION_BASE_H__ 3 | 4 | #include "Macros.h" 5 | #include "SolverBase.h" 6 | 7 | CUMAT_NAMESPACE_BEGIN 8 | 9 | template 10 | class DecompositionBase : public SolverBase<_DecompositionImpl> 11 | { 12 | public: 13 | using Base = SolverBase<_DecompositionImpl>; 14 | using typename Base::Scalar; 15 | using Base::Rows; 16 | using Base::Columns; 17 | using Base::Batches; 18 | using Base::impl; 19 | 20 | typedef SolveOp<_DecompositionImpl, NullaryOp > > InverseResultType; 21 | /** 22 | * \brief Computes the inverse of the input matrix. 23 | * \return The inverse matrix 24 | */ 25 | InverseResultType inverse() const 26 | { 27 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(Rows > 0 && Columns > 0, Rows == Columns), 28 | "Static count of rows and columns must be equal (square matrix)"); 29 | CUMAT_ASSERT(impl().rows() == impl().cols()); 30 | 31 | return impl().solve(NullaryOp >( 32 | impl().rows(), impl().cols(), impl().batches(), functor::IdentityFunctor())); 33 | } 34 | 35 | typedef Matrix DeterminantMatrix; 36 | 37 | /** 38 | * \brief Computes the determinant of this matrix 39 | * \return The determinant 40 | */ 41 | DeterminantMatrix determinant() const 42 | { 43 | return impl().determinant(); 44 | } 45 | /** 46 | * \brief Computes the log-determinant of this matrix. 47 | * \return The log-determinant 48 | */ 49 | DeterminantMatrix logDeterminant() const 50 | { 51 | return impl().logDeterminant(); 52 | } 53 | }; 54 | 55 | CUMAT_NAMESPACE_END 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /extensions/include/cuMat/src/DenseLinAlgPlugin.inl: -------------------------------------------------------------------------------- 1 | //Included inside MatrixBase, define the accessors 2 | 3 | /** 4 | * \brief Computes and returns the LU decomposition with pivoting of this matrix. 5 | * The resulting decomposition can then be used to compute the determinant of the matrix, 6 | * invert the matrix and solve multiple linear equation systems. 7 | */ 8 | LUDecomposition<_Derived> decompositionLU() const 9 | { 10 | return LUDecomposition<_Derived>(derived()); 11 | } 12 | 13 | /** 14 | * \brief Computes and returns the Cholesky decompositionof this matrix. 15 | * The matrix must be Hermetian and positive definite. 16 | * The resulting decomposition can then be used to compute the determinant of the matrix, 17 | * invert the matrix and solve multiple linear equation systems. 18 | */ 19 | CholeskyDecomposition<_Derived> decompositionCholesky() const 20 | { 21 | return CholeskyDecomposition<_Derived>(derived()); 22 | } 23 | 24 | /** 25 | * \brief Computes the determinant of this matrix. 26 | * \return the determinant of this matrix 27 | */ 28 | DeterminantOp<_Derived> determinant() const 29 | { 30 | return DeterminantOp<_Derived>(derived()); 31 | } 32 | 33 | /** 34 | * \brief Computes the log-determinant of this matrix. 35 | * This is only supported for hermitian positive definite matrices, because no sign is computed. 36 | * A negative determinant would return in an complex logarithm which requires to return 37 | * a complex result for real matrices. This is not desired. 38 | * \return the log-determinant of this matrix 39 | */ 40 | Matrix::Scalar, 1, 1, internal::traits<_Derived>::BatchesAtCompileTime, ColumnMajor> logDeterminant() const 41 | { 42 | //TODO: implement direct methods for matrices up to 4x4. 43 | return decompositionLU().logDeterminant(); 44 | } 45 | 46 | /** 47 | * \brief Computes the determinant of this matrix. 48 | * For matrices of up to 4x4, an explicit formula is used. For larger matrices, this method falls back to a Cholesky Decomposition. 49 | * \return the inverse of this matrix 50 | */ 51 | InverseOp<_Derived> inverse() const 52 | { 53 | return InverseOp<_Derived>(derived()); 54 | } 55 | 56 | /** 57 | * \brief Computation of matrix inverse and determinant in one kernel call. 58 | * 59 | * This is only for fixed-size square matrices of size up to 4x4. 60 | * 61 | * \param inverse Reference to the matrix in which to store the inverse. 62 | * \param determinant Reference to the variable in which to store the determinant. 63 | * 64 | * \see inverse(), determinant() 65 | */ 66 | template 67 | void computeInverseAndDet(InverseType& inverseOut, DetType& detOut) const 68 | { 69 | CUMAT_STATIC_ASSERT(Rows >= 1 && Rows <= 4, "This matrix must be a compile-time 1x1, 2x2, 3x3 or 4x4 matrix"); 70 | CUMAT_STATIC_ASSERT(Columns >= 1 && Columns <= 4, "This matrix must be a compile-time 1x1, 2x2, 3x3 or 4x4 matrix"); 71 | CUMAT_STATIC_ASSERT(Rows == Columns, "This matrix must be symmetric"); 72 | CUMAT_STATIC_ASSERT(Rows >= 1 && internal::traits::RowsAtCompileTime, "The output matrix must have the same compile-time size as this matrix"); 73 | CUMAT_STATIC_ASSERT(Columns >= 1 && internal::traits::ColsAtCompileTime, "The output matrix must have the same compile-time size as this matrix"); 74 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(Batches > 0 && internal::traits::BatchesAtCompileTime > 0, Batches == internal::traits::BatchesAtCompileTime), 75 | "This matrix and the output matrix must have the same batch size"); 76 | CUMAT_ASSERT_DIMENSION(batches() == inverseOut.batches()); 77 | 78 | CUMAT_STATIC_ASSERT(internal::traits::RowsAtCompileTime == 1, "The determinant output must be a (batched) scalar, i.e. compile-time 1x1 matrix"); 79 | CUMAT_STATIC_ASSERT(internal::traits::ColsAtCompileTime == 1, "The determinant output must be a (batched) scalar, i.e. compile-time 1x1 matrix"); 80 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(Batches > 0 && internal::traits::BatchesAtCompileTime > 0, Batches == internal::traits::BatchesAtCompileTime), 81 | "This matrix and the determinant matrix must have the same batch size"); 82 | CUMAT_ASSERT_DIMENSION(batches() == detOut.batches()); 83 | 84 | CUMAT_NAMESPACE ComputeInverseWithDet<_Derived, Rows, InverseType, DetType>::run(derived(), inverseOut, detOut); 85 | } 86 | -------------------------------------------------------------------------------- /extensions/include/cuMat/src/DevicePointer.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_DEVICE_POINTER_H__ 2 | #define __CUMAT_DEVICE_POINTER_H__ 3 | 4 | #include 5 | 6 | #include "Macros.h" 7 | #include "Context.h" 8 | 9 | CUMAT_NAMESPACE_BEGIN 10 | 11 | template 12 | class DevicePointer 13 | { 14 | private: 15 | T* pointer_; 16 | int* counter_; 17 | CUMAT_NAMESPACE Context* context_; 18 | friend class DevicePointer::type>; 19 | 20 | __host__ __device__ 21 | void release() 22 | { 23 | #ifndef __CUDA_ARCH__ 24 | //no decrement of the counter in CUDA-code, counter is in host-memory 25 | assert(CUMAT_IMPLIES(counter_, (*counter_) > 0) && "Attempt to calling release() twice"); 26 | if ((counter_) && (--(*counter_) == 0)) 27 | { 28 | delete counter_; 29 | 30 | //DEBUG 31 | if (&Context::current() != context_) 32 | { 33 | CUMAT_LOG_WARNING( 34 | "Freeing memory with a different context than the current context.\n" 35 | "This will likely crash with an invalid-resource-handle error due to different Cub-Allocators"); 36 | } 37 | 38 | context_->freeDevice(pointer_); 39 | } 40 | #endif 41 | } 42 | 43 | public: 44 | DevicePointer(size_t size, CUMAT_NAMESPACE Context& ctx) 45 | : pointer_(nullptr) 46 | , counter_(nullptr) 47 | { 48 | context_ = &ctx; 49 | pointer_ = static_cast(context_->mallocDevice(size * sizeof(T))); 50 | try { 51 | counter_ = new int(1); 52 | } 53 | catch (...) 54 | { 55 | context_->freeDevice(const_cast::type*>(pointer_)); 56 | throw; 57 | } 58 | } 59 | DevicePointer(size_t size) 60 | : DevicePointer(size, CUMAT_NAMESPACE Context::current()) 61 | {} 62 | 63 | __host__ __device__ 64 | DevicePointer() 65 | : pointer_(nullptr) 66 | , counter_(nullptr) 67 | , context_(nullptr) 68 | {} 69 | 70 | __host__ __device__ 71 | DevicePointer(const DevicePointer& rhs) 72 | : pointer_(rhs.pointer_) 73 | , counter_(rhs.counter_) 74 | , context_(rhs.context_) 75 | { 76 | #ifndef __CUDA_ARCH__ 77 | //no increment of the counter in CUDA-code, counter is in host-memory 78 | if (counter_) { 79 | ++(*counter_); 80 | } 81 | #endif 82 | } 83 | 84 | __host__ __device__ 85 | DevicePointer(DevicePointer&& rhs) noexcept 86 | : pointer_(std::move(rhs.pointer_)) 87 | , counter_(std::move(rhs.counter_)) 88 | , context_(std::move(rhs.context_)) 89 | { 90 | rhs.pointer_ = nullptr; 91 | rhs.counter_ = nullptr; 92 | rhs.context_ = nullptr; 93 | } 94 | 95 | __host__ __device__ 96 | DevicePointer& operator=(const DevicePointer& rhs) 97 | { 98 | release(); 99 | pointer_ = rhs.pointer_; 100 | counter_ = rhs.counter_; 101 | context_ = rhs.context_; 102 | #ifndef __CUDA_ARCH__ 103 | //no increment of the counter in CUDA-code, counter is in host-memory 104 | if (counter_) { 105 | ++(*counter_); 106 | } 107 | #endif 108 | return *this; 109 | } 110 | 111 | __host__ __device__ 112 | DevicePointer& operator=(DevicePointer&& rhs) noexcept 113 | { 114 | release(); 115 | pointer_ = std::move(rhs.pointer_); 116 | counter_ = std::move(rhs.counter_); 117 | context_ = std::move(rhs.context_); 118 | rhs.pointer_ = nullptr; 119 | rhs.counter_ = nullptr; 120 | rhs.context_ = nullptr; 121 | return *this; 122 | } 123 | 124 | __host__ __device__ 125 | void swap(DevicePointer& rhs) throw() 126 | { 127 | std::swap(pointer_, rhs.pointer_); 128 | std::swap(counter_, rhs.counter_); 129 | std::swap(context_, rhs.context_); 130 | } 131 | 132 | __host__ __device__ 133 | ~DevicePointer() 134 | { 135 | release(); 136 | } 137 | 138 | __host__ __device__ T* pointer() { return pointer_; } 139 | __host__ __device__ const T* pointer() const { return pointer_; } 140 | 141 | /** 142 | * \brief Returns the current value of the reference counter. 143 | * This can be used to determine if this memory is used uniquely 144 | * by an object. 145 | * \return the current number of references 146 | */ 147 | size_t getCounter() const { return counter_ ? *counter_ : 0; } 148 | }; 149 | 150 | CUMAT_NAMESPACE_END 151 | 152 | #endif 153 | -------------------------------------------------------------------------------- /extensions/include/cuMat/src/DisableCompilerWarnings.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_DISABLE_COMPILER_WARNINGS_H__ 2 | #define __CUMAT_DISABLE_COMPILER_WARNINGS_H__ 3 | 4 | 5 | #if defined __GNUC__ 6 | 7 | //This is a stupid warning in GCC that pops up everytime storage qualifies or so are compared 8 | #pragma GCC diagnostic ignored "-Wenum-compare" 9 | 10 | #endif 11 | 12 | 13 | #endif 14 | -------------------------------------------------------------------------------- /extensions/include/cuMat/src/EigenInteropHelpers.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_EIGEN_INTEROP_HELPERS_H__ 2 | #define __CUMAT_EIGEN_INTEROP_HELPERS_H__ 3 | 4 | 5 | #include "Macros.h" 6 | #include "Constants.h" 7 | 8 | #if CUMAT_EIGEN_SUPPORT==1 9 | #include 10 | #include 11 | 12 | CUMAT_NAMESPACE_BEGIN 13 | 14 | //forward-declare Matrix 15 | template 16 | class Matrix; 17 | 18 | /** 19 | * Namespace for eigen interop. 20 | */ 21 | namespace eigen 22 | { 23 | //There are now a lot of redundant qualifiers, but I want to make 24 | //it clear if we are in the Eigen world or in cuMat world. 25 | 26 | //Flag conversion 27 | 28 | template 29 | struct StorageCuMatToEigen {}; 30 | template<> 31 | struct StorageCuMatToEigen<::cuMat::Flags::ColumnMajor> 32 | { 33 | enum { value = ::Eigen::StorageOptions::ColMajor }; 34 | }; 35 | template<> 36 | struct StorageCuMatToEigen<::cuMat::Flags::RowMajor> 37 | { 38 | enum { value = ::Eigen::StorageOptions::RowMajor }; 39 | }; 40 | 41 | template 42 | struct StorageEigenToCuMat {}; 43 | template<> 44 | struct StorageEigenToCuMat<::Eigen::StorageOptions::RowMajor> 45 | { 46 | enum { value = ::cuMat::Flags::RowMajor }; 47 | }; 48 | template<> 49 | struct StorageEigenToCuMat<::Eigen::StorageOptions::ColMajor> 50 | { 51 | enum { value = ::cuMat::Flags::ColumnMajor }; 52 | }; 53 | 54 | //Size conversion (for dynamic tag) 55 | 56 | template 57 | struct SizeCuMatToEigen 58 | { 59 | enum { size = _Size }; 60 | }; 61 | template<> 62 | struct SizeCuMatToEigen<::cuMat::Dynamic> 63 | { 64 | enum {size = ::Eigen::Dynamic}; 65 | }; 66 | 67 | template 68 | struct SizeEigenToCuMat 69 | { 70 | enum {size = _Size}; 71 | }; 72 | template<> 73 | struct SizeEigenToCuMat<::Eigen::Dynamic> 74 | { 75 | enum{size = ::cuMat::Dynamic}; 76 | }; 77 | 78 | //Type conversion 79 | template 80 | struct TypeCuMatToEigen 81 | { 82 | typedef T type; 83 | }; 84 | template<> 85 | struct TypeCuMatToEigen 86 | { 87 | typedef std::complex type; 88 | }; 89 | template<> 90 | struct TypeCuMatToEigen 91 | { 92 | typedef std::complex type; 93 | }; 94 | 95 | template 96 | struct TypeEigenToCuMat 97 | { 98 | typedef T type; 99 | }; 100 | template<> 101 | struct TypeEigenToCuMat > 102 | { 103 | typedef cfloat type; 104 | }; 105 | template<> 106 | struct TypeEigenToCuMat > 107 | { 108 | typedef cdouble type; 109 | }; 110 | 111 | //Matrix type conversion 112 | 113 | template 114 | struct MatrixCuMatToEigen 115 | { 116 | using type = ::Eigen::Matrix< 117 | //typename TypeCuMatToEigen::type, 118 | typename _CuMatMatrixType::Scalar, 119 | SizeCuMatToEigen<_CuMatMatrixType::Rows>::size, 120 | SizeCuMatToEigen<_CuMatMatrixType::Columns>::size, 121 | //Eigen requires specific storage types for vector sizes 122 | ((_CuMatMatrixType::Rows==1) ? ::Eigen::StorageOptions::RowMajor 123 | : (_CuMatMatrixType::Columns==1) ? ::Eigen::StorageOptions::ColMajor 124 | : StorageCuMatToEigen<_CuMatMatrixType::Flags>::value) 125 | | ::Eigen::DontAlign //otherwise, toEigen() will produce strange errors because we access the native data pointer 126 | >; 127 | }; 128 | template 129 | struct MatrixEigenToCuMat 130 | { 131 | using type = ::cuMat::Matrix< 132 | //typename TypeEigenToCuMat::type, 133 | typename _EigenMatrixType::Scalar, 134 | SizeEigenToCuMat<_EigenMatrixType::RowsAtCompileTime>::size, 135 | SizeEigenToCuMat<_EigenMatrixType::ColsAtCompileTime>::size, 136 | 1, //batch size of 1 137 | StorageEigenToCuMat<_EigenMatrixType::Options>::value 138 | >; 139 | }; 140 | 141 | } 142 | 143 | CUMAT_NAMESPACE_END 144 | 145 | //tell Eigen how to handle cfloat and cdouble 146 | namespace Eigen 147 | { 148 | template<> struct NumTraits : NumTraits> {}; 149 | template<> struct NumTraits : NumTraits> {}; 150 | } 151 | 152 | 153 | #endif 154 | 155 | #endif 156 | -------------------------------------------------------------------------------- /extensions/include/cuMat/src/Errors.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_ERRORS_H__ 2 | #define __CUMAT_ERRORS_H__ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "Macros.h" 13 | #include "Logging.h" 14 | 15 | CUMAT_NAMESPACE_BEGIN 16 | 17 | class cuda_error : public std::exception 18 | { 19 | private: 20 | std::string message_; 21 | public: 22 | cuda_error(std::string message) 23 | : message_(message) 24 | {} 25 | 26 | const char* what() const throw() override 27 | { 28 | return message_.c_str(); 29 | } 30 | }; 31 | 32 | namespace internal { 33 | 34 | class ErrorHelpers 35 | { 36 | public: 37 | static std::string vformat(const char *fmt, va_list ap) 38 | { 39 | // Allocate a buffer on the stack that's big enough for us almost 40 | // all the time. Be prepared to allocate dynamically if it doesn't fit. 41 | size_t size = 1024; 42 | char stackbuf[1024]; 43 | std::vector dynamicbuf; 44 | char *buf = &stackbuf[0]; 45 | va_list ap_copy; 46 | 47 | while (1) { 48 | // Try to vsnprintf into our buffer. 49 | va_copy(ap_copy, ap); 50 | int needed = vsnprintf(buf, size, fmt, ap); 51 | va_end(ap_copy); 52 | 53 | // NB. C99 (which modern Linux and OS X follow) says vsnprintf 54 | // failure returns the length it would have needed. But older 55 | // glibc and current Windows return -1 for failure, i.e., not 56 | // telling us how much was needed. 57 | 58 | if (needed <= (int)size && needed >= 0) { 59 | // It fit fine so we're done. 60 | return std::string(buf, (size_t)needed); 61 | } 62 | 63 | // vsnprintf reported that it wanted to write more characters 64 | // than we allotted. So try again using a dynamic buffer. This 65 | // doesn't happen very often if we chose our initial size well. 66 | size = (needed > 0) ? (needed + 1) : (size * 2); 67 | dynamicbuf.resize(size); 68 | buf = &dynamicbuf[0]; 69 | } 70 | } 71 | //Taken from https://stackoverflow.com/a/69911/4053176 72 | 73 | static std::string format(const char *fmt, ...) 74 | { 75 | va_list ap; 76 | va_start(ap, fmt); 77 | std::string buf = vformat(fmt, ap); 78 | va_end(ap); 79 | return buf; 80 | } 81 | //Taken from https://stackoverflow.com/a/69911/4053176 82 | 83 | // Taken from https://codeyarns.com/2011/03/02/how-to-do-error-checking-in-cuda/ 84 | // and adopted 85 | 86 | private: 87 | static bool evalError(cudaError err, const char* file, const int line) 88 | { 89 | if (cudaErrorCudartUnloading == err) { 90 | std::string msg = format("cudaCheckError() failed at %s:%i : %s\nThis error can happen in multi-threaded applications during shut-down and is ignored.\n", 91 | file, line, cudaGetErrorString(err)); 92 | CUMAT_LOG_SEVERE(msg); 93 | return false; 94 | } 95 | else if (cudaSuccess != err) { 96 | std::string msg = format("cudaSafeCall() failed at %s:%i : %s\n", 97 | file, line, cudaGetErrorString(err)); 98 | CUMAT_LOG_SEVERE(msg); 99 | throw cuda_error(msg); 100 | } 101 | return true; 102 | } 103 | static bool evalErrorNoThrow(cudaError err, const char* file, const int line) 104 | { 105 | if (cudaSuccess != err) { 106 | std::string msg = format("cudaSafeCall() failed at %s:%i : %s\n", 107 | file, line, cudaGetErrorString(err)); 108 | CUMAT_LOG_SEVERE(msg); 109 | return false; 110 | } 111 | return true; 112 | } 113 | public: 114 | static void cudaSafeCall(cudaError err, const char *file, const int line) 115 | { 116 | if (!evalError(err, file, line)) return; 117 | #if CUMAT_VERBOSE_ERROR_CHECKING==1 118 | //insert a device-sync 119 | err = cudaDeviceSynchronize(); 120 | evalError(err, file, line); 121 | #endif 122 | } 123 | 124 | static bool cudaSafeCallNoThrow(cudaError err, const char* file, const int line) 125 | { 126 | if (!evalErrorNoThrow(err, file, line)) return false; 127 | #if CUMAT_VERBOSE_ERROR_CHECKING==1 128 | //insert a device-sync 129 | err = cudaDeviceSynchronize(); 130 | if (!evalError(err, file, line)) return false; 131 | #endif 132 | return true; 133 | } 134 | 135 | static void cudaCheckError(const char *file, const int line) 136 | { 137 | cudaError err = cudaGetLastError(); 138 | if (!evalError(err, file, line)) return; 139 | 140 | #if CUMAT_VERBOSE_ERROR_CHECKING==1 141 | // More careful checking. However, this will affect performance. 142 | err = cudaDeviceSynchronize(); 143 | evalError(err, file, line); 144 | #endif 145 | } 146 | }; 147 | 148 | /** 149 | * \brief Tests if the cuda library call wrapped inside the bracets was executed successfully, aka returned cudaSuccess. 150 | * Throws an cuMat::cuda_error if unsuccessfull 151 | * \param err the error code 152 | */ 153 | #define CUMAT_SAFE_CALL( err ) CUMAT_NAMESPACE internal::ErrorHelpers::cudaSafeCall( err, __FILE__, __LINE__ ) 154 | 155 | /** 156 | * \brief Tests if the cuda library call wrapped inside the bracets was executed successfully, aka returned cudaSuccess 157 | * Returns false iff unsuccessfull 158 | * \param err the error code 159 | */ 160 | #define CUMAT_SAFE_CALL_NO_THROW( err ) CUMAT_NAMESPACE internal::ErrorHelpers::cudaSafeCallNoThrow( err, __FILE__, __LINE__ ) 161 | /** 162 | * \brief Issue this after kernel launches to check for errors in the kernel. 163 | */ 164 | #define CUMAT_CHECK_ERROR() CUMAT_NAMESPACE internal::ErrorHelpers::cudaCheckError( __FILE__, __LINE__ ) 165 | 166 | //TODO: find a better place in some Utility header 167 | 168 | /** 169 | * \brief Numeric type conversion with overflow check. 170 | * If CUMAT_ENABLE_HOST_ASSERTIONS==1, this method 171 | * throws an std::runtime_error if the conversion results in 172 | * an overflow. 173 | * 174 | * If CUMAT_ENABLE_HOST_ASSERTIONS is not defined (default in release mode), 175 | * this method simply becomes static_cast. 176 | * 177 | * Source: The C++ Programming Language 4th Edition by Bjarne Stroustrup 178 | * https://stackoverflow.com/a/30114062/1786598 179 | * 180 | * \tparam Target the target type 181 | * \tparam Source the source type 182 | * \param v the source value 183 | * \return the casted target value 184 | */ 185 | template 186 | CUMAT_STRONG_INLINE Target narrow_cast(Source v) 187 | { 188 | #if CUMAT_ENABLE_HOST_ASSERTIONS==1 189 | auto r = static_cast(v); // convert the value to the target type 190 | if (static_cast(r) != v) 191 | throw std::runtime_error("narrow_cast<>() failed"); 192 | return r; 193 | #else 194 | return static_cast(v); 195 | #endif 196 | } 197 | 198 | } 199 | CUMAT_NAMESPACE_END 200 | 201 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/Logging.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_LOGGING_H__ 2 | #define __CUMAT_LOGGING_H__ 3 | 4 | #include "Macros.h" 5 | 6 | /** 7 | * Defines the logging macros. 8 | * All logging messages are started with a call to 9 | * e.g. CUMAT_LOG_INFO(message) 10 | * 11 | * You can define the CUMAT_LOG and the related logging levels 12 | * to point to your own logging implementation. 13 | * If you don't overwrite these, a very trivial logger is used 14 | * that simply prints to std::cout. 15 | * 16 | * This is achieved by globally defining CUMAT_LOGGING_PLUGIN 17 | * that includes a file that then defines all the logging macros: 18 | * CUMAT_LOG_DEBUG, CUMAT_LOG_INFO, CUMAT_LOG_WARNING, CUMAT_LOG_SEVERE 19 | */ 20 | 21 | #ifndef CUMAT_LOG 22 | 23 | #include 24 | #include 25 | #include 26 | 27 | #ifdef CUMAT_LOGGING_PLUGIN 28 | #include CUMAT_LOGGING_PLUGIN 29 | #endif 30 | 31 | CUMAT_NAMESPACE_BEGIN 32 | 33 | 34 | class DummyLogger 35 | { 36 | private: 37 | std::ios_base::fmtflags flags_; 38 | bool enabled_; 39 | 40 | public: 41 | DummyLogger(const std::string& level) 42 | : enabled_(level != "[debug]") //to disable the many, many logs during kernel evaluations in the test suites 43 | { 44 | flags_ = std::cout.flags(); 45 | if (enabled_) std::cout << level << " "; 46 | } 47 | ~DummyLogger() 48 | { 49 | if (enabled_) { 50 | std::cout << std::endl; 51 | std::cout.flags(flags_); 52 | } 53 | } 54 | 55 | template 56 | DummyLogger& operator<<(const T& t) { 57 | if (enabled_) std::cout << t; 58 | return *this; 59 | } 60 | }; 61 | 62 | #ifndef CUMAT_LOG_DEBUG 63 | /** 64 | * Logs the message as a debug message 65 | */ 66 | #define CUMAT_LOG_DEBUG(...) DummyLogger("[debug]") << __VA_ARGS__ 67 | #endif 68 | 69 | #ifndef CUMAT_LOG_INFO 70 | /** 71 | * Logs the message as a information message 72 | */ 73 | #define CUMAT_LOG_INFO(...) DummyLogger("[info]") << __VA_ARGS__ 74 | #endif 75 | #ifndef CUMAT_LOG_WARNING 76 | /** 77 | * Logs the message as a warning message 78 | */ 79 | #define CUMAT_LOG_WARNING(...) DummyLogger("[warning]") << __VA_ARGS__ 80 | #endif 81 | #ifndef CUMAT_LOG_SEVERE 82 | /** 83 | * Logs the message as a severe error message 84 | */ 85 | #define CUMAT_LOG_SEVERE(...) DummyLogger("[SEVERE]") << __VA_ARGS__ 86 | #endif 87 | 88 | CUMAT_NAMESPACE_END 89 | 90 | #endif 91 | 92 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/Macros.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_MACROS_H__ 2 | #define __CUMAT_MACROS_H__ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | /* 10 | * This file contains global macros and type definitions used everywhere 11 | */ 12 | 13 | #ifndef CUMAT_NAMESPACE 14 | /** 15 | * \brief The namespace of the library 16 | */ 17 | #define CUMAT_NAMESPACE ::cuMat:: 18 | #endif 19 | 20 | #ifndef CUMAT_NAMESPACE_BEGIN 21 | /** 22 | * \brief Defines the namespace in which everything of cuMat lives in 23 | */ 24 | #define CUMAT_NAMESPACE_BEGIN namespace cuMat { 25 | #endif 26 | 27 | #ifndef CUMAT_NAMESPACE_END 28 | /** 29 | * \brief Closes the namespace opened with CUMAT_NAMESPACE_BEGIN 30 | */ 31 | #define CUMAT_NAMESPACE_END } 32 | #endif 33 | 34 | #ifndef CUMAT_FUNCTION_NAMESPACE_BEGIN 35 | /** 36 | * \brief Defines the namespace in which overloaed math functions are defined. 37 | * Examples are 'sin(x)' and 'pow(a,b)'. 38 | * This defaults to "cuMat::functions", but can be changed if needed. 39 | */ 40 | #define CUMAT_FUNCTION_NAMESPACE_BEGIN namespace cuMat { namespace functions { 41 | #endif 42 | 43 | #ifndef CUMAT_FUNCTION_NAMESPACE_END 44 | /** 45 | * \brief Closes the namespace openeded with CUMAT_FUNCTION_NAMESPACE_BEGIN 46 | */ 47 | #define CUMAT_FUNCTION_NAMESPACE_END }} 48 | #endif 49 | 50 | #ifndef CUMAT_FUNCTION_NAMESPACE 51 | /** 52 | * \brief The namespace in which overloaded math functions are defined. 53 | */ 54 | #define CUMAT_FUNCTION_NAMESPACE ::cuMat::functions:: 55 | #endif 56 | 57 | #ifdef _MSC_VER 58 | //running under MS Visual Studio -> no thread_local 59 | #define CUMAT_THREAD_LOCAL __declspec( thread ) 60 | #else 61 | //C++11 compatible 62 | #define CUMAT_THREAD_LOCAL thread_local 63 | #endif 64 | 65 | #ifndef CUMAT_EIGEN_SUPPORT 66 | /** 67 | * \brief Set CUMAT_EIGEN_INTEROP to 1 to enable Eigen interop. 68 | * This enables the methods to convert between Eigen matrices and cuMat matrices. 69 | * Default: 0 70 | */ 71 | #define CUMAT_EIGEN_SUPPORT 0 72 | #endif 73 | 74 | #ifdef __CUDACC__ 75 | /** 76 | * If the current source file is compiled with the NVCC as a CUDA source file, 77 | * this macro is set to one, else to zero. 78 | */ 79 | #define CUMAT_NVCC 1 80 | #else 81 | #define CUMAT_NVCC 0 82 | #endif 83 | 84 | #if CUMAT_NVCC==0 85 | #define CUMAT_ERROR_IF_NO_NVCC(name) {THIS_FUNCTION_REQUIRES_THE_FILE_TO_BE_COMPILED_WITH_NVCC name;} 86 | #else 87 | #define CUMAT_ERROR_IF_NO_NVCC(name) 88 | #endif 89 | 90 | /** 91 | * \brief Define this macro in a class that should not be copyable or assignable 92 | * \param TypeName the name of the class 93 | */ 94 | #define CUMAT_DISALLOW_COPY_AND_ASSIGN(TypeName)\ 95 | TypeName(const TypeName&) = delete; \ 96 | void operator=(const TypeName&) = delete 97 | 98 | #define CUMAT_STR_DETAIL(x) #x 99 | #define CUMAT_STR(x) CUMAT_STR_DETAIL(x) 100 | 101 | 102 | /* 103 | * \brief enable verbose error checking after each kernel launch. 104 | * This implies a synchronization point after kernel. 105 | * 106 | * By default, this is only enabled in a debug build, but it can be 107 | * manually activated by defining CUMAT_VERBOSE_ERROR_CHECKING=1 108 | */ 109 | #ifndef CUMAT_VERBOSE_ERROR_CHECKING 110 | #if defined(_DEBUG) || (!defined(NDEBUG) && !defined(_NDEBUG)) 111 | #define CUMAT_VERBOSE_ERROR_CHECKING 1 112 | #else 113 | #define CUMAT_VERBOSE_ERROR_CHECKING 0 114 | #endif 115 | #endif 116 | 117 | /* 118 | * \brief enable host assertions. 119 | * The assertion macros (CUMAT_ASSERT* without CUMAT_ASSERT_CUDA) 120 | * will throw an exception if triggered. 121 | * 122 | * By default, this is only enabled in a debug build, but it can be 123 | * manually activated by defining CUMAT_ENABLE_HOST_ASSERTIONS=1 124 | */ 125 | #ifndef CUMAT_ENABLE_HOST_ASSERTIONS 126 | #if defined(_DEBUG) || (!defined(NDEBUG) && !defined(_NDEBUG)) 127 | #define CUMAT_ENABLE_HOST_ASSERTIONS 1 128 | #else 129 | #define CUMAT_ENABLE_HOST_ASSERTIONS 0 130 | #endif 131 | #endif 132 | 133 | /* 134 | * \brief enable device assertions. 135 | * The assertion macros CUMAT_ASSERT_CUDA 136 | * will throw an exception if triggered. 137 | * 138 | * By default, this is only enabled in a debug build, but it can be 139 | * manually activated by defining CUMAT_ENABLE_DEVICE_ASSERTIONS=1 140 | */ 141 | #ifndef CUMAT_ENABLE_DEVICE_ASSERTIONS 142 | #if defined(_DEBUG) || (!defined(NDEBUG) && !defined(_NDEBUG)) 143 | #define CUMAT_ENABLE_DEVICE_ASSERTIONS 1 144 | #else 145 | #define CUMAT_ENABLE_DEVICE_ASSERTIONS 0 146 | #endif 147 | #endif 148 | 149 | #if CUMAT_ENABLE_HOST_ASSERTIONS==1 150 | 151 | /** 152 | * \brief Runtime assertion, uses assert() 153 | * Only use for something that should never happen 154 | * \param x the expression that must be true 155 | */ 156 | #define CUMAT_ASSERT(x) assert(x) 157 | 158 | #define CUMAT_ASSERT_ARGUMENT(x) \ 159 | if (!(x)) throw std::invalid_argument(__FILE__ ":" CUMAT_STR(__LINE__) ": Invalid argument: " #x); 160 | #define CUMAT_ASSERT_BOUNDS(x) \ 161 | if (!(x)) throw std::out_of_range(__FILE__ ":" CUMAT_STR(__LINE__) "Out of bounds: " #x); 162 | #define CUMAT_ASSERT_ERROR(x) \ 163 | if (!(x)) throw std::runtime_error(__FILE__ ":" CUMAT_STR(__LINE__) "Runtime Error: " #x); 164 | #define CUMAT_ASSERT_DIMENSION(x) \ 165 | if (!(x)) throw std::invalid_argument(__FILE__ ":" CUMAT_STR(__LINE__) ": Invalid dimensions: " #x); 166 | 167 | #else 168 | 169 | #define CUMAT_ASSERT(x) 170 | #define CUMAT_ASSERT_ARGUMENT(x) 171 | #define CUMAT_ASSERT_BOUNDS(x) 172 | #define CUMAT_ASSERT_ERROR(x) 173 | #define CUMAT_ASSERT_DIMENSION(x) 174 | 175 | #endif 176 | 177 | #if CUMAT_ENABLE_DEVICE_ASSERTIONS==1 178 | /** 179 | * \brief Assertions in device code (if supported) 180 | * \param x the expression that must be true 181 | */ 182 | #define CUMAT_ASSERT_CUDA(x) assert(x) 183 | #else 184 | #define CUMAT_ASSERT_CUDA(x) 185 | #endif 186 | 187 | #define CUMAT_THROW_INVALID_ARGUMENT(msg) \ 188 | throw std::invalid_argument(__FILE__ ":" CUMAT_STR(__LINE__) ": " msg); 189 | #define CUMAT_THROW_RUNTIME_ERROR(msg) \ 190 | throw std::runtime_error(__FILE__ ":" CUMAT_STR(__LINE__) ": " msg); 191 | #define CUMAT_THROW_OUT_OF_RANGE(msg) \ 192 | throw std::out_of_range(__FILE__ ":" CUMAT_STR(__LINE__) ": " msg); 193 | 194 | /** 195 | * \brief A static assertion 196 | * \param exp the compile-time boolean expression that must be true 197 | * \param msg an error message if exp is false 198 | */ 199 | #define CUMAT_STATIC_ASSERT(exp, msg) static_assert(exp, msg) 200 | 201 | #define CUMAT_STRONG_INLINE __inline__ 202 | 203 | 204 | /** 205 | * \brief Returns the integer division x/y rounded up. 206 | * Taken from https://stackoverflow.com/a/2745086/4053176 207 | */ 208 | #define CUMAT_DIV_UP(x, y) (((x) + (y) - 1) / (y)) 209 | 210 | /** 211 | * \brief Computes the logical implication (a -> b) 212 | */ 213 | #define CUMAT_IMPLIES(a,b) (!(a) || (b)) 214 | 215 | 216 | #define CUMAT_PUBLIC_API_NO_METHODS \ 217 | enum \ 218 | { \ 219 | Flags = internal::traits::Flags, \ 220 | Rows = internal::traits::RowsAtCompileTime, \ 221 | Columns = internal::traits::ColsAtCompileTime, \ 222 | Batches = internal::traits::BatchesAtCompileTime \ 223 | }; \ 224 | using Scalar = typename internal::traits::Scalar; \ 225 | using SrcTag = typename internal::traits::SrcTag; \ 226 | using DstTag = typename internal::traits::DstTag; 227 | /** 228 | * \brief Defines the basic API of each cumat class. 229 | * The typedefs and enums that have to be exposed. 230 | * But before that, you have to define the current class in \c Type 231 | * and the base class in \c Base. 232 | */ 233 | #define CUMAT_PUBLIC_API \ 234 | CUMAT_PUBLIC_API_NO_METHODS \ 235 | using Base::derived; \ 236 | using Base::eval_t; 237 | 238 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/MatrixBase.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_MATRIX_BASE_H__ 2 | #define __CUMAT_MATRIX_BASE_H__ 3 | 4 | #include 5 | 6 | #include "Macros.h" 7 | #include "ForwardDeclarations.h" 8 | #include "Constants.h" 9 | 10 | CUMAT_NAMESPACE_BEGIN 11 | 12 | /** 13 | * \brief The base class of all matrix types and matrix expressions. 14 | * \tparam _Derived 15 | */ 16 | template 17 | class MatrixBase 18 | { 19 | public: 20 | typedef _Derived Type; 21 | typedef MatrixBase<_Derived> Base; 22 | CUMAT_PUBLIC_API_NO_METHODS 23 | 24 | /** 25 | * \returns a reference to the _Derived object 26 | */ 27 | __host__ __device__ CUMAT_STRONG_INLINE _Derived& derived() { return *static_cast<_Derived*>(this); } 28 | 29 | /** 30 | * \returns a const reference to the _Derived object 31 | */ 32 | __host__ __device__ CUMAT_STRONG_INLINE const _Derived& derived() const { return *static_cast(this); } 33 | 34 | /** 35 | * \brief Returns the number of rows of this matrix. 36 | * \returns the number of rows. 37 | */ 38 | __host__ __device__ CUMAT_STRONG_INLINE Index rows() const { return derived().rows(); } 39 | /** 40 | * \brief Returns the number of columns of this matrix. 41 | * \returns the number of columns. 42 | */ 43 | __host__ __device__ CUMAT_STRONG_INLINE Index cols() const { return derived().cols(); } 44 | /** 45 | * \brief Returns the number of batches of this matrix. 46 | * \returns the number of batches. 47 | */ 48 | __host__ __device__ CUMAT_STRONG_INLINE Index batches() const { return derived().batches(); } 49 | /** 50 | * \brief Returns the total number of entries in this matrix. 51 | * This value is computed as \code rows()*cols()*batches()* \endcode 52 | * \return the total number of entries 53 | */ 54 | __host__ __device__ CUMAT_STRONG_INLINE Index size() const { return rows()*cols()*batches(); } 55 | 56 | // EVALUATION 57 | 58 | typedef Matrix< 59 | typename internal::traits<_Derived>::Scalar, 60 | internal::traits<_Derived>::RowsAtCompileTime, 61 | internal::traits<_Derived>::ColsAtCompileTime, 62 | internal::traits<_Derived>::BatchesAtCompileTime, 63 | internal::traits<_Derived>::Flags 64 | > eval_t; 65 | 66 | /** 67 | * \brief Evaluates this into a matrix. 68 | * This evaluates any expression template. If this is already a matrix, it is returned unchanged. 69 | * \return the evaluated matrix 70 | */ 71 | eval_t eval() const 72 | { 73 | return eval_t(derived()); 74 | } 75 | 76 | /** 77 | * \brief Conversion: Matrix of size 1-1-1 (scalar) in device memory to the host memory scalar. 78 | * 79 | * This is expecially usefull to directly use the results of full reductions in host code. 80 | * 81 | * \tparam T 82 | */ 83 | explicit operator Scalar () const 84 | { 85 | CUMAT_STATIC_ASSERT( 86 | internal::traits<_Derived>::RowsAtCompileTime == 1 && 87 | internal::traits<_Derived>::ColsAtCompileTime == 1 && 88 | internal::traits<_Derived>::BatchesAtCompileTime == 1, 89 | "Conversion only possible for compile-time scalars"); 90 | eval_t m = eval(); 91 | Scalar v; 92 | m.copyToHost(&v); 93 | return v; 94 | } 95 | 96 | 97 | // CWISE EXPRESSIONS 98 | #include "MatrixBlockPluginRvalue.inl" 99 | #include "UnaryOpsPlugin.inl" 100 | #include "BinaryOpsPlugin.inl" 101 | #include "ReductionOpsPlugin.inl" 102 | #include "DenseLinAlgPlugin.inl" 103 | #include "SparseExpressionOpPlugin.inl" 104 | }; 105 | 106 | 107 | 108 | template 109 | struct MatrixReadWrapper 110 | { 111 | private: 112 | enum 113 | { 114 | //the existing access flags 115 | flags = internal::traits<_Derived>::AccessFlags, 116 | //boolean if the access is sufficient 117 | sufficient = (flags & _AccessFlags) 118 | }; 119 | public: 120 | /** 121 | * \brief The wrapped type: either the type itself, if the access is sufficient, 122 | * or the evaluated type if not. 123 | */ 124 | using type = typename std::conditional::type; 125 | 126 | /* 127 | template>::type> 128 | static type wrap(const T& m) 129 | { 130 | return m.derived(); 131 | } 132 | template>::type> 133 | static type wrap(const T& m) 134 | { 135 | return m.derived().eval(); 136 | } 137 | */ 138 | 139 | private: 140 | MatrixReadWrapper(){} //not constructible 141 | }; 142 | 143 | CUMAT_NAMESPACE_END 144 | 145 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/NullaryOps.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_NULLARY_OPS_H__ 2 | #define __CUMAT_NULLARY_OPS_H__ 3 | 4 | #include "Macros.h" 5 | #include "ForwardDeclarations.h" 6 | #include "CwiseOp.h" 7 | 8 | CUMAT_NAMESPACE_BEGIN 9 | 10 | namespace internal { 11 | template 12 | struct traits > 13 | { 14 | typedef _Scalar Scalar; 15 | enum 16 | { 17 | Flags = _Flags, 18 | RowsAtCompileTime = _Rows, 19 | ColsAtCompileTime = _Columns, 20 | BatchesAtCompileTime = _Batches, 21 | AccessFlags = ReadCwise 22 | }; 23 | typedef CwiseSrcTag SrcTag; 24 | typedef DeletedDstTag DstTag; 25 | }; 26 | } 27 | 28 | /** 29 | * \brief A generic nullary operator. 30 | * This is used as a leaf in the expression template tree. 31 | * The nullary functor can be any class or structure that supports 32 | * the following method: 33 | * \code 34 | * __device__ const _Scalar& operator()(Index row, Index col, Index batch); 35 | * \endcode 36 | * \tparam _Scalar 37 | * \tparam _NullaryFunctor 38 | */ 39 | template 40 | class NullaryOp : public CwiseOp > 41 | { 42 | public: 43 | typedef CwiseOp > Base; 44 | typedef NullaryOp<_Scalar, _Rows, _Columns, _Batches, _Flags, _NullaryFunctor> Type; 45 | CUMAT_PUBLIC_API 46 | 47 | protected: 48 | const Index rows_; 49 | const Index cols_; 50 | const Index batches_; 51 | const _NullaryFunctor functor_; 52 | 53 | public: 54 | NullaryOp(Index rows, Index cols, Index batches, const _NullaryFunctor& functor) 55 | : rows_(rows), cols_(cols), batches_(batches), functor_(functor) 56 | {} 57 | explicit NullaryOp(const _NullaryFunctor& functor) 58 | : rows_(_Rows), cols_(_Columns), batches_(_Batches), functor_(functor) 59 | { 60 | CUMAT_STATIC_ASSERT(_Rows > 0, "The one-parameter constructor is only available for fully fixed-size instances"); 61 | CUMAT_STATIC_ASSERT(_Columns > 0, "The one-parameter constructor is only available for fully fixed-size instances"); 62 | CUMAT_STATIC_ASSERT(_Batches > 0, "The one-parameter constructor is only available for fully fixed-size instances"); 63 | } 64 | 65 | __host__ __device__ CUMAT_STRONG_INLINE Index rows() const { return rows_; } 66 | __host__ __device__ CUMAT_STRONG_INLINE Index cols() const { return cols_; } 67 | __host__ __device__ CUMAT_STRONG_INLINE Index batches() const { return batches_; } 68 | 69 | __device__ CUMAT_STRONG_INLINE Scalar coeff(Index row, Index col, Index batch, Index linear) const 70 | { 71 | return functor_(row, col, batch); 72 | } 73 | }; 74 | 75 | namespace functor 76 | { 77 | template 78 | struct ConstantFunctor 79 | { 80 | private: 81 | const _Scalar value_; 82 | public: 83 | ConstantFunctor(_Scalar value) 84 | : value_(value) 85 | {} 86 | 87 | __device__ CUMAT_STRONG_INLINE const _Scalar& operator()(Index row, Index col, Index batch) const 88 | { 89 | return value_; 90 | } 91 | }; 92 | 93 | template 94 | struct IdentityFunctor 95 | { 96 | public: 97 | __device__ CUMAT_STRONG_INLINE _Scalar operator()(Index row, Index col, Index batch) const 98 | { 99 | return (row==col) ? _Scalar(1) : _Scalar(0); 100 | } 101 | }; 102 | } 103 | 104 | template 105 | HostScalar<_Scalar> make_host_scalar(const _Scalar& value) 106 | { 107 | return HostScalar<_Scalar>(1, 1, 1, functor::ConstantFunctor<_Scalar>(value)); 108 | } 109 | 110 | CUMAT_NAMESPACE_END 111 | 112 | #endif 113 | -------------------------------------------------------------------------------- /extensions/include/cuMat/src/NumTraits.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_NUM_TRAITS_H__ 2 | #define __CUMAT_NUM_TRAITS_H__ 3 | 4 | #include 5 | 6 | #include "Macros.h" 7 | #include "ForwardDeclarations.h" 8 | #include 9 | #include 10 | 11 | CUMAT_NAMESPACE_BEGIN 12 | 13 | namespace internal 14 | { 15 | /** 16 | * \brief General implementation of NumTraits 17 | * \tparam T 18 | */ 19 | template 20 | struct NumTraits 21 | { 22 | /** 23 | * \brief The type itself 24 | */ 25 | typedef T Type; 26 | /** 27 | * \brief For complex types: the corresponding real type; equals to Type for non-complex types 28 | */ 29 | typedef T RealType; 30 | /** 31 | * \brief For compound types (blocked types): the corresponding element type (e.g. cfloat3->cfloat, double4->double); 32 | * equals to Type for non-blocked types 33 | */ 34 | typedef T ElementalType; 35 | enum 36 | { 37 | /** 38 | * \brief Equals one if cuBlas supports this type 39 | * \see CublasApi 40 | */ 41 | IsCudaNumeric = 0, 42 | /** 43 | * \brief Equal to true if this type is a complex type, and hence Type!=RealType 44 | */ 45 | IsComplex = false 46 | }; 47 | /** 48 | * \brief The default epsilon for approximate comparisons 49 | */ 50 | static constexpr CUMAT_STRONG_INLINE ElementalType epsilon() {return std::numeric_limits::epsilon();} 51 | }; 52 | 53 | template <> 54 | struct NumTraits 55 | { 56 | typedef float Type; 57 | typedef float RealType; 58 | typedef float ElementalType; 59 | enum 60 | { 61 | IsCudaNumeric = 1, 62 | IsComplex = false, 63 | }; 64 | static constexpr CUMAT_STRONG_INLINE RealType epsilon() {return std::numeric_limits::epsilon();} 65 | }; 66 | template <> 67 | struct NumTraits 68 | { 69 | typedef double Type; 70 | typedef double RealType; 71 | typedef double ElementalType; 72 | enum 73 | { 74 | IsCudaNumeric = 1, 75 | IsComplex = false, 76 | }; 77 | static constexpr CUMAT_STRONG_INLINE RealType epsilon() {return std::numeric_limits::epsilon();} 78 | }; 79 | 80 | template <> 81 | struct NumTraits 82 | { 83 | typedef cfloat Type; 84 | typedef float RealType; 85 | typedef cfloat ElementalType; 86 | enum 87 | { 88 | IsCudaNumeric = 1, 89 | IsComplex = true, 90 | }; 91 | static constexpr CUMAT_STRONG_INLINE RealType epsilon() {return std::numeric_limits::epsilon();} 92 | }; 93 | 94 | template <> 95 | struct NumTraits 96 | { 97 | typedef cdouble Type; 98 | typedef double RealType; 99 | typedef cdouble ElementalType; 100 | enum 101 | { 102 | IsCudaNumeric = 1, 103 | IsComplex = true, 104 | }; 105 | static constexpr CUMAT_STRONG_INLINE RealType epsilon() {return std::numeric_limits::epsilon();} 106 | }; 107 | 108 | template 109 | struct isPrimitive : std::is_arithmetic {}; 110 | template <> struct isPrimitive : std::integral_constant {}; 111 | template <> struct isPrimitive : std::integral_constant {}; 112 | 113 | /** 114 | * \brief Can the type T be used for broadcasting when the scalar type of the other matrix is S? 115 | */ 116 | template 117 | struct canBroadcast : std::integral_constant::value && CUMAT_NAMESPACE internal::isPrimitive::value) \ 119 | || std::is_same::type, typename std::remove_cv::type>::value 120 | > {}; 121 | 122 | template 123 | struct NumOps //special functions for numbers 124 | { 125 | static __host__ __device__ CUMAT_STRONG_INLINE T conj(const T& v) {return v;} 126 | }; 127 | template<> 128 | struct NumOps 129 | { 130 | static __host__ __device__ CUMAT_STRONG_INLINE cfloat conj(const cfloat& v) {return ::thrust::conj(v);} 131 | }; 132 | template<> 133 | struct NumOps 134 | { 135 | static __host__ __device__ CUMAT_STRONG_INLINE cdouble conj(const cdouble& v) {return ::thrust::conj(v);} 136 | }; 137 | } 138 | 139 | CUMAT_NAMESPACE_END 140 | 141 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/Profiling.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_PROFILING_H__ 2 | #define __CUMAT_PROFILING_H__ 3 | 4 | #include "Macros.h" 5 | 6 | #ifndef CUMAT_PROFILING 7 | /** 8 | * \brief Define this macro as '1' to enable profiling. 9 | * For the created statistics, see the class Profiling. 10 | */ 11 | #define CUMAT_PROFILING 0 12 | #endif 13 | 14 | CUMAT_NAMESPACE_BEGIN 15 | 16 | 17 | 18 | /** 19 | * \brief This class contains the counters used to profile the library. 20 | * See Counter for which statistics are available. 21 | * For easy access, several macros are available with the prefix CUMAT_PROFILING_ . 22 | */ 23 | class Profiling 24 | { 25 | public: 26 | enum Counter 27 | { 28 | DeviceMemAlloc, 29 | DeviceMemFree, 30 | HostMemAlloc, 31 | HostMemFree, 32 | 33 | MemcpyDeviceToDevice, 34 | MemcpyHostToHost, 35 | MemcpyDeviceToHost, 36 | MemcpyHostToDevice, 37 | 38 | /** 39 | * \brief Any evaluation has happend 40 | */ 41 | EvalAny, 42 | /** 43 | * \brief Component-wise evaluation 44 | */ 45 | EvalCwise, 46 | /** 47 | * \brief Special transposition operation 48 | */ 49 | EvalTranspose, 50 | /** 51 | * \brief Reduction operation with CUB 52 | */ 53 | EvalReduction, 54 | /** 55 | * \brief Matrix-Matrix multiplication with cuBLAS 56 | */ 57 | EvalMatmul, 58 | /** 59 | * \brief Sparse component-wise evaluation 60 | */ 61 | EvalCwiseSparse, 62 | /** 63 | * \brief Sparse matrix-matrix multiplication 64 | */ 65 | EvalMatmulSparse, 66 | 67 | _NumCounter_ 68 | }; 69 | private: 70 | size_t counters_[_NumCounter_]; 71 | 72 | public: 73 | void resetAll() 74 | { 75 | for (size_t i = 0; i < _NumCounter_; ++i) counters_[i] = 0; 76 | } 77 | void reset(Counter counter) 78 | { 79 | counters_[counter] = 0; 80 | } 81 | void inc(Counter counter) 82 | { 83 | counters_[counter]++; 84 | } 85 | size_t get(Counter counter) 86 | { 87 | return counters_[counter]; 88 | } 89 | size_t getReset(Counter counter) 90 | { 91 | size_t v = counters_[counter]; 92 | counters_[counter] = 0; 93 | return v; 94 | } 95 | 96 | private: 97 | Profiling() 98 | { 99 | resetAll(); 100 | } 101 | CUMAT_DISALLOW_COPY_AND_ASSIGN(Profiling); 102 | 103 | public: 104 | static Profiling& instance() 105 | { 106 | static Profiling p; 107 | return p; 108 | } 109 | }; 110 | 111 | #ifdef CUMAT_PARSED_BY_DOXYGEN 112 | 113 | /** 114 | * \brief Increments the counter 'counter' (an element of Profiling::Counter). 115 | */ 116 | #define CUMAT_PROFILING_INC(counter) 117 | 118 | /** 119 | * \brief Gets and resets the counter 'counter' (an element of Profiling::Counter). 120 | * If profiling is disabled (CUMAT_PROFILING is not defined or unequal to 1), the result 121 | * is not defined. 122 | */ 123 | #define CUMAT_PROFILING_GET(counter) 124 | 125 | /** 126 | * \brief Resets all counters 127 | */ 128 | #define CUMAT_PROFILING_RESET() 129 | 130 | #else 131 | 132 | #if CUMAT_PROFILING==1 133 | //Profiling enabled 134 | #define CUMAT_PROFILING_INC(counter) \ 135 | CUMAT_NAMESPACE Profiling::instance().inc(CUMAT_NAMESPACE Profiling::Counter::counter) 136 | #define CUMAT_PROFILING_GET(counter) \ 137 | CUMAT_NAMESPACE Profiling::instance().getReset(CUMAT_NAMESPACE Profiling::Counter::counter) 138 | #define CUMAT_PROFILING_RESET() \ 139 | CUMAT_NAMESPACE Profiling::instance().resetAll() 140 | #else 141 | //Profiling disabled 142 | #define CUMAT_PROFILING_INC(counter) ((void)0) 143 | #define CUMAT_PROFILING_GET(counter) ((void)0) 144 | #define CUMAT_PROFILING_RESET() ((void)0) 145 | #endif 146 | 147 | #endif 148 | 149 | 150 | 151 | CUMAT_NAMESPACE_END 152 | 153 | 154 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/ReductionAlgorithmSelection.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_REDUCTION_ALGORITHM_SELECTION_H__ 2 | #define __CUMAT_REDUCTION_ALGORITHM_SELECTION_H__ 3 | 4 | #include 5 | #include 6 | 7 | #include "Macros.h" 8 | #include "ForwardDeclarations.h" 9 | 10 | CUMAT_NAMESPACE_BEGIN 11 | 12 | namespace internal 13 | { 14 | /** 15 | * \brief Algorithms possible during dynamic selection. 16 | * These are a subset of the tags from namespace ReductionAlg. 17 | */ 18 | enum class ReductionAlgorithm 19 | { 20 | Segmented, 21 | Thread, 22 | Warp, 23 | Block256, 24 | Device1, 25 | Device2, 26 | Device4 27 | }; 28 | 29 | /** 30 | * \brief Selects the best reduction algorithm dynamically given 31 | * the reduction axis (inner, middle, outer). 32 | * Given a matrix in ColumnMajor order, these axis 33 | * correspond to the following: inner=Row, middle=Column, outer=Batch. 34 | * 35 | * The timings from which those selections were determined were evaluated 36 | * on a Nvidia RTX 2070. 37 | * For a different architecture, you might need to tweak the conditions 38 | * in the source code. 39 | */ 40 | struct ReductionAlgorithmSelection 41 | { 42 | private: 43 | typedef std::tuple condition; 44 | static constexpr int MAX_CONDITION = 3; 45 | typedef std::array conditions; 46 | typedef std::tuple choice; 47 | 48 | template 49 | static ReductionAlgorithm select( 50 | const choice(&conditions)[N], ReductionAlgorithm def, 51 | Index numBatches, Index batchSize) 52 | { 53 | const double nb = std::log2(numBatches); 54 | const double bs = std::log2(batchSize); 55 | for (const choice& c : conditions) 56 | { 57 | bool success = true; 58 | for (int i = 0; i < std::get<1>(c); ++i) { 59 | const auto& cond = std::get<2>(c)[i]; 60 | if (std::get<0>(cond)*nb + std::get<1>(cond)*bs < std::get<2>(cond)) 61 | success = false; 62 | } 63 | if (success) 64 | return std::get<0>(c); 65 | } 66 | return def; 67 | } 68 | 69 | public: 70 | static ReductionAlgorithm inner(Index numBatches, Index batchSize) 71 | { 72 | //adopt to new architectures 73 | static const choice CONDITONS[] = { 74 | choice{ReductionAlgorithm::Device1, 2, {condition{1.2, 1.0, 19.5}, condition{-1,0,-2.5}}}, 75 | choice{ReductionAlgorithm::Device2, 3, {condition{0.42857142857142855,1,17.821428571428573}, condition{-1,0,-4.25}, condition{1,0,2.5}}}, 76 | choice{ReductionAlgorithm::Device4, 3, {condition{0,1,16.25}, condition{-1,0,-5.5}, condition{1,0,4.25}}}, 77 | choice{ReductionAlgorithm::Block256,2, {condition{-1.6, 1, 8}, condition{-1,0,-5}}}, 78 | choice{ReductionAlgorithm::Thread, 2, {condition{0.475, -1, 2.01875}, condition{0, -1, -4.75}}} 79 | }; 80 | static const ReductionAlgorithm DEFAULT = ReductionAlgorithm::Warp; 81 | return select(CONDITONS, DEFAULT, numBatches, batchSize); 82 | } 83 | 84 | static ReductionAlgorithm middle(Index numBatches, Index batchSize) 85 | { 86 | //adopt to new architectures 87 | static const choice CONDITONS[] = { 88 | choice{ReductionAlgorithm::Device1, 2, {condition{1.5,1,19.5}, condition{-1,0,-2.5}}}, 89 | choice{ReductionAlgorithm::Device2, 3, {condition{0,1,15.5}, condition{1,0,2.5}, condition{-1,0,-4}}}, 90 | choice{ReductionAlgorithm::Device4, 3, {condition{0,1,15.75}, condition{1,0,4}, condition{-1,0,-5.75}}}, 91 | choice{ReductionAlgorithm::Block256,2, {condition{0,1,9}, condition{-1,0,-2.5}}}, 92 | choice{ReductionAlgorithm::Warp, 2, {condition{0,1,4}, condition{-1,0,-11.75}}} 93 | }; 94 | static const ReductionAlgorithm DEFAULT = ReductionAlgorithm::Thread; 95 | return select(CONDITONS, DEFAULT, numBatches, batchSize); 96 | } 97 | 98 | static ReductionAlgorithm outer(Index numBatches, Index batchSize) 99 | { 100 | //adopt to new architectures 101 | static const choice CONDITONS[] = { 102 | choice{ReductionAlgorithm::Device1, 2, {condition{-1,0,-2}, condition{1.875,1,19}}}, 103 | choice{ReductionAlgorithm::Device4, 3, {condition{1,0,2}, condition{-1,0,-4.25}, condition{10, 9, 184.25}}}, 104 | choice{ReductionAlgorithm::Device2, 3, {condition{1,0,2}, condition{-1,0,-4.25}, condition{-0.22222, 1, 14.085555}}}, 105 | choice{ReductionAlgorithm::Segmented, 3, {condition{1,0,4}, condition{0,1,11.5}, condition{-1,0,-8.5}}}, 106 | choice{ReductionAlgorithm::Block256,2, {condition{0,1,8}, condition{-1,0,-2}}}, 107 | choice{ReductionAlgorithm::Warp, 2, {condition{0,1,2.75}, condition{-1,0,-11.75}}} 108 | }; 109 | static const ReductionAlgorithm DEFAULT = ReductionAlgorithm::Thread; 110 | return select(CONDITONS, DEFAULT, numBatches, batchSize); 111 | } 112 | }; 113 | } 114 | 115 | CUMAT_NAMESPACE_END 116 | 117 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/SolverBase.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_SOLVER_BASE_H__ 2 | #define __CUMAT_SOLVER_BASE_H__ 3 | 4 | #include "Macros.h" 5 | #include "ForwardDeclarations.h" 6 | 7 | CUMAT_NAMESPACE_BEGIN 8 | 9 | template 10 | class SolverBase 11 | { 12 | protected: 13 | CUMAT_STRONG_INLINE _DecompositionImpl& impl() { return *static_cast<_DecompositionImpl*>(this); } 14 | CUMAT_STRONG_INLINE const _DecompositionImpl& impl() const { return *static_cast(this); } 15 | 16 | public: 17 | using Scalar = typename internal::traits<_DecompositionImpl>::Scalar; 18 | using MatrixType = typename internal::traits<_DecompositionImpl>::MatrixType; 19 | enum 20 | { 21 | Flags = internal::traits::Flags, 22 | Rows = internal::traits::RowsAtCompileTime, 23 | Columns = internal::traits::ColsAtCompileTime, 24 | Batches = internal::traits::BatchesAtCompileTime, 25 | }; 26 | 27 | /** 28 | * \brief Solves the system of linear equations 29 | * \tparam _RHS the type of the right hand side 30 | * \param rhs the right hand side matrix 31 | * \return The operation that computes the solution of the linear system 32 | */ 33 | template 34 | SolveOp<_DecompositionImpl, _RHS> solve(const MatrixBase<_RHS>& rhs) const 35 | { 36 | return SolveOp<_DecompositionImpl, _RHS>(impl(), rhs.derived()); 37 | } 38 | }; 39 | 40 | namespace internal 41 | { 42 | struct SolveSrcTag {}; 43 | template 44 | struct traits > 45 | { 46 | //using Scalar = typename internal::traits<_Child>::Scalar; 47 | using Scalar = typename internal::traits<_RHS>::Scalar; 48 | enum 49 | { 50 | Flags = ColumnMajor, 51 | RowsAtCompileTime = internal::traits<_RHS>::RowsAtCompileTime, 52 | ColsAtCompileTime = internal::traits<_RHS>::ColsAtCompileTime, 53 | BatchesAtCompileTime = internal::traits<_RHS>::BatchesAtCompileTime, 54 | AccessFlags = 0 55 | }; 56 | typedef SolveSrcTag SrcTag; 57 | typedef DeletedDstTag DstTag; 58 | }; 59 | } 60 | 61 | /** 62 | * \brief General solver operation. 63 | * Delegates to _solve_impl of the solver implementation. 64 | * \tparam _Solver 65 | * \tparam _RHS 66 | */ 67 | template 68 | class SolveOp : public MatrixBase> 69 | { 70 | public: 71 | typedef MatrixBase> Base; 72 | typedef SolveOp<_Solver, _RHS> Type; 73 | using Scalar = typename internal::traits<_RHS>::Scalar; 74 | enum 75 | { 76 | Flags = ColumnMajor, 77 | Rows = internal::traits<_RHS>::RowsAtCompileTime, 78 | Columns = internal::traits<_RHS>::ColsAtCompileTime, 79 | Batches = internal::traits<_RHS>::BatchesAtCompileTime 80 | }; 81 | 82 | private: 83 | const _Solver& decomposition_; 84 | const _RHS rhs_; 85 | 86 | public: 87 | SolveOp(const _Solver& decomposition, const MatrixBase<_RHS>& rhs) 88 | : decomposition_(decomposition) 89 | , rhs_(rhs.derived()) 90 | { 91 | CUMAT_STATIC_ASSERT((std::is_same< 92 | typename internal::NumTraits::ElementalType, 93 | typename internal::NumTraits::Scalar>::ElementalType>::value), 94 | "Datatype of left- and right hand side must match"); 95 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(_Solver::Batches > 1 && _RHS::Batches > 0, _Solver::Batches == _RHS::Batches), 96 | "Static count of batches must match"); //note: _Solver::Batches>1 to allow broadcasting 97 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(_Solver::Rows > 0 && _Solver::Columns > 0, _Solver::Rows == _Solver::Columns), 98 | "Static count of rows and columns must be equal (square matrix)"); 99 | CUMAT_STATIC_ASSERT(CUMAT_IMPLIES(_Solver::Rows > 0 && _RHS::Rows > 0, _Solver::Rows == _RHS::Rows), 100 | "Left and right hand side are not compatible"); 101 | 102 | CUMAT_ASSERT(CUMAT_IMPLIES(_Solver::Batches!=1, decomposition.batches() == rhs.batches()) && "batch size of the matrix and the right hand side does not match"); 103 | CUMAT_ASSERT(decomposition.rows() == decomposition.cols() && "matrix must be square"); //TODO: relax for Least-Squares problems 104 | CUMAT_ASSERT(decomposition.cols() == rhs.rows() && "matrix size does not match right-hand-side"); 105 | } 106 | 107 | __host__ __device__ CUMAT_STRONG_INLINE Index rows() const { return rhs_.rows(); } 108 | __host__ __device__ CUMAT_STRONG_INLINE Index cols() const { return rhs_.cols(); } 109 | __host__ __device__ CUMAT_STRONG_INLINE Index batches() const { return rhs_.batches(); } 110 | 111 | const _Solver& getDecomposition() const { return decomposition_; } 112 | const _RHS& getRhs() const { return rhs_; } 113 | }; 114 | 115 | namespace internal 116 | { 117 | //Assignment for decompositions that call SolveOp::evalTo 118 | template 119 | struct Assignment<_Dst, _Src, _Mode, DenseDstTag, SolveSrcTag> 120 | { 121 | static void assign(_Dst& dst, const _Src& src) 122 | { 123 | static_assert(_Mode == AssignmentMode::ASSIGN, "Decompositions only support AssignmentMode::ASSIGN (operator=)"); 124 | src.derived().getDecomposition()._solve_impl(src.derived().getRhs(), dst.derived()); 125 | } 126 | }; 127 | } 128 | 129 | CUMAT_NAMESPACE_END 130 | 131 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/SparseEvaluation.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_SPARSE_EVALUATION__ 2 | #define __CUMAT_SPARSE_EVALUATION__ 3 | 4 | #include "Macros.h" 5 | #include "CwiseOp.h" 6 | #include "SparseMatrix.h" 7 | 8 | CUMAT_NAMESPACE_BEGIN 9 | 10 | namespace internal 11 | { 12 | namespace kernels 13 | { 14 | template 15 | __global__ void CwiseCSREvaluationKernel(dim3 virtual_size, const T expr, M matrix) 16 | { 17 | const int* JA = matrix.getSparsityPattern().JA.data(); 18 | const int* IA = matrix.getSparsityPattern().IA.data(); 19 | Index batchStride = matrix.getSparsityPattern().nnz; 20 | //TODO: Profiling, what is the best way to loop over the batches? 21 | CUMAT_KERNEL_2D_LOOP(outer, batch, virtual_size) 22 | int start = JA[outer]; 23 | int end = JA[outer + 1]; 24 | for (int i = start; i < end; ++i) 25 | { 26 | int inner = IA[i]; 27 | Index row = outer; 28 | Index col = inner; 29 | Index idx = i + batch * batchStride; 30 | auto val = expr.coeff(row, col, batch, idx); 31 | internal::CwiseAssignmentHandler::assign(matrix, val, idx); 32 | } 33 | CUMAT_KERNEL_2D_LOOP_END 34 | } 35 | template 36 | __global__ void CwiseCSCEvaluationKernel(dim3 virtual_size, const T expr, M matrix) 37 | { 38 | const int* JA = matrix.getSparsityPattern().JA.data(); 39 | const int* IA = matrix.getSparsityPattern().IA.data(); 40 | Index batchStride = matrix.getSparsityPattern().nnz; 41 | //TODO: Profiling, what is the best way to loop over the batches? 42 | CUMAT_KERNEL_2D_LOOP(outer, batch, virtual_size) 43 | int start = JA[outer]; 44 | int end = JA[outer + 1]; 45 | for (int i = start; i < end; ++i) 46 | { 47 | int inner = IA[i]; 48 | Index row = inner; 49 | Index col = outer; 50 | Index idx = i + batch * batchStride; 51 | auto val = expr.coeff(row, col, batch, idx); 52 | internal::CwiseAssignmentHandler::assign(matrix, val, idx); 53 | } 54 | CUMAT_KERNEL_2D_LOOP_END 55 | } 56 | template 57 | __global__ void CwiseELLPACKEvaluationKernel(dim3 virtual_size, const T expr, M matrix) 58 | { 59 | const SparsityPattern::IndexMatrix indices = matrix.getSparsityPattern().indices; 60 | Index nnzPerRow = matrix.getSparsityPattern().nnzPerRow; 61 | Index rows = matrix.getSparsityPattern().rows; 62 | Index batchStride = rows * nnzPerRow; 63 | //TODO: Profiling, what is the best way to loop over the batches? 64 | CUMAT_KERNEL_2D_LOOP(row, batch, virtual_size) 65 | for (int i = 0; i < nnzPerRow; ++i) 66 | { 67 | Index col = indices.coeff(row, i, 0, -1); 68 | if (col < 0) continue; //TODO: test if it is faster to continue reading (and set col=0) and discard before assignment 69 | Index idx = row + i * rows + batch * batchStride; 70 | auto val = expr.coeff(row, col, batch, idx); 71 | internal::CwiseAssignmentHandler::assign(matrix, val, idx); 72 | } 73 | CUMAT_KERNEL_2D_LOOP_END 74 | } 75 | } 76 | } 77 | 78 | namespace internal { 79 | 80 | #if CUMAT_NVCC==1 81 | //General assignment for everything that fulfills CwiseSrcTag into SparseDstTag (cwise sparse evaluation) 82 | //The source expression is only evaluated at the non-zero entries of the target SparseMatrix 83 | template 84 | struct Assignment<_Dst, _Src, _Mode, SparseDstTag, CwiseSrcTag> 85 | { 86 | private: 87 | static void assign(_Dst& dst, const _Src& src, std::integral_constant) 88 | { 89 | //here is now the real logic 90 | Context& ctx = Context::current(); 91 | KernelLaunchConfig cfg = ctx.createLaunchConfig2D(static_cast(dst.derived().outerSize()), static_cast(dst.derived().batches()), 92 | kernels::CwiseCSREvaluationKernel); 93 | kernels::CwiseCSREvaluationKernel 94 | <<< cfg.block_count, cfg.thread_per_block, 0, ctx.stream() >>> 95 | (cfg.virtual_size, src.derived(), dst.derived()); 96 | CUMAT_CHECK_ERROR(); 97 | } 98 | static void assign(_Dst& dst, const _Src& src, std::integral_constant) 99 | { 100 | //here is now the real logic 101 | Context& ctx = Context::current(); 102 | KernelLaunchConfig cfg = ctx.createLaunchConfig2D(static_cast(dst.derived().outerSize()), static_cast(dst.derived().batches()), 103 | kernels::CwiseCSCEvaluationKernel); 104 | kernels::CwiseCSCEvaluationKernel 105 | <<< cfg.block_count, cfg.thread_per_block, 0, ctx.stream() >>> 106 | (cfg.virtual_size, src.derived(), dst.derived()); 107 | CUMAT_CHECK_ERROR(); 108 | } 109 | static void assign(_Dst& dst, const _Src& src, std::integral_constant) 110 | { 111 | //here is now the real logic 112 | Context& ctx = Context::current(); 113 | KernelLaunchConfig cfg = ctx.createLaunchConfig2D(static_cast(dst.derived().outerSize()), static_cast(dst.derived().batches()), 114 | kernels::CwiseELLPACKEvaluationKernel); 115 | kernels::CwiseELLPACKEvaluationKernel 116 | <<< cfg.block_count, cfg.thread_per_block, 0, ctx.stream() >>> 117 | (cfg.virtual_size, src.derived(), dst.derived()); 118 | CUMAT_CHECK_ERROR(); 119 | } 120 | 121 | public: 122 | static void assign(_Dst& dst, const _Src& src) 123 | { 124 | typedef typename _Dst::Type DstActual; 125 | typedef typename _Src::Type SrcActual; 126 | CUMAT_PROFILING_INC(EvalCwiseSparse); 127 | CUMAT_PROFILING_INC(EvalAny); 128 | if (dst.size() == 0) return; 129 | CUMAT_ASSERT(src.rows() == dst.rows()); 130 | CUMAT_ASSERT(src.cols() == dst.cols()); 131 | CUMAT_ASSERT(src.batches() == dst.batches()); 132 | 133 | CUMAT_LOG_DEBUG("Evaluate component wise sparse expression " << typeid(src.derived()).name() 134 | << "\n rows=" << src.rows() << ", cols=" << src.cols() << ", batches=" << src.batches()); 135 | assign(dst, src, std::integral_constant()); 136 | CUMAT_LOG_DEBUG("Evaluation done"); 137 | } 138 | }; 139 | #endif 140 | } 141 | 142 | CUMAT_NAMESPACE_END 143 | 144 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/SparseExpressionOp.h: -------------------------------------------------------------------------------- 1 | #ifndef __CUMAT_SPARSE_EXPRESSION_OP__ 2 | #define __CUMAT_SPARSE_EXPRESSION_OP__ 3 | 4 | #include "Macros.h" 5 | #include "ForwardDeclarations.h" 6 | #include "MatrixBase.h" 7 | #include "SparseMatrixBase.h" 8 | 9 | CUMAT_NAMESPACE_BEGIN 10 | 11 | namespace internal 12 | { 13 | template 14 | struct traits > 15 | { 16 | using Scalar = typename internal::traits<_Child>::Scalar; 17 | enum 18 | { 19 | SFlags = _SparseFlags, 20 | Flags = internal::traits<_Child>::Flags, 21 | RowsAtCompileTime = internal::traits<_Child>::RowsAtCompileTime, 22 | ColsAtCompileTime = internal::traits<_Child>::ColsAtCompileTime, 23 | BatchesAtCompileTime = internal::traits<_Child>::BatchesAtCompileTime, 24 | AccessFlags = ReadCwise 25 | }; 26 | typedef CwiseSrcTag SrcTag; 27 | typedef DeletedDstTag DstTag; 28 | }; 29 | } 30 | 31 | template 32 | class SparseExpressionOp : public SparseMatrixBase> 33 | { 34 | public: 35 | typedef SparseExpressionOp<_Child, _SparseFlags> Type; 36 | typedef SparseMatrixBase Base; 37 | CUMAT_PUBLIC_API 38 | enum 39 | { 40 | SFlags = _SparseFlags 41 | }; 42 | 43 | using typename Base::StorageIndex; 44 | using Base::rows; 45 | using Base::cols; 46 | using Base::batches; 47 | using Base::nnz; 48 | using Base::size; 49 | 50 | protected: 51 | typedef typename MatrixReadWrapper<_Child, AccessFlags::ReadCwise>::type child_wrapped_t; 52 | const child_wrapped_t child_; 53 | 54 | public: 55 | SparseExpressionOp(const MatrixBase<_Child>& child, const SparsityPattern<_SparseFlags>& sparsityPattern) 56 | : Base(sparsityPattern, child.batches()) 57 | , child_(child.derived()) 58 | { 59 | CUMAT_ASSERT_DIMENSION(child.rows() == sparsityPattern.rows); 60 | CUMAT_ASSERT_DIMENSION(child.cols() == sparsityPattern.cols); 61 | } 62 | 63 | __device__ CUMAT_STRONG_INLINE Scalar coeff(Index row, Index col, Index batch, Index index) const 64 | { 65 | return child_.derived().coeff(row, col, batch, index), row, col, batch; 66 | } 67 | 68 | __device__ CUMAT_STRONG_INLINE Scalar getSparseCoeff(Index row, Index col, Index batch, Index index) const 69 | { 70 | return child_.derived().coeff(row, col, batch, index); 71 | } 72 | }; 73 | 74 | CUMAT_NAMESPACE_END 75 | 76 | 77 | #endif -------------------------------------------------------------------------------- /extensions/include/cuMat/src/SparseExpressionOpPlugin.inl: -------------------------------------------------------------------------------- 1 | //Included in MatrixBase 2 | 3 | /** 4 | * \brief Views this matrix expression as a sparse matrix. 5 | * This enforces the specified sparsity pattern and the coefficients 6 | * of this matrix expression are then only evaluated at these positions. 7 | * 8 | * For now, the only use case is the sparse matrix-vector product. 9 | * For example: 10 | * \code 11 | * SparseMatrix<...> m1, m2; //all initialized with the same sparsity pattern 12 | * VectorXf v1, v2 = ...; 13 | * v1 = (m1 + m2) * v2; 14 | * \endcode 15 | * In this form, this would trigger a dense matrix vector multiplication, which is 16 | * completely unfeasable. This is because the the addition expression 17 | * does not know anything about sparse matrices and the product operation 18 | * then only sees an addition expression on the left hand side. Thus, 19 | * because of lacking knowledge, it has to trigger a dense evaluation. 20 | * 21 | * Improvement: 22 | * \code 23 | * v1 = (m1 + m2).sparseView(m1.getSparsityPattern()) * v2; 24 | * \endcode 25 | * with Format being either CSR or CSC. 26 | * This enforces the sparsity pattern of m1 onto the addition expression. 27 | * Thus the (immediately following!) product expression sees this sparse expression 28 | * and can trigger a sparse matrix-vector multiplication. 29 | * But the sparse matrices \c m1 and \c m2 now have to search for their coefficients. 30 | * This is because they don't know that the parent operations (add+multiply) will call 31 | * their coefficients in order of their own entries. In other words, that the \c linear 32 | * index parameter in \ref coeff(Index row, Index col, Index batch, Index linear) matches 33 | * the linear index in their data array. 34 | * (This is a valid assumption if you take transpose and block operations that change the 35 | * access pattern into considerations. Further, this allows e.g. \c m2 to have a different 36 | * sparsity pattern from m1, but only the entries that are included in both are used.) 37 | * 38 | * To overcome the above problem, one has to make one last adjustion: 39 | * \code 40 | * v1 = (m1.direct() + m2.direct()).sparseView(m1.getSparsityPattern()) * v2; 41 | * \endcode 42 | * \ref SparseMatrix::direct() tells the matrix that the linear index in 43 | * \ref coeff(Index row, Index col, Index batch, Index linear) matches the linear index 44 | * in the data array and thus can be used directly. This discards and checks that 45 | * the row, column and batch index actually match. So use this with care 46 | * if you know that access pattern is not changed in the operation. 47 | * (This holds true for all non-broadcasting component wise expressions) 48 | * 49 | * \param pattern the enforced sparsity pattern 50 | * \tparam _SparseFlags the sparse format: CSC or CSR 51 | */ 52 | template 53 | SparseExpressionOp 54 | sparseView(const SparsityPattern<_SparseFlags>& pattern) 55 | { 56 | return SparseExpressionOp(derived(), pattern); 57 | } -------------------------------------------------------------------------------- /extensions/push_pull_inpaint.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // #include "cuda/SplatSliceGPU.cuh" 6 | 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_DIM(x, d) TORCH_CHECK((x.dim() == (d)), #x " must be a tensor with ", d, " dimensions, but has shape ", x.sizes()) 11 | #define CHECK_SIZE(x, d, s) TORCH_CHECK((x.size(d) == (s)), #x " must have ", s, " entries at dimension ", d, ", but has ", x.size(d), " entries") 12 | 13 | #define MAX_CHANNELS 128 14 | 15 | // CUDA forward declarations 16 | std::tuple 17 | push_pull_inpaint_recursion_cuda( 18 | const torch::Tensor& mask, 19 | const torch::Tensor& data); 20 | 21 | 22 | 23 | 24 | torch::Tensor push_pull_inpaint( 25 | const torch::Tensor& mask, 26 | const torch::Tensor& data) 27 | { 28 | //check input 29 | CHECK_CUDA(mask); 30 | CHECK_CONTIGUOUS(mask); 31 | CHECK_DIM(mask, 3); 32 | int64_t B = mask.size(0); 33 | int64_t H = mask.size(1); 34 | int64_t W = mask.size(2); 35 | 36 | CHECK_CUDA(data); 37 | CHECK_CONTIGUOUS(data); 38 | CHECK_DIM(data, 4); 39 | CHECK_SIZE(data, 0, B); 40 | int64_t C = data.size(1); 41 | CHECK_SIZE(data, 2, H); 42 | CHECK_SIZE(data, 3, W); 43 | TORCH_CHECK(C < MAX_CHANNELS, "Inpainting::fastInpaint only supports up to 128 channels, but got " + std::to_string(C)); 44 | 45 | //inpaint recursivly 46 | return std::get<1>(push_pull_inpaint_recursion_cuda(mask, data)); 47 | } 48 | 49 | 50 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 51 | m.def("push_pull_inpaint", &push_pull_inpaint); 52 | } -------------------------------------------------------------------------------- /imgs/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/imgs/dataset.png -------------------------------------------------------------------------------- /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/imgs/teaser.png -------------------------------------------------------------------------------- /inference/assets/blender_vis_base_v26_with_shrinkwrap_full_base.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/inference/assets/blender_vis_base_v26_with_shrinkwrap_full_base.blend -------------------------------------------------------------------------------- /inference/assets/face_landmarker.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/inference/assets/face_landmarker.task -------------------------------------------------------------------------------- /inference_difflocks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ./inference_difflocks.py \ 4 | # --img_path=./samples/medium_11.png \ 5 | # --out_path=./outputs_inference/ 6 | 7 | from inference.img2hair import DiffLocksInference 8 | import subprocess 9 | import os 10 | import argparse 11 | 12 | def run(): 13 | 14 | #argparse 15 | parser = argparse.ArgumentParser(description='Get the weights of each dimensions after training a strand VAE') 16 | parser.add_argument('--strand_checkpoint_path', default="./checkpoints/strand_vae/strand_codec.pt", type=str, help='Path to the strandVAE checkpoint') 17 | parser.add_argument('--difflocks_checkpoint_path', default="./checkpoints/difflocks_diffusion/scalp_v9_40k_06730000.pth", type=str, help='Path to the difflocks checkpoint') 18 | parser.add_argument('--difflocks_config_path', default="./configs/config_scalp_texture_conditional.json", type=str, help='Path to the difflocks config') 19 | parser.add_argument('--rgb2mat_checkpoint_path', default="./checkpoints/rgb2material/rgb2material.pt", type=str, help='Path to the rgb2material checkpoint') 20 | parser.add_argument('--blender_path', type=str, default="", help='Path to the blender executable') 21 | parser.add_argument('--blender_nr_threads', default=8, type=int, help='Number of threads for blender to use') 22 | parser.add_argument('--blender_strands_subsample', default=1.0, type=float, help='Amount of subsample of the strands(1.0=full strands, 0.5=half strands)') 23 | parser.add_argument('--blender_vertex_subsample', default=1.0, type=float, help='Amount of subsample of the vertices(1.0=all vertex, 0.5=half number of vertices per strand)') 24 | parser.add_argument('--alembic_resolution', default=7, type=int, help='Resolution of the exported alembic') 25 | parser.add_argument('--export_alembic', action='store_true', help='weather to export alembic or not') 26 | parser.add_argument('--do_shrinkwrap', action='store_true', help='applies a shrinkwrap modifier in blender that pushes the strands away from the scalp so they dont pass through the head') 27 | parser.add_argument('--img_path', type=str, required=True, help='Path to the image to do inference on') 28 | parser.add_argument('--out_path', type=str, required=True, help='Path to the image to do inference on') 29 | args = parser.parse_args() 30 | 31 | print("args is", args) 32 | 33 | out_path="./outputs_inference/" 34 | 35 | difflocks= DiffLocksInference(args.strand_checkpoint_path, args.difflocks_config_path, args.difflocks_checkpoint_path, args.rgb2mat_checkpoint_path) 36 | 37 | 38 | #run---- 39 | # img_path="./samples/medium_11.png" 40 | strand_points_world, hair_material_dict=difflocks.file2hair(args.img_path, args.out_path) 41 | print("hair_material_dict",hair_material_dict) 42 | 43 | 44 | #create blender file and optionally an alembic file 45 | if args.blender_path!="": 46 | cmd=[args.blender_path, "-t", str(args.blender_nr_threads), "--background", "--python", "./inference/npz2blender.py", "--", "--input_npz", os.path.join(out_path,"difflocks_output_strands.npz"), "--out_path", args.out_path, "--strands_subsample", str(args.blender_strands_subsample), "--vertex_subsample", str(args.blender_vertex_subsample), "--alembic_resolution", str(args.alembic_resolution) ] 47 | if args.do_shrinkwrap: 48 | cmd.append("--shrinkwrap") 49 | if args.export_alembic: 50 | cmd.append("--export_alembic") 51 | subprocess.run(cmd, capture_output=False) 52 | 53 | print("Finished writing to ", args.out_path) 54 | 55 | if __name__ == '__main__': 56 | 57 | run() 58 | -------------------------------------------------------------------------------- /k_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from . import config, layers, models, sampling, utils 2 | from .layers import Denoiser 3 | -------------------------------------------------------------------------------- /k_diffusion/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import flops 2 | from .flags import checkpointing, get_checkpointing 3 | # from .image_v1 import ImageDenoiserModelV1 4 | # from .image_transformer_v2 import ImageTransformerDenoiserModelV2 5 | from .image_transformer_v2_conditional import ImageTransformerDenoiserModelV2Conditional 6 | 7 | -------------------------------------------------------------------------------- /k_diffusion/models/attention.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from inspect import isfunction 3 | import math 4 | from k_diffusion.models.modules import AxialRoPE, apply_rotary_emb_ 5 | from k_diffusion.models.modules import AdaRMSNorm, FeedForwardBlock, LinearGEGLU, RMSNorm, apply_wd, use_flash_2, zero_init 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn, einsum 9 | from einops import rearrange, repeat 10 | from typing import Optional, Any 11 | 12 | import flash_attn 13 | 14 | 15 | 16 | try: 17 | import xformers 18 | import xformers.ops 19 | 20 | XFORMERS_IS_AVAILBLE = True 21 | except: 22 | XFORMERS_IS_AVAILBLE = False 23 | 24 | # CrossAttn precision handling 25 | import os 26 | 27 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 28 | 29 | 30 | def zero_module(module): 31 | """ 32 | Zero out the parameters of a module and return it. 33 | """ 34 | for p in module.parameters(): 35 | p.detach().zero_() 36 | return module 37 | 38 | 39 | def scale_for_cosine_sim(q, k, scale, eps): 40 | dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32)) 41 | sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True) 42 | sum_sq_k = torch.sum(k.to(dtype)**2, dim=-1, keepdim=True) 43 | sqrt_scale = torch.sqrt(scale.to(dtype)) 44 | scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps) 45 | scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps) 46 | return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype) 47 | 48 | def scale_for_cosine_sim_qkv(qkv, scale, eps): 49 | q, k, v = qkv.unbind(2) 50 | q, k = scale_for_cosine_sim(q, k, scale[:, None], eps) 51 | return torch.stack((q, k, v), dim=2) 52 | 53 | def scale_for_cosine_sim_single(q, scale, eps): 54 | dtype = reduce(torch.promote_types, (q.dtype, scale.dtype, torch.float32)) 55 | sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True) 56 | sqrt_scale = torch.sqrt(scale.to(dtype)) 57 | scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps) 58 | return q * scale_q.to(q.dtype) 59 | 60 | class SpatialTransformerSimpleV2(nn.Module): 61 | """ 62 | Transformer block for image-like data. 63 | First, project the input (aka embedding) 64 | and reshape to b, t, d. 65 | Then apply standard transformer action. 66 | Finally, reshape to image 67 | NEW: use_linear for more efficiency instead of the 1x1 convs 68 | """ 69 | 70 | def __init__(self, in_channels, n_heads, d_head, 71 | global_cond_dim, 72 | do_self_attention=True, 73 | dropout=0., 74 | context_dim=None, 75 | ): 76 | super().__init__() 77 | 78 | 79 | self.in_channels = in_channels 80 | inner_dim = n_heads * d_head 81 | self.n_heads = n_heads 82 | self.d_head = d_head 83 | self.do_self_attention=do_self_attention 84 | 85 | 86 | self.x_in_norm = AdaRMSNorm(in_channels, global_cond_dim) 87 | 88 | #x to qkv 89 | if self.do_self_attention: 90 | self.x_qkv_proj = apply_wd(torch.nn.Linear(in_channels, inner_dim * 3, bias=False)) 91 | else: 92 | self.x_q_proj = apply_wd(torch.nn.Linear(in_channels, inner_dim, bias=False)) 93 | self.x_scale = nn.Parameter(torch.full([self.n_heads], 10.0)) 94 | 95 | self.x_pos_emb = AxialRoPE(d_head // 2, self.n_heads) 96 | 97 | 98 | #context to kv 99 | self.cond_kv_proj = apply_wd(torch.nn.Linear(context_dim, inner_dim * 2, bias=False)) 100 | self.cond_scale = nn.Parameter(torch.full([self.n_heads], 10.0)) 101 | self.cond_pos_emb = AxialRoPE(d_head // 2, self.n_heads) 102 | 103 | self.ff = FeedForwardBlock(in_channels, d_ff=int(in_channels*2), cond_features=global_cond_dim, dropout=dropout) 104 | 105 | self.dropout = nn.Dropout(dropout) 106 | self.proj_out = apply_wd(zero_module(nn.Linear(in_channels, inner_dim))) 107 | 108 | 109 | def forward(self, x, pos, global_cond, context=None, context_pos=None): 110 | b, c, h, w = x.shape 111 | x_in = x 112 | x = rearrange(x, 'b c h w -> b h w c') 113 | context = rearrange(context, 'b c h w -> b h w c') 114 | x = self.x_in_norm(x, global_cond) 115 | 116 | if self.do_self_attention: 117 | #x to qkv 118 | x_qkv = self.x_qkv_proj(x) 119 | pos = rearrange(pos, "... h w e -> ... (h w) e").to(x_qkv.dtype) 120 | x_theta = self.x_pos_emb(pos) 121 | if use_flash_2(x_qkv): 122 | x_qkv = rearrange(x_qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) 123 | x_qkv = scale_for_cosine_sim_qkv(x_qkv, self.x_scale, 1e-6) 124 | x_theta = torch.stack((x_theta, x_theta, torch.zeros_like(x_theta)), dim=-3) 125 | x_qkv = apply_rotary_emb_(x_qkv, x_theta) 126 | x_q, x_k, x_v = x_qkv.chunk(3,dim=-3) 127 | else: 128 | print("we couldnt run flash2, maybe it's not installed or the input si not bfloat16") 129 | exit(1) 130 | else: 131 | #x to q 132 | x_q = self.x_q_proj(x) 133 | pos = rearrange(pos, "... h w e -> ... (h w) e").to(x_q.dtype) 134 | x_theta = self.x_pos_emb(pos) 135 | if use_flash_2(x_q): 136 | x_q = rearrange(x_q, "n h w (nh e) -> n (h w) nh e", e=self.d_head) 137 | x_q = scale_for_cosine_sim_single(x_q, self.x_scale[:, None], 1e-6) 138 | x_q=x_q.unsqueeze(2) #n (h w) 1 nh e 139 | x_theta=x_theta.unsqueeze(1) 140 | x_q = apply_rotary_emb_(x_q, x_theta) 141 | else: 142 | print("we couldnt run flash2, maybe it's not installed or the input si not bfloat16") 143 | exit(1) 144 | 145 | 146 | #context to kv 147 | cond_kv = self.cond_kv_proj(context) 148 | # print("cond_kv init",cond_kv.shape) 149 | context_pos = rearrange(context_pos, "... h w e -> ... (h w) e").to(cond_kv.dtype) 150 | cond_theta = self.cond_pos_emb(context_pos) 151 | if use_flash_2(cond_kv): 152 | cond_kv = rearrange(cond_kv, "n h w (t nh e) -> n (h w) t nh e", t=2, e=self.d_head) 153 | cond_k, cond_v = cond_kv.unbind(2) # makes each n (h w) nh e 154 | cond_k = scale_for_cosine_sim_single(cond_k, self.cond_scale[:, None], 1e-6) 155 | cond_k=cond_k.unsqueeze(2) #n (h w) 1 nh e 156 | cond_theta=cond_theta.unsqueeze(1) 157 | cond_k = apply_rotary_emb_(cond_k, cond_theta) 158 | cond_k=cond_k.squeeze(2) 159 | else: 160 | print("we couldnt run flash2, maybe it's not installed or the input si not bfloat16") 161 | exit(1) 162 | 163 | #doing self attention by concating K and V between X and cond 164 | if self.do_self_attention: 165 | k = torch.cat([x_k, cond_k.unsqueeze(2)], dim=1) 166 | v = torch.cat([x_v, cond_v.unsqueeze(2)], dim=1) 167 | else: 168 | # print("not doing self attention") 169 | k=cond_k.unsqueeze(2) 170 | v=cond_v.unsqueeze(2) 171 | q=x_q 172 | 173 | 174 | #rearange a bit 175 | q=q.squeeze(2) 176 | kv=torch.cat([k,v],2) 177 | # print("final q before giving to flash",q.shape) 178 | # print("final kv before giving to flash",kv.shape) 179 | 180 | x = flash_attn.flash_attn_kvpacked_func(q, kv, softmax_scale=1.0) 181 | 182 | x = rearrange(x, 'b (h w) nh e -> b (h w) (nh e)', nh=self.n_heads, e=self.d_head, h=h, w=w) 183 | 184 | #last ff 185 | x = self.dropout(x) 186 | x = self.proj_out(x) 187 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w, c=c) 188 | 189 | x= x + x_in 190 | 191 | #attention part finished------ 192 | 193 | #linear feed forward 194 | x = rearrange(x, 'b c h w -> b h w c', h=h, w=w, c=c) 195 | 196 | # print("x before ff is ", x.shape) 197 | x = self.ff(x, global_cond) 198 | 199 | 200 | x = rearrange(x, 'b h w c -> b c h w', h=h, w=w, c=c) 201 | 202 | return x 203 | 204 | 205 | -------------------------------------------------------------------------------- /k_diffusion/models/axial_rope.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch._dynamo 5 | from torch import nn 6 | 7 | from . import flags 8 | 9 | if flags.get_use_compile(): 10 | torch._dynamo.config.suppress_errors = True 11 | 12 | 13 | def rotate_half(x): 14 | x1, x2 = x[..., 0::2], x[..., 1::2] 15 | x = torch.stack((-x2, x1), dim=-1) 16 | *shape, d, r = x.shape 17 | return x.view(*shape, d * r) 18 | 19 | 20 | @flags.compile_wrap 21 | def apply_rotary_emb(freqs, t, start_index=0, scale=1.0): 22 | freqs = freqs.to(t) 23 | rot_dim = freqs.shape[-1] 24 | end_index = start_index + rot_dim 25 | assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" 26 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 27 | t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) 28 | return torch.cat((t_left, t, t_right), dim=-1) 29 | 30 | 31 | def centers(start, stop, num, dtype=None, device=None): 32 | edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) 33 | return (edges[:-1] + edges[1:]) / 2 34 | 35 | 36 | def make_grid(h_pos, w_pos): 37 | grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1) 38 | h, w, d = grid.shape 39 | return grid.view(h * w, d) 40 | 41 | 42 | def bounding_box(h, w, pixel_aspect_ratio=1.0): 43 | # Adjusted dimensions 44 | w_adj = w 45 | h_adj = h * pixel_aspect_ratio 46 | 47 | # Adjusted aspect ratio 48 | ar_adj = w_adj / h_adj 49 | 50 | # Determine bounding box based on the adjusted aspect ratio 51 | y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0 52 | if ar_adj > 1: 53 | y_min, y_max = -1 / ar_adj, 1 / ar_adj 54 | elif ar_adj < 1: 55 | x_min, x_max = -ar_adj, ar_adj 56 | 57 | return y_min, y_max, x_min, x_max 58 | 59 | 60 | def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None): 61 | y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio) 62 | if align_corners: 63 | h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device) 64 | w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device) 65 | else: 66 | h_pos = centers(y_min, y_max, h, dtype=dtype, device=device) 67 | w_pos = centers(x_min, x_max, w, dtype=dtype, device=device) 68 | return make_grid(h_pos, w_pos) 69 | 70 | 71 | def freqs_pixel(max_freq=10.0): 72 | def init(shape): 73 | freqs = torch.linspace(1.0, max_freq / 2, shape[-1]) * math.pi 74 | return freqs.log().expand(shape) 75 | return init 76 | 77 | 78 | def freqs_pixel_log(max_freq=10.0): 79 | def init(shape): 80 | log_min = math.log(math.pi) 81 | log_max = math.log(max_freq * math.pi / 2) 82 | return torch.linspace(log_min, log_max, shape[-1]).expand(shape) 83 | return init 84 | 85 | 86 | class AxialRoPE(nn.Module): 87 | def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)): 88 | super().__init__() 89 | self.n_heads = n_heads 90 | self.start_index = start_index 91 | log_freqs = freqs_init((n_heads, dim // 4)) 92 | self.freqs_h = nn.Parameter(log_freqs.clone()) 93 | self.freqs_w = nn.Parameter(log_freqs.clone()) 94 | 95 | def extra_repr(self): 96 | dim = (self.freqs_h.shape[-1] + self.freqs_w.shape[-1]) * 2 97 | return f"dim={dim}, n_heads={self.n_heads}, start_index={self.start_index}" 98 | 99 | def get_freqs(self, pos): 100 | if pos.shape[-1] != 2: 101 | raise ValueError("input shape must be (..., 2)") 102 | freqs_h = pos[..., None, None, 0] * self.freqs_h.exp() 103 | freqs_w = pos[..., None, None, 1] * self.freqs_w.exp() 104 | freqs = torch.cat((freqs_h, freqs_w), dim=-1).repeat_interleave(2, dim=-1) 105 | return freqs.transpose(-2, -3) 106 | 107 | def forward(self, x, pos): 108 | freqs = self.get_freqs(pos) 109 | return apply_rotary_emb(freqs, x, self.start_index) 110 | -------------------------------------------------------------------------------- /k_diffusion/models/flags.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from functools import update_wrapper 3 | import os 4 | import threading 5 | 6 | import torch 7 | 8 | 9 | def get_use_compile(): 10 | return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1" 11 | 12 | 13 | def get_use_flash_attention_2(): 14 | return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1" 15 | 16 | 17 | state = threading.local() 18 | state.checkpointing = False 19 | 20 | 21 | @contextmanager 22 | def checkpointing(enable=True): 23 | try: 24 | old_checkpointing, state.checkpointing = state.checkpointing, enable 25 | yield 26 | finally: 27 | state.checkpointing = old_checkpointing 28 | 29 | 30 | def get_checkpointing(): 31 | return getattr(state, "checkpointing", False) 32 | 33 | 34 | class compile_wrap: 35 | def __init__(self, function, *args, **kwargs): 36 | self.function = function 37 | self.args = args 38 | self.kwargs = kwargs 39 | self._compiled_function = None 40 | update_wrapper(self, function) 41 | 42 | @property 43 | def compiled_function(self): 44 | if self._compiled_function is not None: 45 | return self._compiled_function 46 | if get_use_compile(): 47 | try: 48 | self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs) 49 | except RuntimeError: 50 | self._compiled_function = self.function 51 | else: 52 | self._compiled_function = self.function 53 | return self._compiled_function 54 | 55 | def __call__(self, *args, **kwargs): 56 | return self.compiled_function(*args, **kwargs) 57 | -------------------------------------------------------------------------------- /k_diffusion/models/flops.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import math 3 | import threading 4 | 5 | 6 | state = threading.local() 7 | state.flop_counter = None 8 | 9 | 10 | @contextmanager 11 | def flop_counter(enable=True): 12 | try: 13 | old_flop_counter = state.flop_counter 14 | state.flop_counter = FlopCounter() if enable else None 15 | yield state.flop_counter 16 | finally: 17 | state.flop_counter = old_flop_counter 18 | 19 | 20 | class FlopCounter: 21 | def __init__(self): 22 | self.ops = [] 23 | 24 | def op(self, op, *args, **kwargs): 25 | self.ops.append((op, args, kwargs)) 26 | 27 | @property 28 | def flops(self): 29 | flops = 0 30 | for op, args, kwargs in self.ops: 31 | flops += op(*args, **kwargs) 32 | return flops 33 | 34 | 35 | def op(op, *args, **kwargs): 36 | if getattr(state, "flop_counter", None): 37 | state.flop_counter.op(op, *args, **kwargs) 38 | 39 | 40 | def op_linear(x, weight): 41 | return math.prod(x) * weight[0] 42 | 43 | 44 | def op_attention(q, k, v): 45 | *b, s_q, d_q = q 46 | *b, s_k, d_k = k 47 | *b, s_v, d_v = v 48 | return math.prod(b) * s_q * s_k * (d_q + d_v) 49 | 50 | 51 | def op_natten(q, k, v, kernel_size): 52 | *q_rest, d_q = q 53 | *_, d_v = v 54 | return math.prod(q_rest) * (d_q + d_v) * kernel_size**2 55 | -------------------------------------------------------------------------------- /losses/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from losses.losses import compute_loss_dir_l1, compute_loss_curv_l1, compute_loss_l1, compute_loss_l2, compute_loss_kl 3 | import numpy as np 4 | 5 | class StrandVAELoss(torch.nn.Module): 6 | def __init__( 7 | self, 8 | eps: float = 1e-8, 9 | ): 10 | super().__init__() 11 | 12 | self.eps = eps 13 | 14 | def forward(self, phase, gt_dict, pred_dict, latent_dict, hyperparams): 15 | gt_strands= gt_dict["strand_positions"] 16 | pred_hair_strands= pred_dict["strand_positions"] 17 | 18 | # w_kl=map_range_val(phase.epoch_nr, hyperparams.loss_kl_anneal_epoch_nr_start, hyperparams.loss_kl_anneal_epoch_nr_finish, 0.0, hyperparams.loss_kl_weight) 19 | 20 | 21 | loss_dict = {} 22 | 23 | #we always predict the xyz loss just because it's fast 24 | # loss_l2 = compute_loss_l2(gt_strands, pred_hair_strands) 25 | loss_pos = compute_loss_l1(gt_strands, pred_hair_strands) 26 | loss_dir = compute_loss_dir_l1(gt_strands, pred_hair_strands) 27 | loss_curv = compute_loss_curv_l1(gt_strands, pred_hair_strands) 28 | loss_kl = 0.0 29 | if "z_logstd" in latent_dict: 30 | loss_kl = compute_loss_kl(latent_dict["z_mean"], latent_dict["z_logstd"]) 31 | loss = loss_pos*hyperparams.loss_pos_weight + loss_dir*hyperparams.loss_dir_weight + loss_curv*hyperparams.loss_curv_weight + loss_kl*hyperparams.loss_kl_weight 32 | # loss = loss_pos*hyperparams.loss_pos_weight + loss_dir*hyperparams.loss_dir_weight + loss_curv*hyperparams.loss_curv_weight + loss_kl*w_kl 33 | loss_dict['loss'] = loss 34 | loss_dict['loss_pos'] = loss_pos 35 | loss_dict['loss_dir'] = loss_dir 36 | loss_dict['loss_curv'] = loss_curv 37 | loss_dict['loss_kl'] = loss_kl 38 | 39 | 40 | return loss_dict 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /losses/loss_utils.py: -------------------------------------------------------------------------------- 1 | import scipy.signal 2 | import torch 3 | 4 | def apply_reduction(losses, reduction="none"): 5 | """Apply reduction to collection of losses.""" 6 | if reduction == "mean": 7 | losses = losses.mean() 8 | elif reduction == "sum": 9 | losses = losses.sum() 10 | return losses 11 | 12 | -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch 3 | import numpy as np 4 | from typing import List, Any 5 | from losses.loss_utils import apply_reduction 6 | from utils.general_util import get_window 7 | from utils.strand_util import compute_dirs, compute_curv 8 | 9 | 10 | def compute_loss_l2(gt_hair_strands, pred_hair_strands): 11 | loss_l2 = ((pred_hair_strands - gt_hair_strands) ** 2).mean() 12 | 13 | return loss_l2 14 | 15 | def compute_loss_l1(gt_hair_strands, pred_hair_strands): 16 | loss_l1 = torch.nn.functional.l1_loss(pred_hair_strands, gt_hair_strands).mean() 17 | 18 | return loss_l1 19 | 20 | def compute_loss_dir_dot(gt_hair_strands, pred_hair_strands): 21 | nr_verts_per_strand=256 22 | 23 | pred_points = pred_hair_strands.view(-1, nr_verts_per_strand, 3) 24 | gt_hair_strands=gt_hair_strands.view(-1, nr_verts_per_strand, 3) 25 | 26 | # get also a loss for the direciton, we need to compute the direction 27 | pred_deltas = compute_dirs(pred_points) 28 | pred_deltas = pred_deltas.view(-1, 3) 29 | pred_deltas = torch.nn.functional.normalize(pred_deltas, dim=-1) 30 | 31 | gt_dir = compute_dirs(gt_hair_strands) 32 | gt_dir = gt_dir.view(-1, 3) 33 | gt_dir = torch.nn.functional.normalize(gt_dir, dim=-1) 34 | # loss_dir = self.cosine_embed_loss(pred_deltas, gt_dir, torch.ones(gt_dir.shape[0]).cuda()) 35 | 36 | dot = torch.sum(pred_deltas * gt_dir, dim=-1) 37 | 38 | loss_dir = (1.0 - dot).mean() 39 | 40 | return loss_dir 41 | 42 | def compute_loss_dir_l1(gt_hair_strands, pred_hair_strands): 43 | nr_verts_per_strand=256 44 | 45 | pred_points = pred_hair_strands.view(-1, nr_verts_per_strand, 3) 46 | gt_hair_strands=gt_hair_strands.view(-1, nr_verts_per_strand, 3) 47 | 48 | # get also a loss for the direciton, we need to compute the direction 49 | pred_deltas = compute_dirs(pred_points) 50 | pred_deltas = pred_deltas.view(-1, 3) 51 | pred_deltas = pred_deltas*nr_verts_per_strand #Just because the deltas are very tiny and I want them in a nicer range for the loss 52 | 53 | gt_dir = compute_dirs(gt_hair_strands) 54 | gt_dir = gt_dir.view(-1, 3) 55 | gt_dir = gt_dir*nr_verts_per_strand #Just because the deltas are very tiny and I want them in a nicer range for the loss 56 | # loss_dir = self.cosine_embed_loss(pred_deltas, gt_dir, torch.ones(gt_dir.shape[0]).cuda()) 57 | 58 | loss_l1 = torch.nn.functional.l1_loss(pred_deltas, gt_dir).mean() 59 | return loss_l1 60 | 61 | def compute_loss_curv_l1(gt_hair_strands, pred_hair_strands): 62 | nr_verts_per_strand=256 63 | 64 | pred_points = pred_hair_strands.view(-1, nr_verts_per_strand, 3) 65 | gt_hair_strands=gt_hair_strands.view(-1, nr_verts_per_strand, 3) 66 | 67 | # get also a loss for the direciton, we need to compute the direction 68 | pred_deltas = compute_dirs(pred_points) 69 | pred_curvs = compute_curv(pred_deltas) 70 | pred_curvs = pred_curvs.view(-1, 3) 71 | pred_curvs = pred_curvs*nr_verts_per_strand #Just because the deltas are very tiny and I want them in a nicer range for the loss 72 | 73 | gt_dir = compute_dirs(gt_hair_strands) 74 | gt_curvs = compute_curv(gt_dir) 75 | gt_curvs = gt_curvs.view(-1, 3) 76 | gt_curvs = gt_curvs*nr_verts_per_strand #Just because the deltas are very tiny and I want them in a nicer range for the loss 77 | # loss_dir = self.cosine_embed_loss(pred_deltas, gt_dir, torch.ones(gt_dir.shape[0]).cuda()) 78 | 79 | loss_l1 = torch.nn.functional.l1_loss(pred_curvs, gt_curvs).mean() 80 | return loss_l1 81 | 82 | def compute_loss_kl(mean, logstd): 83 | #get input data 84 | kl_loss = 0 85 | 86 | #kl loss 87 | 88 | kl_shape = kl( mean, logstd) 89 | # free bits from IAF-VAE. so that if the KL drops below a certan value, then we stop reducing the KL 90 | kl_shape = torch.clamp(kl_shape, min=0.25) 91 | 92 | kl_loss = kl_shape.mean() 93 | 94 | return kl_loss 95 | 96 | def kl(mean, logstd): 97 | kl = (-0.5 - logstd + 0.5 * mean ** 2 + 0.5 * torch.exp(2 * logstd)) 98 | return kl 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /models/rgb_to_material.py: -------------------------------------------------------------------------------- 1 | from modules.networks import kaiming_init 2 | import torch 3 | from torch import nn 4 | import os 5 | import json 6 | import numpy as np 7 | 8 | 9 | class RGB2MaterialModel(nn.Module): 10 | 11 | def __init__(self, input_dim, out_dim, hidden_dim): 12 | super().__init__() 13 | 14 | self.out_dim=out_dim 15 | 16 | 17 | 18 | #attempt 2 19 | self.dino2conf=nn.Sequential( 20 | nn.Conv2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=1, padding=0, bias=True), 21 | nn.SiLU(), 22 | nn.Conv2d(in_channels=hidden_dim, out_channels=1, kernel_size=1, padding=0, bias=True), 23 | nn.Sigmoid() 24 | ) 25 | self.dino2mat=nn.Sequential( 26 | nn.Conv2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=1, padding=0, bias=True), 27 | nn.SiLU(), 28 | nn.Conv2d(in_channels=hidden_dim, out_channels=out_dim, kernel_size=1, padding=0, bias=True), 29 | nn.Sigmoid() 30 | ) 31 | 32 | 33 | 34 | 35 | self.apply(lambda x: kaiming_init(x, False, nonlinearity="silu")) 36 | 37 | def save(self, root_folder, experiment_name, hyperparams, iter_nr, info=None): 38 | name=str(iter_nr) 39 | if info is not None: 40 | name+="_"+info 41 | models_path = os.path.join(root_folder, experiment_name, name, "models") 42 | if not os.path.exists(models_path): 43 | os.makedirs(models_path, exist_ok=True) 44 | torch.save(self.state_dict(), os.path.join(models_path, "rgb2material.pt")) 45 | 46 | hyperparams_params_path=os.path.join(models_path, "hyperparams.json") 47 | with open(hyperparams_params_path, 'w', encoding='utf-8') as f: 48 | json.dump(vars(hyperparams), f, ensure_ascii=False, indent=4) 49 | 50 | 51 | def forward(self, batch_dict): 52 | 53 | x=batch_dict["dinov2_latents"] #BCHW (1,1024,55,55) 54 | 55 | 56 | #attempt 2, each patch predicts a confidence and a material, then we average all the materials across all patches, weighted by confidence 57 | conf=self.dino2conf(x) 58 | mat=self.dino2mat(x) 59 | 60 | 61 | #average the mat across the pixels 62 | avg_mat = (mat*conf).sum((2,3)) / (conf.sum((2,3)) +1e-6) #sum across all H and W dimensions 63 | x=avg_mat 64 | 65 | 66 | #split the material in parameters, at least the ones that are meaningfull and actually have a loss applied to them 67 | melanin=x[:,3] 68 | redness=x[:,4] 69 | root_darkness_start=x[:,8] 70 | root_darkness_end=x[:,9] 71 | root_darkness_strength=x[:,10] 72 | 73 | 74 | 75 | pred_dict={} 76 | pred_dict["material"]=x 77 | pred_dict["melanin"]=melanin 78 | pred_dict["redness"]=redness 79 | pred_dict["root_darkness_start"]=root_darkness_start 80 | pred_dict["root_darkness_end"]=root_darkness_end 81 | pred_dict["root_darkness_strength"]=root_darkness_strength 82 | 83 | 84 | return pred_dict -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/modules/__init__.py -------------------------------------------------------------------------------- /modules/edm2_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class MPFourier(torch.nn.Module): 6 | def __init__(self, num_channels, bandwidth=1): 7 | super().__init__() 8 | self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth) 9 | self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels)) 10 | 11 | def forward(self, x): 12 | y = x.to(torch.float32) 13 | y = y.ger(self.freqs.to(torch.float32)) 14 | y = y + self.phases.to(torch.float32) 15 | y = y.cos() * np.sqrt(2) 16 | return y.to(x.dtype) -------------------------------------------------------------------------------- /modules/networks.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | class LinearDummy(torch.nn.Module): 9 | def __init__(self, in_channels, out_channels, bias=True): 10 | super().__init__() 11 | self.out_channels = out_channels 12 | self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels)) 13 | self.bias=None 14 | if bias: 15 | self.bias = torch.nn.Parameter(torch.zeros(out_channels)) 16 | 17 | def forward(self, x, gain=1): 18 | return F.linear(x, self.weight, self.bias) 19 | 20 | 21 | class BlockSiren(torch.nn.Module): 22 | def __init__(self, in_channels, out_channels, use_bias=True, is_first_layer=False, scale_init=1.0): 23 | super(BlockSiren, self).__init__() 24 | self.use_bias = use_bias 25 | # self.activ = activ 26 | self.is_first_layer = is_first_layer 27 | self.scale_init = scale_init 28 | # self.freq_scaling = scale_init 29 | 30 | self.conv = torch.nn.Linear(in_channels, out_channels, bias=self.use_bias).cuda() 31 | 32 | with torch.no_grad(): 33 | 34 | #following the official implementation from https://github.com/vsitzmann/siren/blob/master/explore_siren.ipynb 35 | if self.is_first_layer: 36 | self.conv.weight.uniform_(-1 / in_channels, 37 | 1 / in_channels) 38 | else: 39 | self.conv.weight.uniform_(-np.sqrt(6 / in_channels) / self.scale_init, 40 | np.sqrt(6 / in_channels) / self.scale_init) 41 | 42 | self.conv.bias.data.zero_() 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | 47 | x = self.scale_init * x 48 | 49 | x = torch.sin(x) 50 | 51 | return x 52 | 53 | 54 | class LinearWN_v2(torch.nn.Linear): 55 | def __init__(self, in_features, out_features, bias=True): 56 | super(LinearWN_v2, self).__init__(in_features, out_features, bias) 57 | self.g = torch.nn.Parameter(torch.ones(out_features)) 58 | 59 | 60 | self.in_features=in_features 61 | self.out_features=out_features 62 | 63 | def forward(self, input): 64 | w= torch._weight_norm(self.weight, self.g, 0) 65 | out=F.linear(input, w, self.bias) 66 | return out 67 | 68 | 69 | 70 | class Conv1dWN_v2(torch.nn.Conv1d): 71 | def __init__( 72 | self, 73 | in_channels, 74 | out_channels, 75 | kernel_size, 76 | stride=1, 77 | padding=0, 78 | padding_mode="zeros", 79 | dilation=1, 80 | groups=1, 81 | ): 82 | super(Conv1dWN_v2, self).__init__( 83 | in_channels, 84 | out_channels, 85 | kernel_size, 86 | stride, 87 | padding, 88 | dilation, 89 | groups, 90 | True, 91 | ) 92 | self.g = torch.nn.Parameter(torch.ones(out_channels)) 93 | 94 | self.padding_amount=padding 95 | self.padding_mode=padding_mode 96 | 97 | def forward(self, x): 98 | w= torch._weight_norm(self.weight, self.g, 0) 99 | 100 | 101 | 102 | if self.padding_mode != 'zeros': 103 | x= F.pad(x, (self.padding_amount,self.padding_amount), mode=self.padding_mode) 104 | 105 | x= F.conv1d( 106 | x, 107 | w, 108 | bias=self.bias, 109 | stride=self.stride, 110 | padding=0, 111 | dilation=self.dilation, 112 | groups=self.groups, 113 | ) 114 | # print("x after conv", x.shape) 115 | return x 116 | else: 117 | 118 | return F.conv1d( 119 | x, 120 | w, 121 | bias=self.bias, 122 | stride=self.stride, 123 | padding=self.padding, 124 | dilation=self.dilation, 125 | groups=self.groups, 126 | ) 127 | 128 | 129 | def kaiming_init(m, is_linear, nonlinearity="silu"): 130 | # gain = math.sqrt(2.0 / (1.0 + alpha**2)) 131 | 132 | # gain=np.sqrt(10.5) 133 | if nonlinearity=="silu": 134 | gain=np.sqrt(2.3) #works fine with silu 135 | elif nonlinearity=="relu": 136 | gain=np.sqrt(2) #works fine with silu 137 | # gain=np.sqrt(2.15) 138 | # gain=np.sqrt(0.92) #for mpsilu 139 | scale=1.0 140 | 141 | if is_linear: 142 | gain = 1 143 | 144 | # print("effective scale", gain*scale) 145 | 146 | # print("m is ",m) 147 | # help(m) 148 | 149 | if isinstance(m, torch.nn.Conv2d): 150 | ksize = m.kernel_size[0] * m.kernel_size[1] 151 | n1 = m.in_channels 152 | n2 = m.out_channels 153 | 154 | # std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) 155 | std = gain * math.sqrt(n1 * ksize) 156 | elif isinstance(m, torch.nn.ConvTranspose2d): 157 | ksize = m.kernel_size[0] * m.kernel_size[1] // 4 158 | n1 = m.in_channels 159 | n2 = m.out_channels 160 | 161 | # std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) 162 | std = gain * math.sqrt(n1 * ksize) 163 | elif isinstance(m, torch.nn.Conv1d): 164 | ksize = m.kernel_size[0] 165 | n1 = m.in_channels 166 | n2 = m.out_channels 167 | # std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) 168 | std = gain * math.sqrt(n1 * ksize) 169 | elif isinstance(m, torch.nn.Linear): 170 | n1 = m.in_features 171 | n2 = m.out_features 172 | 173 | # std = gain * math.sqrt(2.0 / (n1 + n2)) 174 | std = gain * math.sqrt(n1) 175 | # elif isinstance(m, SMConv1d_v2): 176 | # print("SMConv1d_v2") 177 | # exit() 178 | else: 179 | return 180 | 181 | # help(m) 182 | 183 | # print("m is ", m) 184 | 185 | # print("std is", std) 186 | 187 | # m.weight.data.normal_(0, std) 188 | # m.weight.data.uniform_(-std * scale, std * scale) 189 | fan = torch.nn.init._calculate_correct_fan(m.weight, "fan_in") 190 | std = gain / math.sqrt(fan) 191 | # print("std is", std) 192 | with torch.no_grad(): 193 | # m.weight.normal_(0, std) 194 | m.weight.data.uniform_(-std * math.sqrt(3.0) * scale, std * math.sqrt(3.0) * scale) 195 | if m.bias is not None: 196 | m.bias.data.zero_() 197 | 198 | if isinstance(m, torch.nn.ConvTranspose2d): 199 | # hardcoded for stride=2 for now 200 | m.weight.data[:, :, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] 201 | m.weight.data[:, :, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2] 202 | m.weight.data[:, :, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] 203 | 204 | # if isinstance(m, Conv2dWNUB) or isinstance(m, ConvTranspose2dWNUB) or isinstance(m, LinearWN): 205 | # print("m is ", m) 206 | if ( 207 | # isinstance(m, torch.Conv2dWNUB) 208 | # isinstance(m, torch.Conv2dWN) 209 | # or isinstance(m, torch.ConvTranspose2dWN) 210 | # or isinstance(m, torch.ConvTranspose2dWNUB) 211 | isinstance(m, LinearWN_v2) 212 | or isinstance(m, Conv1dWN_v2) 213 | ): 214 | # print("selected m is ", m) 215 | # help(m) 216 | # norm = np.sqrt(torch.sum(m.weight.data[:] ** 2)) 217 | dims = list(range(1, len(m.weight.shape))) 218 | norm = torch.norm(m.weight, 2, dim=dims, keepdim=True) 219 | # print("weight is ", m.weight.shape) 220 | # print("norm",norm.shape) 221 | # print("m.g.data",m.g.data.shape) 222 | # norm = torch.norm(m.weight, 2, dim=0, keepdim=True) 223 | m.g.data = norm 224 | 225 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.31.0 2 | torch==2.5.0 3 | torchvision==0.20.0 4 | opencv-python==4.9.0.80 5 | mediapipe==0.10.10 6 | trimesh==4.0.4 7 | jsonmerge==1.9.2 8 | dctorch==0.1.2 9 | einops==0.8.1 10 | torchdiffeq==0.2.5 11 | torchsde==0.2.6 12 | wheel==0.41.2 13 | setuptools==59.5.0 14 | libigl==2.5.1 15 | 16 | 17 | -------------------------------------------------------------------------------- /samples/buzzcut_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/buzzcut_3.jpg -------------------------------------------------------------------------------- /samples/cooper_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/cooper_4.jpg -------------------------------------------------------------------------------- /samples/fiennes_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/fiennes_7.jpg -------------------------------------------------------------------------------- /samples/freeman_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/freeman_2.png -------------------------------------------------------------------------------- /samples/harris_5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/harris_5.jpeg -------------------------------------------------------------------------------- /samples/hathaway_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/hathaway_1.jpg -------------------------------------------------------------------------------- /samples/medium_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/medium_11.png -------------------------------------------------------------------------------- /samples/uggams_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meshcapade/difflocks/4f4b2ecf93c452808d9fd7e25f5fa276c6f28d58/samples/uggams_3.png -------------------------------------------------------------------------------- /schedulers/linearlr.py: -------------------------------------------------------------------------------- 1 | #from pytorch 1.10 2 | 3 | import types 4 | import math 5 | from torch import inf 6 | from functools import wraps 7 | import warnings 8 | import weakref 9 | from collections import Counter 10 | from bisect import bisect_right 11 | 12 | # from .optimizer import Optimizer 13 | 14 | import torch 15 | from torch.optim import Optimizer 16 | 17 | class LinearLR(torch.optim.lr_scheduler._LRScheduler): 18 | """Decays the learning rate of each parameter group by linearly changing small 19 | multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. 20 | Notice that such decay can happen simultaneously with other changes to the learning rate 21 | from outside this scheduler. When last_epoch=-1, sets initial lr as lr. 22 | 23 | Args: 24 | optimizer (Optimizer): Wrapped optimizer. 25 | start_factor (float): The number we multiply learning rate in the first epoch. 26 | The multiplication factor changes towards end_factor in the following epochs. 27 | Default: 1./3. 28 | end_factor (float): The number we multiply learning rate at the end of linear changing 29 | process. Default: 1.0. 30 | total_iters (int): The number of iterations that multiplicative factor reaches to 1. 31 | Default: 5. 32 | last_epoch (int): The index of the last epoch. Default: -1. 33 | verbose (bool): If ``True``, prints a message to stdout for 34 | each update. Default: ``False``. 35 | 36 | Example: 37 | >>> # Assuming optimizer uses lr = 0.05 for all groups 38 | >>> # lr = 0.025 if epoch == 0 39 | >>> # lr = 0.03125 if epoch == 1 40 | >>> # lr = 0.0375 if epoch == 2 41 | >>> # lr = 0.04375 if epoch == 3 42 | >>> # lr = 0.005 if epoch >= 4 43 | >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) 44 | >>> for epoch in range(100): 45 | >>> train(...) 46 | >>> validate(...) 47 | >>> scheduler.step() 48 | """ 49 | 50 | def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, 51 | verbose=False): 52 | if start_factor > 1.0 or start_factor < 0: 53 | raise ValueError('Starting multiplicative factor expected to be between 0 and 1.') 54 | 55 | if end_factor > 1.0 or end_factor < 0: 56 | raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') 57 | 58 | self.start_factor = start_factor 59 | self.end_factor = end_factor 60 | self.total_iters = total_iters 61 | super(LinearLR, self).__init__(optimizer, last_epoch, verbose) 62 | 63 | def get_lr(self): 64 | if not self._get_lr_called_within_step: 65 | warnings.warn("To get the last learning rate computed by the scheduler, " 66 | "please use `get_last_lr()`.", UserWarning) 67 | 68 | if self.last_epoch == 0: 69 | return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] 70 | 71 | if (self.last_epoch > self.total_iters): 72 | return [group['lr'] for group in self.optimizer.param_groups] 73 | 74 | return [group['lr'] * (1. + (self.end_factor - self.start_factor) / 75 | (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) 76 | for group in self.optimizer.param_groups] 77 | 78 | def _get_closed_form_lr(self): 79 | return [base_lr * (self.start_factor + 80 | (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) 81 | for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /schedulers/multisteplr.py: -------------------------------------------------------------------------------- 1 | #from pytorch 1.10 2 | 3 | import types 4 | import math 5 | from torch import inf 6 | from functools import wraps 7 | import warnings 8 | import weakref 9 | from collections import Counter 10 | from bisect import bisect_right 11 | 12 | # from .optimizer import Optimizer 13 | 14 | import torch 15 | from torch.optim import Optimizer 16 | 17 | 18 | class MultiStepLR(torch.optim.lr_scheduler._LRScheduler): 19 | """Decays the learning rate of each parameter group by gamma once the 20 | number of epoch reaches one of the milestones. Notice that such decay can 21 | happen simultaneously with other changes to the learning rate from outside 22 | this scheduler. When last_epoch=-1, sets initial lr as lr. 23 | 24 | Args: 25 | optimizer (Optimizer): Wrapped optimizer. 26 | milestones (list): List of epoch indices. Must be increasing. 27 | gamma (float): Multiplicative factor of learning rate decay. 28 | Default: 0.1. 29 | last_epoch (int): The index of last epoch. Default: -1. 30 | verbose (bool): If ``True``, prints a message to stdout for 31 | each update. Default: ``False``. 32 | 33 | Example: 34 | >>> # Assuming optimizer uses lr = 0.05 for all groups 35 | >>> # lr = 0.05 if epoch < 30 36 | >>> # lr = 0.005 if 30 <= epoch < 80 37 | >>> # lr = 0.0005 if epoch >= 80 38 | >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) 39 | >>> for epoch in range(100): 40 | >>> train(...) 41 | >>> validate(...) 42 | >>> scheduler.step() 43 | """ 44 | 45 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False): 46 | self.milestones = Counter(milestones) 47 | self.gamma = gamma 48 | super(MultiStepLR, self).__init__(optimizer, last_epoch, verbose) 49 | 50 | def get_lr(self): 51 | if not self._get_lr_called_within_step: 52 | warnings.warn("To get the last learning rate computed by the scheduler, " 53 | "please use `get_last_lr()`.", UserWarning) 54 | 55 | if self.last_epoch not in self.milestones: 56 | return [group['lr'] for group in self.optimizer.param_groups] 57 | return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] 58 | for group in self.optimizer.param_groups] 59 | 60 | def _get_closed_form_lr(self): 61 | milestones = list(sorted(self.milestones.elements())) 62 | return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) 63 | for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /schedulers/pytorch_warmup/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseWarmup, LinearWarmup, ExponentialWarmup 2 | from .untuned import UntunedLinearWarmup, UntunedExponentialWarmup 3 | from .radam import RAdamWarmup, rho_fn, rho_inf_fn, get_offset 4 | 5 | __version__ = "0.2.0.dev0" 6 | 7 | __all__ = [ 8 | 'BaseWarmup', 9 | 'LinearWarmup', 10 | 'ExponentialWarmup', 11 | 'UntunedLinearWarmup', 12 | 'UntunedExponentialWarmup', 13 | 'RAdamWarmup', 14 | 'rho_fn', 15 | 'rho_inf_fn', 16 | 'get_offset', 17 | ] -------------------------------------------------------------------------------- /schedulers/pytorch_warmup/base.py: -------------------------------------------------------------------------------- 1 | ##from https://github.com/Tony-Y/pytorch_warmup/blob/master/pytorch_warmup/base.py 2 | 3 | import math 4 | from contextlib import contextmanager 5 | from torch.optim import Optimizer 6 | 7 | 8 | def _check_optimizer(optimizer): 9 | if not isinstance(optimizer, Optimizer): 10 | raise TypeError('{} ({}) is not an Optimizer.'.format( 11 | optimizer, type(optimizer).__name__)) 12 | 13 | 14 | class BaseWarmup(object): 15 | """Base class for all warmup schedules 16 | 17 | Arguments: 18 | optimizer (Optimizer): an instance of a subclass of Optimizer 19 | warmup_params (list): warmup paramters 20 | last_step (int): The index of last step. (Default: -1) 21 | """ 22 | 23 | def __init__(self, optimizer, warmup_params, last_step=-1): 24 | self.optimizer = optimizer 25 | self.warmup_params = warmup_params 26 | self.last_step = last_step 27 | self.lrs = [group['lr'] for group in self.optimizer.param_groups] 28 | self.dampen() 29 | 30 | def state_dict(self): 31 | """Returns the state of the warmup scheduler as a :class:`dict`. 32 | 33 | It contains an entry for every variable in self.__dict__ which 34 | is not the optimizer. 35 | """ 36 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 37 | 38 | def load_state_dict(self, state_dict): 39 | """Loads the warmup scheduler's state. 40 | 41 | Arguments: 42 | state_dict (dict): warmup scheduler state. Should be an object returned 43 | from a call to :meth:`state_dict`. 44 | """ 45 | self.__dict__.update(state_dict) 46 | 47 | def dampen(self, step=None): 48 | """Dampen the learning rates. 49 | 50 | Arguments: 51 | step (int): The index of current step. (Default: None) 52 | """ 53 | if step is None: 54 | step = self.last_step + 1 55 | self.last_step = step 56 | 57 | for group, params in zip(self.optimizer.param_groups, self.warmup_params): 58 | omega = self.warmup_factor(step, **params) 59 | group['lr'] *= omega 60 | 61 | @contextmanager 62 | def dampening(self): 63 | for group, lr in zip(self.optimizer.param_groups, self.lrs): 64 | group['lr'] = lr 65 | yield 66 | self.lrs = [group['lr'] for group in self.optimizer.param_groups] 67 | self.dampen() 68 | 69 | def warmup_factor(self, step, **params): 70 | raise NotImplementedError 71 | 72 | 73 | def get_warmup_params(warmup_period, group_count): 74 | if isinstance(warmup_period, list): 75 | if len(warmup_period) != group_count: 76 | raise ValueError( 77 | 'The size of warmup_period ({}) does not match the size of param_groups ({}).'.format( 78 | len(warmup_period), group_count)) 79 | for x in warmup_period: 80 | if not isinstance(x, int): 81 | raise TypeError( 82 | 'An element in warmup_period, {}, is not an int.'.format( 83 | type(x).__name__)) 84 | if x <= 0: 85 | raise ValueError( 86 | 'An element in warmup_period must be a positive integer, but is {}.'.format(x)) 87 | warmup_params = [dict(warmup_period=x) for x in warmup_period] 88 | elif isinstance(warmup_period, int): 89 | if warmup_period <= 0: 90 | raise ValueError( 91 | 'warmup_period must be a positive integer, but is {}.'.format(warmup_period)) 92 | warmup_params = [dict(warmup_period=warmup_period) 93 | for _ in range(group_count)] 94 | else: 95 | raise TypeError('{} ({}) is not a list nor an int.'.format( 96 | warmup_period, type(warmup_period).__name__)) 97 | return warmup_params 98 | 99 | 100 | class LinearWarmup(BaseWarmup): 101 | """Linear warmup schedule. 102 | 103 | Arguments: 104 | optimizer (Optimizer): an instance of a subclass of Optimizer 105 | warmup_period (int or list): Warmup period 106 | last_step (int): The index of last step. (Default: -1) 107 | """ 108 | 109 | def __init__(self, optimizer, warmup_period, last_step=-1): 110 | _check_optimizer(optimizer) 111 | group_count = len(optimizer.param_groups) 112 | warmup_params = get_warmup_params(warmup_period, group_count) 113 | super(LinearWarmup, self).__init__(optimizer, warmup_params, last_step) 114 | 115 | def warmup_factor(self, step, warmup_period): 116 | return min(1.0, (step+1) / warmup_period) 117 | 118 | 119 | class ExponentialWarmup(BaseWarmup): 120 | """Exponential warmup schedule. 121 | 122 | Arguments: 123 | optimizer (Optimizer): an instance of a subclass of Optimizer 124 | warmup_period (int or list): Effective warmup period 125 | last_step (int): The index of last step. (Default: -1) 126 | """ 127 | 128 | def __init__(self, optimizer, warmup_period, last_step=-1): 129 | _check_optimizer(optimizer) 130 | group_count = len(optimizer.param_groups) 131 | warmup_params = get_warmup_params(warmup_period, group_count) 132 | super(ExponentialWarmup, self).__init__(optimizer, warmup_params, last_step) 133 | 134 | def warmup_factor(self, step, warmup_period): 135 | return 1.0 - math.exp(-(step+1) / warmup_period) -------------------------------------------------------------------------------- /schedulers/pytorch_warmup/radam.py: -------------------------------------------------------------------------------- 1 | ###from https://github.com/Tony-Y/pytorch_warmup/blob/master/pytorch_warmup/radam.py 2 | 3 | import math 4 | from .base import BaseWarmup, _check_optimizer 5 | 6 | 7 | def rho_inf_fn(beta2): 8 | return 2.0 / (1 - beta2) - 1 9 | 10 | 11 | def rho_fn(t, beta2, rho_inf): 12 | b2t = beta2 ** t 13 | rho_t = rho_inf - 2 * t * b2t / (1 - b2t) 14 | return rho_t 15 | 16 | 17 | def get_offset(beta2, rho_inf): 18 | if not beta2 > 0.6: 19 | raise ValueError('beta2 ({}) must be greater than 0.6'.format(beta2)) 20 | offset = 1 21 | while True: 22 | if rho_fn(offset, beta2, rho_inf) > 4: 23 | return offset 24 | offset += 1 25 | 26 | 27 | class RAdamWarmup(BaseWarmup): 28 | """RAdam warmup schedule. 29 | 30 | This warmup scheme is described in 31 | `On the adequacy of untuned warmup for adaptive optimization 32 | `_. 33 | 34 | Arguments: 35 | optimizer (Optimizer): an Adam optimizer 36 | last_step (int): The index of last step. (Default: -1) 37 | """ 38 | 39 | def __init__(self, optimizer, last_step=-1): 40 | _check_optimizer(optimizer) 41 | warmup_params = [ 42 | dict( 43 | beta2=x['betas'][1], 44 | rho_inf=rho_inf_fn(x['betas'][1]), 45 | ) 46 | for x in optimizer.param_groups 47 | ] 48 | for x in warmup_params: 49 | x['offset'] = get_offset(**x) 50 | super(RAdamWarmup, self).__init__(optimizer, warmup_params, last_step) 51 | 52 | def warmup_factor(self, step, beta2, rho_inf, offset): 53 | rho = rho_fn(step+offset, beta2, rho_inf) 54 | numerator = (rho - 4) * (rho - 2) * rho_inf 55 | denominator = (rho_inf - 4) * (rho_inf - 2) * rho 56 | return math.sqrt(numerator/denominator) -------------------------------------------------------------------------------- /schedulers/pytorch_warmup/untuned.py: -------------------------------------------------------------------------------- 1 | from .base import LinearWarmup, ExponentialWarmup, _check_optimizer 2 | 3 | 4 | class UntunedLinearWarmup(LinearWarmup): 5 | """Untuned linear warmup schedule for Adam. 6 | 7 | This warmup scheme is described in 8 | `On the adequacy of untuned warmup for adaptive optimization 9 | `_. 10 | 11 | Arguments: 12 | optimizer (Optimizer): an Adam optimizer 13 | last_step (int): The index of last step. (Default: -1) 14 | """ 15 | 16 | def __init__(self, optimizer, last_step=-1): 17 | _check_optimizer(optimizer) 18 | 19 | def warmup_period_fn(beta2): 20 | return int(2.0 / (1.0-beta2)) 21 | warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups] 22 | super(UntunedLinearWarmup, self).__init__(optimizer, warmup_period, last_step) 23 | 24 | 25 | class UntunedExponentialWarmup(ExponentialWarmup): 26 | """Untuned exponetial warmup schedule for Adam. 27 | 28 | This warmup scheme is described in 29 | `On the adequacy of untuned warmup for adaptive optimization 30 | `_. 31 | 32 | Arguments: 33 | optimizer (Optimizer): an Adam optimizer 34 | last_step (int): The index of last step. (Default: -1) 35 | """ 36 | 37 | def __init__(self, optimizer, last_step=-1): 38 | _check_optimizer(optimizer) 39 | 40 | def warmup_period_fn(beta2): 41 | return int(1.0 / (1.0-beta2)) 42 | warmup_period = [warmup_period_fn(x['betas'][1]) for x in optimizer.param_groups] 43 | super(UntunedExponentialWarmup, self).__init__(optimizer, warmup_period, last_step) -------------------------------------------------------------------------------- /schedulers/warmup.py: -------------------------------------------------------------------------------- 1 | #from https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py 2 | 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | from torch.optim.lr_scheduler import ReduceLROnPlateau 5 | 6 | 7 | class GradualWarmupScheduler(_LRScheduler): 8 | """ Gradually warm-up(increasing) learning rate in optimizer. 9 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 10 | Args: 11 | optimizer (Optimizer): Wrapped optimizer. 12 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 13 | total_epoch: target learning rate is reached at total_epoch, gradually 14 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 15 | """ 16 | 17 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 18 | self.multiplier = multiplier 19 | if self.multiplier < 1.: 20 | raise ValueError('multiplier should be greater thant or equal to 1.') 21 | self.total_epoch = total_epoch 22 | self.after_scheduler = after_scheduler 23 | self.finished = False 24 | super(GradualWarmupScheduler, self).__init__(optimizer) 25 | 26 | def get_lr(self): 27 | if self.last_epoch > self.total_epoch: 28 | if self.after_scheduler: 29 | if not self.finished: 30 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 31 | self.finished = True 32 | return self.after_scheduler.get_last_lr() 33 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 34 | 35 | if self.multiplier == 1.0: 36 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 37 | else: 38 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 39 | 40 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 41 | if epoch is None: 42 | epoch = self.last_epoch + 1 43 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 44 | if self.last_epoch <= self.total_epoch: 45 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 46 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 47 | param_group['lr'] = lr 48 | else: 49 | if epoch is None: 50 | self.after_scheduler.step(metrics, None) 51 | else: 52 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 53 | 54 | def step(self, epoch=None, metrics=None): 55 | if type(self.after_scheduler) != ReduceLROnPlateau: 56 | if self.finished and self.after_scheduler: 57 | if epoch is None: 58 | self.after_scheduler.step(None) 59 | else: 60 | self.after_scheduler.step(epoch - self.total_epoch) 61 | self._last_lr = self.after_scheduler.get_last_lr() 62 | else: 63 | return super(GradualWarmupScheduler, self).step(epoch) 64 | else: 65 | self.step_ReduceLROnPlateau(metrics, epoch) -------------------------------------------------------------------------------- /utils/create_strand_latent_weights.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | #when training the diffusion model we apply a L2 loss on all the channels of the scalp texture. However some of the channels are essentially noise. 4 | #This is due to the fact that some dimensions of the strand vae encode very little information and don't signficantly modify the strand shape. 5 | #here we check what is the change in position, direction and curvature when changing each of the dimensions of the latent and we use this delta change as a weight for our diffusion model to downweight certain channels 6 | 7 | 8 | 9 | #python3 ./create_strand_latent_weights.py --checkpoint_path 10 | 11 | 12 | 13 | 14 | import argparse 15 | import torch 16 | from models.strand_codec import StrandCodec 17 | from utils.strand_util import compute_dirs 18 | from losses.loss import StrandVAELoss 19 | import json 20 | from data_loader.dataloader import DiffLocksDataset 21 | from data_loader.mesh_utils import world_to_tbn_space 22 | 23 | 24 | 25 | 26 | torch.set_grad_enabled(False) 27 | 28 | 29 | #transforms the data to a local space, put it on cuda device and reshapes it the way we expect it to be 30 | def prepare_gt_batch(batch): 31 | gt_dict = {} 32 | 33 | tbn=batch['full_strands']["tbn"].cuda() 34 | positions=batch['full_strands']["positions"].cuda() 35 | root_normal=batch['full_strands']["root_normal"].cuda() 36 | 37 | #get it on local space 38 | gt_strand_positions, gt_root_normals = world_to_tbn_space(tbn, 39 | positions, 40 | root_normal) 41 | gt_strand_positions=gt_strand_positions.cuda() 42 | 43 | #reshape it to be nr_strands, nr_points, dim 44 | gt_strand_positions=gt_strand_positions.reshape(-1,256,3) 45 | 46 | gt_dirs=compute_dirs(gt_strand_positions, append_last_dir=False) #nr_strands,256-1,3 47 | 48 | 49 | gt_dict["strand_positions"]=gt_strand_positions 50 | gt_dict["strand_directions"]=gt_dirs 51 | 52 | 53 | return gt_dict 54 | 55 | 56 | class HyperParamsStrandVAE: 57 | def __init__(self): 58 | #####Output 59 | #####decode dir###### 60 | self.decode_type="dir" 61 | self.scale_init=30.0 62 | self.nr_verts_per_strand=256 63 | self.nr_values_to_decode=255 64 | self.dim_per_value_decoded=3 65 | 66 | 67 | ###LOSS###### 68 | #these are the same values that were used to train the strand vae 69 | self.loss_pos_weight=0.5 ##FOR LAMBDA 70 | self.loss_dir_weight=1.0 71 | self.loss_curv_weight=20.0 72 | self.loss_kl_weight=0.0 73 | 74 | 75 | 76 | def main(): 77 | 78 | #argparse 79 | parser = argparse.ArgumentParser(description='Get the weights of each dimensions after training a strand VAE') 80 | parser.add_argument('--checkpoint_path', required=True, help='Path to the strandVAE checkpoint') 81 | args = parser.parse_args() 82 | 83 | path_strand_vae_model=args.checkpoint_path 84 | 85 | 86 | hyperparams=HyperParamsStrandVAE() 87 | 88 | 89 | normalization_dict=DiffLocksDataset.get_normalization_data() 90 | 91 | model = StrandCodec(do_vae=False, 92 | decode_type="dir", 93 | scale_init=30.0, 94 | nr_verts_per_strand=256, nr_values_to_decode=255, 95 | dim_per_value_decoded=3).cuda() 96 | model.load_state_dict(torch.load(path_strand_vae_model)) 97 | model = torch.compile(model) 98 | 99 | 100 | 101 | #latent of dimension 64 and get GT which is the mean strand 102 | latent=torch.zeros(1,64).cuda() 103 | pred_dict = model.decoder(latent, None, normalization_dict) 104 | pred_points=pred_dict["strand_positions"] 105 | gt_strand=pred_points 106 | print("gt_strand",gt_strand.shape) 107 | 108 | 109 | #make loss function 110 | loss_computer= StrandVAELoss() 111 | 112 | 113 | #for each dimension change it by 0.5 and check the error towards the mean strand (GT) 114 | loss_per_dim=[] 115 | for i in range(64): 116 | latent=torch.zeros(1,64).cuda() 117 | latent[:,i]=0.8 118 | pred_dict = model.decoder(latent, None, normalization_dict) 119 | pred_points=pred_dict["strand_positions"] 120 | 121 | #make dicts 122 | gt_dict={"strand_positions": gt_strand} 123 | pred_dict={"strand_positions": pred_points} 124 | latent_dict={} 125 | 126 | #loss 127 | loss_dict = loss_computer(None, gt_dict, pred_dict, latent_dict, hyperparams) 128 | loss=loss_dict["loss"] 129 | loss_per_dim.append(loss) 130 | 131 | # print("loss", loss) 132 | 133 | #normalize losses 134 | loss_normalization=max(loss_per_dim) 135 | weight_per_dim = [(x/loss_normalization).item() for x in loss_per_dim] 136 | 137 | 138 | 139 | #print them in order 140 | for i in range(64): 141 | # print("i", i, " w: ", weight_per_dim[i].item()) 142 | print(weight_per_dim[i]) 143 | 144 | with open("loss_weight_strand_latent.json", "w") as final: 145 | json.dump(weight_per_dim, final) 146 | 147 | 148 | #finished 149 | return 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | 155 | -------------------------------------------------------------------------------- /utils/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import k_diffusion as K 3 | 4 | @torch.no_grad() 5 | def sample_images(nr_images, model_ema, model_config, nr_iters=100, extra_args={}, callback=None): 6 | model_ema.eval() 7 | sigma_min = model_config['sigma_min'] 8 | sigma_max = model_config['sigma_max'] 9 | size = model_config['input_size'] 10 | n_per_proc = nr_images 11 | x = torch.randn([1, n_per_proc, model_config['input_channels'], size[0], size[1]]).cuda() 12 | x = x[0] * sigma_max 13 | model_fn = model_ema 14 | sigmas = K.sampling.get_sigmas_karras(nr_iters, sigma_min, sigma_max, rho=7., device="cuda") 15 | x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=False, callback=callback) 16 | return x_0 17 | 18 | @torch.no_grad() 19 | #samples using classifier free guidance and only enables the cfg_val when the sigma is within the interval. 20 | #idead from this paper: https://arxiv.org/pdf/2404.07724 21 | def sample_images_cfg(nr_images, cfg_val, cfg_interval, model_ema, model_config, nr_iters=100, extra_args={}, callback=None): 22 | model_ema.eval() 23 | sigma_min = model_config['sigma_min'] 24 | sigma_max = model_config['sigma_max'] 25 | size = model_config['input_size'] 26 | n_per_proc = nr_images 27 | x = torch.randn([1, n_per_proc, model_config['input_channels'], size[0], size[1]]).cuda() 28 | x = x[0] * sigma_max 29 | model_fn = model_ema 30 | sigmas = K.sampling.get_sigmas_karras(nr_iters, sigma_min, sigma_max, rho=7., device="cuda") 31 | x_0 = K.sampling.sample_dpmpp_2m_sde_cfg(model_fn, x, sigmas, cfg_val, cfg_interval, extra_args=extra_args, eta=0.0, solver_type='heun', disable=False, callback=callback) 32 | return x_0 -------------------------------------------------------------------------------- /utils/resize_right/interp_methods.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | torch = None 7 | 8 | try: 9 | import numpy 10 | except ImportError: 11 | numpy = None 12 | 13 | if numpy is None and torch is None: 14 | raise ImportError("Must have either Numpy or PyTorch but both not found") 15 | 16 | 17 | def set_framework_dependencies(x): 18 | if type(x) is numpy.ndarray: 19 | to_dtype = lambda a: a 20 | fw = numpy 21 | else: 22 | to_dtype = lambda a: a.to(x.dtype) 23 | fw = torch 24 | eps = fw.finfo(fw.float32).eps 25 | return fw, to_dtype, eps 26 | 27 | 28 | def support_sz(sz): 29 | def wrapper(f): 30 | f.support_sz = sz 31 | return f 32 | return wrapper 33 | 34 | 35 | @support_sz(4) 36 | def cubic(x): 37 | fw, to_dtype, eps = set_framework_dependencies(x) 38 | absx = fw.abs(x) 39 | absx2 = absx ** 2 40 | absx3 = absx ** 3 41 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + 42 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * 43 | to_dtype((1. < absx) & (absx <= 2.))) 44 | 45 | 46 | @support_sz(4) 47 | def lanczos2(x): 48 | fw, to_dtype, eps = set_framework_dependencies(x) 49 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / 50 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) 51 | 52 | 53 | @support_sz(6) 54 | def lanczos3(x): 55 | fw, to_dtype, eps = set_framework_dependencies(x) 56 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / 57 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) 58 | 59 | 60 | @support_sz(2) 61 | def linear(x): 62 | fw, to_dtype, eps = set_framework_dependencies(x) 63 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * 64 | to_dtype((0 <= x) & (x <= 1))) 65 | 66 | 67 | @support_sz(1) 68 | def box(x): 69 | fw, to_dtype, eps = set_framework_dependencies(x) 70 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) 71 | -------------------------------------------------------------------------------- /utils/vis_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class PCA(Function): 6 | @staticmethod 7 | def forward(ctx, sv): #sv corresponds to the slices values, it has dimensions nr_positions x val_full_dim 8 | 9 | # http://agnesmustar.com/2017/11/01/principal-component-analysis-pca-implemented-pytorch/ 10 | 11 | 12 | X=sv.detach().cpu()#we switch to cpu of memory issues when doing svd on really big imaes 13 | k=3 14 | # print("x is ", X.shape) 15 | X_mean = torch.mean(X,0) 16 | # print("x_mean is ", X_mean.shape) 17 | X = X - X_mean.expand_as(X) 18 | 19 | # U,S,V = torch.svd(torch.t(X)) 20 | U,S,V = torch.pca_lowrank( torch.t(X) ) 21 | C = torch.mm(X,U[:,:k]) 22 | # print("C has shape ", C.shape) 23 | # print("C min and max is ", C.min(), " ", C.max() ) 24 | C-=C.min() 25 | C/=C.max() 26 | # print("after normalization C min and max is ", C.min(), " ", C.max() ) 27 | 28 | return C 29 | 30 | 31 | #img supposed to be N,C,H,W 32 | def img_2_pca(img): 33 | assert img.shape[0]==1 34 | img_c=img.shape[1] 35 | img_h=img.shape[2] 36 | img_w=img.shape[3] 37 | vals = img.permute(0,2,3,1) #N,H,W,C 38 | c=PCA.apply(vals.view(-1,img_c)) 39 | c=c.view(1, img_h, img_w, 3).permute(0,3,1,2) #N,H,W,C to N,C,H,W 40 | return c 41 | 42 | --------------------------------------------------------------------------------