├── README.md ├── configs ├── sd15_feature_extractor.yaml ├── sd21_depth_prober.yaml └── sd21_feature_extractor.yaml ├── docs ├── index.html └── static │ ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ ├── index.css │ └── twentytwenty.css │ ├── images │ ├── architecture.png │ ├── classification_quantitative.png │ ├── correspondence_quantitative.png │ ├── correspondence_table.png │ ├── correspondences.png │ ├── depth │ │ ├── 1 │ │ │ ├── depth_input.png │ │ │ ├── depth_ours.png │ │ │ └── depth_sd.png │ │ ├── 2 │ │ │ ├── depth_input.png │ │ │ ├── depth_ours.png │ │ │ └── depth_sd.png │ │ └── 3 │ │ │ ├── depth_input.png │ │ │ ├── depth_ours.png │ │ │ └── depth_sd.png │ ├── depth_table.png │ ├── sem_seg │ │ ├── image_0_0.png │ │ ├── image_14_2.png │ │ ├── image_41_3.png │ │ ├── image_51_2.png │ │ ├── image_55_3.png │ │ ├── image_59_0.png │ │ ├── image_66_2.png │ │ ├── image_67_2.png │ │ ├── image_67_3.png │ │ ├── image_6_0.png │ │ ├── image_7_2.png │ │ ├── image_8_2.png │ │ ├── image_9_1.png │ │ ├── mask_0_0_ours.png │ │ ├── mask_0_0_sd.png │ │ ├── mask_14_2_ours.png │ │ ├── mask_14_2_sd.png │ │ ├── mask_41_3_ours.png │ │ ├── mask_41_3_sd.png │ │ ├── mask_51_2_ours.png │ │ ├── mask_51_2_sd.png │ │ ├── mask_55_3_ours.png │ │ ├── mask_55_3_sd.png │ │ ├── mask_59_0_ours.png │ │ ├── mask_59_0_sd.png │ │ ├── mask_66_2_ours.png │ │ ├── mask_66_2_sd.png │ │ ├── mask_67_2_ours.png │ │ ├── mask_67_2_sd.png │ │ ├── mask_67_3_ours.png │ │ ├── mask_67_3_sd.png │ │ ├── mask_6_0_ours.png │ │ ├── mask_6_0_sd.png │ │ ├── mask_7_2_ours.png │ │ ├── mask_7_2_sd.png │ │ ├── mask_8_2_ours.png │ │ ├── mask_8_2_sd.png │ │ ├── mask_9_1_ours.png │ │ └── mask_9_1_sd.png │ ├── sem_seg_quantitative.png │ └── teaser_fig.jpg │ ├── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ ├── jquery-3.2.1.min.js │ ├── jquery.event.move.js │ └── jquery.twentytwenty.js │ └── pdfs │ └── cleandift.pdf ├── notebooks ├── depth.ipynb ├── example_images │ ├── 2007_009889.jpg │ ├── 2008_004188.jpg │ └── depth │ │ ├── gt.png │ │ └── input.png └── get_correspondences.ipynb ├── requirements.txt ├── src ├── __init__.py ├── ae.py ├── dataloader.py ├── depth.py ├── layers.py ├── min_sd15.py ├── min_sd21.py ├── sd_feature_extraction.py └── utils.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 |

🧹CleanDIFT: Diffusion Features without Noise

2 |
3 | Nick Stracke* · 4 | Stefan A. Baumann* · 5 | Kolja Bauer* · 6 | Frank Fundel · 7 | Björn Ommer 8 |
9 |

10 | CompVis @ LMU Munich
CVPR 2025 (Oral) 11 |

12 | 13 | [![Project Page](https://img.shields.io/badge/Project-Page-blue)](https://compvis.github.io/cleandift/) 14 | [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/pdf/2412.03439) 15 | [![Weights](https://img.shields.io/badge/HuggingFace-Weights-orange)](https://huggingface.co/CompVis/cleandift) 16 | 17 | This repository contains the official implementation of the paper "CleanDIFT: Diffusion Features without Noise". 18 | 19 | We propose CleanDIFT, a novel method to extract noise-free, timestep-independent features by enabling diffusion models to work directly with clean input images. Our approach is efficient, training on a single GPU in just 30 minutes. 20 | 21 | ![teaser](./docs/static/images/teaser_fig.jpg) 22 | 23 | ## 🚀 Usage 24 | 25 | ### Setup 26 | 27 | Just clone the repo and install the requirements via `pip install -r requirements.txt`, then you're ready to go. 28 | 29 | ### Training 30 | 31 | In order to train a feature extractor on your own, you can run `python train.py`. The training script expects your data to be stored in `./data` with the following format: Single level directory with images named `filename.jpg` and corresponding json files `filename.json` that contain the key `caption`. 32 | 33 | ### Feature Extraction 34 | 35 | For feature extraction, please refer to one of the notebooks at [`notebooks`](https://github.com/CompVis/cleandift/tree/main/notebooks). We demonstrate how to extract features and use them for semantic correspondence detection and depth prediction. 36 | 37 | Our checkpoints are fully compatible with the `diffusers` library. If you already have a pipeline using SD 1.5 or SD 2.1 from `diffusers`, you can simply replace the U-Net state dict: 38 | 39 | ```python 40 | from diffusers import UNet2DConditionModel 41 | from huggingface_hub import hf_hub_download 42 | 43 | unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="unet") 44 | ckpt_pth = hf_hub_download(repo_id="CompVis/cleandift", filename="cleandift_sd21_unet.safetensors") 45 | state_dict = load_file(ckpt_pth) 46 | unet.load_state_dict(state_dict, strict=True) 47 | ``` 48 | 49 | #### Combination with Telling Left From Right (TLFR) 50 | Our checkpoints can also be combined with more advanced feature extraction methods than [DIFT](https://diffusionfeatures.github.io/), such as [TLFR](https://telling-left-from-right.github.io/). We provide an adapted version of their codebase that can be used to reproduce our state-of-the-art zero-shot semantic correspondence results in the [`geoaware-sc-eval`](https://github.com/CompVis/cleandift/tree/geoaware-sc-eval) branch. 51 | 52 | ## 🎓 Citation 53 | 54 | If you use this codebase or otherwise found our work valuable, please cite our paper: 55 | 56 | ```bibtex 57 | @inproceedings{stracke2025cleandift, 58 | title={CleanDIFT: Diffusion Features without Noise}, 59 | author={Nick Stracke and Stefan Andreas Baumann and Kolja Bauer and Frank Fundel and Björn Ommer}, 60 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 61 | year={2025} 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /configs/sd15_feature_extractor.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | max_val_steps: 10 3 | val_freq: 100 4 | checkpoint_freq: 100 5 | checkpoint_dir: ./checkpoints 6 | lr: 1e-5 7 | max_steps: null 8 | 9 | grad_accum_steps: 1 10 | 11 | data: 12 | _target_: src.dataloader.DataModule 13 | dataset_dir: ./data 14 | batch_size: 8 15 | img_size: 512 16 | 17 | model: 18 | _target_: src.sd_feature_extraction.StableFeatureAligner 19 | sd_version: sd15 20 | t_max: 999 # Max timestep used during training 21 | num_t_stratification_bins: 3 22 | train_unet: True 23 | learn_timestep: True 24 | use_text_condition: true 25 | 26 | ae: 27 | _target_: src.ae.AutoencoderKL 28 | repo: stable-diffusion-v1-5/stable-diffusion-v1-5 29 | mapping: 30 | _target_: src.utils.MappingSpec 31 | depth: 2 32 | width: 256 33 | d_ff: 768 34 | dropout: 0.0 35 | adapter_layer_class: src.sd_feature_extraction.FFNStack 36 | adapter_layer_params: 37 | depth: 3 38 | ffn_expansion: 1 39 | dim_cond: ${..mapping.width} 40 | feature_extractor_cls: src.sd_feature_extraction.SD15UNetFeatureExtractor 41 | feature_dims: 42 | mid: 1280 43 | us1: 1280 44 | us2: 1280 45 | us3: 1280 46 | us4: 1280 47 | us5: 1280 48 | us6: 1280 49 | us7: 640 50 | us8: 640 51 | us9: 640 52 | us10: 320 53 | 54 | 55 | lr_scheduler: 56 | name: constant_with_warmup 57 | num_warmup_steps: 2000 58 | num_training_steps: null 59 | scheduler_specific_kwargs: {} 60 | 61 | hydra: 62 | job: 63 | chdir: false -------------------------------------------------------------------------------- /configs/sd21_depth_prober.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | _target_: src.depth.DepthPred 5 | loss: 6 | _target_: src.depth.SigLoss 7 | model_config_path: ./configs/sd21_feature_extractor.yaml 8 | diffusion_image_size: 768 9 | channels: 1280 10 | base_model_timestep: 199 11 | use_base_model_features: false 12 | adapter_timestep: null 13 | interpolate_features: NONE 14 | 15 | hydra: 16 | job: 17 | chdir: false 18 | 19 | 20 | -------------------------------------------------------------------------------- /configs/sd21_feature_extractor.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | max_val_steps: 100 3 | val_freq: 100 4 | checkpoint_freq: 200 5 | checkpoint_dir: ./checkpoints 6 | lr: 1e-5 7 | max_steps: null 8 | 9 | grad_accum_steps: 1 10 | 11 | data: 12 | _target_: src.dataloader.DataModule 13 | dataset_dir: ./data 14 | batch_size: 8 15 | img_size: 768 16 | 17 | model: 18 | _target_: src.sd_feature_extraction.StableFeatureAligner 19 | sd_version: sd21 20 | t_max: 999 # Max timestep used during training 21 | num_t_stratification_bins: 3 22 | train_unet: True 23 | learn_timestep: True 24 | use_text_condition: true 25 | 26 | ae: 27 | _target_: src.ae.AutoencoderKL 28 | repo: stabilityai/stable-diffusion-2-1 29 | mapping: 30 | _target_: src.utils.MappingSpec 31 | depth: 2 32 | width: 256 33 | d_ff: 768 34 | dropout: 0.0 35 | adapter_layer_class: src.sd_feature_extraction.FFNStack 36 | adapter_layer_params: 37 | depth: 3 38 | ffn_expansion: 1 39 | dim_cond: ${..mapping.width} 40 | feature_extractor_cls: src.sd_feature_extraction.SD21UNetFeatureExtractor 41 | feature_dims: 42 | mid: 1280 43 | us1: 1280 44 | us2: 1280 45 | us3: 1280 46 | us4: 1280 47 | us5: 1280 48 | us6: 1280 49 | us7: 640 50 | us8: 640 51 | us9: 640 52 | us10: 320 53 | 54 | 55 | lr_scheduler: 56 | name: constant_with_warmup 57 | num_warmup_steps: 2000 58 | num_training_steps: null 59 | scheduler_specific_kwargs: {} 60 | 61 | hydra: 62 | job: 63 | chdir: false 64 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 32 | CleanDIFT: Diffusion Features without Noise 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 94 | 95 | 96 | 97 |
98 |
99 |
100 |
101 |
102 |

🧹 CleanDIFT: Diffusion Features without Noise

103 |
104 | 105 | 106 | Nick Stracke*, 107 | 108 | Stefan Andreas Baumann*, 109 | 110 | Kolja Bauer*, 111 | 112 | 113 | Frank Fundel, 114 | 115 | 116 | Björn Ommer 117 | 118 |
119 | 120 |
121 | CompVis @ LMU Munich   *equal contribution 122 | 123 |
124 |
125 | CVPR 2025 (Oral) 126 |
127 | 128 |
129 | 179 |
180 |
181 |
182 |
183 |
184 |
185 | 186 | 187 |
188 |
189 |
190 |
191 |
192 | 193 | 194 |
195 | 196 |
197 | 198 |
199 |
200 |
201 |
202 |

203 | TL;DR: Diffusion models learn powerful world 204 | representations that have proven valuable for tasks like semantic correspondence detection, depth 205 | estimation, semantic segmentation, and classification. 206 | However, diffusion models require noisy input images, which destroys information and introduces the 207 | noise level as a hyperparameter that needs to be tuned for each task. 208 | We propose a novel method to extract noise-free, timestep-independent 209 | features by enabling diffusion models to work directly with clean input images. Our approach 210 | is efficient, training on a single GPU in just 30 minutes. 211 |

212 | 213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 | 221 | 222 | 223 |
224 |
225 |
226 |
227 |

Overview

228 |
229 |

230 | Internal features from large-scale pre-trained diffusion models have recently been established as powerful 231 | semantic descriptors for a wide range of downstream tasks. Works that use these features generally need to 232 | add noise to images before passing them through the model to obtain the semantic features, as the models 233 | do not offer the most useful features when given images with little to no noise. We show that this noise 234 | has a critical impact on the usefulness of these features that cannot be remedied by ensembling with 235 | different random noises. 236 | We address this issue by introducing a lightweight, unsupervised fine-tuning method that enables diffusion 237 | backbones to provide high-quality, noise-free semantic features. We show that these features readily 238 | outperform previous diffusion features by a wide margin in a wide variety of extraction setups and 239 | downstream tasks, offering better performance than even ensemble-based methods at a fraction of the cost. 240 |

241 |
242 |
243 |
244 |
245 |
246 | 247 | 248 | 249 | 250 | 251 | 252 |
253 |
254 |
255 |
256 |

Clean Features → Clean Predictions

257 |
258 | 259 |

260 | We evaluate our features on a wide range of downstream tasks: unsupervised zero-shot semantic 261 | correspondence, monocular depth estimation, semantic segmentation, and classification. 262 | We compare our features against standard diffusion features, methods that combine diffusion features with 263 | additional features, and non-diffusion-based approaches. 264 |

265 | 266 |
267 |
268 |
269 | 270 |
271 |
272 |

Input Image

273 |
274 |
275 |

Depth Estimation

276 |
277 |
278 |

Input Image

279 |
280 |
281 |

Semantic Segmentation

282 |
283 |
284 | 285 | 286 | 287 |
288 |
289 | Input Image 290 |
291 |
292 |
293 |
294 | 295 |
296 |
297 | 298 |
299 |
300 |
301 |
302 | Input Image 303 |
304 |
305 |
306 |
307 | 308 |
309 |
310 | 311 |
312 |
313 |
314 |
315 | 316 | 317 |
318 |
319 | Input Image 320 |
321 |
322 |
323 |
324 | 325 |
326 |
327 | 328 |
329 |
330 |
331 |
332 | Input Image 333 |
334 |
335 |
336 |
337 | 338 |
339 |
340 | 341 |
342 |
343 |
344 |
345 | 346 | 347 |
348 |
349 | Input Image 350 |
351 |
352 |
353 |
354 | 355 |
356 |
357 | 358 |
359 |
360 |
361 |
362 | Input Image 363 |
364 |
365 |
366 |
367 | 368 |
369 |
370 | 371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 | 379 | 399 | 400 | 428 | 429 | 432 | 433 | 434 |

435 | We compare Depth Estimation and Semantic Segmentation using linear probes on standard diffusion features 436 | and our CleanDIFT features. 437 | Note how the CleanDIFT features are far less noisy when compared to the standard diffusion features. 438 | Depth probes are trained on NYUv2 dataset, Segmentation probes on PASCAL VOC. Standard diffusion features 439 | use t=100 for Semantic Segmentation and t=300 for depth prediction. 440 |

441 | 442 | 444 |

445 | Zero-Shot Semantic Correspondence matching using DIFT features with standard SD 2.1 (t=261) and our 446 | CleanDIFT 447 | features. 448 | Our clean features show significantly less incorrect matches than the standard diffusion features. 449 |

450 |
451 |
452 |
453 |
454 |
455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 |
463 |
464 |
465 |
466 |

How it works

467 |
468 | 469 |

470 | We train our feature extraction model to match the diffusion model's internal representations. 471 | We initialize the feature extraction model as a trainable copy of the diffusion model. 472 | Crucially, the feature extraction model is given the clean input image, while the diffusion model receives 473 | the noisy image and the corresponding timestep as input. 474 | Our goal is to obtain a single, noise-free feature map from the feature extraction model that consolidates 475 | the information of the diffusion model's timestep-dependent internal representations into a single one. 476 | To align our model's representations with the timestep-dependent diffusion model features during training, 477 | we introduce point-wise timestep-conditioned feature projection heads. 478 | The feature maps predicted by these projection heads are then aligned to the diffusion model's features. 479 | For feature extraction at inference time, we usually discard the projection heads and directly use the 480 | feature extraction model's internal representations. 481 | However, the projection heads can also be used to efficiently obtain feature maps for specific timesteps 482 | by reusing the feature extraction model's internal representations and passing them through the projection 483 | heads for different t values. 484 |

485 |
486 | 498 |
499 |
500 |
501 |
502 | 503 | 504 |
505 |
506 |
507 |
508 |

Quantitative Comparison

509 |
510 | 511 | 512 |

Zero-Shot Semantic Correspondence

513 | 515 |

516 | Zero-shot unsupervised semantic correspondence matching performance comparison on SPair-71k. Our improved 517 | features consistently lead to substantial improvements in matching performance. Numbers show our 518 | reproductions. 519 |

520 | 521 | 523 |

524 | We evaluate semantic correspondence matching accuracy for different noise levels. Our feature extractor 525 | outperforms the standard noisy diffusion features across all timesteps t. 526 | We additionally demonstrate that simply providing the diffusion model with a clean image and a non-zero 527 | timestep does not result in improved performance. 528 |

529 |

Monocular Depth Estimation

530 | 531 | 532 |

533 | We evaluate metric depth prediction on NYUv2 using a linear probe. 534 | Our clean features outperform the noisy features by a significant margin. Probes trained on the noisy 535 | features can be reused for the clean features, but incur a smaller performance gain. 536 |

537 |

Semantic Segmentation

538 | 540 | 541 |

542 | Performance on semantic segmentation for the PASCAL VOC dataset using linear probes. Our clean features 543 | outperform the noisy diffusion features for the best noising timestep t. 544 | Semantic segmentation performance of a standard diffusion model heavily depends on the used noising 545 | timestep. 546 | Unlike for semantic correspondence matching, the optimal t value appears to be around t=100. 547 |

548 |

Classification

549 | 550 | 551 |

552 | Classification performance on ImageNet1k, using kNN classifier with k=10 and cosine similarity as the 553 | distance metric. We sweep over different timesteps and feature maps. 554 | We find that the feature map with the lowest spatial resolution (feature map #0) yields the highest 555 | classification accuracy. 556 |

557 |
558 |
559 |
560 |
561 |
562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 |
573 |
574 |

BibTeX

575 |
@inproceedings{stracke2025cleandift,
576 |     title={CleanDIFT: Diffusion Features without Noise}, 
577 |     author={Nick Stracke and Stefan Andreas Baumann and Kolja Bauer and Frank Fundel and Björn Ommer},
578 |     booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
579 |     year={2025}
580 | }
581 |
582 |
583 | 584 | 585 | 586 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 638 | 639 | 640 | 641 | -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/static/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff} -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .methodname { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | /* width: fit-content; */ 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | width: 100%; 116 | } 117 | 118 | .results-carousel .twoitem { 119 | background-color: #000; 120 | display: flex; 121 | flex-direction: column; 122 | align-items: center; 123 | justify-content: center; 124 | height: 95%; 125 | margin: 5px; 126 | overflow: hidden; 127 | border: 1px solid #bbb; 128 | border-radius: 10px; 129 | padding: 0; 130 | font-size: 0; 131 | } 132 | 133 | .slider-navigation-previous { 134 | top: 47.5%; 135 | margin-top: -15px; 136 | left: calc(0% + 15px); 137 | box-shadow: 0 0 12px rgba(0, 0, 0, 1); 138 | } 139 | 140 | .slider-navigation-next { 141 | top: 47.5%; 142 | margin-top: -15px; 143 | left: calc(100% - 15px - 42px); 144 | box-shadow: 0 0 12px rgba(0, 0, 0, 1); 145 | } 146 | 147 | .slider-pagination { 148 | top: calc(100% - 22px); 149 | } 150 | .slider-pagination .slider-page { 151 | background-color: #FFF; 152 | box-shadow: 0 0 4px rgba(0, 0, 0, 1); 153 | } 154 | 155 | .selectable{ 156 | -webkit-touch-callout: all; /* iOS Safari */ 157 | -webkit-user-select: all; /* Safari */ 158 | -khtml-user-select: all; /* Konqueror HTML */ 159 | -moz-user-select: all; /* Firefox */ 160 | -ms-user-select: all; /* Internet Explorer/Edge */ 161 | user-select: all; /* Chrome and Opera */ 162 | } 163 | 164 | 165 | .column-titles h3, .column-title { 166 | font-size: 1.2rem; /* Adjust font size */ 167 | font-weight: bold; /* Make it prominent */ 168 | text-align: center; /* Center-align titles */ 169 | margin-bottom: 0px; /* Add spacing between the title and content */ 170 | } 171 | 172 | .column-title.small { 173 | font-size: 1rem; /* Adjust font size */ 174 | } -------------------------------------------------------------------------------- /docs/static/css/twentytwenty.css: -------------------------------------------------------------------------------- 1 | .twentytwenty-horizontal .twentytwenty-handle:before, 2 | .twentytwenty-horizontal .twentytwenty-handle:after, 3 | .twentytwenty-vertical .twentytwenty-handle:before, 4 | .twentytwenty-vertical .twentytwenty-handle:after { 5 | content: " "; 6 | display: block; 7 | background: white; 8 | position: absolute; 9 | z-index: 30; 10 | -webkit-box-shadow: 0px 0px 12px rgba(51, 51, 51, 0.5); 11 | -moz-box-shadow: 0px 0px 12px rgba(51, 51, 51, 0.5); 12 | box-shadow: 0px 0px 12px rgba(51, 51, 51, 0.5); 13 | } 14 | 15 | .twentytwenty-horizontal .twentytwenty-handle:before, 16 | .twentytwenty-horizontal .twentytwenty-handle:after { 17 | width: 3px; 18 | height: 9999px; 19 | left: 50%; 20 | margin-left: -1.5px; 21 | } 22 | 23 | .twentytwenty-vertical .twentytwenty-handle:before, 24 | .twentytwenty-vertical .twentytwenty-handle:after { 25 | width: 9999px; 26 | height: 3px; 27 | top: 50%; 28 | margin-top: -1.5px; 29 | } 30 | 31 | .twentytwenty-before-label, 32 | .twentytwenty-after-label, 33 | .twentytwenty-overlay { 34 | position: absolute; 35 | top: 0; 36 | width: 100%; 37 | height: 100%; 38 | } 39 | 40 | .twentytwenty-before-label, 41 | .twentytwenty-after-label, 42 | .twentytwenty-overlay { 43 | -webkit-transition-duration: 0.5s; 44 | -moz-transition-duration: 0.5s; 45 | transition-duration: 0.5s; 46 | } 47 | 48 | .twentytwenty-before-label, 49 | .twentytwenty-after-label { 50 | -webkit-transition-property: opacity; 51 | -moz-transition-property: opacity; 52 | transition-property: opacity; 53 | } 54 | 55 | .twentytwenty-before-label:before, 56 | .twentytwenty-after-label:before { 57 | color: white; 58 | font-size: 13px; 59 | letter-spacing: 0.1em; 60 | } 61 | 62 | .twentytwenty-before-label:before, 63 | .twentytwenty-after-label:before { 64 | position: absolute; 65 | background: rgba(255, 255, 255, 0.2); 66 | line-height: 38px; 67 | padding: 0 20px; 68 | -webkit-border-radius: 2px; 69 | -moz-border-radius: 2px; 70 | border-radius: 2px; 71 | } 72 | 73 | .twentytwenty-horizontal .twentytwenty-before-label:before, 74 | .twentytwenty-horizontal .twentytwenty-after-label:before { 75 | top: 50%; 76 | margin-top: -19px; 77 | } 78 | 79 | .twentytwenty-vertical .twentytwenty-before-label:before, 80 | .twentytwenty-vertical .twentytwenty-after-label:before { 81 | left: 50%; 82 | margin-left: -45px; 83 | text-align: center; 84 | width: 90px; 85 | } 86 | 87 | .twentytwenty-left-arrow, 88 | .twentytwenty-right-arrow, 89 | .twentytwenty-up-arrow, 90 | .twentytwenty-down-arrow { 91 | width: 0; 92 | height: 0; 93 | border: 6px inset transparent; 94 | position: absolute; 95 | } 96 | 97 | .twentytwenty-left-arrow, 98 | .twentytwenty-right-arrow { 99 | top: 50%; 100 | margin-top: -6px; 101 | } 102 | 103 | .twentytwenty-up-arrow, 104 | .twentytwenty-down-arrow { 105 | left: 50%; 106 | margin-left: -6px; 107 | } 108 | 109 | .twentytwenty-container { 110 | -webkit-box-sizing: content-box; 111 | -moz-box-sizing: content-box; 112 | box-sizing: content-box; 113 | z-index: 0; 114 | overflow: hidden; 115 | position: relative; 116 | -webkit-user-select: none; 117 | -moz-user-select: none; 118 | -ms-user-select: none; 119 | width: auto; 120 | } 121 | 122 | .twentytwenty-container .cmpcontent { 123 | display: block; 124 | max-width: 100%; 125 | width: auto; 126 | } 127 | 128 | .twentytwenty-container .cmpcontent.twentytwenty-before { 129 | position: absolute; 130 | top: 0; 131 | bottom: 0; 132 | left: 0; 133 | right: 0; 134 | } 135 | 136 | .twentytwenty-container cmpcomponent { 137 | width: 100%; 138 | max-width: 100%; 139 | } 140 | 141 | .twentytwenty-container.active .twentytwenty-overlay, 142 | .twentytwenty-container.active :hover.twentytwenty-overlay { 143 | background: rgba(0, 0, 0, 0); 144 | } 145 | 146 | .twentytwenty-container.active .twentytwenty-overlay .twentytwenty-before-label, 147 | .twentytwenty-container.active .twentytwenty-overlay .twentytwenty-after-label, 148 | .twentytwenty-container.active 149 | :hover.twentytwenty-overlay 150 | .twentytwenty-before-label, 151 | .twentytwenty-container.active 152 | :hover.twentytwenty-overlay 153 | .twentytwenty-after-label { 154 | opacity: 0; 155 | } 156 | 157 | .twentytwenty-container * { 158 | -webkit-box-sizing: content-box; 159 | -moz-box-sizing: content-box; 160 | box-sizing: content-box; 161 | } 162 | 163 | .twentytwenty-before-label { 164 | opacity: 0; 165 | } 166 | 167 | .twentytwenty-before-label:before { 168 | content: attr(data-content); 169 | } 170 | 171 | .twentytwenty-after-label { 172 | opacity: 0; 173 | } 174 | 175 | .twentytwenty-after-label:before { 176 | content: attr(data-content); 177 | } 178 | 179 | .twentytwenty-horizontal .twentytwenty-before-label:before { 180 | left: 10px; 181 | } 182 | 183 | .twentytwenty-horizontal .twentytwenty-after-label:before { 184 | right: 10px; 185 | } 186 | 187 | .twentytwenty-vertical .twentytwenty-before-label:before { 188 | top: 10px; 189 | } 190 | 191 | .twentytwenty-vertical .twentytwenty-after-label:before { 192 | bottom: 10px; 193 | } 194 | 195 | .twentytwenty-overlay { 196 | -webkit-transition-property: background; 197 | -moz-transition-property: background; 198 | transition-property: background; 199 | background: rgba(0, 0, 0, 0); 200 | z-index: 25; 201 | } 202 | 203 | .twentytwenty-overlay:hover { 204 | background: rgba(0, 0, 0, 0.5); 205 | } 206 | 207 | .twentytwenty-overlay:hover .twentytwenty-after-label { 208 | opacity: 1; 209 | } 210 | 211 | .twentytwenty-overlay:hover .twentytwenty-before-label { 212 | opacity: 1; 213 | } 214 | 215 | .twentytwenty-before { 216 | z-index: 20; 217 | } 218 | 219 | .twentytwenty-after { 220 | z-index: 0; 221 | } 222 | 223 | .twentytwenty-handle { 224 | height: 38px; 225 | width: 38px; 226 | position: absolute; 227 | left: 50%; 228 | top: 50%; 229 | margin-left: -22px; 230 | margin-top: -22px; 231 | border: 3px solid white; 232 | -webkit-border-radius: 1000px; 233 | -moz-border-radius: 1000px; 234 | border-radius: 1000px; 235 | -webkit-box-shadow: 0px 0px 12px rgba(51, 51, 51, 0.5); 236 | -moz-box-shadow: 0px 0px 12px rgba(51, 51, 51, 0.5); 237 | box-shadow: 0px 0px 12px rgba(51, 51, 51, 0.5); 238 | z-index: 40; 239 | cursor: pointer; 240 | } 241 | 242 | .twentytwenty-horizontal .twentytwenty-handle:before { 243 | bottom: 50%; 244 | margin-bottom: 22px; 245 | -webkit-box-shadow: 0 3px 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 246 | -moz-box-shadow: 0 3px 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 247 | box-shadow: 0 3px 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 248 | } 249 | 250 | .twentytwenty-horizontal .twentytwenty-handle:after { 251 | top: 50%; 252 | margin-top: 22px; 253 | -webkit-box-shadow: 0 -3px 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 254 | -moz-box-shadow: 0 -3px 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 255 | box-shadow: 0 -3px 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 256 | } 257 | 258 | .twentytwenty-vertical .twentytwenty-handle:before { 259 | left: 50%; 260 | margin-left: 22px; 261 | -webkit-box-shadow: 3px 0 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 262 | -moz-box-shadow: 3px 0 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 263 | box-shadow: 3px 0 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 264 | } 265 | 266 | .twentytwenty-vertical .twentytwenty-handle:after { 267 | right: 50%; 268 | margin-right: 22px; 269 | -webkit-box-shadow: -3px 0 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 270 | -moz-box-shadow: -3px 0 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 271 | box-shadow: -3px 0 0 white, 0px 0px 12px rgba(51, 51, 51, 0.5); 272 | } 273 | 274 | .twentytwenty-left-arrow { 275 | border-right: 6px solid white; 276 | left: 50%; 277 | margin-left: -17px; 278 | } 279 | 280 | .twentytwenty-right-arrow { 281 | border-left: 6px solid white; 282 | right: 50%; 283 | margin-right: -17px; 284 | } 285 | 286 | .twentytwenty-up-arrow { 287 | border-bottom: 6px solid white; 288 | top: 50%; 289 | margin-top: -17px; 290 | } 291 | 292 | .twentytwenty-down-arrow { 293 | border-top: 6px solid white; 294 | bottom: 50%; 295 | margin-bottom: -17px; 296 | } 297 | 298 | .input-image-container { 299 | width: 20%; 300 | } 301 | 302 | .image-slider-row { 303 | display: flex; 304 | justify-content: space-between; 305 | margin-bottom: 20px; /* Space between rows */ 306 | align-items: center; 307 | } 308 | 309 | .our-container { 310 | display: flex; /* Enables Flexbox */ 311 | gap: 20px; /* Adds space between columns */ 312 | margin-bottom: 20px; 313 | } 314 | 315 | .our-container.no-margin { 316 | margin-bottom: 0; 317 | } 318 | 319 | @media (max-width: 640px) { 320 | .our-container { 321 | flex-direction: column; /* Stacks columns vertically */ 322 | gap: 0px; /* Removes space between columns */ 323 | margin-bottom: 20px; 324 | } 325 | } 326 | 327 | .our-column { 328 | flex: 2 1 20%; /* Makes columns equal width */ 329 | } 330 | 331 | .our-column.depth { 332 | flex: 1.335; /* Makes columns equal width */ 333 | } 334 | 335 | .our-column.seg { 336 | flex: 1; /* Makes columns equal width */ 337 | } 338 | 339 | -------------------------------------------------------------------------------- /docs/static/images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/architecture.png -------------------------------------------------------------------------------- /docs/static/images/classification_quantitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/classification_quantitative.png -------------------------------------------------------------------------------- /docs/static/images/correspondence_quantitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/correspondence_quantitative.png -------------------------------------------------------------------------------- /docs/static/images/correspondence_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/correspondence_table.png -------------------------------------------------------------------------------- /docs/static/images/correspondences.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/correspondences.png -------------------------------------------------------------------------------- /docs/static/images/depth/1/depth_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth/1/depth_input.png -------------------------------------------------------------------------------- /docs/static/images/depth/1/depth_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth/1/depth_ours.png -------------------------------------------------------------------------------- /docs/static/images/depth/1/depth_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth/1/depth_sd.png -------------------------------------------------------------------------------- /docs/static/images/depth/2/depth_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth/2/depth_input.png -------------------------------------------------------------------------------- /docs/static/images/depth/2/depth_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth/2/depth_ours.png -------------------------------------------------------------------------------- /docs/static/images/depth/2/depth_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth/2/depth_sd.png -------------------------------------------------------------------------------- /docs/static/images/depth/3/depth_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth/3/depth_input.png -------------------------------------------------------------------------------- /docs/static/images/depth/3/depth_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth/3/depth_ours.png -------------------------------------------------------------------------------- /docs/static/images/depth/3/depth_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth/3/depth_sd.png -------------------------------------------------------------------------------- /docs/static/images/depth_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/depth_table.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_0_0.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_14_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_14_2.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_41_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_41_3.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_51_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_51_2.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_55_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_55_3.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_59_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_59_0.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_66_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_66_2.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_67_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_67_2.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_67_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_67_3.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_6_0.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_7_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_7_2.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_8_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_8_2.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/image_9_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/image_9_1.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_0_0_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_0_0_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_0_0_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_0_0_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_14_2_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_14_2_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_14_2_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_14_2_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_41_3_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_41_3_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_41_3_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_41_3_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_51_2_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_51_2_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_51_2_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_51_2_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_55_3_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_55_3_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_55_3_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_55_3_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_59_0_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_59_0_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_59_0_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_59_0_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_66_2_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_66_2_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_66_2_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_66_2_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_67_2_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_67_2_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_67_2_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_67_2_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_67_3_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_67_3_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_67_3_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_67_3_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_6_0_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_6_0_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_6_0_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_6_0_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_7_2_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_7_2_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_7_2_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_7_2_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_8_2_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_8_2_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_8_2_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_8_2_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_9_1_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_9_1_ours.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg/mask_9_1_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg/mask_9_1_sd.png -------------------------------------------------------------------------------- /docs/static/images/sem_seg_quantitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/sem_seg_quantitative.png -------------------------------------------------------------------------------- /docs/static/images/teaser_fig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/images/teaser_fig.jpg -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.js: -------------------------------------------------------------------------------- 1 | (function webpackUniversalModuleDefinition(root, factory) { 2 | if(typeof exports === 'object' && typeof module === 'object') 3 | module.exports = factory(); 4 | else if(typeof define === 'function' && define.amd) 5 | define([], factory); 6 | else if(typeof exports === 'object') 7 | exports["bulmaSlider"] = factory(); 8 | else 9 | root["bulmaSlider"] = factory(); 10 | })(typeof self !== 'undefined' ? self : this, function() { 11 | return /******/ (function(modules) { // webpackBootstrap 12 | /******/ // The module cache 13 | /******/ var installedModules = {}; 14 | /******/ 15 | /******/ // The require function 16 | /******/ function __webpack_require__(moduleId) { 17 | /******/ 18 | /******/ // Check if module is in cache 19 | /******/ if(installedModules[moduleId]) { 20 | /******/ return installedModules[moduleId].exports; 21 | /******/ } 22 | /******/ // Create a new module (and put it into the cache) 23 | /******/ var module = installedModules[moduleId] = { 24 | /******/ i: moduleId, 25 | /******/ l: false, 26 | /******/ exports: {} 27 | /******/ }; 28 | /******/ 29 | /******/ // Execute the module function 30 | /******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__); 31 | /******/ 32 | /******/ // Flag the module as loaded 33 | /******/ module.l = true; 34 | /******/ 35 | /******/ // Return the exports of the module 36 | /******/ return module.exports; 37 | /******/ } 38 | /******/ 39 | /******/ 40 | /******/ // expose the modules object (__webpack_modules__) 41 | /******/ __webpack_require__.m = modules; 42 | /******/ 43 | /******/ // expose the module cache 44 | /******/ __webpack_require__.c = installedModules; 45 | /******/ 46 | /******/ // define getter function for harmony exports 47 | /******/ __webpack_require__.d = function(exports, name, getter) { 48 | /******/ if(!__webpack_require__.o(exports, name)) { 49 | /******/ Object.defineProperty(exports, name, { 50 | /******/ configurable: false, 51 | /******/ enumerable: true, 52 | /******/ get: getter 53 | /******/ }); 54 | /******/ } 55 | /******/ }; 56 | /******/ 57 | /******/ // getDefaultExport function for compatibility with non-harmony modules 58 | /******/ __webpack_require__.n = function(module) { 59 | /******/ var getter = module && module.__esModule ? 60 | /******/ function getDefault() { return module['default']; } : 61 | /******/ function getModuleExports() { return module; }; 62 | /******/ __webpack_require__.d(getter, 'a', getter); 63 | /******/ return getter; 64 | /******/ }; 65 | /******/ 66 | /******/ // Object.prototype.hasOwnProperty.call 67 | /******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); }; 68 | /******/ 69 | /******/ // __webpack_public_path__ 70 | /******/ __webpack_require__.p = ""; 71 | /******/ 72 | /******/ // Load entry module and return exports 73 | /******/ return __webpack_require__(__webpack_require__.s = 0); 74 | /******/ }) 75 | /************************************************************************/ 76 | /******/ ([ 77 | /* 0 */ 78 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 79 | 80 | "use strict"; 81 | Object.defineProperty(__webpack_exports__, "__esModule", { value: true }); 82 | /* harmony export (binding) */ __webpack_require__.d(__webpack_exports__, "isString", function() { return isString; }); 83 | /* harmony import */ var __WEBPACK_IMPORTED_MODULE_0__events__ = __webpack_require__(1); 84 | var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; }; 85 | 86 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 87 | 88 | var _typeof = typeof Symbol === "function" && typeof Symbol.iterator === "symbol" ? function (obj) { return typeof obj; } : function (obj) { return obj && typeof Symbol === "function" && obj.constructor === Symbol && obj !== Symbol.prototype ? "symbol" : typeof obj; }; 89 | 90 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 91 | 92 | function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; } 93 | 94 | function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; } 95 | 96 | 97 | 98 | var isString = function isString(unknown) { 99 | return typeof unknown === 'string' || !!unknown && (typeof unknown === 'undefined' ? 'undefined' : _typeof(unknown)) === 'object' && Object.prototype.toString.call(unknown) === '[object String]'; 100 | }; 101 | 102 | var bulmaSlider = function (_EventEmitter) { 103 | _inherits(bulmaSlider, _EventEmitter); 104 | 105 | function bulmaSlider(selector) { 106 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 107 | 108 | _classCallCheck(this, bulmaSlider); 109 | 110 | var _this = _possibleConstructorReturn(this, (bulmaSlider.__proto__ || Object.getPrototypeOf(bulmaSlider)).call(this)); 111 | 112 | _this.element = typeof selector === 'string' ? document.querySelector(selector) : selector; 113 | // An invalid selector or non-DOM node has been provided. 114 | if (!_this.element) { 115 | throw new Error('An invalid selector or non-DOM node has been provided.'); 116 | } 117 | 118 | _this._clickEvents = ['click']; 119 | /// Set default options and merge with instance defined 120 | _this.options = _extends({}, options); 121 | 122 | _this.onSliderInput = _this.onSliderInput.bind(_this); 123 | 124 | _this.init(); 125 | return _this; 126 | } 127 | 128 | /** 129 | * Initiate all DOM element containing selector 130 | * @method 131 | * @return {Array} Array of all slider instances 132 | */ 133 | 134 | 135 | _createClass(bulmaSlider, [{ 136 | key: 'init', 137 | 138 | 139 | /** 140 | * Initiate plugin 141 | * @method init 142 | * @return {void} 143 | */ 144 | value: function init() { 145 | this._id = 'bulmaSlider' + new Date().getTime() + Math.floor(Math.random() * Math.floor(9999)); 146 | this.output = this._findOutputForSlider(); 147 | 148 | this._bindEvents(); 149 | 150 | if (this.output) { 151 | if (this.element.classList.contains('has-output-tooltip')) { 152 | // Get new output position 153 | var newPosition = this._getSliderOutputPosition(); 154 | 155 | // Set output position 156 | this.output.style['left'] = newPosition.position; 157 | } 158 | } 159 | 160 | this.emit('bulmaslider:ready', this.element.value); 161 | } 162 | }, { 163 | key: '_findOutputForSlider', 164 | value: function _findOutputForSlider() { 165 | var _this2 = this; 166 | 167 | var result = null; 168 | var outputs = document.getElementsByTagName('output') || []; 169 | 170 | Array.from(outputs).forEach(function (output) { 171 | if (output.htmlFor == _this2.element.getAttribute('id')) { 172 | result = output; 173 | return true; 174 | } 175 | }); 176 | return result; 177 | } 178 | }, { 179 | key: '_getSliderOutputPosition', 180 | value: function _getSliderOutputPosition() { 181 | // Update output position 182 | var newPlace, minValue; 183 | 184 | var style = window.getComputedStyle(this.element, null); 185 | // Measure width of range input 186 | var sliderWidth = parseInt(style.getPropertyValue('width'), 10); 187 | 188 | // Figure out placement percentage between left and right of input 189 | if (!this.element.getAttribute('min')) { 190 | minValue = 0; 191 | } else { 192 | minValue = this.element.getAttribute('min'); 193 | } 194 | var newPoint = (this.element.value - minValue) / (this.element.getAttribute('max') - minValue); 195 | 196 | // Prevent bubble from going beyond left or right (unsupported browsers) 197 | if (newPoint < 0) { 198 | newPlace = 0; 199 | } else if (newPoint > 1) { 200 | newPlace = sliderWidth; 201 | } else { 202 | newPlace = sliderWidth * newPoint; 203 | } 204 | 205 | return { 206 | 'position': newPlace + 'px' 207 | }; 208 | } 209 | 210 | /** 211 | * Bind all events 212 | * @method _bindEvents 213 | * @return {void} 214 | */ 215 | 216 | }, { 217 | key: '_bindEvents', 218 | value: function _bindEvents() { 219 | if (this.output) { 220 | // Add event listener to update output when slider value change 221 | this.element.addEventListener('input', this.onSliderInput, false); 222 | } 223 | } 224 | }, { 225 | key: 'onSliderInput', 226 | value: function onSliderInput(e) { 227 | e.preventDefault(); 228 | 229 | if (this.element.classList.contains('has-output-tooltip')) { 230 | // Get new output position 231 | var newPosition = this._getSliderOutputPosition(); 232 | 233 | // Set output position 234 | this.output.style['left'] = newPosition.position; 235 | } 236 | 237 | // Check for prefix and postfix 238 | var prefix = this.output.hasAttribute('data-prefix') ? this.output.getAttribute('data-prefix') : ''; 239 | var postfix = this.output.hasAttribute('data-postfix') ? this.output.getAttribute('data-postfix') : ''; 240 | 241 | // Update output with slider value 242 | this.output.value = prefix + this.element.value + postfix; 243 | 244 | this.emit('bulmaslider:ready', this.element.value); 245 | } 246 | }], [{ 247 | key: 'attach', 248 | value: function attach() { 249 | var _this3 = this; 250 | 251 | var selector = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 'input[type="range"].slider'; 252 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 253 | 254 | var instances = new Array(); 255 | 256 | var elements = isString(selector) ? document.querySelectorAll(selector) : Array.isArray(selector) ? selector : [selector]; 257 | elements.forEach(function (element) { 258 | if (typeof element[_this3.constructor.name] === 'undefined') { 259 | var instance = new bulmaSlider(element, options); 260 | element[_this3.constructor.name] = instance; 261 | instances.push(instance); 262 | } else { 263 | instances.push(element[_this3.constructor.name]); 264 | } 265 | }); 266 | 267 | return instances; 268 | } 269 | }]); 270 | 271 | return bulmaSlider; 272 | }(__WEBPACK_IMPORTED_MODULE_0__events__["a" /* default */]); 273 | 274 | /* harmony default export */ __webpack_exports__["default"] = (bulmaSlider); 275 | 276 | /***/ }), 277 | /* 1 */ 278 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 279 | 280 | "use strict"; 281 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 282 | 283 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 284 | 285 | var EventEmitter = function () { 286 | function EventEmitter() { 287 | var listeners = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : []; 288 | 289 | _classCallCheck(this, EventEmitter); 290 | 291 | this._listeners = new Map(listeners); 292 | this._middlewares = new Map(); 293 | } 294 | 295 | _createClass(EventEmitter, [{ 296 | key: "listenerCount", 297 | value: function listenerCount(eventName) { 298 | if (!this._listeners.has(eventName)) { 299 | return 0; 300 | } 301 | 302 | var eventListeners = this._listeners.get(eventName); 303 | return eventListeners.length; 304 | } 305 | }, { 306 | key: "removeListeners", 307 | value: function removeListeners() { 308 | var _this = this; 309 | 310 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 311 | var middleware = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false; 312 | 313 | if (eventName !== null) { 314 | if (Array.isArray(eventName)) { 315 | name.forEach(function (e) { 316 | return _this.removeListeners(e, middleware); 317 | }); 318 | } else { 319 | this._listeners.delete(eventName); 320 | 321 | if (middleware) { 322 | this.removeMiddleware(eventName); 323 | } 324 | } 325 | } else { 326 | this._listeners = new Map(); 327 | } 328 | } 329 | }, { 330 | key: "middleware", 331 | value: function middleware(eventName, fn) { 332 | var _this2 = this; 333 | 334 | if (Array.isArray(eventName)) { 335 | name.forEach(function (e) { 336 | return _this2.middleware(e, fn); 337 | }); 338 | } else { 339 | if (!Array.isArray(this._middlewares.get(eventName))) { 340 | this._middlewares.set(eventName, []); 341 | } 342 | 343 | this._middlewares.get(eventName).push(fn); 344 | } 345 | } 346 | }, { 347 | key: "removeMiddleware", 348 | value: function removeMiddleware() { 349 | var _this3 = this; 350 | 351 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 352 | 353 | if (eventName !== null) { 354 | if (Array.isArray(eventName)) { 355 | name.forEach(function (e) { 356 | return _this3.removeMiddleware(e); 357 | }); 358 | } else { 359 | this._middlewares.delete(eventName); 360 | } 361 | } else { 362 | this._middlewares = new Map(); 363 | } 364 | } 365 | }, { 366 | key: "on", 367 | value: function on(name, callback) { 368 | var _this4 = this; 369 | 370 | var once = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 371 | 372 | if (Array.isArray(name)) { 373 | name.forEach(function (e) { 374 | return _this4.on(e, callback); 375 | }); 376 | } else { 377 | name = name.toString(); 378 | var split = name.split(/,|, | /); 379 | 380 | if (split.length > 1) { 381 | split.forEach(function (e) { 382 | return _this4.on(e, callback); 383 | }); 384 | } else { 385 | if (!Array.isArray(this._listeners.get(name))) { 386 | this._listeners.set(name, []); 387 | } 388 | 389 | this._listeners.get(name).push({ once: once, callback: callback }); 390 | } 391 | } 392 | } 393 | }, { 394 | key: "once", 395 | value: function once(name, callback) { 396 | this.on(name, callback, true); 397 | } 398 | }, { 399 | key: "emit", 400 | value: function emit(name, data) { 401 | var _this5 = this; 402 | 403 | var silent = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 404 | 405 | name = name.toString(); 406 | var listeners = this._listeners.get(name); 407 | var middlewares = null; 408 | var doneCount = 0; 409 | var execute = silent; 410 | 411 | if (Array.isArray(listeners)) { 412 | listeners.forEach(function (listener, index) { 413 | // Start Middleware checks unless we're doing a silent emit 414 | if (!silent) { 415 | middlewares = _this5._middlewares.get(name); 416 | // Check and execute Middleware 417 | if (Array.isArray(middlewares)) { 418 | middlewares.forEach(function (middleware) { 419 | middleware(data, function () { 420 | var newData = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 421 | 422 | if (newData !== null) { 423 | data = newData; 424 | } 425 | doneCount++; 426 | }, name); 427 | }); 428 | 429 | if (doneCount >= middlewares.length) { 430 | execute = true; 431 | } 432 | } else { 433 | execute = true; 434 | } 435 | } 436 | 437 | // If Middleware checks have been passed, execute 438 | if (execute) { 439 | if (listener.once) { 440 | listeners[index] = null; 441 | } 442 | listener.callback(data); 443 | } 444 | }); 445 | 446 | // Dirty way of removing used Events 447 | while (listeners.indexOf(null) !== -1) { 448 | listeners.splice(listeners.indexOf(null), 1); 449 | } 450 | } 451 | } 452 | }]); 453 | 454 | return EventEmitter; 455 | }(); 456 | 457 | /* harmony default export */ __webpack_exports__["a"] = (EventEmitter); 458 | 459 | /***/ }) 460 | /******/ ])["default"]; 461 | }); -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/jquery.event.move.js: -------------------------------------------------------------------------------- 1 | // DOM.event.move 2 | // 3 | // 2.0.0 4 | // 5 | // Stephen Band 6 | // 7 | // Triggers 'movestart', 'move' and 'moveend' events after 8 | // mousemoves following a mousedown cross a distance threshold, 9 | // similar to the native 'dragstart', 'drag' and 'dragend' events. 10 | // Move events are throttled to animation frames. Move event objects 11 | // have the properties: 12 | // 13 | // pageX: 14 | // pageY: Page coordinates of pointer. 15 | // startX: 16 | // startY: Page coordinates of pointer at movestart. 17 | // distX: 18 | // distY: Distance the pointer has moved since movestart. 19 | // deltaX: 20 | // deltaY: Distance the finger has moved since last event. 21 | // velocityX: 22 | // velocityY: Average velocity over last few events. 23 | 24 | 25 | (function(fn) { 26 | if (typeof define === 'function' && define.amd) { 27 | define([], fn); 28 | } else if ((typeof module !== "undefined" && module !== null) && module.exports) { 29 | module.exports = fn; 30 | } else { 31 | fn(); 32 | } 33 | })(function(){ 34 | var assign = Object.assign || window.jQuery && jQuery.extend; 35 | 36 | // Number of pixels a pressed pointer travels before movestart 37 | // event is fired. 38 | var threshold = 8; 39 | 40 | // Shim for requestAnimationFrame, falling back to timer. See: 41 | // see http://paulirish.com/2011/requestanimationframe-for-smart-animating/ 42 | var requestFrame = (function(){ 43 | return ( 44 | window.requestAnimationFrame || 45 | window.webkitRequestAnimationFrame || 46 | window.mozRequestAnimationFrame || 47 | window.oRequestAnimationFrame || 48 | window.msRequestAnimationFrame || 49 | function(fn, element){ 50 | return window.setTimeout(function(){ 51 | fn(); 52 | }, 25); 53 | } 54 | ); 55 | })(); 56 | 57 | // Shim for customEvent 58 | // see https://developer.mozilla.org/en-US/docs/Web/API/CustomEvent/CustomEvent#Polyfill 59 | (function () { 60 | if ( typeof window.CustomEvent === "function" ) return false; 61 | function CustomEvent ( event, params ) { 62 | params = params || { bubbles: false, cancelable: false, detail: undefined }; 63 | var evt = document.createEvent( 'CustomEvent' ); 64 | evt.initCustomEvent( event, params.bubbles, params.cancelable, params.detail ); 65 | return evt; 66 | } 67 | 68 | CustomEvent.prototype = window.Event.prototype; 69 | window.CustomEvent = CustomEvent; 70 | })(); 71 | 72 | var ignoreTags = { 73 | textarea: true, 74 | input: true, 75 | select: true, 76 | button: true 77 | }; 78 | 79 | var mouseevents = { 80 | move: 'mousemove', 81 | cancel: 'mouseup dragstart', 82 | end: 'mouseup' 83 | }; 84 | 85 | var touchevents = { 86 | move: 'touchmove', 87 | cancel: 'touchend', 88 | end: 'touchend' 89 | }; 90 | 91 | var rspaces = /\s+/; 92 | 93 | 94 | // DOM Events 95 | 96 | var eventOptions = { bubbles: true, cancelable: true }; 97 | 98 | var eventsSymbol = typeof Symbol === "function" ? Symbol('events') : {}; 99 | 100 | function createEvent(type) { 101 | return new CustomEvent(type, eventOptions); 102 | } 103 | 104 | function getEvents(node) { 105 | return node[eventsSymbol] || (node[eventsSymbol] = {}); 106 | } 107 | 108 | function on(node, types, fn, data, selector) { 109 | types = types.split(rspaces); 110 | 111 | var events = getEvents(node); 112 | var i = types.length; 113 | var handlers, type; 114 | 115 | function handler(e) { fn(e, data); } 116 | 117 | while (i--) { 118 | type = types[i]; 119 | handlers = events[type] || (events[type] = []); 120 | handlers.push([fn, handler]); 121 | node.addEventListener(type, handler); 122 | } 123 | } 124 | 125 | function off(node, types, fn, selector) { 126 | types = types.split(rspaces); 127 | 128 | var events = getEvents(node); 129 | var i = types.length; 130 | var type, handlers, k; 131 | 132 | if (!events) { return; } 133 | 134 | while (i--) { 135 | type = types[i]; 136 | handlers = events[type]; 137 | if (!handlers) { continue; } 138 | k = handlers.length; 139 | while (k--) { 140 | if (handlers[k][0] === fn) { 141 | node.removeEventListener(type, handlers[k][1]); 142 | handlers.splice(k, 1); 143 | } 144 | } 145 | } 146 | } 147 | 148 | function trigger(node, type, properties) { 149 | // Don't cache events. It prevents you from triggering an event of a 150 | // given type from inside the handler of another event of that type. 151 | var event = createEvent(type); 152 | if (properties) { assign(event, properties); } 153 | node.dispatchEvent(event); 154 | } 155 | 156 | 157 | // Constructors 158 | 159 | function Timer(fn){ 160 | var callback = fn, 161 | active = false, 162 | running = false; 163 | 164 | function trigger(time) { 165 | if (active){ 166 | callback(); 167 | requestFrame(trigger); 168 | running = true; 169 | active = false; 170 | } 171 | else { 172 | running = false; 173 | } 174 | } 175 | 176 | this.kick = function(fn) { 177 | active = true; 178 | if (!running) { trigger(); } 179 | }; 180 | 181 | this.end = function(fn) { 182 | var cb = callback; 183 | 184 | if (!fn) { return; } 185 | 186 | // If the timer is not running, simply call the end callback. 187 | if (!running) { 188 | fn(); 189 | } 190 | // If the timer is running, and has been kicked lately, then 191 | // queue up the current callback and the end callback, otherwise 192 | // just the end callback. 193 | else { 194 | callback = active ? 195 | function(){ cb(); fn(); } : 196 | fn ; 197 | 198 | active = true; 199 | } 200 | }; 201 | } 202 | 203 | 204 | // Functions 205 | 206 | function noop() {} 207 | 208 | function preventDefault(e) { 209 | e.preventDefault(); 210 | } 211 | 212 | function isIgnoreTag(e) { 213 | return !!ignoreTags[e.target.tagName.toLowerCase()]; 214 | } 215 | 216 | function isPrimaryButton(e) { 217 | // Ignore mousedowns on any button other than the left (or primary) 218 | // mouse button, or when a modifier key is pressed. 219 | return (e.which === 1 && !e.ctrlKey && !e.altKey); 220 | } 221 | 222 | function identifiedTouch(touchList, id) { 223 | var i, l; 224 | 225 | if (touchList.identifiedTouch) { 226 | return touchList.identifiedTouch(id); 227 | } 228 | 229 | // touchList.identifiedTouch() does not exist in 230 | // webkit yet… we must do the search ourselves... 231 | 232 | i = -1; 233 | l = touchList.length; 234 | 235 | while (++i < l) { 236 | if (touchList[i].identifier === id) { 237 | return touchList[i]; 238 | } 239 | } 240 | } 241 | 242 | function changedTouch(e, data) { 243 | var touch = identifiedTouch(e.changedTouches, data.identifier); 244 | 245 | // This isn't the touch you're looking for. 246 | if (!touch) { return; } 247 | 248 | // Chrome Android (at least) includes touches that have not 249 | // changed in e.changedTouches. That's a bit annoying. Check 250 | // that this touch has changed. 251 | if (touch.pageX === data.pageX && touch.pageY === data.pageY) { return; } 252 | 253 | return touch; 254 | } 255 | 256 | 257 | // Handlers that decide when the first movestart is triggered 258 | 259 | function mousedown(e){ 260 | // Ignore non-primary buttons 261 | if (!isPrimaryButton(e)) { return; } 262 | 263 | // Ignore form and interactive elements 264 | if (isIgnoreTag(e)) { return; } 265 | 266 | on(document, mouseevents.move, mousemove, e); 267 | on(document, mouseevents.cancel, mouseend, e); 268 | } 269 | 270 | function mousemove(e, data){ 271 | checkThreshold(e, data, e, removeMouse); 272 | } 273 | 274 | function mouseend(e, data) { 275 | removeMouse(); 276 | } 277 | 278 | function removeMouse() { 279 | off(document, mouseevents.move, mousemove); 280 | off(document, mouseevents.cancel, mouseend); 281 | } 282 | 283 | function touchstart(e) { 284 | // Don't get in the way of interaction with form elements 285 | if (ignoreTags[e.target.tagName.toLowerCase()]) { return; } 286 | 287 | var touch = e.changedTouches[0]; 288 | 289 | // iOS live updates the touch objects whereas Android gives us copies. 290 | // That means we can't trust the touchstart object to stay the same, 291 | // so we must copy the data. This object acts as a template for 292 | // movestart, move and moveend event objects. 293 | var data = { 294 | target: touch.target, 295 | pageX: touch.pageX, 296 | pageY: touch.pageY, 297 | identifier: touch.identifier, 298 | 299 | // The only way to make handlers individually unbindable is by 300 | // making them unique. 301 | touchmove: function(e, data) { touchmove(e, data); }, 302 | touchend: function(e, data) { touchend(e, data); } 303 | }; 304 | 305 | on(document, touchevents.move, data.touchmove, data); 306 | on(document, touchevents.cancel, data.touchend, data); 307 | } 308 | 309 | function touchmove(e, data) { 310 | var touch = changedTouch(e, data); 311 | if (!touch) { return; } 312 | checkThreshold(e, data, touch, removeTouch); 313 | } 314 | 315 | function touchend(e, data) { 316 | var touch = identifiedTouch(e.changedTouches, data.identifier); 317 | if (!touch) { return; } 318 | removeTouch(data); 319 | } 320 | 321 | function removeTouch(data) { 322 | off(document, touchevents.move, data.touchmove); 323 | off(document, touchevents.cancel, data.touchend); 324 | } 325 | 326 | function checkThreshold(e, data, touch, fn) { 327 | var distX = touch.pageX - data.pageX; 328 | var distY = touch.pageY - data.pageY; 329 | 330 | // Do nothing if the threshold has not been crossed. 331 | if ((distX * distX) + (distY * distY) < (threshold * threshold)) { return; } 332 | 333 | triggerStart(e, data, touch, distX, distY, fn); 334 | } 335 | 336 | function triggerStart(e, data, touch, distX, distY, fn) { 337 | var touches = e.targetTouches; 338 | var time = e.timeStamp - data.timeStamp; 339 | 340 | // Create a movestart object with some special properties that 341 | // are passed only to the movestart handlers. 342 | var template = { 343 | altKey: e.altKey, 344 | ctrlKey: e.ctrlKey, 345 | shiftKey: e.shiftKey, 346 | startX: data.pageX, 347 | startY: data.pageY, 348 | distX: distX, 349 | distY: distY, 350 | deltaX: distX, 351 | deltaY: distY, 352 | pageX: touch.pageX, 353 | pageY: touch.pageY, 354 | velocityX: distX / time, 355 | velocityY: distY / time, 356 | identifier: data.identifier, 357 | targetTouches: touches, 358 | finger: touches ? touches.length : 1, 359 | enableMove: function() { 360 | this.moveEnabled = true; 361 | this.enableMove = noop; 362 | e.preventDefault(); 363 | } 364 | }; 365 | 366 | // Trigger the movestart event. 367 | trigger(data.target, 'movestart', template); 368 | 369 | // Unbind handlers that tracked the touch or mouse up till now. 370 | fn(data); 371 | } 372 | 373 | 374 | // Handlers that control what happens following a movestart 375 | 376 | function activeMousemove(e, data) { 377 | var timer = data.timer; 378 | 379 | data.touch = e; 380 | data.timeStamp = e.timeStamp; 381 | timer.kick(); 382 | } 383 | 384 | function activeMouseend(e, data) { 385 | var target = data.target; 386 | var event = data.event; 387 | var timer = data.timer; 388 | 389 | removeActiveMouse(); 390 | 391 | endEvent(target, event, timer, function() { 392 | // Unbind the click suppressor, waiting until after mouseup 393 | // has been handled. 394 | setTimeout(function(){ 395 | off(target, 'click', preventDefault); 396 | }, 0); 397 | }); 398 | } 399 | 400 | function removeActiveMouse() { 401 | off(document, mouseevents.move, activeMousemove); 402 | off(document, mouseevents.end, activeMouseend); 403 | } 404 | 405 | function activeTouchmove(e, data) { 406 | var event = data.event; 407 | var timer = data.timer; 408 | var touch = changedTouch(e, event); 409 | 410 | if (!touch) { return; } 411 | 412 | // Stop the interface from gesturing 413 | e.preventDefault(); 414 | 415 | event.targetTouches = e.targetTouches; 416 | data.touch = touch; 417 | data.timeStamp = e.timeStamp; 418 | 419 | timer.kick(); 420 | } 421 | 422 | function activeTouchend(e, data) { 423 | var target = data.target; 424 | var event = data.event; 425 | var timer = data.timer; 426 | var touch = identifiedTouch(e.changedTouches, event.identifier); 427 | 428 | // This isn't the touch you're looking for. 429 | if (!touch) { return; } 430 | 431 | removeActiveTouch(data); 432 | endEvent(target, event, timer); 433 | } 434 | 435 | function removeActiveTouch(data) { 436 | off(document, touchevents.move, data.activeTouchmove); 437 | off(document, touchevents.end, data.activeTouchend); 438 | } 439 | 440 | 441 | // Logic for triggering move and moveend events 442 | 443 | function updateEvent(event, touch, timeStamp) { 444 | var time = timeStamp - event.timeStamp; 445 | 446 | event.distX = touch.pageX - event.startX; 447 | event.distY = touch.pageY - event.startY; 448 | event.deltaX = touch.pageX - event.pageX; 449 | event.deltaY = touch.pageY - event.pageY; 450 | 451 | // Average the velocity of the last few events using a decay 452 | // curve to even out spurious jumps in values. 453 | event.velocityX = 0.3 * event.velocityX + 0.7 * event.deltaX / time; 454 | event.velocityY = 0.3 * event.velocityY + 0.7 * event.deltaY / time; 455 | event.pageX = touch.pageX; 456 | event.pageY = touch.pageY; 457 | } 458 | 459 | function endEvent(target, event, timer, fn) { 460 | timer.end(function(){ 461 | trigger(target, 'moveend', event); 462 | return fn && fn(); 463 | }); 464 | } 465 | 466 | 467 | // Set up the DOM 468 | 469 | function movestart(e) { 470 | if (e.defaultPrevented) { return; } 471 | if (!e.moveEnabled) { return; } 472 | 473 | var event = { 474 | startX: e.startX, 475 | startY: e.startY, 476 | pageX: e.pageX, 477 | pageY: e.pageY, 478 | distX: e.distX, 479 | distY: e.distY, 480 | deltaX: e.deltaX, 481 | deltaY: e.deltaY, 482 | velocityX: e.velocityX, 483 | velocityY: e.velocityY, 484 | identifier: e.identifier, 485 | targetTouches: e.targetTouches, 486 | finger: e.finger 487 | }; 488 | 489 | var data = { 490 | target: e.target, 491 | event: event, 492 | timer: new Timer(update), 493 | touch: undefined, 494 | timeStamp: e.timeStamp 495 | }; 496 | 497 | function update(time) { 498 | updateEvent(event, data.touch, data.timeStamp); 499 | trigger(data.target, 'move', event); 500 | } 501 | 502 | if (e.identifier === undefined) { 503 | // We're dealing with a mouse event. 504 | // Stop clicks from propagating during a move 505 | on(e.target, 'click', preventDefault); 506 | on(document, mouseevents.move, activeMousemove, data); 507 | on(document, mouseevents.end, activeMouseend, data); 508 | } 509 | else { 510 | // In order to unbind correct handlers they have to be unique 511 | data.activeTouchmove = function(e, data) { activeTouchmove(e, data); }; 512 | data.activeTouchend = function(e, data) { activeTouchend(e, data); }; 513 | 514 | // We're dealing with a touch. 515 | on(document, touchevents.move, data.activeTouchmove, data); 516 | on(document, touchevents.end, data.activeTouchend, data); 517 | } 518 | } 519 | 520 | on(document, 'mousedown', mousedown); 521 | on(document, 'touchstart', touchstart); 522 | on(document, 'movestart', movestart); 523 | 524 | 525 | // jQuery special events 526 | // 527 | // jQuery event objects are copies of DOM event objects. They need 528 | // a little help copying the move properties across. 529 | 530 | if (!window.jQuery) { return; } 531 | 532 | var properties = ("startX startY pageX pageY distX distY deltaX deltaY velocityX velocityY").split(' '); 533 | 534 | function enableMove1(e) { e.enableMove(); } 535 | function enableMove2(e) { e.enableMove(); } 536 | function enableMove3(e) { e.enableMove(); } 537 | 538 | function add(handleObj) { 539 | var handler = handleObj.handler; 540 | 541 | handleObj.handler = function(e) { 542 | // Copy move properties across from originalEvent 543 | var i = properties.length; 544 | var property; 545 | 546 | while(i--) { 547 | property = properties[i]; 548 | e[property] = e.originalEvent[property]; 549 | } 550 | 551 | handler.apply(this, arguments); 552 | }; 553 | } 554 | 555 | jQuery.event.special.movestart = { 556 | setup: function() { 557 | // Movestart must be enabled to allow other move events 558 | on(this, 'movestart', enableMove1); 559 | 560 | // Do listen to DOM events 561 | return false; 562 | }, 563 | 564 | teardown: function() { 565 | off(this, 'movestart', enableMove1); 566 | return false; 567 | }, 568 | 569 | add: add 570 | }; 571 | 572 | jQuery.event.special.move = { 573 | setup: function() { 574 | on(this, 'movestart', enableMove2); 575 | return false; 576 | }, 577 | 578 | teardown: function() { 579 | off(this, 'movestart', enableMove2); 580 | return false; 581 | }, 582 | 583 | add: add 584 | }; 585 | 586 | jQuery.event.special.moveend = { 587 | setup: function() { 588 | on(this, 'movestart', enableMove3); 589 | return false; 590 | }, 591 | 592 | teardown: function() { 593 | off(this, 'movestart', enableMove3); 594 | return false; 595 | }, 596 | 597 | add: add 598 | }; 599 | }); 600 | -------------------------------------------------------------------------------- /docs/static/js/jquery.twentytwenty.js: -------------------------------------------------------------------------------- 1 | (function($){ 2 | 3 | $.fn.twentytwenty = function(options) { 4 | var options = $.extend({ 5 | default_offset_pct: 0.5, 6 | orientation: 'horizontal', 7 | before_label: 'Before', 8 | after_label: 'After', 9 | no_overlay: false, 10 | move_slider_on_hover: false, 11 | move_with_handle_only: true, 12 | click_to_move: false, 13 | }, options); 14 | 15 | return this.each(function() { 16 | var container = $(this); 17 | var sliderOrientation = options.orientation; 18 | var beforeDirection = (sliderOrientation === 'vertical') ? 'down' : 'left'; 19 | var afterDirection = (sliderOrientation === 'vertical') ? 'up' : 'right'; 20 | var this_Offset_Pct = $(container).attr("default_offset_pct"); 21 | var sliderPct = this_Offset_Pct ? this_Offset_Pct : options.default_offset_pct; 22 | 23 | container.wrap("
"); 24 | if(!options.no_overlay) { 25 | container.append("
"); 26 | var overlay = container.find(".twentytwenty-overlay"); 27 | overlay.append("
"); 28 | overlay.append("
"); 29 | } 30 | var beforeImg = container.find(".cmpcontent:first"); 31 | var afterImg = container.find(".cmpcontent:last"); 32 | container.append("
"); 33 | var slider = container.find(".twentytwenty-handle"); 34 | slider.append(""); 35 | slider.append(""); 36 | container.addClass("twentytwenty-container"); 37 | beforeImg.addClass("twentytwenty-before"); 38 | afterImg.addClass("twentytwenty-after"); 39 | 40 | var calcOffset = function(dimensionPct) { 41 | var w = $(container).width(); 42 | var h = $(container).height(); 43 | 44 | return { 45 | w: w+"px", 46 | h: h+"px", 47 | cw: (dimensionPct*w)+"px", 48 | ch: (dimensionPct*h)+"px" 49 | }; 50 | }; 51 | 52 | var adjustContainer = function(offset) { 53 | if (sliderOrientation === 'vertical') { 54 | beforeImg.css("clip", "rect(0,"+offset.w+","+offset.ch+",0)"); 55 | afterImg.css("clip", "rect("+offset.ch+","+offset.w+","+offset.h+",0)"); 56 | } 57 | else { 58 | beforeImg.css("clip", "rect(0,"+offset.cw+","+offset.h+",0)"); 59 | afterImg.css("clip", "rect(0,"+offset.w+","+offset.h+","+offset.cw+")"); 60 | } 61 | container.css("height", offset.h); 62 | }; 63 | 64 | var adjustSlider = function(pct) { 65 | var offset = calcOffset(pct); 66 | slider.css((sliderOrientation==="vertical") ? "top" : "left", (sliderOrientation==="vertical") ? offset.ch : offset.cw); 67 | adjustContainer(offset); 68 | }; 69 | 70 | // Return the number specified or the min/max number if it outside the range given. 71 | var minMaxNumber = function(num, min, max) { 72 | return Math.max(min, Math.min(max, num)); 73 | }; 74 | 75 | // Calculate the slider percentage based on the position. 76 | var getSliderPercentage = function(positionX, positionY) { 77 | var sliderPercentage = (sliderOrientation === 'vertical') ? 78 | (positionY-offsetY)/imgHeight : 79 | (positionX-offsetX)/imgWidth; 80 | 81 | return minMaxNumber(sliderPercentage, 0, 1); 82 | }; 83 | 84 | 85 | $(window).on("resize.twentytwenty", function(e) { 86 | adjustSlider(sliderPct); 87 | }); 88 | 89 | var offsetX = 0; 90 | var offsetY = 0; 91 | var imgWidth = 0; 92 | var imgHeight = 0; 93 | var onMoveStart = function(e) { 94 | if (((e.distX > e.distY && e.distX < -e.distY) || (e.distX < e.distY && e.distX > -e.distY)) && sliderOrientation !== 'vertical') { 95 | e.preventDefault(); 96 | } 97 | else if (((e.distX < e.distY && e.distX < -e.distY) || (e.distX > e.distY && e.distX > -e.distY)) && sliderOrientation === 'vertical') { 98 | e.preventDefault(); 99 | } 100 | container.addClass("active"); 101 | offsetX = container.offset().left; 102 | offsetY = container.offset().top; 103 | imgWidth = beforeImg.width(); 104 | imgHeight = beforeImg.height(); 105 | }; 106 | var onMove = function(e) { 107 | if (container.hasClass("active")) { 108 | sliderPct = getSliderPercentage(e.pageX, e.pageY); 109 | adjustSlider(sliderPct); 110 | } 111 | }; 112 | var onMoveEnd = function() { 113 | container.removeClass("active"); 114 | }; 115 | 116 | var moveTarget = options.move_with_handle_only ? slider : container; 117 | moveTarget.on("movestart",onMoveStart); 118 | moveTarget.on("move",onMove); 119 | moveTarget.on("moveend",onMoveEnd); 120 | 121 | if (options.move_slider_on_hover) { 122 | container.on("mouseenter", onMoveStart); 123 | container.on("mousemove", onMove); 124 | container.on("mouseleave", onMoveEnd); 125 | } 126 | 127 | slider.on("touchmove", function(e) { 128 | e.preventDefault(); 129 | }); 130 | 131 | container.find("cmpcontent").on("mousedown", function(event) { 132 | event.preventDefault(); 133 | }); 134 | 135 | if (options.click_to_move) { 136 | container.on('click', function(e) { 137 | offsetX = container.offset().left; 138 | offsetY = container.offset().top; 139 | imgWidth = beforeImg.width(); 140 | imgHeight = beforeImg.height(); 141 | 142 | sliderPct = getSliderPercentage(e.pageX, e.pageY); 143 | adjustSlider(sliderPct); 144 | }); 145 | } 146 | 147 | $(window).trigger("resize.twentytwenty"); 148 | }); 149 | }; 150 | 151 | })(jQuery); -------------------------------------------------------------------------------- /docs/static/pdfs/cleandift.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/docs/static/pdfs/cleandift.pdf -------------------------------------------------------------------------------- /notebooks/example_images/2007_009889.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/notebooks/example_images/2007_009889.jpg -------------------------------------------------------------------------------- /notebooks/example_images/2008_004188.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/notebooks/example_images/2008_004188.jpg -------------------------------------------------------------------------------- /notebooks/example_images/depth/gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/notebooks/example_images/depth/gt.png -------------------------------------------------------------------------------- /notebooks/example_images/depth/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/notebooks/example_images/depth/input.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | diffusers>=0.27 3 | einops 4 | hydra-core 5 | jaxtyping 6 | numpy 7 | omegaconf 8 | torch>=2.1.0 9 | torchvision 10 | transformers 11 | tqdm -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/cleandift/87bc45edc290eddc82c6a9bba9fc17896043f3ec/src/__init__.py -------------------------------------------------------------------------------- /src/ae.py: -------------------------------------------------------------------------------- 1 | import diffusers 2 | import torch 3 | from torch import nn 4 | 5 | class AutoencoderKL(nn.Module): 6 | def __init__(self, scale: float = 0.18215, shift: float = 0.0, repo="stabilityai/stable-diffusion-2-1"): 7 | super().__init__() 8 | self.scale = scale 9 | self.shift = shift 10 | self.ae = diffusers.AutoencoderKL.from_pretrained(repo, subfolder="vae") 11 | self.ae.eval() 12 | self.ae.compile() 13 | self.ae.requires_grad_(False) 14 | 15 | def forward(self, img): 16 | return self.encode(img) 17 | 18 | @torch.no_grad() 19 | def encode(self, img): 20 | latent = self.ae.encode(img, return_dict=False)[0].sample() 21 | return (latent - self.shift) * self.scale 22 | 23 | @torch.no_grad() 24 | def decode(self, latent): 25 | rec = self.ae.decode(latent / self.scale + self.shift, return_dict=False)[0] 26 | return rec -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import torch 5 | import torch.utils.data as data 6 | from PIL import Image 7 | from torchvision.transforms.functional import to_tensor 8 | from torchvision import transforms 9 | from typing import Tuple 10 | 11 | def load_image(path: str, img_size: int): 12 | image = Image.open(path).convert("RGB") 13 | resize_transform = transforms.Resize((img_size, img_size)) 14 | image = resize_transform(image) 15 | image = to_tensor(image) * 2 - 1 16 | return image 17 | 18 | class DummyDataset(data.Dataset): 19 | def __init__(self, dataset_dir: str, img_size: int = 512, train: bool = True): 20 | self.dataset_dir = dataset_dir 21 | self.img_size = img_size 22 | self.data = [] 23 | 24 | jpg_files = [f for f in os.listdir(dataset_dir) if f.endswith('.jpg')] 25 | for img_path in jpg_files: 26 | json_path = os.path.join(dataset_dir, os.path.splitext(img_path)[0] + ".json") 27 | assert os.path.exists(json_path) 28 | with open(json_path, 'r') as json_file: 29 | json_dict = json.load(json_file) 30 | assert "caption" in json_dict.keys() 31 | self.data.append({"img_path": os.path.join(dataset_dir, img_path), "caption": json_dict["caption"]}) 32 | 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | def __getitem__(self, idx): 38 | item = self.data[idx] 39 | sample = {} 40 | # Load image 41 | sample["x"] = load_image(item["img_path"], img_size=self.img_size) 42 | sample["caption"] = item["caption"] 43 | return sample 44 | 45 | class DataModule: 46 | def __init__(self, dataset_dir: str, batch_size: int = 1, img_size: int = 512): 47 | self.batch_size = batch_size 48 | 49 | train_dataset = DummyDataset(dataset_dir=dataset_dir, train=True, img_size=img_size) 50 | self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) 51 | 52 | val_dataset = DummyDataset(dataset_dir=dataset_dir, train=False, img_size=img_size) 53 | self.val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size) 54 | 55 | 56 | def train_dataloader(self): 57 | return self.train_loader 58 | 59 | def val_dataloader(self): 60 | return self.val_loader -------------------------------------------------------------------------------- /src/depth.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from omegaconf import OmegaConf 3 | import hydra 4 | import torch 5 | from torch import nn 6 | from jaxtyping import Float 7 | import torch.nn.functional as F 8 | 9 | 10 | class SigLoss(nn.Module): 11 | """SigLoss. 12 | 13 | This follows `AdaBins `_. 14 | 15 | Args: 16 | valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. 17 | loss_weight (float): Weight of the loss. Default: 1.0. 18 | max_depth (int): When filtering invalid gt, set a max threshold. Default: None. 19 | warm_up (bool): A simple warm up stage to help convergence. Default: False. 20 | warm_iter (int): The number of warm up stage. Default: 100. 21 | 22 | Adapted from: https://github.com/facebookresearch/dinov2/blob/main/dinov2/eval/depth/models/losses/sigloss.py 23 | """ 24 | 25 | def __init__( 26 | self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss" 27 | ): 28 | super(SigLoss, self).__init__() 29 | self.valid_mask = valid_mask 30 | self.loss_weight = loss_weight 31 | self.max_depth = max_depth 32 | self.loss_name = loss_name 33 | 34 | self.eps = 0.001 35 | 36 | self.warm_up = warm_up 37 | self.warm_iter = warm_iter 38 | self.warm_up_counter = 0 39 | 40 | def sigloss(self, input, target): 41 | if self.valid_mask: 42 | valid_mask = target > 0 43 | if self.max_depth is not None: 44 | valid_mask = torch.logical_and(target > 0, target <= self.max_depth) 45 | input = input[valid_mask] 46 | target = target[valid_mask] 47 | 48 | if self.warm_up: 49 | if self.warm_up_counter < self.warm_iter: 50 | g = torch.log(input + self.eps) - torch.log(target + self.eps) 51 | g = 0.15 * torch.pow(torch.mean(g), 2) 52 | self.warm_up_counter += 1 53 | return torch.sqrt(g) 54 | 55 | g = torch.log(input + self.eps) - torch.log(target + self.eps) 56 | Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2) 57 | return torch.sqrt(Dg) 58 | 59 | def forward(self, depth_pred, depth_gt): 60 | """Forward function.""" 61 | 62 | loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt) 63 | return loss_depth 64 | 65 | 66 | class DepthPred(torch.nn.Module): 67 | """ 68 | Adapted from: 69 | - https://github.com/facebookresearch/dinov2/blob/main/dinov2/eval/depth/models/decode_heads/decode_head.py 70 | - https://github.com/facebookresearch/dinov2/blob/main/dinov2/eval/depth/models/decode_heads/linear_head.py 71 | """ 72 | 73 | def __init__( 74 | self, 75 | model_config_path: str, 76 | channels: int, 77 | loss: nn.Module, 78 | diffusion_image_size: int, 79 | classify=True, 80 | n_bins=256, 81 | min_depth=0.1, 82 | max_depth=10.0, 83 | bins_strategy="UD", 84 | norm_strategy="linear", 85 | scale_up=False, 86 | use_base_model_features=False, 87 | base_model_timestep: None | int = None, 88 | adapter_timestep: int | None = None, 89 | extraction_layer="us6", 90 | interpolate_features: Literal["DINO", "FULL", "NONE"] = "NONE", 91 | n_vis_samples: int = 0, 92 | ): 93 | super().__init__() 94 | self.loss = loss 95 | self.classify = classify 96 | self.n_bins = n_bins 97 | self.min_depth = min_depth 98 | self.max_depth = max_depth 99 | self.scale_up = scale_up 100 | self.use_base_model_features = use_base_model_features 101 | self.base_model_timestep = base_model_timestep 102 | self.diffusion_image_size = diffusion_image_size 103 | self.extraction_layer = extraction_layer 104 | self.interpolate_features = interpolate_features 105 | self.adapter_timestep = adapter_timestep 106 | self.n_vis_samples = n_vis_samples 107 | 108 | if use_base_model_features: 109 | assert base_model_timestep is not None, "Need to provide base_model_timestep if using base model features" 110 | 111 | cfg_model = OmegaConf.load(model_config_path) 112 | OmegaConf.resolve(cfg_model) 113 | self.feature_extractor = hydra.utils.instantiate(cfg_model).model 114 | 115 | self.feature_extractor.requires_grad_(False) 116 | self.feature_extractor.eval() 117 | 118 | if self.classify: 119 | assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" 120 | assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" 121 | 122 | self.bins_strategy = bins_strategy 123 | self.norm_strategy = norm_strategy 124 | self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=1, padding=0, stride=1) 125 | else: 126 | self.conv_depth = nn.Conv2d(channels, 1, kernel_size=1, padding=0, stride=1) 127 | 128 | def depth_pred(self, feat: Float[torch.Tensor, "B C H W"]) -> Float[torch.Tensor, "B 1 H W"]: 129 | """Prediction each pixel.""" 130 | if self.classify: 131 | logit = self.conv_depth(feat) 132 | 133 | if self.bins_strategy == "UD": 134 | bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device, dtype=feat.dtype) 135 | elif self.bins_strategy == "SID": 136 | bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device, dtype=feat.dtype) 137 | 138 | # following Adabins, default linear 139 | if self.norm_strategy == "linear": 140 | logit = torch.relu(logit) 141 | eps = 0.1 142 | logit = logit + eps 143 | logit = logit / logit.sum(dim=1, keepdim=True) 144 | elif self.norm_strategy == "softmax": 145 | logit = torch.softmax(logit, dim=1) 146 | elif self.norm_strategy == "sigmoid": 147 | logit = torch.sigmoid(logit) 148 | logit = logit / logit.sum(dim=1, keepdim=True) 149 | 150 | output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) 151 | 152 | else: 153 | if self.scale_up: 154 | output = torch.sigmoid(self.conv_depth(feat)) * self.max_depth 155 | else: 156 | output = torch.relu(self.conv_depth(feat)) + self.min_depth 157 | 158 | return output 159 | 160 | def forward(self, img, depth_gt, *args, **kwargs): 161 | depth_target = depth_gt.data[0].cuda().bfloat16() 162 | image = img.data[0].cuda().bfloat16() 163 | 164 | depth_pred = self.predict(image) 165 | 166 | if self.interpolate_features != "FULL": 167 | H, W = depth_pred.shape[-2:] 168 | depth_target = F.interpolate(depth_target, (H, W), mode="bilinear", align_corners=False) 169 | 170 | return self.loss(depth_pred, depth_target) 171 | 172 | def predict(self, image): 173 | 174 | B, C, H, W = image.shape 175 | image = F.interpolate( 176 | image, size=(self.diffusion_image_size, self.diffusion_image_size), mode="bilinear", align_corners=False 177 | ) 178 | 179 | with torch.no_grad(): 180 | timestep = None 181 | if self.use_base_model_features: 182 | timestep = torch.tensor([self.base_model_timestep] * B, device=image.device) 183 | elif self.adapter_timestep is not None: 184 | timestep = torch.tensor([self.adapter_timestep] * B, device=image.device) 185 | features = self.feature_extractor.get_features( 186 | image, 187 | ["A photo of a room"] * B, 188 | timestep, 189 | self.extraction_layer, 190 | use_base_model=self.use_base_model_features, 191 | ) 192 | 193 | if self.interpolate_features == "FULL": 194 | features = F.interpolate(features, size=(H, W), mode="bilinear", align_corners=False) 195 | elif self.interpolate_features == "DINO": 196 | H, W = features.shape[-2:] 197 | features = F.interpolate(features, size=(H * 4, W * 4), mode="bilinear", align_corners=False) 198 | 199 | depth_pred = self.depth_pred(features) 200 | 201 | return depth_pred 202 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import reduce 6 | from .utils import zero_init 7 | 8 | 9 | class FeedForwardBlock(nn.Module): 10 | def __init__(self, d_model, d_ff, d_cond_norm=None, dropout=0.0): 11 | super().__init__() 12 | if d_cond_norm is not None: 13 | self.norm = AdaRMSNorm(d_model, d_cond_norm) 14 | else: 15 | self.norm = RMSNorm(d_model) 16 | self.up_proj = LinearSwiGLU(d_model, d_ff, bias=False) 17 | self.dropout = nn.Dropout(dropout) 18 | self.down_proj = zero_init(Linear(d_ff, d_model, bias=False)) 19 | 20 | def forward(self, x, cond_norm=None, **kwargs): 21 | skip = x 22 | if cond_norm is not None: 23 | x = self.norm(x, cond_norm) 24 | else: 25 | x = self.norm(x) 26 | x = self.up_proj(x) 27 | x = self.dropout(x) 28 | x = self.down_proj(x) 29 | return x + skip 30 | 31 | class RMSNorm(nn.Module): 32 | def __init__(self, shape, eps=1e-6): 33 | super().__init__() 34 | self.eps = eps 35 | self.scale = nn.Parameter(torch.ones(shape)) 36 | 37 | def extra_repr(self): 38 | return f"shape={tuple(self.scale.shape)}, eps={self.eps}" 39 | 40 | def forward(self, x): 41 | return rms_norm(x, self.scale, self.eps) 42 | 43 | 44 | def rms_norm(x, scale, eps): 45 | dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) 46 | mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) 47 | scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) 48 | return x * scale.to(x.dtype) 49 | 50 | class AdaRMSNorm(nn.Module): 51 | def __init__(self, features, cond_features, eps=1e-6): 52 | super().__init__() 53 | self.eps = eps 54 | self.linear = zero_init(Linear(cond_features, features, bias=False)) 55 | 56 | def extra_repr(self): 57 | return f"eps={self.eps}," 58 | 59 | def forward(self, x, cond): 60 | return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) 61 | 62 | class MappingNetwork(nn.Module): 63 | def __init__(self, n_layers, d_model, d_ff, dropout=0.0): 64 | super().__init__() 65 | self.in_norm = RMSNorm(d_model) 66 | self.blocks = nn.ModuleList([FeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)]) 67 | self.out_norm = RMSNorm(d_model) 68 | 69 | def forward(self, x): 70 | x = self.in_norm(x) 71 | for block in self.blocks: 72 | x = block(x) 73 | x = self.out_norm(x) 74 | return x 75 | 76 | 77 | def linear_swiglu(x, weight, bias=None): 78 | x = x @ weight.mT 79 | if bias is not None: 80 | x = x + bias 81 | x, gate = x.chunk(2, dim=-1) 82 | return x * F.silu(gate) 83 | 84 | class LinearSwiGLU(nn.Linear): 85 | def __init__(self, in_features, out_features, bias=True): 86 | super().__init__(in_features, out_features * 2, bias=bias) 87 | self.out_features = out_features 88 | 89 | def forward(self, x): 90 | return linear_swiglu(x, self.weight, self.bias) 91 | 92 | 93 | class Linear(nn.Linear): 94 | def forward(self, x): 95 | return super().forward(x) 96 | 97 | 98 | class FourierFeatures(nn.Module): 99 | def __init__(self, in_features, out_features, std=1.): 100 | super().__init__() 101 | assert out_features % 2 == 0 102 | self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) 103 | 104 | def forward(self, input): 105 | f = 2 * math.pi * input @ self.weight.T 106 | return torch.cat([f.cos(), f.sin()], dim=-1) -------------------------------------------------------------------------------- /src/min_sd15.py: -------------------------------------------------------------------------------- 1 | # Obviously modified from the original source code 2 | # https://github.com/huggingface/diffusers 3 | # So has APACHE 2.0 license 4 | 5 | # Author : Simo Ryu 6 | # Adapted for SD15 by Nick Stracke 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math 12 | 13 | from collections import namedtuple 14 | 15 | 16 | class Timesteps(nn.Module): 17 | def __init__(self, num_channels: int = 320): 18 | super().__init__() 19 | self.num_channels = num_channels 20 | 21 | def forward(self, timesteps): 22 | half_dim = self.num_channels // 2 23 | exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) 24 | exponent = exponent / (half_dim - 0.0) 25 | 26 | emb = torch.exp(exponent) 27 | emb = timesteps[:, None].float() * emb[None, :] 28 | 29 | sin_emb = torch.sin(emb) 30 | cos_emb = torch.cos(emb) 31 | emb = torch.cat([cos_emb, sin_emb], dim=-1) 32 | 33 | return emb 34 | 35 | 36 | class TimestepEmbedding(nn.Module): 37 | def __init__(self, in_features, out_features): 38 | super(TimestepEmbedding, self).__init__() 39 | self.linear_1 = nn.Linear(in_features, out_features, bias=True) 40 | self.act = nn.SiLU() 41 | self.linear_2 = nn.Linear(out_features, out_features, bias=True) 42 | 43 | def forward(self, sample): 44 | sample = self.linear_1(sample) 45 | sample = self.act(sample) 46 | sample = self.linear_2(sample) 47 | 48 | return sample 49 | 50 | 51 | class ResnetBlock2D(nn.Module): 52 | def __init__(self, in_channels, out_channels, conv_shortcut=True): 53 | super(ResnetBlock2D, self).__init__() 54 | self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-05, affine=True) 55 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 56 | self.time_emb_proj = nn.Linear(1280, out_channels, bias=True) 57 | self.norm2 = nn.GroupNorm(32, out_channels, eps=1e-05, affine=True) 58 | self.dropout = nn.Dropout(p=0.0, inplace=False) 59 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 60 | self.nonlinearity = nn.SiLU() 61 | self.conv_shortcut = None 62 | if conv_shortcut: 63 | self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) 64 | 65 | def forward(self, input_tensor, temb): 66 | hidden_states = input_tensor 67 | hidden_states = self.norm1(hidden_states) 68 | hidden_states = self.nonlinearity(hidden_states) 69 | 70 | hidden_states = self.conv1(hidden_states) 71 | 72 | temb = self.nonlinearity(temb) 73 | temb = self.time_emb_proj(temb)[:, :, None, None] 74 | hidden_states = hidden_states + temb 75 | hidden_states = self.norm2(hidden_states) 76 | 77 | hidden_states = self.nonlinearity(hidden_states) 78 | hidden_states = self.dropout(hidden_states) 79 | hidden_states = self.conv2(hidden_states) 80 | 81 | if self.conv_shortcut is not None: 82 | input_tensor = self.conv_shortcut(input_tensor) 83 | 84 | output_tensor = input_tensor + hidden_states 85 | 86 | return output_tensor 87 | 88 | 89 | class Attention(nn.Module): 90 | def __init__(self, inner_dim, cross_attention_dim=None, num_heads=8, dropout=0.0): 91 | super(Attention, self).__init__() 92 | if num_heads is None: 93 | self.head_dim = 64 94 | self.num_heads = inner_dim // self.head_dim 95 | else: 96 | self.num_heads = num_heads 97 | self.head_dim = inner_dim // num_heads 98 | 99 | self.scale = self.head_dim**-0.5 100 | if cross_attention_dim is None: 101 | cross_attention_dim = inner_dim 102 | self.to_q = nn.Linear(inner_dim, inner_dim, bias=False) 103 | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) 104 | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) 105 | 106 | self.to_out = nn.ModuleList([nn.Linear(inner_dim, inner_dim), nn.Dropout(dropout, inplace=False)]) 107 | 108 | def forward(self, hidden_states, encoder_hidden_states=None): 109 | q = self.to_q(hidden_states) 110 | k = self.to_k(encoder_hidden_states) if encoder_hidden_states is not None else self.to_k(hidden_states) 111 | v = self.to_v(encoder_hidden_states) if encoder_hidden_states is not None else self.to_v(hidden_states) 112 | b, t, c = q.size() 113 | 114 | q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2) 115 | k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2) 116 | v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2) 117 | 118 | attn_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=self.scale) 119 | attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c) 120 | 121 | for layer in self.to_out: 122 | attn_output = layer(attn_output) 123 | 124 | return attn_output 125 | 126 | 127 | class GEGLU(nn.Module): 128 | def __init__(self, in_features, out_features): 129 | super(GEGLU, self).__init__() 130 | self.proj = nn.Linear(in_features, out_features * 2, bias=True) 131 | 132 | def forward(self, x): 133 | x_proj = self.proj(x) 134 | x1, x2 = x_proj.chunk(2, dim=-1) 135 | return x1 * torch.nn.functional.gelu(x2) 136 | 137 | 138 | class FeedForward(nn.Module): 139 | def __init__(self, in_features, out_features): 140 | super(FeedForward, self).__init__() 141 | 142 | self.net = nn.ModuleList( 143 | [ 144 | GEGLU(in_features, out_features * 4), 145 | nn.Dropout(p=0.0, inplace=False), 146 | nn.Linear(out_features * 4, out_features, bias=True), 147 | ] 148 | ) 149 | 150 | def forward(self, x): 151 | for layer in self.net: 152 | x = layer(x) 153 | return x 154 | 155 | 156 | class BasicTransformerBlock(nn.Module): 157 | def __init__(self, hidden_size): 158 | super(BasicTransformerBlock, self).__init__() 159 | self.norm1 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) 160 | self.attn1 = Attention(hidden_size) 161 | self.norm2 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) 162 | self.attn2 = Attention(hidden_size, 768) 163 | self.norm3 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) 164 | self.ff = FeedForward(hidden_size, hidden_size) 165 | 166 | def forward(self, x, encoder_hidden_states=None): 167 | residual = x 168 | 169 | x = self.norm1(x) 170 | x = self.attn1(x) 171 | x = x + residual 172 | 173 | residual = x 174 | 175 | x = self.norm2(x) 176 | if encoder_hidden_states is not None: 177 | x = self.attn2(x, encoder_hidden_states) 178 | else: 179 | x = self.attn2(x) 180 | x = x + residual 181 | 182 | residual = x 183 | 184 | x = self.norm3(x) 185 | x = self.ff(x) 186 | x = x + residual 187 | return x 188 | 189 | 190 | class Transformer2DModel(nn.Module): 191 | def __init__(self, in_channels, out_channels, n_layers): 192 | super(Transformer2DModel, self).__init__() 193 | self.norm = nn.GroupNorm(32, in_channels, eps=1e-06, affine=True) 194 | self.proj_in = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 195 | self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(out_channels) for _ in range(n_layers)]) 196 | self.proj_out = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0) 197 | 198 | def forward(self, hidden_states, encoder_hidden_states=None): 199 | batch, _, height, width = hidden_states.shape 200 | res = hidden_states 201 | hidden_states = self.norm(hidden_states) 202 | inner_dim = hidden_states.shape[1] 203 | hidden_states = self.proj_in(hidden_states) 204 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 205 | 206 | for block in self.transformer_blocks: 207 | hidden_states = block(hidden_states, encoder_hidden_states) 208 | 209 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 210 | hidden_states = self.proj_out(hidden_states) 211 | 212 | return hidden_states + res 213 | 214 | 215 | class Downsample2D(nn.Module): 216 | def __init__(self, in_channels, out_channels): 217 | super(Downsample2D, self).__init__() 218 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) 219 | 220 | def forward(self, x): 221 | return self.conv(x) 222 | 223 | 224 | class Upsample2D(nn.Module): 225 | def __init__(self, in_channels, out_channels): 226 | super(Upsample2D, self).__init__() 227 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 228 | 229 | def forward(self, x): 230 | x = F.interpolate(x, scale_factor=2.0, mode="nearest") 231 | return self.conv(x) 232 | 233 | 234 | class DownBlock2D(nn.Module): 235 | def __init__(self, in_channels, out_channels, has_downsamplers=True): 236 | super(DownBlock2D, self).__init__() 237 | self.resnets = nn.ModuleList( 238 | [ 239 | ResnetBlock2D(in_channels, out_channels, conv_shortcut=False), 240 | ResnetBlock2D(out_channels, out_channels, conv_shortcut=False), 241 | ] 242 | ) 243 | self.downsamplers = None 244 | if has_downsamplers: 245 | self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) 246 | 247 | def forward(self, hidden_states, temb): 248 | output_states = [] 249 | for module in self.resnets: 250 | hidden_states = module(hidden_states, temb) 251 | output_states.append(hidden_states) 252 | 253 | if self.downsamplers is not None: 254 | hidden_states = self.downsamplers[0](hidden_states) 255 | output_states.append(hidden_states) 256 | 257 | return hidden_states, output_states 258 | 259 | 260 | class CrossAttnDownBlock2D(nn.Module): 261 | def __init__(self, in_channels, out_channels, n_layers, has_downsamplers=True, conv_shortcut=True): 262 | super(CrossAttnDownBlock2D, self).__init__() 263 | self.attentions = nn.ModuleList( 264 | [ 265 | Transformer2DModel(out_channels, out_channels, n_layers), 266 | Transformer2DModel(out_channels, out_channels, n_layers), 267 | ] 268 | ) 269 | self.resnets = nn.ModuleList( 270 | [ 271 | ResnetBlock2D(in_channels, out_channels, conv_shortcut), 272 | ResnetBlock2D(out_channels, out_channels, conv_shortcut=False), 273 | ] 274 | ) 275 | self.downsamplers = None 276 | if has_downsamplers: 277 | self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) 278 | 279 | def forward(self, hidden_states, temb, encoder_hidden_states): 280 | output_states = [] 281 | for resnet, attn in zip(self.resnets, self.attentions): 282 | hidden_states = resnet(hidden_states, temb) 283 | hidden_states = attn( 284 | hidden_states, 285 | encoder_hidden_states=encoder_hidden_states, 286 | ) 287 | output_states.append(hidden_states) 288 | 289 | if self.downsamplers is not None: 290 | hidden_states = self.downsamplers[0](hidden_states) 291 | output_states.append(hidden_states) 292 | 293 | return hidden_states, output_states 294 | 295 | 296 | class CrossAttnUpBlock2D(nn.Module): 297 | def __init__(self, in_channels, out_channels, prev_output_channel, n_layers, has_upsamplers=True): 298 | super(CrossAttnUpBlock2D, self).__init__() 299 | self.attentions = nn.ModuleList( 300 | [ 301 | Transformer2DModel(out_channels, out_channels, n_layers), 302 | Transformer2DModel(out_channels, out_channels, n_layers), 303 | Transformer2DModel(out_channels, out_channels, n_layers), 304 | ] 305 | ) 306 | self.resnets = nn.ModuleList( 307 | [ 308 | ResnetBlock2D(prev_output_channel + out_channels, out_channels), 309 | ResnetBlock2D(2 * out_channels, out_channels), 310 | ResnetBlock2D(out_channels + in_channels, out_channels), 311 | ] 312 | ) 313 | 314 | self.upsamplers = None 315 | if has_upsamplers: 316 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) 317 | 318 | def forward(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states): 319 | all_hidden_states = [] 320 | for resnet, attn in zip(self.resnets, self.attentions): 321 | # pop res hidden states 322 | res_hidden_states = res_hidden_states_tuple[-1] 323 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 324 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 325 | hidden_states = resnet(hidden_states, temb) 326 | hidden_states = attn( 327 | hidden_states, 328 | encoder_hidden_states=encoder_hidden_states, 329 | ) 330 | all_hidden_states.append(hidden_states) 331 | 332 | if self.upsamplers is not None: 333 | for upsampler in self.upsamplers: 334 | hidden_states = upsampler(hidden_states) 335 | all_hidden_states[-1] = hidden_states 336 | 337 | return hidden_states, all_hidden_states 338 | 339 | 340 | class UpBlock2D(nn.Module): 341 | def __init__(self, in_channels, out_channels, prev_output_channel, has_upsamplers=True): 342 | super(UpBlock2D, self).__init__() 343 | self.resnets = nn.ModuleList( 344 | [ 345 | ResnetBlock2D(out_channels + prev_output_channel, out_channels), 346 | ResnetBlock2D(out_channels * 2, out_channels), 347 | ResnetBlock2D(out_channels + in_channels, out_channels), 348 | ] 349 | ) 350 | 351 | self.upsamplers = None 352 | if has_upsamplers: 353 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) 354 | 355 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None): 356 | all_hidden_states = [] 357 | for resnet in self.resnets: 358 | res_hidden_states = res_hidden_states_tuple[-1] 359 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 360 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 361 | hidden_states = resnet(hidden_states, temb) 362 | all_hidden_states.append(hidden_states) 363 | 364 | if self.upsamplers is not None: 365 | for upsampler in self.upsamplers: 366 | hidden_states = upsampler(hidden_states) 367 | all_hidden_states[-1] = hidden_states 368 | 369 | return hidden_states, all_hidden_states 370 | 371 | 372 | class UNetMidBlock2DCrossAttn(nn.Module): 373 | def __init__(self, in_features): 374 | super(UNetMidBlock2DCrossAttn, self).__init__() 375 | self.attentions = nn.ModuleList([Transformer2DModel(in_features, in_features, n_layers=1)]) 376 | self.resnets = nn.ModuleList( 377 | [ 378 | ResnetBlock2D(in_features, in_features, conv_shortcut=False), 379 | ResnetBlock2D(in_features, in_features, conv_shortcut=False), 380 | ] 381 | ) 382 | 383 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None): 384 | hidden_states = self.resnets[0](hidden_states, temb) 385 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 386 | hidden_states = attn( 387 | hidden_states, 388 | encoder_hidden_states=encoder_hidden_states, 389 | ) 390 | hidden_states = resnet(hidden_states, temb) 391 | 392 | return hidden_states 393 | 394 | 395 | class SD15UNetModel(nn.Module): 396 | def __init__(self): 397 | super(SD15UNetModel, self).__init__() 398 | 399 | # This is needed to imitate huggingface config behavior 400 | # has nothing to do with the model itself 401 | # remove this if you don't use diffuser's pipeline 402 | self.config = namedtuple("config", "in_channels time_cond_proj_dim sample_size") 403 | self.config.in_channels = 4 404 | # self.config.addition_time_embed_dim = 256 405 | self.config.sample_size = 64 406 | self.config.time_cond_proj_dim = None 407 | 408 | self.conv_in = nn.Conv2d(4, 320, kernel_size=3, stride=1, padding=1) 409 | self.time_proj = Timesteps() 410 | self.time_embedding = TimestepEmbedding(in_features=320, out_features=1280) 411 | self.down_blocks = nn.ModuleList( 412 | [ 413 | CrossAttnDownBlock2D(in_channels=320, out_channels=320, n_layers=1, conv_shortcut=False), 414 | CrossAttnDownBlock2D(in_channels=320, out_channels=640, n_layers=1), 415 | CrossAttnDownBlock2D(in_channels=640, out_channels=1280, n_layers=1), 416 | DownBlock2D(in_channels=1280, out_channels=1280, has_downsamplers=False), 417 | ] 418 | ) 419 | self.up_blocks = nn.ModuleList( 420 | [ 421 | UpBlock2D(in_channels=1280, out_channels=1280, prev_output_channel=1280), 422 | CrossAttnUpBlock2D( 423 | in_channels=640, 424 | out_channels=1280, 425 | prev_output_channel=1280, 426 | n_layers=1, 427 | ), 428 | CrossAttnUpBlock2D( 429 | in_channels=320, 430 | out_channels=640, 431 | prev_output_channel=1280, 432 | n_layers=1, 433 | ), 434 | CrossAttnUpBlock2D( 435 | in_channels=320, out_channels=320, prev_output_channel=640, n_layers=1, has_upsamplers=False 436 | ), 437 | ] 438 | ) 439 | 440 | self.mid_block = UNetMidBlock2DCrossAttn(1280) 441 | self.conv_norm_out = nn.GroupNorm(32, 320, eps=1e-05, affine=True) 442 | self.conv_act = nn.SiLU() 443 | self.conv_out = nn.Conv2d(320, 4, kernel_size=3, stride=1, padding=1) 444 | 445 | def forward(self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs): 446 | # Implement the forward pass through the model 447 | timesteps = timesteps.expand(sample.shape[0]) 448 | t_emb = self.time_proj(timesteps).to(dtype=sample.dtype) 449 | emb = self.time_embedding(t_emb) 450 | 451 | sample = self.conv_in(sample) 452 | 453 | # 3. down 454 | s0 = sample 455 | sample, [s1, s2, s3] = self.down_blocks[0]( 456 | sample, 457 | temb=emb, 458 | encoder_hidden_states=encoder_hidden_states, 459 | ) 460 | 461 | sample, [s4, s5, s6] = self.down_blocks[1]( 462 | sample, 463 | temb=emb, 464 | encoder_hidden_states=encoder_hidden_states, 465 | ) 466 | 467 | sample, [s7, s8, s9] = self.down_blocks[2]( 468 | sample, 469 | temb=emb, 470 | encoder_hidden_states=encoder_hidden_states, 471 | ) 472 | 473 | sample, [s10, s11] = self.down_blocks[3]( 474 | sample, 475 | temb=emb, 476 | ) 477 | 478 | # 4. mid 479 | sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) 480 | 481 | # 5. up 482 | _, [us1, us2, us3] = self.up_blocks[0]( 483 | hidden_states=sample, 484 | temb=emb, 485 | res_hidden_states_tuple=[s9, s10, s11], 486 | ) 487 | 488 | _, [us4, us5, us6] = self.up_blocks[1]( 489 | hidden_states=us3, 490 | temb=emb, 491 | res_hidden_states_tuple=[s6, s7, s8], 492 | encoder_hidden_states=encoder_hidden_states, 493 | ) 494 | 495 | _, [us7, us8, us9] = self.up_blocks[2]( 496 | hidden_states=us6, 497 | temb=emb, 498 | res_hidden_states_tuple=[s3, s4, s5], 499 | encoder_hidden_states=encoder_hidden_states, 500 | ) 501 | 502 | _, [us10, us11, us12] = self.up_blocks[3]( 503 | hidden_states=us9, 504 | temb=emb, 505 | res_hidden_states_tuple=[s0, s1, s2], 506 | encoder_hidden_states=encoder_hidden_states, 507 | ) 508 | 509 | # 6. post-process 510 | sample = self.conv_norm_out(us12) 511 | sample = self.conv_act(sample) 512 | sample = self.conv_out(sample) 513 | 514 | return [sample] 515 | -------------------------------------------------------------------------------- /src/min_sd21.py: -------------------------------------------------------------------------------- 1 | # Obviously modified from the original source code 2 | # https://github.com/huggingface/diffusers 3 | # So has APACHE 2.0 license 4 | 5 | # Author : Simo Ryu 6 | # Adapted for SD21 by Nick Stracke 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math 12 | 13 | from collections import namedtuple 14 | 15 | 16 | class Timesteps(nn.Module): 17 | def __init__(self, num_channels: int = 320): 18 | super().__init__() 19 | self.num_channels = num_channels 20 | 21 | def forward(self, timesteps): 22 | half_dim = self.num_channels // 2 23 | exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) 24 | exponent = exponent / (half_dim - 0.0) 25 | 26 | emb = torch.exp(exponent) 27 | emb = timesteps[:, None].float() * emb[None, :] 28 | 29 | sin_emb = torch.sin(emb) 30 | cos_emb = torch.cos(emb) 31 | emb = torch.cat([cos_emb, sin_emb], dim=-1) 32 | 33 | return emb 34 | 35 | 36 | class TimestepEmbedding(nn.Module): 37 | def __init__(self, in_features, out_features): 38 | super(TimestepEmbedding, self).__init__() 39 | self.linear_1 = nn.Linear(in_features, out_features, bias=True) 40 | self.act = nn.SiLU() 41 | self.linear_2 = nn.Linear(out_features, out_features, bias=True) 42 | 43 | def forward(self, sample): 44 | sample = self.linear_1(sample) 45 | sample = self.act(sample) 46 | sample = self.linear_2(sample) 47 | 48 | return sample 49 | 50 | 51 | class ResnetBlock2D(nn.Module): 52 | def __init__(self, in_channels, out_channels, conv_shortcut=True): 53 | super(ResnetBlock2D, self).__init__() 54 | self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-05, affine=True) 55 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 56 | self.time_emb_proj = nn.Linear(1280, out_channels, bias=True) 57 | self.norm2 = nn.GroupNorm(32, out_channels, eps=1e-05, affine=True) 58 | self.dropout = nn.Dropout(p=0.0, inplace=False) 59 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 60 | self.nonlinearity = nn.SiLU() 61 | self.conv_shortcut = None 62 | if conv_shortcut: 63 | self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) 64 | 65 | def forward(self, input_tensor, temb): 66 | hidden_states = input_tensor 67 | hidden_states = self.norm1(hidden_states) 68 | hidden_states = self.nonlinearity(hidden_states) 69 | 70 | hidden_states = self.conv1(hidden_states) 71 | 72 | temb = self.nonlinearity(temb) 73 | temb = self.time_emb_proj(temb)[:, :, None, None] 74 | hidden_states = hidden_states + temb 75 | hidden_states = self.norm2(hidden_states) 76 | 77 | hidden_states = self.nonlinearity(hidden_states) 78 | hidden_states = self.dropout(hidden_states) 79 | hidden_states = self.conv2(hidden_states) 80 | 81 | if self.conv_shortcut is not None: 82 | input_tensor = self.conv_shortcut(input_tensor) 83 | 84 | output_tensor = input_tensor + hidden_states 85 | 86 | return output_tensor 87 | 88 | 89 | class Attention(nn.Module): 90 | def __init__(self, inner_dim, cross_attention_dim=None, num_heads=None, dropout=0.0): 91 | super(Attention, self).__init__() 92 | if num_heads is None: 93 | self.head_dim = 64 94 | self.num_heads = inner_dim // self.head_dim 95 | else: 96 | self.num_heads = num_heads 97 | self.head_dim = inner_dim // num_heads 98 | 99 | self.scale = self.head_dim**-0.5 100 | if cross_attention_dim is None: 101 | cross_attention_dim = inner_dim 102 | self.to_q = nn.Linear(inner_dim, inner_dim, bias=False) 103 | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) 104 | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) 105 | 106 | self.to_out = nn.ModuleList([nn.Linear(inner_dim, inner_dim), nn.Dropout(dropout, inplace=False)]) 107 | 108 | def forward(self, hidden_states, encoder_hidden_states=None): 109 | q = self.to_q(hidden_states) 110 | k = self.to_k(encoder_hidden_states) if encoder_hidden_states is not None else self.to_k(hidden_states) 111 | v = self.to_v(encoder_hidden_states) if encoder_hidden_states is not None else self.to_v(hidden_states) 112 | 113 | # # NOTE SD21 upcasts the attention computions 114 | # I checked and not upcasting in bfloat16 basically leads to no visible differences ~Nick 115 | # dtype = q.dtype 116 | # q = q.float() 117 | # k = k.float() 118 | # # typically v is not upcasted but we need to do it to work with sdpa 119 | # v = v.float() 120 | 121 | b, t, c = q.size() 122 | 123 | q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2) 124 | k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2) 125 | v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2) 126 | 127 | attn_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=self.scale) 128 | # attn_output = attn_output.to(dtype) 129 | attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c) 130 | 131 | for layer in self.to_out: 132 | attn_output = layer(attn_output) 133 | 134 | return attn_output 135 | 136 | 137 | class GEGLU(nn.Module): 138 | def __init__(self, in_features, out_features): 139 | super(GEGLU, self).__init__() 140 | self.proj = nn.Linear(in_features, out_features * 2, bias=True) 141 | 142 | def forward(self, x): 143 | x_proj = self.proj(x) 144 | x1, x2 = x_proj.chunk(2, dim=-1) 145 | return x1 * torch.nn.functional.gelu(x2) 146 | 147 | 148 | class FeedForward(nn.Module): 149 | def __init__(self, in_features, out_features): 150 | super(FeedForward, self).__init__() 151 | 152 | self.net = nn.ModuleList( 153 | [ 154 | GEGLU(in_features, out_features * 4), 155 | nn.Dropout(p=0.0, inplace=False), 156 | nn.Linear(out_features * 4, out_features, bias=True), 157 | ] 158 | ) 159 | 160 | def forward(self, x): 161 | for layer in self.net: 162 | x = layer(x) 163 | return x 164 | 165 | 166 | class BasicTransformerBlock(nn.Module): 167 | def __init__(self, hidden_size): 168 | super(BasicTransformerBlock, self).__init__() 169 | self.norm1 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) 170 | self.attn1 = Attention(hidden_size) 171 | self.norm2 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) 172 | self.attn2 = Attention(hidden_size, 1024) 173 | self.norm3 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) 174 | self.ff = FeedForward(hidden_size, hidden_size) 175 | 176 | def forward(self, x, encoder_hidden_states=None): 177 | residual = x 178 | 179 | x = self.norm1(x) 180 | x = self.attn1(x) 181 | x = x + residual 182 | 183 | residual = x 184 | 185 | x = self.norm2(x) 186 | if encoder_hidden_states is not None: 187 | x = self.attn2(x, encoder_hidden_states) 188 | else: 189 | x = self.attn2(x) 190 | x = x + residual 191 | 192 | residual = x 193 | 194 | x = self.norm3(x) 195 | x = self.ff(x) 196 | x = x + residual 197 | return x 198 | 199 | 200 | class Transformer2DModel(nn.Module): 201 | def __init__(self, in_channels, out_channels, n_layers): 202 | super(Transformer2DModel, self).__init__() 203 | self.norm = nn.GroupNorm(32, in_channels, eps=1e-06, affine=True) 204 | self.proj_in = nn.Linear(in_channels, out_channels, bias=True) 205 | self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(out_channels) for _ in range(n_layers)]) 206 | self.proj_out = nn.Linear(out_channels, out_channels, bias=True) 207 | 208 | def forward(self, hidden_states, encoder_hidden_states=None): 209 | batch, _, height, width = hidden_states.shape 210 | res = hidden_states 211 | hidden_states = self.norm(hidden_states) 212 | inner_dim = hidden_states.shape[1] 213 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 214 | hidden_states = self.proj_in(hidden_states) 215 | 216 | for block in self.transformer_blocks: 217 | hidden_states = block(hidden_states, encoder_hidden_states) 218 | 219 | hidden_states = self.proj_out(hidden_states) 220 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 221 | 222 | return hidden_states + res 223 | 224 | 225 | class Downsample2D(nn.Module): 226 | def __init__(self, in_channels, out_channels): 227 | super(Downsample2D, self).__init__() 228 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) 229 | 230 | def forward(self, x): 231 | return self.conv(x) 232 | 233 | 234 | class Upsample2D(nn.Module): 235 | def __init__(self, in_channels, out_channels): 236 | super(Upsample2D, self).__init__() 237 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 238 | 239 | def forward(self, x): 240 | x = F.interpolate(x, scale_factor=2.0, mode="nearest") 241 | return self.conv(x) 242 | 243 | 244 | class DownBlock2D(nn.Module): 245 | def __init__(self, in_channels, out_channels, has_downsamplers=True): 246 | super(DownBlock2D, self).__init__() 247 | self.resnets = nn.ModuleList( 248 | [ 249 | ResnetBlock2D(in_channels, out_channels, conv_shortcut=False), 250 | ResnetBlock2D(out_channels, out_channels, conv_shortcut=False), 251 | ] 252 | ) 253 | self.downsamplers = None 254 | if has_downsamplers: 255 | self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) 256 | 257 | def forward(self, hidden_states, temb): 258 | output_states = [] 259 | for module in self.resnets: 260 | hidden_states = module(hidden_states, temb) 261 | output_states.append(hidden_states) 262 | 263 | if self.downsamplers is not None: 264 | hidden_states = self.downsamplers[0](hidden_states) 265 | output_states.append(hidden_states) 266 | 267 | return hidden_states, output_states 268 | 269 | 270 | class CrossAttnDownBlock2D(nn.Module): 271 | def __init__(self, in_channels, out_channels, n_layers, has_downsamplers=True, conv_shortcut=True): 272 | super(CrossAttnDownBlock2D, self).__init__() 273 | self.attentions = nn.ModuleList( 274 | [ 275 | Transformer2DModel(out_channels, out_channels, n_layers), 276 | Transformer2DModel(out_channels, out_channels, n_layers), 277 | ] 278 | ) 279 | self.resnets = nn.ModuleList( 280 | [ 281 | ResnetBlock2D(in_channels, out_channels, conv_shortcut), 282 | ResnetBlock2D(out_channels, out_channels, conv_shortcut=False), 283 | ] 284 | ) 285 | self.downsamplers = None 286 | if has_downsamplers: 287 | self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) 288 | 289 | def forward(self, hidden_states, temb, encoder_hidden_states): 290 | output_states = [] 291 | for resnet, attn in zip(self.resnets, self.attentions): 292 | hidden_states = resnet(hidden_states, temb) 293 | hidden_states = attn( 294 | hidden_states, 295 | encoder_hidden_states=encoder_hidden_states, 296 | ) 297 | output_states.append(hidden_states) 298 | 299 | if self.downsamplers is not None: 300 | hidden_states = self.downsamplers[0](hidden_states) 301 | output_states.append(hidden_states) 302 | 303 | return hidden_states, output_states 304 | 305 | 306 | class CrossAttnUpBlock2D(nn.Module): 307 | def __init__(self, in_channels, out_channels, prev_output_channel, n_layers, has_upsamplers=True): 308 | super(CrossAttnUpBlock2D, self).__init__() 309 | self.attentions = nn.ModuleList( 310 | [ 311 | Transformer2DModel(out_channels, out_channels, n_layers), 312 | Transformer2DModel(out_channels, out_channels, n_layers), 313 | Transformer2DModel(out_channels, out_channels, n_layers), 314 | ] 315 | ) 316 | self.resnets = nn.ModuleList( 317 | [ 318 | ResnetBlock2D(prev_output_channel + out_channels, out_channels), 319 | ResnetBlock2D(2 * out_channels, out_channels), 320 | ResnetBlock2D(out_channels + in_channels, out_channels), 321 | ] 322 | ) 323 | 324 | self.upsamplers = None 325 | if has_upsamplers: 326 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) 327 | 328 | def forward(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states): 329 | all_hidden_states = [] 330 | for resnet, attn in zip(self.resnets, self.attentions): 331 | # pop res hidden states 332 | res_hidden_states = res_hidden_states_tuple[-1] 333 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 334 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 335 | hidden_states = resnet(hidden_states, temb) 336 | hidden_states = attn( 337 | hidden_states, 338 | encoder_hidden_states=encoder_hidden_states, 339 | ) 340 | all_hidden_states.append(hidden_states) 341 | 342 | if self.upsamplers is not None: 343 | for upsampler in self.upsamplers: 344 | hidden_states = upsampler(hidden_states) 345 | all_hidden_states[-1] = hidden_states 346 | 347 | return hidden_states, all_hidden_states 348 | 349 | 350 | class UpBlock2D(nn.Module): 351 | def __init__(self, in_channels, out_channels, prev_output_channel, has_upsamplers=True): 352 | super(UpBlock2D, self).__init__() 353 | self.resnets = nn.ModuleList( 354 | [ 355 | ResnetBlock2D(out_channels + prev_output_channel, out_channels), 356 | ResnetBlock2D(out_channels * 2, out_channels), 357 | ResnetBlock2D(out_channels + in_channels, out_channels), 358 | ] 359 | ) 360 | 361 | self.upsamplers = None 362 | if has_upsamplers: 363 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) 364 | 365 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None): 366 | all_hidden_states = [] 367 | for resnet in self.resnets: 368 | res_hidden_states = res_hidden_states_tuple[-1] 369 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 370 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 371 | hidden_states = resnet(hidden_states, temb) 372 | all_hidden_states.append(hidden_states) 373 | 374 | if self.upsamplers is not None: 375 | for upsampler in self.upsamplers: 376 | hidden_states = upsampler(hidden_states) 377 | all_hidden_states[-1] = hidden_states 378 | 379 | return hidden_states, all_hidden_states 380 | 381 | 382 | class UNetMidBlock2DCrossAttn(nn.Module): 383 | def __init__(self, in_features): 384 | super(UNetMidBlock2DCrossAttn, self).__init__() 385 | self.attentions = nn.ModuleList([Transformer2DModel(in_features, in_features, n_layers=1)]) 386 | self.resnets = nn.ModuleList( 387 | [ 388 | ResnetBlock2D(in_features, in_features, conv_shortcut=False), 389 | ResnetBlock2D(in_features, in_features, conv_shortcut=False), 390 | ] 391 | ) 392 | 393 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None): 394 | hidden_states = self.resnets[0](hidden_states, temb) 395 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 396 | hidden_states = attn( 397 | hidden_states, 398 | encoder_hidden_states=encoder_hidden_states, 399 | ) 400 | hidden_states = resnet(hidden_states, temb) 401 | 402 | return hidden_states 403 | 404 | 405 | class SD21UNetModel(nn.Module): 406 | def __init__(self): 407 | super(SD21UNetModel, self).__init__() 408 | 409 | # This is needed to imitate huggingface config behavior 410 | # has nothing to do with the model itself 411 | # remove this if you don't use diffuser's pipeline 412 | self.config = namedtuple("config", "in_channels time_cond_proj_dim sample_size") 413 | self.config.in_channels = 4 414 | # self.config.addition_time_embed_dim = 256 415 | self.config.sample_size = 96 416 | self.config.time_cond_proj_dim = None 417 | 418 | self.conv_in = nn.Conv2d(4, 320, kernel_size=3, stride=1, padding=1) 419 | self.time_proj = Timesteps() 420 | self.time_embedding = TimestepEmbedding(in_features=320, out_features=1280) 421 | self.down_blocks = nn.ModuleList( 422 | [ 423 | CrossAttnDownBlock2D(in_channels=320, out_channels=320, n_layers=1, conv_shortcut=False), 424 | CrossAttnDownBlock2D(in_channels=320, out_channels=640, n_layers=1), 425 | CrossAttnDownBlock2D(in_channels=640, out_channels=1280, n_layers=1), 426 | DownBlock2D(in_channels=1280, out_channels=1280, has_downsamplers=False), 427 | ] 428 | ) 429 | self.up_blocks = nn.ModuleList( 430 | [ 431 | UpBlock2D(in_channels=1280, out_channels=1280, prev_output_channel=1280), 432 | CrossAttnUpBlock2D( 433 | in_channels=640, 434 | out_channels=1280, 435 | prev_output_channel=1280, 436 | n_layers=1, 437 | ), 438 | CrossAttnUpBlock2D( 439 | in_channels=320, 440 | out_channels=640, 441 | prev_output_channel=1280, 442 | n_layers=1, 443 | ), 444 | CrossAttnUpBlock2D( 445 | in_channels=320, out_channels=320, prev_output_channel=640, n_layers=1, has_upsamplers=False 446 | ), 447 | ] 448 | ) 449 | 450 | self.mid_block = UNetMidBlock2DCrossAttn(1280) 451 | self.conv_norm_out = nn.GroupNorm(32, 320, eps=1e-05, affine=True) 452 | self.conv_act = nn.SiLU() 453 | self.conv_out = nn.Conv2d(320, 4, kernel_size=3, stride=1, padding=1) 454 | 455 | def forward(self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs): 456 | # Implement the forward pass through the model 457 | timesteps = timesteps.expand(sample.shape[0]) 458 | t_emb = self.time_proj(timesteps).to(dtype=sample.dtype) 459 | emb = self.time_embedding(t_emb) 460 | 461 | sample = self.conv_in(sample) 462 | 463 | # 3. down 464 | s0 = sample 465 | sample, [s1, s2, s3] = self.down_blocks[0]( 466 | sample, 467 | temb=emb, 468 | encoder_hidden_states=encoder_hidden_states, 469 | ) 470 | 471 | sample, [s4, s5, s6] = self.down_blocks[1]( 472 | sample, 473 | temb=emb, 474 | encoder_hidden_states=encoder_hidden_states, 475 | ) 476 | 477 | sample, [s7, s8, s9] = self.down_blocks[2]( 478 | sample, 479 | temb=emb, 480 | encoder_hidden_states=encoder_hidden_states, 481 | ) 482 | 483 | sample, [s10, s11] = self.down_blocks[3]( 484 | sample, 485 | temb=emb, 486 | ) 487 | 488 | # 4. mid 489 | sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) 490 | 491 | # 5. up 492 | _, [us1, us2, us3] = self.up_blocks[0]( 493 | hidden_states=sample, 494 | temb=emb, 495 | res_hidden_states_tuple=[s9, s10, s11], 496 | ) 497 | 498 | _, [us4, us5, us6] = self.up_blocks[1]( 499 | hidden_states=us3, 500 | temb=emb, 501 | res_hidden_states_tuple=[s6, s7, s8], 502 | encoder_hidden_states=encoder_hidden_states, 503 | ) 504 | 505 | _, [us7, us8, us9] = self.up_blocks[2]( 506 | hidden_states=us6, 507 | temb=emb, 508 | res_hidden_states_tuple=[s3, s4, s5], 509 | encoder_hidden_states=encoder_hidden_states, 510 | ) 511 | 512 | _, [us10, us11, us12] = self.up_blocks[3]( 513 | hidden_states=us9, 514 | temb=emb, 515 | res_hidden_states_tuple=[s0, s1, s2], 516 | encoder_hidden_states=encoder_hidden_states, 517 | ) 518 | 519 | # 6. post-process 520 | sample = self.conv_norm_out(us12) 521 | sample = self.conv_act(sample) 522 | sample = self.conv_out(sample) 523 | 524 | return [sample] 525 | -------------------------------------------------------------------------------- /src/sd_feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import einops 5 | from diffusers import DiffusionPipeline 6 | from jaxtyping import Float, Int 7 | from pydoc import locate 8 | from typing import Literal 9 | from .layers import FeedForwardBlock, FourierFeatures, Linear, MappingNetwork 10 | from .min_sd15 import SD15UNetModel 11 | from .min_sd21 import SD21UNetModel 12 | 13 | 14 | class SD15UNetFeatureExtractor(SD15UNetModel): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs): 19 | timesteps = timesteps.expand(sample.shape[0]) 20 | t_emb = self.time_proj(timesteps).to(dtype=sample.dtype) 21 | emb = self.time_embedding(t_emb) 22 | 23 | sample = self.conv_in(sample) 24 | 25 | # 3. down 26 | s0 = sample 27 | sample, [s1, s2, s3] = self.down_blocks[0]( 28 | sample, 29 | temb=emb, 30 | encoder_hidden_states=encoder_hidden_states, 31 | ) 32 | 33 | sample, [s4, s5, s6] = self.down_blocks[1]( 34 | sample, 35 | temb=emb, 36 | encoder_hidden_states=encoder_hidden_states, 37 | ) 38 | 39 | sample, [s7, s8, s9] = self.down_blocks[2]( 40 | sample, 41 | temb=emb, 42 | encoder_hidden_states=encoder_hidden_states, 43 | ) 44 | 45 | sample, [s10, s11] = self.down_blocks[3]( 46 | sample, 47 | temb=emb, 48 | ) 49 | 50 | # 4. mid 51 | sample_mid = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) 52 | 53 | # 5. up 54 | _, [us1, us2, us3] = self.up_blocks[0]( 55 | hidden_states=sample_mid, 56 | temb=emb, 57 | res_hidden_states_tuple=[s9, s10, s11], 58 | ) 59 | 60 | _, [us4, us5, us6] = self.up_blocks[1]( 61 | hidden_states=us3, 62 | temb=emb, 63 | res_hidden_states_tuple=[s6, s7, s8], 64 | encoder_hidden_states=encoder_hidden_states, 65 | ) 66 | 67 | _, [us7, us8, us9] = self.up_blocks[2]( 68 | hidden_states=us6, 69 | temb=emb, 70 | res_hidden_states_tuple=[s3, s4, s5], 71 | encoder_hidden_states=encoder_hidden_states, 72 | ) 73 | 74 | _, [us10, us11, _] = self.up_blocks[3]( 75 | hidden_states=us9, 76 | temb=emb, 77 | res_hidden_states_tuple=[s0, s1, s2], 78 | encoder_hidden_states=encoder_hidden_states, 79 | ) 80 | 81 | return { 82 | "mid": sample_mid, 83 | "us1": us1, 84 | "us2": us2, 85 | "us3": us3, 86 | "us4": us4, 87 | "us5": us5, 88 | "us6": us6, 89 | "us7": us7, 90 | "us8": us8, 91 | "us9": us9, 92 | "us10": us10, 93 | } 94 | 95 | 96 | class SD21UNetFeatureExtractor(SD21UNetModel): 97 | def __init__(self): 98 | super().__init__() 99 | 100 | def forward(self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs): 101 | timesteps = timesteps.expand(sample.shape[0]) 102 | t_emb = self.time_proj(timesteps).to(dtype=sample.dtype) 103 | emb = self.time_embedding(t_emb) 104 | 105 | sample = self.conv_in(sample) 106 | 107 | # 3. down 108 | s0 = sample 109 | sample, [s1, s2, s3] = self.down_blocks[0]( 110 | sample, 111 | temb=emb, 112 | encoder_hidden_states=encoder_hidden_states, 113 | ) 114 | 115 | sample, [s4, s5, s6] = self.down_blocks[1]( 116 | sample, 117 | temb=emb, 118 | encoder_hidden_states=encoder_hidden_states, 119 | ) 120 | 121 | sample, [s7, s8, s9] = self.down_blocks[2]( 122 | sample, 123 | temb=emb, 124 | encoder_hidden_states=encoder_hidden_states, 125 | ) 126 | 127 | sample, [s10, s11] = self.down_blocks[3]( 128 | sample, 129 | temb=emb, 130 | ) 131 | 132 | # 4. mid 133 | sample_mid = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) 134 | 135 | # 5. up 136 | _, [us1, us2, us3] = self.up_blocks[0]( 137 | hidden_states=sample_mid, 138 | temb=emb, 139 | res_hidden_states_tuple=[s9, s10, s11], 140 | ) 141 | 142 | _, [us4, us5, us6] = self.up_blocks[1]( 143 | hidden_states=us3, 144 | temb=emb, 145 | res_hidden_states_tuple=[s6, s7, s8], 146 | encoder_hidden_states=encoder_hidden_states, 147 | ) 148 | 149 | _, [us7, us8, us9] = self.up_blocks[2]( 150 | hidden_states=us6, 151 | temb=emb, 152 | res_hidden_states_tuple=[s3, s4, s5], 153 | encoder_hidden_states=encoder_hidden_states, 154 | ) 155 | 156 | _, [us10, us11, _] = self.up_blocks[3]( 157 | hidden_states=us9, 158 | temb=emb, 159 | res_hidden_states_tuple=[s0, s1, s2], 160 | encoder_hidden_states=encoder_hidden_states, 161 | ) 162 | 163 | return { 164 | "mid": sample_mid, 165 | "us1": us1, 166 | "us2": us2, 167 | "us3": us3, 168 | "us4": us4, 169 | "us5": us5, 170 | "us6": us6, 171 | "us7": us7, 172 | "us8": us8, 173 | "us9": us9, 174 | "us10": us10, 175 | } 176 | 177 | class FeedForwardBlockCustom(FeedForwardBlock): 178 | def __init__(self, d_model: int, d_ff: int, d_cond_norm: int = None, norm_type: Literal['AdaRMS', 'FiLM'] = 'AdaRMS', use_gating: bool = True): 179 | super().__init__(d_model=d_model, d_ff=d_ff, d_cond_norm=d_cond_norm) 180 | if not use_gating: 181 | self.up_proj = LinearSwish(d_model, d_ff, bias=False) 182 | if norm_type == 'FiLM': 183 | self.norm = FiLMNorm(d_model, d_cond_norm) 184 | 185 | class FFNStack(nn.Module): 186 | def __init__(self, dim: int, depth: int, ffn_expansion: float, dim_cond: int, 187 | norm_type: Literal['AdaRMS', 'FiLM'] = 'AdaRMS', use_gating: bool = True) -> None: 188 | super().__init__() 189 | self.layers = nn.ModuleList( 190 | [FeedForwardBlockCustom(d_model=dim, d_ff=int(dim * ffn_expansion), d_cond_norm=dim_cond, norm_type=norm_type, use_gating=use_gating) 191 | for _ in range(depth)]) 192 | 193 | def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: 194 | for layer in self.layers: 195 | x = layer(x, cond_norm=cond) 196 | return x 197 | 198 | class FiLMNorm(nn.Module): 199 | def __init__(self, features, cond_features): 200 | super().__init__() 201 | self.linear = Linear(cond_features, features * 2, bias=False) 202 | self.feature_dim = features 203 | 204 | def forward(self, x, cond): 205 | B, _, D = x.shape 206 | scale, shift = self.linear(cond).chunk(2, dim=-1) 207 | # broadcast scale and shift across all features 208 | scale = scale.view(B, 1, D) 209 | shift = scale.view(B, 1, D) 210 | return scale * x + shift 211 | 212 | class LinearSwish(nn.Linear): 213 | def __init__(self, in_features, out_features, bias=True): 214 | super().__init__(in_features, out_features, bias=bias) 215 | 216 | def forward(self, x): 217 | return F.silu(super().forward(x)) 218 | 219 | 220 | class ArgSequential(nn.Module): # Utility class to enable instantiating nn.Sequential instances with Hydra 221 | def __init__(self, *layers) -> None: 222 | super().__init__() 223 | self.layers = nn.ModuleList(layers) 224 | 225 | def forward(self, x, *args, **kwargs): 226 | for layer in self.layers: 227 | x = layer(x, *args, **kwargs) 228 | return x 229 | 230 | class StableFeatureAligner(nn.Module): 231 | def __init__( 232 | self, 233 | ae: nn.Module, 234 | mapping, 235 | adapter_layer_class: str, 236 | feature_dims: dict[str, int], 237 | feature_extractor_cls: str, 238 | sd_version: Literal["sd15", "sd21"], 239 | adapter_layer_params: dict = {}, 240 | use_text_condition: bool = False, 241 | t_min: int = 1, 242 | t_max: int = 999, 243 | t_max_model: int = 999, 244 | num_t_stratification_bins: int = 3, 245 | alignment_loss: Literal["cossim", "mse", "l1"] = "cossim", 246 | train_unet: bool = True, 247 | train_adapter: bool = True, 248 | t_init: int = 261, 249 | learn_timestep: bool = False, 250 | val_dataset: torch.utils.data.Dataset | None = None, 251 | val_t: int = 261, 252 | val_feature_key: str = "us6", 253 | val_chunk_size: int = 10, 254 | use_adapters: bool = True 255 | ): 256 | super().__init__() 257 | self.ae = ae 258 | self.sd_version = sd_version 259 | self.val_t = val_t 260 | self.val_feature_key = val_feature_key 261 | self.val_dataset = val_dataset 262 | self.val_chunk_size = val_chunk_size 263 | self.use_adapters = use_adapters 264 | 265 | if sd_version == "sd15": 266 | self.repo = "stable-diffusion-v1-5/stable-diffusion-v1-5" 267 | elif sd_version == "sd21": 268 | self.repo = "stabilityai/stable-diffusion-2-1" 269 | else: 270 | raise ValueError(f"Invalid SD version: {sd_version}") 271 | 272 | self.mapping = None 273 | if use_adapters: 274 | self.time_emb = FourierFeatures(1, mapping.width) 275 | self.time_in_proj = Linear(mapping.width, mapping.width, bias=False) 276 | self.mapping = MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=mapping.dropout) 277 | self.mapping.compile() 278 | 279 | if use_adapters: 280 | self.adapters = nn.ModuleDict() 281 | for k, dim in feature_dims.items(): 282 | self.adapters[k] = locate(adapter_layer_class)(dim=dim, **adapter_layer_params) 283 | self.adapters[k].requires_grad_(train_adapter) 284 | 285 | self.unet_feature_extractor_base = locate(feature_extractor_cls)().cuda() 286 | self.pipe = DiffusionPipeline.from_pretrained( 287 | self.repo, 288 | torch_dtype=torch.bfloat16, 289 | use_safetensors=True, 290 | ).to("cuda") 291 | self.unet_feature_extractor_base.load_state_dict(self.pipe.unet.state_dict()) 292 | self.unet_feature_extractor_base.eval() 293 | self.unet_feature_extractor_base.requires_grad_(False) 294 | self.unet_feature_extractor_base.compile() 295 | 296 | self.unet_feature_extractor_cleandift = locate(feature_extractor_cls)().cuda() 297 | self.unet_feature_extractor_cleandift.load_state_dict( 298 | {k: v.detach().clone() for k, v in self.unet_feature_extractor_base.state_dict().items()} 299 | ) 300 | 301 | if train_unet or learn_timestep: 302 | self.unet_feature_extractor_cleandift.train() 303 | else: 304 | self.unet_feature_extractor_cleandift.eval() 305 | self.unet_feature_extractor_cleandift.requires_grad_(train_unet) 306 | self.unet_feature_extractor_cleandift.compile() 307 | 308 | self.use_text_condition = use_text_condition 309 | if self.use_text_condition: 310 | self.pipe.text_encoder.compile() 311 | else: 312 | with torch.no_grad(): 313 | prompt_embeds_dict = self.get_prompt_embeds([""]) 314 | self._empty_prompt_embeds = prompt_embeds_dict["prompt_embeds"] 315 | del self.pipe.text_encoder 316 | 317 | del self.pipe.unet, self.pipe.vae 318 | 319 | self.t_min = t_min 320 | self.t_max = t_max 321 | self.t_max_model = t_max_model 322 | self.num_t_stratification_bins = num_t_stratification_bins 323 | self.alignment_loss = alignment_loss 324 | self.timestep = nn.Parameter( 325 | torch.tensor(float(t_init), requires_grad=learn_timestep), requires_grad=learn_timestep 326 | ) 327 | 328 | def get_prompt_embeds(self, prompt: list[str]) -> dict[str, torch.Tensor | None]: 329 | self.prompt_embeds, _ = self.pipe.encode_prompt( 330 | prompt=prompt, 331 | device=torch.device("cuda"), 332 | num_images_per_prompt=1, 333 | do_classifier_free_guidance=False, 334 | ) 335 | return {"prompt_embeds": self.prompt_embeds} 336 | 337 | def _get_unet_conds(self, prompts: list[str], device, dtype, N_T) -> dict[str, torch.Tensor]: 338 | B = len(prompts) 339 | if self.use_text_condition: 340 | prompt_embeds_dict = self.get_prompt_embeds(prompts) 341 | else: 342 | prompt_embeds_dict = {"prompt_embeds": einops.repeat(self._empty_prompt_embeds, "b ... -> (B b) ...", B=B)} 343 | 344 | unet_conds = { 345 | "encoder_hidden_states": einops.repeat( 346 | prompt_embeds_dict["prompt_embeds"], "B ... -> (B N_T) ...", N_T=N_T 347 | ).to(dtype=dtype, device=device), 348 | "added_cond_kwargs": {}, 349 | } 350 | 351 | return unet_conds 352 | 353 | def forward( 354 | self, x: Float[torch.Tensor, "b c h w"], caption: list[str], **kwargs 355 | ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: 356 | B, *_ = x.shape 357 | device = x.device 358 | t_range = self.t_max - self.t_min 359 | t_range_per_bin = t_range / self.num_t_stratification_bins 360 | t: Int[torch.Tensor, "B N_T"] = ( 361 | self.t_min 362 | + torch.rand((B, self.num_t_stratification_bins), device=device) * t_range_per_bin 363 | + torch.arange(0, self.num_t_stratification_bins, device=device)[None, :] * t_range_per_bin 364 | ).long() 365 | B, N_T = t.shape 366 | 367 | with torch.no_grad(): 368 | unet_conds = self._get_unet_conds(caption, device, x.dtype, N_T) 369 | x_0: Float[torch.Tensor, "(B N_T) ..."] = self.ae.encode(x) 370 | x_0 = einops.repeat(x_0, "B ... -> (B N_T) ...", N_T=N_T) 371 | _, *latent_shape = x_0.shape 372 | noise_sample = torch.randn((B * N_T, *latent_shape), device=device, dtype=x.dtype) 373 | 374 | x_t: Float[torch.Tensor, "(B N_T) ..."] = self.pipe.scheduler.add_noise( 375 | x_0, 376 | noise_sample, 377 | einops.rearrange(t, "B N_T -> (B N_T)"), 378 | ) 379 | 380 | feats_base: dict[str, Float[torch.Tensor, "B N_T ..."]] = { 381 | k: einops.rearrange(v, "(B N_T) D H W -> B N_T (H W) D", B=B, N_T=N_T) 382 | for k, v in self.unet_feature_extractor_base( 383 | x_t, 384 | einops.rearrange(t, "B N_T -> (B N_T)"), 385 | **unet_conds, 386 | ).items() 387 | } 388 | 389 | feats_cleandift: dict[str, Float[torch.Tensor, "B N_T ..."]] = { 390 | k: einops.rearrange(v, "(B N_T) D H W -> B N_T (H W) D", N_T=N_T) 391 | for k, v in self.unet_feature_extractor_cleandift( 392 | x_0, 393 | einops.rearrange(torch.ones_like(t) * self.timestep, "B N_T -> (B N_T)"), 394 | **unet_conds, 395 | ).items() 396 | } 397 | 398 | if self.use_adapters: 399 | # time conditioning for adapters 400 | if not self.mapping is None: 401 | map_cond: Float[torch.Tensor, "(B N_T) ..."] = self.mapping( 402 | self.time_in_proj( 403 | self.time_emb( 404 | einops.rearrange(t, "B N_T -> (B N_T) 1").to(dtype=x.dtype, device=device) / self.t_max_model 405 | ) 406 | ) 407 | ) 408 | 409 | feats_cleandift: dict[str, Float[torch.Tensor, "B N_T ..."]] = { 410 | k: einops.rearrange( 411 | self.adapters[k](einops.rearrange(v, "B N_T ... -> (B N_T) ..."), cond=map_cond), 412 | "(B N_T) ... -> B N_T ...", 413 | B=B, 414 | N_T=N_T, 415 | ) 416 | for k, v in feats_cleandift.items() 417 | } 418 | 419 | if self.alignment_loss == "mse": 420 | return {f"mse_{k}": F.mse_loss(feats_cleandift[k], v.detach()) for k, v in feats_base.items()} 421 | elif self.alignment_loss == "l1": 422 | return {f"l1_{k}": F.l1_loss(feats_cleandift[k], v.detach()) for k, v in feats_base.items()} 423 | elif self.alignment_loss == "cossim": 424 | return { 425 | f"neg_cossim_{k}": -F.cosine_similarity(feats_cleandift[k], v.detach(), dim=-1).mean() 426 | for k, v in feats_base.items() 427 | } 428 | else: 429 | raise ValueError(f"Invalid alignment loss type: {self.alignment_loss}") 430 | 431 | @torch.no_grad() 432 | def get_features( 433 | self, 434 | x: Float[torch.Tensor, "b c h w"], 435 | caption: list[str] | None, 436 | t: Int[torch.Tensor, "b"] | None, 437 | feat_key: str, 438 | use_base_model: bool = False, 439 | input_pure_noise: bool = False, 440 | eps: torch.Tensor = None, 441 | ) -> Float[torch.Tensor, "b d h' w'"]: 442 | if use_base_model: 443 | assert not t is None 444 | B, *_ = x.shape 445 | 446 | if caption is None: 447 | caption = [""] * B 448 | 449 | unet_conds = self._get_unet_conds(caption, x.device, x.dtype, 1) 450 | x_0 = self.ae.encode(x) 451 | eps = torch.randn_like(x_0) if eps is None else eps 452 | if input_pure_noise: 453 | assert torch.allclose( 454 | t, torch.full_like(t, 999) 455 | ), "Sanity check. Pure noise means that no x_t is given to the U-Net, just pure noise (eps)." 456 | x_t = eps 457 | else: 458 | x_t = self.pipe.scheduler.add_noise(x_0, eps, t) 459 | 460 | if feat_key is None: 461 | feats = self.unet_feature_extractor_base(x_t, t, **unet_conds) 462 | else: 463 | feats = self.unet_feature_extractor_base(x_t, t, **unet_conds)[feat_key] 464 | return feats 465 | else: 466 | (B, *_), device = x.shape, x.device 467 | 468 | if caption is None: 469 | caption = [""] * B 470 | 471 | unet_conds = self._get_unet_conds(caption, device, x.dtype, 1) 472 | x_0 = self.ae.encode(x) 473 | 474 | feats = self.unet_feature_extractor_cleandift( 475 | x_0, 476 | torch.ones((B,), device=device, dtype=self.timestep.dtype) * self.timestep, 477 | **unet_conds, 478 | ) 479 | 480 | if feat_key is not None: 481 | feats = feats[feat_key] 482 | 483 | feats = einops.rearrange(feats,"B D H W -> B H W D",) 484 | if t is None: 485 | return einops.rearrange(feats, "B H W D -> B D H W") 486 | else: 487 | assert self.use_adapters, "Adapters must be enabled to use t conditioning on cleandift model" 488 | map_cond: Float[torch.Tensor, "B ..."] = self.mapping( 489 | self.time_in_proj(self.time_emb(t[:, None].to(dtype=x.dtype, device=device) / self.t_max_model)) 490 | ) 491 | if feat_key is not None: 492 | return einops.rearrange(self.adapters[feat_key](feats, cond=map_cond), "B H W D -> B D H W") 493 | else: 494 | return {key: einops.rearrange(self.adapters[key](feats[key], cond=map_cond), "B H W D -> B D H W") for key in feats.keys()} -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | from dataclasses import dataclass 6 | from typing import Dict, Union, Any 7 | 8 | def set_seed(seed=42, cuda=True): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | if cuda: 13 | torch.cuda.manual_seed_all(seed) 14 | 15 | 16 | def dict_to(d: Dict[str, Union[torch.Tensor, Any]], **to_kwargs) -> Dict[str, Union[torch.Tensor, Any]]: 17 | return {k: (v.to(**to_kwargs) if isinstance(v, torch.Tensor) else v) for k, v in d.items()} 18 | 19 | 20 | # Helpers 21 | def zero_init(layer): 22 | nn.init.zeros_(layer.weight) 23 | if layer.bias is not None: 24 | nn.init.zeros_(layer.bias) 25 | return layer 26 | 27 | 28 | @dataclass 29 | class MappingSpec: 30 | depth: int 31 | width: int 32 | d_ff: int 33 | dropout: float -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hydra 3 | import logging 4 | import os 5 | import torch 6 | from typing import Optional 7 | from omegaconf import DictConfig, OmegaConf 8 | from src.utils import set_seed, dict_to 9 | from transformers import get_scheduler 10 | from tqdm.auto import tqdm 11 | 12 | 13 | @hydra.main(config_path="configs", config_name="sd15_feature_extractor", version_base="1.1") 14 | def main(cfg: DictConfig): 15 | OmegaConf.resolve(cfg) 16 | set_seed(cfg.seed) 17 | logger = logging.getLogger(f"{__name__}") 18 | device = torch.device("cuda:0") 19 | 20 | # Load model 21 | cfg = hydra.utils.instantiate(cfg) 22 | model = cfg.model.to(device) 23 | model.train() 24 | 25 | data = cfg.data 26 | dataloader_train = data.train_dataloader() 27 | 28 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) 29 | 30 | lr_scheduler = get_scheduler( 31 | name=cfg.lr_scheduler.name, 32 | optimizer=optimizer, 33 | num_warmup_steps=cfg.lr_scheduler.num_warmup_steps, 34 | num_training_steps=cfg.lr_scheduler.num_training_steps, 35 | scheduler_specific_kwargs=OmegaConf.to_container(cfg.lr_scheduler.scheduler_specific_kwargs), 36 | ) 37 | 38 | i_epoch = -1 39 | stop = False 40 | max_steps: Optional[int] = cfg.max_steps 41 | 42 | val_freq: Optional[int] = cfg.val_freq 43 | if not val_freq is None: 44 | dataloader_val = data.val_dataloader() 45 | max_val_steps: Optional[int] = cfg.max_val_steps 46 | checkpoint_freq: Optional[int] = cfg.checkpoint_freq 47 | checkpoint_dir: str = cfg.checkpoint_dir 48 | os.makedirs(checkpoint_dir, exist_ok=True) 49 | 50 | grad_accum_steps = cfg.grad_accum_steps 51 | print(f"grad_accum_steps={grad_accum_steps}") 52 | 53 | step = 0 54 | 55 | while not stop: # Epochs 56 | i_epoch += 1 57 | for batch in ( 58 | pbar := tqdm(dataloader_train, desc=f"Optimizing (Epoch {i_epoch + 1})") 59 | ): 60 | loss_sum = 0 61 | for accum_step in range(grad_accum_steps): 62 | losses = model(**dict_to(batch, device=device)) 63 | loss = sum(v.mean() for v in losses.values()) 64 | loss.backward() 65 | loss_sum += float(loss.detach().item()) 66 | pbar.set_postfix({ 'loss': loss_sum / (accum_step + 1) }) 67 | 68 | optimizer.step() 69 | lr_scheduler.step() 70 | optimizer.zero_grad() 71 | 72 | if not val_freq is None and step % val_freq == 0: 73 | model.eval() 74 | 75 | with torch.no_grad(): 76 | val_losses_accumulated = [] 77 | for i, val_batch in enumerate( 78 | tqdm(dataloader_val, desc=f"Validating", total=max_val_steps) 79 | ): 80 | val_losses = model(**dict_to(val_batch, device=device)) 81 | val_loss = sum(v.mean() for v in val_losses.values()) 82 | val_losses_accumulated.append((val_loss).cpu().item()) 83 | 84 | if max_val_steps is not None and i + 1 >= max_val_steps: 85 | break 86 | 87 | val_loss = sum(val_losses_accumulated) / len(val_losses_accumulated) 88 | logger.info(f"Validation loss: {val_loss}") 89 | 90 | # put model into train mode 91 | model.train() 92 | 93 | if not checkpoint_freq is None and (step + 1) % checkpoint_freq == 0: 94 | checkpoint_path = os.path.join(checkpoint_dir, f"step_{(step + 1)}.pth") 95 | torch.save(model.state_dict(), checkpoint_path) 96 | logger.info(f"Saved checkpoint to {checkpoint_path}") 97 | 98 | if not max_steps is None and step == max_steps: 99 | stop = True 100 | break 101 | 102 | step += 1 103 | 104 | 105 | if __name__ == "__main__": 106 | torch.backends.cuda.matmul.allow_tf32 = True 107 | torch.backends.cudnn.allow_tf32 = True 108 | torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit) 109 | main() 110 | --------------------------------------------------------------------------------