├── .gitignore ├── README.md ├── assets ├── fig5_sevir_420_advance.png └── pipeline.png ├── configs └── sevir_used │ ├── EarthFormer.yaml │ ├── autoencoder_kl_gan.yaml │ ├── cascast_diffusion.yaml │ ├── compress_earthformer.yaml │ └── compress_gt.yaml ├── datasets ├── hko7 │ ├── hko7.py │ └── hko7_list │ │ ├── hko7_full_list.txt │ │ ├── hko7_rainy_test.txt │ │ ├── hko7_rainy_train.txt │ │ ├── hko7_rainy_valid.txt │ │ └── txt2pkl.py ├── meteonet │ ├── meteonet.py │ └── preprocess_upload.py ├── sevir_diffusion_eval.py ├── sevir_latent_used.py ├── sevir_list │ ├── test.txt │ ├── train.txt │ └── val.txt ├── sevir_preprocess_used.py ├── sevir_pretrain_used.py ├── sevir_used.py └── sevir_util │ └── sevir_cmap.py ├── evaluation.py ├── experiments ├── EarthFormer │ └── world_size1-ckpt │ │ └── training_options.yaml └── cascast_diffusion │ └── world_size1-ckpt │ └── training_options.yaml ├── latent_preprocess.py ├── megatron_utils ├── __init__.py ├── parallel_state.py ├── tensor_parallel │ ├── __init__.py │ ├── cross_entropy.py │ ├── data.py │ ├── layers.py │ ├── mappings.py │ ├── random.py │ └── utils.py └── utils.py ├── models ├── autoencoder_kl_gan_model.py ├── latent_compress_model.py ├── latent_diffusion_model.py ├── latent_diffusion_model_eval.py ├── model.py └── non_ar_model.py ├── networks ├── autoencoder_kl.py ├── casformer.py ├── earthformer_xy.py ├── ldm │ ├── ldm_attention.py │ ├── ldm_util.py │ └── unet2d_openai.py ├── lpipsWithDisc.py ├── prediff │ ├── __init__.py │ ├── models │ │ ├── cuboid_transformer │ │ │ ├── __init__.py │ │ │ ├── cuboid_transformer.py │ │ │ ├── cuboid_transformer_patterns.py │ │ │ └── cuboid_transformer_unet.py │ │ ├── openaimodel.py │ │ ├── time_embed.py │ │ └── utils.py │ ├── taming │ │ ├── README.md │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── autoencoder_kl.py │ │ ├── losses │ │ │ ├── contperceptual.py │ │ │ ├── lpips.py │ │ │ ├── model.py │ │ │ └── util.py │ │ ├── resnet.py │ │ ├── unet_2d_blocks.py │ │ └── vae.py │ └── utils │ │ ├── __init__.py │ │ ├── distributions.py │ │ ├── download.py │ │ ├── ema.py │ │ ├── gifmaker.py │ │ ├── layout.py │ │ ├── optim.py │ │ ├── path.py │ │ ├── pl_checkpoint.py │ │ └── registry.py └── utils │ ├── MIMBlock.py │ ├── MIMN.py │ ├── SpatioTemporalLSTMCell.py │ ├── cuboid_transformer.py │ ├── cuboid_transformer_patterns.py │ ├── cuboid_transformer_unet_dec.py │ ├── mlp.py │ └── utils.py ├── scripts ├── compress_earthformer.sh ├── compress_gt.sh ├── eval_deterministic.sh ├── eval_diffusion.sh ├── eval_diffusion_infer.sh ├── train_autoencoder.sh ├── train_deterministic.sh └── train_diffusion.sh ├── src └── diffusers │ ├── __init__.py │ ├── commands │ ├── __init__.py │ ├── diffusers_cli.py │ ├── env.py │ └── fp16_safetensors.py │ ├── configuration_utils.py │ ├── dependency_versions_check.py │ ├── dependency_versions_table.py │ ├── experimental │ ├── README.md │ ├── __init__.py │ └── rl │ │ ├── __init__.py │ │ └── value_guided_sampling.py │ ├── image_processor.py │ ├── loaders │ ├── __init__.py │ ├── ip_adapter.py │ ├── lora.py │ ├── lora_conversion_utils.py │ ├── single_file.py │ ├── textual_inversion.py │ ├── unet.py │ └── utils.py │ ├── models │ ├── README.md │ ├── __init__.py │ ├── activations.py │ ├── adapter.py │ ├── attention.py │ ├── attention_flax.py │ ├── attention_processor.py │ ├── autoencoder_asym_kl.py │ ├── autoencoder_kl.py │ ├── autoencoder_kl_temporal_decoder.py │ ├── autoencoder_tiny.py │ ├── consistency_decoder_vae.py │ ├── controlnet.py │ ├── controlnet_flax.py │ ├── dual_transformer_2d.py │ ├── embeddings.py │ ├── embeddings_flax.py │ ├── lora.py │ ├── modeling_flax_pytorch_utils.py │ ├── modeling_flax_utils.py │ ├── modeling_outputs.py │ ├── modeling_pytorch_flax_utils.py │ ├── modeling_utils.py │ ├── normalization.py │ ├── prior_transformer.py │ ├── resnet.py │ ├── resnet_flax.py │ ├── t5_film_transformer.py │ ├── transformer_2d.py │ ├── transformer_temporal.py │ ├── unet_1d.py │ ├── unet_1d_blocks.py │ ├── unet_2d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_blocks_flax.py │ ├── unet_2d_condition.py │ ├── unet_2d_condition_flax.py │ ├── unet_3d_blocks.py │ ├── unet_3d_condition.py │ ├── unet_kandi3.py │ ├── unet_motion_model.py │ ├── unet_spatio_temporal_condition.py │ ├── vae.py │ ├── vae_flax.py │ └── vq_model.py │ ├── optimization.py │ ├── pipelines │ ├── README.md │ ├── __init__.py │ ├── alt_diffusion │ │ ├── __init__.py │ │ ├── modeling_roberta_series.py │ │ ├── pipeline_alt_diffusion.py │ │ ├── pipeline_alt_diffusion_img2img.py │ │ └── pipeline_output.py │ ├── animatediff │ │ ├── __init__.py │ │ └── pipeline_animatediff.py │ ├── audio_diffusion │ │ ├── __init__.py │ │ ├── mel.py │ │ └── pipeline_audio_diffusion.py │ ├── audioldm │ │ ├── __init__.py │ │ └── pipeline_audioldm.py │ ├── audioldm2 │ │ ├── __init__.py │ │ ├── modeling_audioldm2.py │ │ └── pipeline_audioldm2.py │ ├── auto_pipeline.py │ ├── blip_diffusion │ │ ├── __init__.py │ │ ├── blip_image_processing.py │ │ ├── modeling_blip2.py │ │ ├── modeling_ctx_clip.py │ │ └── pipeline_blip_diffusion.py │ ├── consistency_models │ │ ├── __init__.py │ │ └── pipeline_consistency_models.py │ ├── controlnet │ │ ├── __init__.py │ │ ├── multicontrolnet.py │ │ ├── pipeline_controlnet.py │ │ ├── pipeline_controlnet_blip_diffusion.py │ │ ├── pipeline_controlnet_img2img.py │ │ ├── pipeline_controlnet_inpaint.py │ │ ├── pipeline_controlnet_inpaint_sd_xl.py │ │ ├── pipeline_controlnet_sd_xl.py │ │ ├── pipeline_controlnet_sd_xl_img2img.py │ │ └── pipeline_flax_controlnet.py │ ├── dance_diffusion │ │ ├── __init__.py │ │ └── pipeline_dance_diffusion.py │ ├── ddim │ │ ├── __init__.py │ │ └── pipeline_ddim.py │ ├── ddpm │ │ ├── __init__.py │ │ └── pipeline_ddpm.py │ ├── deepfloyd_if │ │ ├── __init__.py │ │ ├── pipeline_if.py │ │ ├── pipeline_if_img2img.py │ │ ├── pipeline_if_img2img_superresolution.py │ │ ├── pipeline_if_inpainting.py │ │ ├── pipeline_if_inpainting_superresolution.py │ │ ├── pipeline_if_superresolution.py │ │ ├── pipeline_output.py │ │ ├── safety_checker.py │ │ ├── timesteps.py │ │ └── watermark.py │ ├── dit │ │ ├── __init__.py │ │ └── pipeline_dit.py │ ├── kandinsky │ │ ├── __init__.py │ │ ├── pipeline_kandinsky.py │ │ ├── pipeline_kandinsky_combined.py │ │ ├── pipeline_kandinsky_img2img.py │ │ ├── pipeline_kandinsky_inpaint.py │ │ ├── pipeline_kandinsky_prior.py │ │ └── text_encoder.py │ ├── kandinsky2_2 │ │ ├── __init__.py │ │ ├── pipeline_kandinsky2_2.py │ │ ├── pipeline_kandinsky2_2_combined.py │ │ ├── pipeline_kandinsky2_2_controlnet.py │ │ ├── pipeline_kandinsky2_2_controlnet_img2img.py │ │ ├── pipeline_kandinsky2_2_img2img.py │ │ ├── pipeline_kandinsky2_2_inpainting.py │ │ ├── pipeline_kandinsky2_2_prior.py │ │ └── pipeline_kandinsky2_2_prior_emb2emb.py │ ├── kandinsky3 │ │ ├── __init__.py │ │ ├── kandinsky3_pipeline.py │ │ └── kandinsky3img2img_pipeline.py │ ├── latent_consistency_models │ │ ├── __init__.py │ │ ├── pipeline_latent_consistency_img2img.py │ │ └── pipeline_latent_consistency_text2img.py │ ├── latent_diffusion │ │ ├── __init__.py │ │ ├── pipeline_latent_diffusion.py │ │ └── pipeline_latent_diffusion_superresolution.py │ ├── latent_diffusion_uncond │ │ ├── __init__.py │ │ └── pipeline_latent_diffusion_uncond.py │ ├── musicldm │ │ ├── __init__.py │ │ └── pipeline_musicldm.py │ ├── onnx_utils.py │ ├── paint_by_example │ │ ├── __init__.py │ │ ├── image_encoder.py │ │ └── pipeline_paint_by_example.py │ ├── pipeline_flax_utils.py │ ├── pipeline_utils.py │ ├── pixart_alpha │ │ ├── __init__.py │ │ └── pipeline_pixart_alpha.py │ ├── pndm │ │ ├── __init__.py │ │ └── pipeline_pndm.py │ ├── repaint │ │ ├── __init__.py │ │ └── pipeline_repaint.py │ ├── score_sde_ve │ │ ├── __init__.py │ │ └── pipeline_score_sde_ve.py │ ├── semantic_stable_diffusion │ │ ├── __init__.py │ │ ├── pipeline_output.py │ │ └── pipeline_semantic_stable_diffusion.py │ ├── shap_e │ │ ├── __init__.py │ │ ├── camera.py │ │ ├── pipeline_shap_e.py │ │ ├── pipeline_shap_e_img2img.py │ │ └── renderer.py │ ├── spectrogram_diffusion │ │ ├── __init__.py │ │ ├── continous_encoder.py │ │ ├── midi_utils.py │ │ ├── notes_encoder.py │ │ └── pipeline_spectrogram_diffusion.py │ ├── stable_diffusion │ │ ├── README.md │ │ ├── __init__.py │ │ ├── clip_image_project_model.py │ │ ├── convert_from_ckpt.py │ │ ├── pipeline_cycle_diffusion.py │ │ ├── pipeline_flax_stable_diffusion.py │ │ ├── pipeline_flax_stable_diffusion_img2img.py │ │ ├── pipeline_flax_stable_diffusion_inpaint.py │ │ ├── pipeline_onnx_stable_diffusion.py │ │ ├── pipeline_onnx_stable_diffusion_img2img.py │ │ ├── pipeline_onnx_stable_diffusion_inpaint.py │ │ ├── pipeline_onnx_stable_diffusion_inpaint_legacy.py │ │ ├── pipeline_onnx_stable_diffusion_upscale.py │ │ ├── pipeline_output.py │ │ ├── pipeline_stable_diffusion.py │ │ ├── pipeline_stable_diffusion_attend_and_excite.py │ │ ├── pipeline_stable_diffusion_depth2img.py │ │ ├── pipeline_stable_diffusion_diffedit.py │ │ ├── pipeline_stable_diffusion_gligen.py │ │ ├── pipeline_stable_diffusion_gligen_text_image.py │ │ ├── pipeline_stable_diffusion_image_variation.py │ │ ├── pipeline_stable_diffusion_img2img.py │ │ ├── pipeline_stable_diffusion_inpaint.py │ │ ├── pipeline_stable_diffusion_inpaint_legacy.py │ │ ├── pipeline_stable_diffusion_instruct_pix2pix.py │ │ ├── pipeline_stable_diffusion_k_diffusion.py │ │ ├── pipeline_stable_diffusion_latent_upscale.py │ │ ├── pipeline_stable_diffusion_ldm3d.py │ │ ├── pipeline_stable_diffusion_model_editing.py │ │ ├── pipeline_stable_diffusion_panorama.py │ │ ├── pipeline_stable_diffusion_paradigms.py │ │ ├── pipeline_stable_diffusion_pix2pix_zero.py │ │ ├── pipeline_stable_diffusion_sag.py │ │ ├── pipeline_stable_diffusion_upscale.py │ │ ├── pipeline_stable_unclip.py │ │ ├── pipeline_stable_unclip_img2img.py │ │ ├── safety_checker.py │ │ ├── safety_checker_flax.py │ │ └── stable_unclip_image_normalizer.py │ ├── stable_diffusion_safe │ │ ├── __init__.py │ │ ├── pipeline_output.py │ │ ├── pipeline_stable_diffusion_safe.py │ │ └── safety_checker.py │ ├── stable_diffusion_xl │ │ ├── __init__.py │ │ ├── pipeline_flax_stable_diffusion_xl.py │ │ ├── pipeline_output.py │ │ ├── pipeline_stable_diffusion_xl.py │ │ ├── pipeline_stable_diffusion_xl_img2img.py │ │ ├── pipeline_stable_diffusion_xl_inpaint.py │ │ ├── pipeline_stable_diffusion_xl_instruct_pix2pix.py │ │ └── watermark.py │ ├── stable_video_diffusion │ │ ├── __init__.py │ │ └── pipeline_stable_video_diffusion.py │ ├── stochastic_karras_ve │ │ ├── __init__.py │ │ └── pipeline_stochastic_karras_ve.py │ ├── t2i_adapter │ │ ├── __init__.py │ │ ├── pipeline_stable_diffusion_adapter.py │ │ └── pipeline_stable_diffusion_xl_adapter.py │ ├── text_to_video_synthesis │ │ ├── __init__.py │ │ ├── pipeline_output.py │ │ ├── pipeline_text_to_video_synth.py │ │ ├── pipeline_text_to_video_synth_img2img.py │ │ ├── pipeline_text_to_video_zero.py │ │ └── pipeline_text_to_video_zero_sdxl.py │ ├── unclip │ │ ├── __init__.py │ │ ├── pipeline_unclip.py │ │ ├── pipeline_unclip_image_variation.py │ │ └── text_proj.py │ ├── unidiffuser │ │ ├── __init__.py │ │ ├── modeling_text_decoder.py │ │ ├── modeling_uvit.py │ │ └── pipeline_unidiffuser.py │ ├── versatile_diffusion │ │ ├── __init__.py │ │ ├── modeling_text_unet.py │ │ ├── pipeline_versatile_diffusion.py │ │ ├── pipeline_versatile_diffusion_dual_guided.py │ │ ├── pipeline_versatile_diffusion_image_variation.py │ │ └── pipeline_versatile_diffusion_text_to_image.py │ ├── vq_diffusion │ │ ├── __init__.py │ │ └── pipeline_vq_diffusion.py │ └── wuerstchen │ │ ├── __init__.py │ │ ├── modeling_paella_vq_model.py │ │ ├── modeling_wuerstchen_common.py │ │ ├── modeling_wuerstchen_diffnext.py │ │ ├── modeling_wuerstchen_prior.py │ │ ├── pipeline_wuerstchen.py │ │ ├── pipeline_wuerstchen_combined.py │ │ └── pipeline_wuerstchen_prior.py │ ├── py.typed │ ├── schedulers │ ├── README.md │ ├── __init__.py │ ├── deprecated │ │ ├── __init__.py │ │ ├── scheduling_karras_ve.py │ │ └── scheduling_sde_vp.py │ ├── scheduling_consistency_decoder.py │ ├── scheduling_consistency_models.py │ ├── scheduling_ddim.py │ ├── scheduling_ddim_flax.py │ ├── scheduling_ddim_inverse.py │ ├── scheduling_ddim_parallel.py │ ├── scheduling_ddpm.py │ ├── scheduling_ddpm_flax.py │ ├── scheduling_ddpm_parallel.py │ ├── scheduling_ddpm_wuerstchen.py │ ├── scheduling_deis_multistep.py │ ├── scheduling_dpmsolver_multistep.py │ ├── scheduling_dpmsolver_multistep_flax.py │ ├── scheduling_dpmsolver_multistep_inverse.py │ ├── scheduling_dpmsolver_sde.py │ ├── scheduling_dpmsolver_singlestep.py │ ├── scheduling_euler_ancestral_discrete.py │ ├── scheduling_euler_discrete.py │ ├── scheduling_euler_discrete_flax.py │ ├── scheduling_heun_discrete.py │ ├── scheduling_ipndm.py │ ├── scheduling_k_dpm_2_ancestral_discrete.py │ ├── scheduling_k_dpm_2_discrete.py │ ├── scheduling_karras_ve_flax.py │ ├── scheduling_lcm.py │ ├── scheduling_lms_discrete.py │ ├── scheduling_lms_discrete_flax.py │ ├── scheduling_pndm.py │ ├── scheduling_pndm_flax.py │ ├── scheduling_repaint.py │ ├── scheduling_sde_ve.py │ ├── scheduling_sde_ve_flax.py │ ├── scheduling_unclip.py │ ├── scheduling_unipc_multistep.py │ ├── scheduling_utils.py │ ├── scheduling_utils_flax.py │ └── scheduling_vq_diffusion.py │ ├── training_utils.py │ └── utils │ ├── __init__.py │ ├── accelerate_utils.py │ ├── constants.py │ ├── deprecation_utils.py │ ├── doc_utils.py │ ├── dummy_flax_and_transformers_objects.py │ ├── dummy_flax_objects.py │ ├── dummy_note_seq_objects.py │ ├── dummy_onnx_objects.py │ ├── dummy_pt_objects.py │ ├── dummy_torch_and_librosa_objects.py │ ├── dummy_torch_and_scipy_objects.py │ ├── dummy_torch_and_torchsde_objects.py │ ├── dummy_torch_and_transformers_and_k_diffusion_objects.py │ ├── dummy_torch_and_transformers_and_onnx_objects.py │ ├── dummy_torch_and_transformers_objects.py │ ├── dummy_transformers_and_torch_and_note_seq_objects.py │ ├── dynamic_modules_utils.py │ ├── export_utils.py │ ├── hub_utils.py │ ├── import_utils.py │ ├── loading_utils.py │ ├── logging.py │ ├── model_card_template.md │ ├── outputs.py │ ├── peft_utils.py │ ├── pil_utils.py │ ├── state_dict_utils.py │ ├── testing_utils.py │ ├── torch_utils.py │ └── versions.py ├── train.py ├── train_debug.sh └── utils ├── __init__.py ├── builder.py ├── checkpoint_ceph.py ├── distributedsample.py ├── logger.py ├── metrics.py ├── misc.py ├── optim.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /train_job 3 | /tools 4 | /wandb 5 | /None 6 | /diffusion_noise 7 | /preprocess 8 | /wandb 9 | **/__pycache__ 10 | **/valid_options.yaml 11 | 12 | 13 | *.pyc 14 | *.pth 15 | *.pkl 16 | *.pt 17 | *.npy 18 | -------------------------------------------------------------------------------- /assets/fig5_sevir_420_advance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenEarthLab/CasCast/487a68b5ade9aa829fe7df2e8f6746b4d9acc233/assets/fig5_sevir_420_advance.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenEarthLab/CasCast/487a68b5ade9aa829fe7df2e8f6746b4d9acc233/assets/pipeline.png -------------------------------------------------------------------------------- /configs/sevir_used/EarthFormer.yaml: -------------------------------------------------------------------------------- 1 | sevir_used: &sevir 2 | type: sevir 3 | input_length: &input_length 13 4 | pred_length: &pred_length 12 5 | total_length: &total_length 25 6 | base_freq: 5min 7 | data_dir: path/to/sevir 8 | 9 | dataset: 10 | train: 11 | <<: *sevir 12 | 13 | valid: 14 | <<: *sevir 15 | 16 | sampler: 17 | type: DistributedSampler 18 | 19 | dataloader: 20 | num_workers: 8 21 | pin_memory: False 22 | prefetch_factor: 2 23 | persistent_workers: True 24 | 25 | trainer: 26 | batch_size: 8 # to check 27 | valid_batch_size: 16 28 | max_epoch: &max_epoch 1 29 | max_step: 100000 30 | 31 | model: 32 | type: non_ar_model 33 | params: 34 | sub_model: 35 | EarthFormer_xy: 36 | in_len: 13 37 | out_len: 12 38 | height: 384 39 | width: 384 40 | 41 | save_best: &loss_type MSE 42 | use_ceph: False 43 | ceph_checkpoint_path: "mpas:s3://sevir/checkpoint" 44 | metrics_type: SEVIRSkillScore 45 | data_type: fp32 46 | 47 | visualizer: 48 | visualizer_type: sevir_visualizer 49 | visualizer_step: 8000 50 | 51 | optimizer: 52 | EarthFormer_xy: 53 | type: AdamW 54 | params: 55 | lr: 0.001 56 | betas: [0.9, 0.999] 57 | weight_decay: 0.00001 58 | # eps: 0.000001 59 | 60 | lr_scheduler: 61 | EarthFormer_xy: 62 | by_step: True 63 | sched: cosine 64 | epochs: *max_epoch 65 | min_lr: 0.00001 66 | warmup_lr: 0.00001 67 | warmup_epochs: 0.1 68 | lr_noise: 69 | cooldown_epochs: 0 70 | 71 | extra_params: 72 | loss_type: MSELoss 73 | enabled_amp: False 74 | log_step: 20 75 | z_score_delta: False 76 | 77 | wandb: 78 | project_name: sevir -------------------------------------------------------------------------------- /configs/sevir_used/autoencoder_kl_gan.yaml: -------------------------------------------------------------------------------- 1 | sevir: &sevir 2 | type: sevir_pretrain 3 | 4 | dataset: 5 | train: 6 | <<: *sevir 7 | 8 | valid: 9 | <<: *sevir 10 | 11 | sampler: 12 | type: TrainingSampler 13 | 14 | dataloader: 15 | num_workers: 8 16 | pin_memory: False 17 | prefetch_factor: 2 18 | persistent_workers: True 19 | 20 | trainer: 21 | batch_size: 8 # to check 22 | valid_batch_size: 16 23 | max_epoch: &max_epoch 1 24 | max_step: 200000 25 | 26 | model: 27 | type: autoencoder_kl_gan_model 28 | params: 29 | sub_model: 30 | autoencoder_kl: 31 | in_channels: 1 32 | out_channels: 1 33 | down_block_types: ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'] 34 | up_block_types: ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'] 35 | block_out_channels: [128, 256, 512, 512] 36 | layers_per_block: 2 37 | latent_channels: 4 38 | norm_num_groups: 32 39 | 40 | lpipsWithDisc: 41 | disc_start: 25001 ## not default 42 | logvar_init: 0.0 43 | kl_weight: 0.000001 44 | pixelloss_weight: 1.0 45 | disc_num_layers: 3 46 | disc_in_channels: 1 47 | disc_factor: 1.0 48 | disc_weight: 0.5 49 | perceptual_weight: 0.0 50 | 51 | save_best: &loss_type MSE 52 | use_ceph: False 53 | ceph_checkpoint_path: "mpas:s3://sevir/checkpoint" 54 | metrics_type: SEVIRSkillScore 55 | data_type: fp32 56 | 57 | visualizer: 58 | visualizer_type: sevir_visualizer 59 | visualizer_step: 1000 60 | 61 | optimizer: 62 | autoencoder_kl: 63 | type: AdamW 64 | params: 65 | lr: 0.0001 66 | betas: [0.9, 0.999] 67 | weight_decay: 0.00001 68 | # eps: 0.000001 69 | 70 | lpipsWithDisc: 71 | type: AdamW 72 | params: 73 | lr: 0.0001 74 | betas: [0.9, 0.999] 75 | weight_decay: 0.00001 76 | 77 | 78 | lr_scheduler: 79 | autoencoder_kl: 80 | by_step: True 81 | sched: cosine 82 | epochs: *max_epoch 83 | min_lr: 0.000001 84 | warmup_lr: 0.000001 85 | warmup_epochs: 0.1 86 | lr_noise: 87 | cooldown_epochs: 0 88 | 89 | lpipsWithDisc: 90 | by_step: True 91 | sched: cosine 92 | epochs: *max_epoch 93 | min_lr: 0.000001 94 | warmup_lr: 0.000001 95 | warmup_epochs: 0.1 96 | lr_noise: 97 | cooldown_epochs: 0 98 | 99 | extra_params: 100 | loss_type: MSELoss 101 | enabled_amp: False 102 | log_step: 20 103 | z_score_delta: False 104 | # checkpoint_path: EarthFormer_xy/world_size1-xytest/checkpoint_latest.pth ## for pretrained advective predictor 105 | 106 | # wandb: 107 | # project_name: sevir -------------------------------------------------------------------------------- /configs/sevir_used/cascast_diffusion.yaml: -------------------------------------------------------------------------------- 1 | sevir: &sevir 2 | type: sevir_latent 3 | input_length: &input_length 13 4 | pred_length: &pred_length 12 5 | total_length: &total_length 25 6 | base_freq: 5min 7 | data_dir: radar:s3://sevir_latent ## path/to/sevir 8 | latent_gt_dir: path/to/latent_gt 9 | latent_deterministic_dir: path/to/latent_prediction 10 | 11 | latent_size: 48x48x4 12 | 13 | dataset: 14 | train: 15 | <<: *sevir 16 | 17 | valid: 18 | <<: *sevir 19 | 20 | sampler: 21 | type: TrainingSampler 22 | 23 | dataloader: 24 | num_workers: 8 25 | pin_memory: False 26 | prefetch_factor: 2 27 | persistent_workers: True 28 | 29 | trainer: 30 | batch_size: 8 # to check 31 | valid_batch_size: 12 32 | max_epoch: &max_epoch 1 33 | max_step: 100000 34 | 35 | model: 36 | type: latent_diffusion_model 37 | params: 38 | diffusion_kwargs: 39 | noise_scheduler: 40 | DDPMScheduler: 41 | num_train_timesteps: &num_classes 1000 42 | beta_start: &sigma_start 0.0001 43 | beta_end: &sigma_end 0.02 44 | beta_schedule: &sigma_dist linear 45 | clip_sample_range: 13 46 | prediction_type: epsilon 47 | classifier_free_guidance: 48 | p_uncond: 0.1 49 | guidance_weight: 1 ## TODO 50 | 51 | sub_model: 52 | casformer: 53 | arch: DiT-custom 54 | config: 55 | input_size: 48 56 | in_channels: 8 57 | mlp_ratio: 4.0 58 | learn_sigma: False 59 | out_channels: 48 60 | split_num: 12 61 | num_heads: 16 62 | single_heads_num: 4 63 | hidden_size: 1152 64 | enc_hidden_size: 256 65 | patch_size: 2 66 | enc_depth: 12 67 | latent_depth: 12 68 | 69 | autoencoder_kl: 70 | in_channels: 1 71 | out_channels: 1 72 | down_block_types: ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'] 73 | up_block_types: ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'] 74 | block_out_channels: [128, 256, 512, 512] 75 | layers_per_block: 2 76 | latent_channels: 4 77 | norm_num_groups: 32 78 | 79 | 80 | save_best: &loss_type MSE 81 | use_ceph: False 82 | ceph_checkpoint_path: "mpas:s3://sevir/checkpoint" 83 | metrics_type: None 84 | data_type: fp32 85 | 86 | visualizer: 87 | visualizer_type: sevir_visualizer 88 | visualizer_step: 4000 89 | 90 | optimizer: 91 | casformer: 92 | type: AdamW 93 | params: 94 | lr: 0.0005 95 | betas: [0.9, 0.95] 96 | # eps: 0.000001 97 | 98 | lr_scheduler: 99 | casformer: 100 | by_step: True 101 | sched: cosine 102 | epochs: *max_epoch 103 | min_lr: 0.00001 104 | warmup_lr: 0.00001 105 | warmup_epochs: 0.1 106 | lr_noise: 107 | cooldown_epochs: 0 108 | 109 | extra_params: 110 | loss_type: MSELoss 111 | enabled_amp: False 112 | log_step: 20 113 | predictor_checkpoint_path: None ## for pretrained advective predictor 114 | autoencoder_checkpoint_path: ckpts/autoencoder/ckpt.pth ## for pretrained autoencoder 115 | save_epoch_interval: 20 116 | 117 | wandb: 118 | project_name: sevir -------------------------------------------------------------------------------- /configs/sevir_used/compress_earthformer.yaml: -------------------------------------------------------------------------------- 1 | sevir: &sevir 2 | type: sevir_preprocess 3 | input_length: &input_length 13 4 | pred_length: &pred_length 12 5 | total_length: &total_length 25 6 | base_freq: 5min 7 | data_dir: radar:s3://weather_radar_datasets/sevir 8 | 9 | dataset: 10 | train: 11 | <<: *sevir 12 | 13 | valid: 14 | <<: *sevir 15 | 16 | test: 17 | <<: *sevir 18 | 19 | sampler: 20 | type: DistributedSampler 21 | 22 | dataloader: 23 | num_workers: 8 24 | pin_memory: False 25 | prefetch_factor: 2 26 | persistent_workers: True 27 | drop_last: False 28 | 29 | trainer: 30 | batch_size: 3 # to check 31 | valid_batch_size: 3 32 | test_batch_size: 3 33 | max_epoch: &max_epoch 1 34 | max_step: 100000 35 | 36 | model: 37 | type: latent_compress_model 38 | params: 39 | latent_size: 48x48x4 40 | model_name: earthformer 41 | latent_data_save_dir: latent_data 42 | 43 | sub_model: 44 | autoencoder_kl: 45 | in_channels: 1 46 | out_channels: 1 47 | down_block_types: ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'] 48 | up_block_types: ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'] 49 | block_out_channels: [128, 256, 512, 512] 50 | layers_per_block: 2 51 | latent_channels: 4 52 | norm_num_groups: 32 53 | 54 | EarthFormer_xy: 55 | in_len: 13 56 | out_len: 12 57 | height: 384 58 | width: 384 59 | 60 | 61 | save_best: &loss_type MSE 62 | use_ceph: False 63 | ceph_checkpoint_path: "mpas:s3://sevir/checkpoint" 64 | metrics_type: SEVIRSkillScore 65 | data_type: fp32 66 | 67 | visualizer: 68 | visualizer_type: sevir_visualizer 69 | visualizer_step: 1000 70 | 71 | optimizer: 72 | autoencoder_kl: 73 | type: AdamW 74 | params: 75 | lr: 0.001 76 | betas: [0.9, 0.95] 77 | # eps: 0.000001 78 | 79 | lr_scheduler: 80 | autoencoder_kl: 81 | by_step: False 82 | sched: cosine 83 | epochs: *max_epoch 84 | min_lr: 0.00001 85 | warmup_lr: 0.00001 86 | warmup_epochs: 0.1 87 | lr_noise: 88 | cooldown_epochs: 0 89 | 90 | extra_params: 91 | loss_type: MSELoss 92 | enabled_amp: False 93 | log_step: 20 94 | predictor_checkpoint_path: ckpts/earthformer/ckpt.pth ## for pretrained advective predictor 95 | autoencoder_checkpoint_path: ckpts/autoencoder/ckpt.pth ## for pretrained autoencoder -------------------------------------------------------------------------------- /configs/sevir_used/compress_gt.yaml: -------------------------------------------------------------------------------- 1 | sevir: &sevir 2 | type: sevir_preprocess 3 | input_length: &input_length 13 4 | pred_length: &pred_length 12 5 | total_length: &total_length 25 6 | base_freq: 5min 7 | data_dir: radar:s3://weather_radar_datasets/sevir 8 | 9 | dataset: 10 | train: 11 | <<: *sevir 12 | 13 | valid: 14 | <<: *sevir 15 | 16 | test: 17 | <<: *sevir 18 | 19 | sampler: 20 | type: DistributedSampler 21 | 22 | dataloader: 23 | num_workers: 8 24 | pin_memory: False 25 | prefetch_factor: 2 26 | persistent_workers: True 27 | drop_last: False 28 | 29 | trainer: 30 | batch_size: 3 # to check 31 | valid_batch_size: 3 32 | test_batch_size: 3 33 | max_epoch: &max_epoch 1 34 | max_step: 100000 35 | 36 | model: 37 | type: latent_compress_model 38 | params: 39 | latent_size: 48x48x4 40 | model_name: gt 41 | latent_data_save_dir: latent_data 42 | 43 | sub_model: 44 | autoencoder_kl: 45 | in_channels: 1 46 | out_channels: 1 47 | down_block_types: ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'] 48 | up_block_types: ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'] 49 | block_out_channels: [128, 256, 512, 512] 50 | layers_per_block: 2 51 | latent_channels: 4 52 | norm_num_groups: 32 53 | 54 | 55 | save_best: &loss_type MSE 56 | use_ceph: False 57 | ceph_checkpoint_path: "mpas:s3://sevir/checkpoint" 58 | metrics_type: SEVIRSkillScore 59 | data_type: fp32 60 | 61 | visualizer: 62 | visualizer_type: sevir_visualizer 63 | visualizer_step: 1000 64 | 65 | optimizer: 66 | autoencoder_kl: 67 | type: AdamW 68 | params: 69 | lr: 0.001 70 | betas: [0.9, 0.95] 71 | # eps: 0.000001 72 | 73 | lr_scheduler: 74 | autoencoder_kl: 75 | by_step: False 76 | sched: cosine 77 | epochs: *max_epoch 78 | min_lr: 0.00001 79 | warmup_lr: 0.00001 80 | warmup_epochs: 0.1 81 | lr_noise: 82 | cooldown_epochs: 0 83 | 84 | extra_params: 85 | loss_type: MSELoss 86 | enabled_amp: False 87 | log_step: 20 88 | predictor_checkpoint_path: None #EarthFormer_xy/world_size1-xytest/checkpoint_latest.pth ## for pretrained advective predictor 89 | autoencoder_checkpoint_path: ckpts/autoencoder/ckpt.pth ## for pretrained autoencoder -------------------------------------------------------------------------------- /datasets/hko7/hko7_list/txt2pkl.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | txt_path = '/mnt/cache/gongjunchao/workdir/radar_forecasting/datasets/hko7_list/hko7_full_list.txt' 4 | 5 | with open(txt_path, 'r', encoding='utf-8') as file: 6 | lines = file.readlines() 7 | 8 | lines_list = [line.strip() for line in lines] 9 | 10 | 11 | with open('/mnt/cache/gongjunchao/workdir/radar_forecasting/datasets/hko7_list/hko7_full_list.pkl', 'wb') as pkl_file: 12 | pickle.dump(lines_list, pkl_file) 13 | -------------------------------------------------------------------------------- /datasets/meteonet/meteonet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | import torch 5 | try: 6 | from petrel_client.client import Client 7 | except: 8 | pass 9 | import io 10 | 11 | 12 | 13 | class meteonet_24(Dataset): 14 | def __init__(self, split, input_length=12, pred_length=12, data_dir='radar:s3://meteonet_data/24Frames', **kwargs): 15 | super().__init__() 16 | assert input_length == 12, pred_length==12 17 | self.input_length = 12 18 | self.pred_length = 12 19 | self.total_length = self.input_length + self.pred_length 20 | 21 | self.file_list = self._init_file_list(split) 22 | 23 | self.data_dir = data_dir 24 | self.client = Client("~/petreloss.conf") 25 | 26 | def _init_file_list(self, split): 27 | if split == 'train': 28 | txt_path = '/mnt/cache/gongjunchao/workdir/radar_forecasting/datasets/meteonet/train_2h.txt' 29 | elif split == 'valid': 30 | txt_path = '/mnt/cache/gongjunchao/workdir/radar_forecasting/datasets/meteonet/valid_2h.txt' 31 | elif split == 'test': 32 | txt_path = '/mnt/cache/gongjunchao/workdir/radar_forecasting/datasets/meteonet/test_2h.txt' 33 | files = [] 34 | with open(f'{txt_path}', 'r') as file: 35 | for line in file.readlines(): 36 | files.append(line.strip()) 37 | return files 38 | 39 | def __len__(self): 40 | return len(self.file_list) 41 | 42 | def _load_frames(self, file): 43 | file_path = os.path.join(self.data_dir, file) 44 | with io.BytesIO(self.client.get(file_path)) as f: 45 | frame_data = np.load(f) 46 | tensor = torch.from_numpy(frame_data) / 70 ##TODO: get max 47 | ## 1, h, w, t -> t, c, h, w 48 | tensor = tensor.unsqueeze(dim=1) 49 | return tensor 50 | 51 | def __getitem__(self, index): 52 | file = self.file_list[index] 53 | frame_data = self._load_frames(file) 54 | packed_results = dict() 55 | packed_results['inputs'] = frame_data[:self.input_length] 56 | packed_results['data_samples'] = frame_data[self.input_length:self.input_length+self.pred_length] 57 | return packed_results 58 | 59 | 60 | 61 | if __name__ == "__main__": 62 | dataset = meteonet_24(split='train', input_length=12, pred_length=12, data_dir='radar:s3://meteonet_data/24Frames') 63 | print(len(dataset)) 64 | 65 | import time 66 | st_time = time.time() 67 | _max = 0 68 | for i in range(len(dataset)): 69 | data = dataset.__getitem__(i) 70 | ed_time = time.time() 71 | _max = max(data['inputs'].max(), _max, data['data_samples'].max()) 72 | 73 | print((ed_time - st_time)/(i+1)) 74 | print(data['inputs'].shape) 75 | print(data['data_samples'].shape) 76 | print(_max) 77 | 78 | ### srun -p ai4earth --kill-on-bad-exit=1 --quotatype=reserved --gres=gpu:0 python -u meteonet.py ### 79 | 80 | -------------------------------------------------------------------------------- /datasets/sevir_diffusion_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | import torch 5 | try: 6 | from petrel_client.client import Client 7 | except: 8 | pass 9 | import io 10 | 11 | 12 | def get_sevir_latent_dataset( split, input_length=13, pred_length=12, data_dir='/mnt/data/oss_beijing/video_prediction_dataset/sevir/sevir', base_freq='5min', height=384, width=384, **kwargs): 13 | return sevir_latent(split, input_length=input_length, pred_length=pred_length, data_dir=data_dir, base_freq=base_freq, height=height, width=width, **kwargs) 14 | 15 | 16 | 17 | class sevir_latent(Dataset): 18 | def __init__(self, split, input_length=13, pred_length=12, base_freq='5min', height=384, width=384, 19 | latent_size='48x48x4', 20 | data_dir='None', latent_gt_dir='None', latent_deterministic_dir='None', 21 | **kwargs): 22 | super().__init__() 23 | assert input_length == 13, pred_length==12 24 | self.input_length = 13 25 | self.pred_length = 12 26 | 27 | self.latent_size = latent_size 28 | 29 | self.file_list = self._init_file_list(split) 30 | self.data_dir = os.path.join(data_dir, f'{split}_2h') 31 | self.latent_gt_dir = os.path.join(latent_gt_dir, f'{split}_2h') 32 | self.latent_deterministic_dir = os.path.join(latent_deterministic_dir, f'{split}_2h') 33 | 34 | def _init_file_list(self, split): 35 | if split == 'train': 36 | txt_path = 'datasets/sevir_list/train.txt' 37 | elif split == 'valid': 38 | txt_path = 'datasets/sevir_list/val.txt' 39 | elif split == 'test': 40 | txt_path = 'datasets/sevir_list/test.txt' 41 | files = [] 42 | with open(f'{txt_path}', 'r') as file: 43 | for line in file.readlines(): 44 | files.append(line.strip()) 45 | return files 46 | 47 | def __len__(self): 48 | return len(self.file_list) 49 | 50 | def _load_latent_frames(self, file, datasource): 51 | file_path = os.path.join(datasource, file) 52 | frame_data = np.load(file_path) 53 | ## t, c, h, w ## 54 | tensor = torch.from_numpy(frame_data) 55 | return tensor 56 | 57 | def _load_frames(self, file): 58 | file_path = os.path.join(self.data_dir, file) 59 | frame_data = np.load(file_path) 60 | tensor = torch.from_numpy(frame_data) / 255 61 | ## 1, h, w, t -> t, c, h, w 62 | tensor = tensor.permute(3, 0, 1, 2) 63 | return tensor 64 | 65 | def __getitem__(self, index): 66 | file = self.file_list[index] 67 | coarse_latent_data = self._load_latent_frames(file, datasource=self.latent_deterministic_dir) 68 | packed_results = dict() 69 | packed_results['inputs'] = coarse_latent_data 70 | return packed_results -------------------------------------------------------------------------------- /datasets/sevir_latent_used.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | import torch 5 | try: 6 | from petrel_client.client import Client 7 | except: 8 | pass 9 | import io 10 | 11 | 12 | def get_sevir_latent_dataset( split, input_length=13, pred_length=12, data_dir='/mnt/data/oss_beijing/video_prediction_dataset/sevir/sevir', base_freq='5min', height=384, width=384, **kwargs): 13 | return sevir_latent(split, input_length=input_length, pred_length=pred_length, data_dir=data_dir, base_freq=base_freq, height=height, width=width, **kwargs) 14 | 15 | 16 | 17 | class sevir_latent(Dataset): 18 | def __init__(self, split, input_length=13, pred_length=12, base_freq='5min', height=384, width=384, 19 | latent_size='48x48x4', 20 | data_dir='None', latent_gt_dir='None', latent_deterministic_dir='None', 21 | **kwargs): 22 | super().__init__() 23 | assert input_length == 13, pred_length==12 24 | self.input_length = 13 25 | self.pred_length = 12 26 | 27 | self.latent_size = latent_size 28 | 29 | self.file_list = self._init_file_list(split) 30 | self.data_dir = os.path.join(data_dir, f'{split}_2h') 31 | self.latent_gt_dir = os.path.join(latent_gt_dir, f'{split}_2h') 32 | self.latent_deterministic_dir = os.path.join(latent_deterministic_dir, f'{split}_2h') 33 | 34 | def _init_file_list(self, split): 35 | if split == 'train': 36 | txt_path = 'datasets/sevir_list/train.txt' 37 | elif split == 'valid': 38 | txt_path = 'datasets/sevir_list/val.txt' 39 | elif split == 'test': 40 | txt_path = 'datasets/sevir_list/test.txt' 41 | files = [] 42 | with open(f'{txt_path}', 'r') as file: 43 | for line in file.readlines(): 44 | files.append(line.strip()) 45 | return files 46 | 47 | def __len__(self): 48 | return len(self.file_list) 49 | 50 | def _load_latent_frames(self, file, datasource): 51 | file_path = os.path.join(datasource, file) 52 | frame_data = np.load(file_path) 53 | ## t, c, h, w ## 54 | tensor = torch.from_numpy(frame_data) 55 | return tensor 56 | 57 | def _load_frames(self, file): 58 | file_path = os.path.join(self.data_dir, file) 59 | frame_data = np.load(file_path) 60 | tensor = torch.from_numpy(frame_data) / 255 61 | ## 1, h, w, t -> t, c, h, w 62 | tensor = tensor.permute(3, 0, 1, 2) 63 | return tensor 64 | 65 | def __getitem__(self, index): 66 | file = self.file_list[index] 67 | gt_data = self._load_frames(file)[self.input_length:] 68 | gt_latent_data = self._load_latent_frames(file, datasource=self.latent_gt_dir) 69 | coarse_latent_data = self._load_latent_frames(file, datasource=self.latent_deterministic_dir) 70 | packed_results = dict() 71 | packed_results['inputs'] = coarse_latent_data 72 | packed_results['data_samples'] = {'latent':gt_latent_data, 'original':gt_data} 73 | return packed_results -------------------------------------------------------------------------------- /datasets/sevir_preprocess_used.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | import torch 5 | try: 6 | from petrel_client.client import Client 7 | except: 8 | pass 9 | import io 10 | 11 | 12 | 13 | # SEVIR Dataset constants 14 | def get_sevir_dataset( split, input_length=13, pred_length=12, data_dir='/mnt/data/oss_beijing/video_prediction_dataset/sevir/sevir', base_freq='5min', height=384, width=384, **kwargs): 15 | return sevir_preprocess(split, input_length=input_length, pred_length=pred_length, data_dir=data_dir, base_freq=base_freq, height=height, width=width, **kwargs) 16 | 17 | 18 | 19 | class sevir_preprocess(Dataset): 20 | def __init__(self, split, input_length=13, pred_length=12, data_dir='radar:s3://weather_radar_datasets/sevir', base_freq='5min', height=384, width=384, **kwargs): 21 | super().__init__() 22 | assert input_length == 13, pred_length==12 23 | self.input_length = 13 24 | self.pred_length = 12 25 | 26 | self.file_list = self._init_file_list(split) 27 | self.data_dir = os.path.join(data_dir, f'{split}_2h') 28 | 29 | 30 | 31 | def _init_file_list(self, split): 32 | if split == 'train': 33 | txt_path = 'datasets/sevir_list/train.txt' 34 | elif split == 'valid': 35 | txt_path = 'datasets/sevir_list/val.txt' 36 | elif split == 'test': 37 | txt_path = 'datasets/sevir_list/test.txt' 38 | files = [] 39 | with open(f'{txt_path}', 'r') as file: 40 | for line in file.readlines(): 41 | files.append(line.strip()) 42 | return files 43 | 44 | def __len__(self): 45 | return len(self.file_list) 46 | 47 | def _load_frames(self, file): 48 | file_path = os.path.join(self.data_dir, file) 49 | frame_data = np.load(file_path) 50 | tensor = torch.from_numpy(frame_data) / 255 51 | ## 1, h, w, t -> t, c, h, w 52 | tensor = tensor.permute(3, 0, 1, 2) 53 | return tensor 54 | 55 | 56 | def __getitem__(self, index): 57 | file = self.file_list[index] 58 | frame_data = self._load_frames(file) 59 | packed_results = dict() 60 | packed_results['inputs'] = frame_data[:self.input_length] 61 | packed_results['data_samples'] = frame_data[self.input_length:self.input_length+self.pred_length] 62 | packed_results['file_name'] = file 63 | return packed_results -------------------------------------------------------------------------------- /datasets/sevir_pretrain_used.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | import torch 5 | import random 6 | try: 7 | from petrel_client.client import Client 8 | except: 9 | pass 10 | import io 11 | 12 | 13 | 14 | class sevir_pretrain(Dataset): 15 | def __init__(self, split, input_length=13, pred_length=12, data_dir='path/to/sevir', base_freq='5min', height=384, width=384, **kwargs): 16 | super().__init__() 17 | assert input_length == 13, pred_length==12 18 | self.input_length = 13 19 | self.pred_length = 12 20 | self.total_length = self.input_length + self.pred_length 21 | 22 | self.file_list = self._init_file_list(split) 23 | self.data_dir = os.path.join(data_dir, f'{split}_2h') 24 | 25 | 26 | def _init_file_list(self, split): 27 | if split == 'train': 28 | txt_path = 'datasets/sevir_list/train.txt' 29 | elif split == 'valid': 30 | txt_path = 'datasets/sevir_list/val.txt' 31 | elif split == 'test': 32 | txt_path = 'datasets/sevir_list/test.txt' 33 | files = [] 34 | with open(f'{txt_path}', 'r') as file: 35 | for line in file.readlines(): 36 | files.append(line.strip()) 37 | return files 38 | 39 | def __len__(self): 40 | return len(self.file_list) 41 | 42 | def _load_frames(self, file): 43 | file_path = os.path.join(self.data_dir, file) 44 | frame_data = np.load(file_path) 45 | tensor = torch.from_numpy(frame_data) / 255 46 | ## 1, h, w, t -> t, c, h, w 47 | tensor = tensor.permute(3, 0, 1, 2) 48 | return tensor 49 | 50 | 51 | def __getitem__(self, index): 52 | file = self.file_list[index] 53 | frame_data = self._load_frames(file) 54 | packed_results = dict() 55 | 56 | select_frame_idx = random.randint(0, self.total_length - 1) 57 | packed_results['inputs'] = frame_data[select_frame_idx] 58 | packed_results['data_samples'] = frame_data[select_frame_idx] 59 | return packed_results -------------------------------------------------------------------------------- /datasets/sevir_used.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | import torch 5 | try: 6 | from petrel_client.client import Client 7 | except: 8 | pass 9 | import io 10 | 11 | 12 | 13 | def get_sevir_dataset( split, input_length=13, pred_length=12, data_dir='/mnt/data/oss_beijing/video_prediction_dataset/sevir/sevir', base_freq='5min', height=384, width=384, **kwargs): 14 | return sevir(split, input_length=input_length, pred_length=pred_length, data_dir=data_dir, base_freq=base_freq, height=height, width=width, **kwargs) 15 | 16 | 17 | 18 | class sevir(Dataset): 19 | def __init__(self, split, input_length=13, pred_length=12, data_dir='path/to/sevir', base_freq='5min', height=384, width=384, **kwargs): 20 | super().__init__() 21 | assert input_length == 13, pred_length==12 22 | self.input_length = 13 23 | self.pred_length = 12 24 | 25 | self.file_list = self._init_file_list(split) 26 | self.data_dir = os.path.join(data_dir, f'{split}_2h') 27 | 28 | 29 | def _init_file_list(self, split): 30 | if split == 'train': 31 | txt_path = 'datasets/sevir_list/train.txt' 32 | elif split == 'valid': 33 | txt_path = 'datasets/sevir_list/val.txt' 34 | elif split == 'test': 35 | txt_path = 'datasets/sevir_list/test.txt' 36 | files = [] 37 | with open(f'{txt_path}', 'r') as file: 38 | for line in file.readlines(): 39 | files.append(line.strip()) 40 | return files 41 | 42 | def __len__(self): 43 | return len(self.file_list) 44 | 45 | def _load_frames(self, file): 46 | file_path = os.path.join(self.data_dir, file) 47 | frame_data = np.load(file_path) 48 | tensor = torch.from_numpy(frame_data) / 255 49 | ## 1, h, w, t -> t, c, h, w 50 | tensor = tensor.permute(3, 0, 1, 2) 51 | return tensor 52 | 53 | 54 | def __getitem__(self, index): 55 | file = self.file_list[index] 56 | frame_data = self._load_frames(file) 57 | packed_results = dict() 58 | packed_results['inputs'] = frame_data[:self.input_length] 59 | packed_results['data_samples'] = frame_data[self.input_length:self.input_length+self.pred_length] 60 | return packed_results 61 | -------------------------------------------------------------------------------- /experiments/EarthFormer/world_size1-ckpt/training_options.yaml: -------------------------------------------------------------------------------- 1 | tensor_model_parallel_size: 1 2 | resume: false 3 | resume_from_config: false 4 | seed: 0 5 | cuda: 0 6 | world_size: 1 7 | per_cpus: 4 8 | local_rank: 0 9 | init_method: tcp://127.0.0.1:34182 10 | outdir: /mnt/lustre/gongjunchao/release_code/cascast/experiments/EarthFormer 11 | cfg: ./configs/sevir_used/EarthFormer.yaml 12 | desc: debug 13 | visual_vars: null 14 | debug: true 15 | resume_checkpoint: null 16 | resume_cfg_file: null 17 | rank: 0 18 | distributed: false 19 | relative_checkpoint_dir: EarthFormer/world_size1-debug 20 | sevir: 21 | type: sevir 22 | input_length: 13 23 | pred_length: 12 24 | total_length: 25 25 | base_freq: 5min 26 | data_dir: pixel_data/sevir #path/to/sevir 27 | dataset: 28 | train: 29 | type: sevir 30 | input_length: 13 31 | pred_length: 12 32 | total_length: 25 33 | base_freq: 5min 34 | data_dir: pixel_data/sevir #path/to/sevir 35 | valid: 36 | type: sevir 37 | input_length: 13 38 | pred_length: 12 39 | total_length: 25 40 | base_freq: 5min 41 | data_dir: pixel_data/sevir #path/to/sevir 42 | sampler: 43 | type: DistributedSampler 44 | dataloader: 45 | num_workers: 8 46 | pin_memory: false 47 | prefetch_factor: 2 48 | persistent_workers: true 49 | trainer: 50 | batch_size: 8 51 | valid_batch_size: 16 52 | max_epoch: 1 53 | max_step: 100000 54 | model: 55 | type: non_ar_model 56 | params: 57 | sub_model: 58 | EarthFormer_xy: 59 | in_len: 13 60 | out_len: 12 61 | height: 384 62 | width: 384 63 | save_best: MSE 64 | use_ceph: false 65 | ceph_checkpoint_path: mpas:s3://sevir/checkpoint 66 | metrics_type: SEVIRSkillScore 67 | data_type: fp32 68 | visualizer: 69 | visualizer_type: sevir_visualizer 70 | visualizer_step: 8000 71 | optimizer: 72 | EarthFormer_xy: 73 | type: AdamW 74 | params: 75 | lr: 0.001 76 | betas: 77 | - 0.9 78 | - 0.999 79 | weight_decay: 1.0e-05 80 | lr_scheduler: 81 | EarthFormer_xy: 82 | by_step: true 83 | sched: cosine 84 | epochs: 1 85 | min_lr: 1.0e-05 86 | warmup_lr: 1.0e-05 87 | warmup_epochs: 0.1 88 | lr_noise: null 89 | cooldown_epochs: 0 90 | extra_params: 91 | loss_type: MSELoss 92 | enabled_amp: false 93 | log_step: 20 94 | z_score_delta: false 95 | wandb: 96 | project_name: sevir 97 | -------------------------------------------------------------------------------- /megatron_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import megatron_utils.parallel_state as parallel_state 2 | import megatron_utils.tensor_parallel 3 | import megatron_utils.utils 4 | 5 | # Alias parallel_state as mpu, its legacy name 6 | mpu = parallel_state 7 | 8 | __all__ = [ 9 | "parallel_state", 10 | "tensor_parallel", 11 | "utils", 12 | ] -------------------------------------------------------------------------------- /megatron_utils/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import vocab_parallel_cross_entropy 2 | from .data import broadcast_data 3 | 4 | from .layers import ( 5 | ColumnParallelLinear, 6 | RowParallelLinear, 7 | VocabParallelEmbedding, 8 | set_tensor_model_parallel_attributes, 9 | set_defaults_if_not_set_tensor_model_parallel_attributes, 10 | copy_tensor_model_parallel_attributes, 11 | param_is_not_tensor_parallel_duplicate, 12 | linear_with_grad_accumulation_and_async_allreduce 13 | 14 | ) 15 | 16 | from .mappings import ( 17 | copy_to_tensor_model_parallel_region, 18 | gather_from_tensor_model_parallel_region, 19 | gather_from_sequence_parallel_region, 20 | scatter_to_tensor_model_parallel_region, 21 | scatter_to_sequence_parallel_region, 22 | ) 23 | 24 | from .random import ( 25 | checkpoint, 26 | get_cuda_rng_tracker, 27 | model_parallel_cuda_manual_seed, 28 | ) 29 | 30 | from .utils import ( 31 | split_tensor_along_last_dim, 32 | split_tensor_into_1d_equal_chunks, 33 | gather_split_1d_tensor, 34 | ) 35 | 36 | __all__ = [ 37 | # cross_entropy.py 38 | "vocab_parallel_cross_entropy", 39 | # data.py 40 | "broadcast_data", 41 | #layers.py 42 | "ColumnParallelLinear", 43 | "RowParallelLinear", 44 | "VocabParallelEmbedding", 45 | "set_tensor_model_parallel_attributes", 46 | "set_defaults_if_not_set_tensor_model_parallel_attributes", 47 | "copy_tensor_model_parallel_attributes", 48 | "param_is_not_tensor_parallel_duplicate", 49 | "linear_with_grad_accumulation_and_async_allreduce", 50 | # mappings.py 51 | "copy_to_tensor_model_parallel_region", 52 | "gather_from_tensor_model_parallel_region", 53 | "gather_from_sequence_parallel_region", 54 | # "reduce_from_tensor_model_parallel_region", 55 | "scatter_to_tensor_model_parallel_region", 56 | "scatter_to_sequence_parallel_region", 57 | # random.py 58 | "checkpoint", 59 | "get_cuda_rng_tracker", 60 | "model_parallel_cuda_manual_seed", 61 | # utils.py 62 | "split_tensor_along_last_dim", 63 | "split_tensor_into_1d_equal_chunks", 64 | "gather_split_1d_tensor", 65 | ] -------------------------------------------------------------------------------- /networks/autoencoder_kl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from networks.prediff.taming.autoencoder_kl import AutoencoderKL 4 | 5 | 6 | class autoencoder_kl(nn.Module): 7 | def __init__(self, config): 8 | super(autoencoder_kl, self).__init__() 9 | self.config = config 10 | self.net = AutoencoderKL(**config) 11 | 12 | def forward(self, sample, sample_posterior=True, return_posterior=True, generator=None): 13 | pred, posterior = self.net(sample, sample_posterior, return_posterior, generator) 14 | out = [pred, posterior] 15 | return out 16 | 17 | 18 | -------------------------------------------------------------------------------- /networks/lpipsWithDisc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from networks.prediff.taming.autoencoder_kl import AutoencoderKL 4 | from networks.prediff.taming.losses.contperceptual import LPIPSWithDiscriminator 5 | from networks.prediff.utils.distributions import DiagonalGaussianDistribution 6 | 7 | class lpipsWithDisc(nn.Module): 8 | def __init__(self, config): 9 | super(lpipsWithDisc, self).__init__() 10 | self.config = config 11 | self.net = LPIPSWithDiscriminator(**config) 12 | 13 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, global_step, mask=None, last_layer=None, split='train'): 14 | loss, loss_dict = self.net(inputs=inputs, reconstructions=reconstructions, posteriors=posteriors, 15 | optimizer_idx=optimizer_idx, global_step=global_step, mask=mask, last_layer=last_layer, split=split) 16 | out = [loss, loss_dict] 17 | return out -------------------------------------------------------------------------------- /networks/prediff/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1.dev' -------------------------------------------------------------------------------- /networks/prediff/models/cuboid_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .cuboid_transformer_unet import CuboidTransformerUNet -------------------------------------------------------------------------------- /networks/prediff/taming/README.md: -------------------------------------------------------------------------------- 1 | # Taming Transformers for High-Resolution Image Synthesis 2 | This subdirectory contains the implementations of `AutoencoderKL` and `LPIPSWithDiscriminator`. 3 | All codes are adopted from [taming-transformers](https://github.com/CompVis/taming-transformers), [stable-diffusion](https://github.com/CompVis/stable-diffusion) and [Diffusers](https://huggingface.co/docs/diffusers) 4 | 5 | Alternatively, you can use the implementation of `AutoencoderKL` from [diffusers==0.13.0](https://github.com/huggingface/diffusers/blob/v0.13.0/src/diffusers/models/autoencoder_kl.py) by 6 | ```python 7 | from diffusers.models import AutoencoderKL 8 | ``` 9 | and set `return_dict=False` in methods `forward`, `encode`, `_decode` and `decode`. 10 | -------------------------------------------------------------------------------- /networks/prediff/taming/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder_kl import AutoencoderKL 2 | from .losses.contperceptual import LPIPSWithDiscriminator 3 | -------------------------------------------------------------------------------- /networks/prediff/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenEarthLab/CasCast/487a68b5ade9aa829fe7df2e8f6746b4d9acc233/networks/prediff/utils/__init__.py -------------------------------------------------------------------------------- /networks/prediff/utils/distributions.py: -------------------------------------------------------------------------------- 1 | """Code is adapted from https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/distributions/distributions.py""" 2 | from typing import Optional 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class AbstractDistribution: 8 | def sample(self): 9 | raise NotImplementedError() 10 | 11 | def mode(self): 12 | raise NotImplementedError() 13 | 14 | 15 | class DiracDistribution(AbstractDistribution): 16 | def __init__(self, value): 17 | self.value = value 18 | 19 | def sample(self): 20 | return self.value 21 | 22 | def mode(self): 23 | return self.value 24 | 25 | 26 | class DiagonalGaussianDistribution(object): 27 | def __init__(self, parameters, deterministic=False): 28 | self.parameters = parameters 29 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 30 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 31 | self.deterministic = deterministic 32 | self.std = torch.exp(0.5 * self.logvar) 33 | self.var = torch.exp(self.logvar) 34 | if self.deterministic: 35 | self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) 36 | 37 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: 38 | # make sure sample is on the same device as the parameters and has same dtype 39 | sample = torch.randn(self.mean.shape, generator=generator, 40 | device=self.parameters.device, dtype=self.parameters.dtype) 41 | x = self.mean + self.std * sample 42 | return x 43 | 44 | def kl(self, other=None): 45 | if self.deterministic: 46 | return torch.Tensor([0.]) 47 | else: 48 | if other is None: 49 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 50 | + self.var 51 | - 1.0 52 | - self.logvar, 53 | dim=[1, 2, 3], ) 54 | else: 55 | return 0.5 * torch.sum( 56 | torch.pow(self.mean - other.mean, 2) / other.var 57 | + self.var / other.var 58 | - 1.0 59 | - self.logvar 60 | + other.logvar, 61 | dim=[1, 2, 3], ) 62 | 63 | def nll(self, sample, dims=[1, 2, 3]): 64 | if self.deterministic: 65 | return torch.Tensor([0.]) 66 | logtwopi = np.log(2.0 * np.pi) 67 | return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims) 69 | 70 | def mode(self): 71 | return self.mean 72 | 73 | 74 | def normal_kl(mean1, logvar1, mean2, logvar2): 75 | """ 76 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 77 | Compute the KL divergence between two gaussians. 78 | Shapes are automatically broadcasted, so batches can be compared to 79 | scalars, among other use cases. 80 | """ 81 | tensor = None 82 | for obj in (mean1, logvar1, mean2, logvar2): 83 | if isinstance(obj, torch.Tensor): 84 | tensor = obj 85 | break 86 | assert tensor is not None, "at least one argument must be a Tensor" 87 | 88 | # Force variances to be Tensors. Broadcasting helps convert scalars to 89 | # Tensors, but it does not work for torch.exp(). 90 | logvar1, logvar2 = [ 91 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 92 | for x in (logvar1, logvar2) 93 | ] 94 | 95 | return 0.5 * ( 96 | -1.0 97 | + logvar2 98 | - logvar1 99 | + torch.exp(logvar1 - logvar2) 100 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 101 | ) 102 | -------------------------------------------------------------------------------- /networks/prediff/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | 5 | pretrained_sevirlr_vae_name = "pretrained_sevirlr_vae_8x8x64_v1.pt" 6 | pretrained_sevirlr_earthformerunet_name = "pretrained_sevirlr_earthformerunet_v1.pt" 7 | pretrained_sevirlr_alignment_name = "pretrained_sevirlr_alignment_avg_x_cuboid_v1.pt" 8 | 9 | file_id_dict = { 10 | pretrained_sevirlr_vae_name: "10OicEQuOPzSKDp5WYF3zDHsL-COywe98", 11 | pretrained_sevirlr_earthformerunet_name: "1cVB0Sm2V4OMTLxNNEXlb2__ONqSAjUDJ", 12 | pretrained_sevirlr_alignment_name: "1CzrzNVDVTyc8ivnWqEOiJnrN6wraY6fy", 13 | } 14 | 15 | 16 | def download_pretrained_weights(ckpt_name, save_dir=None, exist_ok=False): 17 | r""" 18 | Download pretrained weights from Google Drive. 19 | 20 | Parameters 21 | ---------- 22 | ckpt_name: str 23 | save_dir: str 24 | exist_ok: bool 25 | """ 26 | if save_dir is None: 27 | from .path import default_pretrained_dir 28 | save_dir = default_pretrained_dir 29 | ckpt_path = os.path.join(save_dir, ckpt_name) 30 | if os.path.exists(ckpt_path) and not exist_ok: 31 | warnings.warn(f"Checkpoint file {ckpt_path} already exists!") 32 | else: 33 | os.makedirs(save_dir, exist_ok=True) 34 | file_id = file_id_dict[ckpt_name] 35 | os.system(f"wget --load-cookies /tmp/cookies.txt " 36 | f"\"https://docs.google.com/uc?export=download&confirm=" 37 | f"$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate " 38 | f"'https://docs.google.com/uc?export=download&id={file_id}'" 39 | f" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')" 40 | f"&id={file_id}\" -O {ckpt_path} && rm -rf /tmp/cookies.txt") 41 | -------------------------------------------------------------------------------- /networks/prediff/utils/ema.py: -------------------------------------------------------------------------------- 1 | """Code is adapted from https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/ema.py""" 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class LitEma(nn.Module): 7 | def __init__(self, model, decay=0.9999, use_num_upates=True): 8 | super().__init__() 9 | if decay < 0.0 or decay > 1.0: 10 | raise ValueError('Decay must be between 0 and 1') 11 | 12 | self.m_name2s_name = {} 13 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 14 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 15 | else torch.tensor(-1, dtype=torch.int)) 16 | 17 | for name, p in model.named_parameters(): 18 | if p.requires_grad: 19 | # remove as '.'-character is not allowed in buffers 20 | s_name = name.replace('.', '') 21 | self.m_name2s_name.update({name: s_name}) 22 | self.register_buffer(s_name, p.clone().detach().data) 23 | 24 | self.collected_params = [] 25 | 26 | def forward(self, model): 27 | decay = self.decay 28 | 29 | if self.num_updates >= 0: 30 | self.num_updates += 1 31 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 32 | 33 | one_minus_decay = 1.0 - decay 34 | 35 | with torch.no_grad(): 36 | m_param = dict(model.named_parameters()) 37 | shadow_params = dict(self.named_buffers()) 38 | 39 | for key in m_param: 40 | if m_param[key].requires_grad: 41 | sname = self.m_name2s_name[key] 42 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 43 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 44 | else: 45 | assert key not in self.m_name2s_name 46 | 47 | def copy_to(self, model): 48 | m_param = dict(model.named_parameters()) 49 | shadow_params = dict(self.named_buffers()) 50 | for key in m_param: 51 | if m_param[key].requires_grad: 52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 53 | else: 54 | assert key not in self.m_name2s_name 55 | 56 | def store(self, parameters): 57 | """ 58 | Save the current parameters for restoring later. 59 | Args: 60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 61 | temporarily stored. 62 | """ 63 | self.collected_params = [param.clone() for param in parameters] 64 | 65 | def restore(self, parameters): 66 | """ 67 | Restore the parameters stored with the `store` method. 68 | Useful to validate the model with EMA parameters without affecting the 69 | original optimization process. Store the parameters before the 70 | `copy_to` method. After validation (or model saving), use this to 71 | restore the former parameters. 72 | Args: 73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 74 | updated with the stored parameters. 75 | """ 76 | for c_param, param in zip(self.collected_params, parameters): 77 | param.data.copy_(c_param.data) 78 | -------------------------------------------------------------------------------- /networks/prediff/utils/gifmaker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def save_gif(single_seq, fname): 6 | """Save a single gif consisting of image sequence in single_seq to fname.""" 7 | img_seq = [Image.fromarray(img.astype(np.float32) * 255, 'F').convert("L") for img in single_seq] 8 | img = img_seq[0] 9 | img.save(fname, save_all=True, append_images=img_seq[1:]) 10 | -------------------------------------------------------------------------------- /networks/prediff/utils/layout.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, Any 3 | 4 | 5 | def layout_to_in_out_slice(layout, in_len, out_len=None): 6 | t_axis = layout.find("T") 7 | num_axes = len(layout) 8 | in_slice = [slice(None, None), ] * num_axes 9 | out_slice = deepcopy(in_slice) 10 | in_slice[t_axis] = slice(None, in_len) 11 | if out_len is None: 12 | out_slice[t_axis] = slice(in_len, None) 13 | else: 14 | out_slice[t_axis] = slice(in_len, in_len + out_len) 15 | return in_slice, out_slice 16 | 17 | 18 | def parse_layout_shape(layout: str) -> Dict[str, Any]: 19 | r""" 20 | 21 | Parameters 22 | ---------- 23 | layout: str 24 | e.g., "NTHWC", "NHWC". 25 | 26 | Returns 27 | ------- 28 | ret: Dict 29 | """ 30 | batch_axis = layout.find("N") 31 | t_axis = layout.find("T") 32 | h_axis = layout.find("H") 33 | w_axis = layout.find("W") 34 | c_axis = layout.find("C") 35 | return { 36 | "batch_axis": batch_axis, 37 | "t_axis": t_axis, 38 | "h_axis": h_axis, 39 | "w_axis": w_axis, 40 | "c_axis": c_axis, 41 | } 42 | -------------------------------------------------------------------------------- /networks/prediff/utils/optim.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def warmup_lambda(warmup_steps, min_lr_ratio=0.1): 7 | def ret_lambda(epoch): 8 | if epoch <= warmup_steps: 9 | return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps 10 | else: 11 | return 1.0 12 | return ret_lambda 13 | 14 | 15 | def get_loss_fn(loss: str = "l2") -> Callable: 16 | if loss in ("l2", "mse"): 17 | return F.mse_loss 18 | elif loss in ("l1", "mae"): 19 | return F.l1_loss 20 | 21 | 22 | def disabled_train(self): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def disable_train(model: nn.Module): 29 | r""" 30 | Disable training to avoid error when used in pl.LightningModule 31 | """ 32 | model.eval() 33 | model.train = disabled_train 34 | for param in model.parameters(): 35 | param.requires_grad = False 36 | return model 37 | -------------------------------------------------------------------------------- /networks/prediff/utils/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) 4 | 5 | default_exps_dir = os.path.abspath(os.path.join(root_dir, "experiments")) 6 | 7 | default_dataset_dir = os.path.abspath(os.path.join(root_dir, "datasets")) 8 | default_dataset_sevir_dir = os.path.abspath(os.path.join(default_dataset_dir, "sevir")) 9 | default_dataset_sevirlr_dir = os.path.abspath(os.path.join(default_dataset_dir, "sevirlr")) 10 | 11 | default_pretrained_dir = os.path.abspath(os.path.join(root_dir, "pretrained")) 12 | default_pretrained_metrics_dir = os.path.abspath(os.path.join(default_pretrained_dir, "metrics")) 13 | default_pretrained_vae_dir = os.path.abspath(os.path.join(default_pretrained_dir, "vae")) 14 | default_pretrained_earthformerunet_dir = os.path.abspath(os.path.join(default_pretrained_dir, "earthformerunet")) 15 | default_pretrained_alignment_dir = os.path.abspath(os.path.join(default_pretrained_dir, "alignment")) 16 | -------------------------------------------------------------------------------- /networks/prediff/utils/pl_checkpoint.py: -------------------------------------------------------------------------------- 1 | from typing import IO, Union, Callable 2 | from collections import OrderedDict 3 | import torch 4 | from lightning_fabric.utilities.cloud_io import _load 5 | from lightning_fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH 6 | from lightning.pytorch.utilities.migration import pl_legacy_patch 7 | from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint 8 | 9 | 10 | def pl_load( 11 | path_or_url: Union[IO, _PATH], 12 | map_location: _MAP_LOCATION_TYPE = None, 13 | ) -> OrderedDict[str, torch.Tensor]: 14 | r""" 15 | Load the `state_dict` only from a PyTorch-Lightning checkpoint. 16 | Code is adopted from https://github.com/Lightning-AI/lightning/blob/255b18823e7da265e0e2e3996f55dcd0f78e9f3e/src/lightning/pytorch/core/saving.py 17 | """ 18 | with pl_legacy_patch(): 19 | checkpoint = _load(path_or_url, map_location=map_location) 20 | # convert legacy checkpoints to the new format 21 | checkpoint = _pl_migrate_checkpoint( 22 | checkpoint, checkpoint_path=(path_or_url if isinstance(path_or_url, _PATH) else None) 23 | ) 24 | return checkpoint["state_dict"] 25 | 26 | 27 | def pl_ckpt_to_state_dict( 28 | checkpoint_path: str, 29 | map_location: _MAP_LOCATION_TYPE = None, 30 | key_fn: Callable = lambda x: x, 31 | ): 32 | r""" 33 | Parameters 34 | ---------- 35 | checkpoint_path: str 36 | map_location: _MAP_LOCATION_TYPE 37 | A function, torch.device, string or a dict specifying how to remap storage locations. 38 | The same as the arg `map_location` in `torch.load()`. 39 | key_fn: Callable 40 | A function to map the keys in the loaded checkpoint to the desired keys in the returned state_dict. 41 | 42 | Returns 43 | ------- 44 | state_dict: OrderedDict 45 | """ 46 | if map_location is None: 47 | map_location = lambda storage, loc: storage 48 | pl_ckpt_state_dict = pl_load(checkpoint_path, map_location=map_location) 49 | state_dict = {key_fn(key): val for key, val in pl_ckpt_state_dict.items()} 50 | return state_dict 51 | -------------------------------------------------------------------------------- /scripts/compress_earthformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=1 4 | node_num=1 5 | single_gpus=`expr $gpus / $node_num` 6 | 7 | cpus=13 8 | 9 | # export NCCL_IB_DISABLE=1 10 | # export NCCL_SOCKET_IFNAME=eth0 11 | # export NCCL_DEBUG=INFO 12 | # export NCCL_DEBUG_SUBSYS=ALL 13 | # export TORCH_DISTRIBUTED_DEBUG=INFO 14 | 15 | while true 16 | do 17 | PORT=$((((RANDOM<<15)|RANDOM)%49152 + 10000)) 18 | break 19 | done 20 | echo $PORT 21 | 22 | # export TORCH_DISTRIBUTED_DEBUG=DETAIL 23 | 24 | srun -p ai4earth --kill-on-bad-exit=1 --quotatype=reserved --ntasks-per-node=$single_gpus --time=43200 --cpus-per-task=$cpus -N $node_num -o train_job/%j.out --gres=gpu:$single_gpus --async python -u latent_preprocess.py \ 25 | --init_method 'tcp://127.0.0.1:'$PORT \ 26 | -c ./configs/sevir_used/compress_earthformer.yaml \ 27 | --world_size $gpus \ 28 | --per_cpus $cpus \ 29 | --tensor_model_parallel_size 1 \ 30 | --outdir '/mnt/cache/gongjunchao/workdir/radar_forecasting/experiments' \ 31 | --desc 'earthformer_48x48x4' 32 | 33 | # 34 | sleep 2 35 | rm -f batchscript-* -------------------------------------------------------------------------------- /scripts/compress_gt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=1 4 | node_num=1 5 | single_gpus=`expr $gpus / $node_num` 6 | 7 | cpus=13 8 | 9 | # export NCCL_IB_DISABLE=1 10 | # export NCCL_SOCKET_IFNAME=eth0 11 | # export NCCL_DEBUG=INFO 12 | # export NCCL_DEBUG_SUBSYS=ALL 13 | # export TORCH_DISTRIBUTED_DEBUG=INFO 14 | 15 | while true 16 | do 17 | PORT=$((((RANDOM<<15)|RANDOM)%49152 + 10000)) 18 | break 19 | done 20 | echo $PORT 21 | 22 | # export TORCH_DISTRIBUTED_DEBUG=DETAIL 23 | 24 | srun -p ai4earth --kill-on-bad-exit=1 --quotatype=reserved --ntasks-per-node=$single_gpus --time=43200 --cpus-per-task=$cpus -N $node_num -o train_job/%j.out --gres=gpu:$single_gpus --async python -u latent_preprocess.py \ 25 | --init_method 'tcp://127.0.0.1:'$PORT \ 26 | -c ./configs/sevir_used/compress_gt.yaml \ 27 | --world_size $gpus \ 28 | --per_cpus $cpus \ 29 | --tensor_model_parallel_size 1 \ 30 | --outdir '/mnt/cache/gongjunchao/workdir/radar_forecasting/experiments' \ 31 | --desc 'gt_48x48x4' 32 | 33 | # 34 | sleep 2 35 | rm -f batchscript-* -------------------------------------------------------------------------------- /scripts/eval_deterministic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=1 4 | node_num=1 5 | single_gpus=`expr $gpus / $node_num` 6 | 7 | cpus=8 8 | 9 | # export NCCL_IB_DISABLE=1 10 | # export NCCL_SOCKET_IFNAME=eth0 11 | # export NCCL_DEBUG=INFO 12 | # export NCCL_DEBUG_SUBSYS=ALL 13 | 14 | 15 | PORT=$((((RANDOM<<15)|RANDOM)%49152 + 10000)) 16 | 17 | echo $PORT 18 | 19 | srun -p ai4earth --quotatype=reserved --ntasks-per-node=$single_gpus --cpus-per-task=$cpus --time=43200 -N $node_num --gres=gpu:$single_gpus python -u evaluation.py \ 20 | --init_method 'tcp://127.0.0.1:'$PORT \ 21 | --world_size $gpus \ 22 | --per_cpus $cpus \ 23 | --batch_size 8 \ 24 | --num_workers 8 \ 25 | --cfgdir /mnt/lustre/gongjunchao/release_code/cascast/experiments/EarthFormer/world_size1-ckpt \ 26 | --pred_len 12 \ 27 | --test_name test \ 28 | --metrics_type SEVIRSkillScore 29 | 30 | sleep 2 31 | rm -f batchscript-* -------------------------------------------------------------------------------- /scripts/eval_diffusion.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=1 4 | node_num=1 5 | single_gpus=`expr $gpus / $node_num` 6 | 7 | cpus=8 8 | 9 | # export NCCL_IB_DISABLE=1 10 | # export NCCL_SOCKET_IFNAME=eth0 11 | # export NCCL_DEBUG=INFO 12 | # export NCCL_DEBUG_SUBSYS=ALL 13 | 14 | 15 | PORT=$((((RANDOM<<15)|RANDOM)%49152 + 10000)) 16 | 17 | echo $PORT 18 | 19 | 20 | srun -p ai4earth --quotatype=auto --ntasks-per-node=$single_gpus -x SH-IDC1-10-140-24-110 --cpus-per-task=$cpus --time=43200 -N $node_num --gres=gpu:$single_gpus python -u evaluation.py \ 21 | --init_method 'tcp://127.0.0.1:'$PORT \ 22 | --world_size $gpus \ 23 | --per_cpus $cpus \ 24 | --batch_size 8 \ 25 | --num_workers 8 \ 26 | --cfgdir /mnt/lustre/gongjunchao/release_code/cascast/experiments/cascast_diffusion/world_size1-ckpt \ 27 | --pred_len 12 \ 28 | --test_name test \ 29 | --ens_member 1 \ 30 | --cfg_weight 2 \ 31 | --metrics_type SEVIRSkillScore 32 | 33 | sleep 2 34 | rm -f batchscript-* -------------------------------------------------------------------------------- /scripts/eval_diffusion_infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=1 4 | node_num=1 5 | single_gpus=`expr $gpus / $node_num` 6 | 7 | cpus=8 8 | 9 | # export NCCL_IB_DISABLE=1 10 | # export NCCL_SOCKET_IFNAME=eth0 11 | # export NCCL_DEBUG=INFO 12 | # export NCCL_DEBUG_SUBSYS=ALL 13 | 14 | 15 | PORT=$((((RANDOM<<15)|RANDOM)%49152 + 10000)) 16 | 17 | echo $PORT 18 | 19 | 20 | srun -p ai4earth --quotatype=auto --ntasks-per-node=$single_gpus -x SH-IDC1-10-140-24-110 --cpus-per-task=$cpus --time=43200 -N $node_num --gres=gpu:$single_gpus python -u evaluation.py \ 21 | --init_method 'tcp://127.0.0.1:'$PORT \ 22 | --world_size $gpus \ 23 | --per_cpus $cpus \ 24 | --batch_size 8 \ 25 | --num_workers 8 \ 26 | --cfgdir ./experiments/cascast_diffusion/world_size1-ckpt \ 27 | --pred_len 12 \ 28 | --test_name test \ 29 | --ens_member 1 \ 30 | --cfg_weight 2 \ 31 | --metrics_type SEVIRSkillScore 32 | 33 | sleep 2 34 | rm -f batchscript-* -------------------------------------------------------------------------------- /scripts/train_autoencoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=4 4 | node_num=1 5 | single_gpus=`expr $gpus / $node_num` 6 | 7 | cpus=13 8 | 9 | # export NCCL_IB_DISABLE=1 10 | # export NCCL_SOCKET_IFNAME=eth0 11 | # export NCCL_DEBUG=INFO 12 | # export NCCL_DEBUG_SUBSYS=ALL 13 | # export TORCH_DISTRIBUTED_DEBUG=INFO 14 | 15 | while true 16 | do 17 | PORT=$((((RANDOM<<15)|RANDOM)%49152 + 10000)) 18 | break 19 | done 20 | echo $PORT 21 | 22 | # export TORCH_DISTRIBUTED_DEBUG=DETAIL 23 | 24 | srun -p ai4earth --kill-on-bad-exit=1 --quotatype=reserved --ntasks-per-node=$single_gpus --time=43200 --cpus-per-task=$cpus -N $node_num -o train_job/%j.out --gres=gpu:$single_gpus --async python -u train.py \ 25 | --init_method 'tcp://127.0.0.1:'$PORT \ 26 | -c ./configs/sevir_used/autoencoder_kl_gan.yaml \ 27 | --world_size $gpus \ 28 | --per_cpus $cpus \ 29 | --tensor_model_parallel_size 1 \ 30 | --outdir '/mnt/lustre/gongjunchao/release_code/cascast/experiments' \ 31 | --desc 'bs8_200k' 32 | 33 | # 34 | sleep 2 35 | rm -f batchscript-* -------------------------------------------------------------------------------- /scripts/train_deterministic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=4 4 | node_num=1 5 | single_gpus=`expr $gpus / $node_num` 6 | 7 | cpus=13 8 | 9 | # export NCCL_IB_DISABLE=1 10 | # export NCCL_SOCKET_IFNAME=eth0 11 | # export NCCL_DEBUG=INFO 12 | # export NCCL_DEBUG_SUBSYS=ALL 13 | # export TORCH_DISTRIBUTED_DEBUG=INFO 14 | 15 | while true 16 | do 17 | PORT=$((((RANDOM<<15)|RANDOM)%49152 + 10000)) 18 | break 19 | done 20 | echo $PORT 21 | 22 | # export TORCH_DISTRIBUTED_DEBUG=DETAIL 23 | 24 | srun -p ai4earth --kill-on-bad-exit=1 --quotatype=reserved --ntasks-per-node=$single_gpus --time=43200 --cpus-per-task=$cpus -N $node_num -o train_job/%j.out --gres=gpu:$single_gpus --async python -u train.py \ 25 | --init_method 'tcp://127.0.0.1:'$PORT \ 26 | -c ./configs/sevir_used/EarthFormer.yaml \ 27 | --world_size $gpus \ 28 | --per_cpus $cpus \ 29 | --tensor_model_parallel_size 1 \ 30 | --outdir '/mnt/lustre/gongjunchao/release_code/cascast/experiments' \ 31 | --desc 'earthformer_bs32_100k' 32 | 33 | # 34 | sleep 2 35 | rm -f batchscript-* -------------------------------------------------------------------------------- /scripts/train_diffusion.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=4 4 | node_num=1 5 | single_gpus=`expr $gpus / $node_num` 6 | 7 | cpus=4 8 | 9 | # export NCCL_IB_DISABLE=1 10 | # export NCCL_SOCKET_IFNAME=eth0 11 | # export NCCL_DEBUG=INFO 12 | # export NCCL_DEBUG_SUBSYS=ALL 13 | # export TORCH_DISTRIBUTED_DEBUG=INFO 14 | 15 | while true 16 | do 17 | PORT=$((((RANDOM<<15)|RANDOM)%49152 + 10000)) 18 | break 19 | done 20 | echo $PORT 21 | 22 | # export TORCH_DISTRIBUTED_DEBUG=DETAIL 23 | 24 | srun -p ai4earth --kill-on-bad-exit=1 --quotatype=auto --ntasks-per-node=$gpus --time=43200 --cpus-per-task=$cpus -N $node_num --gres=gpu:$single_gpus python -u train.py \ 25 | --init_method 'tcp://127.0.0.1:'$PORT \ 26 | -c ./configs/sevir_used/cascast_diffusion.yaml \ 27 | --world_size $gpus \ 28 | --per_cpus $cpus \ 29 | --tensor_model_parallel_size 1 \ 30 | --outdir '/mnt/lustre/gongjunchao/release_code/cascast/experiments' \ 31 | --desc 'debug' \ 32 | --debug 33 | # 34 | sleep 2 35 | rm -f batchscript-* -------------------------------------------------------------------------------- /src/diffusers/commands/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from argparse import ArgumentParser 17 | 18 | 19 | class BaseDiffusersCLICommand(ABC): 20 | @staticmethod 21 | @abstractmethod 22 | def register_subcommand(parser: ArgumentParser): 23 | raise NotImplementedError() 24 | 25 | @abstractmethod 26 | def run(self): 27 | raise NotImplementedError() 28 | -------------------------------------------------------------------------------- /src/diffusers/commands/diffusers_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from argparse import ArgumentParser 17 | 18 | from .env import EnvironmentCommand 19 | from .fp16_safetensors import FP16SafetensorsCommand 20 | 21 | 22 | def main(): 23 | parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") 24 | commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") 25 | 26 | # Register commands 27 | EnvironmentCommand.register_subcommand(commands_parser) 28 | FP16SafetensorsCommand.register_subcommand(commands_parser) 29 | 30 | # Let's go 31 | args = parser.parse_args() 32 | 33 | if not hasattr(args, "func"): 34 | parser.print_help() 35 | exit(1) 36 | 37 | # Run 38 | service = args.func(args) 39 | service.run() 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /src/diffusers/commands/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import platform 16 | from argparse import ArgumentParser 17 | 18 | import huggingface_hub 19 | 20 | from .. import __version__ as version 21 | from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available 22 | from . import BaseDiffusersCLICommand 23 | 24 | 25 | def info_command_factory(_): 26 | return EnvironmentCommand() 27 | 28 | 29 | class EnvironmentCommand(BaseDiffusersCLICommand): 30 | @staticmethod 31 | def register_subcommand(parser: ArgumentParser): 32 | download_parser = parser.add_parser("env") 33 | download_parser.set_defaults(func=info_command_factory) 34 | 35 | def run(self): 36 | hub_version = huggingface_hub.__version__ 37 | 38 | pt_version = "not installed" 39 | pt_cuda_available = "NA" 40 | if is_torch_available(): 41 | import torch 42 | 43 | pt_version = torch.__version__ 44 | pt_cuda_available = torch.cuda.is_available() 45 | 46 | transformers_version = "not installed" 47 | if is_transformers_available(): 48 | import transformers 49 | 50 | transformers_version = transformers.__version__ 51 | 52 | accelerate_version = "not installed" 53 | if is_accelerate_available(): 54 | import accelerate 55 | 56 | accelerate_version = accelerate.__version__ 57 | 58 | xformers_version = "not installed" 59 | if is_xformers_available(): 60 | import xformers 61 | 62 | xformers_version = xformers.__version__ 63 | 64 | info = { 65 | "`diffusers` version": version, 66 | "Platform": platform.platform(), 67 | "Python version": platform.python_version(), 68 | "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", 69 | "Huggingface_hub version": hub_version, 70 | "Transformers version": transformers_version, 71 | "Accelerate version": accelerate_version, 72 | "xFormers version": xformers_version, 73 | "Using GPU in script?": "", 74 | "Using distributed or parallel set-up in script?": "", 75 | } 76 | 77 | print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") 78 | print(self.format_dict(info)) 79 | 80 | return info 81 | 82 | @staticmethod 83 | def format_dict(d): 84 | return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" 85 | -------------------------------------------------------------------------------- /src/diffusers/dependency_versions_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .dependency_versions_table import deps 16 | from .utils.versions import require_version, require_version_core 17 | 18 | 19 | # define which module versions we always want to check at run time 20 | # (usually the ones defined in `install_requires` in setup.py) 21 | # 22 | # order specific notes: 23 | # - tqdm must be checked before tokenizers 24 | 25 | pkgs_to_check_at_runtime = "python requests filelock numpy".split() 26 | for pkg in pkgs_to_check_at_runtime: 27 | if pkg in deps: 28 | require_version_core(deps[pkg]) 29 | else: 30 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") 31 | 32 | 33 | def dep_version_check(pkg, hint=None): 34 | require_version(deps[pkg], hint) 35 | -------------------------------------------------------------------------------- /src/diffusers/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update` 4 | deps = { 5 | "Pillow": "Pillow", 6 | "accelerate": "accelerate>=0.11.0", 7 | "compel": "compel==0.1.8", 8 | "datasets": "datasets", 9 | "filelock": "filelock", 10 | "flax": "flax>=0.4.1", 11 | "hf-doc-builder": "hf-doc-builder>=0.3.0", 12 | "huggingface-hub": "huggingface-hub>=0.19.4", 13 | "requests-mock": "requests-mock==1.10.0", 14 | "importlib_metadata": "importlib_metadata", 15 | "invisible-watermark": "invisible-watermark>=0.2.0", 16 | "isort": "isort>=5.5.4", 17 | "jax": "jax>=0.4.1", 18 | "jaxlib": "jaxlib>=0.4.1", 19 | "Jinja2": "Jinja2", 20 | "k-diffusion": "k-diffusion>=0.0.12", 21 | "torchsde": "torchsde", 22 | "note_seq": "note_seq", 23 | "librosa": "librosa", 24 | "numpy": "numpy", 25 | "omegaconf": "omegaconf", 26 | "parameterized": "parameterized", 27 | "peft": "peft>=0.6.0", 28 | "protobuf": "protobuf>=3.20.3,<4", 29 | "pytest": "pytest", 30 | "pytest-timeout": "pytest-timeout", 31 | "pytest-xdist": "pytest-xdist", 32 | "python": "python>=3.8.0", 33 | "ruff": "ruff>=0.1.5,<=0.2", 34 | "safetensors": "safetensors>=0.3.1", 35 | "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", 36 | "scipy": "scipy", 37 | "onnx": "onnx", 38 | "regex": "regex!=2019.12.17", 39 | "requests": "requests", 40 | "tensorboard": "tensorboard", 41 | "torch": "torch>=1.4", 42 | "torchvision": "torchvision", 43 | "transformers": "transformers>=4.25.1", 44 | "urllib3": "urllib3<=2.0.0", 45 | } 46 | -------------------------------------------------------------------------------- /src/diffusers/experimental/README.md: -------------------------------------------------------------------------------- 1 | # 🧨 Diffusers Experimental 2 | 3 | We are adding experimental code to support novel applications and usages of the Diffusers library. 4 | Currently, the following experiments are supported: 5 | * Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. -------------------------------------------------------------------------------- /src/diffusers/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | from .rl import ValueGuidedRLPipeline 2 | -------------------------------------------------------------------------------- /src/diffusers/experimental/rl/__init__.py: -------------------------------------------------------------------------------- 1 | from .value_guided_sampling import ValueGuidedRLPipeline 2 | -------------------------------------------------------------------------------- /src/diffusers/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate 4 | from ..utils.import_utils import is_torch_available, is_transformers_available 5 | 6 | 7 | def text_encoder_lora_state_dict(text_encoder): 8 | deprecate( 9 | "text_encoder_load_state_dict in `models`", 10 | "0.27.0", 11 | "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", 12 | ) 13 | state_dict = {} 14 | 15 | for name, module in text_encoder_attn_modules(text_encoder): 16 | for k, v in module.q_proj.lora_linear_layer.state_dict().items(): 17 | state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v 18 | 19 | for k, v in module.k_proj.lora_linear_layer.state_dict().items(): 20 | state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v 21 | 22 | for k, v in module.v_proj.lora_linear_layer.state_dict().items(): 23 | state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v 24 | 25 | for k, v in module.out_proj.lora_linear_layer.state_dict().items(): 26 | state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v 27 | 28 | return state_dict 29 | 30 | 31 | if is_transformers_available(): 32 | 33 | def text_encoder_attn_modules(text_encoder): 34 | deprecate( 35 | "text_encoder_attn_modules in `models`", 36 | "0.27.0", 37 | "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", 38 | ) 39 | from transformers import CLIPTextModel, CLIPTextModelWithProjection 40 | 41 | attn_modules = [] 42 | 43 | if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): 44 | for i, layer in enumerate(text_encoder.text_model.encoder.layers): 45 | name = f"text_model.encoder.layers.{i}.self_attn" 46 | mod = layer.self_attn 47 | attn_modules.append((name, mod)) 48 | else: 49 | raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") 50 | 51 | return attn_modules 52 | 53 | 54 | _import_structure = {} 55 | 56 | if is_torch_available(): 57 | _import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"] 58 | _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] 59 | _import_structure["utils"] = ["AttnProcsLayers"] 60 | 61 | if is_transformers_available(): 62 | _import_structure["single_file"].extend(["FromSingleFileMixin"]) 63 | _import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"] 64 | _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] 65 | _import_structure["ip_adapter"] = ["IPAdapterMixin"] 66 | 67 | 68 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 69 | if is_torch_available(): 70 | from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin 71 | from .unet import UNet2DConditionLoadersMixin 72 | from .utils import AttnProcsLayers 73 | 74 | if is_transformers_available(): 75 | from .ip_adapter import IPAdapterMixin 76 | from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin 77 | from .single_file import FromSingleFileMixin 78 | from .textual_inversion import TextualInversionLoaderMixin 79 | else: 80 | import sys 81 | 82 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 83 | -------------------------------------------------------------------------------- /src/diffusers/loaders/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | import torch 18 | 19 | 20 | class AttnProcsLayers(torch.nn.Module): 21 | def __init__(self, state_dict: Dict[str, torch.Tensor]): 22 | super().__init__() 23 | self.layers = torch.nn.ModuleList(state_dict.values()) 24 | self.mapping = dict(enumerate(state_dict.keys())) 25 | self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} 26 | 27 | # .processor for unet, .self_attn for text encoder 28 | self.split_keys = [".processor", ".self_attn"] 29 | 30 | # we add a hook to state_dict() and load_state_dict() so that the 31 | # naming fits with `unet.attn_processors` 32 | def map_to(module, state_dict, *args, **kwargs): 33 | new_state_dict = {} 34 | for key, value in state_dict.items(): 35 | num = int(key.split(".")[1]) # 0 is always "layers" 36 | new_key = key.replace(f"layers.{num}", module.mapping[num]) 37 | new_state_dict[new_key] = value 38 | 39 | return new_state_dict 40 | 41 | def remap_key(key, state_dict): 42 | for k in self.split_keys: 43 | if k in key: 44 | return key.split(k)[0] + k 45 | 46 | raise ValueError( 47 | f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." 48 | ) 49 | 50 | def map_from(module, state_dict, *args, **kwargs): 51 | all_keys = list(state_dict.keys()) 52 | for key in all_keys: 53 | replace_key = remap_key(key, state_dict) 54 | new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") 55 | state_dict[new_key] = state_dict[key] 56 | del state_dict[key] 57 | 58 | self._register_state_dict_hook(map_to) 59 | self._register_load_state_dict_pre_hook(map_from, with_module=True) 60 | -------------------------------------------------------------------------------- /src/diffusers/models/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview). -------------------------------------------------------------------------------- /src/diffusers/models/embeddings_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | 16 | import flax.linen as nn 17 | import jax.numpy as jnp 18 | 19 | 20 | def get_sinusoidal_embeddings( 21 | timesteps: jnp.ndarray, 22 | embedding_dim: int, 23 | freq_shift: float = 1, 24 | min_timescale: float = 1, 25 | max_timescale: float = 1.0e4, 26 | flip_sin_to_cos: bool = False, 27 | scale: float = 1.0, 28 | ) -> jnp.ndarray: 29 | """Returns the positional encoding (same as Tensor2Tensor). 30 | 31 | Args: 32 | timesteps: a 1-D Tensor of N indices, one per batch element. 33 | These may be fractional. 34 | embedding_dim: The number of output channels. 35 | min_timescale: The smallest time unit (should probably be 0.0). 36 | max_timescale: The largest time unit. 37 | Returns: 38 | a Tensor of timing signals [N, num_channels] 39 | """ 40 | assert timesteps.ndim == 1, "Timesteps should be a 1d-array" 41 | assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" 42 | num_timescales = float(embedding_dim // 2) 43 | log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) 44 | inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) 45 | emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) 46 | 47 | # scale embeddings 48 | scaled_time = scale * emb 49 | 50 | if flip_sin_to_cos: 51 | signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) 52 | else: 53 | signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) 54 | signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) 55 | return signal 56 | 57 | 58 | class FlaxTimestepEmbedding(nn.Module): 59 | r""" 60 | Time step Embedding Module. Learns embeddings for input time steps. 61 | 62 | Args: 63 | time_embed_dim (`int`, *optional*, defaults to `32`): 64 | Time step embedding dimension 65 | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): 66 | Parameters `dtype` 67 | """ 68 | 69 | time_embed_dim: int = 32 70 | dtype: jnp.dtype = jnp.float32 71 | 72 | @nn.compact 73 | def __call__(self, temb): 74 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) 75 | temb = nn.silu(temb) 76 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) 77 | return temb 78 | 79 | 80 | class FlaxTimesteps(nn.Module): 81 | r""" 82 | Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 83 | 84 | Args: 85 | dim (`int`, *optional*, defaults to `32`): 86 | Time step embedding dimension 87 | """ 88 | 89 | dim: int = 32 90 | flip_sin_to_cos: bool = False 91 | freq_shift: float = 1 92 | 93 | @nn.compact 94 | def __call__(self, timesteps): 95 | return get_sinusoidal_embeddings( 96 | timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift 97 | ) 98 | -------------------------------------------------------------------------------- /src/diffusers/models/modeling_outputs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ..utils import BaseOutput 4 | 5 | 6 | @dataclass 7 | class AutoencoderKLOutput(BaseOutput): 8 | """ 9 | Output of AutoencoderKL encoding method. 10 | 11 | Args: 12 | latent_dist (`DiagonalGaussianDistribution`): 13 | Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. 14 | `DiagonalGaussianDistribution` allows for sampling latents from the distribution. 15 | """ 16 | 17 | latent_dist: "DiagonalGaussianDistribution" # noqa: F821 18 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/alt_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["modeling_roberta_series"] = ["RobertaSeriesModelWithTransformation"] 25 | _import_structure["pipeline_alt_diffusion"] = ["AltDiffusionPipeline"] 26 | _import_structure["pipeline_alt_diffusion_img2img"] = ["AltDiffusionImg2ImgPipeline"] 27 | 28 | _import_structure["pipeline_output"] = ["AltDiffusionPipelineOutput"] 29 | 30 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 31 | try: 32 | if not (is_transformers_available() and is_torch_available()): 33 | raise OptionalDependencyNotAvailable() 34 | except OptionalDependencyNotAvailable: 35 | from ...utils.dummy_torch_and_transformers_objects import * 36 | 37 | else: 38 | from .modeling_roberta_series import RobertaSeriesModelWithTransformation 39 | from .pipeline_alt_diffusion import AltDiffusionPipeline 40 | from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline 41 | from .pipeline_output import AltDiffusionPipelineOutput 42 | 43 | else: 44 | import sys 45 | 46 | sys.modules[__name__] = _LazyModule( 47 | __name__, 48 | globals()["__file__"], 49 | _import_structure, 50 | module_spec=__spec__, 51 | ) 52 | for name, value in _dummy_objects.items(): 53 | setattr(sys.modules[__name__], name, value) 54 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/alt_diffusion/pipeline_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | 7 | from ...utils import ( 8 | BaseOutput, 9 | ) 10 | 11 | 12 | @dataclass 13 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt 14 | class AltDiffusionPipelineOutput(BaseOutput): 15 | """ 16 | Output class for Alt Diffusion pipelines. 17 | 18 | Args: 19 | images (`List[PIL.Image.Image]` or `np.ndarray`) 20 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 21 | num_channels)`. 22 | nsfw_content_detected (`List[bool]`) 23 | List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or 24 | `None` if safety checking could not be performed. 25 | """ 26 | 27 | images: Union[List[PIL.Image.Image], np.ndarray] 28 | nsfw_content_detected: Optional[List[bool]] 29 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/animatediff/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"] 25 | 26 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 27 | try: 28 | if not (is_transformers_available() and is_torch_available()): 29 | raise OptionalDependencyNotAvailable() 30 | except OptionalDependencyNotAvailable: 31 | from ...utils.dummy_torch_and_transformers_objects import * 32 | 33 | else: 34 | from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput 35 | 36 | else: 37 | import sys 38 | 39 | sys.modules[__name__] = _LazyModule( 40 | __name__, 41 | globals()["__file__"], 42 | _import_structure, 43 | module_spec=__spec__, 44 | ) 45 | for name, value in _dummy_objects.items(): 46 | setattr(sys.modules[__name__], name, value) 47 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/audio_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = { 7 | "mel": ["Mel"], 8 | "pipeline_audio_diffusion": ["AudioDiffusionPipeline"], 9 | } 10 | 11 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 12 | from .mel import Mel 13 | from .pipeline_audio_diffusion import AudioDiffusionPipeline 14 | 15 | else: 16 | import sys 17 | 18 | sys.modules[__name__] = _LazyModule( 19 | __name__, 20 | globals()["__file__"], 21 | _import_structure, 22 | module_spec=__spec__, 23 | ) 24 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/audioldm/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | is_torch_available, 8 | is_transformers_available, 9 | is_transformers_version, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils.dummy_torch_and_transformers_objects import ( 21 | AudioLDMPipeline, 22 | ) 23 | 24 | _dummy_objects.update({"AudioLDMPipeline": AudioLDMPipeline}) 25 | else: 26 | _import_structure["pipeline_audioldm"] = ["AudioLDMPipeline"] 27 | 28 | 29 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 30 | try: 31 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): 32 | raise OptionalDependencyNotAvailable() 33 | except OptionalDependencyNotAvailable: 34 | from ...utils.dummy_torch_and_transformers_objects import ( 35 | AudioLDMPipeline, 36 | ) 37 | 38 | else: 39 | from .pipeline_audioldm import AudioLDMPipeline 40 | else: 41 | import sys 42 | 43 | sys.modules[__name__] = _LazyModule( 44 | __name__, 45 | globals()["__file__"], 46 | _import_structure, 47 | module_spec=__spec__, 48 | ) 49 | 50 | for name, value in _dummy_objects.items(): 51 | setattr(sys.modules[__name__], name, value) 52 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/audioldm2/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | is_transformers_version, 11 | ) 12 | 13 | 14 | _dummy_objects = {} 15 | _import_structure = {} 16 | 17 | try: 18 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): 19 | raise OptionalDependencyNotAvailable() 20 | except OptionalDependencyNotAvailable: 21 | from ...utils import dummy_torch_and_transformers_objects 22 | 23 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 24 | else: 25 | _import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"] 26 | _import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"] 27 | 28 | 29 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 30 | try: 31 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): 32 | raise OptionalDependencyNotAvailable() 33 | except OptionalDependencyNotAvailable: 34 | from ...utils.dummy_torch_and_transformers_objects import * 35 | 36 | else: 37 | from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel 38 | from .pipeline_audioldm2 import AudioLDM2Pipeline 39 | 40 | else: 41 | import sys 42 | 43 | sys.modules[__name__] = _LazyModule( 44 | __name__, 45 | globals()["__file__"], 46 | _import_structure, 47 | module_spec=__spec__, 48 | ) 49 | for name, value in _dummy_objects.items(): 50 | setattr(sys.modules[__name__], name, value) 51 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/blip_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL 6 | from PIL import Image 7 | 8 | from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available 9 | 10 | 11 | try: 12 | if not (is_transformers_available() and is_torch_available()): 13 | raise OptionalDependencyNotAvailable() 14 | except OptionalDependencyNotAvailable: 15 | from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline 16 | else: 17 | from .blip_image_processing import BlipImageProcessor 18 | from .modeling_blip2 import Blip2QFormerModel 19 | from .modeling_ctx_clip import ContextCLIPTextModel 20 | from .pipeline_blip_diffusion import BlipDiffusionPipeline 21 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/consistency_models/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | _LazyModule, 6 | ) 7 | 8 | 9 | _import_structure = { 10 | "pipeline_consistency_models": ["ConsistencyModelPipeline"], 11 | } 12 | 13 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 14 | from .pipeline_consistency_models import ConsistencyModelPipeline 15 | 16 | else: 17 | import sys 18 | 19 | sys.modules[__name__] = _LazyModule( 20 | __name__, 21 | globals()["__file__"], 22 | _import_structure, 23 | module_spec=__spec__, 24 | ) 25 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/controlnet/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_flax_available, 9 | is_torch_available, 10 | is_transformers_available, 11 | ) 12 | 13 | 14 | _dummy_objects = {} 15 | _import_structure = {} 16 | 17 | try: 18 | if not (is_transformers_available() and is_torch_available()): 19 | raise OptionalDependencyNotAvailable() 20 | except OptionalDependencyNotAvailable: 21 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 22 | 23 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 24 | else: 25 | _import_structure["multicontrolnet"] = ["MultiControlNetModel"] 26 | _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"] 27 | _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"] 28 | _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"] 29 | _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"] 30 | _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"] 31 | _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"] 32 | _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"] 33 | try: 34 | if not (is_transformers_available() and is_flax_available()): 35 | raise OptionalDependencyNotAvailable() 36 | except OptionalDependencyNotAvailable: 37 | from ...utils import dummy_flax_and_transformers_objects # noqa F403 38 | 39 | _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) 40 | else: 41 | _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] 42 | 43 | 44 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 45 | try: 46 | if not (is_transformers_available() and is_torch_available()): 47 | raise OptionalDependencyNotAvailable() 48 | 49 | except OptionalDependencyNotAvailable: 50 | from ...utils.dummy_torch_and_transformers_objects import * 51 | else: 52 | from .multicontrolnet import MultiControlNetModel 53 | from .pipeline_controlnet import StableDiffusionControlNetPipeline 54 | from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline 55 | from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline 56 | from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline 57 | from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline 58 | from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline 59 | from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline 60 | 61 | try: 62 | if not (is_transformers_available() and is_flax_available()): 63 | raise OptionalDependencyNotAvailable() 64 | except OptionalDependencyNotAvailable: 65 | from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 66 | else: 67 | from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline 68 | 69 | 70 | else: 71 | import sys 72 | 73 | sys.modules[__name__] = _LazyModule( 74 | __name__, 75 | globals()["__file__"], 76 | _import_structure, 77 | module_spec=__spec__, 78 | ) 79 | for name, value in _dummy_objects.items(): 80 | setattr(sys.modules[__name__], name, value) 81 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/dance_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = {"pipeline_dance_diffusion": ["DanceDiffusionPipeline"]} 7 | 8 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 9 | from .pipeline_dance_diffusion import DanceDiffusionPipeline 10 | else: 11 | import sys 12 | 13 | sys.modules[__name__] = _LazyModule( 14 | __name__, 15 | globals()["__file__"], 16 | _import_structure, 17 | module_spec=__spec__, 18 | ) 19 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/ddim/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = {"pipeline_ddim": ["DDIMPipeline"]} 7 | 8 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 9 | from .pipeline_ddim import DDIMPipeline 10 | else: 11 | import sys 12 | 13 | sys.modules[__name__] = _LazyModule( 14 | __name__, 15 | globals()["__file__"], 16 | _import_structure, 17 | module_spec=__spec__, 18 | ) 19 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/ddpm/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | _LazyModule, 6 | ) 7 | 8 | 9 | _import_structure = {"pipeline_ddpm": ["DDPMPipeline"]} 10 | 11 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 12 | from .pipeline_ddpm import DDPMPipeline 13 | 14 | else: 15 | import sys 16 | 17 | sys.modules[__name__] = _LazyModule( 18 | __name__, 19 | globals()["__file__"], 20 | _import_structure, 21 | module_spec=__spec__, 22 | ) 23 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/deepfloyd_if/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = { 15 | "timesteps": [ 16 | "fast27_timesteps", 17 | "smart100_timesteps", 18 | "smart185_timesteps", 19 | "smart27_timesteps", 20 | "smart50_timesteps", 21 | "super100_timesteps", 22 | "super27_timesteps", 23 | "super40_timesteps", 24 | ] 25 | } 26 | 27 | try: 28 | if not (is_transformers_available() and is_torch_available()): 29 | raise OptionalDependencyNotAvailable() 30 | except OptionalDependencyNotAvailable: 31 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 32 | 33 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 34 | else: 35 | _import_structure["pipeline_if"] = ["IFPipeline"] 36 | _import_structure["pipeline_if_img2img"] = ["IFImg2ImgPipeline"] 37 | _import_structure["pipeline_if_img2img_superresolution"] = ["IFImg2ImgSuperResolutionPipeline"] 38 | _import_structure["pipeline_if_inpainting"] = ["IFInpaintingPipeline"] 39 | _import_structure["pipeline_if_inpainting_superresolution"] = ["IFInpaintingSuperResolutionPipeline"] 40 | _import_structure["pipeline_if_superresolution"] = ["IFSuperResolutionPipeline"] 41 | _import_structure["pipeline_output"] = ["IFPipelineOutput"] 42 | _import_structure["safety_checker"] = ["IFSafetyChecker"] 43 | _import_structure["watermark"] = ["IFWatermarker"] 44 | 45 | 46 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 47 | try: 48 | if not (is_transformers_available() and is_torch_available()): 49 | raise OptionalDependencyNotAvailable() 50 | 51 | except OptionalDependencyNotAvailable: 52 | from ...utils.dummy_torch_and_transformers_objects import * 53 | else: 54 | from .pipeline_if import IFPipeline 55 | from .pipeline_if_img2img import IFImg2ImgPipeline 56 | from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline 57 | from .pipeline_if_inpainting import IFInpaintingPipeline 58 | from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline 59 | from .pipeline_if_superresolution import IFSuperResolutionPipeline 60 | from .pipeline_output import IFPipelineOutput 61 | from .safety_checker import IFSafetyChecker 62 | from .timesteps import ( 63 | fast27_timesteps, 64 | smart27_timesteps, 65 | smart50_timesteps, 66 | smart100_timesteps, 67 | smart185_timesteps, 68 | super27_timesteps, 69 | super40_timesteps, 70 | super100_timesteps, 71 | ) 72 | from .watermark import IFWatermarker 73 | 74 | else: 75 | import sys 76 | 77 | sys.modules[__name__] = _LazyModule( 78 | __name__, 79 | globals()["__file__"], 80 | _import_structure, 81 | module_spec=__spec__, 82 | ) 83 | 84 | for name, value in _dummy_objects.items(): 85 | setattr(sys.modules[__name__], name, value) 86 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/deepfloyd_if/pipeline_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | 7 | from ...utils import BaseOutput 8 | 9 | 10 | @dataclass 11 | class IFPipelineOutput(BaseOutput): 12 | """ 13 | Args: 14 | Output class for Stable Diffusion pipelines. 15 | images (`List[PIL.Image.Image]` or `np.ndarray`) 16 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 17 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 18 | nsfw_detected (`List[bool]`) 19 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" 20 | (nsfw) content or a watermark. `None` if safety checking could not be performed. 21 | watermark_detected (`List[bool]`) 22 | List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety 23 | checking could not be performed. 24 | """ 25 | 26 | images: Union[List[PIL.Image.Image], np.ndarray] 27 | nsfw_detected: Optional[List[bool]] 28 | watermark_detected: Optional[List[bool]] 29 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/deepfloyd_if/safety_checker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel 5 | 6 | from ...utils import logging 7 | 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | class IFSafetyChecker(PreTrainedModel): 13 | config_class = CLIPConfig 14 | 15 | _no_split_modules = ["CLIPEncoderLayer"] 16 | 17 | def __init__(self, config: CLIPConfig): 18 | super().__init__(config) 19 | 20 | self.vision_model = CLIPVisionModelWithProjection(config.vision_config) 21 | 22 | self.p_head = nn.Linear(config.vision_config.projection_dim, 1) 23 | self.w_head = nn.Linear(config.vision_config.projection_dim, 1) 24 | 25 | @torch.no_grad() 26 | def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): 27 | image_embeds = self.vision_model(clip_input)[0] 28 | 29 | nsfw_detected = self.p_head(image_embeds) 30 | nsfw_detected = nsfw_detected.flatten() 31 | nsfw_detected = nsfw_detected > p_threshold 32 | nsfw_detected = nsfw_detected.tolist() 33 | 34 | if any(nsfw_detected): 35 | logger.warning( 36 | "Potential NSFW content was detected in one or more images. A black image will be returned instead." 37 | " Try again with a different prompt and/or seed." 38 | ) 39 | 40 | for idx, nsfw_detected_ in enumerate(nsfw_detected): 41 | if nsfw_detected_: 42 | images[idx] = np.zeros(images[idx].shape) 43 | 44 | watermark_detected = self.w_head(image_embeds) 45 | watermark_detected = watermark_detected.flatten() 46 | watermark_detected = watermark_detected > w_threshold 47 | watermark_detected = watermark_detected.tolist() 48 | 49 | if any(watermark_detected): 50 | logger.warning( 51 | "Potential watermarked content was detected in one or more images. A black image will be returned instead." 52 | " Try again with a different prompt and/or seed." 53 | ) 54 | 55 | for idx, watermark_detected_ in enumerate(watermark_detected): 56 | if watermark_detected_: 57 | images[idx] = np.zeros(images[idx].shape) 58 | 59 | return images, nsfw_detected, watermark_detected 60 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/deepfloyd_if/watermark.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import PIL.Image 4 | import torch 5 | from PIL import Image 6 | 7 | from ...configuration_utils import ConfigMixin 8 | from ...models.modeling_utils import ModelMixin 9 | from ...utils import PIL_INTERPOLATION 10 | 11 | 12 | class IFWatermarker(ModelMixin, ConfigMixin): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) 17 | self.watermark_image_as_pil = None 18 | 19 | def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None): 20 | # copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 21 | 22 | h = images[0].height 23 | w = images[0].width 24 | 25 | sample_size = sample_size or h 26 | 27 | coef = min(h / sample_size, w / sample_size) 28 | img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) 29 | 30 | S1, S2 = 1024**2, img_w * img_h 31 | K = (S2 / S1) ** 0.5 32 | wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) 33 | 34 | if self.watermark_image_as_pil is None: 35 | watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() 36 | watermark_image = Image.fromarray(watermark_image, mode="RGBA") 37 | self.watermark_image_as_pil = watermark_image 38 | 39 | wm_img = self.watermark_image_as_pil.resize( 40 | (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None 41 | ) 42 | 43 | for pil_img in images: 44 | pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) 45 | 46 | return images 47 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/dit/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = {"pipeline_dit": ["DiTPipeline"]} 7 | 8 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 9 | from .pipeline_dit import DiTPipeline 10 | 11 | else: 12 | import sys 13 | 14 | sys.modules[__name__] = _LazyModule( 15 | __name__, 16 | globals()["__file__"], 17 | _import_structure, 18 | module_spec=__spec__, 19 | ) 20 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/kandinsky/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["pipeline_kandinsky"] = ["KandinskyPipeline"] 25 | _import_structure["pipeline_kandinsky_combined"] = [ 26 | "KandinskyCombinedPipeline", 27 | "KandinskyImg2ImgCombinedPipeline", 28 | "KandinskyInpaintCombinedPipeline", 29 | ] 30 | _import_structure["pipeline_kandinsky_img2img"] = ["KandinskyImg2ImgPipeline"] 31 | _import_structure["pipeline_kandinsky_inpaint"] = ["KandinskyInpaintPipeline"] 32 | _import_structure["pipeline_kandinsky_prior"] = ["KandinskyPriorPipeline", "KandinskyPriorPipelineOutput"] 33 | _import_structure["text_encoder"] = ["MultilingualCLIP"] 34 | 35 | 36 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 37 | try: 38 | if not (is_transformers_available() and is_torch_available()): 39 | raise OptionalDependencyNotAvailable() 40 | except OptionalDependencyNotAvailable: 41 | from ...utils.dummy_torch_and_transformers_objects import * 42 | 43 | else: 44 | from .pipeline_kandinsky import KandinskyPipeline 45 | from .pipeline_kandinsky_combined import ( 46 | KandinskyCombinedPipeline, 47 | KandinskyImg2ImgCombinedPipeline, 48 | KandinskyInpaintCombinedPipeline, 49 | ) 50 | from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline 51 | from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline 52 | from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput 53 | from .text_encoder import MultilingualCLIP 54 | 55 | else: 56 | import sys 57 | 58 | sys.modules[__name__] = _LazyModule( 59 | __name__, 60 | globals()["__file__"], 61 | _import_structure, 62 | module_spec=__spec__, 63 | ) 64 | 65 | for name, value in _dummy_objects.items(): 66 | setattr(sys.modules[__name__], name, value) 67 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/kandinsky/text_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel 3 | 4 | 5 | class MCLIPConfig(XLMRobertaConfig): 6 | model_type = "M-CLIP" 7 | 8 | def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs): 9 | self.transformerDimensions = transformerDimSize 10 | self.numDims = imageDimSize 11 | super().__init__(**kwargs) 12 | 13 | 14 | class MultilingualCLIP(PreTrainedModel): 15 | config_class = MCLIPConfig 16 | 17 | def __init__(self, config, *args, **kwargs): 18 | super().__init__(config, *args, **kwargs) 19 | self.transformer = XLMRobertaModel(config) 20 | self.LinearTransformation = torch.nn.Linear( 21 | in_features=config.transformerDimensions, out_features=config.numDims 22 | ) 23 | 24 | def forward(self, input_ids, attention_mask): 25 | embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] 26 | embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] 27 | return self.LinearTransformation(embs2), embs 28 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/kandinsky2_2/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["pipeline_kandinsky2_2"] = ["KandinskyV22Pipeline"] 25 | _import_structure["pipeline_kandinsky2_2_combined"] = [ 26 | "KandinskyV22CombinedPipeline", 27 | "KandinskyV22Img2ImgCombinedPipeline", 28 | "KandinskyV22InpaintCombinedPipeline", 29 | ] 30 | _import_structure["pipeline_kandinsky2_2_controlnet"] = ["KandinskyV22ControlnetPipeline"] 31 | _import_structure["pipeline_kandinsky2_2_controlnet_img2img"] = ["KandinskyV22ControlnetImg2ImgPipeline"] 32 | _import_structure["pipeline_kandinsky2_2_img2img"] = ["KandinskyV22Img2ImgPipeline"] 33 | _import_structure["pipeline_kandinsky2_2_inpainting"] = ["KandinskyV22InpaintPipeline"] 34 | _import_structure["pipeline_kandinsky2_2_prior"] = ["KandinskyV22PriorPipeline"] 35 | _import_structure["pipeline_kandinsky2_2_prior_emb2emb"] = ["KandinskyV22PriorEmb2EmbPipeline"] 36 | 37 | 38 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 39 | try: 40 | if not (is_transformers_available() and is_torch_available()): 41 | raise OptionalDependencyNotAvailable() 42 | 43 | except OptionalDependencyNotAvailable: 44 | from ...utils.dummy_torch_and_transformers_objects import * 45 | else: 46 | from .pipeline_kandinsky2_2 import KandinskyV22Pipeline 47 | from .pipeline_kandinsky2_2_combined import ( 48 | KandinskyV22CombinedPipeline, 49 | KandinskyV22Img2ImgCombinedPipeline, 50 | KandinskyV22InpaintCombinedPipeline, 51 | ) 52 | from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline 53 | from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline 54 | from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline 55 | from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline 56 | from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline 57 | from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline 58 | 59 | else: 60 | import sys 61 | 62 | sys.modules[__name__] = _LazyModule( 63 | __name__, 64 | globals()["__file__"], 65 | _import_structure, 66 | module_spec=__spec__, 67 | ) 68 | 69 | for name, value in _dummy_objects.items(): 70 | setattr(sys.modules[__name__], name, value) 71 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/kandinsky3/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"] 25 | _import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"] 26 | 27 | 28 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 29 | try: 30 | if not (is_transformers_available() and is_torch_available()): 31 | raise OptionalDependencyNotAvailable() 32 | 33 | except OptionalDependencyNotAvailable: 34 | from ...utils.dummy_torch_and_transformers_objects import * 35 | else: 36 | from .kandinsky3_pipeline import Kandinsky3Pipeline 37 | from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline 38 | else: 39 | import sys 40 | 41 | sys.modules[__name__] = _LazyModule( 42 | __name__, 43 | globals()["__file__"], 44 | _import_structure, 45 | module_spec=__spec__, 46 | ) 47 | 48 | for name, value in _dummy_objects.items(): 49 | setattr(sys.modules[__name__], name, value) 50 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/latent_consistency_models/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | 17 | try: 18 | if not (is_transformers_available() and is_torch_available()): 19 | raise OptionalDependencyNotAvailable() 20 | except OptionalDependencyNotAvailable: 21 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 22 | 23 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 24 | else: 25 | _import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"] 26 | _import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"] 27 | 28 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 29 | try: 30 | if not (is_transformers_available() and is_torch_available()): 31 | raise OptionalDependencyNotAvailable() 32 | 33 | except OptionalDependencyNotAvailable: 34 | from ...utils.dummy_torch_and_transformers_objects import * 35 | else: 36 | from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline 37 | from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline 38 | 39 | else: 40 | import sys 41 | 42 | sys.modules[__name__] = _LazyModule( 43 | __name__, 44 | globals()["__file__"], 45 | _import_structure, 46 | module_spec=__spec__, 47 | ) 48 | 49 | for name, value in _dummy_objects.items(): 50 | setattr(sys.modules[__name__], name, value) 51 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/latent_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["pipeline_latent_diffusion"] = ["LDMBertModel", "LDMTextToImagePipeline"] 25 | _import_structure["pipeline_latent_diffusion_superresolution"] = ["LDMSuperResolutionPipeline"] 26 | 27 | 28 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 29 | try: 30 | if not (is_transformers_available() and is_torch_available()): 31 | raise OptionalDependencyNotAvailable() 32 | 33 | except OptionalDependencyNotAvailable: 34 | from ...utils.dummy_torch_and_transformers_objects import * 35 | else: 36 | from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline 37 | from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline 38 | 39 | else: 40 | import sys 41 | 42 | sys.modules[__name__] = _LazyModule( 43 | __name__, 44 | globals()["__file__"], 45 | _import_structure, 46 | module_spec=__spec__, 47 | ) 48 | 49 | for name, value in _dummy_objects.items(): 50 | setattr(sys.modules[__name__], name, value) 51 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/latent_diffusion_uncond/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = {"pipeline_latent_diffusion_uncond": ["LDMPipeline"]} 7 | 8 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 9 | from .pipeline_latent_diffusion_uncond import LDMPipeline 10 | else: 11 | import sys 12 | 13 | sys.modules[__name__] = _LazyModule( 14 | __name__, 15 | globals()["__file__"], 16 | _import_structure, 17 | module_spec=__spec__, 18 | ) 19 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/musicldm/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | is_transformers_version, 11 | ) 12 | 13 | 14 | _dummy_objects = {} 15 | _import_structure = {} 16 | 17 | try: 18 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): 19 | raise OptionalDependencyNotAvailable() 20 | except OptionalDependencyNotAvailable: 21 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 22 | 23 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 24 | else: 25 | _import_structure["pipeline_musicldm"] = ["MusicLDMPipeline"] 26 | 27 | 28 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 29 | try: 30 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): 31 | raise OptionalDependencyNotAvailable() 32 | 33 | except OptionalDependencyNotAvailable: 34 | from ...utils.dummy_torch_and_transformers_objects import * 35 | else: 36 | from .pipeline_musicldm import MusicLDMPipeline 37 | 38 | else: 39 | import sys 40 | 41 | sys.modules[__name__] = _LazyModule( 42 | __name__, 43 | globals()["__file__"], 44 | _import_structure, 45 | module_spec=__spec__, 46 | ) 47 | 48 | for name, value in _dummy_objects.items(): 49 | setattr(sys.modules[__name__], name, value) 50 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/paint_by_example/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import TYPE_CHECKING, List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL 6 | from PIL import Image 7 | 8 | from ...utils import ( 9 | DIFFUSERS_SLOW_IMPORT, 10 | OptionalDependencyNotAvailable, 11 | _LazyModule, 12 | get_objects_from_module, 13 | is_torch_available, 14 | is_transformers_available, 15 | ) 16 | 17 | 18 | _dummy_objects = {} 19 | _import_structure = {} 20 | 21 | try: 22 | if not (is_transformers_available() and is_torch_available()): 23 | raise OptionalDependencyNotAvailable() 24 | except OptionalDependencyNotAvailable: 25 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 26 | 27 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 28 | else: 29 | _import_structure["image_encoder"] = ["PaintByExampleImageEncoder"] 30 | _import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"] 31 | 32 | 33 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 34 | try: 35 | if not (is_transformers_available() and is_torch_available()): 36 | raise OptionalDependencyNotAvailable() 37 | 38 | except OptionalDependencyNotAvailable: 39 | from ...utils.dummy_torch_and_transformers_objects import * 40 | else: 41 | from .image_encoder import PaintByExampleImageEncoder 42 | from .pipeline_paint_by_example import PaintByExamplePipeline 43 | 44 | else: 45 | import sys 46 | 47 | sys.modules[__name__] = _LazyModule( 48 | __name__, 49 | globals()["__file__"], 50 | _import_structure, 51 | module_spec=__spec__, 52 | ) 53 | 54 | for name, value in _dummy_objects.items(): 55 | setattr(sys.modules[__name__], name, value) 56 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/paint_by_example/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from torch import nn 16 | from transformers import CLIPPreTrainedModel, CLIPVisionModel 17 | 18 | from ...models.attention import BasicTransformerBlock 19 | from ...utils import logging 20 | 21 | 22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 23 | 24 | 25 | class PaintByExampleImageEncoder(CLIPPreTrainedModel): 26 | def __init__(self, config, proj_size=None): 27 | super().__init__(config) 28 | self.proj_size = proj_size or getattr(config, "projection_dim", 768) 29 | 30 | self.model = CLIPVisionModel(config) 31 | self.mapper = PaintByExampleMapper(config) 32 | self.final_layer_norm = nn.LayerNorm(config.hidden_size) 33 | self.proj_out = nn.Linear(config.hidden_size, self.proj_size) 34 | 35 | # uncondition for scaling 36 | self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) 37 | 38 | def forward(self, pixel_values, return_uncond_vector=False): 39 | clip_output = self.model(pixel_values=pixel_values) 40 | latent_states = clip_output.pooler_output 41 | latent_states = self.mapper(latent_states[:, None]) 42 | latent_states = self.final_layer_norm(latent_states) 43 | latent_states = self.proj_out(latent_states) 44 | if return_uncond_vector: 45 | return latent_states, self.uncond_vector 46 | 47 | return latent_states 48 | 49 | 50 | class PaintByExampleMapper(nn.Module): 51 | def __init__(self, config): 52 | super().__init__() 53 | num_layers = (config.num_hidden_layers + 1) // 5 54 | hid_size = config.hidden_size 55 | num_heads = 1 56 | self.blocks = nn.ModuleList( 57 | [ 58 | BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) 59 | for _ in range(num_layers) 60 | ] 61 | ) 62 | 63 | def forward(self, hidden_states): 64 | for block in self.blocks: 65 | hidden_states = block(hidden_states) 66 | 67 | return hidden_states 68 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/pixart_alpha/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | 17 | try: 18 | if not (is_transformers_available() and is_torch_available()): 19 | raise OptionalDependencyNotAvailable() 20 | except OptionalDependencyNotAvailable: 21 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 22 | 23 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 24 | else: 25 | _import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"] 26 | 27 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 28 | try: 29 | if not (is_transformers_available() and is_torch_available()): 30 | raise OptionalDependencyNotAvailable() 31 | 32 | except OptionalDependencyNotAvailable: 33 | from ...utils.dummy_torch_and_transformers_objects import * 34 | else: 35 | from .pipeline_pixart_alpha import PixArtAlphaPipeline 36 | 37 | else: 38 | import sys 39 | 40 | sys.modules[__name__] = _LazyModule( 41 | __name__, 42 | globals()["__file__"], 43 | _import_structure, 44 | module_spec=__spec__, 45 | ) 46 | 47 | for name, value in _dummy_objects.items(): 48 | setattr(sys.modules[__name__], name, value) 49 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/pndm/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = {"pipeline_pndm": ["PNDMPipeline"]} 7 | 8 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 9 | from .pipeline_pndm import PNDMPipeline 10 | else: 11 | import sys 12 | 13 | sys.modules[__name__] = _LazyModule( 14 | __name__, 15 | globals()["__file__"], 16 | _import_structure, 17 | module_spec=__spec__, 18 | ) 19 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/repaint/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = {"pipeline_repaint": ["RePaintPipeline"]} 7 | 8 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 9 | from .pipeline_repaint import RePaintPipeline 10 | 11 | else: 12 | import sys 13 | 14 | sys.modules[__name__] = _LazyModule( 15 | __name__, 16 | globals()["__file__"], 17 | _import_structure, 18 | module_spec=__spec__, 19 | ) 20 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/score_sde_ve/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = {"pipeline_score_sde_ve": ["ScoreSdeVePipeline"]} 7 | 8 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 9 | from .pipeline_score_sde_ve import ScoreSdeVePipeline 10 | 11 | else: 12 | import sys 13 | 14 | sys.modules[__name__] = _LazyModule( 15 | __name__, 16 | globals()["__file__"], 17 | _import_structure, 18 | module_spec=__spec__, 19 | ) 20 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/semantic_stable_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["pipeline_output"] = ["SemanticStableDiffusionPipelineOutput"] 25 | _import_structure["pipeline_semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] 26 | 27 | 28 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 29 | try: 30 | if not (is_transformers_available() and is_torch_available()): 31 | raise OptionalDependencyNotAvailable() 32 | 33 | except OptionalDependencyNotAvailable: 34 | from ...utils.dummy_torch_and_transformers_objects import * 35 | else: 36 | from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline 37 | 38 | else: 39 | import sys 40 | 41 | sys.modules[__name__] = _LazyModule( 42 | __name__, 43 | globals()["__file__"], 44 | _import_structure, 45 | module_spec=__spec__, 46 | ) 47 | 48 | for name, value in _dummy_objects.items(): 49 | setattr(sys.modules[__name__], name, value) 50 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | 7 | from ...utils import BaseOutput 8 | 9 | 10 | @dataclass 11 | class SemanticStableDiffusionPipelineOutput(BaseOutput): 12 | """ 13 | Output class for Stable Diffusion pipelines. 14 | 15 | Args: 16 | images (`List[PIL.Image.Image]` or `np.ndarray`) 17 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 18 | num_channels)`. 19 | nsfw_content_detected (`List[bool]`) 20 | List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or 21 | `None` if safety checking could not be performed. 22 | """ 23 | 24 | images: Union[List[PIL.Image.Image], np.ndarray] 25 | nsfw_content_detected: Optional[List[bool]] 26 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/shap_e/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["camera"] = ["create_pan_cameras"] 25 | _import_structure["pipeline_shap_e"] = ["ShapEPipeline"] 26 | _import_structure["pipeline_shap_e_img2img"] = ["ShapEImg2ImgPipeline"] 27 | _import_structure["renderer"] = [ 28 | "BoundingBoxVolume", 29 | "ImportanceRaySampler", 30 | "MLPNeRFModelOutput", 31 | "MLPNeRSTFModel", 32 | "ShapEParamsProjModel", 33 | "ShapERenderer", 34 | "StratifiedRaySampler", 35 | "VoidNeRFModel", 36 | ] 37 | 38 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 39 | try: 40 | if not (is_transformers_available() and is_torch_available()): 41 | raise OptionalDependencyNotAvailable() 42 | 43 | except OptionalDependencyNotAvailable: 44 | from ...utils.dummy_torch_and_transformers_objects import * 45 | else: 46 | from .camera import create_pan_cameras 47 | from .pipeline_shap_e import ShapEPipeline 48 | from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline 49 | from .renderer import ( 50 | BoundingBoxVolume, 51 | ImportanceRaySampler, 52 | MLPNeRFModelOutput, 53 | MLPNeRSTFModel, 54 | ShapEParamsProjModel, 55 | ShapERenderer, 56 | StratifiedRaySampler, 57 | VoidNeRFModel, 58 | ) 59 | 60 | else: 61 | import sys 62 | 63 | sys.modules[__name__] = _LazyModule( 64 | __name__, 65 | globals()["__file__"], 66 | _import_structure, 67 | module_spec=__spec__, 68 | ) 69 | 70 | for name, value in _dummy_objects.items(): 71 | setattr(sys.modules[__name__], name, value) 72 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/spectrogram_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import TYPE_CHECKING 3 | from ...utils import DIFFUSERS_SLOW_IMPORT 4 | from ...utils import ( 5 | _LazyModule, 6 | is_note_seq_available, 7 | OptionalDependencyNotAvailable, 8 | is_torch_available, 9 | is_transformers_available, 10 | get_objects_from_module, 11 | ) 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["continous_encoder"] = ["SpectrogramContEncoder"] 25 | _import_structure["notes_encoder"] = ["SpectrogramNotesEncoder"] 26 | _import_structure["pipeline_spectrogram_diffusion"] = [ 27 | "SpectrogramContEncoder", 28 | "SpectrogramDiffusionPipeline", 29 | "T5FilmDecoder", 30 | ] 31 | try: 32 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): 33 | raise OptionalDependencyNotAvailable() 34 | except OptionalDependencyNotAvailable: 35 | from ...utils import dummy_transformers_and_torch_and_note_seq_objects 36 | 37 | _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) 38 | else: 39 | _import_structure["midi_utils"] = ["MidiProcessor"] 40 | 41 | 42 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 43 | try: 44 | if not (is_transformers_available() and is_torch_available()): 45 | raise OptionalDependencyNotAvailable() 46 | 47 | except OptionalDependencyNotAvailable: 48 | from ...utils.dummy_torch_and_transformers_objects import * 49 | else: 50 | from .pipeline_spectrogram_diffusion import SpectrogramDiffusionPipeline 51 | from .pipeline_spectrogram_diffusion import SpectrogramContEncoder 52 | from .pipeline_spectrogram_diffusion import SpectrogramNotesEncoder 53 | from .pipeline_spectrogram_diffusion import T5FilmDecoder 54 | 55 | try: 56 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): 57 | raise OptionalDependencyNotAvailable() 58 | except OptionalDependencyNotAvailable: 59 | from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * 60 | 61 | else: 62 | from .midi_utils import MidiProcessor 63 | 64 | else: 65 | import sys 66 | 67 | sys.modules[__name__] = _LazyModule( 68 | __name__, 69 | globals()["__file__"], 70 | _import_structure, 71 | module_spec=__spec__, 72 | ) 73 | 74 | for name, value in _dummy_objects.items(): 75 | setattr(sys.modules[__name__], name, value) 76 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Music Spectrogram Diffusion Authors. 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | from transformers.modeling_utils import ModuleUtilsMixin 19 | from transformers.models.t5.modeling_t5 import ( 20 | T5Block, 21 | T5Config, 22 | T5LayerNorm, 23 | ) 24 | 25 | from ...configuration_utils import ConfigMixin, register_to_config 26 | from ...models import ModelMixin 27 | 28 | 29 | class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): 30 | @register_to_config 31 | def __init__( 32 | self, 33 | input_dims: int, 34 | targets_context_length: int, 35 | d_model: int, 36 | dropout_rate: float, 37 | num_layers: int, 38 | num_heads: int, 39 | d_kv: int, 40 | d_ff: int, 41 | feed_forward_proj: str, 42 | is_decoder: bool = False, 43 | ): 44 | super().__init__() 45 | 46 | self.input_proj = nn.Linear(input_dims, d_model, bias=False) 47 | 48 | self.position_encoding = nn.Embedding(targets_context_length, d_model) 49 | self.position_encoding.weight.requires_grad = False 50 | 51 | self.dropout_pre = nn.Dropout(p=dropout_rate) 52 | 53 | t5config = T5Config( 54 | d_model=d_model, 55 | num_heads=num_heads, 56 | d_kv=d_kv, 57 | d_ff=d_ff, 58 | feed_forward_proj=feed_forward_proj, 59 | dropout_rate=dropout_rate, 60 | is_decoder=is_decoder, 61 | is_encoder_decoder=False, 62 | ) 63 | self.encoders = nn.ModuleList() 64 | for lyr_num in range(num_layers): 65 | lyr = T5Block(t5config) 66 | self.encoders.append(lyr) 67 | 68 | self.layer_norm = T5LayerNorm(d_model) 69 | self.dropout_post = nn.Dropout(p=dropout_rate) 70 | 71 | def forward(self, encoder_inputs, encoder_inputs_mask): 72 | x = self.input_proj(encoder_inputs) 73 | 74 | # terminal relative positional encodings 75 | max_positions = encoder_inputs.shape[1] 76 | input_positions = torch.arange(max_positions, device=encoder_inputs.device) 77 | 78 | seq_lens = encoder_inputs_mask.sum(-1) 79 | input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) 80 | x += self.position_encoding(input_positions) 81 | 82 | x = self.dropout_pre(x) 83 | 84 | # inverted the attention mask 85 | input_shape = encoder_inputs.size() 86 | extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) 87 | 88 | for lyr in self.encoders: 89 | x = lyr(x, extended_attention_mask)[0] 90 | x = self.layer_norm(x) 91 | 92 | return self.dropout_post(x), encoder_inputs_mask 93 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Music Spectrogram Diffusion Authors. 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | from transformers.modeling_utils import ModuleUtilsMixin 19 | from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm 20 | 21 | from ...configuration_utils import ConfigMixin, register_to_config 22 | from ...models import ModelMixin 23 | 24 | 25 | class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): 26 | @register_to_config 27 | def __init__( 28 | self, 29 | max_length: int, 30 | vocab_size: int, 31 | d_model: int, 32 | dropout_rate: float, 33 | num_layers: int, 34 | num_heads: int, 35 | d_kv: int, 36 | d_ff: int, 37 | feed_forward_proj: str, 38 | is_decoder: bool = False, 39 | ): 40 | super().__init__() 41 | 42 | self.token_embedder = nn.Embedding(vocab_size, d_model) 43 | 44 | self.position_encoding = nn.Embedding(max_length, d_model) 45 | self.position_encoding.weight.requires_grad = False 46 | 47 | self.dropout_pre = nn.Dropout(p=dropout_rate) 48 | 49 | t5config = T5Config( 50 | vocab_size=vocab_size, 51 | d_model=d_model, 52 | num_heads=num_heads, 53 | d_kv=d_kv, 54 | d_ff=d_ff, 55 | dropout_rate=dropout_rate, 56 | feed_forward_proj=feed_forward_proj, 57 | is_decoder=is_decoder, 58 | is_encoder_decoder=False, 59 | ) 60 | 61 | self.encoders = nn.ModuleList() 62 | for lyr_num in range(num_layers): 63 | lyr = T5Block(t5config) 64 | self.encoders.append(lyr) 65 | 66 | self.layer_norm = T5LayerNorm(d_model) 67 | self.dropout_post = nn.Dropout(p=dropout_rate) 68 | 69 | def forward(self, encoder_input_tokens, encoder_inputs_mask): 70 | x = self.token_embedder(encoder_input_tokens) 71 | 72 | seq_length = encoder_input_tokens.shape[1] 73 | inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) 74 | x += self.position_encoding(inputs_positions) 75 | 76 | x = self.dropout_pre(x) 77 | 78 | # inverted the attention mask 79 | input_shape = encoder_input_tokens.size() 80 | extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) 81 | 82 | for lyr in self.encoders: 83 | x = lyr(x, extended_attention_mask)[0] 84 | x = self.layer_norm(x) 85 | 86 | return self.dropout_post(x), encoder_inputs_mask 87 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The GLIGEN Authors and HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch import nn 16 | 17 | from ...configuration_utils import ConfigMixin, register_to_config 18 | from ...models.modeling_utils import ModelMixin 19 | 20 | 21 | class CLIPImageProjection(ModelMixin, ConfigMixin): 22 | @register_to_config 23 | def __init__(self, hidden_size: int = 768): 24 | super().__init__() 25 | self.hidden_size = hidden_size 26 | self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 27 | 28 | def forward(self, x): 29 | return self.project(x) 30 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stable_diffusion/pipeline_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | 7 | from ...utils import BaseOutput, is_flax_available 8 | 9 | 10 | @dataclass 11 | class StableDiffusionPipelineOutput(BaseOutput): 12 | """ 13 | Output class for Stable Diffusion pipelines. 14 | 15 | Args: 16 | images (`List[PIL.Image.Image]` or `np.ndarray`) 17 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 18 | num_channels)`. 19 | nsfw_content_detected (`List[bool]`) 20 | List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or 21 | `None` if safety checking could not be performed. 22 | """ 23 | 24 | images: Union[List[PIL.Image.Image], np.ndarray] 25 | nsfw_content_detected: Optional[List[bool]] 26 | 27 | 28 | if is_flax_available(): 29 | import flax 30 | 31 | @flax.struct.dataclass 32 | class FlaxStableDiffusionPipelineOutput(BaseOutput): 33 | """ 34 | Output class for Flax-based Stable Diffusion pipelines. 35 | 36 | Args: 37 | images (`np.ndarray`): 38 | Denoised images of array shape of `(batch_size, height, width, num_channels)`. 39 | nsfw_content_detected (`List[bool]`): 40 | List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content 41 | or `None` if safety checking could not be performed. 42 | """ 43 | 44 | images: np.ndarray 45 | nsfw_content_detected: List[bool] 46 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Union 16 | 17 | import torch 18 | from torch import nn 19 | 20 | from ...configuration_utils import ConfigMixin, register_to_config 21 | from ...models.modeling_utils import ModelMixin 22 | 23 | 24 | class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): 25 | """ 26 | This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. 27 | 28 | It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image 29 | embeddings. 30 | """ 31 | 32 | @register_to_config 33 | def __init__( 34 | self, 35 | embedding_dim: int = 768, 36 | ): 37 | super().__init__() 38 | 39 | self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) 40 | self.std = nn.Parameter(torch.ones(1, embedding_dim)) 41 | 42 | def to( 43 | self, 44 | torch_device: Optional[Union[str, torch.device]] = None, 45 | torch_dtype: Optional[torch.dtype] = None, 46 | ): 47 | self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) 48 | self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) 49 | return self 50 | 51 | def scale(self, embeds): 52 | embeds = (embeds - self.mean) * 1.0 / self.std 53 | return embeds 54 | 55 | def unscale(self, embeds): 56 | embeds = (embeds * self.std) + self.mean 57 | return embeds 58 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stable_diffusion_safe/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import TYPE_CHECKING, List, Optional, Union 4 | 5 | import numpy as np 6 | import PIL 7 | from PIL import Image 8 | 9 | from ...utils import ( 10 | DIFFUSERS_SLOW_IMPORT, 11 | BaseOutput, 12 | OptionalDependencyNotAvailable, 13 | _LazyModule, 14 | get_objects_from_module, 15 | is_torch_available, 16 | is_transformers_available, 17 | ) 18 | 19 | 20 | @dataclass 21 | class SafetyConfig(object): 22 | WEAK = { 23 | "sld_warmup_steps": 15, 24 | "sld_guidance_scale": 20, 25 | "sld_threshold": 0.0, 26 | "sld_momentum_scale": 0.0, 27 | "sld_mom_beta": 0.0, 28 | } 29 | MEDIUM = { 30 | "sld_warmup_steps": 10, 31 | "sld_guidance_scale": 1000, 32 | "sld_threshold": 0.01, 33 | "sld_momentum_scale": 0.3, 34 | "sld_mom_beta": 0.4, 35 | } 36 | STRONG = { 37 | "sld_warmup_steps": 7, 38 | "sld_guidance_scale": 2000, 39 | "sld_threshold": 0.025, 40 | "sld_momentum_scale": 0.5, 41 | "sld_mom_beta": 0.7, 42 | } 43 | MAX = { 44 | "sld_warmup_steps": 0, 45 | "sld_guidance_scale": 5000, 46 | "sld_threshold": 1.0, 47 | "sld_momentum_scale": 0.5, 48 | "sld_mom_beta": 0.7, 49 | } 50 | 51 | 52 | _dummy_objects = {} 53 | _additional_imports = {} 54 | _import_structure = {} 55 | 56 | _additional_imports.update({"SafetyConfig": SafetyConfig}) 57 | 58 | try: 59 | if not (is_transformers_available() and is_torch_available()): 60 | raise OptionalDependencyNotAvailable() 61 | except OptionalDependencyNotAvailable: 62 | from ...utils import dummy_torch_and_transformers_objects 63 | 64 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 65 | else: 66 | _import_structure.update( 67 | { 68 | "pipeline_output": ["StableDiffusionSafePipelineOutput"], 69 | "pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"], 70 | "safety_checker": ["StableDiffusionSafetyChecker"], 71 | } 72 | ) 73 | 74 | 75 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 76 | try: 77 | if not (is_transformers_available() and is_torch_available()): 78 | raise OptionalDependencyNotAvailable() 79 | except OptionalDependencyNotAvailable: 80 | from ...utils.dummy_torch_and_transformers_objects import * 81 | else: 82 | from .pipeline_output import StableDiffusionSafePipelineOutput 83 | from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe 84 | from .safety_checker import SafeStableDiffusionSafetyChecker 85 | 86 | else: 87 | import sys 88 | 89 | sys.modules[__name__] = _LazyModule( 90 | __name__, 91 | globals()["__file__"], 92 | _import_structure, 93 | module_spec=__spec__, 94 | ) 95 | 96 | for name, value in _dummy_objects.items(): 97 | setattr(sys.modules[__name__], name, value) 98 | for name, value in _additional_imports.items(): 99 | setattr(sys.modules[__name__], name, value) 100 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | 7 | from ...utils import ( 8 | BaseOutput, 9 | ) 10 | 11 | 12 | @dataclass 13 | class StableDiffusionSafePipelineOutput(BaseOutput): 14 | """ 15 | Output class for Safe Stable Diffusion pipelines. 16 | 17 | Args: 18 | images (`List[PIL.Image.Image]` or `np.ndarray`) 19 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 20 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 21 | nsfw_content_detected (`List[bool]`) 22 | List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" 23 | (nsfw) content, or `None` if safety checking could not be performed. 24 | images (`List[PIL.Image.Image]` or `np.ndarray`) 25 | List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" 26 | (nsfw) content, or `None` if no safety check was performed or no images were flagged. 27 | applied_safety_concept (`str`) 28 | The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled 29 | """ 30 | 31 | images: Union[List[PIL.Image.Image], np.ndarray] 32 | nsfw_content_detected: Optional[List[bool]] 33 | unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] 34 | applied_safety_concept: Optional[str] 35 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stable_diffusion_xl/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_flax_available, 9 | is_torch_available, 10 | is_transformers_available, 11 | ) 12 | 13 | 14 | _dummy_objects = {} 15 | _additional_imports = {} 16 | _import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]} 17 | 18 | if is_transformers_available() and is_flax_available(): 19 | _import_structure["pipeline_output"].extend(["FlaxStableDiffusionXLPipelineOutput"]) 20 | try: 21 | if not (is_transformers_available() and is_torch_available()): 22 | raise OptionalDependencyNotAvailable() 23 | except OptionalDependencyNotAvailable: 24 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 25 | 26 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 27 | else: 28 | _import_structure["pipeline_stable_diffusion_xl"] = ["StableDiffusionXLPipeline"] 29 | _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] 30 | _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] 31 | _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] 32 | 33 | if is_transformers_available() and is_flax_available(): 34 | from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState 35 | 36 | _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState}) 37 | _import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"] 38 | 39 | 40 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 41 | try: 42 | if not (is_transformers_available() and is_torch_available()): 43 | raise OptionalDependencyNotAvailable() 44 | except OptionalDependencyNotAvailable: 45 | from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 46 | else: 47 | from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline 48 | from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline 49 | from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline 50 | from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline 51 | 52 | try: 53 | if not (is_transformers_available() and is_flax_available()): 54 | raise OptionalDependencyNotAvailable() 55 | except OptionalDependencyNotAvailable: 56 | from ...utils.dummy_flax_objects import * 57 | else: 58 | from .pipeline_flax_stable_diffusion_xl import ( 59 | FlaxStableDiffusionXLPipeline, 60 | ) 61 | from .pipeline_output import FlaxStableDiffusionXLPipelineOutput 62 | 63 | else: 64 | import sys 65 | 66 | sys.modules[__name__] = _LazyModule( 67 | __name__, 68 | globals()["__file__"], 69 | _import_structure, 70 | module_spec=__spec__, 71 | ) 72 | 73 | for name, value in _dummy_objects.items(): 74 | setattr(sys.modules[__name__], name, value) 75 | for name, value in _additional_imports.items(): 76 | setattr(sys.modules[__name__], name, value) 77 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | 7 | from ...utils import BaseOutput, is_flax_available 8 | 9 | 10 | @dataclass 11 | class StableDiffusionXLPipelineOutput(BaseOutput): 12 | """ 13 | Output class for Stable Diffusion pipelines. 14 | 15 | Args: 16 | images (`List[PIL.Image.Image]` or `np.ndarray`) 17 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 18 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 19 | """ 20 | 21 | images: Union[List[PIL.Image.Image], np.ndarray] 22 | 23 | 24 | if is_flax_available(): 25 | import flax 26 | 27 | @flax.struct.dataclass 28 | class FlaxStableDiffusionXLPipelineOutput(BaseOutput): 29 | """ 30 | Output class for Flax Stable Diffusion XL pipelines. 31 | 32 | Args: 33 | images (`np.ndarray`) 34 | Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. 35 | """ 36 | 37 | images: np.ndarray 38 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stable_diffusion_xl/watermark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ...utils import is_invisible_watermark_available 5 | 6 | 7 | if is_invisible_watermark_available(): 8 | from imwatermark import WatermarkEncoder 9 | 10 | 11 | # Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 12 | WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 13 | # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 14 | WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] 15 | 16 | 17 | class StableDiffusionXLWatermarker: 18 | def __init__(self): 19 | self.watermark = WATERMARK_BITS 20 | self.encoder = WatermarkEncoder() 21 | 22 | self.encoder.set_watermark("bits", self.watermark) 23 | 24 | def apply_watermark(self, images: torch.FloatTensor): 25 | # can't encode images that are smaller than 256 26 | if images.shape[-1] < 256: 27 | return images 28 | 29 | images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() 30 | 31 | images = [self.encoder.encode(image, "dwtDct") for image in images] 32 | 33 | images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) 34 | 35 | images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) 36 | return images 37 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stable_video_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | BaseOutput, 6 | OptionalDependencyNotAvailable, 7 | _LazyModule, 8 | get_objects_from_module, 9 | is_torch_available, 10 | is_transformers_available, 11 | ) 12 | 13 | 14 | _dummy_objects = {} 15 | _import_structure = {} 16 | 17 | try: 18 | if not (is_transformers_available() and is_torch_available()): 19 | raise OptionalDependencyNotAvailable() 20 | except OptionalDependencyNotAvailable: 21 | from ...utils import dummy_torch_and_transformers_objects 22 | 23 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 24 | else: 25 | _import_structure.update( 26 | { 27 | "pipeline_stable_video_diffusion": [ 28 | "StableVideoDiffusionPipeline", 29 | "StableVideoDiffusionPipelineOutput", 30 | ], 31 | } 32 | ) 33 | 34 | 35 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 36 | try: 37 | if not (is_transformers_available() and is_torch_available()): 38 | raise OptionalDependencyNotAvailable() 39 | except OptionalDependencyNotAvailable: 40 | from ...utils.dummy_torch_and_transformers_objects import * 41 | else: 42 | from .pipeline_stable_video_diffusion import ( 43 | StableVideoDiffusionPipeline, 44 | StableVideoDiffusionPipelineOutput, 45 | ) 46 | 47 | else: 48 | import sys 49 | 50 | sys.modules[__name__] = _LazyModule( 51 | __name__, 52 | globals()["__file__"], 53 | _import_structure, 54 | module_spec=__spec__, 55 | ) 56 | 57 | for name, value in _dummy_objects.items(): 58 | setattr(sys.modules[__name__], name, value) 59 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/stochastic_karras_ve/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule 4 | 5 | 6 | _import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]} 7 | 8 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 9 | from .pipeline_stochastic_karras_ve import KarrasVePipeline 10 | 11 | else: 12 | import sys 13 | 14 | sys.modules[__name__] = _LazyModule( 15 | __name__, 16 | globals()["__file__"], 17 | _import_structure, 18 | module_spec=__spec__, 19 | ) 20 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/t2i_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["pipeline_stable_diffusion_adapter"] = ["StableDiffusionAdapterPipeline"] 25 | _import_structure["pipeline_stable_diffusion_xl_adapter"] = ["StableDiffusionXLAdapterPipeline"] 26 | 27 | 28 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 29 | try: 30 | if not (is_transformers_available() and is_torch_available()): 31 | raise OptionalDependencyNotAvailable() 32 | except OptionalDependencyNotAvailable: 33 | from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 34 | else: 35 | from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline 36 | from .pipeline_stable_diffusion_xl_adapter import StableDiffusionXLAdapterPipeline 37 | else: 38 | import sys 39 | 40 | sys.modules[__name__] = _LazyModule( 41 | __name__, 42 | globals()["__file__"], 43 | _import_structure, 44 | module_spec=__spec__, 45 | ) 46 | for name, value in _dummy_objects.items(): 47 | setattr(sys.modules[__name__], name, value) 48 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/text_to_video_synthesis/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["pipeline_output"] = ["TextToVideoSDPipelineOutput"] 25 | _import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"] 26 | _import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"] 27 | _import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"] 28 | _import_structure["pipeline_text_to_video_zero_sdxl"] = ["TextToVideoZeroSDXLPipeline"] 29 | 30 | 31 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 32 | try: 33 | if not (is_transformers_available() and is_torch_available()): 34 | raise OptionalDependencyNotAvailable() 35 | except OptionalDependencyNotAvailable: 36 | from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 37 | else: 38 | from .pipeline_output import TextToVideoSDPipelineOutput 39 | from .pipeline_text_to_video_synth import TextToVideoSDPipeline 40 | from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline 41 | from .pipeline_text_to_video_zero import TextToVideoZeroPipeline 42 | from .pipeline_text_to_video_zero_sdxl import TextToVideoZeroSDXLPipeline 43 | 44 | else: 45 | import sys 46 | 47 | sys.modules[__name__] = _LazyModule( 48 | __name__, 49 | globals()["__file__"], 50 | _import_structure, 51 | module_spec=__spec__, 52 | ) 53 | for name, value in _dummy_objects.items(): 54 | setattr(sys.modules[__name__], name, value) 55 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...utils import ( 8 | BaseOutput, 9 | ) 10 | 11 | 12 | @dataclass 13 | class TextToVideoSDPipelineOutput(BaseOutput): 14 | """ 15 | Output class for text-to-video pipelines. 16 | 17 | Args: 18 | frames (`List[np.ndarray]` or `torch.FloatTensor`) 19 | List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as 20 | a `torch` tensor. The length of the list denotes the video length (the number of frames). 21 | """ 22 | 23 | frames: Union[List[np.ndarray], torch.FloatTensor] 24 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/unclip/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | is_torch_available, 8 | is_transformers_available, 9 | is_transformers_version, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline 21 | 22 | _dummy_objects.update( 23 | {"UnCLIPImageVariationPipeline": UnCLIPImageVariationPipeline, "UnCLIPPipeline": UnCLIPPipeline} 24 | ) 25 | else: 26 | _import_structure["pipeline_unclip"] = ["UnCLIPPipeline"] 27 | _import_structure["pipeline_unclip_image_variation"] = ["UnCLIPImageVariationPipeline"] 28 | _import_structure["text_proj"] = ["UnCLIPTextProjModel"] 29 | 30 | 31 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 32 | try: 33 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): 34 | raise OptionalDependencyNotAvailable() 35 | except OptionalDependencyNotAvailable: 36 | from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 37 | else: 38 | from .pipeline_unclip import UnCLIPPipeline 39 | from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline 40 | from .text_proj import UnCLIPTextProjModel 41 | 42 | else: 43 | import sys 44 | 45 | sys.modules[__name__] = _LazyModule( 46 | __name__, 47 | globals()["__file__"], 48 | _import_structure, 49 | module_spec=__spec__, 50 | ) 51 | for name, value in _dummy_objects.items(): 52 | setattr(sys.modules[__name__], name, value) 53 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/unidiffuser/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | is_torch_available, 8 | is_transformers_available, 9 | ) 10 | 11 | 12 | _dummy_objects = {} 13 | _import_structure = {} 14 | 15 | try: 16 | if not (is_transformers_available() and is_torch_available()): 17 | raise OptionalDependencyNotAvailable() 18 | except OptionalDependencyNotAvailable: 19 | from ...utils.dummy_torch_and_transformers_objects import ( 20 | ImageTextPipelineOutput, 21 | UniDiffuserPipeline, 22 | ) 23 | 24 | _dummy_objects.update( 25 | {"ImageTextPipelineOutput": ImageTextPipelineOutput, "UniDiffuserPipeline": UniDiffuserPipeline} 26 | ) 27 | else: 28 | _import_structure["modeling_text_decoder"] = ["UniDiffuserTextDecoder"] 29 | _import_structure["modeling_uvit"] = ["UniDiffuserModel", "UTransformer2DModel"] 30 | _import_structure["pipeline_unidiffuser"] = ["ImageTextPipelineOutput", "UniDiffuserPipeline"] 31 | 32 | 33 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 34 | try: 35 | if not (is_transformers_available() and is_torch_available()): 36 | raise OptionalDependencyNotAvailable() 37 | except OptionalDependencyNotAvailable: 38 | from ...utils.dummy_torch_and_transformers_objects import ( 39 | ImageTextPipelineOutput, 40 | UniDiffuserPipeline, 41 | ) 42 | else: 43 | from .modeling_text_decoder import UniDiffuserTextDecoder 44 | from .modeling_uvit import UniDiffuserModel, UTransformer2DModel 45 | from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline 46 | 47 | else: 48 | import sys 49 | 50 | sys.modules[__name__] = _LazyModule( 51 | __name__, 52 | globals()["__file__"], 53 | _import_structure, 54 | module_spec=__spec__, 55 | ) 56 | 57 | for name, value in _dummy_objects.items(): 58 | setattr(sys.modules[__name__], name, value) 59 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/versatile_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | is_torch_available, 8 | is_transformers_available, 9 | is_transformers_version, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils.dummy_torch_and_transformers_objects import ( 21 | VersatileDiffusionDualGuidedPipeline, 22 | VersatileDiffusionImageVariationPipeline, 23 | VersatileDiffusionPipeline, 24 | VersatileDiffusionTextToImagePipeline, 25 | ) 26 | 27 | _dummy_objects.update( 28 | { 29 | "VersatileDiffusionDualGuidedPipeline": VersatileDiffusionDualGuidedPipeline, 30 | "VersatileDiffusionImageVariationPipeline": VersatileDiffusionImageVariationPipeline, 31 | "VersatileDiffusionPipeline": VersatileDiffusionPipeline, 32 | "VersatileDiffusionTextToImagePipeline": VersatileDiffusionTextToImagePipeline, 33 | } 34 | ) 35 | else: 36 | _import_structure["modeling_text_unet"] = ["UNetFlatConditionModel"] 37 | _import_structure["pipeline_versatile_diffusion"] = ["VersatileDiffusionPipeline"] 38 | _import_structure["pipeline_versatile_diffusion_dual_guided"] = ["VersatileDiffusionDualGuidedPipeline"] 39 | _import_structure["pipeline_versatile_diffusion_image_variation"] = ["VersatileDiffusionImageVariationPipeline"] 40 | _import_structure["pipeline_versatile_diffusion_text_to_image"] = ["VersatileDiffusionTextToImagePipeline"] 41 | 42 | 43 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 44 | try: 45 | if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): 46 | raise OptionalDependencyNotAvailable() 47 | except OptionalDependencyNotAvailable: 48 | from ...utils.dummy_torch_and_transformers_objects import ( 49 | VersatileDiffusionDualGuidedPipeline, 50 | VersatileDiffusionImageVariationPipeline, 51 | VersatileDiffusionPipeline, 52 | VersatileDiffusionTextToImagePipeline, 53 | ) 54 | else: 55 | from .pipeline_versatile_diffusion import VersatileDiffusionPipeline 56 | from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline 57 | from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline 58 | from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline 59 | 60 | else: 61 | import sys 62 | 63 | sys.modules[__name__] = _LazyModule( 64 | __name__, 65 | globals()["__file__"], 66 | _import_structure, 67 | module_spec=__spec__, 68 | ) 69 | 70 | for name, value in _dummy_objects.items(): 71 | setattr(sys.modules[__name__], name, value) 72 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/vq_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | is_torch_available, 8 | is_transformers_available, 9 | ) 10 | 11 | 12 | _dummy_objects = {} 13 | _import_structure = {} 14 | 15 | try: 16 | if not (is_transformers_available() and is_torch_available()): 17 | raise OptionalDependencyNotAvailable() 18 | except OptionalDependencyNotAvailable: 19 | from ...utils.dummy_torch_and_transformers_objects import ( 20 | LearnedClassifierFreeSamplingEmbeddings, 21 | VQDiffusionPipeline, 22 | ) 23 | 24 | _dummy_objects.update( 25 | { 26 | "LearnedClassifierFreeSamplingEmbeddings": LearnedClassifierFreeSamplingEmbeddings, 27 | "VQDiffusionPipeline": VQDiffusionPipeline, 28 | } 29 | ) 30 | else: 31 | _import_structure["pipeline_vq_diffusion"] = ["LearnedClassifierFreeSamplingEmbeddings", "VQDiffusionPipeline"] 32 | 33 | 34 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 35 | try: 36 | if not (is_transformers_available() and is_torch_available()): 37 | raise OptionalDependencyNotAvailable() 38 | except OptionalDependencyNotAvailable: 39 | from ...utils.dummy_torch_and_transformers_objects import ( 40 | LearnedClassifierFreeSamplingEmbeddings, 41 | VQDiffusionPipeline, 42 | ) 43 | else: 44 | from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline 45 | 46 | else: 47 | import sys 48 | 49 | sys.modules[__name__] = _LazyModule( 50 | __name__, 51 | globals()["__file__"], 52 | _import_structure, 53 | module_spec=__spec__, 54 | ) 55 | 56 | for name, value in _dummy_objects.items(): 57 | setattr(sys.modules[__name__], name, value) 58 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/wuerstchen/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_torch_and_transformers_objects 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) 23 | else: 24 | _import_structure["modeling_paella_vq_model"] = ["PaellaVQModel"] 25 | _import_structure["modeling_wuerstchen_diffnext"] = ["WuerstchenDiffNeXt"] 26 | _import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"] 27 | _import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"] 28 | _import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"] 29 | _import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"] 30 | 31 | 32 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 33 | try: 34 | if not (is_transformers_available() and is_torch_available()): 35 | raise OptionalDependencyNotAvailable() 36 | except OptionalDependencyNotAvailable: 37 | from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 38 | else: 39 | from .modeling_paella_vq_model import PaellaVQModel 40 | from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt 41 | from .modeling_wuerstchen_prior import WuerstchenPrior 42 | from .pipeline_wuerstchen import WuerstchenDecoderPipeline 43 | from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline 44 | from .pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPriorPipeline 45 | else: 46 | import sys 47 | 48 | sys.modules[__name__] = _LazyModule( 49 | __name__, 50 | globals()["__file__"], 51 | _import_structure, 52 | module_spec=__spec__, 53 | ) 54 | 55 | for name, value in _dummy_objects.items(): 56 | setattr(sys.modules[__name__], name, value) 57 | -------------------------------------------------------------------------------- /src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Dominic Rampas MIT License 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | from ...models.attention_processor import Attention 20 | from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear 21 | from ...utils import USE_PEFT_BACKEND 22 | 23 | 24 | class WuerstchenLayerNorm(nn.LayerNorm): 25 | def __init__(self, *args, **kwargs): 26 | super().__init__(*args, **kwargs) 27 | 28 | def forward(self, x): 29 | x = x.permute(0, 2, 3, 1) 30 | x = super().forward(x) 31 | return x.permute(0, 3, 1, 2) 32 | 33 | 34 | class TimestepBlock(nn.Module): 35 | def __init__(self, c, c_timestep): 36 | super().__init__() 37 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 38 | self.mapper = linear_cls(c_timestep, c * 2) 39 | 40 | def forward(self, x, t): 41 | a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) 42 | return x * (1 + a) + b 43 | 44 | 45 | class ResBlock(nn.Module): 46 | def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): 47 | super().__init__() 48 | 49 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 50 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 51 | 52 | self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) 53 | self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) 54 | self.channelwise = nn.Sequential( 55 | linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c) 56 | ) 57 | 58 | def forward(self, x, x_skip=None): 59 | x_res = x 60 | if x_skip is not None: 61 | x = torch.cat([x, x_skip], dim=1) 62 | x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) 63 | x = self.channelwise(x).permute(0, 3, 1, 2) 64 | return x + x_res 65 | 66 | 67 | # from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 68 | class GlobalResponseNorm(nn.Module): 69 | def __init__(self, dim): 70 | super().__init__() 71 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 72 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 73 | 74 | def forward(self, x): 75 | agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) 76 | stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) 77 | return self.gamma * (x * stand_div_norm) + self.beta + x 78 | 79 | 80 | class AttnBlock(nn.Module): 81 | def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): 82 | super().__init__() 83 | 84 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 85 | 86 | self.self_attn = self_attn 87 | self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) 88 | self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) 89 | self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c)) 90 | 91 | def forward(self, x, kv): 92 | kv = self.kv_mapper(kv) 93 | norm_x = self.norm(x) 94 | if self.self_attn: 95 | batch_size, channel, _, _ = x.shape 96 | kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) 97 | x = x + self.attention(norm_x, encoder_hidden_states=kv) 98 | return x 99 | -------------------------------------------------------------------------------- /src/diffusers/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenEarthLab/CasCast/487a68b5ade9aa829fe7df2e8f6746b4d9acc233/src/diffusers/py.typed -------------------------------------------------------------------------------- /src/diffusers/schedulers/README.md: -------------------------------------------------------------------------------- 1 | # Schedulers 2 | 3 | For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers/overview). -------------------------------------------------------------------------------- /src/diffusers/schedulers/deprecated/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from ...utils import ( 4 | DIFFUSERS_SLOW_IMPORT, 5 | OptionalDependencyNotAvailable, 6 | _LazyModule, 7 | get_objects_from_module, 8 | is_torch_available, 9 | is_transformers_available, 10 | ) 11 | 12 | 13 | _dummy_objects = {} 14 | _import_structure = {} 15 | 16 | try: 17 | if not (is_transformers_available() and is_torch_available()): 18 | raise OptionalDependencyNotAvailable() 19 | except OptionalDependencyNotAvailable: 20 | from ...utils import dummy_pt_objects # noqa F403 21 | 22 | _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) 23 | else: 24 | _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] 25 | _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"] 26 | 27 | if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: 28 | try: 29 | if not is_torch_available(): 30 | raise OptionalDependencyNotAvailable() 31 | 32 | except OptionalDependencyNotAvailable: 33 | from ..utils.dummy_pt_objects import * # noqa F403 34 | else: 35 | from .scheduling_karras_ve import KarrasVeScheduler 36 | from .scheduling_sde_vp import ScoreSdeVpScheduler 37 | 38 | 39 | else: 40 | import sys 41 | 42 | sys.modules[__name__] = _LazyModule( 43 | __name__, 44 | globals()["__file__"], 45 | _import_structure, 46 | module_spec=__spec__, 47 | ) 48 | 49 | for name, value in _dummy_objects.items(): 50 | setattr(sys.modules[__name__], name, value) 51 | -------------------------------------------------------------------------------- /src/diffusers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | from packaging import version 19 | 20 | from .. import __version__ 21 | from .constants import ( 22 | CONFIG_NAME, 23 | DEPRECATED_REVISION_ARGS, 24 | DIFFUSERS_CACHE, 25 | DIFFUSERS_DYNAMIC_MODULE_NAME, 26 | FLAX_WEIGHTS_NAME, 27 | HF_MODULES_CACHE, 28 | HUGGINGFACE_CO_RESOLVE_ENDPOINT, 29 | MIN_PEFT_VERSION, 30 | ONNX_EXTERNAL_WEIGHTS_NAME, 31 | ONNX_WEIGHTS_NAME, 32 | SAFETENSORS_WEIGHTS_NAME, 33 | USE_PEFT_BACKEND, 34 | WEIGHTS_NAME, 35 | ) 36 | from .deprecation_utils import deprecate 37 | from .doc_utils import replace_example_docstring 38 | from .dynamic_modules_utils import get_class_from_dynamic_module 39 | from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video 40 | from .hub_utils import ( 41 | HF_HUB_OFFLINE, 42 | PushToHubMixin, 43 | _add_variant, 44 | _get_model_file, 45 | extract_commit_hash, 46 | http_user_agent, 47 | ) 48 | from .import_utils import ( 49 | BACKENDS_MAPPING, 50 | DIFFUSERS_SLOW_IMPORT, 51 | ENV_VARS_TRUE_AND_AUTO_VALUES, 52 | ENV_VARS_TRUE_VALUES, 53 | USE_JAX, 54 | USE_TF, 55 | USE_TORCH, 56 | DummyObject, 57 | OptionalDependencyNotAvailable, 58 | _LazyModule, 59 | get_objects_from_module, 60 | is_accelerate_available, 61 | is_accelerate_version, 62 | is_bs4_available, 63 | is_flax_available, 64 | is_ftfy_available, 65 | is_inflect_available, 66 | is_invisible_watermark_available, 67 | is_k_diffusion_available, 68 | is_k_diffusion_version, 69 | is_librosa_available, 70 | is_note_seq_available, 71 | is_omegaconf_available, 72 | is_onnx_available, 73 | is_peft_available, 74 | is_scipy_available, 75 | is_tensorboard_available, 76 | is_torch_available, 77 | is_torch_version, 78 | is_torch_xla_available, 79 | is_torchsde_available, 80 | is_transformers_available, 81 | is_transformers_version, 82 | is_unidecode_available, 83 | is_wandb_available, 84 | is_xformers_available, 85 | requires_backends, 86 | ) 87 | from .loading_utils import load_image 88 | from .logging import get_logger 89 | from .outputs import BaseOutput 90 | from .peft_utils import ( 91 | check_peft_version, 92 | delete_adapter_layers, 93 | get_adapter_name, 94 | get_peft_kwargs, 95 | recurse_remove_peft_layers, 96 | scale_lora_layers, 97 | set_adapter_layers, 98 | set_weights_and_activate_adapters, 99 | unscale_lora_layers, 100 | ) 101 | from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil 102 | from .state_dict_utils import ( 103 | convert_state_dict_to_diffusers, 104 | convert_state_dict_to_peft, 105 | convert_unet_state_dict_to_peft, 106 | ) 107 | 108 | 109 | logger = get_logger(__name__) 110 | 111 | 112 | def check_min_version(min_version): 113 | if version.parse(__version__) < version.parse(min_version): 114 | if "dev" in min_version: 115 | error_message = ( 116 | "This example requires a source install from HuggingFace diffusers (see " 117 | "`https://huggingface.co/docs/diffusers/installation#install-from-source`)," 118 | ) 119 | else: 120 | error_message = f"This example requires a minimum version of {min_version}," 121 | error_message += f" but the version found is {__version__}.\n" 122 | raise ImportError(error_message) 123 | -------------------------------------------------------------------------------- /src/diffusers/utils/accelerate_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Accelerate utilities: Utilities related to accelerate 16 | """ 17 | 18 | from packaging import version 19 | 20 | from .import_utils import is_accelerate_available 21 | 22 | 23 | if is_accelerate_available(): 24 | import accelerate 25 | 26 | 27 | def apply_forward_hook(method): 28 | """ 29 | Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful 30 | for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the 31 | appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. 32 | 33 | This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. 34 | 35 | :param method: The method to decorate. This method should be a method of a PyTorch module. 36 | """ 37 | if not is_accelerate_available(): 38 | return method 39 | accelerate_version = version.parse(accelerate.__version__).base_version 40 | if version.parse(accelerate_version) < version.parse("0.17.0"): 41 | return method 42 | 43 | def wrapper(self, *args, **kwargs): 44 | if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): 45 | self._hf_hook.pre_forward(self) 46 | return method(self, *args, **kwargs) 47 | 48 | return wrapper 49 | -------------------------------------------------------------------------------- /src/diffusers/utils/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | import os 16 | 17 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home 18 | from packaging import version 19 | 20 | from ..dependency_versions_check import dep_version_check 21 | from .import_utils import ENV_VARS_TRUE_VALUES, is_peft_available, is_transformers_available 22 | 23 | 24 | default_cache_path = HUGGINGFACE_HUB_CACHE 25 | 26 | MIN_PEFT_VERSION = "0.6.0" 27 | MIN_TRANSFORMERS_VERSION = "4.34.0" 28 | _CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES 29 | 30 | 31 | CONFIG_NAME = "config.json" 32 | WEIGHTS_NAME = "diffusion_pytorch_model.bin" 33 | FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" 34 | ONNX_WEIGHTS_NAME = "model.onnx" 35 | SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" 36 | ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" 37 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") 38 | DIFFUSERS_CACHE = default_cache_path 39 | DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" 40 | HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) 41 | DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] 42 | 43 | # Below should be `True` if the current version of `peft` and `transformers` are compatible with 44 | # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are 45 | # available. 46 | # For PEFT it is has to be greater than or equal to 0.6.0 and for transformers it has to be greater than or equal to 4.34.0. 47 | _required_peft_version = is_peft_available() and version.parse( 48 | version.parse(importlib.metadata.version("peft")).base_version 49 | ) >= version.parse(MIN_PEFT_VERSION) 50 | _required_transformers_version = is_transformers_available() and version.parse( 51 | version.parse(importlib.metadata.version("transformers")).base_version 52 | ) >= version.parse(MIN_TRANSFORMERS_VERSION) 53 | 54 | USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version 55 | 56 | if USE_PEFT_BACKEND and _CHECK_PEFT: 57 | dep_version_check("peft") 58 | -------------------------------------------------------------------------------- /src/diffusers/utils/deprecation_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from typing import Any, Dict, Optional, Union 4 | 5 | from packaging import version 6 | 7 | 8 | def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): 9 | from .. import __version__ 10 | 11 | deprecated_kwargs = take_from 12 | values = () 13 | if not isinstance(args[0], tuple): 14 | args = (args,) 15 | 16 | for attribute, version_name, message in args: 17 | if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): 18 | raise ValueError( 19 | f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" 20 | f" version {__version__} is >= {version_name}" 21 | ) 22 | 23 | warning = None 24 | if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: 25 | values += (deprecated_kwargs.pop(attribute),) 26 | warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." 27 | elif hasattr(deprecated_kwargs, attribute): 28 | values += (getattr(deprecated_kwargs, attribute),) 29 | warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." 30 | elif deprecated_kwargs is None: 31 | warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." 32 | 33 | if warning is not None: 34 | warning = warning + " " if standard_warn else "" 35 | warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) 36 | 37 | if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: 38 | call_frame = inspect.getouterframes(inspect.currentframe())[1] 39 | filename = call_frame.filename 40 | line_number = call_frame.lineno 41 | function = call_frame.function 42 | key, value = next(iter(deprecated_kwargs.items())) 43 | raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") 44 | 45 | if len(values) == 0: 46 | return 47 | elif len(values) == 1: 48 | return values[0] 49 | return values 50 | -------------------------------------------------------------------------------- /src/diffusers/utils/doc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Doc utilities: Utilities related to documentation 16 | """ 17 | import re 18 | 19 | 20 | def replace_example_docstring(example_docstring): 21 | def docstring_decorator(fn): 22 | func_doc = fn.__doc__ 23 | lines = func_doc.split("\n") 24 | i = 0 25 | while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None: 26 | i += 1 27 | if i < len(lines): 28 | lines[i] = example_docstring 29 | func_doc = "\n".join(lines) 30 | else: 31 | raise ValueError( 32 | f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, " 33 | f"current docstring is:\n{func_doc}" 34 | ) 35 | fn.__doc__ = func_doc 36 | return fn 37 | 38 | return docstring_decorator 39 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_flax_and_transformers_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): 6 | _backends = ["flax", "transformers"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["flax", "transformers"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["flax", "transformers"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["flax", "transformers"]) 18 | 19 | 20 | class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): 21 | _backends = ["flax", "transformers"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["flax", "transformers"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["flax", "transformers"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["flax", "transformers"]) 33 | 34 | 35 | class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject): 36 | _backends = ["flax", "transformers"] 37 | 38 | def __init__(self, *args, **kwargs): 39 | requires_backends(self, ["flax", "transformers"]) 40 | 41 | @classmethod 42 | def from_config(cls, *args, **kwargs): 43 | requires_backends(cls, ["flax", "transformers"]) 44 | 45 | @classmethod 46 | def from_pretrained(cls, *args, **kwargs): 47 | requires_backends(cls, ["flax", "transformers"]) 48 | 49 | 50 | class FlaxStableDiffusionPipeline(metaclass=DummyObject): 51 | _backends = ["flax", "transformers"] 52 | 53 | def __init__(self, *args, **kwargs): 54 | requires_backends(self, ["flax", "transformers"]) 55 | 56 | @classmethod 57 | def from_config(cls, *args, **kwargs): 58 | requires_backends(cls, ["flax", "transformers"]) 59 | 60 | @classmethod 61 | def from_pretrained(cls, *args, **kwargs): 62 | requires_backends(cls, ["flax", "transformers"]) 63 | 64 | 65 | class FlaxStableDiffusionXLPipeline(metaclass=DummyObject): 66 | _backends = ["flax", "transformers"] 67 | 68 | def __init__(self, *args, **kwargs): 69 | requires_backends(self, ["flax", "transformers"]) 70 | 71 | @classmethod 72 | def from_config(cls, *args, **kwargs): 73 | requires_backends(cls, ["flax", "transformers"]) 74 | 75 | @classmethod 76 | def from_pretrained(cls, *args, **kwargs): 77 | requires_backends(cls, ["flax", "transformers"]) 78 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_note_seq_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class MidiProcessor(metaclass=DummyObject): 6 | _backends = ["note_seq"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["note_seq"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["note_seq"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["note_seq"]) 18 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_onnx_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class OnnxRuntimeModel(metaclass=DummyObject): 6 | _backends = ["onnx"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["onnx"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["onnx"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["onnx"]) 18 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_torch_and_librosa_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class AudioDiffusionPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "librosa"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "librosa"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "librosa"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "librosa"]) 18 | 19 | 20 | class Mel(metaclass=DummyObject): 21 | _backends = ["torch", "librosa"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["torch", "librosa"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["torch", "librosa"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["torch", "librosa"]) 33 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_torch_and_scipy_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class LMSDiscreteScheduler(metaclass=DummyObject): 6 | _backends = ["torch", "scipy"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "scipy"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "scipy"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "scipy"]) 18 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_torch_and_torchsde_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class DPMSolverSDEScheduler(metaclass=DummyObject): 6 | _backends = ["torch", "torchsde"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "torchsde"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "torchsde"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "torchsde"]) 18 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class StableDiffusionKDiffusionPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "transformers", "k_diffusion"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "transformers", "k_diffusion"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "transformers", "k_diffusion"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "transformers", "k_diffusion"]) 18 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "transformers", "onnx"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "transformers", "onnx"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "transformers", "onnx"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "transformers", "onnx"]) 18 | 19 | 20 | class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): 21 | _backends = ["torch", "transformers", "onnx"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["torch", "transformers", "onnx"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["torch", "transformers", "onnx"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["torch", "transformers", "onnx"]) 33 | 34 | 35 | class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): 36 | _backends = ["torch", "transformers", "onnx"] 37 | 38 | def __init__(self, *args, **kwargs): 39 | requires_backends(self, ["torch", "transformers", "onnx"]) 40 | 41 | @classmethod 42 | def from_config(cls, *args, **kwargs): 43 | requires_backends(cls, ["torch", "transformers", "onnx"]) 44 | 45 | @classmethod 46 | def from_pretrained(cls, *args, **kwargs): 47 | requires_backends(cls, ["torch", "transformers", "onnx"]) 48 | 49 | 50 | class OnnxStableDiffusionPipeline(metaclass=DummyObject): 51 | _backends = ["torch", "transformers", "onnx"] 52 | 53 | def __init__(self, *args, **kwargs): 54 | requires_backends(self, ["torch", "transformers", "onnx"]) 55 | 56 | @classmethod 57 | def from_config(cls, *args, **kwargs): 58 | requires_backends(cls, ["torch", "transformers", "onnx"]) 59 | 60 | @classmethod 61 | def from_pretrained(cls, *args, **kwargs): 62 | requires_backends(cls, ["torch", "transformers", "onnx"]) 63 | 64 | 65 | class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject): 66 | _backends = ["torch", "transformers", "onnx"] 67 | 68 | def __init__(self, *args, **kwargs): 69 | requires_backends(self, ["torch", "transformers", "onnx"]) 70 | 71 | @classmethod 72 | def from_config(cls, *args, **kwargs): 73 | requires_backends(cls, ["torch", "transformers", "onnx"]) 74 | 75 | @classmethod 76 | def from_pretrained(cls, *args, **kwargs): 77 | requires_backends(cls, ["torch", "transformers", "onnx"]) 78 | 79 | 80 | class StableDiffusionOnnxPipeline(metaclass=DummyObject): 81 | _backends = ["torch", "transformers", "onnx"] 82 | 83 | def __init__(self, *args, **kwargs): 84 | requires_backends(self, ["torch", "transformers", "onnx"]) 85 | 86 | @classmethod 87 | def from_config(cls, *args, **kwargs): 88 | requires_backends(cls, ["torch", "transformers", "onnx"]) 89 | 90 | @classmethod 91 | def from_pretrained(cls, *args, **kwargs): 92 | requires_backends(cls, ["torch", "transformers", "onnx"]) 93 | -------------------------------------------------------------------------------- /src/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class SpectrogramDiffusionPipeline(metaclass=DummyObject): 6 | _backends = ["transformers", "torch", "note_seq"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["transformers", "torch", "note_seq"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["transformers", "torch", "note_seq"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["transformers", "torch", "note_seq"]) 18 | -------------------------------------------------------------------------------- /src/diffusers/utils/loading_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | 4 | import PIL.Image 5 | import PIL.ImageOps 6 | import requests 7 | 8 | 9 | def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: 10 | """ 11 | Loads `image` to a PIL Image. 12 | 13 | Args: 14 | image (`str` or `PIL.Image.Image`): 15 | The image to convert to the PIL Image format. 16 | Returns: 17 | `PIL.Image.Image`: 18 | A PIL Image. 19 | """ 20 | if isinstance(image, str): 21 | if image.startswith("http://") or image.startswith("https://"): 22 | image = PIL.Image.open(requests.get(image, stream=True).raw) 23 | elif os.path.isfile(image): 24 | image = PIL.Image.open(image) 25 | else: 26 | raise ValueError( 27 | f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" 28 | ) 29 | elif isinstance(image, PIL.Image.Image): 30 | image = image 31 | else: 32 | raise ValueError( 33 | "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." 34 | ) 35 | image = PIL.ImageOps.exif_transpose(image) 36 | image = image.convert("RGB") 37 | return image 38 | -------------------------------------------------------------------------------- /src/diffusers/utils/model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | 7 | 8 | # {{ model_name | default("Diffusion Model") }} 9 | 10 | ## Model description 11 | 12 | This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library 13 | on the `{{ dataset_name }}` dataset. 14 | 15 | ## Intended uses & limitations 16 | 17 | #### How to use 18 | 19 | ```python 20 | # TODO: add an example code snippet for running this diffusion pipeline 21 | ``` 22 | 23 | #### Limitations and bias 24 | 25 | [TODO: provide examples of latent issues and potential remediations] 26 | 27 | ## Training data 28 | 29 | [TODO: describe the data used to train the model] 30 | 31 | ### Training hyperparameters 32 | 33 | The following hyperparameters were used during training: 34 | - learning_rate: {{ learning_rate }} 35 | - train_batch_size: {{ train_batch_size }} 36 | - eval_batch_size: {{ eval_batch_size }} 37 | - gradient_accumulation_steps: {{ gradient_accumulation_steps }} 38 | - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} 39 | - lr_scheduler: {{ lr_scheduler }} 40 | - lr_warmup_steps: {{ lr_warmup_steps }} 41 | - ema_inv_gamma: {{ ema_inv_gamma }} 42 | - ema_inv_gamma: {{ ema_power }} 43 | - ema_inv_gamma: {{ ema_max_decay }} 44 | - mixed_precision: {{ mixed_precision }} 45 | 46 | ### Training results 47 | 48 | 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) 49 | 50 | 51 | -------------------------------------------------------------------------------- /src/diffusers/utils/pil_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import PIL.Image 4 | import PIL.ImageOps 5 | from packaging import version 6 | from PIL import Image 7 | 8 | 9 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 10 | PIL_INTERPOLATION = { 11 | "linear": PIL.Image.Resampling.BILINEAR, 12 | "bilinear": PIL.Image.Resampling.BILINEAR, 13 | "bicubic": PIL.Image.Resampling.BICUBIC, 14 | "lanczos": PIL.Image.Resampling.LANCZOS, 15 | "nearest": PIL.Image.Resampling.NEAREST, 16 | } 17 | else: 18 | PIL_INTERPOLATION = { 19 | "linear": PIL.Image.LINEAR, 20 | "bilinear": PIL.Image.BILINEAR, 21 | "bicubic": PIL.Image.BICUBIC, 22 | "lanczos": PIL.Image.LANCZOS, 23 | "nearest": PIL.Image.NEAREST, 24 | } 25 | 26 | 27 | def pt_to_pil(images): 28 | """ 29 | Convert a torch image to a PIL image. 30 | """ 31 | images = (images / 2 + 0.5).clamp(0, 1) 32 | images = images.cpu().permute(0, 2, 3, 1).float().numpy() 33 | images = numpy_to_pil(images) 34 | return images 35 | 36 | 37 | def numpy_to_pil(images): 38 | """ 39 | Convert a numpy image or a batch of images to a PIL image. 40 | """ 41 | if images.ndim == 3: 42 | images = images[None, ...] 43 | images = (images * 255).round().astype("uint8") 44 | if images.shape[-1] == 1: 45 | # special case for grayscale (single channel) images 46 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] 47 | else: 48 | pil_images = [Image.fromarray(image) for image in images] 49 | 50 | return pil_images 51 | 52 | 53 | def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: 54 | """ 55 | Prepares a single grid of images. Useful for visualization purposes. 56 | """ 57 | assert len(images) == rows * cols 58 | 59 | if resize is not None: 60 | images = [img.resize((resize, resize)) for img in images] 61 | 62 | w, h = images[0].size 63 | grid = Image.new("RGB", size=(cols * w, rows * h)) 64 | 65 | for i, img in enumerate(images): 66 | grid.paste(img, box=(i % cols * w, i // cols * h)) 67 | return grid 68 | -------------------------------------------------------------------------------- /train_debug.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=1 4 | node_num=1 5 | single_gpus=`expr $gpus / $node_num` 6 | 7 | cpus=13 8 | 9 | # export NCCL_IB_DISABLE=1 10 | # export NCCL_SOCKET_IFNAME=eth0 11 | # export NCCL_DEBUG=INFO 12 | # export NCCL_DEBUG_SUBSYS=ALL 13 | # export TORCH_DISTRIBUTED_DEBUG=INFO 14 | 15 | while true 16 | do 17 | PORT=$((((RANDOM<<15)|RANDOM)%49152 + 10000)) 18 | break 19 | done 20 | echo $PORT 21 | 22 | # export TORCH_DISTRIBUTED_DEBUG=DETAIL 23 | 24 | srun -p ai4earth --kill-on-bad-exit=1 --quotatype=auto --ntasks-per-node=$single_gpus --time=43200 --cpus-per-task=$cpus -N $node_num --gres=gpu:$single_gpus python -u train.py \ 25 | --init_method 'tcp://127.0.0.1:'$PORT \ 26 | -c ./configs/sevir_used/cascast_diffusion.yaml \ 27 | --world_size $gpus \ 28 | --per_cpus $cpus \ 29 | --tensor_model_parallel_size 1 \ 30 | --outdir '/mnt/lustre/gongjunchao/release_code/cascast/experiments' \ 31 | --desc 'debug' 32 | 33 | # 34 | sleep 2 35 | rm -f batchscript-* -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenEarthLab/CasCast/487a68b5ade9aa829fe7df2e8f6746b4d9acc233/utils/__init__.py -------------------------------------------------------------------------------- /utils/checkpoint_ceph.py: -------------------------------------------------------------------------------- 1 | import imp 2 | from torch.utils.data import Dataset 3 | 4 | try: 5 | from petrel_client.client import Client 6 | except: 7 | pass 8 | from tqdm import tqdm 9 | import numpy as np 10 | import io 11 | import torch 12 | import os 13 | 14 | class checkpoint_ceph(object): 15 | def __init__(self, conf_path="~/petreloss.conf", checkpoint_dir="weatherbench:s3://weatherbench/checkpoint") -> None: 16 | self.client = Client(conf_path=conf_path) 17 | self.checkpoint_dir = checkpoint_dir 18 | 19 | def load_checkpoint(self, url): 20 | url = os.path.join(self.checkpoint_dir, url) 21 | # url = self.checkpoint_dir + "/" + url 22 | if not self.client.contains(url): 23 | return None 24 | with io.BytesIO(self.client.get(url, update_cache=True)) as f: 25 | checkpoint_data = torch.load(f, map_location=torch.device('cpu')) 26 | return checkpoint_data 27 | 28 | def load_checkpoint_with_ckptDir(self, url, ckpt_dir): 29 | url = os.path.join(ckpt_dir, url) 30 | if not self.client.contains(url): 31 | return None 32 | with io.BytesIO(self.client.get(url, update_cache=True)) as f: 33 | checkpoint_data = torch.load(f, map_location=torch.device('cpu')) 34 | return checkpoint_data 35 | 36 | def save_checkpoint(self, url, data): 37 | url = os.path.join(self.checkpoint_dir, url) 38 | # url = self.checkpoint_dir + "/" + url 39 | with io.BytesIO() as f: 40 | torch.save(data, f) 41 | f.seek(0) 42 | self.client.put(url, f) 43 | 44 | def save_prediction_results(self, url, data): 45 | url = os.path.join(self.checkpoint_dir, url) 46 | # url = self.checkpoint_dir + "/" + url 47 | with io.BytesIO() as f: 48 | np.save(f, data) 49 | f.seek(0) 50 | self.client.put(url, f) 51 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | 6 | logger_initialized = {} 7 | 8 | def get_logger(name, save_dir, distributed_rank, filename="log.log", resume=False): 9 | logger = logging.getLogger(name) 10 | if name in logger_initialized: 11 | return logger 12 | 13 | logger.propagate = False 14 | # don't log results for the non-master process 15 | if distributed_rank > 0: 16 | logger.setLevel(logging.ERROR) 17 | return logger 18 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 19 | 20 | ch = logging.StreamHandler() 21 | ch.setLevel(logging.INFO) 22 | ch.setFormatter(formatter) 23 | logger.addHandler(ch) 24 | 25 | if save_dir: 26 | if resume: 27 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode='a') 28 | else: 29 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode='w') 30 | fh.setLevel(logging.INFO) 31 | fh.setFormatter(formatter) 32 | logger.addHandler(fh) 33 | 34 | logger.setLevel(logging.INFO) 35 | 36 | logger_initialized[name] = True 37 | 38 | return logger 39 | --------------------------------------------------------------------------------