├── LICENSE ├── README.md ├── configs ├── trajscut_fm_flowdcn_b_lion_t100.yml ├── trajscutrw_ddim_lion_dit_t100.yml └── trajscutrw_fm_sit_lion_t100.yml ├── figs ├── coeffs_timestep2.png ├── fid_performance.png ├── flux_cfg2.png ├── method.png └── pixart1024_cfg2.png ├── main.py ├── precompute └── placeholder ├── pretrain_models └── placeholder ├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── celeba.py │ ├── cifar10.py │ ├── cifar100.py │ ├── imagenet.py │ ├── metric_dataset.py │ ├── randn.py │ └── transforms.py ├── diffusion │ ├── __init__.py │ ├── base │ │ ├── guidance.py │ │ ├── sampling.py │ │ └── scheduling.py │ ├── ddpm │ │ ├── ddim_sampling.py │ │ ├── dpmsolver_sampling.py │ │ ├── neuralsolver.py │ │ ├── ns_sampling.py │ │ ├── scheduling.py │ │ └── vp_sampling.py │ ├── flow_matching │ │ ├── adam_sampling.py │ │ ├── neuralsolver.py │ │ ├── ns_sampling.py │ │ ├── sampling.py │ │ └── scheduling.py │ └── solver_training.py ├── lightning_data.py ├── lightning_model.py ├── models │ ├── base_model.py │ ├── dit.py │ └── flowdcn.py ├── ops │ ├── DCNv4_op │ │ ├── DCNv4 │ │ │ ├── __init__.py │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ ├── dcnv4_func.py │ │ │ │ ├── flash_deform_attn_func.py │ │ │ │ └── table.py │ │ │ └── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── dcnv4.py │ │ │ │ └── flash_deform_attn.py │ │ ├── MANIFEST.in │ │ ├── __init__.py │ │ ├── make.sh │ │ ├── scripts │ │ │ ├── find_best.py │ │ │ ├── search_bwd.sh │ │ │ ├── search_dcnv4.py │ │ │ ├── search_dcnv4_bwd.py │ │ │ ├── search_dcnv4_bwd_engine.py │ │ │ ├── search_dcnv4_engine.py │ │ │ ├── search_fwd.sh │ │ │ ├── test_dcnv4.py │ │ │ ├── test_dcnv4_bwd.py │ │ │ ├── test_flash_deform_attn.py │ │ │ └── test_flash_deform_attn_backward.py │ │ ├── setup.py │ │ └── src │ │ │ ├── cuda │ │ │ ├── common.h │ │ │ ├── dcnv4_col2im_cuda.cuh │ │ │ ├── dcnv4_cuda.cu │ │ │ ├── dcnv4_cuda.h │ │ │ ├── dcnv4_im2col_cuda.cuh │ │ │ ├── flash_deform_attn_cuda.cu │ │ │ ├── flash_deform_attn_cuda.h │ │ │ ├── flash_deform_col2im_cuda.cuh │ │ │ └── flash_deform_im2col_cuda.cuh │ │ │ ├── dcnv4.h │ │ │ └── vision.cpp │ ├── cuda_kernels │ │ ├── backward.cu │ │ ├── forward.py │ │ ├── function.py │ │ └── setup.py │ ├── triton_kernels │ │ ├── __init__.py │ │ ├── backward.py │ │ ├── forward.py │ │ └── function.py │ └── triton_kernels_udcn │ │ ├── backward.py │ │ ├── forward.py │ │ ├── function.py │ │ └── utils.py └── utils │ ├── __init__.py │ ├── callbacks.py │ ├── metrics.py │ ├── model_loader.py │ ├── saver.py │ └── vae.py ├── t2i_vis ├── __init__.py ├── flux.ipynb ├── fm_scheduling.py ├── pixart_sigma_1024.py ├── pixart_sigma_256.py ├── pixart_sigma_512.py ├── sd3.ipynb └── vp_scheduling.py └── tools ├── fid_curve ├── eular_steps.py ├── flowdcn_256.py ├── flowdcn_512.py ├── mse_curve.py ├── search_model.py └── sit.py ├── is_curve ├── flowdcn_256.py ├── flowdcn_512.py └── sit.py ├── pr_curve ├── flowdcn_256.py ├── flowdcn_512.py └── sit.py ├── recall_curve ├── flowdcn_256.py ├── flowdcn_512.py └── sit.py └── sfid_curve ├── flowdcn_256.py ├── flowdcn_512.py └── sit.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 wang shuai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuralSolver: Differentiable Solver Search for Fast Diffusion Sampling 2 | ### Link:[https://huggingface.co/papers/2505.21114](https://huggingface.co/papers/2505.21114) 3 | ![](./figs/coeffs_timestep2.png) 4 | This repository contains the code for the paper: 5 | **Differentiable Solver Search for Fast Diffusion Sampling** 6 | #### [NEWS] [11.25] 🍺 Our searched NeuralSolvers and corresponding code are now available in the official repo! 7 | 8 | ## Methods 9 | Our motivation is that Adams-like linear multi-step methods employ lagrange interpolations, ignoring the *x* related terms. 10 | So we defines a universal interpolation function `P(x, t)` and directly estimate the pre-intergal **coefficients and timesteps** used in the sampling. 11 | 12 | ![](./figs/method.png) 13 | 14 | 15 | ## Expertiments 16 | 17 | Compared to Linear-Multi-Step Methods, our NeuralSolvers(searched on FlowDCN-B-R256) consistently improve the FID metrics by a large margin respectively. 18 | 19 | ![](./figs/fid_performance.png) 20 | 21 | We provide a adams-like linear-multi-step solver for the recitified flow sampling. The related configs are named with `adam2` or `adam4`. The solver code are placed in `./src/diffusion/flow_matching/adam_sampling.py`. 22 | 23 | Compared to Henu/RK4, the linear-multi-step solver is more stable and faster. 24 | 25 | | SiT-XL-R256 | Steps | NFE-CFG | Extra-Paramters | FID | IS | PR | Recall | 26 | |--|-------|----------|-----------------|------|-------|------|--------| 27 | | Heun | 8 | 16x2 | 0 | 3.68 | / | / | / | 28 | | Heun | 11 | 22x2 | 0 | 2.79 | / | / | / | 29 | | Heun | 15 | 30x2 | 0 | 2.42 | / | / | / | 30 | | Adam2 | 6 | 6x2 | 0 | 6.35 | 190 | 0.75 | 0.55 | 31 | | Adam2 | 8 | 8x2 | 0 | 4.16 | 212 | 0.78 | 0.56 | 32 | | Adam2 | 16 | 16x2 | 0 | 2.42 | 237 | 0.80 | 0.60 | 33 | | Adam4 | 16 | 16x2 | 0 | 2.27 | 243 | 0.80 | 0.60 | 34 | | FlowTurbo | 6 | (7+3)x2 | 30408704(29M) | 3.93 | 223.6 | 0.79 | 0.56 | 35 | | FlowTurbo | 8 | (8+2)x2 | 30408704(29M) | 3.63 | / | / | / | 36 | | FlowTurbo | 10 | (12+2)x2 | 30408704(29M) | 2.69 | / | / | / | 37 | | FlowTurbo | 15 | (17+3)x2 | 30408704(29M) | 2.22 | 248 | 0.81 | 0.60 | 38 | | NeuralSolver | 6 | 6x2 | 21 | 3.57 | 214 | 0.77 | 0.58 | 39 | | NeuralSolver | 7 | 7x2 | 28 | 2.78 | 229 | 0.79 | 0.60 | 40 | | NeuralSolver | 8 | 8x2 | 36 | 2.65 | 234 | 0.79 | 0.60 | 41 | | NeuralSolver | 10 | 10x2 | 55 | 2.40 | 238 | 0.79 | 0.60 | 42 | | NeuralSolver | 15 | 15x2 | 110 | 2.24 | 244 | 0.80 | 0.60 | 43 | 44 | ## Visualizations of zero-shot T2I 45 | 46 | #### Flux Models with Euler-shift3 and our NeuralSolver(searched on SiT-XL-R256) under CFG=2.0 47 | ![](./figs/flux_cfg2.png) 48 | 49 | #### PixArt Models with UniPC/DPMSolver++ and our NeuralSolver(searched on DiT-XL-R256) under CFG=2.0 50 | ![](./figs/pixart1024_cfg2.png) 51 | 52 | 53 | ## Citation 54 | ```bibtex 55 | @article{wangdifferentiable, 56 | title={Differentiable Solver Search for fast diffusion sampling}, 57 | author={Wang, Shuai and Li, Zexian and Song, Tianhui and Li, Xubin and Ge, Tiezheng and Zheng, Bo and Wang, Limin and others} 58 | } 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /configs/trajscut_fm_flowdcn_b_lion_t100.yml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.2.1 2 | seed_everything: true 3 | torch_hub_dir: null 4 | huggingface_cache_dir: null 5 | tags: 6 | exp: trajscut_fm_flowdcnb_lion_t100 7 | b: &batch_size 4 # batch_per_process 8 | s: &step 6 9 | e: &max_num_epochs 5 10 | trainer: 11 | accelerator: auto 12 | strategy: auto 13 | devices: 0, 14 | num_nodes: 1 15 | precision: 16-mixed 16 | callbacks: 17 | - src.utils.callbacks.DummyCheckpointHook 18 | - class_path: lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint 19 | init_args: 20 | monitor: fid 21 | every_n_epochs: 1 22 | save_top_k: 1 23 | - class_path: src.utils.callbacks.GradientMonitor 24 | fast_dev_run: null 25 | max_epochs: *max_num_epochs 26 | min_epochs: null 27 | max_steps: -1 28 | min_steps: null 29 | max_time: null 30 | limit_train_batches: null 31 | limit_val_batches: null 32 | limit_test_batches: null 33 | limit_predict_batches: null 34 | overfit_batches: 0.0 35 | val_check_interval: null 36 | check_val_every_n_epoch: 1 37 | num_sanity_val_steps: 2 38 | log_every_n_steps: 1 39 | enable_checkpointing: null 40 | enable_progress_bar: null 41 | enable_model_summary: null 42 | accumulate_grad_batches: 1 43 | gradient_clip_val: null 44 | gradient_clip_algorithm: null 45 | deterministic: null 46 | benchmark: null 47 | inference_mode: true 48 | use_distributed_sampler: false 49 | profiler: null 50 | detect_anomaly: false 51 | barebones: false 52 | sync_batchnorm: false 53 | reload_dataloaders_every_n_epochs: 1 54 | default_root_dir: null 55 | model: 56 | vae: 57 | class_path: src.utils.vae.LatentVAE 58 | init_args: 59 | precompute: true 60 | weight_path: stabilityai/sd-vae-ft-ema 61 | train_denoisers: 62 | - class_path: src.models.msd.DeformableDit 63 | init_args: 64 | patch_size: 2 65 | in_channels: 4 66 | num_groups: 12 67 | hidden_size: 768 68 | num_blocks: 12 69 | num_classes: 1000 70 | learn_sigma: true 71 | load_ema: true 72 | weight_path: ./pretrained/flowdcn_s.pt 73 | eval_denoiser: 74 | class_path: src.models.dit.Dit 75 | init_args: 76 | input_size: 32 77 | patch_size: 2 78 | in_channels: 4 79 | num_groups: 16 80 | hidden_size: 1152 81 | num_blocks: 28 82 | num_classes: 1000 83 | learn_sigma: true 84 | load_ema: false 85 | weight_path: ./pretrained/flowdcn_s.pt 86 | metric: 87 | class_path: src.utils.metrics.UnifiedMetric 88 | init_args: 89 | enabled_metrics: 90 | - fid 91 | - is 92 | - sfid 93 | reset_real_features: false 94 | precompute_data_path: 95 | fid: /data/oss_bucket_0/wangshuai/pretrain_models/precompute/imagenet256_fid_train.pt 96 | sfid: /data/oss_bucket_0/wangshuai/pretrain_models/precompute/imagenet256_sfid_train.pt 97 | solver_trainer: 98 | class_path: src.diffusion.solver_training.TrajsTrainer 99 | init_args: 100 | max_cfg_aug: 1.0 101 | min_cfg_aug: 1.0 102 | target_sampler: 103 | class_path: src.diffusion.flow_matching.sampling2.FlowMatchEulerSampler 104 | init_args: 105 | num_steps: 100 106 | scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler 107 | guidance_fn: &guidance_fn src.diffusion.base.guidance.simple_guidance_fn 108 | null_class: 1000 109 | w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler 110 | step_fn: &step_fn src.diffusion.flow_matching.sampling2.ode_step_fn 111 | guidance: &guidance 1.375 112 | source_sampler: 113 | class_path: src.diffusion.flow_matching.neural_sampling_nonsymsolver.FlowMatchNeuralSampler 114 | init_args: 115 | num_steps: *step 116 | null_class: 1000 117 | guidance: *guidance 118 | scheduler: *scheduler 119 | w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler 120 | guidance_fn: *guidance_fn 121 | step_fn: *step_fn 122 | optimizer: 123 | class_path: lion_pytorch.Lion 124 | init_args: 125 | lr: 1e-2 126 | weight_decay: 0.00 127 | data: 128 | test_gen_root: data/pred 129 | test_nature_root: data/val 130 | train_batch_size: *batch_size 131 | train_num_workers: 2 132 | train_prefetch_factor: 8 133 | eval_batch_size: 32 134 | eval_num_workers: 4 135 | eval_max_num_instances: 50000 # fid10k 136 | eval_seeds: null 137 | eval_selected_classes: null 138 | pred_batch_size: 64 139 | pred_num_workers: 2 140 | pred_seeds: null 141 | test_batch_size: 64 142 | test_num_workers: 16 143 | test_image_size: 144 | - 256 145 | - 256 146 | num_classes: 1000 147 | latent_shape: 148 | - 4 149 | - 32 150 | - 32 -------------------------------------------------------------------------------- /configs/trajscutrw_ddim_lion_dit_t100.yml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.2.1 2 | seed_everything: true 3 | torch_hub_dir: null 4 | huggingface_cache_dir: null 5 | tags: 6 | exp: trajscutrw_ddimsx_dit_t100 7 | b: &batch_size 4 # batch_per_process 8 | s: &step 10 9 | e: &max_num_epochs 5 10 | trainer: 11 | accelerator: auto 12 | strategy: auto 13 | devices: 1 14 | num_nodes: 1 15 | precision: 16-mixed 16 | callbacks: 17 | - src.utils.callbacks.DummyCheckpointHook 18 | - class_path: lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint 19 | init_args: 20 | monitor: fid 21 | every_n_epochs: 1 22 | save_top_k: 1 23 | - class_path: src.utils.callbacks.GradientMonitor 24 | fast_dev_run: null 25 | max_epochs: *max_num_epochs 26 | min_epochs: null 27 | max_steps: -1 28 | min_steps: null 29 | max_time: null 30 | limit_train_batches: null 31 | limit_val_batches: null 32 | limit_test_batches: null 33 | limit_predict_batches: null 34 | overfit_batches: 0.0 35 | val_check_interval: null 36 | check_val_every_n_epoch: 1 37 | num_sanity_val_steps: 2 38 | log_every_n_steps: 1 39 | enable_checkpointing: null 40 | enable_progress_bar: null 41 | enable_model_summary: null 42 | accumulate_grad_batches: 1 43 | gradient_clip_val: null 44 | gradient_clip_algorithm: null 45 | deterministic: null 46 | benchmark: null 47 | inference_mode: true 48 | use_distributed_sampler: false 49 | profiler: null 50 | detect_anomaly: false 51 | barebones: false 52 | sync_batchnorm: false 53 | reload_dataloaders_every_n_epochs: 1 54 | default_root_dir: null 55 | model: 56 | vae: 57 | class_path: src.utils.vae.LatentVAE 58 | init_args: 59 | precompute: true 60 | weight_path: stabilityai/sd-vae-ft-ema 61 | train_denoisers: 62 | - class_path: src.models.dit.DiT 63 | init_args: 64 | input_size: 32 65 | patch_size: 2 66 | in_channels: 4 67 | num_groups: 16 68 | hidden_size: 1152 69 | num_blocks: 28 70 | num_classes: &num_classes 1000 71 | learn_sigma: true 72 | load_ema: false 73 | weight_path: ./pretrain_models/DiT-XL-2-256x256.pt 74 | eval_denoiser: 75 | class_path: src.models.dit.DiT 76 | init_args: 77 | input_size: 32 78 | patch_size: 2 79 | in_channels: 4 80 | num_groups: 16 81 | hidden_size: 1152 82 | num_blocks: 28 83 | num_classes: *num_classes 84 | learn_sigma: true 85 | load_ema: false 86 | weight_path: ./pretrain_models/DiT-XL-2-256x256.pt 87 | metric: 88 | class_path: src.utils.metrics.UnifiedMetric 89 | init_args: 90 | enabled_metrics: 91 | - fid 92 | - is 93 | reset_real_features: false 94 | precompute_data_path: 95 | fid: ./precompute/imagenet256_fid_train.pt 96 | solver_trainer: 97 | class_path: src.diffusion.solver_training.TrajsReWeightTrainer 98 | init_args: 99 | max_cfg_aug: 1.0 100 | min_cfg_aug: 1.0 101 | target_sampler: 102 | class_path: src.diffusion.ddpm.ddim_sampling.DDIMSampler 103 | init_args: 104 | num_steps: 100 105 | null_class: *num_classes 106 | guidance: 1.0 107 | scheduler: src.diffusion.ddpm.scheduling.DDPMScheduler 108 | guidance_fn: &guidance_fn src.diffusion.base.guidance.c3_guidance_fn 109 | source_sampler: 110 | class_path: src.diffusion.ddpm.neuralsolver.NeuralSolverSampler 111 | init_args: 112 | num_steps: *step 113 | null_class: *num_classes 114 | train_max_t: 1000 115 | guidance: 1.5 116 | scheduler: src.diffusion.ddpm.scheduling.VPScheduler 117 | guidance_fn: *guidance_fn 118 | optimizer: 119 | class_path: lion_pytorch.Lion 120 | init_args: 121 | lr: 1e-2 122 | weight_decay: 0.00 123 | lr_scheduler: null 124 | save_dir: ditxl_c3_g1.5 125 | data: 126 | test_gen_root: data/pred 127 | test_nature_root: data/val 128 | train_batch_size: *batch_size 129 | train_num_workers: 2 130 | train_prefetch_factor: 8 131 | eval_batch_size: 32 132 | eval_num_workers: 4 133 | eval_max_num_instances: 50000 # fid50k 134 | eval_seeds: null 135 | eval_selected_classes: null 136 | pred_batch_size: 64 137 | pred_num_workers: 2 138 | pred_seeds: null 139 | test_batch_size: 64 140 | test_num_workers: 16 141 | test_image_size: 142 | - 256 143 | - 256 144 | num_classes: *num_classes 145 | latent_shape: 146 | - 4 147 | - 32 148 | - 32 -------------------------------------------------------------------------------- /configs/trajscutrw_fm_sit_lion_t100.yml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.2.1 2 | seed_everything: true 3 | torch_hub_dir: null 4 | huggingface_cache_dir: null 5 | tags: 6 | exp: trajscutrw_fm_sit_lion_t100 7 | b: &batch_size 4 # batch_per_process 8 | s: &step 10 9 | e: &max_num_epochs 20 10 | trainer: 11 | accelerator: auto 12 | strategy: auto 13 | devices: 1 14 | num_nodes: 1 15 | precision: 16-mixed 16 | callbacks: 17 | - src.utils.callbacks.DummyCheckpointHook 18 | - class_path: lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint 19 | init_args: 20 | monitor: fid 21 | every_n_epochs: 1 22 | save_top_k: 1 23 | - class_path: src.utils.callbacks.GradientMonitor 24 | fast_dev_run: null 25 | max_epochs: *max_num_epochs 26 | min_epochs: null 27 | max_steps: -1 28 | min_steps: null 29 | max_time: null 30 | limit_train_batches: null 31 | limit_val_batches: null 32 | limit_test_batches: null 33 | limit_predict_batches: null 34 | overfit_batches: 0.0 35 | val_check_interval: null 36 | check_val_every_n_epoch: 1 37 | num_sanity_val_steps: 2 38 | log_every_n_steps: 1 39 | enable_checkpointing: null 40 | enable_progress_bar: null 41 | enable_model_summary: null 42 | accumulate_grad_batches: 1 43 | gradient_clip_val: null 44 | gradient_clip_algorithm: null 45 | deterministic: null 46 | benchmark: null 47 | inference_mode: true 48 | use_distributed_sampler: false 49 | profiler: null 50 | detect_anomaly: false 51 | barebones: false 52 | sync_batchnorm: false 53 | reload_dataloaders_every_n_epochs: 1 54 | default_root_dir: null 55 | model: 56 | vae: 57 | class_path: src.utils.vae.LatentVAE 58 | init_args: 59 | precompute: true 60 | weight_path: stabilityai/sd-vae-ft-ema 61 | train_denoisers: 62 | - class_path: src.models.dit.DiT 63 | init_args: 64 | input_size: 32 65 | patch_size: 2 66 | in_channels: 4 67 | num_groups: 16 68 | hidden_size: 1152 69 | num_blocks: 28 70 | num_classes: 1000 71 | learn_sigma: true 72 | load_ema: false 73 | weight_path: ./pretrain_models/SiT-XL-2-256.pt 74 | eval_denoiser: 75 | class_path: src.models.dit.DiT 76 | init_args: 77 | input_size: 32 78 | patch_size: 2 79 | in_channels: 4 80 | num_groups: 16 81 | hidden_size: 1152 82 | num_blocks: 28 83 | num_classes: 1000 84 | learn_sigma: true 85 | load_ema: false 86 | weight_path: ./pretrain_models/SiT-XL-2-256.pt 87 | metric: 88 | class_path: src.utils.metrics.UnifiedMetric 89 | init_args: 90 | enabled_metrics: 91 | - fid 92 | - is 93 | reset_real_features: false 94 | precompute_data_path: 95 | fid: ./precompute/imagenet256_fid_train.pt 96 | solver_trainer: 97 | class_path: src.diffusion.solver_training.TrajsReWeightTrainer 98 | init_args: 99 | max_cfg_aug: 1.0 100 | min_cfg_aug: 1.0 101 | target_sampler: 102 | class_path: src.diffusion.flow_matching.sampling.EulerSampler 103 | init_args: 104 | num_steps: 100 105 | scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler 106 | guidance_fn: &guidance_fn src.diffusion.base.guidance.simple_guidance_fn 107 | null_class: 1000 108 | w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler 109 | step_fn: &step_fn src.diffusion.flow_matching.sampling.ode_step_fn 110 | guidance: &guidance 1.375 111 | source_sampler: 112 | class_path: src.diffusion.flow_matching.neuralsolver.NeuralSolverSampler 113 | init_args: 114 | num_steps: *step 115 | null_class: 1000 116 | guidance: *guidance 117 | scheduler: *scheduler 118 | w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler 119 | guidance_fn: *guidance_fn 120 | step_fn: *step_fn 121 | optimizer: 122 | class_path: lion_pytorch.Lion 123 | init_args: 124 | lr: 1e-2 125 | weight_decay: 0.00 126 | data: 127 | test_gen_root: data/pred 128 | test_nature_root: data/val 129 | train_batch_size: *batch_size 130 | train_num_workers: 2 131 | train_prefetch_factor: 8 132 | eval_batch_size: 32 133 | eval_num_workers: 4 134 | eval_max_num_instances: 50000 # fid50k 135 | eval_seeds: null 136 | eval_selected_classes: null 137 | pred_batch_size: 64 138 | pred_num_workers: 2 139 | pred_seeds: null 140 | test_batch_size: 64 141 | test_num_workers: 16 142 | test_image_size: 143 | - 256 144 | - 256 145 | num_classes: 1000 146 | latent_shape: 147 | - 4 148 | - 32 149 | - 32 -------------------------------------------------------------------------------- /figs/coeffs_timestep2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/figs/coeffs_timestep2.png -------------------------------------------------------------------------------- /figs/fid_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/figs/fid_performance.png -------------------------------------------------------------------------------- /figs/flux_cfg2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/figs/flux_cfg2.png -------------------------------------------------------------------------------- /figs/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/figs/method.png -------------------------------------------------------------------------------- /figs/pixart1024_cfg2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/figs/pixart1024_cfg2.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import os 3 | import torch 4 | from lightning import Trainer 5 | from lightning.pytorch.cli import LightningCLI, LightningArgumentParser 6 | import os 7 | from src.lightning_data import DataModule 8 | from src.lightning_model import LightningModel 9 | # from src.utils.logger import WandbSaveConfigCallback 10 | 11 | class ReWriteRootDirCli(LightningCLI): 12 | 13 | def before_instantiate_classes(self) -> None: 14 | super().before_instantiate_classes() 15 | subcommand = self.subcommand 16 | 17 | # convert local batch_size to global batch_size 18 | num_nodes = self.config[subcommand]["trainer"]["num_nodes"] 19 | self.config[subcommand]["tags"]['b'] = str(self.config[subcommand]["tags"]['b']) +f"x{num_nodes}" 20 | 21 | # formulate the root dir 22 | default_root_dir = self.config[subcommand]["trainer"]["default_root_dir"] 23 | if default_root_dir is None: 24 | default_root_dir = os.path.join(os.getcwd(), "workdirs") 25 | dirname = "" 26 | for v, k in self.config[subcommand]["tags"].items(): 27 | dirname+=f"_{v}{k}" 28 | 29 | dirname = dirname[1:] 30 | default_root_dir = os.path.join(default_root_dir, dirname) 31 | self.config[subcommand]["trainer"]["default_root_dir"] = default_root_dir 32 | 33 | # predict without logger 34 | if subcommand == "predict": 35 | self.config[subcommand]["trainer"]["logger"] = None 36 | 37 | # # predict path check 38 | # if subcommand == "predict": 39 | # pred_root = os.path.join(default_root_dir,self.config[subcommand]["model"]["save_dir"]) 40 | # if os.path.exists(pred_root): 41 | # if len(os.listdir(pred_root)) != 0: 42 | # raise ValueError(f"Prediction path {pred_root} is not empty") 43 | # make stupid odps happy 44 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 45 | class TagsClass: 46 | def __init__(self, exp:str, b:int|str, s:int|str, e:int,): 47 | ... 48 | parser.add_class_arguments(TagsClass, nested_key="tags") 49 | parser.link_arguments("model.metric.precompute_data_path", "data.test_only_gen_data") 50 | 51 | def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 52 | super().add_default_arguments_to_parser(parser) 53 | parser.add_argument("--tables",type=str, default="", help=("make nebu happy" ),) 54 | parser.add_argument("--torch_hub_dir", type=str, default=None, help=("torch hub dir"),) 55 | parser.add_argument("--huggingface_cache_dir", type=str, default=None, help=("huggingface hub dir"),) 56 | 57 | def instantiate_trainer(self, **kwargs: Any) -> Trainer: 58 | trainer = super().instantiate_trainer(**kwargs) 59 | return trainer 60 | 61 | def instantiate_classes(self) -> None: 62 | torch_hub_dir = self._get(self.config, "torch_hub_dir") 63 | huggingface_cache_dir = self._get(self.config, "huggingface_cache_dir") 64 | if huggingface_cache_dir is not None: 65 | os.environ["HUGGINGFACE_HUB_CACHE"] = huggingface_cache_dir 66 | if torch_hub_dir is not None: 67 | os.environ["TORCH_HOME"] = torch_hub_dir 68 | torch.hub.set_dir(torch_hub_dir) 69 | super().instantiate_classes() 70 | 71 | def cli_main(): 72 | # ignore all warnings that could be false positives 73 | torch.set_float32_matmul_precision('medium') 74 | cli = ReWriteRootDirCli(LightningModel, DataModule, auto_configure_optimizers=False, save_config_kwargs={"overwrite": True}) 75 | 76 | if __name__ == "__main__": 77 | cli_main() 78 | -------------------------------------------------------------------------------- /precompute/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/precompute/placeholder -------------------------------------------------------------------------------- /pretrain_models/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/pretrain_models/placeholder -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/data/celeba.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from torchvision.datasets import CelebA 3 | 4 | 5 | class LocalDataset(CelebA): 6 | def __init__(self, root:str, data_convert:Callable, ): 7 | super(LocalDataset, self).__init__(root, "train") 8 | self.data_convert = data_convert 9 | 10 | def __getitem__(self, idx): 11 | data = super().__getitem__(idx) 12 | img, label = self.data_convert(data) 13 | return img, label -------------------------------------------------------------------------------- /src/data/cifar10.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from torchvision.datasets import CIFAR10 3 | 4 | 5 | class LocalDataset(CIFAR10): 6 | def __init__(self, root:str, data_convert:Callable): 7 | super(LocalDataset, self).__init__(root, True) 8 | self.data_convert = data_convert 9 | 10 | def __getitem__(self, idx): 11 | data = super().__getitem__(idx) 12 | img, label = self.data_convert(data) 13 | return img, label 14 | -------------------------------------------------------------------------------- /src/data/cifar100.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from torchvision.datasets import CIFAR100 3 | 4 | 5 | class LocalDataset(CIFAR100): 6 | def __init__(self, root:str, data_convert:Callable): 7 | super(LocalDataset, self).__init__(root, True) 8 | self.data_convert = data_convert 9 | 10 | def __getitem__(self, idx): 11 | data = super().__getitem__(idx) 12 | img, label = self.data_convert(data) 13 | return img, label 14 | -------------------------------------------------------------------------------- /src/data/imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from typing import Callable 4 | from torchvision.datasets import ImageFolder 5 | 6 | class LocalDataset(ImageFolder): 7 | def __init__(self, root, 8 | data_convert:Callable, 9 | ): 10 | super().__init__(root) 11 | self.data_convert = data_convert 12 | def __getitem__(self, idx): 13 | data = super().__getitem__(idx) 14 | img, label = self.data_convert(data) 15 | return img, label 16 | import numpy as np 17 | from torch.utils.data import Dataset 18 | class LocalCachedDataset(Dataset): 19 | def __init__(self, root,): 20 | super().__init__() 21 | self.root = root 22 | cache_names_file = os.path.join(root, 'cache_names.txt') 23 | with open(cache_names_file, 'r') as f: 24 | self.filenames = f.readlines() 25 | self.filenames = sorted(self.filenames) 26 | self.filenames = [x.strip() for x in self.filenames] 27 | def __getitem__(self, idx: int): 28 | filename = os.path.join(self.root,self.filenames[idx]) 29 | pk_data = torch.load(filename) 30 | mean = pk_data['mean'] 31 | logvar = pk_data['logvar'] 32 | label = pk_data['label'] 33 | if "fliped_mean" in pk_data.keys(): 34 | if np.random.rand() > 0.5: 35 | mean = pk_data['fliped_mean'] 36 | logvar = pk_data['fliped_logvar'] 37 | logvar = torch.clamp(logvar, -30.0, 20.0) 38 | std = torch.exp(0.5 * logvar) 39 | sample = mean + torch.randn_like(mean) * std 40 | return sample, label 41 | def __len__(self) -> int: 42 | return len(self.filenames) -------------------------------------------------------------------------------- /src/data/metric_dataset.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import torch 4 | import random 5 | from torchvision.io.image import read_image 6 | import torchvision.transforms as tvtf 7 | from torch.utils.data import Dataset 8 | from src.data.transforms import CenterCrop 9 | from PIL import Image 10 | IMG_EXTENSIONS = ( 11 | "*.png", 12 | "*.JPEG", 13 | "*.jpeg", 14 | "*.jpg" 15 | ) 16 | 17 | def test_collate(batch): 18 | return torch.stack(batch) 19 | 20 | class ImageDataset(Dataset): 21 | def __init__(self, root, image_size=(224, 224)): 22 | self.root = pathlib.Path(root) 23 | images = [] 24 | for ext in IMG_EXTENSIONS: 25 | images.extend(self.root.rglob(ext)) 26 | random.shuffle(images) 27 | self.images = list(map(lambda x: str(x), images)) 28 | self.transform = tvtf.Compose( 29 | [ 30 | CenterCrop(image_size[0]), 31 | tvtf.ToTensor(), 32 | tvtf.Lambda(lambda x: (x*255).to(torch.uint8)), 33 | tvtf.Lambda(lambda x: x.expand(3, -1, -1)) 34 | ] 35 | ) 36 | self.size = image_size 37 | 38 | def __getitem__(self, idx): 39 | try: 40 | image = Image.open(self.images[idx])#read_image(self.images[idx]) 41 | image = self.transform(image) 42 | except Exception as e: 43 | print(self.images[idx]) 44 | image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8) 45 | 46 | # print(image) 47 | metadata = dict( 48 | path = self.images[idx], 49 | root = self.root, 50 | ) 51 | return image #, metadata 52 | 53 | def __len__(self): 54 | return len(self.images) -------------------------------------------------------------------------------- /src/data/randn.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | 9 | class RandomNDataset(Dataset): 10 | def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, random_seed=True): 11 | self.selected_classes = selected_classes 12 | if selected_classes is not None: 13 | num_classes = len(selected_classes) 14 | max_num_instances = 10*num_classes 15 | self.num_classes = num_classes 16 | self.seeds = seeds 17 | if seeds is not None: 18 | self.max_num_instances = len(seeds)*num_classes 19 | self.num_seeds = len(seeds) 20 | else: 21 | self.num_seeds = (max_num_instances + num_classes - 1) // num_classes 22 | self.max_num_instances = self.num_seeds*num_classes 23 | 24 | self.latent_shape = latent_shape 25 | self.random_seed = random_seed 26 | 27 | 28 | def __getitem__(self, idx): 29 | label = idx // self.num_seeds 30 | if self.selected_classes: 31 | label = self.selected_classes[label] 32 | if self.random_seed: 33 | seed = random.randint(0, 1<<31) 34 | else: 35 | seed = idx 36 | if self.seeds is not None: 37 | seed = self.seeds[idx % self.num_seeds] 38 | # cls_dir = os.path.join(self.root, f"{label}") 39 | filename = f"{label}_{seed}.png", 40 | generator = torch.Generator().manual_seed(seed) 41 | latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) 42 | return latent, label, filename 43 | def __len__(self): 44 | return self.max_num_instances -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Callable 4 | import torchvision.transforms as tvtf 5 | from PIL import Image 6 | from io import BytesIO 7 | import base64 8 | import pickle 9 | 10 | class VARNeBuDataPreTransform: 11 | def __call__(self, sample): 12 | try: 13 | oss_path = sample[0] 14 | class_id = int(sample[1]) 15 | pil_image = Image.open(BytesIO(sample[-1])).convert("RGB") 16 | while pil_image.size[0] >= 512 and pil_image.size[1] >= 512: 17 | pil_image = pil_image.resize( 18 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 19 | ) 20 | return pil_image, class_id 21 | except Exception as e: 22 | print('Failed to pre-process sample: \n', repr(e)) 23 | return None 24 | 25 | class NeBuDataPreTransform: 26 | def __call__(self, sample): 27 | try: 28 | oss_path = sample[0] 29 | class_id = int(sample[1]) 30 | pil_image = Image.open(BytesIO(sample[-1])).convert("RGB") 31 | return pil_image, class_id 32 | except Exception as e: 33 | print('Failed to pre-process sample: \n', repr(e)) 34 | return None 35 | 36 | 37 | class NeBuPreComputedDataPreTransform: 38 | def __call__(self, sample): 39 | try: 40 | base64_data = BytesIO(sample[-1]).getvalue() 41 | pk_data = base64.b64decode(base64_data.decode("utf-8")) 42 | pk_data = pickle.loads(pk_data) 43 | mean = pk_data['mean'] 44 | logvar = pk_data['logvar'] 45 | label = pk_data['label'] 46 | if "fliped_mean" in pk_data.keys(): 47 | if np.random.rand() > 0.5: 48 | mean = pk_data['fliped_mean'] 49 | logvar = pk_data['fliped_logvar'] 50 | logvar = torch.clamp(logvar, -30.0, 20.0) 51 | std = torch.exp(0.5 * logvar) 52 | sample = mean + torch.randn_like(mean) * std 53 | return sample, label 54 | except Exception as e: 55 | print('Failed to pre-process sample: \n', repr(e)) 56 | return torch.randn(4,32, 32), 0 57 | 58 | class CenterCrop: 59 | def __init__(self, size): 60 | self.size = size 61 | def __call__(self, image): 62 | def center_crop_arr(pil_image, image_size): 63 | """ 64 | Center cropping implementation from ADM. 65 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 66 | """ 67 | while min(*pil_image.size) >= 2 * image_size: 68 | pil_image = pil_image.resize( 69 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 70 | ) 71 | 72 | scale = image_size / min(*pil_image.size) 73 | pil_image = pil_image.resize( 74 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 75 | ) 76 | 77 | arr = np.array(pil_image) 78 | crop_y = (arr.shape[0] - image_size) // 2 79 | crop_x = (arr.shape[1] - image_size) // 2 80 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 81 | 82 | return center_crop_arr(image, self.size) 83 | 84 | def xymetacollate(batch): 85 | latent, label, metadata = zip(*batch) 86 | latent = torch.stack(latent) 87 | label = torch.tensor(label) 88 | return latent, label, metadata 89 | 90 | 91 | class UnifiedTransforms: 92 | def __init__(self, size, pre_transform:Callable=None, transform=(), precomputed_data=False): 93 | self.pre_transform = pre_transform 94 | if precomputed_data: 95 | self.transform = tvtf.Compose( 96 | [ 97 | *transform, 98 | # tvtf.RandomHorizontalFlip(0.5), 99 | ] 100 | ) 101 | else: 102 | self.transform = tvtf.Compose( 103 | [ 104 | *transform, 105 | CenterCrop(size), 106 | tvtf.RandomHorizontalFlip(0.5), 107 | tvtf.ToTensor(), 108 | tvtf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 109 | ] 110 | ) 111 | def __call__(self, data): 112 | if self.pre_transform is not None: 113 | x, y = self.pre_transform(data) 114 | else: 115 | x, y = data 116 | x = self.transform(x) 117 | return x, y -------------------------------------------------------------------------------- /src/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/src/diffusion/__init__.py -------------------------------------------------------------------------------- /src/diffusion/base/guidance.py: -------------------------------------------------------------------------------- 1 | def simple_guidance_fn(out, cfg): 2 | uncondition, condtion = out.chunk(2, dim=0) 3 | out = uncondition + cfg * (condtion - uncondition) 4 | return out 5 | 6 | def c3_guidance_fn(out, cfg): 7 | # guidance function in DiT/SiT, seems like a bug not a feature? 8 | uncondition, condtion = out.chunk(2, dim=0) 9 | out = condtion 10 | out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3]) 11 | return out -------------------------------------------------------------------------------- /src/diffusion/base/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import logging 5 | from typing import Callable 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | class BaseSampler(nn.Module): 10 | def __init__(self, 11 | null_class, 12 | guidance_fn: Callable, 13 | num_steps: int = 250, 14 | guidance: float = 1.0, 15 | weight_path: str=None, 16 | *args, 17 | **kwargs 18 | ): 19 | super().__init__() 20 | self.null_class = null_class 21 | self.num_steps = num_steps 22 | self.guidance = guidance 23 | self.guidance_fn = guidance_fn 24 | self.weight_path = weight_path 25 | 26 | def _timesteps(self): 27 | raise NotImplementedError 28 | 29 | def _report_stats(self): 30 | return {} 31 | 32 | def report_stats(self): 33 | logger.warning("logging stats of neural sampler:") 34 | stats = self._report_stats() 35 | for k, v in stats.items(): 36 | logger.warning(f"{k}->{v}") 37 | 38 | def _impl_sampling(self, net, images, labels): 39 | raise NotImplementedError 40 | 41 | def __call__(self, net, images, labels, return_x_trajectory=False): 42 | x_trajs = self._impl_sampling(net, images, labels) 43 | if return_x_trajectory: 44 | return x_trajs 45 | return x_trajs[-1] -------------------------------------------------------------------------------- /src/diffusion/base/scheduling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | class BaseScheduler(nn.Module): 6 | def alpha(self, t) -> Tensor: 7 | ... 8 | def sigma(self, t) -> Tensor: 9 | ... 10 | 11 | def dalpha(self, t) -> Tensor: 12 | ... 13 | def dsigma(self, t) -> Tensor: 14 | ... 15 | 16 | def dalpha_over_alpha(self, t) -> Tensor: 17 | return self.dalpha(t) / self.alpha(t) 18 | 19 | def dsigma_mul_sigma(self, t) -> Tensor: 20 | return self.dsigma(t)*self.sigma(t) 21 | 22 | def drift_coefficient(self, t): 23 | alpha, sigma = self.alpha(t), self.sigma(t) 24 | dalpha, dsigma = self.dalpha(t), self.dsigma(t) 25 | return dalpha/(alpha + 1e-6) 26 | 27 | def diffuse_coefficient(self, t): 28 | alpha, sigma = self.alpha(t), self.sigma(t) 29 | dalpha, dsigma = self.dalpha(t), self.dsigma(t) 30 | return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2 31 | 32 | def w(self, t): 33 | return self.sigma(t) 34 | -------------------------------------------------------------------------------- /src/diffusion/ddpm/ddim_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.diffusion.base.scheduling import * 3 | from src.diffusion.base.sampling import * 4 | 5 | from typing import Callable 6 | 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | class DDIMSampler(BaseSampler): 11 | def __init__( 12 | self, 13 | scheduler: BaseScheduler, 14 | train_num_steps=1000, 15 | *args, 16 | **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.scheduler = scheduler 20 | self.train_num_steps = train_num_steps 21 | assert self.scheduler is not None 22 | 23 | def _impl_sampling(self, net, images, labels): 24 | batch_size = images.shape[0] 25 | steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=images.device) 26 | steps = torch.flip(steps, dims=[0]) 27 | 28 | null_labels = torch.full_like(labels, self.null_class) 29 | labels = torch.cat([null_labels, labels], dim=0) 30 | x = x0 = images 31 | trajs = [x, ] 32 | for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): 33 | t_cur = t_cur.repeat(batch_size) 34 | t_next = t_next.repeat(batch_size) 35 | sigma = self.scheduler.sigma(t_cur) 36 | alpha = self.scheduler.alpha(t_cur) 37 | sigma_next = self.scheduler.sigma(t_next) 38 | alpha_next = self.scheduler.alpha(t_next) 39 | cfg_x = torch.cat([x, x], dim=0) 40 | t = t_cur.repeat(2) 41 | out = net(cfg_x, t, labels) 42 | out = self.guidance_fn(out, self.guidance) 43 | x0 = (x - sigma * out) / alpha 44 | x = alpha_next * x0 + sigma_next * out 45 | trajs.append(x) 46 | return trajs -------------------------------------------------------------------------------- /src/diffusion/ddpm/dpmsolver_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.diffusion.base.scheduling import * 4 | from src.diffusion.base.sampling import * 5 | from src.diffusion.base.guidance import simple_guidance_fn 6 | 7 | from typing import Callable 8 | 9 | 10 | def ode_step_fn(x, v, s, beta, dt): 11 | return x + v*dt 12 | 13 | def sde_step_fn(x, v, s, beta, dt): 14 | return x + (v + 0.5*s*beta)*dt + torch.sqrt(dt*beta)*torch.randn_like(x) 15 | 16 | import logging 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | 21 | class DPM1sSolverSampler(BaseSampler): 22 | def __init__( 23 | self, 24 | train_max_t=1000, 25 | num_steps: int = 250, 26 | scheduler: BaseScheduler = None, 27 | step_fn: Callable = ode_step_fn, 28 | *args, 29 | **kwargs 30 | ): 31 | super().__init__(*args, **kwargs) 32 | self.scheduler = scheduler 33 | self.num_steps = num_steps 34 | self.step_fn = step_fn 35 | self.train_max_t = train_max_t 36 | assert self.scheduler is not None 37 | 38 | 39 | def _impl_sampling(self, net, images, labels): 40 | batch_size = images.shape[0] 41 | null_labels = torch.full_like(labels, self.null_class) 42 | labels = torch.cat([null_labels, labels], dim=0) 43 | x = images 44 | pred_trajectory = [] 45 | trajs = [x, ] 46 | t_cur = torch.ones(1).to(images.device, images.dtype)*0.999 47 | dt = 1/self.num_steps 48 | t_cur = t_cur.repeat(batch_size) 49 | for i in range(self.num_steps): 50 | sigma = self.scheduler.sigma(t_cur) 51 | alpha = self.scheduler.alpha(t_cur) 52 | lamda = (alpha/sigma) 53 | sigma_next = self.scheduler.sigma(t_cur - dt) 54 | alpha_next = self.scheduler.alpha(t_cur - dt) 55 | lamda_next = (alpha_next/sigma_next) 56 | cfg_x = torch.cat([x, x], dim=0) 57 | t = t_cur.repeat(2) 58 | eps = net(cfg_x, t * self.train_max_t, labels) 59 | eps = self.guidance_fn(eps, self.guidance) 60 | x0 = (x - sigma*eps)/alpha 61 | pred_trajectory.append(x0) 62 | delta_lamda = lamda_next - lamda 63 | x = (sigma_next/sigma)*x + sigma_next*(delta_lamda)*x0 64 | t_cur = t_cur - dt 65 | trajs.append(x) 66 | return trajs -------------------------------------------------------------------------------- /src/diffusion/ddpm/neuralsolver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.diffusion.base.scheduling import * 4 | from src.diffusion.base.sampling import * 5 | from src.diffusion.base.guidance import simple_guidance_fn 6 | 7 | from typing import Callable 8 | 9 | 10 | def ode_step_fn(x, v, s, beta, dt): 11 | return x + v*dt 12 | 13 | def sde_step_fn(x, v, s, beta, dt): 14 | return x + (v + 0.5*s*beta)*dt + torch.sqrt(dt*beta)*torch.randn_like(x) 15 | 16 | import logging 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class NeuralSolverSampler(BaseSampler): 21 | def __init__( 22 | self, 23 | train_max_t=1000, 24 | num_steps: int = 250, 25 | scheduler: BaseScheduler = None, 26 | step_fn: Callable = ode_step_fn, 27 | *args, 28 | **kwargs 29 | ): 30 | super().__init__(*args, **kwargs) 31 | self.scheduler = scheduler 32 | self.num_steps = num_steps 33 | self.step_fn = step_fn 34 | self.train_max_t = train_max_t 35 | assert self.scheduler is not None 36 | self._register_parameters(num_steps) 37 | 38 | def _register_parameters(self, num_steps): 39 | self._raw_solver_coeffs = nn.Parameter(torch.eye(num_steps) * 0) 40 | timedeltas = (1 / self.num_steps) 41 | self._raw_timedeltas = nn.Parameter(torch.full((num_steps,), fill_value=timedeltas)) 42 | 43 | def _timesteps(self): 44 | t = torch.softmax(self._raw_timedeltas, dim=0) 45 | for i in range(self.num_steps): 46 | if i > 0: 47 | t[i] += t[i - 1] 48 | return t 49 | 50 | @torch.no_grad() 51 | def _report_stats(self): 52 | timedeltas = torch.softmax(self._raw_timedeltas, dim=0) 53 | solver_coeffs = self._raw_solver_coeffs.detach() 54 | return {"timedeltas": timedeltas, "solver_coeffs": solver_coeffs} 55 | 56 | def _impl_sampling(self, net, images, labels): 57 | batch_size = images.shape[0] 58 | null_labels = torch.full_like(labels, self.null_class) 59 | labels = torch.cat([null_labels, labels], dim=0) 60 | x = images 61 | pred_trajectory = [] 62 | trajs = [x, ] 63 | t_cur = torch.ones(1).to(images.device, images.dtype)*0.999 64 | timedeltas = self._raw_timedeltas.to(images.device, images.dtype) 65 | solver_coeffs = self._raw_solver_coeffs.to(images.device, images.dtype) 66 | t_cur = t_cur.repeat(batch_size) 67 | for i in range(self.num_steps): 68 | dt = timedeltas[i] 69 | sigma = self.scheduler.sigma(t_cur) 70 | alpha = self.scheduler.alpha(t_cur) 71 | lamda = (alpha/sigma) 72 | sigma_next = self.scheduler.sigma(t_cur - dt) 73 | alpha_next = self.scheduler.alpha(t_cur - dt) 74 | lamda_next = (alpha_next/sigma_next) 75 | cfg_x = torch.cat([x, x], dim=0) 76 | t = t_cur.repeat(2) 77 | eps = net(cfg_x, t * self.train_max_t, labels) 78 | eps = self.guidance_fn(eps, self.guidance) 79 | x0 = (x - sigma*eps)/alpha 80 | pred_trajectory.append(x0) 81 | dpmeps = torch.zeros_like(x0) 82 | sum_solver_coeff = 0.0 83 | for j in range(i): 84 | if self.num_steps <= 6 and i == self.num_steps - 2 and j != i - 1: continue 85 | if self.num_steps <= 6 and i == self.num_steps - 1 and j != i - 1: continue 86 | dpmeps += solver_coeffs[i, j] * pred_trajectory[j] 87 | sum_solver_coeff += solver_coeffs[i, j] 88 | dpmeps += (1 - sum_solver_coeff) * pred_trajectory[-1] 89 | delta_lamda = lamda_next - lamda 90 | x = (sigma_next / sigma) * x + sigma_next * (delta_lamda) * dpmeps 91 | pred_trajectory.append(x0) 92 | delta_lamda = lamda_next - lamda 93 | x = (sigma_next/sigma)*x + sigma_next*(delta_lamda)*x0 94 | t_cur = t_cur - dt 95 | trajs.append(x) 96 | return trajs -------------------------------------------------------------------------------- /src/diffusion/ddpm/scheduling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from src.diffusion.base.scheduling import * 4 | 5 | 6 | class DDPMScheduler(BaseScheduler): 7 | def __init__( 8 | self, 9 | beta_min=0.0001, 10 | beta_max=0.02, 11 | num_steps=1000, 12 | ): 13 | super().__init__() 14 | self.beta_min = beta_min 15 | self.beta_max = beta_max 16 | self.num_steps = num_steps 17 | 18 | self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda") 19 | self.alphas_table = torch.cumprod(1-self.betas_table, dim=0) 20 | self.sigmas_table = 1-self.alphas_table 21 | 22 | 23 | def beta(self, t) -> Tensor: 24 | t = t.to(torch.long) 25 | return self.betas_table[t].view(-1, 1, 1, 1) 26 | 27 | def alpha(self, t) -> Tensor: 28 | t = t.to(torch.long) 29 | return self.alphas_table[t].view(-1, 1, 1, 1)**0.5 30 | 31 | def sigma(self, t) -> Tensor: 32 | t = t.to(torch.long) 33 | return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5 34 | 35 | def dsigma(self, t) -> Tensor: 36 | raise NotImplementedError("wrong usage") 37 | 38 | def dalpha_over_alpha(self, t) ->Tensor: 39 | raise NotImplementedError("wrong usage") 40 | 41 | def dsigma_mul_sigma(self, t) ->Tensor: 42 | raise NotImplementedError("wrong usage") 43 | 44 | def dalpha(self, t) -> Tensor: 45 | raise NotImplementedError("wrong usage") 46 | 47 | def drift_coefficient(self, t): 48 | raise NotImplementedError("wrong usage") 49 | 50 | def diffuse_coefficient(self, t): 51 | raise NotImplementedError("wrong usage") 52 | 53 | def w(self, t): 54 | raise NotImplementedError("wrong usage") 55 | 56 | 57 | class VPScheduler(BaseScheduler): 58 | def __init__( 59 | self, 60 | beta_min=0.1, 61 | beta_max=20, 62 | ): 63 | super().__init__() 64 | self.beta_min = beta_min 65 | self.beta_d = beta_max - beta_min 66 | def beta(self, t) -> Tensor: 67 | t = torch.clamp(t, min=1e-3, max=1) 68 | return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1) 69 | 70 | def sigma(self, t) -> Tensor: 71 | t = torch.clamp(t, min=1e-3, max=1) 72 | inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t 73 | return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1) 74 | 75 | def dsigma(self, t) -> Tensor: 76 | raise NotImplementedError("wrong usage") 77 | 78 | def dalpha_over_alpha(self, t) ->Tensor: 79 | raise NotImplementedError("wrong usage") 80 | 81 | def dsigma_mul_sigma(self, t) ->Tensor: 82 | raise NotImplementedError("wrong usage") 83 | 84 | def dalpha(self, t) -> Tensor: 85 | raise NotImplementedError("wrong usage") 86 | 87 | def alpha(self, t) -> Tensor: 88 | t = torch.clamp(t, min=1e-3, max=1) 89 | inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t 90 | return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1) 91 | 92 | def drift_coefficient(self, t): 93 | raise NotImplementedError("wrong usage") 94 | 95 | def diffuse_coefficient(self, t): 96 | raise NotImplementedError("wrong usage") 97 | 98 | def w(self, t): 99 | return self.diffuse_coefficient(t) 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /src/diffusion/ddpm/vp_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.diffusion.base.scheduling import * 4 | from src.diffusion.base.sampling import * 5 | from typing import Callable 6 | 7 | def ode_step_fn(x, eps, beta, sigma, dt): 8 | return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt 9 | 10 | def sde_step_fn(x, eps, beta, sigma, dt): 11 | return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x) 12 | 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | class VPEulerSampler(BaseSampler): 17 | def __init__( 18 | self, 19 | train_max_t=1000, 20 | scheduler: BaseScheduler = None, 21 | guidance_fn: Callable = None, 22 | step_fn: Callable = ode_step_fn, 23 | last_step=None, 24 | last_step_fn: Callable = ode_step_fn, 25 | *args, 26 | **kwargs 27 | ): 28 | super().__init__(*args, **kwargs) 29 | self.scheduler = scheduler 30 | self.guidance_fn = guidance_fn 31 | self.step_fn = step_fn 32 | self.last_step = last_step 33 | self.last_step_fn = last_step_fn 34 | self.train_max_t = train_max_t 35 | 36 | if self.last_step is None or self.num_steps == 1: 37 | self.last_step = 1.0 / self.num_steps 38 | assert self.last_step > 0.0 39 | assert self.scheduler is not None 40 | 41 | def _impl_sampling(self, net, images, labels): 42 | batch_size = images.shape[0] 43 | steps = torch.linspace(1.0, self.last_step, self.num_steps, device=images.device) 44 | steps = torch.cat([steps, torch.tensor([0.0], device=images.device)], dim=0) 45 | null_labels = torch.full_like(labels, self.null_class) 46 | labels = torch.cat([null_labels, labels], dim=0) 47 | x = images 48 | for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): 49 | dt = t_next - t_cur 50 | t_cur = t_cur.repeat(batch_size) 51 | sigma = self.scheduler.sigma(t_cur) 52 | beta = self.scheduler.beta(t_cur) 53 | cfg_x = torch.cat([x, x], dim=0) 54 | t = t_cur.repeat(2) 55 | out = net(cfg_x, t*self.train_max_t, labels) 56 | eps = self.guidance_fn(out, self.guidance) 57 | if i < self.num_steps -1 : 58 | x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0]) 59 | x = self.step_fn(x, eps, beta, sigma, dt) 60 | else: 61 | x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step) 62 | return x0 -------------------------------------------------------------------------------- /src/diffusion/flow_matching/adam_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.diffusion.base.sampling import * 4 | from src.diffusion.base.guidance import * 5 | from src.diffusion.base.scheduling import * 6 | 7 | from typing import Callable 8 | 9 | def ode_step_fn(x, v, dt, s, w): 10 | return x + v * dt 11 | 12 | 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | class AdamLMSampler(BaseSampler): 17 | def __init__( 18 | self, 19 | order: int = 2, 20 | num_steps: int = 250, 21 | scheduler: BaseScheduler = None, 22 | w_scheduler: BaseScheduler = None, 23 | step_fn: Callable = ode_step_fn, 24 | *args, 25 | **kwargs 26 | ): 27 | super().__init__(*args, **kwargs) 28 | self.scheduler = scheduler 29 | self.num_steps = num_steps 30 | self.step_fn = step_fn 31 | self.w_scheduler = w_scheduler 32 | 33 | assert self.scheduler is not None 34 | assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] 35 | if self.w_scheduler is not None: 36 | if self.step_fn == ode_step_fn: 37 | logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") 38 | self._register_parameters(order) 39 | 40 | # TODO: try non-uniform timesteps ? 41 | def _register_parameters(self, order=2): 42 | self._raw_solver_coeffs = torch.nn.Parameter(torch.eye(self.num_steps) * 0, requires_grad=False) 43 | for i in range(1, self.num_steps): 44 | if i >= 1 and order>=2: 45 | self._raw_solver_coeffs[i, i-1] = -0.5 46 | if i>=2 and order>=3: 47 | self._raw_solver_coeffs[i, i-2:i] = torch.tensor([+5 / 12, -16 / 12]) 48 | if i>=3 and order>=4: 49 | self._raw_solver_coeffs[i, i - 3:i] = torch.tensor([-9 / 24, +37 / 24, -59 / 24]) 50 | timedeltas = (1 / self.num_steps) 51 | self._raw_timedeltas = torch.nn.Parameter(torch.full((self.num_steps,), fill_value=timedeltas)) 52 | 53 | 54 | def _impl_sampling(self, net, images, labels): 55 | """ 56 | sampling process of Euler sampler 57 | - 58 | """ 59 | batch_size = images.shape[0] 60 | null_labels = torch.full_like(labels, self.null_class) 61 | labels = torch.cat([null_labels, labels], dim=0) 62 | x = x0 = images 63 | pred_trajectory = [] 64 | trajs = [x, ] 65 | t_cur = torch.zeros(1).to(images.device, images.dtype) 66 | timedeltas = self._raw_timedeltas 67 | solver_coeffs = self._raw_solver_coeffs 68 | t_cur = t_cur.repeat(batch_size) 69 | for i in range(self.num_steps): 70 | cfg_x = torch.cat([x, x], dim=0) 71 | t = t_cur.repeat(2) 72 | out = net(cfg_x, t, labels) 73 | out = self.guidance_fn(out, self.guidance) 74 | pred_trajectory.append(out) 75 | out = torch.zeros_like(out) 76 | sum_solver_coeff = 0.0 77 | for j in range(i): 78 | out += solver_coeffs[i, j] * pred_trajectory[j] 79 | sum_solver_coeff += solver_coeffs[i, j] 80 | out += (1-sum_solver_coeff)*pred_trajectory[-1] 81 | v = out 82 | dt = timedeltas[i] 83 | x0 = self.step_fn(x, v, 1-t[0], s=0, w=0) 84 | x = self.step_fn(x, v, dt, s=0, w=0) 85 | t_cur += dt 86 | trajs.append(x) 87 | return trajs -------------------------------------------------------------------------------- /src/diffusion/flow_matching/neuralsolver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.diffusion.base.sampling import * 4 | from src.diffusion.base.guidance import * 5 | from src.diffusion.base.scheduling import * 6 | 7 | from typing import Callable 8 | 9 | def ode_step_fn(x, v, dt, s, w): 10 | return x + v * dt 11 | 12 | import logging 13 | logger = logging.getLogger(__name__) 14 | 15 | class NeuralSolverSampler(BaseSampler): 16 | def __init__( 17 | self, 18 | num_steps: int = 250, 19 | scheduler: BaseScheduler = None, 20 | w_scheduler: BaseScheduler = None, 21 | step_fn: Callable = ode_step_fn, 22 | *args, 23 | **kwargs 24 | ): 25 | super().__init__(*args, **kwargs) 26 | self.scheduler = scheduler 27 | self.num_steps = num_steps 28 | self.step_fn = step_fn 29 | self.w_scheduler = w_scheduler 30 | 31 | assert self.scheduler is not None 32 | assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] 33 | if self.w_scheduler is not None: 34 | if self.step_fn == ode_step_fn: 35 | logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") 36 | self._register_parameters(self.num_steps) 37 | 38 | def _register_parameters(self, num_steps): 39 | self._raw_solver_coeffs = nn.Parameter(torch.eye(num_steps) * 0) 40 | timedeltas = (1 / self.num_steps) 41 | self._raw_timedeltas = nn.Parameter(torch.full((num_steps,), fill_value=timedeltas)) 42 | 43 | def _timesteps(self): 44 | t = torch.softmax(self._raw_timedeltas, dim=0) 45 | for i in range(self.num_steps): 46 | if i > 0: 47 | t[i] += t[i - 1] 48 | return t 49 | 50 | @torch.no_grad() 51 | def _report_stats(self): 52 | timedeltas = torch.softmax(self._raw_timedeltas, dim=0) 53 | solver_coeffs = self._raw_solver_coeffs.detach() 54 | return {"timedeltas": timedeltas, "solver_coeffs": solver_coeffs} 55 | 56 | 57 | def _impl_sampling(self, net, images, labels): 58 | """ 59 | sampling process of Euler sampler 60 | - 61 | """ 62 | batch_size = images.shape[0] 63 | null_labels = torch.full_like(labels, self.null_class) 64 | labels = torch.cat([null_labels, labels], dim=0) 65 | x = x0 = images 66 | pred_trajectory = [] 67 | trajs = [x,] 68 | t_cur = torch.zeros(1).to(images.device, images.dtype) 69 | timedeltas = self._raw_timedeltas.to(images.device, images.dtype) 70 | solver_coeffs = self._raw_solver_coeffs.to(images.device, images.dtype) 71 | t_cur = t_cur.repeat(batch_size) 72 | for i in range(self.num_steps): 73 | cfg_x = torch.cat([x, x], dim=0) 74 | t = t_cur.repeat(2) 75 | out = net(cfg_x, t, labels) 76 | out = self.guidance_fn(out, self.guidance) 77 | pred_trajectory.append(out) 78 | out = torch.zeros_like(out) 79 | sum_solver_coeff = 0.0 80 | for j in range(i): 81 | if self.num_steps <= 6 and i == self.num_steps - 2 and j != i - 1: continue 82 | if self.num_steps <= 6 and i == self.num_steps - 1 and j != i - 1: continue 83 | out += solver_coeffs[i, j] * pred_trajectory[j] 84 | sum_solver_coeff += solver_coeffs[i, j] 85 | out += (1-sum_solver_coeff)*pred_trajectory[-1] 86 | v = out 87 | dt = timedeltas[i] 88 | x0 = self.step_fn(x, v, 1-t[0], s=0, w=0) 89 | x = self.step_fn(x, v, dt, s=0, w=0) 90 | t_cur += dt 91 | trajs.append(x) 92 | return trajs -------------------------------------------------------------------------------- /src/diffusion/flow_matching/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.diffusion.base.guidance import * 4 | from src.diffusion.base.scheduling import * 5 | from src.diffusion.base.sampling import * 6 | 7 | from typing import Callable 8 | 9 | 10 | def ode_step_fn(x, v, dt, s, w): 11 | return x + v * dt 12 | 13 | def sde_mean_step_fn(x, v, dt, s, w): 14 | return x + v * dt + s * w * dt 15 | 16 | def sde_step_fn(x, v, dt, s, w): 17 | return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x) 18 | 19 | def sde_preserve_step_fn(x, v, dt, s, w): 20 | return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x) 21 | 22 | 23 | import logging 24 | logger = logging.getLogger(__name__) 25 | 26 | class EulerSampler(BaseSampler): 27 | def __init__( 28 | self, 29 | scheduler: BaseScheduler = None, 30 | w_scheduler: BaseScheduler = None, 31 | step_fn: Callable = ode_step_fn, 32 | last_step=None, 33 | last_step_fn: Callable = ode_step_fn, 34 | *args, 35 | **kwargs 36 | ): 37 | super().__init__(*args, **kwargs) 38 | self.scheduler = scheduler 39 | self.step_fn = step_fn 40 | self.last_step = last_step 41 | self.last_step_fn = last_step_fn 42 | self.w_scheduler = w_scheduler 43 | 44 | if self.last_step is None or self.num_steps == 1: 45 | self.last_step = 1.0 / self.num_steps 46 | assert self.last_step > 0.0 47 | assert self.scheduler is not None 48 | assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] 49 | if self.w_scheduler is not None: 50 | if self.step_fn == ode_step_fn: 51 | logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") 52 | 53 | def _impl_sampling(self, net, images, labels): 54 | """ 55 | sampling process of Euler sampler 56 | - 57 | """ 58 | batch_size = images.shape[0] 59 | steps = torch.linspace(0.0, 1 - self.last_step, self.num_steps, device=images.device) 60 | steps = torch.cat([steps, torch.tensor([1.0], device=images.device)], dim=0) 61 | 62 | null_labels = torch.full_like(labels, self.null_class) 63 | labels = torch.cat([null_labels, labels], dim=0) 64 | x = images 65 | dt = steps[1] - steps[0] 66 | 67 | trajs = [x, ] 68 | for i, t_cur in enumerate(steps[:-1]): 69 | t_cur = t_cur.repeat(batch_size) 70 | sigma = self.scheduler.sigma(t_cur) 71 | dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur) 72 | dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) 73 | if self.w_scheduler: 74 | w = self.w_scheduler.w(t_cur) 75 | else: 76 | w = 0.0 77 | 78 | cfg_x = torch.cat([x, x], dim=0) 79 | t = t_cur.repeat(2) 80 | out = net(cfg_x, t, labels) 81 | out = self.guidance_fn(out, self.guidance) 82 | v = out 83 | s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma) 84 | if i < self.num_steps -1 : 85 | x = self.step_fn(x, v, dt, s=s, w=w) 86 | else: 87 | x = self.last_step_fn(x, v, self.last_step, s=s, w=w) 88 | trajs.append(x) 89 | return trajs -------------------------------------------------------------------------------- /src/diffusion/flow_matching/scheduling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from src.diffusion.base.scheduling import * 4 | 5 | 6 | class LinearScheduler(BaseScheduler): 7 | def alpha(self, t) -> Tensor: 8 | return (t).view(-1, 1, 1, 1) 9 | def sigma(self, t) -> Tensor: 10 | return (1-t).view(-1, 1, 1, 1) 11 | def dalpha(self, t) -> Tensor: 12 | return torch.full_like(t, 1.0).view(-1, 1, 1, 1) 13 | def dsigma(self, t) -> Tensor: 14 | return torch.full_like(t, -1.0).view(-1, 1, 1, 1) 15 | 16 | # SoTA for ImageNet! 17 | class GVPScheduler(BaseScheduler): 18 | def alpha(self, t) -> Tensor: 19 | return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) 20 | def sigma(self, t) -> Tensor: 21 | return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) 22 | def dalpha(self, t) -> Tensor: 23 | return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) 24 | def dsigma(self, t) -> Tensor: 25 | return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) 26 | def w(self, t): 27 | return torch.sin(t)**2 28 | 29 | class ConstScheduler(BaseScheduler): 30 | def w(self, t): 31 | return torch.ones(1, 1, 1, 1).to(t.device, t.dtype) 32 | 33 | from src.diffusion.ddpm.scheduling import VPScheduler 34 | class VPBetaScheduler(VPScheduler): 35 | def w(self, t): 36 | return self.beta(t).view(-1, 1, 1, 1) 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /src/diffusion/solver_training.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional 2 | import random 3 | import torch.nn as nn 4 | from src.diffusion.base.sampling import BaseSampler 5 | 6 | class BaseTrainer(nn.Module): 7 | def __init__(self, 8 | min_cfg_aug=1.0, 9 | max_cfg_aug=1.0, 10 | buffer_size=1024, 11 | ): 12 | super().__init__() 13 | self.min_cfg_aug = min_cfg_aug 14 | self.max_cfg_aug = max_cfg_aug 15 | self.buffer_size = buffer_size 16 | def setup(self, target_sampler: BaseSampler, source_sampler: BaseSampler): 17 | self.target_sampler = target_sampler 18 | self.source_sampler = source_sampler 19 | 20 | def _impl_trainstep(self, net, images, labels): 21 | raise NotImplementedError 22 | 23 | def __call__(self, nets, images, labels): 24 | aug_cfg = torch.rand(images.shape[0], 1, 1, 1, device=images.device) * (self.max_cfg_aug-self.min_cfg_aug) + self.min_cfg_aug 25 | target_sampler_guidance = self.target_sampler.guidance 26 | source_sampler_guidance = self.source_sampler.guidance 27 | self.target_sampler.guidance = aug_cfg 28 | self.source_sampler.guidance = aug_cfg 29 | net = random.choice(nets) 30 | out = self._impl_trainstep(net, images, labels) 31 | self.target_sampler.guidance = target_sampler_guidance 32 | self.source_sampler.guidance = source_sampler_guidance 33 | return out 34 | 35 | 36 | class TrajsTrainer(BaseTrainer): 37 | def _impl_trainstep(self, net, images, labels): 38 | target_traj = self.target_sampler(net, images, labels, return_x_trajectory=True) 39 | source_traj = self.source_sampler(net, images, labels, return_x_trajectory=True) 40 | 41 | source_t = self.source_sampler._timesteps() 42 | source_t = torch.round(source_t*(len(target_traj)-1)).long() 43 | # select the corresponding target_traj 44 | selected_target_traj = [target_traj[i] for i in source_t] 45 | selected_source_traj = source_traj[1:] 46 | loss = 0.0 47 | out= dict() 48 | for i, (t, s) in enumerate(zip(selected_target_traj, selected_source_traj)): 49 | iter_loss = torch.nn.functional.mse_loss(t, s, reduction="mean") 50 | loss = loss + iter_loss 51 | out[f"iter{i}"] = iter_loss 52 | out["loss"] = loss 53 | return out 54 | 55 | 56 | class TrajsReWeightTrainer(BaseTrainer): 57 | def _impl_trainstep(self, net, images, labels): 58 | target_traj = self.target_sampler(net, images, labels, return_x_trajectory=True) 59 | source_traj = self.source_sampler(net, images, labels, return_x_trajectory=True) 60 | 61 | source_t = self.source_sampler._timesteps() 62 | source_t = torch.round(source_t*(len(target_traj)-1)).long() 63 | # select the corresponding target_traj 64 | selected_target_traj = [target_traj[i] for i in source_t] 65 | selected_source_traj = source_traj[1:] 66 | final_loss = torch.nn.functional.huber_loss(target_traj[-1], source_traj[-1], reduction="mean", delta=0.001)*1000 67 | loss = 0.0 68 | out= dict() 69 | for i, (t, s) in enumerate(zip(selected_target_traj, selected_source_traj)): 70 | iter_loss = torch.nn.functional.mse_loss(t, s, reduction="mean") 71 | loss = loss + iter_loss 72 | out[f"iter{i}"] = iter_loss 73 | out["loss"] = loss + final_loss 74 | return out -------------------------------------------------------------------------------- /src/lightning_data.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any 2 | import torch 3 | import lightning.pytorch as pl 4 | from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS 5 | from torch.utils.data import DataLoader 6 | from src.data.transforms import UnifiedTransforms, xymetacollate 7 | from src.data.randn import RandomNDataset 8 | from src.data.metric_dataset import ImageDataset 9 | 10 | class DataModule(pl.LightningDataModule): 11 | def __init__(self, 12 | test_nature_root, 13 | test_gen_root, 14 | train_full_random_seed=True, 15 | train_batch_size=64, 16 | train_num_workers=4, 17 | train_prefetch_factor=16, 18 | eval_batch_size=16, 19 | eval_num_workers=4, 20 | eval_max_num_instances=32, 21 | eval_seeds="0,1,2,3,4", 22 | eval_selected_classes=(207, 360, 387, 974, 88, 979, 417, 279), 23 | pred_batch_size=16, 24 | pred_num_workers=4, 25 | pred_seeds: str = None, 26 | pred_selected_classes=None, 27 | test_only_gen_data: Any = None, 28 | test_batch_size=64, 29 | test_num_workers=4, 30 | test_image_size=(224, 224), 31 | num_classes=1000, 32 | latent_shape=(4, 64, 64), 33 | ): 34 | super().__init__() 35 | eval_seeds = list(map(lambda x: int(x), eval_seeds.strip().split(","))) if eval_seeds is not None else None 36 | pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None 37 | 38 | # stupid data_convert override, just to make nebular happy 39 | self.train_batch_size = train_batch_size 40 | self.train_num_workers = train_num_workers 41 | self.train_prefetch_factor = train_prefetch_factor 42 | self.train_full_random_seed = train_full_random_seed 43 | 44 | self.test_nature_root = test_nature_root 45 | self.test_gen_root = test_gen_root 46 | self.eval_max_num_instances = eval_max_num_instances 47 | self.eval_seeds = eval_seeds 48 | self.pred_seeds = pred_seeds 49 | self.num_classes = num_classes 50 | self.latent_shape = latent_shape 51 | self.test_image_size = test_image_size 52 | 53 | self.eval_batch_size = eval_batch_size 54 | self.test_batch_size = test_batch_size 55 | self.pred_batch_size = pred_batch_size 56 | 57 | self.pred_num_workers = pred_num_workers 58 | self.eval_num_workers = eval_num_workers 59 | self.test_num_workers = test_num_workers 60 | 61 | self.eval_selected_classes = eval_selected_classes 62 | self.pred_selected_classes = pred_selected_classes 63 | 64 | self.test_only_gen_data = test_only_gen_data 65 | 66 | self._train_dataloader = None 67 | 68 | def prepare_data(self) -> None: 69 | ... 70 | 71 | def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: 72 | return batch 73 | def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: 74 | return batch 75 | 76 | def train_dataloader(self) -> TRAIN_DATALOADERS: 77 | global_rank = self.trainer.global_rank 78 | world_size = self.trainer.world_size 79 | self.train_dataset = RandomNDataset( 80 | max_num_instances=10000, 81 | num_classes=self.num_classes, 82 | latent_shape=self.latent_shape, 83 | random_seed=self.train_full_random_seed, 84 | ) 85 | from torch.utils.data import DistributedSampler 86 | sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True) 87 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, 88 | num_workers=self.train_num_workers, 89 | prefetch_factor=self.train_prefetch_factor, 90 | collate_fn=xymetacollate, 91 | sampler=sampler 92 | ) 93 | 94 | def test_dataloader(self) -> EVAL_DATALOADERS: 95 | global_rank = self.trainer.global_rank 96 | world_size = self.trainer.world_size 97 | self.test_nature_dataset = ImageDataset( 98 | root=self.test_nature_root, 99 | image_size=self.test_image_size 100 | ) 101 | self.test_gen_dataset = ImageDataset( 102 | root=self.test_gen_root, 103 | image_size=self.test_image_size 104 | ) 105 | from torch.utils.data import DistributedSampler 106 | test_nature_sampler = DistributedSampler(self.test_nature_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) 107 | test_gen_sampler = DistributedSampler(self.test_gen_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) 108 | from src.data.metric_dataset import test_collate 109 | if self.test_only_gen_data: 110 | return [ 111 | DataLoader(self.test_gen_dataset, self.test_batch_size, num_workers=self.test_num_workers, prefetch_factor=2, collate_fn=test_collate, sampler=test_gen_sampler), 112 | ] 113 | return [ 114 | DataLoader(self.test_gen_dataset, self.test_batch_size, num_workers=self.test_num_workers, prefetch_factor=2, collate_fn=test_collate, sampler=test_gen_sampler), 115 | DataLoader(self.test_nature_dataset, self.test_batch_size, num_workers=self.test_num_workers, prefetch_factor=2, collate_fn=test_collate, sampler=test_nature_sampler), 116 | ] 117 | 118 | def val_dataloader(self) -> EVAL_DATALOADERS: 119 | global_rank = self.trainer.global_rank 120 | world_size = self.trainer.world_size 121 | self.eval_dataset = RandomNDataset( 122 | seeds=self.eval_seeds, 123 | latent_shape=self.latent_shape, 124 | max_num_instances=self.eval_max_num_instances, 125 | num_classes=self.num_classes, 126 | selected_classes=self.eval_selected_classes, 127 | ) 128 | from torch.utils.data import DistributedSampler 129 | sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) 130 | return DataLoader(self.eval_dataset, self.eval_batch_size, 131 | num_workers=self.eval_num_workers, 132 | prefetch_factor=4, 133 | collate_fn=xymetacollate, 134 | sampler=sampler 135 | ) 136 | 137 | def predict_dataloader(self) -> EVAL_DATALOADERS: 138 | global_rank = self.trainer.global_rank 139 | world_size = self.trainer.world_size 140 | self.pred_dataset = RandomNDataset( 141 | seeds= self.pred_seeds, 142 | max_num_instances=50000, 143 | num_classes=self.num_classes, 144 | selected_classes=self.pred_selected_classes, 145 | latent_shape=self.latent_shape, 146 | ) 147 | from torch.utils.data import DistributedSampler 148 | sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) 149 | return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size, 150 | num_workers=self.pred_num_workers, 151 | prefetch_factor=4, 152 | collate_fn=xymetacollate, 153 | sampler=sampler 154 | ) 155 | -------------------------------------------------------------------------------- /src/models/dit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 6 | 7 | from src.models.base_model import BaseModel, modulate 8 | 9 | 10 | class DiTBlock(nn.Module): 11 | """ 12 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 13 | """ 14 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 15 | super().__init__() 16 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 17 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 18 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 19 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 20 | approx_gelu = lambda: nn.GELU(approximate="tanh") 21 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 22 | self.adaLN_modulation = nn.Sequential( 23 | nn.SiLU(), 24 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 25 | ) 26 | 27 | def forward(self, x, c): 28 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 29 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 30 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 31 | return x 32 | 33 | 34 | ################################################################################# 35 | # Sine/Cosine Positional Embedding Functions # 36 | ################################################################################# 37 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 38 | 39 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 40 | """ 41 | grid_size: int of the grid height and width 42 | return: 43 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 44 | """ 45 | grid_h = np.arange(grid_size, dtype=np.float32) 46 | grid_w = np.arange(grid_size, dtype=np.float32) 47 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 48 | grid = np.stack(grid, axis=0) 49 | 50 | grid = grid.reshape([2, 1, grid_size, grid_size]) 51 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 52 | if cls_token and extra_tokens > 0: 53 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 54 | return pos_embed 55 | 56 | 57 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 58 | assert embed_dim % 2 == 0 59 | 60 | # use half of dimensions to encode grid_h 61 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 62 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 63 | 64 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 65 | return emb 66 | 67 | 68 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 69 | """ 70 | embed_dim: output dimension for each position 71 | pos: a list of positions to be encoded: size (M,) 72 | out: (M, D) 73 | """ 74 | assert embed_dim % 2 == 0 75 | omega = np.arange(embed_dim // 2, dtype=np.float64) 76 | omega /= embed_dim / 2. 77 | omega = 1. / 10000**omega # (D/2,) 78 | 79 | pos = pos.reshape(-1) # (M,) 80 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 81 | 82 | emb_sin = np.sin(out) # (M, D/2) 83 | emb_cos = np.cos(out) # (M, D/2) 84 | 85 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 86 | return emb 87 | 88 | 89 | class DiT(BaseModel): 90 | def __init__(self, *args, **kwargs): 91 | super().__init__(*args, **kwargs) 92 | self.blocks = nn.ModuleList([ 93 | DiTBlock(self.hidden_size, self.num_groups, mlp_ratio=4) for _ in range(self.num_blocks) 94 | ]) 95 | num_patches = self.x_embedder.num_patches 96 | # Will use fixed sin-cos embedding: 97 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.hidden_size), requires_grad=False) 98 | 99 | self.initialize_weights() 100 | def initialize_weights(self): 101 | super().initialize_weights() 102 | # # Initialize (and freeze) pos_embed by sin-cos embedding: 103 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 104 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 105 | 106 | # Zero-out adaLN modulation layers in DiT blocks: 107 | for block in self.blocks: 108 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 109 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/DCNv4/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import DCNv4Function, FlashDeformAttnFunction 2 | from .modules import DCNv4, FlashDeformAttn -------------------------------------------------------------------------------- /src/ops/DCNv4_op/DCNv4/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # from .ms_flash_deform_attn_func import FlashMSDeformAttnFunction 10 | from .flash_deform_attn_func import FlashDeformAttnFunction 11 | from .dcnv4_func import DCNv4Function -------------------------------------------------------------------------------- /src/ops/DCNv4_op/DCNv4/functions/dcnv4_func.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import torch 12 | import math 13 | import torch.nn.functional as F 14 | from torch.autograd import Function 15 | from torch.autograd.function import once_differentiable 16 | from torch.cuda.amp import custom_bwd, custom_fwd 17 | from .table import TABLE, BWDTABLE 18 | 19 | from DCNv4 import ext 20 | 21 | def factors(N): 22 | res = [] 23 | for i in range(1, N+1): 24 | if N % i == 0: 25 | res.append(i) 26 | return res 27 | 28 | def findspec(B, H, W, G, C): 29 | key = f"{B}x{H}x{W}x{G}x{C}" 30 | if key in TABLE: 31 | return TABLE[key][0], TABLE[key][1] 32 | 33 | d_stride = 8 34 | ms = factors(B*H*W) 35 | multiplier = 1 36 | for m in ms: 37 | if m <= 64 and (m * G * C // d_stride) <= 512: 38 | multiplier = m 39 | n_thread = multiplier * G * C // d_stride 40 | key = f"{B}x{H}x{W}x{G}x{C}" 41 | TABLE[key] = (d_stride, n_thread) 42 | return d_stride, n_thread 43 | 44 | def find_spec_bwd(B, H, W, G, C): 45 | key = f"{B}x{H}x{W}x{G}x{C}" 46 | if key in BWDTABLE: 47 | return BWDTABLE[key][0], BWDTABLE[key][1] 48 | 49 | if C >= 64: 50 | d_stride = 2 51 | else: 52 | d_stride = 1 53 | 54 | ms = factors(B*H*W) 55 | multiplier = 1 56 | for m in ms: 57 | if m <= 64 and (m * G * C // d_stride) <= 256: 58 | multiplier = m 59 | n_thread = multiplier * G * C // d_stride 60 | return d_stride, n_thread 61 | 62 | class DCNv4Function(Function): 63 | @staticmethod 64 | @custom_fwd 65 | def forward( 66 | ctx, input, offset_mask, 67 | kernel_h, kernel_w, stride_h, stride_w, 68 | pad_h, pad_w, dilation_h, dilation_w, 69 | group, group_channels, offset_scale, 70 | im2col_step, remove_center): 71 | 72 | forward_d_stride, forward_block_thread = findspec(input.shape[0], input.shape[1], input.shape[2], group, group_channels) 73 | backward_d_stride, backward_block_thread = find_spec_bwd(input.shape[0], input.shape[1], input.shape[2], group, group_channels) 74 | 75 | ctx.kernel_h = kernel_h 76 | ctx.kernel_w = kernel_w 77 | ctx.stride_h = stride_h 78 | ctx.stride_w = stride_w 79 | ctx.pad_h = pad_h 80 | ctx.pad_w = pad_w 81 | ctx.dilation_h = dilation_h 82 | ctx.dilation_w = dilation_w 83 | ctx.group = group 84 | ctx.group_channels = group_channels 85 | ctx.offset_scale = offset_scale 86 | ctx.im2col_step = im2col_step 87 | ctx.remove_center = remove_center 88 | ctx.backward_d_stride = backward_d_stride 89 | ctx.backward_block_thread = backward_block_thread 90 | 91 | args = [ 92 | input, offset_mask, kernel_h, 93 | kernel_w, stride_h, stride_w, pad_h, 94 | pad_w, dilation_h, dilation_w, group, 95 | group_channels, offset_scale, 96 | ctx.im2col_step, 97 | remove_center, 98 | forward_d_stride, 99 | forward_block_thread, 100 | False, 101 | ] 102 | 103 | output = ext.dcnv4_forward(*args) 104 | ctx.save_for_backward(input, offset_mask) 105 | 106 | return output 107 | 108 | @staticmethod 109 | @once_differentiable 110 | @custom_bwd 111 | def backward(ctx, grad_output): 112 | input, offset_mask = ctx.saved_tensors 113 | 114 | args = [ 115 | input, offset_mask, ctx.kernel_h, 116 | ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h, 117 | ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group, 118 | ctx.group_channels, ctx.offset_scale, ctx.im2col_step, 119 | grad_output.contiguous(), ctx.remove_center, 120 | ctx.backward_d_stride, ctx.backward_block_thread, 121 | False 122 | ] 123 | 124 | grad_input, grad_offset_mask = \ 125 | ext.dcnv4_backward(*args) 126 | 127 | return grad_input, grad_offset_mask, \ 128 | None, None, None, None, None, None, None,\ 129 | None, None, None, None, None, None 130 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/DCNv4/functions/flash_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | import numpy as np 18 | 19 | from DCNv4 import ext 20 | 21 | shm_size_dict = { 22 | "8.0": 163000, 23 | "8.6": 99000, 24 | "8.7": 163000, 25 | "8.9": 99000, 26 | "9.0": 227000, 27 | "7.5": 64000, 28 | "7.0": 96000, 29 | } 30 | 31 | cuda_capability = f"{torch.cuda.get_device_properties(0).major}.{torch.cuda.get_device_properties(0).minor}" 32 | 33 | if cuda_capability not in shm_size_dict: 34 | raise NotImplementedError 35 | 36 | shm_size_cap = shm_size_dict[cuda_capability] 37 | 38 | def factors(N): 39 | res = [] 40 | for i in range(1, N+1): 41 | if N % i == 0: 42 | res.append(i) 43 | return res 44 | 45 | def findspec(B, Q, G, C): 46 | d_stride = 8 47 | ms = factors(B*Q) 48 | multiplier = 1 49 | for m in ms: 50 | if m <= 64 and (m * G * C // d_stride) <= 512: 51 | multiplier = m 52 | n_thread = multiplier * G * C // d_stride 53 | return d_stride, n_thread 54 | 55 | def findspec_bwd(B, Q, G, C): 56 | if C >= 64: 57 | d_stride = 2 58 | else: 59 | d_stride = 1 60 | 61 | ms = factors(B*Q) 62 | multiplier = 1 63 | for m in ms: 64 | if m <= 64 and (m * G * C // d_stride) <= 256: 65 | multiplier = m 66 | n_thread = multiplier * G * C // d_stride 67 | return d_stride, n_thread 68 | 69 | class FlashDeformAttnFunction(Function): 70 | @staticmethod 71 | @torch.autocast("cuda", enabled=True, dtype=torch.float16) 72 | def forward( 73 | ctx, value, value_spatial_shapes, value_level_start_index, 74 | sampling_loc_attn, im2col_step, K=8 75 | ): 76 | 77 | ctx.im2col_step = im2col_step 78 | ctx.K = K 79 | d_stride, blockthread = findspec(value.shape[0], sampling_loc_attn.shape[1], value.shape[2], value.shape[3]) 80 | d_stride_backward, blockthread_backward = findspec_bwd(value.shape[0], sampling_loc_attn.shape[1], value.shape[2], value.shape[3]) 81 | 82 | ctx.d_stride_backward = d_stride_backward 83 | ctx.blockthread_backward = blockthread_backward 84 | 85 | output = ext.flash_deform_attn_forward( 86 | value, 87 | value_spatial_shapes, 88 | value_level_start_index, 89 | sampling_loc_attn, 90 | ctx.im2col_step, 91 | K, 92 | d_stride, 93 | blockthread, 94 | ) 95 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_loc_attn) 96 | return output 97 | 98 | @staticmethod 99 | @once_differentiable 100 | def backward(ctx, grad_output): 101 | value, value_spatial_shapes, value_level_start_index, sampling_loc_attn = ctx.saved_tensors 102 | grad_value, grad_sampling_loc_attn = ext.flash_deform_attn_backward( 103 | value, 104 | value_spatial_shapes, 105 | value_level_start_index, 106 | sampling_loc_attn, 107 | grad_output.contiguous(), 108 | ctx.im2col_step, 109 | ctx.K, 110 | ctx.d_stride_backward, 111 | ctx.blockthread_backward, 112 | ) 113 | 114 | return grad_value, None, None, grad_sampling_loc_attn, None, None 115 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/DCNv4/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .flash_deform_attn import FlashDeformAttn 10 | from .dcnv4 import DCNv4 11 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/DCNv4/modules/dcnv4.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Deformable Convolution v4 3 | # Copyright (c) 2023 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import math 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | from torch.nn.init import xavier_uniform_, constant_ 16 | from ..functions import DCNv4Function 17 | 18 | class CenterFeatureScaleModule(nn.Module): 19 | def forward(self, 20 | query, 21 | center_feature_scale_proj_weight, 22 | center_feature_scale_proj_bias): 23 | center_feature_scale = F.linear(query, 24 | weight=center_feature_scale_proj_weight, 25 | bias=center_feature_scale_proj_bias).sigmoid() 26 | return center_feature_scale 27 | 28 | class DCNv4(nn.Module): 29 | def __init__( 30 | self, 31 | channels=64, 32 | kernel_size=3, 33 | stride=1, 34 | pad=1, 35 | dilation=1, 36 | group=4, 37 | offset_scale=1.0, 38 | dw_kernel_size=None, 39 | center_feature_scale=False, 40 | remove_center=False, 41 | output_bias=True, 42 | without_pointwise=False, 43 | **kwargs): 44 | """ 45 | DCNv4 Module 46 | :param channels 47 | :param kernel_size 48 | :param stride 49 | :param pad 50 | :param dilation 51 | :param group 52 | :param offset_scale 53 | :param act_layer 54 | :param norm_layer 55 | """ 56 | super().__init__() 57 | if channels % group != 0: 58 | raise ValueError( 59 | f'channels must be divisible by group, but got {channels} and {group}') 60 | _d_per_group = channels // group 61 | 62 | # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation 63 | assert _d_per_group % 16 == 0 64 | 65 | self.offset_scale = offset_scale 66 | self.channels = channels 67 | self.kernel_size = kernel_size 68 | self.stride = stride 69 | self.dilation = dilation 70 | self.pad = pad 71 | self.group = group 72 | self.group_channels = channels // group 73 | self.offset_scale = offset_scale 74 | self.dw_kernel_size = dw_kernel_size 75 | self.center_feature_scale = center_feature_scale 76 | self.remove_center = int(remove_center) 77 | self.without_pointwise = without_pointwise 78 | 79 | self.K = group * (kernel_size * kernel_size - self.remove_center) 80 | if dw_kernel_size is not None: 81 | self.offset_mask_dw = nn.Conv2d(channels, channels, dw_kernel_size, stride=1, padding=(dw_kernel_size - 1) // 2, groups=channels) 82 | self.offset_mask = nn.Linear(channels, int(math.ceil((self.K * 3)/8)*8)) 83 | if not without_pointwise: 84 | self.value_proj = nn.Linear(channels, channels) 85 | self.output_proj = nn.Linear(channels, channels, bias=output_bias) 86 | self._reset_parameters() 87 | 88 | if center_feature_scale: 89 | self.center_feature_scale_proj_weight = nn.Parameter( 90 | torch.zeros((group, channels), dtype=torch.float)) 91 | self.center_feature_scale_proj_bias = nn.Parameter( 92 | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) 93 | self.center_feature_scale_module = CenterFeatureScaleModule() 94 | 95 | def _reset_parameters(self): 96 | constant_(self.offset_mask.weight.data, 0.) 97 | constant_(self.offset_mask.bias.data, 0.) 98 | if not self.without_pointwise: 99 | xavier_uniform_(self.value_proj.weight.data) 100 | constant_(self.value_proj.bias.data, 0.) 101 | xavier_uniform_(self.output_proj.weight.data) 102 | if self.output_proj.bias is not None: 103 | constant_(self.output_proj.bias.data, 0.) 104 | 105 | def forward(self, input, shape=None): 106 | """ 107 | :param query (N, H, W, C) 108 | :return output (N, H, W, C) 109 | """ 110 | N, L, C = input.shape 111 | if shape is not None: 112 | H, W = shape 113 | else: 114 | H, W = int(L**0.5), int(L**0.5) 115 | 116 | 117 | x = input 118 | if not self.without_pointwise: 119 | x = self.value_proj(x) 120 | x = x.reshape(N, H, W, -1) 121 | if self.dw_kernel_size is not None: 122 | offset_mask_input = self.offset_mask_dw(input.view(N, H, W, C).permute(0, 3, 1, 2)) 123 | offset_mask_input = offset_mask_input.permute(0, 2, 3, 1).view(N, L, C) 124 | else: 125 | offset_mask_input = input 126 | offset_mask = self.offset_mask(offset_mask_input).reshape(N, H, W, -1) 127 | 128 | x_proj = x 129 | 130 | x = DCNv4Function.apply( 131 | x, offset_mask, 132 | self.kernel_size, self.kernel_size, 133 | self.stride, self.stride, 134 | self.pad, self.pad, 135 | self.dilation, self.dilation, 136 | self.group, self.group_channels, 137 | self.offset_scale, 138 | 256, 139 | self.remove_center 140 | ) 141 | 142 | if self.center_feature_scale: 143 | center_feature_scale = self.center_feature_scale_module( 144 | x, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) 145 | center_feature_scale = center_feature_scale[..., None].repeat( 146 | 1, 1, 1, 1, self.channels // self.group).flatten(-2) 147 | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale 148 | 149 | x = x.view(N, L, -1) 150 | 151 | if not self.without_pointwise: 152 | x = self.output_proj(x) 153 | return x 154 | 155 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/DCNv4/modules/flash_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import FlashDeformAttnFunction 22 | 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n - 1) == 0) and n != 0 28 | 29 | 30 | class FlashDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError("d_model must be divisible by n_heads, but got {} and {}".format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn( 46 | "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 47 | "which is more efficient in our CUDA implementation." 48 | ) 49 | 50 | self.im2col_step = 64 51 | 52 | self.d_model = d_model 53 | self.n_levels = n_levels 54 | self.n_heads = n_heads 55 | self.n_points = n_points 56 | 57 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 58 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 59 | self.value_proj = nn.Linear(d_model, d_model) 60 | self.output_proj = nn.Linear(d_model, d_model) 61 | 62 | self._reset_parameters() 63 | 64 | def _reset_parameters(self): 65 | constant_(self.sampling_offsets.weight.data, 0.0) 66 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 67 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 68 | grid_init = ( 69 | (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) 70 | .view(self.n_heads, 1, 1, 2) 71 | .repeat(1, self.n_levels, self.n_points, 1) 72 | ) 73 | for i in range(self.n_points): 74 | grid_init[:, :, i, :] *= i + 1 75 | with torch.no_grad(): 76 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 77 | constant_(self.attention_weights.weight.data, 0.0) 78 | constant_(self.attention_weights.bias.data, 0.0) 79 | xavier_uniform_(self.value_proj.weight.data) 80 | constant_(self.value_proj.bias.data, 0.0) 81 | xavier_uniform_(self.output_proj.weight.data) 82 | constant_(self.output_proj.bias.data, 0.0) 83 | 84 | def forward( 85 | self, 86 | query, 87 | reference_points, 88 | input_flatten, 89 | input_spatial_shapes, 90 | input_level_start_index, 91 | input_padding_mask=None, 92 | ): 93 | """ 94 | :param query (N, Length_{query}, C) 95 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 96 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 97 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 98 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 99 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 100 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 101 | 102 | :return output (N, Length_{query}, C) 103 | """ 104 | N, Len_q, _ = query.shape 105 | N, Len_in, _ = input_flatten.shape 106 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 107 | 108 | value = self.value_proj(input_flatten) 109 | if input_padding_mask is not None: 110 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 111 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 112 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 113 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 114 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 115 | # N, Len_q, n_heads, n_levels, n_points, 2 116 | if reference_points.shape[-1] == 2: 117 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 118 | sampling_locations = ( 119 | reference_points[:, :, None, :, None, :] 120 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 121 | ) 122 | elif reference_points.shape[-1] == 4: 123 | sampling_locations = ( 124 | reference_points[:, :, None, :, None, :2] 125 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 126 | ) 127 | else: 128 | raise ValueError( 129 | "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) 130 | ) 131 | 132 | output = FlashDeformAttnFunction.apply( 133 | value, 134 | input_spatial_shapes, 135 | input_level_start_index, 136 | sampling_locations, 137 | attention_weights, 138 | self.im2col_step, 139 | self.n_points 140 | ) 141 | output = self.output_proj(output) 142 | return output 143 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/* 2 | include src/cuda/* 3 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/src/ops/DCNv4_op/__init__.py -------------------------------------------------------------------------------- /src/ops/DCNv4_op/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/scripts/find_best.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | class LineParser: 4 | def __init__(self) -> None: 5 | self.data = {} 6 | 7 | def parse(self, line): 8 | def startswith(line, lst): 9 | for ele in lst: 10 | if line.startswith(ele): 11 | return True 12 | return False 13 | 14 | if not startswith(line, ['1', '2', '3', '4', '5', '6', '7', '8', '9']): 15 | return 16 | 17 | eles = line.strip().split() 18 | key = eles[0] 19 | if key not in self.data: 20 | self.data[key] = [] 21 | 22 | self.data[key].append([eles[1], float(eles[2])]) 23 | 24 | def sort(self): 25 | for k, v in self.data.items(): 26 | nv = sorted(v, key=lambda x: x[1]) 27 | self.data[k] = nv 28 | 29 | def display_best(self): 30 | for k, v in self.data.items(): 31 | print(f'{k} \t {v[0][0]} \t {v[0][1]:.4f} \t {v[1][0]} \t {v[1][1]:.4f}') 32 | 33 | def display_best_python(self, output): 34 | res = {} 35 | def parse(spec): 36 | d_stride = int(spec.split('/')[0]) 37 | thread = int(spec.split('/')[1].split('(')[0]) 38 | m = int(spec.split('(')[1].split(')')[0]) 39 | return d_stride, thread, m 40 | 41 | for k, v in self.data.items(): 42 | res[k] = parse(v[0][0]) 43 | 44 | with open(output, "w") as f: 45 | json.dump(res, f, indent=4) 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--input', type=str) 50 | parser.add_argument('--output', type=str) 51 | args = parser.parse_args() 52 | 53 | with open(args.input) as f: 54 | lines = f.readlines() 55 | 56 | lineparser = LineParser() 57 | for line in lines: 58 | lineparser.parse(line) 59 | lineparser.sort() 60 | lineparser.display_best() 61 | lineparser.display_best_python(args.output) -------------------------------------------------------------------------------- /src/ops/DCNv4_op/scripts/search_bwd.sh: -------------------------------------------------------------------------------- 1 | python search_dcnv4_bwd_engine.py > res_bwd.txt 2 | python find_best.py --input res_bwd.txt --output table_bwd.py -------------------------------------------------------------------------------- /src/ops/DCNv4_op/scripts/search_dcnv4.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import time 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import math 11 | from torch.autograd import gradcheck 12 | import pandas as pd 13 | from easydict import EasyDict as edict 14 | import argparse 15 | 16 | from torch.cuda import Event 17 | 18 | from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch 19 | from functions.dcnv4_func import DCNv4Function 20 | torch.set_printoptions(threshold=10000) 21 | 22 | torch.manual_seed(3) 23 | 24 | 25 | #@torch.no_grad() 26 | def speed_test(func, args, inputs, name='Unknown'): 27 | 28 | tic = Event(enable_timing=True) 29 | toc = Event(enable_timing=True) 30 | # warmup 31 | for i in range(args.warmup_num): 32 | func(*inputs) 33 | 34 | total_time = 0 35 | tic.record() 36 | for i in range(args.test_num): 37 | o = func(*inputs) 38 | torch.cuda.synchronize() 39 | toc.record() 40 | 41 | avg_time = tic.elapsed_time(toc) / args.test_num 42 | # print( 43 | # f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms') 44 | return avg_time 45 | 46 | @torch.no_grad() 47 | def test(N, H_in, W_in, M, D, spec=None): 48 | Kh, Kw = 3, 3 49 | remove_center = False 50 | P = Kh * Kw - remove_center 51 | offset_scale = 2.0 52 | pad = 1 53 | dilation = 1 54 | stride = 1 55 | H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 56 | W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 57 | 58 | input = torch.rand(N, H_in, W_in, M*D).cuda() 59 | # print(input.shape) 60 | offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*2 61 | # offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*0 62 | mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 63 | mask_origin = mask_origin.half() 64 | mask = mask_origin 65 | # mask = torch.nn.functional.softmax(mask_origin, dim=-1) 66 | offset_mask = torch.cat([offset.unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2) 67 | 68 | im2col_step = 128 69 | 70 | input = input.half() 71 | offset = offset.half() 72 | mask = mask.half() 73 | offset_mask = offset_mask.half() 74 | 75 | dcnv3_args = [ 76 | input, 77 | offset, 78 | mask, 79 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, 80 | im2col_step, remove_center, 81 | ] 82 | output_pytorch = DCNv3Function.apply(*dcnv3_args) 83 | 84 | input1 = input.detach() 85 | 86 | def pad(om): 87 | padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3] 88 | padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om) 89 | return torch.cat([om, padded], dim=-1) 90 | 91 | dcnv4_args = [ 92 | input1, pad(offset_mask), 93 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, 94 | im2col_step, remove_center, 95 | spec[0], spec[1], 2, None 96 | # 8, 512, 2, 256 97 | ] 98 | output_flash_cuda = DCNv4Function.apply(*dcnv4_args) 99 | 100 | fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 101 | max_abs_err = (output_flash_cuda - output_pytorch).abs().max() 102 | max_rel_err = ((output_flash_cuda - output_pytorch).abs() / 103 | (output_pytorch.abs()+ 1e-3)).max() 104 | # print('>>> forward half') 105 | # print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 106 | if not fwdok: 107 | print(f"Wrong: {N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]})") 108 | return 109 | # assert(fwdok) 110 | 111 | test_args = edict({'warmup_num': 10000, 'test_num': 10000}) 112 | 113 | exp_time_dcnv4 = speed_test(DCNv4Function.apply, test_args, dcnv4_args, name='exp') 114 | torch.cuda.synchronize() 115 | print(f"{N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]}): {exp_time_dcnv4}") 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument("--n", type=int) 121 | parser.add_argument("--h", type=int) 122 | parser.add_argument("--w", type=int) 123 | parser.add_argument("--g", type=int) 124 | parser.add_argument("--c", type=int) 125 | parser.add_argument("--dstride", type=int) 126 | parser.add_argument("--blockthread", type=int) 127 | parser.add_argument("--multiplier", type=int) 128 | args = parser.parse_args() 129 | test(args.n, args.h, args.w, args.g, args.c, (args.dstride, args.blockthread, args.multiplier)) 130 | 131 | 132 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/scripts/search_dcnv4_bwd_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def factors(N): 4 | res = [] 5 | for i in range(1, N+1): 6 | if N % i == 0: 7 | res.append(i) 8 | return res 9 | 10 | if __name__ == '__main__': 11 | BATCH=64 12 | for N, Hin, Win in [(BATCH, 56, 56), (BATCH, 28, 28), (BATCH, 14, 14), (BATCH, 7, 7), 13 | (1, 200, 320), (1, 100, 160), (1, 50, 80), (1, 25, 40), (1, 64, 64)]: 14 | for group_channel in [16, 32, 64]: 15 | for group in [4, 5, 6, 7, 8]: 16 | for d_stride in [1, 2, 4]: 17 | for m in factors(N*Hin*Win): 18 | if m > 64: 19 | break 20 | block_thread = group * (group_channel//d_stride) * m 21 | if block_thread > 1024: 22 | break 23 | cmd = f"python search_dcnv4_bwd.py --n {N} --h {Hin} --w {Win} --g {group} --c {group_channel} --dstride {d_stride} --blockthread {block_thread} --multiplier {m}" 24 | os.system(cmd) -------------------------------------------------------------------------------- /src/ops/DCNv4_op/scripts/search_dcnv4_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def factors(N): 4 | res = [] 5 | for i in range(1, N+1): 6 | if N % i == 0: 7 | res.append(i) 8 | return res 9 | 10 | if __name__ == '__main__': 11 | BATCH=64 12 | for group_channel in [16, 32, 64]: 13 | for group in [4, 5, 6, 7, 8]: 14 | for N, Hin, Win in [(BATCH, 56, 56), (BATCH, 28, 28), (BATCH, 14, 14), (BATCH, 7, 7), 15 | (1, 200, 320), (1, 100, 160), (1, 50, 80), (1, 25, 40), (1, 64, 64)]: 16 | for d_stride in [2, 4, 8, 16]: 17 | for m in factors(N*Hin*Win): 18 | if m > 64: 19 | break 20 | block_thread = group * (group_channel//d_stride) * m 21 | if block_thread > 1024: 22 | break 23 | cmd = f"python search_dcnv4.py --n {N} --h {Hin} --w {Win} --g {group} --c {group_channel} --dstride {d_stride} --blockthread {block_thread} --multiplier {m}" 24 | os.system(cmd) -------------------------------------------------------------------------------- /src/ops/DCNv4_op/scripts/search_fwd.sh: -------------------------------------------------------------------------------- 1 | python search_dcnv4_engine.py > res.txt 2 | python find_best.py --input res.txt --output table.py -------------------------------------------------------------------------------- /src/ops/DCNv4_op/scripts/test_dcnv4.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # DCNv4 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import time 12 | import torch 13 | import torch.nn as nn 14 | import math 15 | from torch.autograd import gradcheck 16 | import pandas as pd 17 | from easydict import EasyDict as edict 18 | 19 | from torch.cuda import Event 20 | 21 | from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch 22 | from functions.dcnv4_func import DCNv4Function 23 | torch.set_printoptions(threshold=10000) 24 | 25 | H_in, W_in = 56, 56 26 | N, M, D = 64, 4, 32 27 | 28 | # H_in, W_in = 28, 28 29 | # N, M, D = 64, 8, 32 30 | 31 | # H_in, W_in = 14, 14 32 | # N, M, D = 64, 16, 32 33 | 34 | # H_in, W_in = 7, 7 35 | # N, M, D = 64, 32, 32 36 | 37 | # H_in, W_in = 8, 8 38 | # N, M, D = 128, 4, 16 39 | 40 | 41 | Kh, Kw = 3, 3 42 | remove_center = False 43 | P = Kh * Kw - remove_center 44 | offset_scale = 2.0 45 | pad = 1 46 | dilation = 1 47 | stride = 1 48 | H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 49 | W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 50 | 51 | torch.manual_seed(3) 52 | 53 | #@torch.no_grad() 54 | def speed_test(func, args, inputs, name='Unknown'): 55 | 56 | tic = Event(enable_timing=True) 57 | toc = Event(enable_timing=True) 58 | # warmup 59 | for i in range(args.warmup_num): 60 | func(*inputs) 61 | 62 | total_time = 0 63 | tic.record() 64 | for i in range(args.test_num): 65 | o = func(*inputs) 66 | torch.cuda.synchronize() 67 | toc.record() 68 | 69 | avg_time = tic.elapsed_time(toc) / args.test_num 70 | print( 71 | f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms') 72 | return avg_time 73 | 74 | @torch.no_grad() 75 | def check_forward_equal_with_pytorch_half(): 76 | input = torch.rand(N, H_in, W_in, M*D).cuda() 77 | print(input.shape) 78 | offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*10 79 | # offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*0 80 | mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 81 | mask_origin = mask_origin.half() 82 | mask = mask_origin 83 | # mask = torch.nn.functional.softmax(mask_origin, dim=-1) 84 | offset_mask = torch.cat([offset.unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2) 85 | 86 | im2col_step = 128 87 | 88 | input = input.half() 89 | offset = offset.half() 90 | mask = mask.half() 91 | offset_mask = offset_mask.half() 92 | 93 | dcnv3_args = [ 94 | input, 95 | offset, 96 | mask, 97 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, 98 | im2col_step, remove_center, 99 | ] 100 | output_pytorch = DCNv3Function.apply(*dcnv3_args) 101 | 102 | input1 = input.detach() 103 | 104 | def pad(om): 105 | padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3] 106 | padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om) 107 | return torch.cat([om, padded], dim=-1) 108 | 109 | dcnv4_args = [ 110 | input1, pad(offset_mask), 111 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, 112 | im2col_step, remove_center, 8, 512, 2, 256, True, True, 113 | ] 114 | output_flash_cuda = DCNv4Function.apply(*dcnv4_args) 115 | 116 | fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 117 | max_abs_err = (output_flash_cuda - output_pytorch).abs().max() 118 | max_rel_err = ((output_flash_cuda - output_pytorch).abs() / 119 | (output_pytorch.abs()+ 1e-3)).max() 120 | print('>>> forward half') 121 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 122 | assert(fwdok) 123 | 124 | test_args = edict({'warmup_num': 1000, 'test_num': 1000}) 125 | 126 | exp_time_dcnv4 = speed_test(DCNv4Function.apply, test_args, dcnv4_args, name='exp') 127 | exp_time_dcnv3 = speed_test(DCNv3Function.apply, test_args, dcnv3_args, name='exp') 128 | torch.cuda.synchronize() 129 | 130 | results = [{}] 131 | results[0]['dcnv3_time'] = exp_time_dcnv3 132 | results[0]['dcnv4_time'] = exp_time_dcnv4 133 | columns = list(results[0].keys()) 134 | 135 | outputs = pd.DataFrame(results, columns=columns) 136 | with pd.option_context( 137 | 'display.max_rows', None, 'display.max_columns', None, 138 | 'display.max_colwidth', None, 'display.width', None, 139 | 'display.precision', 4, ): 140 | print(outputs) 141 | 142 | 143 | if __name__ == '__main__': 144 | check_forward_equal_with_pytorch_half() 145 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/scripts/test_flash_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | from easydict import EasyDict as edict 13 | from torch.cuda import Event 14 | import pandas as pd 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions import MSDeformAttnFunction, FlashDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | # N, M, D = 1, 4, 8 25 | # # Lq, L, P = 2, 2, 2 26 | # # shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | # Lq, L, P = 1, 2, 8 28 | # shapes = torch.as_tensor([(8, 16), (4, 8)], dtype=torch.long).cuda() 29 | 30 | # N, M, D = 1, 8, 32 31 | # # Lq, L, P = 2, 2, 2 32 | # # shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 33 | # Lq, L, P = 300, 4, 4 34 | # # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (17, 19)], dtype=torch.long).cuda() 35 | # # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (16, 16)], dtype=torch.long).cuda() 36 | # # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (17, 19)], dtype=torch.long).cuda() 37 | # # shapes = torch.as_tensor([(17, 19), (4, 4)], dtype=torch.long).cuda() 38 | # shapes = torch.as_tensor([(100, 151), (50, 76), (25, 38), (13, 19)], dtype=torch.long).cuda() 39 | # # shapes = torch.as_tensor([(110, 151)], dtype=torch.long).cuda() 40 | 41 | # B:6 42 | # H:232 43 | # W:400 44 | # G:5 45 | # D: 16 46 | # channels: 80 47 | # kernel: 3 points = 3 * 3 48 | # num_split = 45 = kernel *kernel * G 49 | 50 | H = 256 51 | W = 256 52 | N, M, D = 1, 8, 32 53 | Lq, L, P = 100*152, 4, 8 54 | 55 | shapes = torch.Tensor([[100, 152], [ 50, 76], [ 25, 38], [ 13, 19]]).long().cuda() 56 | 57 | # x = x.reshape([B, H*W, G, D + self.num_split * 3]) 58 | # shapes = torch.as_tensor([(H, W)], dtype=torch.long).cuda() 59 | # shapes = torch.as_tensor([(H, W), (H // 2, W // 2)], dtype=torch.long).cuda() 60 | # shapes = torch.as_tensor([(H, W), (H // 2, W // 2), (H // 4, W // 4), (H // 8, W // 8)], dtype=torch.long).cuda() 61 | 62 | level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) 63 | S = sum([(H * W).item() for H, W in shapes]) 64 | print(S) 65 | 66 | def get_reference_points(spatial_shapes, device): 67 | reference_points_list = [] 68 | for lvl, (H_, W_) in enumerate(spatial_shapes): 69 | 70 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 71 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 72 | ref_y = ref_y.reshape(-1)[None] / (H_) 73 | ref_x = ref_x.reshape(-1)[None] / (W_) 74 | ref = torch.stack((ref_x, ref_y), -1) 75 | reference_points_list.append(ref) 76 | reference_points = torch.cat(reference_points_list, 1) 77 | # reference_points = reference_points[:, :, None] * valid_ratios[:, None] 78 | return reference_points 79 | 80 | 81 | torch.manual_seed(3) 82 | 83 | @torch.no_grad() 84 | def speed_test(func, args, inputs, name='Unknown'): 85 | 86 | tic = Event(enable_timing=True) 87 | toc = Event(enable_timing=True) 88 | # warmup 89 | for i in range(args.warmup_num): 90 | func(*inputs) 91 | 92 | tic.record() 93 | for i in range(args.test_num): 94 | func(*inputs) 95 | toc.record() 96 | torch.cuda.synchronize() 97 | 98 | avg_time = tic.elapsed_time(toc) / args.test_num 99 | print( 100 | f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms') 101 | return avg_time 102 | 103 | 104 | @torch.no_grad() 105 | def check_forward_equal_with_pytorch_half(): 106 | value = torch.rand(N, S, M, D).cuda() * 0.01 107 | # offset = (torch.rand(N, Lq, M, L, P, 2).cuda() * 2 - 1) / 10 108 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 109 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 110 | sampling_loc_attn = torch.cat([sampling_locations.reshape(N, Lq, M, L*P*2), attention_weights.reshape(N, Lq, M, L*P)], dim=-1) 111 | attention_weights = torch.nn.functional.softmax(attention_weights.flatten(-2, -1), dim=-1).unflatten(-1, (L, P)) 112 | 113 | 114 | im2col_step = 128 115 | 116 | flash_fn_args = ( 117 | value.half(), 118 | shapes, 119 | level_start_index, 120 | sampling_loc_attn.half(), 121 | im2col_step, 122 | P, 16 123 | ) 124 | output_cuda = ( 125 | FlashDeformAttnFunction.apply(*flash_fn_args) 126 | .detach() 127 | .cpu() 128 | ).double() 129 | 130 | fn_args = ( 131 | value, 132 | shapes, 133 | level_start_index, 134 | sampling_locations, 135 | attention_weights, 136 | im2col_step, 137 | ) 138 | 139 | output_pytorch = ( 140 | MSDeformAttnFunction.apply(*fn_args) 141 | .detach().double() 142 | .cpu() 143 | ) 144 | 145 | max_abs_err = (output_cuda - output_pytorch).abs().max() 146 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 147 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 148 | 149 | print( 150 | f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}" 151 | ) 152 | 153 | test_args = edict({'warmup_num': 1000, 'test_num': 1000}) 154 | exp_time_base = speed_test( 155 | MSDeformAttnFunction.apply, test_args, fn_args, name='exp') 156 | exp_time = speed_test( 157 | FlashDeformAttnFunction.apply, test_args, flash_fn_args, name='exp') 158 | 159 | results = [{}] 160 | results[0]['time'] = exp_time 161 | results[0]['time_base'] = exp_time_base 162 | columns = list(results[0].keys()) 163 | 164 | outputs = pd.DataFrame(results, columns=columns) 165 | with pd.option_context( 166 | 'display.max_rows', None, 'display.max_columns', None, 167 | 'display.max_colwidth', None, 'display.width', None, 168 | 'display.precision', 4, ): 169 | print(outputs) 170 | 171 | 172 | if __name__ == "__main__": 173 | check_forward_equal_with_pytorch_half() 174 | 175 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/scripts/test_flash_deform_attn_backward.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | from easydict import EasyDict as edict 13 | from torch.cuda import Event 14 | import pandas as pd 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions import MSDeformAttnFunction, ms_deform_attn_core_pytorch, FlashDeformAttnFunction 22 | 23 | 24 | H = 256 25 | W = 256 26 | N, M, D = 1, 8, 16 27 | Lq, L, P = H * W, 1, 8 28 | 29 | # x = x.reshape([B, H*W, G, D + self.num_split * 3]) 30 | shapes = torch.as_tensor([(H, W)], dtype=torch.long).cuda() 31 | # shapes = torch.as_tensor([(H, W), (H // 2, W // 2)], dtype=torch.long).cuda() 32 | # shapes = torch.as_tensor([(H, W), (H // 2, W // 2), (H // 4, W // 4), (H // 8, W // 8)], dtype=torch.long).cuda() 33 | 34 | H = 256 35 | W = 256 36 | N, M, D = 1, 8, 32 37 | Lq, L, P = 100*152, 4, 8 38 | 39 | shapes = torch.Tensor([[100, 152], [ 50, 76], [ 25, 38], [ 13, 19]]).long().cuda() 40 | 41 | level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) 42 | S = sum([(H * W).item() for H, W in shapes]) 43 | 44 | def get_reference_points(spatial_shapes, device): 45 | reference_points_list = [] 46 | for lvl, (H_, W_) in enumerate(spatial_shapes): 47 | 48 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 49 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 50 | ref_y = ref_y.reshape(-1)[None] / (H_) 51 | ref_x = ref_x.reshape(-1)[None] / (W_) 52 | ref = torch.stack((ref_x, ref_y), -1) 53 | reference_points_list.append(ref) 54 | reference_points = torch.cat(reference_points_list, 1) 55 | # reference_points = reference_points[:, :, None] * valid_ratios[:, None] 56 | return reference_points 57 | 58 | 59 | torch.manual_seed(3) 60 | 61 | @torch.no_grad() 62 | def speed_test(func, args, inputs, name='Unknown'): 63 | 64 | tic = Event(enable_timing=True) 65 | toc = Event(enable_timing=True) 66 | # warmup 67 | for i in range(args.warmup_num): 68 | func(*inputs) 69 | 70 | tic.record() 71 | for i in range(args.test_num): 72 | func(*inputs) 73 | toc.record() 74 | torch.cuda.synchronize() 75 | 76 | avg_time = tic.elapsed_time(toc) / args.test_num 77 | print( 78 | f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms') 79 | return avg_time 80 | 81 | 82 | def check_forward_equal_with_pytorch_half(): 83 | value = torch.rand(N, S, M, D).cuda() * 0.01 84 | offset = (torch.rand(N, Lq, M, L, P, 2).cuda() * 2 - 1) / 10 85 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 86 | attention_weights_origin = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 87 | attention_weights_origin.requires_grad = True 88 | sampling_loc_attn = torch.cat([sampling_locations.detach().reshape(N, Lq, M, L*P*2), attention_weights_origin.detach().reshape(N, Lq, M, L*P)], dim=-1) 89 | 90 | attention_weights = torch.nn.functional.softmax(attention_weights_origin.flatten(-2, -1), dim=-1).unflatten(-1, (L, P)) 91 | 92 | 93 | im2col_step = 128 94 | 95 | value.requires_grad = True 96 | sampling_loc_attn.requires_grad = True 97 | output_cuda = ( 98 | FlashDeformAttnFunction.apply( 99 | value.float(), 100 | shapes, 101 | level_start_index, 102 | sampling_loc_attn.float(), 103 | im2col_step, 104 | ) 105 | ) 106 | (output_cuda.float().sum()/10).backward() 107 | 108 | 109 | value1 = value.detach() 110 | value1.requires_grad = True 111 | sampling_locations.requires_grad = True 112 | #attention_weights.requires_grad = True 113 | output_pytorch = ( 114 | ms_deform_attn_core_pytorch(value1, shapes, sampling_locations, attention_weights) 115 | ) 116 | (output_pytorch.sum()/10).backward() 117 | 118 | max_abs_err = (output_cuda.float() - output_pytorch).abs().max() 119 | max_rel_err = ((output_cuda.float() - output_pytorch).abs() / output_pytorch.abs()).max() 120 | fwdok = torch.allclose(output_cuda.float(), output_pytorch, rtol=1e-2, atol=1e-3) 121 | print(fwdok) 122 | print(max_abs_err, max_rel_err) 123 | #exit() 124 | 125 | bwdok1 = torch.allclose(value.grad, value1.grad, rtol=1e-2, atol=1e-3) 126 | print(bwdok1) 127 | # rel_err = (sampling_locations.grad - sampling_loc_attn.grad[..., :L*P*2].reshape(*sampling_locations.shape)).abs()/(sampling_locations.grad.abs()+1e-3) 128 | # print(rel_err.max()) 129 | 130 | locgrad1 = sampling_locations.grad 131 | locgrad2 = sampling_loc_attn.grad[..., :L*P*2].reshape(*sampling_locations.shape) 132 | bwdok2 = torch.allclose(locgrad1, locgrad2, rtol=1e-2, atol=1e-3) 133 | print(bwdok2) 134 | rel_err = (locgrad1 - locgrad2).abs()/(locgrad1.abs()+1e-3) 135 | print(rel_err.max()) 136 | 137 | attngrad1 = attention_weights_origin.grad 138 | attngrad2 = sampling_loc_attn.grad[..., L*P*2:].reshape(*attention_weights_origin.shape) 139 | bwdok3 = torch.allclose(locgrad1, locgrad2, rtol=1e-2, atol=1e-3) 140 | print(bwdok3) 141 | rel_err = (attngrad1 - attngrad2).abs()/(attngrad1.abs()+1e-3) 142 | print(rel_err.max()) 143 | exit() 144 | #exit() 145 | 146 | # pdb.set_trace() 147 | max_abs_err = (output_cuda - output_pytorch).abs().max() 148 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 149 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 150 | 151 | print( 152 | f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}" 153 | ) 154 | 155 | 156 | fn_args = ( 157 | value, 158 | shapes, 159 | level_start_index, 160 | sampling_locations, 161 | attention_weights, 162 | im2col_step, 163 | ) 164 | 165 | flash_dcn_fn_args = ( 166 | value.half(), 167 | shapes, 168 | level_start_index, 169 | sampling_loc_attn.half(), 170 | im2col_step, 171 | ) 172 | 173 | 174 | test_args = edict({'warmup_num': 50, 'test_num': 100}) 175 | exp_time = speed_test( 176 | FlashMSDeformAttnFunction.apply, test_args, flash_dcn_fn_args, name='exp') 177 | exp_time_base = speed_test( 178 | MSDeformAttnFunction.apply, test_args, fn_args, name='exp') 179 | 180 | results = [{}] 181 | results[0]['time'] = exp_time 182 | results[0]['time_base'] = exp_time_base 183 | columns = list(results[0].keys()) 184 | 185 | outputs = pd.DataFrame(results, columns=columns) 186 | with pd.option_context( 187 | 'display.max_rows', None, 'display.max_columns', None, 188 | 'display.max_colwidth', None, 'display.width', None, 189 | 'display.precision', 4, ): 190 | print(outputs) 191 | 192 | 193 | if __name__ == "__main__": 194 | check_forward_equal_with_pytorch_half() -------------------------------------------------------------------------------- /src/ops/DCNv4_op/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable Convolution v4 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | "-O3", 46 | ] 47 | else: 48 | raise NotImplementedError('Cuda is not available') 49 | 50 | sources = [os.path.join(extensions_dir, s) for s in sources] 51 | include_dirs = [extensions_dir] 52 | ext_modules = [ 53 | extension( 54 | "DCNv4.ext", 55 | sources, 56 | include_dirs=include_dirs, 57 | define_macros=define_macros, 58 | extra_compile_args=extra_compile_args, 59 | ) 60 | ] 61 | return ext_modules 62 | 63 | setup( 64 | name="DCNv4", 65 | version="1.0.0.post2", 66 | author="Yuwen Xiong, Feng Wang", 67 | url="", 68 | description="PyTorch Wrapper for CUDA Functions of DCNv4", 69 | packages=['DCNv4', 'DCNv4/functions', 'DCNv4/modules'], 70 | ext_modules=get_extensions(), 71 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 72 | ) 73 | -------------------------------------------------------------------------------- /src/ops/DCNv4_op/src/cuda/dcnv4_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #pragma once 13 | #include 14 | 15 | at::Tensor dcnv4_cuda_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &p_offset, 18 | const int kernel_h, const int kernel_w, const int stride_h, 19 | const int stride_w, const int pad_h, const int pad_w, const int dilation_h, 20 | const int dilation_w, const int group, const int group_channels, 21 | const float offset_scale, const int im2col_step, const int remove_center, 22 | const int d_stride, const int block_thread, const bool softmax); 23 | 24 | std::vector 25 | dcnv4_cuda_backward( 26 | const at::Tensor &value, 27 | const at::Tensor &p_offset, 28 | const int kernel_h, const int kernel_w, const int stride_h, 29 | const int stride_w, const int pad_h, const int pad_w, const int dilation_h, 30 | const int dilation_w, const int group, const int group_channels, 31 | const float offset_scale, const int im2col_step, const at::Tensor &grad_output, 32 | const int remove_center, const int d_stride, const int block_thread, 33 | const bool softmax); -------------------------------------------------------------------------------- /src/ops/DCNv4_op/src/cuda/flash_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #pragma once 13 | #include 14 | 15 | at::Tensor flash_deform_attn_cuda_forward( 16 | const at::Tensor &value, const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, const at::Tensor &sampling_loc_attn, 18 | const int im2col_step, const int K, const int d_stride, const int block_thread); 19 | 20 | std::vector 21 | flash_deform_attn_cuda_backward( 22 | const at::Tensor &value, const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, const at::Tensor &sampling_loc_attn, 24 | const at::Tensor &grad_output, const int im2col_step, const int K, 25 | const int d_stride, const int block_thread); -------------------------------------------------------------------------------- /src/ops/DCNv4_op/src/dcnv4.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #pragma once 13 | 14 | #ifdef WITH_CUDA 15 | #include "cuda/dcnv4_cuda.h" 16 | #include "cuda/flash_deform_attn_cuda.h" 17 | #endif 18 | 19 | at::Tensor flash_deform_attn_forward(const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc_attn, 23 | const int im2col_step, const int K, 24 | const int d_stride, const int block_thread) { 25 | if (value.device().is_cuda()) { 26 | #ifdef WITH_CUDA 27 | return flash_deform_attn_cuda_forward(value, spatial_shapes, 28 | level_start_index, 29 | sampling_loc_attn, im2col_step, 30 | K, d_stride, block_thread); 31 | #else 32 | AT_ERROR("Not compiled with GPU support"); 33 | #endif 34 | } 35 | AT_ERROR("Not implemented on the CPU"); 36 | } 37 | 38 | std::vector 39 | flash_deform_attn_backward(const at::Tensor &value, 40 | const at::Tensor &spatial_shapes, 41 | const at::Tensor &level_start_index, 42 | const at::Tensor &sampling_loc_attn, 43 | const at::Tensor &grad_output, 44 | const int im2col_step, 45 | const int K, 46 | const int d_stride, const int block_thread){ 47 | if (value.device().is_cuda()) { 48 | #ifdef WITH_CUDA 49 | return flash_deform_attn_cuda_backward(value, 50 | spatial_shapes, 51 | level_start_index, 52 | sampling_loc_attn, 53 | grad_output, 54 | im2col_step, 55 | K, d_stride, 56 | block_thread); 57 | #else 58 | AT_ERROR("Not compiled with GPU support"); 59 | #endif 60 | } 61 | AT_ERROR("Not implemented on the CPU"); 62 | } 63 | 64 | at::Tensor dcnv4_forward( 65 | const at::Tensor &value, 66 | const at::Tensor &p_offset, 67 | const int kernel_h, const int kernel_w, const int stride_h, 68 | const int stride_w, const int pad_h, const int pad_w, const int dilation_h, 69 | const int dilation_w, const int group, const int group_channels, 70 | const float offset_scale, const int im2col_step, const int remove_center, 71 | const int d_stride, const int block_thread, const bool softmax) { 72 | if (value.device().is_cuda()) { 73 | #ifdef WITH_CUDA 74 | return dcnv4_cuda_forward( 75 | value, p_offset, kernel_h, kernel_w, stride_h, stride_w, pad_h, 76 | pad_w, dilation_h, dilation_w, group, group_channels, offset_scale, 77 | im2col_step, remove_center, d_stride, block_thread, softmax); 78 | #else 79 | AT_ERROR("Not compiled with GPU support"); 80 | #endif 81 | } 82 | AT_ERROR("Not implemented on the CPU"); 83 | } 84 | 85 | std::vector 86 | dcnv4_backward( 87 | const at::Tensor &value, 88 | const at::Tensor &p_offset, 89 | const int kernel_h, const int kernel_w, const int stride_h, 90 | const int stride_w, const int pad_h, const int pad_w, const int dilation_h, 91 | const int dilation_w, const int group, const int group_channels, 92 | const float offset_scale, const int im2col_step, const at::Tensor &grad_output, 93 | const int remove_center, const int d_stride, const int block_thread, 94 | const bool softmax){ 95 | if (value.device().is_cuda()) { 96 | #ifdef WITH_CUDA 97 | return dcnv4_cuda_backward( 98 | value, p_offset, kernel_h, kernel_w, stride_h, stride_w, pad_h, 99 | pad_w, dilation_h, dilation_w, group, group_channels, offset_scale, 100 | im2col_step, grad_output, remove_center, d_stride, block_thread, 101 | softmax); 102 | #else 103 | AT_ERROR("Not compiled with GPU support"); 104 | #endif 105 | } 106 | AT_ERROR("Not implemented on the CPU"); 107 | } -------------------------------------------------------------------------------- /src/ops/DCNv4_op/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #include "dcnv4.h" 13 | 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("flash_deform_attn_forward", &flash_deform_attn_forward, 16 | "flash_deform_attn_forward"); 17 | m.def("flash_deform_attn_backward", &flash_deform_attn_backward, 18 | "flash_deform_attn_backward"); 19 | m.def("dcnv4_forward", &dcnv4_forward, "dcnv4_forward"); 20 | m.def("dcnv4_backward", &dcnv4_backward, "dcnv4_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /src/ops/cuda_kernels/forward.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | @triton.autotune( 5 | configs=[ 6 | # triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1), 7 | triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), 8 | triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), 9 | ], 10 | key=['B', 'H', 'W', 'G', 'C', 'K'], 11 | ) 12 | @triton.jit 13 | def forward_kernel( 14 | B: tl.constexpr, 15 | H: tl.constexpr, # image_size_h 16 | W: tl.constexpr, # image_size_w 17 | G: tl.constexpr, # num_channels_per_group 18 | C: tl.constexpr, # num_groups 19 | K: tl.constexpr, # kernel size 20 | input_ptr, # input features [B, H, W, G, C] 21 | deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] 22 | weights_ptr, # weights [B, H, W, G, K] 23 | out_ptr, # out [B, H, W, G, C] 24 | BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group 25 | ): 26 | pid = tl.program_id(0) 27 | wid = pid % W 28 | hid = pid // W % H 29 | gid = pid // (W * H) % G 30 | bid = pid // (W * H * G) 31 | 32 | id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) 33 | common_offset = bid*H*W*G + hid*W*G + wid*G + gid 34 | batch_base = bid * H * W * G * C 35 | 36 | for block_base in tl.static_range(0, C, BLOCK_SIZE): 37 | buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) 38 | block_offset = tl.arange(0, BLOCK_SIZE) + block_base 39 | block_mask = (block_offset < C) & id_mask 40 | for k in tl.static_range(K): 41 | deformable_offset = (common_offset * K + k) * 2 42 | 43 | x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid 44 | y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid 45 | 46 | floor_x = x.to(tl.int32) 47 | floor_y = y.to(tl.int32) 48 | ceil_x = floor_x + 1 49 | ceil_y = floor_y + 1 50 | 51 | # load top left 52 | tl_weight = (ceil_x - x) * (ceil_y - y) 53 | tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE 54 | tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) 55 | 56 | # load top right 57 | tr_weight = (x - floor_x) * (ceil_y - y) 58 | tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE 59 | tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) 60 | # load bottom left 61 | bl_weight = (ceil_x - x) * (y - floor_y) 62 | bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE 63 | bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) 64 | # load bottom right 65 | br_weight = (x - floor_x) * (y - floor_y) 66 | br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE 67 | br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) 68 | 69 | # load dynamic weight and mask 70 | weights_offset = common_offset*K + k 71 | weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) 72 | 73 | 74 | 75 | tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) 76 | tl_block_input = tl_block_input * tl_weight 77 | 78 | # load top right 79 | tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) 80 | tr_block_input = tr_block_input * tr_weight 81 | # load bottom left 82 | bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) 83 | bl_block_input = bl_block_input * bl_weight 84 | # load bottom right 85 | br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) 86 | br_block_input = br_block_input * br_weight 87 | 88 | # sampled 89 | sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input 90 | 91 | weighted_sampled_input = sampled_input * weight 92 | buffer = buffer + weighted_sampled_input 93 | # store to out_ptr 94 | tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) 95 | 96 | -------------------------------------------------------------------------------- /src/ops/cuda_kernels/function.py: -------------------------------------------------------------------------------- 1 | import time 2 | import backward 3 | 4 | import math 5 | import torch 6 | from typing import Any 7 | from torch.autograd import Function 8 | from torch.autograd.function import once_differentiable 9 | from torch.cuda.amp import custom_fwd, custom_bwd 10 | from .forward import forward_kernel 11 | 12 | 13 | class DeformableAttentionFunction(Function): 14 | BP_FUNCS = [ 15 | backward.backward_p1_c2_tile16_thread128, 16 | backward.backward_p1_c4_tile16_thread128, 17 | backward.backward_p2_c2_tile16_thread128, 18 | backward.backward_p1_c2_tile16_thread256, 19 | backward.backward_p1_c4_tile16_thread256, 20 | backward.backward_p2_c2_tile16_thread256, 21 | backward.backward_p1_c2_tile16_thread384, 22 | backward.backward_p1_c4_tile16_thread384, 23 | backward.backward_p2_c2_tile16_thread384, 24 | backward.backward_p1_c2_tile16_thread512, 25 | backward.backward_p1_c4_tile16_thread512, 26 | backward.backward_p2_c2_tile16_thread512, 27 | backward.backward_p1_c2_tile16_thread768, 28 | backward.backward_p1_c4_tile16_thread768, 29 | backward.backward_p2_c2_tile16_thread768, 30 | backward.backward_p1_c2_tile32_thread128, 31 | backward.backward_p1_c2_tile32_thread256, 32 | backward.backward_p1_c2_tile32_thread384, 33 | backward.backward_p1_c2_tile32_thread512, 34 | ] 35 | BP_TABLES = dict() 36 | 37 | 38 | @staticmethod 39 | @custom_fwd(cast_inputs=torch.float16) 40 | def forward(ctx: Any, inputs, deformables, weights) -> Any: 41 | B, H, W, G, C = inputs.shape 42 | _, _, _, _, K = weights.shape 43 | out = torch.zeros_like(inputs) 44 | grid = lambda META: (B * H * W * G,) 45 | forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out) 46 | ctx.save_for_backward(inputs, deformables, weights) 47 | return out 48 | @staticmethod 49 | def find_bp_funcs(values, deformables, weights, grad_out): 50 | B, H, W, G, C = values.shape 51 | B, H, W, G, K = weights.shape 52 | hash_value = 10000 * B + 100 * H + W + 1000 * G 53 | if hash_value in DeformableAttentionFunction.BP_TABLES.keys(): 54 | return DeformableAttentionFunction.BP_TABLES[hash_value] 55 | print("missing") 56 | candicate_func = None 57 | min_t = 999.0 58 | grad_values = torch.zeros_like(values) 59 | grad_deformables = torch.zeros_like(deformables) 60 | grad_weights = torch.zeros_like(weights) 61 | for func in DeformableAttentionFunction.BP_FUNCS: 62 | t = [] 63 | for i in range(100): 64 | torch.cuda.synchronize() 65 | start_t = time.time() 66 | func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights) 67 | torch.cuda.synchronize() 68 | t.append(time.time() - start_t) 69 | t = t[-50:] 70 | t = sum(t) / len(t) 71 | if t < min_t: 72 | min_t = t 73 | DeformableAttentionFunction.BP_TABLES[hash_value] = func 74 | candicate_func = func 75 | assert candicate_func is not None 76 | print(candicate_func) 77 | return candicate_func 78 | 79 | @staticmethod 80 | @once_differentiable 81 | @custom_bwd 82 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 83 | grad_out = grad_outputs[0] 84 | values, deformables, weights = ctx.saved_tensors 85 | B, H, W, G, C = values.shape 86 | B, H, W, G, K = weights.shape 87 | func = DeformableAttentionFunction.find_bp_funcs(values, deformables, weights, grad_out) 88 | grad_values = torch.zeros_like(values) 89 | grad_deformables = torch.zeros_like(deformables) 90 | grad_weights = torch.zeros_like(weights) 91 | func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights) 92 | # grad_values = torch.nan_to_num(grad_values, nan=0.0, posinf=0.0, neginf=0.0) 93 | # grad_deformables = torch.nan_to_num(grad_deformables, nan=0.0, posinf=0.0, neginf=0.0) 94 | # grad_weights = torch.nan_to_num(grad_weights, nan=0.0, posinf=0.0, neginf=0.0) 95 | return grad_values, grad_deformables, grad_weights -------------------------------------------------------------------------------- /src/ops/cuda_kernels/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='mycuda', 6 | ext_modules=[ 7 | CUDAExtension('mycuda', ['./backward.cu',], 8 | extra_compile_args={'cxx': [], 'nvcc': [ 9 | "-O3", 10 | "-DCUDA_HAS_FP16=1", 11 | "-D__CUDA_NO_HALF_OPERATORS__", 12 | "-D__CUDA_NO_HALF_CONVERSIONS__", 13 | "-D__CUDA_NO_HALF2_OPERATORS__", 14 | ]} 15 | ), 16 | ], 17 | cmdclass={ 18 | 'build_ext': BuildExtension 19 | } 20 | ) -------------------------------------------------------------------------------- /src/ops/triton_kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/src/ops/triton_kernels/__init__.py -------------------------------------------------------------------------------- /src/ops/triton_kernels/backward.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | @triton.autotune( 5 | configs=[ 6 | triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1), 7 | triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), 8 | ], 9 | key=['B', 'H', 'W', 'G', 'C', 'K'], 10 | ) 11 | @triton.jit 12 | def backward_kernel( 13 | B: tl.constexpr, 14 | H: tl.constexpr, # image_size_h 15 | W: tl.constexpr, # image_size_w 16 | G: tl.constexpr, # num_groups 17 | C: tl.constexpr, # num_channels_per_group 18 | K: tl.constexpr, # kernel size 19 | input_ptr, # input features [B, H, W, G, C] 20 | deformable_ptr, # deformable offsets [B, H, W, G, K, 2] 21 | weights_ptr, # weights [B, H, W, G, K] 22 | grad_ptr, # out [B, H, W, G, C] 23 | grad_input_ptr, # input features [B, H, W, G, C] 24 | grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2] 25 | grad_weights_ptr, # weights [B, H, W, G, K] 26 | BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group 27 | ): 28 | 29 | pid = tl.program_id(0) 30 | wid = pid % W 31 | hid = pid // W % H 32 | gid = pid // (W * H) % G 33 | bid = pid // (W * H * G) 34 | 35 | id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) 36 | 37 | common_offset = bid*H*W*G + hid*W*G + wid*G + gid 38 | batch_base = bid * H * W * G * C 39 | for k in tl.static_range(K): 40 | # load dynamic weight and mask 41 | weights_offset = common_offset*K + k 42 | weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) 43 | dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) 44 | dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) 45 | dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty) 46 | deformable_offset = (common_offset * K + k)*2 47 | x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid 48 | y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid 49 | for block_base in tl.static_range(0, C, BLOCK_SIZE): 50 | block_offset = tl.arange(0, BLOCK_SIZE) + block_base 51 | block_mask = (block_offset < C) & id_mask 52 | grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0) 53 | dods = weight*grad 54 | 55 | floor_x = x.to(tl.int32) 56 | floor_y = y.to(tl.int32) 57 | ceil_x = floor_x + 1 58 | ceil_y = floor_y + 1 59 | 60 | # load top left 61 | tl_weight = (ceil_x - x) * (ceil_y - y) 62 | tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset 63 | tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)) 64 | tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0) 65 | tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0) 66 | dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y) 67 | dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x) 68 | dodw = dodw + tl_block_input_dot_grad * tl_weight 69 | 70 | dodtl = dods * tl_weight 71 | tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl) 72 | 73 | 74 | # load top right 75 | tr_weight = (x - floor_x) * (ceil_y - y) 76 | tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset 77 | tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)) 78 | tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0) 79 | tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0) 80 | dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y) 81 | dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x) 82 | dodw = dodw + tr_block_input_dot_grad*tr_weight 83 | 84 | dodtr = dods * tr_weight 85 | tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr) 86 | 87 | 88 | # load bottom left 89 | bl_weight = (ceil_x - x) * (y - floor_y) 90 | bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset 91 | bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)) 92 | bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0) 93 | bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0) 94 | dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y) 95 | dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x) 96 | dodw = dodw + bl_block_input_dot_grad*bl_weight 97 | 98 | dodbl = dods * bl_weight 99 | tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl) 100 | 101 | 102 | # load bottom right 103 | br_weight = (x - floor_x) * (y - floor_y) 104 | br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset 105 | br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)) 106 | br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0) 107 | br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask 108 | 109 | dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y) 110 | dody = dody + 1 * br_block_input_dot_grad * (x - floor_x) 111 | dodw = dodw + br_block_input_dot_grad*br_weight 112 | 113 | dodbr = dods * br_weight 114 | tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr) 115 | dodx = dodx * weight 116 | dody = dody * weight 117 | tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask) 118 | tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask) 119 | tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask) 120 | 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /src/ops/triton_kernels/forward.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | @triton.autotune( 5 | configs=[ 6 | triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), 7 | # triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), 8 | ], 9 | key=['B', 'H', 'W', 'G', 'C', 'K'], 10 | ) 11 | @triton.jit 12 | def forward_kernel( 13 | B: tl.constexpr, 14 | H: tl.constexpr, # image_size_h 15 | W: tl.constexpr, # image_size_w 16 | G: tl.constexpr, # num_channels_per_group 17 | C: tl.constexpr, # num_groups 18 | K: tl.constexpr, # kernel size 19 | input_ptr, # input features [B, H, W, G, C] 20 | deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] 21 | weights_ptr, # weights [B, H, W, G, K] 22 | out_ptr, # out [B, H, W, G, C] 23 | BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group 24 | ): 25 | pid = tl.program_id(0) 26 | wid = pid % W 27 | hid = pid // W % H 28 | gid = pid // (W * H) % G 29 | bid = pid // (W * H * G) 30 | 31 | id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) 32 | common_offset = bid*H*W*G + hid*W*G + wid*G + gid 33 | batch_base = bid * H * W * G * C 34 | 35 | for block_base in tl.static_range(0, C, BLOCK_SIZE): 36 | buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) 37 | block_offset = tl.arange(0, BLOCK_SIZE) + block_base 38 | block_mask = (block_offset < C) & id_mask 39 | for k in tl.static_range(K): 40 | deformable_offset = (common_offset * K + k) * 2 41 | 42 | x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid 43 | y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid 44 | 45 | floor_x = x.to(tl.int32) 46 | floor_y = y.to(tl.int32) 47 | ceil_x = floor_x + 1 48 | ceil_y = floor_y + 1 49 | 50 | # load top left 51 | tl_weight = (ceil_x - x) * (ceil_y - y) 52 | tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE 53 | tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) 54 | 55 | # load top right 56 | tr_weight = (x - floor_x) * (ceil_y - y) 57 | tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE 58 | tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) 59 | # load bottom left 60 | bl_weight = (ceil_x - x) * (y - floor_y) 61 | bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE 62 | bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) 63 | # load bottom right 64 | br_weight = (x - floor_x) * (y - floor_y) 65 | br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE 66 | br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) 67 | 68 | # load dynamic weight and mask 69 | weights_offset = common_offset*K + k 70 | weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) 71 | 72 | 73 | 74 | tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) 75 | tl_block_input = tl_block_input * tl_weight 76 | 77 | # load top right 78 | tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) 79 | tr_block_input = tr_block_input * tr_weight 80 | # load bottom left 81 | bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) 82 | bl_block_input = bl_block_input * bl_weight 83 | # load bottom right 84 | br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) 85 | br_block_input = br_block_input * br_weight 86 | 87 | # sampled 88 | sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input 89 | 90 | weighted_sampled_input = sampled_input * weight 91 | buffer = buffer + weighted_sampled_input 92 | # store to out_ptr 93 | tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) 94 | 95 | -------------------------------------------------------------------------------- /src/ops/triton_kernels/function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | from typing import Any 4 | from torch.autograd import Function 5 | from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd 6 | from .forward import forward_kernel 7 | from .backward import backward_kernel 8 | 9 | 10 | 11 | class DCNFunction(Function): 12 | 13 | @staticmethod 14 | @custom_fwd 15 | def forward(ctx: Any, inputs, deformables, weights) -> Any: 16 | B, H, W, G, C = inputs.shape 17 | _, _, _, _, K, _ = deformables.shape 18 | out = torch.zeros_like(inputs) 19 | grid = lambda META: (B * H * W * G,) 20 | 21 | forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out) 22 | ctx.save_for_backward(inputs, deformables, weights) 23 | return out 24 | 25 | @staticmethod 26 | @custom_bwd 27 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 28 | grad_output = grad_outputs[0].contiguous() 29 | 30 | inputs, deformables, weights = ctx.saved_tensors 31 | B, H, W, G, C = inputs.shape 32 | _, _, _, _, K, _ = deformables.shape 33 | 34 | grad_inputs = torch.zeros_like(inputs) 35 | grad_deformables = torch.zeros_like(deformables) 36 | grad_weights = torch.zeros_like(weights) 37 | grid = lambda META: (B * H * W * G,) 38 | backward_kernel[grid]( 39 | B, H, W, G, C, K, 40 | inputs, 41 | deformables, 42 | weights, 43 | grad_output, 44 | grad_inputs, 45 | grad_deformables, 46 | grad_weights, 47 | ) 48 | return (grad_inputs, grad_deformables, grad_weights) -------------------------------------------------------------------------------- /src/ops/triton_kernels_udcn/backward.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | @triton.autotune( 5 | configs=[ 6 | # triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1), 7 | # triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1), 8 | # triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=2), 9 | triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), 10 | ], 11 | key=['B', 'H', 'W', 'G', 'C', 'K'], 12 | ) 13 | @triton.jit 14 | def _backward_kernel( 15 | B: tl.constexpr, 16 | H: tl.constexpr, # image_size_h 17 | W: tl.constexpr, # image_size_w 18 | N: tl.constexpr, 19 | G: tl.constexpr, # num_groups 20 | C: tl.constexpr, # num_channels_per_group 21 | K: tl.constexpr, # kernel size 22 | input_ptr, # input features [B, H, W, G, C] 23 | deformable_ptr, # deformable offsets [B, H, W, G, K, 2] 24 | weights_ptr, # weights [B, H, W, G, K] 25 | grad_ptr, # out [B, H, W, G, C] 26 | grad_input_ptr, # input features [B, H, W, G, C] 27 | grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2] 28 | grad_weights_ptr, # weights [B, H, W, G, K] 29 | BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group 30 | ): 31 | 32 | pid = tl.program_id(0) 33 | nid = pid % N 34 | gid = pid // N % G 35 | bid = pid // (N * G) 36 | 37 | id_mask = (nid < N) & (gid < G) & (bid < B) 38 | 39 | common_offset = bid*N*G + nid*G + gid 40 | batch_base = bid * H * W * G * C 41 | for k in tl.static_range(K): 42 | # load dynamic weight and mask 43 | weights_offset = common_offset*K + k 44 | weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) 45 | dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) 46 | dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) 47 | dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty) 48 | deformable_offset = (common_offset * K + k)*2 49 | x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) 50 | y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) 51 | for block_base in tl.static_range(0, C, BLOCK_SIZE): 52 | block_offset = tl.arange(0, BLOCK_SIZE) + block_base 53 | block_mask = (block_offset < C) & id_mask 54 | grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0) 55 | dods = weight*grad 56 | 57 | floor_x = x.to(tl.int32) 58 | floor_y = y.to(tl.int32) 59 | ceil_x = floor_x + 1 60 | ceil_y = floor_y + 1 61 | 62 | # load top left 63 | tl_weight = (ceil_x - x) * (ceil_y - y) 64 | tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset 65 | tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)) 66 | tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0) 67 | tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0) 68 | dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y) 69 | dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x) 70 | dodw = dodw + tl_block_input_dot_grad * tl_weight 71 | 72 | dodtl = dods * tl_weight 73 | tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl) 74 | 75 | 76 | # load top right 77 | tr_weight = (x - floor_x) * (ceil_y - y) 78 | tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset 79 | tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)) 80 | tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0) 81 | tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0) 82 | dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y) 83 | dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x) 84 | dodw = dodw + tr_block_input_dot_grad*tr_weight 85 | 86 | dodtr = dods * tr_weight 87 | tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr) 88 | 89 | 90 | # load bottom left 91 | bl_weight = (ceil_x - x) * (y - floor_y) 92 | bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset 93 | bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)) 94 | bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0) 95 | bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0) 96 | dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y) 97 | dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x) 98 | dodw = dodw + bl_block_input_dot_grad*bl_weight 99 | 100 | dodbl = dods * bl_weight 101 | tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl) 102 | 103 | 104 | # load bottom right 105 | br_weight = (x - floor_x) * (y - floor_y) 106 | br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset 107 | br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)) 108 | br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0) 109 | br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask 110 | 111 | dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y) 112 | dody = dody + 1 * br_block_input_dot_grad * (x - floor_x) 113 | dodw = dodw + br_block_input_dot_grad*br_weight 114 | 115 | dodbr = dods * br_weight 116 | tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr) 117 | dodx = dodx * weight 118 | dody = dody * weight 119 | tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask) 120 | tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask) 121 | tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask) -------------------------------------------------------------------------------- /src/ops/triton_kernels_udcn/forward.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | @triton.autotune( 5 | configs=[ 6 | # triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1), 7 | # triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), 8 | triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), 9 | ], 10 | key=['B', 'H', 'W', 'N', 'G', 'C', 'K'], 11 | ) 12 | @triton.jit 13 | def _forward_kernel( 14 | B: tl.constexpr, 15 | H: tl.constexpr, # image_size_h 16 | W: tl.constexpr, # image_size_w 17 | N: tl.constexpr, 18 | G: tl.constexpr, # num_channels_per_group 19 | C: tl.constexpr, # num_groups 20 | K: tl.constexpr, # kernel size 21 | input_ptr, # input features [B, H, W, G, C] 22 | deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] 23 | weights_ptr, # weights [B, H, W, G, K] 24 | out_ptr, # out [B, H, W, G, C] 25 | BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group 26 | ): 27 | pid = tl.program_id(0) 28 | # wid = pid % W 29 | # hid = pid // W % H 30 | nid = pid % N 31 | gid = pid // (N) % G 32 | bid = pid // (N * G) 33 | 34 | id_mask = (nid < N) & (gid < G) & (bid < B) 35 | common_offset = bid*N*G + nid*G + gid 36 | batch_base = bid * H * W * G * C 37 | 38 | for block_base in tl.static_range(0, C, BLOCK_SIZE): 39 | buffer = tl.zeros((BLOCK_SIZE, ), dtype=out_ptr.dtype.element_ty) 40 | block_offset = tl.arange(0, BLOCK_SIZE) + block_base 41 | block_mask = (block_offset < C) & id_mask 42 | for k in tl.static_range(K): 43 | deformable_offset = (common_offset * K + k) * 2 44 | x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) 45 | y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) 46 | 47 | floor_x = x.to(tl.int32) 48 | floor_y = y.to(tl.int32) 49 | ceil_x = floor_x + 1 50 | ceil_y = floor_y + 1 51 | 52 | # load top left 53 | tl_weight = (ceil_x - x) * (ceil_y - y) 54 | tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE 55 | tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) 56 | 57 | # load top right 58 | tr_weight = (x - floor_x) * (ceil_y - y) 59 | tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE 60 | tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) 61 | # load bottom left 62 | bl_weight = (ceil_x - x) * (y - floor_y) 63 | bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE 64 | bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) 65 | # load bottom right 66 | br_weight = (x - floor_x) * (y - floor_y) 67 | br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE 68 | br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) 69 | 70 | # load dynamic weight and mask 71 | weights_offset = common_offset*K + k 72 | weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) 73 | 74 | 75 | 76 | tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) 77 | tl_block_input = tl_block_input * tl_weight 78 | 79 | # load top right 80 | tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) 81 | tr_block_input = tr_block_input * tr_weight 82 | # load bottom left 83 | bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) 84 | bl_block_input = bl_block_input * bl_weight 85 | # load bottom right 86 | br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) 87 | br_block_input = br_block_input * br_weight 88 | 89 | # sampled 90 | sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input 91 | 92 | weighted_sampled_input = sampled_input * weight 93 | buffer = buffer + weighted_sampled_input 94 | # store to out_ptr 95 | tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) 96 | 97 | -------------------------------------------------------------------------------- /src/ops/triton_kernels_udcn/function.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import torch 4 | from typing import Any 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_fwd, custom_bwd 8 | from src.ops.udcn_kernels.triton_forward import _forward_kernel 9 | from src.ops.udcn_kernels.triton_backward import _backward_kernel 10 | 11 | from functools import lru_cache 12 | 13 | @lru_cache() 14 | def static_grid(H, W, device, dtype): 15 | grid = torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W), indexing='xy'), dim=-1) 16 | return grid.to(device=device, dtype=dtype) 17 | 18 | 19 | class DeformableAttentionFunction(Function): 20 | @staticmethod 21 | @custom_fwd 22 | def forward(ctx: Any, inputs, deformables, weights) -> Any: 23 | B, H, W, G, C = inputs.shape 24 | _, N, _, K = weights.shape 25 | out = torch.zeros(B, N, G, C, device=inputs.device, dtype=inputs.dtype) 26 | grid = lambda META: (B * N * G,) 27 | _forward_kernel[grid](B, H, W, N, G, C, K, inputs, deformables, weights, out) 28 | ctx.save_for_backward(inputs, deformables, weights) 29 | return out 30 | @staticmethod 31 | @once_differentiable 32 | @custom_bwd 33 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 34 | grad_out = grad_outputs[0] 35 | values, deformables, weights = ctx.saved_tensors 36 | B, H, W, G, C = values.shape 37 | B, N, G, K = weights.shape 38 | grad_values = torch.zeros_like(values, dtype=torch.float16) 39 | grad_deformables = torch.zeros_like(deformables) 40 | grad_weights = torch.zeros_like(weights) 41 | grid = lambda META: (B * N * G,) 42 | _backward_kernel[grid](B, H, W, N, G, C, K, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights) 43 | # grad_values = torch.nan_to_num(grad_values, nan=0.0, posinf=0.0, neginf=0.0) 44 | # grad_deformables = torch.nan_to_num(grad_deformables, nan=0.0, posinf=0.0, neginf=0.0) 45 | # grad_weights = torch.nan_to_num(grad_weights, nan=0.0, posinf=0.0, neginf=0.0) 46 | return grad_values.to(values.dtype), grad_deformables, grad_weights -------------------------------------------------------------------------------- /src/ops/triton_kernels_udcn/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import lru_cache 3 | 4 | @lru_cache() 5 | def _generate_grid(H, W): 6 | """ 7 | Internal function to generate a grid of coordinates and cache it. 8 | This function uses purely hashable types (integers) for caching. 9 | """ 10 | return torch.stack(torch.meshgrid(torch.arange(W), torch.arange(H), indexing='xy'), dim=-1) 11 | 12 | 13 | def static_grid(H, W, device, dtype): 14 | """ 15 | Generates a grid of coordinates for a given height (H) and width (W), with caching. 16 | 17 | Args: 18 | H (int): Height of the grid. 19 | W (int): Width of the grid. 20 | device (torch.device): The device on which to place the generated grid (e.g., 'cpu', 'cuda'). 21 | dtype (torch.dtype): The desired data type of the grid (e.g., torch.float32, torch.int32). 22 | 23 | Returns: 24 | torch.Tensor: The generated grid of coordinates. 25 | """ 26 | if not isinstance(H, int) or H <= 0: 27 | raise ValueError("H should be a positive integer.") 28 | if not isinstance(W, int) or W <= 0: 29 | raise ValueError("W should be a positive integer.") 30 | 31 | if not isinstance(device, torch.device): 32 | raise ValueError("device should be a torch.device instance.") 33 | 34 | if not isinstance(dtype, torch.dtype): 35 | raise ValueError("dtype should be a torch.dtype instance.") 36 | 37 | # Generate the grid using the cached internal function 38 | grid = _generate_grid(H, W) 39 | 40 | # Move the grid to the specified device and convert it to the desired dtype 41 | return grid.to(device=device, dtype=dtype) -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/model_loader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from lightning.fabric.utilities.types import _PATH 6 | 7 | 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | 11 | class ModelLoader: 12 | def __init__(self,): 13 | super().__init__() 14 | 15 | def prepare(self, rank, world_size, device, dtype, vae, denoisers, metric, sampler): 16 | self._device = device 17 | self._dtype = dtype # not used 18 | self.rank = rank 19 | self.world_size = world_size 20 | 21 | if metric.precompute_data_path: 22 | _precompute_data_path = dict() 23 | for k, v in metric.precompute_data_path.items(): 24 | _precompute_data_path[k] = v 25 | metric.precompute_data_path = _precompute_data_path 26 | 27 | def load(self, vae, denoisers, metric, sampler): 28 | if vae.weight_path: 29 | vae = vae.from_pretrained(vae.weight_path).to(self._device) 30 | for i, denoiser in enumerate(denoisers): 31 | if denoiser.weight_path: 32 | weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu')) 33 | if denoiser.load_ema and "optimizer_states" in weight.keys(): 34 | weight = weight["optimizer_states"][0]["ema"] 35 | params = list(denoiser.parameters()) 36 | for w, p in zip(weight, params): 37 | p.data.copy_(w) 38 | else: 39 | try: 40 | params = list(denoiser.parameters()) 41 | for w, p in zip(weight['state_dict'].values(), params): 42 | p.data.copy_(w) 43 | except: 44 | denoiser.load_state_dict(weight) 45 | denoiser.to(self._device) 46 | if metric.precompute_data_path: 47 | metric.load_precompute_data(metric.precompute_data_path, self.rank, self.world_size) 48 | if sampler.weight_path: 49 | params = list(sampler.parameters()) 50 | weight = torch.load(sampler.weight_path, map_location=torch.device('cpu')) 51 | for w, p in zip(weight['state_dict'].values(), params): 52 | p.data.copy_(w) 53 | return vae, denoisers, metric, sampler -------------------------------------------------------------------------------- /src/utils/saver.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import tempfile 3 | import shutil 4 | import time 5 | from typing import Sequence 6 | from torchvision.utils import save_image 7 | 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | 11 | class ImageSaver: 12 | def __init__(self, target_dir, rank, max_save_num=50000, compressed=True): 13 | self.target_dir = target_dir 14 | self.tmp_dir = tempfile.TemporaryDirectory() 15 | self.tmp_dir_name = self.tmp_dir.name 16 | self.max_save_num = max_save_num 17 | self._have_saved_num = 0 18 | self.compressed = compressed 19 | self.rank = rank 20 | self.max_upload_num = 500 21 | self._have_uploaded_num = 0 22 | 23 | while True: 24 | time.sleep(1.0) 25 | if not os.path.exists(self.target_dir): 26 | try: 27 | os.makedirs(self.target_dir, exist_ok=True) 28 | except: 29 | logger.warning(f'{self.target_dir} does not exist, trying again...') 30 | else: 31 | break 32 | 33 | def save_image(self, images, filenames): 34 | for sample, filename in zip(images, filenames): 35 | if isinstance(filename, Sequence): 36 | filename = filename[0] 37 | path = f'{self.tmp_dir_name}/{filename}' 38 | if self._have_saved_num >= self.max_save_num: 39 | break 40 | save_image(sample, path, nrow=4, normalize=True, value_range=(-1, 1)) 41 | self._have_saved_num += 1 42 | def upload_image(self, images, filenames): 43 | for sample, filename in zip(images, filenames): 44 | if isinstance(filename, Sequence): 45 | filename = filename[0] 46 | path = f'{self.target_dir}/{filename}' 47 | if self._have_uploaded_num >= self.max_upload_num: 48 | break 49 | save_image(sample, path, nrow=4, normalize=True, value_range=(-1, 1)) 50 | self._have_uploaded_num += 1 51 | 52 | 53 | def upload_all(self, prefix=""): 54 | rank = self.rank 55 | if self.compressed: 56 | # zip the files in tmp dir 57 | shutil.make_archive(f"{rank}", 'zip', self.tmp_dir_name+"/") 58 | # copy to target dir 59 | os.system(f'cp {rank}.zip {self.target_dir}/{prefix}_{rank}.zip') 60 | else: 61 | raise NotImplementedError 62 | # os.system(f'cp -r {self.tmp_dir_name} {self.target_dir}/{rank}') 63 | 64 | # clear tmp dir 65 | self.tmp_dir.cleanup() 66 | -------------------------------------------------------------------------------- /src/utils/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import subprocess 3 | import lightning.pytorch as pl 4 | 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | def class_fn_from_str(class_str): 8 | class_module, from_class = class_str.rsplit(".", 1) 9 | class_module = __import__(class_module, fromlist=[from_class]) 10 | return getattr(class_module, from_class) 11 | 12 | 13 | class PixelVAE(torch.nn.Module): 14 | def __init__(self,): 15 | super().__init__() 16 | self.model = torch.nn.Identity() 17 | 18 | def encode(self, x): 19 | return x 20 | def decode(self, x): 21 | return x 22 | 23 | @staticmethod 24 | def from_pretrained(path): 25 | return PixelVAE() 26 | 27 | 28 | class LatentVAE(PixelVAE): 29 | def __init__(self, precompute=False, weight_path:str=None): 30 | super().__init__() 31 | self.precompute = precompute 32 | self.model = None 33 | self.weight_path = weight_path 34 | 35 | @torch.no_grad() 36 | def encode(self, x): 37 | assert self.model is not None 38 | if self.precompute: 39 | return x.mul_(0.18215) 40 | return self.model.encode(x).latent_dist.sample().mul_(0.18215) 41 | @torch.no_grad() 42 | def decode(self, x): 43 | assert self.model is not None 44 | return self.model.decode(x.div_(0.18215)).sample 45 | 46 | def from_pretrained(self, path): 47 | vae = self 48 | from diffusers.models import AutoencoderKL 49 | setattr(vae, "model", AutoencoderKL.from_pretrained(path)) 50 | return vae 51 | 52 | -------------------------------------------------------------------------------- /t2i_vis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MCG-NJU/NeuralSolver/ce22af301f40973723759e0a2002828490c15250/t2i_vis/__init__.py -------------------------------------------------------------------------------- /t2i_vis/pixart_sigma_1024.py: -------------------------------------------------------------------------------- 1 | from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline, UniPCMultistepScheduler 2 | import torch 3 | 4 | # # 加载预训练的stable diffusion模型 5 | # model_id = "/data/songtianhui.sth/models/stable-diffusion-v1-5/" 6 | # pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) 7 | # pipe = pipe.to("cuda") 8 | 9 | model_path = "/data/songtianhui.sth/models/PixArt-XL-2-1024-MS/" 10 | prompt_json_path = "/home/shuaiou.ws/flowdcn-t2i/data/coco/coco_val_captions.json" 11 | 12 | import torch 13 | import json 14 | import re 15 | from PIL import Image 16 | from diffusers import Transformer2DModel, PixArtAlphaPipeline 17 | from vp_scheduling import NeuralSolver 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | weight_dtype = torch.float16 20 | def remove_invalid_chars(text): 21 | pattern = r'[^\w\s]' # 只保留中文、数字、字母、空格 22 | return re.sub(pattern, '', text) 23 | 24 | 25 | pipe = PixArtAlphaPipeline.from_pretrained( 26 | model_path, 27 | torch_dtype=weight_dtype, 28 | use_safetensors=True, 29 | ) 30 | pipe.to(device) 31 | 32 | # pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True) 33 | dpmsolver = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 34 | unipcsolver = UniPCMultistepScheduler.from_config(pipe.scheduler.config) 35 | prompts_list:list = json.load(open(prompt_json_path)) 36 | prompts_list = prompts_list[101:250] 37 | for prompts in prompts_list: 38 | prompt = prompts[0] 39 | dpmimages = [] 40 | dssimages = [] 41 | unipcimages = [] 42 | for step in range(5, 11): 43 | pipe.scheduler = dpmsolver 44 | dpmimage: Image = pipe(prompt, num_inference_steps=step, generator=torch.Generator(0), guidance_scale=2).images[ 45 | 0] 46 | pipe.scheduler = unipcsolver 47 | unipcimage: Image = \ 48 | pipe(prompt, num_inference_steps=step, generator=torch.Generator(0), guidance_scale=2).images[0] 49 | pipe.scheduler = NeuralSolver(num_train_timesteps=1000) 50 | dssimage: Image = pipe(prompt, num_inference_steps=step, generator=torch.Generator(0), guidance_scale=2).images[ 51 | 0] 52 | dpmimages.append(dpmimage) 53 | unipcimages.append(unipcimage) 54 | dssimages.append(dssimage) 55 | dpm_horizontal_concat = Image.new("RGB", (dpmimages[0].width * len(dpmimages), dpmimages[0].height)) 56 | unipc_horizontal_concat = Image.new("RGB", (unipcimages[0].width * len(unipcimages), unipcimages[0].height)) 57 | dss_horizontal_concat = Image.new("RGB", (dssimages[0].width * len(dssimages), dssimages[0].height)) 58 | for i, image in enumerate(dpmimages): 59 | dpm_horizontal_concat.paste(image, (i * dpmimages[0].width, 0)) 60 | unipc_horizontal_concat.paste(unipcimages[i], (i * unipcimages[0].width, 0)) 61 | dss_horizontal_concat.paste(dssimages[i], (i * dssimages[0].width, 0)) 62 | vertical_concat = Image.new("RGB", (dpm_horizontal_concat.width, 63 | dpm_horizontal_concat.height + unipc_horizontal_concat.height + dss_horizontal_concat.height)) 64 | vertical_concat.paste(dpm_horizontal_concat, (0, 0)) 65 | vertical_concat.paste(unipc_horizontal_concat, (0, dpm_horizontal_concat.height)) 66 | vertical_concat.paste(dss_horizontal_concat, (0, dpm_horizontal_concat.height + unipc_horizontal_concat.height)) 67 | # vertical_concat.save(f"./pixart-sigma-1024/{remove_invalid_chars(prompt)}.png") 68 | vertical_concat.save(f"./pixart-sigma-1024/{remove_invalid_chars(prompt)}.jpg") 69 | 70 | 71 | -------------------------------------------------------------------------------- /t2i_vis/pixart_sigma_256.py: -------------------------------------------------------------------------------- 1 | from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler, StableDiffusionPipeline, DDIMScheduler 2 | import torch 3 | import os 4 | 5 | # # 加载预训练的stable diffusion模型 6 | # model_id = "/data/songtianhui.sth/models/stable-diffusion-v1-5/" 7 | # pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) 8 | # pipe = pipe.to("cuda") 9 | step = 9 10 | model_path = "/home/shuaiou.ws/fasterSolver/t2i_vis/PixArt-XL-2-256x256/" 11 | prompt_json_path = "/home/shuaiou.ws/flowdcn-t2i/data/coco/coco_val_captions.json" 12 | save_dir = f"/home/shuaiou.ws/fasterSolver/t2i_vis/pixart-sigma-256-{step}steps" 13 | 14 | import torch 15 | import json 16 | import re 17 | from PIL import Image 18 | from diffusers import Transformer2DModel, PixArtSigmaPipeline 19 | from vp_scheduling import NeuralSolver 20 | device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu") 21 | weight_dtype = torch.float16 22 | def remove_invalid_chars(text): 23 | pattern = r'[^\w\s]' # 只保留中文、数字、字母、空格 24 | return re.sub(pattern, '', text) 25 | 26 | 27 | pipe = PixArtSigmaPipeline.from_pretrained( 28 | model_path, 29 | torch_dtype=weight_dtype, 30 | use_safetensors=True, 31 | ) 32 | pipe.to(device) 33 | 34 | # pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True) 35 | dpmsolver = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 36 | unipcsolver = UniPCMultistepScheduler.from_config(pipe.scheduler.config) 37 | # dssolver = NeuralSolver(num_train_timesteps=1000) 38 | # pipe.enable_model_cpu_offload() 39 | prompts_list:list = json.load(open(prompt_json_path)) 40 | prompts_list5000 = prompts_list[:5000] 41 | prompt_list = [] 42 | for prompts in prompts_list5000: 43 | prompt = prompts[0] 44 | prompt_list.append(prompt) 45 | 46 | # mkdir for images 47 | for save_type in ["dpm", "unipc", "dss"]: 48 | if not os.path.exists(f"{save_dir}/{save_type}"): 49 | os.makedirs(f"{save_dir}/{save_type}") 50 | 51 | bsz = 16 52 | bnum = len(prompts_list)//bsz + 1 53 | 54 | for bid in range(bnum): 55 | prompts = prompt_list[bid*bsz:(bid+1)*bsz] 56 | pipe.scheduler = dpmsolver 57 | dpmimages:Image = pipe(prompts, num_inference_steps=step, generator=torch.Generator(0), guidance_scale=1.5).images 58 | pipe.scheduler = unipcsolver 59 | unipcimages:Image = pipe(prompts, num_inference_steps=step, generator=torch.Generator(0), guidance_scale=1.5).images 60 | pipe.scheduler = NeuralSolver(num_train_timesteps=1000) 61 | dssimages:Image = pipe(prompts, num_inference_steps=step, generator=torch.Generator(0), guidance_scale=1.5).images 62 | for prompt, dpmimage, unipcimage, dssimage in zip(prompts, dpmimages, unipcimages, dssimages): 63 | # save images 64 | dpmimage.save(f"{save_dir}/dpm/{remove_invalid_chars(prompt)}.png") 65 | unipcimage.save(f"{save_dir}/unipc/{remove_invalid_chars(prompt)}.png") 66 | dssimage.save(f"{save_dir}/dss/{remove_invalid_chars(prompt)}.png") 67 | print(bid*bsz) 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /t2i_vis/pixart_sigma_512.py: -------------------------------------------------------------------------------- 1 | from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler, StableDiffusionPipeline, DDIMScheduler 2 | import torch 3 | 4 | # # 加载预训练的stable diffusion模型 5 | # model_id = "/data/songtianhui.sth/models/stable-diffusion-v1-5/" 6 | # pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) 7 | # pipe = pipe.to("cuda") 8 | 9 | model_path = "/data/songtianhui.sth/models/PixArt-XL-2-512x512/" 10 | prompt_json_path = "/home/shuaiou.ws/flowdcn-t2i/data/coco/coco_val_captions.json" 11 | 12 | import torch 13 | import json 14 | import re 15 | from PIL import Image 16 | from diffusers import Transformer2DModel, PixArtSigmaPipeline 17 | from vp_scheduling import NeuralSolver 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | weight_dtype = torch.float16 20 | def remove_invalid_chars(text): 21 | pattern = r'[^\w\s]' # 只保留中文、数字、字母、空格 22 | return re.sub(pattern, '', text) 23 | 24 | 25 | pipe = PixArtSigmaPipeline.from_pretrained( 26 | model_path, 27 | torch_dtype=weight_dtype, 28 | use_safetensors=True, 29 | ) 30 | pipe.to(device) 31 | 32 | # pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True) 33 | dpmsolver = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 34 | unipcsolver = UniPCMultistepScheduler.from_config(pipe.scheduler.config) 35 | # dssolver = NeuralSolver(num_train_timesteps=1000) 36 | # pipe.enable_model_cpu_offload() 37 | prompts_list:list = json.load(open(prompt_json_path)) 38 | prompts_list = prompts_list[2:150] 39 | for prompts in prompts_list: 40 | prompt = prompts[0] 41 | dpmimages = [] 42 | unipcimages = [] 43 | dssimages = [] 44 | for step in range(5, 11): 45 | pipe.scheduler = dpmsolver 46 | dpmimage:Image = pipe(prompt, num_inference_steps=step, generator=torch.Generator(0), guidance_scale=2).images[0] 47 | pipe.scheduler = unipcsolver 48 | unipcimage:Image = pipe(prompt, num_inference_steps=step, generator=torch.Generator(0), guidance_scale=2).images[0] 49 | pipe.scheduler = NeuralSolver(num_train_timesteps=1000) 50 | dssimage:Image = pipe(prompt, num_inference_steps=step, generator=torch.Generator(0), guidance_scale=2).images[0] 51 | dpmimages.append(dpmimage) 52 | unipcimages.append(unipcimage) 53 | dssimages.append(dssimage) 54 | dpm_horizontal_concat = Image.new("RGB", (dpmimages[0].width*len(dpmimages), dpmimages[0].height)) 55 | unipc_horizontal_concat = Image.new("RGB", (unipcimages[0].width*len(unipcimages), unipcimages[0].height)) 56 | dss_horizontal_concat = Image.new("RGB", (dssimages[0].width*len(dssimages), dssimages[0].height)) 57 | for i, image in enumerate(dpmimages): 58 | dpm_horizontal_concat.paste(image, (i*dpmimages[0].width, 0)) 59 | unipc_horizontal_concat.paste(unipcimages[i], (i*unipcimages[0].width, 0)) 60 | dss_horizontal_concat.paste(dssimages[i], (i*dssimages[0].width, 0)) 61 | vertical_concat = Image.new("RGB", (dpm_horizontal_concat.width, dpm_horizontal_concat.height+unipc_horizontal_concat.height+dss_horizontal_concat.height)) 62 | vertical_concat.paste(dpm_horizontal_concat, (0, 0)) 63 | vertical_concat.paste(unipc_horizontal_concat, (0, dpm_horizontal_concat.height)) 64 | vertical_concat.paste(dss_horizontal_concat, (0, dpm_horizontal_concat.height + unipc_horizontal_concat.height)) 65 | # vertical_concat.save(f"./pixart-sigma-512/{remove_invalid_chars(prompt)}.jpg") 66 | break 67 | 68 | 69 | -------------------------------------------------------------------------------- /tools/fid_curve/eular_steps.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "25-steps": [5.45,4.07,3.39,3.18,2.97,2.82], 4 | "50-steps": [4.97,3.48,3.13,2.77,2.60,2.55], 5 | "100-steps": [4.92,3.57,2.78,2.65,2.47,2.40], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | '25-steps':"#FF595E", 11 | '50-steps':"#FFCA3A", 12 | "100-steps":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | plt.title("Solver with various steps of reference Euler on SiT-XL/2") 20 | plt.ylabel("FID") 21 | plt.ylim(2, 6) 22 | plt.xlabel("Number of steps") 23 | plt.legend() 24 | # plt.margins(0, 0) 25 | plt.savefig(f"tools/plot_figs/coeffs/steps-fid.png", bbox_inches='tight') 26 | plt.close() -------------------------------------------------------------------------------- /tools/fid_curve/flowdcn_256.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [8.47, 5.96, 4.66, 3.87, 3.4, 3.05], 4 | "Adam4-Solver": [8.87, 6.21, 4.65, 3.70, 3.17, 2.81], 5 | "Searched-Solver": [5.10, 3.41, 2.73, 2.60, 2.46, 2.35], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i] - 0.15, data[i] + 0.2, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [2.17, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 2.5, "Eular(50steps) - 2.17", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(256x256)") 25 | plt.ylabel("FID") 26 | plt.ylim(2, 10) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn256-fid.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/fid_curve/flowdcn_512.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [19.21, 12.4, 9.21, 7.38, 6.26, 5.49], 4 | "Adam4-Solver": [19.9, 13.19, 9.55, 7.41, 6.16, 5.33], 5 | "Searched-Solver": [7.24, 4.7, 3.45, 3.18, 3.00, 2.77], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.15, data[i] + 0.5, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [2.81, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 3.0, "Eular(50steps) - 2.81", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(512x512)") 25 | plt.ylabel("FID") 26 | plt.ylim(2, 20) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn512-fid.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/fid_curve/mse_curve.py: -------------------------------------------------------------------------------- 1 | data_dict = { 2 | "FlowDCN-S/2": [0.0336, 0.0204, 0.0150, 0.0110, 0.0087, 0.0081], 3 | "FlowDCN-B/2": [0.0342, 0.0211, 0.0126, 0.0105, 0.0076, 0.0064], 4 | "SiT-XL/2": [0.0324, 0.0192, 0.0117, 0.0079, 0.0072, 0.0051], 5 | } 6 | 7 | x_axis = [5, 6, 7, 8, 9, 10] 8 | color_table = { 9 | 'SiT-XL/2':"#FF595E", 10 | 'FlowDCN-S/2':"#FFCA3A", 11 | "FlowDCN-B/2":"#8AC926", 12 | } 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | for data_name, data in data_dict.items(): 16 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 17 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 18 | plt.title("RecError of Solver with Different Search Model on SiT-XL/2") 19 | plt.ylabel("MSE") 20 | # plt.ylim(2, 10) 21 | plt.xlabel("Number of steps") 22 | plt.legend() 23 | # plt.margins(0, 0) 24 | plt.savefig(f"tools/plot_figs/coeffs/models-mse.png", bbox_inches='tight') 25 | plt.close() -------------------------------------------------------------------------------- /tools/fid_curve/search_model.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "FlowDCN-S/2": [5.35, 3.47, 3.05, 2.78, 2.54, 2.48], 4 | "FlowDCN-B/2": [4.92, 3.57, 2.78, 2.65, 2.47, 2.40], 5 | "SiT-XL/2": [5.65, 3.46, 2.86, 2.68, 2.58, 2.43], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'SiT-XL/2':"#FF595E", 11 | 'FlowDCN-S/2':"#FFCA3A", 12 | "FlowDCN-B/2":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | plt.title("Solver with Different Search Model on SiT-XL/2") 20 | plt.ylabel("FID") 21 | plt.ylim(2, 6) 22 | plt.xlabel("Number of steps") 23 | plt.legend() 24 | # plt.margins(0, 0) 25 | plt.savefig(f"tools/plot_figs/coeffs/models-fid.png", bbox_inches='tight') 26 | plt.close() -------------------------------------------------------------------------------- /tools/fid_curve/sit.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [8.96, 6.35, 4.96, 4.16, 3.6, 3.26], 4 | "Adam4-Solver": [9.64, 6.92, 5.21, 4.15, 3.51, 3.11], 5 | "Searched-Solver": [4.92, 3.57, 2.78, 2.65, 2.47, 2.40], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.05, data[i]+0.15, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [2.23, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 2.5, "Eular(50steps) - 2.23", color="#6A4C93") 24 | plt.title("Performance of solvers on SiT-XL/2") 25 | plt.ylabel("FID") 26 | plt.ylim(2, 10) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.show() 30 | # plt.margins(0, 0) 31 | plt.savefig(f"tools/plot_figs/coeffs/sit-fid.png", bbox_inches='tight') 32 | plt.close() -------------------------------------------------------------------------------- /tools/is_curve/flowdcn_256.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [174, 195, 207, 216, 222, 226], 4 | "Adam4-Solver": [175, 194, 209, 221, 228, 233], 5 | "Searched-Solver": [192, 218, 228, 232, 237, 239], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.1, data[i] + 2, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [247, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 249, "Eular(50steps) - 247", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(256x256)") 25 | plt.ylabel("IS") 26 | plt.ylim(160, 260) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn256-is.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/is_curve/flowdcn_512.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [122, 155, 177, 191, 202, 209], 4 | "Adam4-Solver": [123, 156, 178, 194, 205, 215], 5 | "Searched-Solver": [178, 203, 222, 226, 232, 238], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.15, data[i] + 0.002, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [243, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 240, "Eular(50steps) - 243", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(512x512)") 25 | plt.ylabel("IS") 26 | plt.ylim(100, 250) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn512-is.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/is_curve/sit.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [166, 190, 203, 212, 217, 222], 4 | "Adam4-Solver": [166, 187, 203, 214, 221, 227], 5 | "Searched-Solver": [192, 214, 226, 230, 231, 234], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.05, data[i]+3, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [244, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 240, "Eular(50steps) - 244", color="#6A4C93") 24 | plt.title("Performance of solvers on SiT-XL/2") 25 | plt.ylabel("IS") 26 | plt.ylim(160, 260) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.show() 30 | # plt.margins(0, 0) 31 | plt.savefig(f"tools/plot_figs/coeffs/sit-is.png", bbox_inches='tight') 32 | plt.close() -------------------------------------------------------------------------------- /tools/pr_curve/flowdcn_256.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [0.73, 0.76, 0.78, 0.79, 0.79, 0.79], 4 | "Adam4-Solver": [0.73, 0.76, 0.78, 0.79, 0.79, 0.80], 5 | "Searched-Solver": [0.74, 0.78, 0.79, 0.79, 0.80, 0.80], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i] - 0.15, data[i] + 0.002, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [0.81, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 0.80, "Eular(50steps) - 0.81", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(256x256)") 25 | plt.ylabel("Precision") 26 | plt.ylim(0.7, 0.85) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn256-pr.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/pr_curve/flowdcn_512.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [0.708, 0.76, 0.79, 0.81, 0.82, 0.82], 4 | "Adam4-Solver": [0.71, 0.76, 0.79, 0.81, 0.82, 0.825], 5 | "Searched-Solver": [0.80, 0.819, 0.829, 0.830, 0.834, 0.833], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.15, data[i] + 0.002, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [0.837, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 0.83, "Eular(50steps) - 0.837", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(512x512)") 25 | plt.ylabel("Precision") 26 | plt.ylim(0.65, 0.9) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn512-pr.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/pr_curve/sit.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [0.72, 0.75, 0.77, 0.78, 0.78, 0.78], 4 | "Adam4-Solver": [0.72, 0.74, 0.76, 0.77, 0.78, 0.79], 5 | "Searched-Solver": [0.73, 0.76, 0.77, 0.78, 0.78, 0.78], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.05, data[i]+0.002, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [0.79, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 0.78, "Eular(50steps) - 0.79", color="#6A4C93") 24 | plt.title("Performance of solvers on SiT-XL/2") 25 | plt.ylabel("Precision") 26 | plt.ylim(0.7, 0.85) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.show() 30 | # plt.margins(0, 0) 31 | plt.savefig(f"tools/plot_figs/coeffs/sit-pr.png", bbox_inches='tight') 32 | plt.close() -------------------------------------------------------------------------------- /tools/recall_curve/flowdcn_256.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [0.52, 0.54, 0.54, 0.55, 0.56, 0.57], 4 | "Adam4-Solver": [0.53, 0.54, 0.55, 0.56, 0.56, 0.57], 5 | "Searched-Solver": [0.57, 0.57, 0.58, 0.58, 0.58, 0.58], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i] - 0.15, data[i] + 0.002, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [0.58, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 0.585, "Eular(50steps) - 0.58", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(256x256)") 25 | plt.ylabel("Recall") 26 | plt.ylim(0.5, 0.6) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn256-recall.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/recall_curve/flowdcn_512.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [0.46, 0.47, 0.46, 0.487, 0.47, 0.48], 4 | "Adam4-Solver": [0.458, 0.478, 0.45, 0.465, 0.484, 0.481], 5 | "Searched-Solver": [0.519, 0.553, 0.546, 0.551, 0.559, 0.555], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.15, data[i] + 0.002, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [0.537, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 0.53, "Eular(50steps) - 0.537", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(512x512)") 25 | plt.ylabel("Recall") 26 | plt.ylim(0.4, 0.6) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn512-recall.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/recall_curve/sit.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [0.53, 0.55, 0.56, 0.56, 0.57, 0.58], 4 | "Adam4-Solver": [0.53, 0.55, 0.56, 0.57, 0.58, 0.58], 5 | "Searched-Solver": [0.58, 0.60, 0.59, 0.59, 0.60, 0.60], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.05, data[i]+0.002, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [0.59, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 0.58, "Eular(50steps) - 0.59", color="#6A4C93") 24 | plt.title("Performance of solvers on SiT-XL/2") 25 | plt.ylabel("Recall") 26 | plt.ylim(0.5, 0.65) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.show() 30 | # plt.margins(0, 0) 31 | plt.savefig(f"tools/plot_figs/coeffs/sit-recall.png", bbox_inches='tight') 32 | plt.close() -------------------------------------------------------------------------------- /tools/sfid_curve/flowdcn_256.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [8.48, 7.32, 6.59, 6.06, 5.72, 5.43], 4 | "Adam4-Solver": [12.6, 10.1, 8.41, 7.24, 6.52, 5.97], 5 | "Searched-Solver": [5.50, 5.12, 5.2, 5.33, 5.32, 5.07], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i] - 0.15, data[i] + 0.2, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [4.32, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 4.5, "Eular(50steps) - 4.32", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(256x256)") 25 | plt.ylabel("sFID") 26 | plt.ylim(4, 15) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn256-sfid.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/sfid_curve/flowdcn_512.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [19.12, 15.4, 13.0, 11.3, 10.2, 9.37], 4 | "Adam4-Solver": [24.4, 19.71, 15.9, 13.3, 11.65, 10.39], 5 | "Searched-Solver": [6.07, 5.10, 4.69, 4.7, 4.61, 4.68], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.15, data[i] + 0.5, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [5.44, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 5.40, "Eular(50steps) - 5.44", color="#6A4C93") 24 | plt.title("Performance of solvers on FlowDCN-XL/2(512x512)") 25 | plt.ylabel("sFID") 26 | plt.ylim(3, 20) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.margins(0, 0) 30 | plt.savefig(f"tools/plot_figs/coeffs/flowdcn512-sfid.png", bbox_inches='tight') 31 | plt.close() -------------------------------------------------------------------------------- /tools/sfid_curve/sit.py: -------------------------------------------------------------------------------- 1 | 2 | data_dict = { 3 | "Adam2-Solver": [12.6, 7.83, 7.08, 6.56, 6.25, 5.96], 4 | "Adam4-Solver": [13.6, 11.1, 9.29, 8.07, 7.33, 6.73], 5 | "Searched-Solver": [4.85, 4.83, 4.79, 4.89, 4.91, 4.96], 6 | } 7 | 8 | x_axis = [5, 6, 7, 8, 9, 10] 9 | color_table = { 10 | 'Adam2-Solver':"#FF595E", 11 | 'Adam4-Solver':"#FFCA3A", 12 | "Searched-Solver":"#8AC926", 13 | } 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | for data_name, data in data_dict.items(): 17 | plt.plot(x_axis, data, label=data_name, color=color_table[data_name]) 18 | plt.scatter(x_axis, data, marker="x", color=color_table[data_name]) 19 | if data_name == "Searched-Solver": 20 | for i in range(len(data)): 21 | plt.text(x_axis[i]-0.05, data[i]+0.15, str(data[i]), color=color_table[data_name]) 22 | plt.plot(x_axis, [4.60, ] * len(x_axis), linestyle="--", color="#6A4C93") 23 | plt.text(5, 4.0, "Eular(50steps) - 4.60", color="#6A4C93") 24 | plt.title("Performance of solvers on SiT-XL/2") 25 | plt.ylabel("sFID") 26 | plt.ylim(3.5, 15) 27 | plt.xlabel("Number of steps") 28 | plt.legend() 29 | # plt.show() 30 | # plt.margins(0, 0) 31 | plt.savefig(f"tools/plot_figs/coeffs/sit-sfid.png", bbox_inches='tight') 32 | plt.close() --------------------------------------------------------------------------------