The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .github
    └── workflows
    │   ├── bot-autolint.yaml
    │   └── ci.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── CITATION.cff
├── CIs
    └── add_license_all.sh
├── Dockerfile
├── LICENSE
├── README.md
├── app
    ├── app_sana.py
    ├── app_sana_4bit.py
    ├── app_sana_4bit_compare_bf16.py
    ├── app_sana_controlnet_hed.py
    ├── app_sana_inpaint.py
    ├── app_sana_multithread.py
    ├── app_sana_sprint.py
    ├── safety_check.py
    ├── sana_controlnet_pipeline.py
    ├── sana_pipeline.py
    ├── sana_pipeline_inpaint.py
    └── sana_sprint_pipeline.py
├── asset
    ├── Sana.jpg
    ├── app_styles
    │   └── controlnet_app_style.css
    ├── apple.webp
    ├── controlnet
    │   ├── ref_images
    │   │   ├── A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg
    │   │   ├── a house.png
    │   │   ├── a living room.png
    │   │   └── nvidia.png
    │   └── samples_controlnet.json
    ├── cover.png
    ├── docs
    │   ├── ComfyUI
    │   │   ├── SANA-1.5_FlowEuler.json
    │   │   ├── SANA-Sprint.json
    │   │   ├── Sana_CogVideoX.json
    │   │   ├── Sana_FlowEuler.json
    │   │   ├── Sana_FlowEuler_2K.json
    │   │   ├── Sana_FlowEuler_4K.json
    │   │   └── comfyui.md
    │   ├── inference_scaling
    │   │   ├── inference_scaling.md
    │   │   ├── results.jpg
    │   │   └── scaling_curve.jpg
    │   ├── metrics_toolkit.md
    │   ├── model_zoo.md
    │   ├── quantize
    │   │   ├── 4bit_sana.md
    │   │   └── 8bit_sana.md
    │   ├── sana_controlnet.md
    │   ├── sana_lora_dreambooth.md
    │   ├── sana_sprint.md
    │   └── sana_video.md
    ├── example_data
    │   ├── 00000000.jpg
    │   ├── 00000000.png
    │   ├── 00000000.txt
    │   ├── 00000000_InternVL2-26B.json
    │   ├── 00000000_InternVL2-26B_clip_score.json
    │   ├── 00000000_VILA1-5-13B.json
    │   ├── 00000000_VILA1-5-13B_clip_score.json
    │   ├── 00000000_prompt_clip_score.json
    │   └── meta_data.json
    ├── examples.py
    ├── logo.png
    ├── mit-logo.jpg
    ├── model-incremental.jpg
    ├── model_paths.txt
    ├── paper2video.jpg
    └── samples
    │   ├── sample_i2v.txt
    │   ├── samples.txt
    │   ├── samples_mini.txt
    │   └── video_prompts_samples.txt
├── configs
    ├── sana1-5_config
    │   └── 1024ms
    │   │   ├── Sana_1600M_1024px_AdamW_fsdp.yaml
    │   │   ├── Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml
    │   │   ├── Sana_3200M_1024px_came8bit_grow_constant_allqknorm_bf16_lr2e5.yaml
    │   │   └── Sana_4800M_1024px_came8bit_grow_constant_allqknorm_bf16_lr2e5.yaml
    ├── sana_app_config
    │   ├── Sana_1600M_app.yaml
    │   └── Sana_600M_app.yaml
    ├── sana_base.yaml
    ├── sana_config
    │   ├── 1024ms
    │   │   ├── Sana_1600M_img1024.yaml
    │   │   ├── Sana_1600M_img1024_AdamW.yaml
    │   │   ├── Sana_1600M_img1024_CAME8bit.yaml
    │   │   └── Sana_600M_img1024.yaml
    │   ├── 2048ms
    │   │   └── Sana_1600M_img2048_bf16.yaml
    │   ├── 4096ms
    │   │   └── Sana_1600M_img4096_bf16.yaml
    │   └── 512ms
    │   │   ├── Sana_1600M_img512.yaml
    │   │   ├── Sana_600M_img512.yaml
    │   │   ├── ci_Sana_600M_img512.yaml
    │   │   └── sample_dataset.yaml
    ├── sana_controlnet_config
    │   ├── Sana_1600M_1024px_controlnet_bf16.yaml
    │   └── Sana_600M_img1024_controlnet.yaml
    ├── sana_sprint_config
    │   └── 1024ms
    │   │   ├── SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml
    │   │   ├── SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd_dc_ae_lite.yaml
    │   │   ├── SanaSprint_1600M_img1024_bf16_normT_allqknorm_teacher_ft.yaml
    │   │   ├── SanaSprint_600M_1024px_allqknorm_bf16_scm_ladd.yaml
    │   │   └── SanaSprint_600M_1024px_allqknorm_bf16_scm_ladd_dc_ae_lite.yaml
    └── sana_video_config
    │   ├── Sana_2000M_256px_AdamW_fsdp.yaml
    │   └── Sana_2000M_480px_AdamW_fsdp.yaml
├── diffusion
    ├── __init__.py
    ├── data
    │   ├── __init__.py
    │   ├── builder.py
    │   ├── datasets
    │   │   ├── __init__.py
    │   │   ├── sana_data.py
    │   │   ├── sana_data_multi_scale.py
    │   │   ├── utils.py
    │   │   └── video
    │   │   │   └── sana_video_data.py
    │   ├── transforms.py
    │   └── wids
    │   │   ├── __init__.py
    │   │   ├── wids.py
    │   │   ├── wids_dl.py
    │   │   ├── wids_lru.py
    │   │   ├── wids_mmtar.py
    │   │   ├── wids_specs.py
    │   │   └── wids_tar.py
    ├── guiders
    │   ├── __init__.py
    │   └── adaptive_projected_guidance.py
    ├── model
    │   ├── __init__.py
    │   ├── act.py
    │   ├── builder.py
    │   ├── dc_ae
    │   │   └── efficientvit
    │   │   │   ├── __init__.py
    │   │   │   ├── ae_model_zoo.py
    │   │   │   ├── apps
    │   │   │       ├── __init__.py
    │   │   │       ├── setup.py
    │   │   │       ├── trainer
    │   │   │       │   ├── __init__.py
    │   │   │       │   └── run_config.py
    │   │   │       └── utils
    │   │   │       │   ├── __init__.py
    │   │   │       │   ├── dist.py
    │   │   │       │   ├── ema.py
    │   │   │       │   ├── export.py
    │   │   │       │   ├── image.py
    │   │   │       │   ├── init.py
    │   │   │       │   ├── lr.py
    │   │   │       │   ├── metric.py
    │   │   │       │   ├── misc.py
    │   │   │       │   └── opt.py
    │   │   │   └── models
    │   │   │       ├── __init__.py
    │   │   │       ├── efficientvit
    │   │   │           ├── __init__.py
    │   │   │           ├── dc_ae.py
    │   │   │           └── dc_ae_with_temporal.py
    │   │   │       ├── nn
    │   │   │           ├── __init__.py
    │   │   │           ├── act.py
    │   │   │           ├── drop.py
    │   │   │           ├── norm.py
    │   │   │           ├── ops.py
    │   │   │           ├── ops_3d.py
    │   │   │           └── triton_rms_norm.py
    │   │   │       └── utils
    │   │   │           ├── __init__.py
    │   │   │           ├── list.py
    │   │   │           ├── network.py
    │   │   │           ├── random.py
    │   │   │           └── video.py
    │   ├── diffusion_utils.py
    │   ├── dpm_solver.py
    │   ├── edm_sample.py
    │   ├── gaussian_diffusion.py
    │   ├── model_growth_utils.py
    │   ├── nets
    │   │   ├── __init__.py
    │   │   ├── basic_modules.py
    │   │   ├── fastlinear
    │   │   │   ├── develop_triton_ffn.py
    │   │   │   ├── develop_triton_litemla.py
    │   │   │   ├── modules
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── flash_attn.py
    │   │   │   │   ├── lite_mla.py
    │   │   │   │   ├── mb_conv_pre_glu.py
    │   │   │   │   ├── nn
    │   │   │   │   │   ├── act.py
    │   │   │   │   │   ├── conv.py
    │   │   │   │   │   └── norm.py
    │   │   │   │   ├── triton_lite_mla.py
    │   │   │   │   ├── triton_lite_mla_fwd.py
    │   │   │   │   ├── triton_lite_mla_kernels
    │   │   │   │   │   ├── custom_autotune.py
    │   │   │   │   │   ├── linear_relu_fwd.py
    │   │   │   │   │   ├── mm.py
    │   │   │   │   │   ├── pad_vk_mm_fwd.py
    │   │   │   │   │   ├── proj_divide_bwd.py
    │   │   │   │   │   ├── vk_mm_relu_bwd.py
    │   │   │   │   │   ├── vk_q_mm_divide_fwd.py
    │   │   │   │   │   └── vk_q_mm_relu_bwd.py
    │   │   │   │   ├── triton_mb_conv_pre_glu.py
    │   │   │   │   ├── triton_mb_conv_pre_glu_kernels
    │   │   │   │   │   ├── depthwise_conv_fwd.py
    │   │   │   │   │   └── linear_glu_fwd.py
    │   │   │   │   └── utils
    │   │   │   │   │   ├── compare_results.py
    │   │   │   │   │   ├── custom_autotune.py
    │   │   │   │   │   ├── dtype.py
    │   │   │   │   │   ├── export_onnx.py
    │   │   │   │   │   └── model.py
    │   │   │   └── readme.md
    │   │   ├── ladd_blocks.py
    │   │   ├── sana.py
    │   │   ├── sana_U_shape.py
    │   │   ├── sana_U_shape_multi_scale.py
    │   │   ├── sana_blocks.py
    │   │   ├── sana_ladd.py
    │   │   ├── sana_multi_scale.py
    │   │   ├── sana_multi_scale_adaln.py
    │   │   ├── sana_multi_scale_controlnet.py
    │   │   ├── sana_multi_scale_video.py
    │   │   └── sana_others.py
    │   ├── norms.py
    │   ├── qwen
    │   │   └── qwen_vl.py
    │   ├── respace.py
    │   ├── sa_solver.py
    │   ├── timestep_sampler.py
    │   ├── utils.py
    │   ├── wan
    │   │   ├── __init__.py
    │   │   ├── attention.py
    │   │   ├── clip.py
    │   │   ├── fsdp_utils.py
    │   │   ├── model.py
    │   │   ├── model_wrapper.py
    │   │   ├── t5.py
    │   │   ├── tokenizers.py
    │   │   ├── vae.py
    │   │   └── xlm_roberta.py
    │   └── wan2_2
    │   │   └── vae.py
    ├── scheduler
    │   ├── __init__.py
    │   ├── dpm_solver.py
    │   ├── flow_euler_sampler.py
    │   ├── iddpm.py
    │   ├── lcm_scheduler.py
    │   ├── sa_sampler.py
    │   ├── sa_solver_diffusers.py
    │   ├── scm_scheduler.py
    │   └── trigflow_scheduler.py
    └── utils
    │   ├── __init__.py
    │   ├── checkpoint.py
    │   ├── config.py
    │   ├── config_wan.py
    │   ├── data_sampler.py
    │   ├── dist_utils.py
    │   ├── git.py
    │   ├── import_utils.py
    │   ├── logger.py
    │   ├── lr_scheduler.py
    │   ├── misc.py
    │   └── optimizer.py
├── environment_setup.sh
├── inference_video_scripts
    ├── inference_sana_video.py
    └── inference_sana_video.sh
├── pyproject.toml
├── sana
    ├── cli
    │   ├── run.py
    │   └── upload2hf.py
    └── tools
    │   ├── __init__.py
    │   ├── download.py
    │   └── hf_utils.py
├── scripts
    ├── bash_run_inference_metric.sh
    ├── bash_run_inference_metric_dpg.sh
    ├── bash_run_inference_metric_geneval.sh
    ├── bash_run_inference_metric_imagereward.sh
    ├── infer_metric_run_inference_metric.sh
    ├── infer_metric_run_inference_metric_geneval.sh
    ├── infer_run_inference.sh
    ├── infer_run_inference_geneval.sh
    ├── infer_run_inference_geneval_diffusers.sh
    ├── inference.py
    ├── inference_dpg.py
    ├── inference_geneval.py
    ├── inference_geneval_diffusers.py
    ├── inference_image_reward.py
    ├── inference_sana_sprint.py
    ├── inference_sana_sprint_geneval.py
    ├── interface.py
    └── style.css
├── tests
    └── bash
    │   ├── entry.sh
    │   ├── inference
    │       └── test_inference.sh
    │   ├── setup_test_data.sh
    │   └── training
    │       ├── test_training_all.sh
    │       ├── test_training_fsdp.sh
    │       ├── test_training_vae.sh
    │       └── test_training_video.sh
├── tools
    ├── __init__.py
    ├── controlnet
    │   ├── annotator
    │   │   ├── ckpts
    │   │   │   └── ckpts.txt
    │   │   ├── hed
    │   │   │   └── __init__.py
    │   │   └── util.py
    │   ├── inference_controlnet.py
    │   └── utils.py
    ├── convert_ImgDataset_to_WebDatasetMS_format.py
    ├── convert_py_to_yaml.py
    ├── convert_sana_to_diffusers.py
    ├── convert_sana_to_svdquant.py
    ├── convert_sana_video_to_diffusers.py
    ├── create_wids_metadata.py
    ├── download.py
    ├── inference_scaling
    │   ├── nvila_sana_pick.py
    │   └── nvila_sana_pick.sh
    └── metrics
    │   ├── clip-score
    │       ├── .gitignore
    │       ├── LICENSE
    │       ├── README.md
    │       ├── clip_score.py
    │       ├── setup.py
    │       └── src
    │       │   └── clip_score
    │       │       ├── __init__.py
    │       │       ├── __main__.py
    │       │       └── clip_score.py
    │   ├── compute_clipscore.sh
    │   ├── compute_dpg.sh
    │   ├── compute_fid_embedding.sh
    │   ├── compute_geneval.sh
    │   ├── compute_imagereward.sh
    │   ├── dpg_bench
    │       ├── compute_dpg_bench.py
    │       ├── dpg_bench.csv
    │       ├── metadata.json
    │       └── requirements.txt
    │   ├── geneval
    │       ├── LICENSE
    │       ├── README.md
    │       ├── annotations
    │       │   ├── annotations_clip.csv
    │       │   ├── annotations_if-xl.csv
    │       │   ├── annotations_sdv2.csv
    │       │   └── mturk_hit_template.html
    │       ├── environment.yml
    │       ├── evaluation
    │       │   ├── download_models.sh
    │       │   ├── evaluate_images.py
    │       │   ├── object_names.txt
    │       │   └── summary_scores.py
    │       ├── generation
    │       │   └── diffusers_generate.py
    │       ├── geneval_env.md
    │       ├── images
    │       │   └── geneval_figure_1.png
    │       └── prompts
    │       │   ├── create_prompts.py
    │       │   ├── evaluation_metadata.jsonl
    │       │   ├── generation_prompts.txt
    │       │   └── object_names.txt
    │   ├── image_reward
    │       ├── benchmark-prompts-dict.json
    │       └── compute_image_reward.py
    │   ├── pytorch-fid
    │       ├── .gitignore
    │       ├── CHANGELOG.md
    │       ├── LICENSE
    │       ├── README.md
    │       ├── compute_fid.py
    │       ├── noxfile.py
    │       ├── setup.cfg
    │       ├── setup.py
    │       ├── src
    │       │   └── pytorch_fid
    │       │   │   ├── __init__.py
    │       │   │   ├── __main__.py
    │       │   │   ├── fid_score.py
    │       │   │   └── inception.py
    │       └── tests
    │       │   └── test_fid_score.py
    │   └── utils.py
├── train_scripts
    ├── train.py
    ├── train.sh
    ├── train_dreambooth_lora_sana.py
    ├── train_lora.sh
    ├── train_scm_ladd.py
    └── train_scm_ladd.sh
└── train_video_scripts
    ├── train_video_ivjoint.py
    ├── train_video_ivjoint.sh
    ├── train_video_ivjoint_chunk.py
    └── train_video_ivjoint_chunk.sh


/.github/workflows/bot-autolint.yaml:
--------------------------------------------------------------------------------
 1 | name: Auto Lint (triggered by "auto lint" label)
 2 | on:
 3 |   pull_request:
 4 |     types:
 5 |       - opened
 6 |       - edited
 7 |       - closed
 8 |       - reopened
 9 |       - synchronize
10 |       - labeled
11 |       - unlabeled
12 | # run only one unit test for a branch / tag.
13 | concurrency:
14 |   group: ci-lint-${{ github.head_ref || github.ref }}
15 |   cancel-in-progress: true
16 | jobs:
17 |   lint-by-label:
18 |     if: contains(github.event.pull_request.labels.*.name, 'lint wanted')
19 |     runs-on: ubuntu-latest
20 |     steps:
21 |       - name: Check out Git repository
22 |         uses: actions/checkout@v4
23 |         with:
24 |           token: ${{ secrets.GITHUB_TOKEN }}
25 |           ref: ${{ github.event.pull_request.head.ref }}
26 |       - name: Set up Python
27 |         uses: actions/setup-python@v5
28 |         with:
29 |           python-version: '3.10'
30 |       - name: Test pre-commit hooks
31 |         continue-on-error: true
32 |         uses: pre-commit/action@v3.0.0 # sync with https://github.com/Efficient-Large-Model/VILA-Internal/blob/main/.github/workflows/pre-commit.yaml
33 |         with:
34 |           extra_args: --all-files
35 |       - name: Check if there are any changes
36 |         id: verify_diff
37 |         run: |
38 |           git diff --quiet . || echo "changed=true" >> $GITHUB_OUTPUT
39 |       - name: Commit files
40 |         if: steps.verify_diff.outputs.changed == 'true'
41 |         run: |
42 |           git config --local user.email "action@github.com"
43 |           git config --local user.name "GitHub Action"
44 |           git add .
45 |           git commit -m "[CI-Lint] Fix code style issues with pre-commit ${{ github.sha }}" -a
46 |           git push
47 |       - name: Remove label(s) after lint
48 |         uses: actions-ecosystem/action-remove-labels@v1
49 |         with:
50 |           labels: lint wanted
51 | 


--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
 1 | repos:
 2 |   - repo: https://github.com/pre-commit/pre-commit-hooks
 3 |     rev: v5.0.0
 4 |     hooks:
 5 |       - id: trailing-whitespace
 6 |         name: (Common) Remove trailing whitespaces
 7 |       - id: mixed-line-ending
 8 |         name: (Common) Fix mixed line ending
 9 |         args: [--fix=lf]
10 |       - id: end-of-file-fixer
11 |         name: (Common) Remove extra EOF newlines
12 |       - id: check-merge-conflict
13 |         name: (Common) Check for merge conflicts
14 |       - id: requirements-txt-fixer
15 |         name: (Common) Sort "requirements.txt"
16 |       - id: fix-encoding-pragma
17 |         name: (Python) Remove encoding pragmas
18 |         args: [--remove]
19 |         # - id: debug-statements
20 |         #   name: (Python) Check for debugger imports
21 |       - id: check-json
22 |         name: (JSON) Check syntax
23 |       - id: check-yaml
24 |         name: (YAML) Check syntax
25 |       - id: check-toml
26 |         name: (TOML) Check syntax
27 |   # - repo: https://github.com/shellcheck-py/shellcheck-py
28 |   #   rev: v0.10.0.1
29 |   #   hooks:
30 |   #     - id: shellcheck
31 |   - repo: https://github.com/google/yamlfmt
32 |     rev: v0.13.0
33 |     hooks:
34 |       - id: yamlfmt
35 |   - repo: https://github.com/executablebooks/mdformat
36 |     rev: 0.7.16
37 |     hooks:
38 |       - id: mdformat
39 |         name: (Markdown) Format docs with mdformat
40 |   - repo: https://github.com/asottile/pyupgrade
41 |     rev: v3.2.2
42 |     hooks:
43 |       - id: pyupgrade
44 |         name: (Python) Update syntax for newer versions
45 |         args: [--py37-plus]
46 |   - repo: https://github.com/psf/black
47 |     rev: 22.10.0
48 |     hooks:
49 |       - id: black
50 |         name: (Python) Format code with black
51 |   - repo: https://github.com/pycqa/isort
52 |     rev: 5.12.0
53 |     hooks:
54 |       - id: isort
55 |         name: (Python) Sort imports with isort
56 |   - repo: https://github.com/pre-commit/mirrors-clang-format
57 |     rev: v15.0.4
58 |     hooks:
59 |       - id: clang-format
60 |         name: (C/C++/CUDA) Format code with clang-format
61 |         args: [-style=google, -i]
62 |         types_or: [c, c++, cuda]
63 | 


--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
 1 | cff-version: 1.2.0
 2 | title: 'SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer'
 3 | message: >-
 4 |   If you use this software or research, please cite it using the
 5 |   metadata from this file.
 6 | type: misc
 7 | authors:
 8 |   - given-names: Enze
 9 |     family-names: Xie
10 |   - given-names: Junsong
11 |     family-names: Chen
12 |   - given-names: Junyu
13 |     family-names: Chen
14 |   - given-names: Han
15 |     family-names: Cai
16 |   - given-names: Haotian
17 |     family-names: Tang
18 |   - given-names: Yujun
19 |     family-names: Lin
20 |   - given-names: Zhekai
21 |     family-names: Zhang
22 |   - given-names: Muyang
23 |     family-names: Li
24 |   - given-names: Ligeng
25 |     family-names: Zhu
26 |   - given-names: Yao
27 |     family-names: Lu
28 |   - given-names: Song
29 |     family-names: Han
30 | repository-code: 'https://github.com/NVlabs/Sana'
31 | abstract: >-
32 |   SANA proposes an efficient linear Diffusion Transformer (DiT) for high-resolution
33 |   image synthesis, featuring a depth-growth paradigm, model pruning techniques,
34 |   and inference-time scaling strategies to reduce training costs while maintaining
35 |   generation quality. SANA-Sprint also achieves one-step generation of high-resolution images
36 | keywords:
37 |   - deep-learning
38 |   - diffusion-models
39 |   - transformer
40 |   - image-generation
41 |   - text-to-image
42 |   - efficient-training
43 |   - distillation
44 | license: Apache-2.0
45 | version: 2.0.0
46 | doi: 10.48550/arXiv.2410.10629
47 | date-released: 2024-10-16
48 | 


--------------------------------------------------------------------------------
/CIs/add_license_all.sh:
--------------------------------------------------------------------------------
1 | #/bin/bash
2 | addlicense -s -c 'NVIDIA CORPORATION & AFFILIATES' -ignore "**/*__init__.py" **/*.py
3 | 


--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
 1 | FROM nvcr.io/nvidia/pytorch:24.06-py3
 2 | 
 3 | ENV PATH=/opt/conda/bin:$PATH
 4 | 
 5 | RUN apt-get update && apt-get install -y \
 6 |     libgl1-mesa-glx \
 7 |     libglib2.0-0 \
 8 |     && rm -rf /var/lib/apt/lists/*
 9 | 
10 | WORKDIR /app
11 | 
12 | RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \
13 |     && sh ~/miniconda.sh -b -p /opt/conda \
14 |     && rm ~/miniconda.sh
15 | 
16 | COPY pyproject.toml pyproject.toml
17 | COPY diffusion diffusion
18 | COPY configs configs
19 | COPY sana sana
20 | COPY app app
21 | COPY tools tools
22 | 
23 | COPY environment_setup.sh environment_setup.sh
24 | RUN ./environment_setup.sh
25 | 
26 | CMD ["python", "-u", "-W", "ignore", "app/app_sana.py", "--share", "--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml", "--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth"]
27 | 


--------------------------------------------------------------------------------
/app/safety_check.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #      http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | # This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
17 | import json
18 | from functools import lru_cache
19 | from typing import Dict, Sequence
20 | 
21 | import torch
22 | import transformers
23 | from torch.nn.functional import softmax
24 | from transformers import AutoModelForCausalLM, AutoTokenizer
25 | 
26 | __all__ = ["is_dangerous"]
27 | 
28 | __rank, __local_rank, __world_size, __device = (
29 |     0,
30 |     0,
31 |     1,
32 |     "cuda" if torch.cuda.is_available() else "cpu",
33 | )
34 | 
35 | 
36 | safety_policy = """
37 |     The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
38 | """
39 | 
40 | 
41 | def is_dangerous(tokenizer, model, user_prompt, threshold=0.5):
42 |     prompt = f"""You are a policy expert trying to help determine whether a user
43 |     prompt is in violation of the defined safety policies.
44 | 
45 |     <start_of_turn>
46 |     Human Question: {user_prompt.strip()}
47 |     <end_of_turn>
48 | 
49 |     Our safety principle is defined in the below:
50 | 
51 |     {safety_policy.strip()}
52 | 
53 |     Does the human question violate the above principle? Your answer must start
54 |     with 'Yes' or 'No'. And then walk through step by step to be sure we answer
55 |     correctly.
56 |     """
57 | 
58 |     inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
59 |     with torch.no_grad():
60 |         logits = model(**inputs).logits
61 | 
62 |     # Extract the logits for the Yes and No tokens
63 |     vocab = tokenizer.get_vocab()
64 |     selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
65 | 
66 |     # Convert these logits to a probability with softmax
67 |     probabilities = softmax(selected_logits, dim=0)
68 | 
69 |     # Return probability of 'Yes'
70 |     score = probabilities[0].item()
71 | 
72 |     return score > threshold
73 | 


--------------------------------------------------------------------------------
/asset/Sana.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/Sana.jpg


--------------------------------------------------------------------------------
/asset/app_styles/controlnet_app_style.css:
--------------------------------------------------------------------------------
 1 | @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
 2 | 
 3 | body{align-items: center;}
 4 | .gradio-container{max-width: 1200px !important}
 5 | h1{text-align:center}
 6 | 
 7 | .wrap.svelte-p4aq0j.svelte-p4aq0j {
 8 |     display: none;
 9 | }
10 | 
11 | #column_input, #column_output {
12 |     width: 500px;
13 |     display: flex;
14 |     align-items: center;
15 | }
16 | 
17 | #input_header, #output_header {
18 |     display: flex;
19 |     justify-content: center;
20 |     align-items: center;
21 |     width: 400px;
22 | }
23 | 
24 | #accessibility {
25 |     text-align: center; /* Center-aligns the text */
26 |     margin: auto;       /* Centers the element horizontally */
27 | }
28 | 
29 | #random_seed {height: 71px;}
30 | #run_button {height: 87px;}
31 | 


--------------------------------------------------------------------------------
/asset/apple.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/apple.webp


--------------------------------------------------------------------------------
/asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg


--------------------------------------------------------------------------------
/asset/controlnet/ref_images/a house.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/controlnet/ref_images/a house.png


--------------------------------------------------------------------------------
/asset/controlnet/ref_images/a living room.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/controlnet/ref_images/a living room.png


--------------------------------------------------------------------------------
/asset/controlnet/ref_images/nvidia.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/controlnet/ref_images/nvidia.png


--------------------------------------------------------------------------------
/asset/controlnet/samples_controlnet.json:
--------------------------------------------------------------------------------
 1 | [
 2 |     {
 3 |         "prompt": "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
 4 |         "ref_image_path": "asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg"
 5 |     },
 6 |     {
 7 |         "prompt": "an architecture in INDIA,15th-18th style, with a lot of details",
 8 |         "ref_image_path": "asset/controlnet/ref_images/a house.png"
 9 |     },
10 |     {
11 |         "prompt": "An IKEA modern style living room with sofa, coffee table, stairs, etc., a brand new theme.",
12 |         "ref_image_path": "asset/controlnet/ref_images/a living room.png"
13 |     },
14 |     {
15 |         "prompt": "A modern new living room with sofa, coffee table, carpet, stairs, etc., high quality high detail, high resolution.",
16 |         "ref_image_path": "asset/controlnet/ref_images/a living room.png"
17 |     },
18 |     {
19 |         "prompt": "big eye, vibrant colors, intricate details, captivating gaze, surreal, dreamlike, fantasy, enchanting, mysterious, magical, moonlit, mystical, ethereal, enchanting {macro lens, high aperture, low ISO}",
20 |         "ref_image_path": "asset/controlnet/ref_images/nvidia.png"
21 |     },
22 |     {
23 |         "prompt": "shining eye, bright and vivid colors, radiant glow, sparkling reflections, joyful, uplifting, optimistic, hopeful, magical, luminous, celestial, dreamy {zoom lens, high aperture, natural light, vibrant color film}",
24 |         "ref_image_path": "asset/controlnet/ref_images/nvidia.png"
25 |     }
26 | ]
27 | 


--------------------------------------------------------------------------------
/asset/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/cover.png


--------------------------------------------------------------------------------
/asset/docs/ComfyUI/comfyui.md:
--------------------------------------------------------------------------------
 1 | ## 🖌️ Sana-ComfyUI
 2 | 
 3 | [Original Repo](https://github.com/city96/ComfyUI_ExtraModels)
 4 | 
 5 | ### Model info / implementation
 6 | 
 7 | - Uses Gemma2 2B as the text encoder
 8 | - Multiple resolutions and models available
 9 | - Compressed latent space (32 channels, /32 compression) - needs custom VAE
10 | 
11 | ### Usage
12 | 
13 | 1. All the checkpoints will be downloaded automatically.
14 | 1. KSampler(Flow Euler) is available for now; Flow DPM-Solver will be available soon.
15 | 
16 | ```bash
17 | git clone https://github.com/comfyanonymous/ComfyUI.git
18 | cd ComfyUI
19 | git clone https://github.com/lawrence-cj/ComfyUI_ExtraModels.git custom_nodes/ComfyUI_ExtraModels
20 | 
21 | python main.py
22 | ```
23 | 
24 | ### A sample workflow for Sana
25 | 
26 | [Sana workflow](Sana_FlowEuler.json)
27 | 
28 | ![Sana](https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/page/asset/content/comfyui/sana.jpg)
29 | 
30 | ### A sample for T2I(Sana) + I2V(CogVideoX)
31 | 
32 | [Sana + CogVideoX workflow](Sana_CogVideoX.json)
33 | 
34 | [![Sample T2I + I2V](https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/page/asset/content/comfyui/sana-cogvideox.jpg)](https://nvlabs.github.io/Sana/asset/content/comfyui/Sana_CogVideoX_Fun.mp4)
35 | 
36 | ### A sample workflow for Sana 4096x4096 image (18GB GPU is needed)
37 | 
38 | [Sana workflow](Sana_FlowEuler_4K.json)
39 | 
40 | ![Sana](https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/page/asset/content/comfyui/Sana_4K_workflow.jpg)
41 | 


--------------------------------------------------------------------------------
/asset/docs/inference_scaling/inference_scaling.md:
--------------------------------------------------------------------------------
 1 | ## Inference Time Scaling for SANA-1.5
 2 | 
 3 | ![results](results.jpg)
 4 | 
 5 | We trained a specialized [NVILA-2B](https://huggingface.co/Efficient-Large-Model/NVILA-Lite-2B-Verifier) model to score images, which we named VISA (VIla as SAna verifier). By selecting the top 4 images from 2,048 candidates, we enhanced the GenEval performance of SD1.5 and SANA-1.5-4.8B v2, increasing their scores from 42 to 87 and 81 to 96, respectively.
 6 | 
 7 | ![curve](scaling_curve.jpg)
 8 | 
 9 | Even for smaller number of candidates, like 32, we can also push the performance over 90% for SANA-1.5-4.8B v2 in the GenEval.
10 | 
11 | ### Environment Requirement
12 | 
13 | Dependency setups:
14 | 
15 | ```bash
16 | # other transformers version may also work, but we have not tested
17 | pip install transformers==4.46
18 | pip install git+https://github.com/bfshi/scaling_on_scales.git
19 | ```
20 | 
21 | ### 1. Generate N images with a .pth file for the following selection
22 | 
23 | ```bash
24 | # download the checkpoint for the following generation
25 | huggingface-cli download Efficient-Large-Model/Sana_600M_512px --repo-type model --local-dir output/Sana_600M_512px --local-dir-use-symlinks False
26 | # 32 is a relatively small number for test but can already push the geneval>90% when we verify the SANA-1.5-4.8B v2 model. Set it to larger number like 2048 for the limit of sky.
27 | n_samples=32
28 | pick_number=4
29 | 
30 | output_dir=output/geneval_generated_path
31 | # example
32 | bash scripts/infer_run_inference_geneval.sh \
33 |     configs/sana_config/512ms/Sana_600M_img512.yaml \
34 |     output/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth \
35 |     --img_nums_per_sample=$n_samples \
36 |     --output_dir=$output_dir
37 | ```
38 | 
39 | ### 2. Use NVILA-Verifier to select from the generated images
40 | 
41 | ```bash
42 | bash tools/inference_scaling/nvila_sana_pick.sh \
43 |     $output_dir \
44 |     $n_samples \
45 |     $pick_number
46 | ```
47 | 
48 | ### 3. Calculate the GenEval metric
49 | 
50 | You need to use the GenEval environment for the final evaluation. The document about installation can be found [here](../../../tools/metrics/geneval/geneval_env.md).
51 | 
52 | ```bash
53 | # activate geneval env
54 | conda activate geneval
55 | 
56 | DIR_AFTER_PICK="output/nvila_pick/best_${pick_number}_of_${n_samples}/${output_dir}"
57 | 
58 | bash tools/metrics/compute_geneval.sh $(dirname "$DIR_AFTER_PICK") $(basename "$DIR_AFTER_PICK")
59 | ```
60 | 


--------------------------------------------------------------------------------
/asset/docs/inference_scaling/results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/docs/inference_scaling/results.jpg


--------------------------------------------------------------------------------
/asset/docs/inference_scaling/scaling_curve.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/docs/inference_scaling/scaling_curve.jpg


--------------------------------------------------------------------------------
/asset/docs/quantize/4bit_sana.md:
--------------------------------------------------------------------------------
 1 | <!--Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | #
 4 | # Licensed under the Apache License, Version 2.0 (the "License");
 5 | # you may not use this file except in compliance with the License.
 6 | # You may obtain a copy of the License at
 7 | #
 8 | #      http://www.apache.org/licenses/LICENSE-2.0
 9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License. -->
15 | 
16 | # 4bit SanaPipeline
17 | 
18 | ### 1. Environment setup
19 | 
20 | Follow the official [SVDQuant-Nunchaku](https://github.com/mit-han-lab/nunchaku) repository to set up the environment. The guidance can be found [here](https://github.com/mit-han-lab/nunchaku?tab=readme-ov-file#installation).
21 | 
22 | ### 1-1. Quantize Sana with SVDQuant-4bit (Optional)
23 | 
24 | 1. Convert pth to SVDQuant required safetensor
25 | 
26 | ```
27 | python tools/convert_sana_to_svdquant.py \
28 |       --orig_ckpt_path Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth \
29 |       --model_type SanaMS1.5_1600M_P1_D20 \
30 |       --dtype bf16 \
31 |       --dump_path output/SANA1.5_1.6B_1024px_svdquant_diffusers \
32 |       --save_full_pipeline
33 | ```
34 | 
35 | 2. follow the guidance to compress model
36 |    [Quantization guidance](https://github.com/mit-han-lab/deepcompressor/tree/main/examples/diffusion)
37 | 
38 | ### 2. Code snap for inference
39 | 
40 | Here we show the code snippet for SanaPipeline. For SanaPAGPipeline, please refer to the [SanaPAGPipeline](https://github.com/mit-han-lab/nunchaku/blob/main/examples/sana_1600m_pag.py) section.
41 | 
42 | ```python
43 | import torch
44 | from diffusers import SanaPipeline
45 | 
46 | from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
47 | 
48 | transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
49 | pipe = SanaPipeline.from_pretrained(
50 |     "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
51 |     transformer=transformer,
52 |     variant="bf16",
53 |     torch_dtype=torch.bfloat16,
54 | ).to("cuda")
55 | 
56 | pipe.text_encoder.to(torch.bfloat16)
57 | pipe.vae.to(torch.bfloat16)
58 | 
59 | image = pipe(
60 |     prompt="A cute 🐼 eating 🎋, ink drawing style",
61 |     height=1024,
62 |     width=1024,
63 |     guidance_scale=4.5,
64 |     num_inference_steps=20,
65 |     generator=torch.Generator().manual_seed(42),
66 | ).images[0]
67 | image.save("sana_1600m.png")
68 | ```
69 | 
70 | ### 3. Online demo
71 | 
72 | 1). Launch the 4bit Sana.
73 | 
74 | ```bash
75 | python app/app_sana_4bit.py
76 | ```
77 | 
78 | 2). Compare with BF16 version
79 | 
80 | Refer to the original [Nunchaku-Sana.](https://github.com/mit-han-lab/nunchaku/tree/main/app/sana/t2i) guidance for SanaPAGPipeline
81 | 
82 | ```bash
83 | python app/app_sana_4bit_compare_bf16.py
84 | ```
85 | 


--------------------------------------------------------------------------------
/asset/docs/sana_controlnet.md:
--------------------------------------------------------------------------------
 1 | <!-- Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | 
 3 | Licensed under the Apache License, Version 2.0 (the "License");
 4 | you may not use this file except in compliance with the License.
 5 | You may obtain a copy of the License at
 6 | 
 7 |      http://www.apache.org/licenses/LICENSE-2.0
 8 | 
 9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | 
15 | SPDX-License-Identifier: Apache-2.0 -->
16 | 
17 | ## 🔥 ControlNet
18 | 
19 | We incorporate a ControlNet-like(https://github.com/lllyasviel/ControlNet) module enables fine-grained control over text-to-image diffusion models. We implement a ControlNet-Transformer architecture, specifically tailored for Transformers, achieving explicit controllability alongside high-quality image generation.
20 | 
21 | <p align="center">
22 |   <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/page/asset/content/controlnet/sana_controlnet.jpg"  height=480>
23 | </p>
24 | 
25 | ## Inference of `Sana + ControlNet`
26 | 
27 | ### 1). Gradio Interface
28 | 
29 | ```bash
30 | python app/app_sana_controlnet_hed.py \
31 |         --config configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml \
32 |         --model_path hf://Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED/checkpoints/Sana_1600M_1024px_BF16_ControlNet_HED.pth
33 | ```
34 | 
35 | <p align="center" border-raduis="10px">
36 |   <img src="https://nvlabs.github.io/Sana/asset/content/controlnet/controlnet_app.jpg" width="90%" alt="teaser_page2"/>
37 | </p>
38 | 
39 | ### 2). Inference with JSON file
40 | 
41 | ```bash
42 | python tools/controlnet/inference_controlnet.py \
43 |         --config configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml \
44 |         --model_path hf://Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED/checkpoints/Sana_1600M_1024px_BF16_ControlNet_HED.pth \
45 |         --json_file asset/controlnet/samples_controlnet.json
46 | ```
47 | 
48 | ### 3). Inference code snap
49 | 
50 | ```python
51 | import torch
52 | from PIL import Image
53 | from app.sana_controlnet_pipeline import SanaControlNetPipeline
54 | 
55 | device = "cuda" if torch.cuda.is_available() else "cpu"
56 | 
57 | pipe = SanaControlNetPipeline("configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml")
58 | pipe.from_pretrained("hf://Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED/checkpoints/Sana_1600M_1024px_BF16_ControlNet_HED.pth")
59 | 
60 | ref_image = Image.open("asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg")
61 | prompt = "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape."
62 | 
63 | images = pipe(
64 |     prompt=prompt,
65 |     ref_image=ref_image,
66 |     guidance_scale=4.5,
67 |     num_inference_steps=10,
68 |     sketch_thickness=2,
69 |     generator=torch.Generator(device=device).manual_seed(0),
70 | )
71 | ```
72 | 
73 | ## Training of `Sana + ControlNet`
74 | 
75 | ### Coming soon
76 | 


--------------------------------------------------------------------------------
/asset/example_data/00000000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/example_data/00000000.jpg


--------------------------------------------------------------------------------
/asset/example_data/00000000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/example_data/00000000.png


--------------------------------------------------------------------------------
/asset/example_data/00000000.txt:
--------------------------------------------------------------------------------
1 | a cyberpunk cat with a neon sign that says "Sana".
2 | 


--------------------------------------------------------------------------------
/asset/example_data/00000000_InternVL2-26B.json:
--------------------------------------------------------------------------------
1 | {
2 |     "00000000": {
3 |         "InternVL2-26B": "a cyberpunk cat with a neon sign that says 'Sana'"
4 |     }
5 | }
6 | 


--------------------------------------------------------------------------------
/asset/example_data/00000000_InternVL2-26B_clip_score.json:
--------------------------------------------------------------------------------
1 | {
2 |     "00000000": {
3 |         "InternVL2-26B": "27.1037"
4 |     }
5 | }
6 | 


--------------------------------------------------------------------------------
/asset/example_data/00000000_VILA1-5-13B.json:
--------------------------------------------------------------------------------
1 | {
2 |     "00000000": {
3 |         "VILA1-5-13B": "a cyberpunk cat with a neon sign that says 'Sana'"
4 |     }
5 | }
6 | 


--------------------------------------------------------------------------------
/asset/example_data/00000000_VILA1-5-13B_clip_score.json:
--------------------------------------------------------------------------------
1 | {
2 |     "00000000": {
3 |         "VILA1-5-13B": "27.2321"
4 |     }
5 | }
6 | 


--------------------------------------------------------------------------------
/asset/example_data/00000000_prompt_clip_score.json:
--------------------------------------------------------------------------------
1 | {
2 |     "00000000": {
3 |         "prompt":  "26.7331"
4 |     }
5 | }
6 | 


--------------------------------------------------------------------------------
/asset/example_data/meta_data.json:
--------------------------------------------------------------------------------
1 | {
2 |     "name": "sana-dev",
3 |     "__kind__": "Sana-ImgDataset",
4 |     "img_names": [
5 |         "00000000", "00000000", "00000000.png", "00000000.jpg"
6 |     ]
7 | }
8 | 


--------------------------------------------------------------------------------
/asset/examples.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | examples = [
18 |     [
19 |         "A small cactus with a happy face in the Sahara desert.",
20 |         "flow_dpm-solver",
21 |         20,
22 |         5.0,
23 |         2.5,
24 |     ],
25 |     [
26 |         "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history"
27 |         "of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits "
28 |         "mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret "
29 |         "and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile "
30 |         "as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and "
31 |         "the Parisian streets and city in the background, depth of field, cinematic 35mm film.",
32 |         "flow_dpm-solver",
33 |         20,
34 |         5.0,
35 |         2.5,
36 |     ],
37 |     [
38 |         "An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. "
39 |         "Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. "
40 |         "The quote 'Find the universe within you' is etched in bold letters across the horizon."
41 |         "blue and pink, brilliantly illuminated in the background.",
42 |         "flow_dpm-solver",
43 |         20,
44 |         5.0,
45 |         2.5,
46 |     ],
47 |     [
48 |         "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
49 |         "flow_dpm-solver",
50 |         20,
51 |         5.0,
52 |         2.5,
53 |     ],
54 |     [
55 |         "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
56 |         "flow_dpm-solver",
57 |         20,
58 |         5.0,
59 |         2.5,
60 |     ],
61 |     [
62 |         "a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, "
63 |         "national geographic photo, 8k resolution, crayon art, interactive artwork",
64 |         "flow_dpm-solver",
65 |         20,
66 |         5.0,
67 |         2.5,
68 |     ],
69 | ]
70 | 


--------------------------------------------------------------------------------
/asset/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/logo.png


--------------------------------------------------------------------------------
/asset/mit-logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/mit-logo.jpg


--------------------------------------------------------------------------------
/asset/model-incremental.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/model-incremental.jpg


--------------------------------------------------------------------------------
/asset/model_paths.txt:
--------------------------------------------------------------------------------
1 | output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
2 | output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
3 | 


--------------------------------------------------------------------------------
/asset/paper2video.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/paper2video.jpg


--------------------------------------------------------------------------------
/asset/samples/sample_i2v.txt:
--------------------------------------------------------------------------------
1 | A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle.<image>asset/samples/i2v-1.png
2 | A majestic brown cow with large curved horns gallops across a dusty field under a clear blue sky. The cow's powerful strides kick up clouds of dust, creating a dynamic sense of motion. The camera captures the animal from a low angle, emphasizing its imposing presence against the vast, open landscape. The sunlight highlights the cow's glossy coat, adding depth and vibrancy to the scene. A slow-motion effect enhances the fluidity of the cow's movement, making it appear almost airborne.<image>asset/samples/i2v-2.png
3 | 


--------------------------------------------------------------------------------
/asset/samples/samples_mini.txt:
--------------------------------------------------------------------------------
 1 | A cyberpunk cat with a neon sign that says 'Sana'.
 2 | A small cactus with a happy face in the Sahara desert.
 3 | The towel was on top of the hard counter.
 4 | A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
 5 | I want to supplement vitamin c, please help me paint related food.
 6 | A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
 7 | an old rusted robot wearing pants and a jacket riding skis in a supermarket.
 8 | professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
 9 | Astronaut in a jungle, cold color palette, muted colors, detailed
10 | a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests.
11 | 


--------------------------------------------------------------------------------
/asset/samples/video_prompts_samples.txt:
--------------------------------------------------------------------------------
 1 | Extreme close-up of a thoughtful, gray-haired professor in his 60s, sitting motionless in a Paris café, dressed in a wool coat and beret, pondering the universe. His subtle closed-mouth smile reveals an answer. Golden light, cinematic depth of field, Paris streets blurred in the background. Cinematic 35mm film.
 2 | A woman surrounded by swirling smoke of vibrant colors, warm light bathing her figure. Medium shot, soft focus.
 3 | "Minecraft with the most gorgeous high-res 8K texture pack ever, showcasing detailed landscapes and characters. Smooth camera pans over lush forests, towering mountains, and serene villages. Epic vistas and intricate textures. Wide and close-up shots."
 4 | Japanese animated film style, a young woman standing on a ship's deck, looking back at the camera with a serene expression. The background shows the ocean and sky stretching out behind her. Medium shot focusing on her profile.
 5 | A hyper-speed train's internal window view passing through an old European city, showing fast-moving buildings and landscapes. Medium shot from inside the train.
 6 | A large orange octopus rests on the ocean floor, blending with sand and rocks, tentacles spread, eyes closed. A brown, spiky king crab creeps closer, claws raised. Wide angle captures the vast, clear, sunlit blue sea, focusing on the octopus and crab with a depth of field blur.
 7 | Wildlife along the Kinabatangan River in Borneo, focusing on diverse animals like orangutans, proboscis monkeys, and crocodiles in their natural habitat. Aerial and ground shots showcasing lush rainforest surroundings.
 8 | A lively pink pig running swiftly towards the camera in a bustling Tokyo alleyway, surrounded by neon lights and signs. Close-up, dynamic shot.
 9 | In the daytime, an anime-style white car drives towards the camera, splashing water from a pond as it passes by, medium shot.
10 | A toy robot in blue jeans and a white t-shirt leisurely walking in Antarctica as the sun sets beautifully. The robot has a friendly expression, moving with smooth steps. Wide shot capturing the vast icy landscape and colorful sky.
11 | 


--------------------------------------------------------------------------------
/configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: [data/toy_data]
  3 |   image_size: 1024
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: []
  7 |   external_clipscore_suffixes: []
  8 |   clip_thr_temperature: 0.1
  9 |   clip_thr: 25.0
 10 |   del_img_clip_thr: 22.0
 11 |   load_text_feat: false
 12 |   load_vae_feat: true
 13 |   transform: default_train
 14 |   type: SanaWebDatasetMS
 15 |   sort_dataset: false
 16 | # model config
 17 | model:
 18 |   model: SanaMS_1600M_P1_D20
 19 |   image_size: 1024
 20 |   mixed_precision: bf16
 21 |   fp32_attention: true
 22 |   load_from: hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth
 23 |   aspect_ratio_type: ASPECT_RATIO_1024
 24 |   multi_scale: true
 25 |   attn_type: linear
 26 |   ffn_type: glumbconv
 27 |   mlp_acts:
 28 |     - silu
 29 |     - silu
 30 |     -
 31 |   mlp_ratio: 2.5
 32 |   use_pe: false
 33 |   qk_norm: true
 34 |   cross_norm: true
 35 |   class_dropout_prob: 0.1
 36 | # VAE setting
 37 | vae:
 38 |   vae_type: AutoencoderDC
 39 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 40 |   scale_factor: 0.41407
 41 |   vae_latent_dim: 32
 42 |   vae_downsample_rate: 32
 43 |   sample_posterior: true
 44 | # text encoder
 45 | text_encoder:
 46 |   text_encoder_name: gemma-2-2b-it
 47 |   y_norm: true
 48 |   y_norm_scale_factor: 0.01
 49 |   model_max_length: 300
 50 |   # CHI
 51 |   chi_prompt:
 52 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 53 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 54 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 55 |     - 'Here are examples of how to transform or refine prompts:'
 56 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 57 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 58 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 59 |     - 'User Prompt: '
 60 | # Sana schedule Flow
 61 | scheduler:
 62 |   predict_flow_v: true
 63 |   noise_schedule: linear_flow
 64 |   pred_sigma: false
 65 |   flow_shift: 3.0
 66 |   # logit-normal timestep
 67 |   weighting_scheme: logit_normal
 68 |   logit_mean: 0.0
 69 |   logit_std: 1.0
 70 |   vis_sampler: flow_dpm-solver
 71 | # training setting
 72 | train:
 73 |   use_fsdp: true
 74 |   num_workers: 10
 75 |   seed: 1
 76 |   train_batch_size: 32
 77 |   num_epochs: 100
 78 |   gradient_accumulation_steps: 1
 79 |   grad_checkpointing: true
 80 |   gradient_clip: 0.1
 81 |   optimizer:
 82 |     betas:
 83 |       - 0.9
 84 |       - 0.999
 85 |     eps: 1.0e-10
 86 |     lr: 2.0e-5
 87 |     type: AdamW
 88 |     weight_decay: 0.0
 89 |   lr_schedule: constant
 90 |   lr_schedule_args:
 91 |     num_warmup_steps: 1000
 92 |   local_save_vis: true # if save log image locally
 93 |   visualize: true
 94 |   eval_sampling_steps: 500
 95 |   log_interval: 20
 96 |   save_model_epochs: 5
 97 |   save_model_steps: 500
 98 |   work_dir: output/debug
 99 |   online_metric: false
100 |   eval_metric_step: 2000
101 |   online_metric_dir: metric_helper
102 | 


--------------------------------------------------------------------------------
/configs/sana_app_config/Sana_1600M_app.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: []
  3 |   image_size: 1024
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: []
  7 |   external_clipscore_suffixes: []
  8 |   clip_thr_temperature: 0.1
  9 |   clip_thr: 25.0
 10 |   load_text_feat: false
 11 |   load_vae_feat: false
 12 |   transform: default_train
 13 |   type: SanaWebDatasetMS
 14 |   data:
 15 |   sort_dataset: false
 16 | # model config
 17 | model:
 18 |   model: SanaMS_1600M_P1_D20
 19 |   image_size: 1024
 20 |   mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
 21 |   fp32_attention: true
 22 |   load_from:
 23 |   resume_from:
 24 |   aspect_ratio_type: ASPECT_RATIO_1024
 25 |   multi_scale: true
 26 |   #pe_interpolation: 1.
 27 |   attn_type: linear
 28 |   ffn_type: glumbconv
 29 |   mlp_acts:
 30 |     - silu
 31 |     - silu
 32 |     -
 33 |   mlp_ratio: 2.5
 34 |   use_pe: false
 35 |   qk_norm: false
 36 |   class_dropout_prob: 0.1
 37 |   # CFG & PAG settings
 38 |   pag_applied_layers:
 39 |     - 8
 40 | # VAE setting
 41 | vae:
 42 |   vae_type: AutoencoderDC
 43 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 44 |   scale_factor: 0.41407
 45 |   vae_latent_dim: 32
 46 |   vae_downsample_rate: 32
 47 |   sample_posterior: true
 48 | # text encoder
 49 | text_encoder:
 50 |   text_encoder_name: gemma-2-2b-it
 51 |   y_norm: true
 52 |   y_norm_scale_factor: 0.01
 53 |   model_max_length: 300
 54 |   # CHI
 55 |   chi_prompt:
 56 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 57 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 58 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 59 |     - 'Here are examples of how to transform or refine prompts:'
 60 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 61 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 62 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 63 |     - 'User Prompt: '
 64 | # Sana schedule Flow
 65 | scheduler:
 66 |   predict_flow_v: true
 67 |   noise_schedule: linear_flow
 68 |   pred_sigma: false
 69 |   flow_shift: 3.0
 70 |   # logit-normal timestep
 71 |   weighting_scheme: logit_normal
 72 |   logit_mean: 0.0
 73 |   logit_std: 1.0
 74 |   vis_sampler: flow_dpm-solver
 75 | # training setting
 76 | train:
 77 |   num_workers: 10
 78 |   seed: 1
 79 |   train_batch_size: 64
 80 |   num_epochs: 100
 81 |   gradient_accumulation_steps: 1
 82 |   grad_checkpointing: true
 83 |   gradient_clip: 0.1
 84 |   optimizer:
 85 |     betas:
 86 |       - 0.9
 87 |       - 0.999
 88 |       - 0.9999
 89 |     eps:
 90 |       - 1.0e-30
 91 |       - 1.0e-16
 92 |     lr: 0.0001
 93 |     type: CAMEWrapper
 94 |     weight_decay: 0.0
 95 |   lr_schedule: constant
 96 |   lr_schedule_args:
 97 |     num_warmup_steps: 2000
 98 |   local_save_vis: true # if save log image locally
 99 |   visualize: true
100 |   eval_sampling_steps: 500
101 |   log_interval: 20
102 |   save_model_epochs: 5
103 |   save_model_steps: 500
104 |   work_dir: output/debug
105 |   online_metric: false
106 |   eval_metric_step: 2000
107 |   online_metric_dir: metric_helper
108 | 


--------------------------------------------------------------------------------
/configs/sana_app_config/Sana_600M_app.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: []
  3 |   image_size: 1024
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: []
  7 |   external_clipscore_suffixes: []
  8 |   clip_thr_temperature: 0.1
  9 |   clip_thr: 25.0
 10 |   load_text_feat: false
 11 |   load_vae_feat: true
 12 |   transform: default_train
 13 |   type: SanaWebDatasetMS
 14 |   sort_dataset: false
 15 | # model config
 16 | model:
 17 |   model: SanaMS_600M_P1_D28
 18 |   image_size: 1024
 19 |   mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
 20 |   fp32_attention: true
 21 |   load_from:
 22 |   resume_from:
 23 |   aspect_ratio_type: ASPECT_RATIO_1024
 24 |   multi_scale: true
 25 |   attn_type: linear
 26 |   ffn_type: glumbconv
 27 |   mlp_acts:
 28 |     - silu
 29 |     - silu
 30 |     -
 31 |   mlp_ratio: 2.5
 32 |   use_pe: false
 33 |   qk_norm: false
 34 |   class_dropout_prob: 0.1
 35 |   # CFG & PAG settings
 36 |   pag_applied_layers:
 37 |     - 14
 38 | # VAE setting
 39 | vae:
 40 |   vae_type: AutoencoderDC
 41 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 42 |   scale_factor: 0.41407
 43 |   vae_latent_dim: 32
 44 |   vae_downsample_rate: 32
 45 |   sample_posterior: true
 46 | # text encoder
 47 | text_encoder:
 48 |   text_encoder_name: gemma-2-2b-it
 49 |   y_norm: true
 50 |   y_norm_scale_factor: 0.01
 51 |   model_max_length: 300
 52 |   # CHI
 53 |   chi_prompt:
 54 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 55 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 56 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 57 |     - 'Here are examples of how to transform or refine prompts:'
 58 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 59 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 60 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 61 |     - 'User Prompt: '
 62 | # Sana schedule Flow
 63 | scheduler:
 64 |   predict_flow_v: true
 65 |   noise_schedule: linear_flow
 66 |   pred_sigma: false
 67 |   flow_shift: 4.0
 68 |   # logit-normal timestep
 69 |   weighting_scheme: logit_normal
 70 |   logit_mean: 0.0
 71 |   logit_std: 1.0
 72 |   vis_sampler: flow_dpm-solver
 73 | # training setting
 74 | train:
 75 |   num_workers: 10
 76 |   seed: 1
 77 |   train_batch_size: 64
 78 |   num_epochs: 100
 79 |   gradient_accumulation_steps: 1
 80 |   grad_checkpointing: true
 81 |   gradient_clip: 0.1
 82 |   optimizer:
 83 |     betas:
 84 |       - 0.9
 85 |       - 0.999
 86 |       - 0.9999
 87 |     eps:
 88 |       - 1.0e-30
 89 |       - 1.0e-16
 90 |     lr: 0.0001
 91 |     type: CAMEWrapper
 92 |     weight_decay: 0.0
 93 |   lr_schedule: constant
 94 |   lr_schedule_args:
 95 |     num_warmup_steps: 2000
 96 |   local_save_vis: true # if save log image locally
 97 |   visualize: true
 98 |   eval_sampling_steps: 500
 99 |   log_interval: 20
100 |   save_model_epochs: 5
101 |   save_model_steps: 500
102 |   work_dir: output/debug
103 |   online_metric: false
104 |   eval_metric_step: 2000
105 |   online_metric_dir: metric_helper
106 | 


--------------------------------------------------------------------------------
/configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: [data/toy_data]
  3 |   image_size: 1024
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
  7 |   external_clipscore_suffixes:
  8 |     - _InternVL2-26B_clip_score
  9 |     - _VILA1-5-13B_clip_score
 10 |     - _prompt_clip_score
 11 |   clip_thr_temperature: 0.1
 12 |   clip_thr: 25.0
 13 |   load_text_feat: false
 14 |   load_vae_feat: false
 15 |   transform: default_train
 16 |   type: SanaWebDatasetMS
 17 |   sort_dataset: false
 18 | # model config
 19 | model:
 20 |   model: SanaMS_1600M_P1_D20
 21 |   image_size: 1024
 22 |   mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
 23 |   fp32_attention: true
 24 |   load_from:
 25 |   resume_from:
 26 |   aspect_ratio_type: ASPECT_RATIO_1024
 27 |   multi_scale: true
 28 |   #pe_interpolation: 1.
 29 |   attn_type: linear
 30 |   ffn_type: glumbconv
 31 |   mlp_acts:
 32 |     - silu
 33 |     - silu
 34 |     -
 35 |   mlp_ratio: 2.5
 36 |   use_pe: false
 37 |   qk_norm: false
 38 |   class_dropout_prob: 0.1
 39 |   # PAG
 40 |   pag_applied_layers:
 41 |     - 8
 42 | # VAE setting
 43 | vae:
 44 |   vae_type: AutoencoderDC
 45 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 46 |   scale_factor: 0.41407
 47 |   vae_latent_dim: 32
 48 |   vae_downsample_rate: 32
 49 |   sample_posterior: true
 50 | # text encoder
 51 | text_encoder:
 52 |   text_encoder_name: gemma-2-2b-it
 53 |   y_norm: true
 54 |   y_norm_scale_factor: 0.01
 55 |   model_max_length: 300
 56 |   # CHI
 57 |   chi_prompt:
 58 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 59 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 60 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 61 |     - 'Here are examples of how to transform or refine prompts:'
 62 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 63 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 64 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 65 |     - 'User Prompt: '
 66 | # Sana schedule Flow
 67 | scheduler:
 68 |   predict_flow_v: true
 69 |   noise_schedule: linear_flow
 70 |   pred_sigma: false
 71 |   flow_shift: 3.0
 72 |   # logit-normal timestep
 73 |   weighting_scheme: logit_normal
 74 |   logit_mean: 0.0
 75 |   logit_std: 1.0
 76 |   vis_sampler: flow_dpm-solver
 77 | # training setting
 78 | train:
 79 |   num_workers: 10
 80 |   seed: 1
 81 |   train_batch_size: 64
 82 |   num_epochs: 100
 83 |   gradient_accumulation_steps: 1
 84 |   grad_checkpointing: true
 85 |   gradient_clip: 0.1
 86 |   optimizer:
 87 |     lr: 1.0e-4
 88 |     type: AdamW
 89 |     weight_decay: 0.01
 90 |     eps: 1.0e-8
 91 |     betas: [0.9, 0.999]
 92 |   lr_schedule: constant
 93 |   lr_schedule_args:
 94 |     num_warmup_steps: 2000
 95 |   local_save_vis: true # if save log image locally
 96 |   visualize: true
 97 |   eval_sampling_steps: 500
 98 |   log_interval: 20
 99 |   save_model_epochs: 5
100 |   save_model_steps: 500
101 |   work_dir: output/debug
102 |   online_metric: false
103 |   eval_metric_step: 2000
104 |   online_metric_dir: metric_helper
105 | 


--------------------------------------------------------------------------------
/configs/sana_config/1024ms/Sana_600M_img1024.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: [data/toy_data]
  3 |   image_size: 1024
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
  7 |   external_clipscore_suffixes:
  8 |     - _InternVL2-26B_clip_score
  9 |     - _VILA1-5-13B_clip_score
 10 |     - _prompt_clip_score
 11 |   clip_thr_temperature: 0.1
 12 |   clip_thr: 25.0
 13 |   load_text_feat: false
 14 |   load_vae_feat: false
 15 |   transform: default_train
 16 |   type: SanaWebDatasetMS
 17 |   sort_dataset: false
 18 | # model config
 19 | model:
 20 |   model: SanaMS_600M_P1_D28
 21 |   image_size: 1024
 22 |   mixed_precision: fp16
 23 |   fp32_attention: true
 24 |   load_from:
 25 |   resume_from:
 26 |   aspect_ratio_type: ASPECT_RATIO_1024
 27 |   multi_scale: true
 28 |   attn_type: linear
 29 |   ffn_type: glumbconv
 30 |   mlp_acts:
 31 |     - silu
 32 |     - silu
 33 |     -
 34 |   mlp_ratio: 2.5
 35 |   use_pe: false
 36 |   qk_norm: false
 37 |   class_dropout_prob: 0.1
 38 | # VAE setting
 39 | vae:
 40 |   vae_type: AutoencoderDC
 41 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 42 |   scale_factor: 0.41407
 43 |   vae_latent_dim: 32
 44 |   vae_downsample_rate: 32
 45 |   sample_posterior: true
 46 | # text encoder
 47 | text_encoder:
 48 |   text_encoder_name: gemma-2-2b-it
 49 |   y_norm: true
 50 |   y_norm_scale_factor: 0.01
 51 |   model_max_length: 300
 52 |   # CHI
 53 |   chi_prompt:
 54 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 55 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 56 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 57 |     - 'Here are examples of how to transform or refine prompts:'
 58 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 59 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 60 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 61 |     - 'User Prompt: '
 62 | # Sana schedule Flow
 63 | scheduler:
 64 |   predict_flow_v: true
 65 |   noise_schedule: linear_flow
 66 |   pred_sigma: false
 67 |   flow_shift: 4.0
 68 |   # logit-normal timestep
 69 |   weighting_scheme: logit_normal
 70 |   logit_mean: 0.0
 71 |   logit_std: 1.0
 72 |   vis_sampler: flow_dpm-solver
 73 | # training setting
 74 | train:
 75 |   num_workers: 10
 76 |   seed: 1
 77 |   train_batch_size: 64
 78 |   num_epochs: 100
 79 |   gradient_accumulation_steps: 1
 80 |   grad_checkpointing: true
 81 |   gradient_clip: 0.1
 82 |   optimizer:
 83 |     betas:
 84 |       - 0.9
 85 |       - 0.999
 86 |       - 0.9999
 87 |     eps:
 88 |       - 1.0e-30
 89 |       - 1.0e-16
 90 |     lr: 0.0001
 91 |     type: CAMEWrapper
 92 |     weight_decay: 0.0
 93 |   lr_schedule: constant
 94 |   lr_schedule_args:
 95 |     num_warmup_steps: 2000
 96 |   local_save_vis: true # if save log image locally
 97 |   visualize: true
 98 |   eval_sampling_steps: 500
 99 |   log_interval: 20
100 |   save_model_epochs: 5
101 |   save_model_steps: 500
102 |   work_dir: output/debug
103 |   online_metric: false
104 |   eval_metric_step: 2000
105 |   online_metric_dir: metric_helper
106 | 


--------------------------------------------------------------------------------
/configs/sana_config/512ms/Sana_1600M_img512.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: [data/data_public/dir1]
  3 |   image_size: 512
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
  7 |   external_clipscore_suffixes:
  8 |     - _InternVL2-26B_clip_score
  9 |     - _VILA1-5-13B_clip_score
 10 |     - _prompt_clip_score
 11 |   clip_thr_temperature: 0.1
 12 |   clip_thr: 25.0
 13 |   load_text_feat: false
 14 |   load_vae_feat: false
 15 |   transform: default_train
 16 |   type: SanaWebDatasetMS
 17 |   sort_dataset: false
 18 | # model config
 19 | model:
 20 |   model: SanaMS_1600M_P1_D20
 21 |   image_size: 512
 22 |   mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
 23 |   fp32_attention: true
 24 |   load_from:
 25 |   resume_from:
 26 |   aspect_ratio_type: ASPECT_RATIO_512
 27 |   multi_scale: true
 28 |   attn_type: linear
 29 |   ffn_type: glumbconv
 30 |   mlp_acts:
 31 |     - silu
 32 |     - silu
 33 |     -
 34 |   mlp_ratio: 2.5
 35 |   use_pe: false
 36 |   qk_norm: false
 37 |   class_dropout_prob: 0.1
 38 |   # PAG
 39 |   pag_applied_layers:
 40 |     - 8
 41 | # VAE setting
 42 | vae:
 43 |   vae_type: AutoencoderDC
 44 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 45 |   scale_factor: 0.41407
 46 |   vae_latent_dim: 32
 47 |   vae_downsample_rate: 32
 48 |   sample_posterior: true
 49 | # text encoder
 50 | text_encoder:
 51 |   text_encoder_name: gemma-2-2b-it
 52 |   y_norm: true
 53 |   y_norm_scale_factor: 0.01
 54 |   model_max_length: 300
 55 |   # CHI
 56 |   chi_prompt:
 57 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 58 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 59 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 60 |     - 'Here are examples of how to transform or refine prompts:'
 61 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 62 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 63 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 64 |     - 'User Prompt: '
 65 | # Sana schedule Flow
 66 | scheduler:
 67 |   predict_flow_v: true
 68 |   noise_schedule: linear_flow
 69 |   pred_sigma: false
 70 |   flow_shift: 3.0
 71 |   # logit-normal timestep
 72 |   weighting_scheme: logit_normal
 73 |   logit_mean: 0.0
 74 |   logit_std: 1.0
 75 |   vis_sampler: flow_dpm-solver
 76 | # training setting
 77 | train:
 78 |   num_workers: 10
 79 |   seed: 1
 80 |   train_batch_size: 64
 81 |   num_epochs: 100
 82 |   gradient_accumulation_steps: 1
 83 |   grad_checkpointing: true
 84 |   gradient_clip: 0.1
 85 |   optimizer:
 86 |     betas:
 87 |       - 0.9
 88 |       - 0.999
 89 |       - 0.9999
 90 |     eps:
 91 |       - 1.0e-30
 92 |       - 1.0e-16
 93 |     lr: 0.0001
 94 |     type: CAMEWrapper
 95 |     weight_decay: 0.0
 96 |   lr_schedule: constant
 97 |   lr_schedule_args:
 98 |     num_warmup_steps: 2000
 99 |   local_save_vis: true # if save log image locally
100 |   visualize: true
101 |   eval_sampling_steps: 500
102 |   log_interval: 20
103 |   save_model_epochs: 5
104 |   save_model_steps: 500
105 |   work_dir: output/debug
106 |   online_metric: false
107 |   eval_metric_step: 2000
108 |   online_metric_dir: metric_helper
109 | 


--------------------------------------------------------------------------------
/configs/sana_config/512ms/Sana_600M_img512.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: [data/data_public/dir1]
  3 |   image_size: 512
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
  7 |   external_clipscore_suffixes:
  8 |     - _InternVL2-26B_clip_score
  9 |     - _VILA1-5-13B_clip_score
 10 |     - _prompt_clip_score
 11 |   clip_thr_temperature: 0.1
 12 |   clip_thr: 25.0
 13 |   load_text_feat: false
 14 |   load_vae_feat: false
 15 |   transform: default_train
 16 |   type: SanaWebDatasetMS
 17 |   sort_dataset: false
 18 | # model config
 19 | model:
 20 |   model: SanaMS_600M_P1_D28
 21 |   image_size: 512
 22 |   mixed_precision: fp16
 23 |   fp32_attention: true
 24 |   load_from:
 25 |   resume_from:
 26 |   aspect_ratio_type: ASPECT_RATIO_512
 27 |   multi_scale: true
 28 |   #pe_interpolation: 1.
 29 |   attn_type: linear
 30 |   linear_head_dim: 32
 31 |   ffn_type: glumbconv
 32 |   mlp_acts:
 33 |     - silu
 34 |     - silu
 35 |     - null
 36 |   mlp_ratio: 2.5
 37 |   use_pe: false
 38 |   qk_norm: false
 39 |   class_dropout_prob: 0.1
 40 | # VAE setting
 41 | vae:
 42 |   vae_type: AutoencoderDC
 43 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 44 |   scale_factor: 0.41407
 45 |   vae_latent_dim: 32
 46 |   vae_downsample_rate: 32
 47 |   sample_posterior: true
 48 | # text encoder
 49 | text_encoder:
 50 |   text_encoder_name: gemma-2-2b-it
 51 |   y_norm: true
 52 |   y_norm_scale_factor: 0.01
 53 |   model_max_length: 300
 54 |   # CHI
 55 |   chi_prompt:
 56 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 57 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 58 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 59 |     - 'Here are examples of how to transform or refine prompts:'
 60 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 61 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 62 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 63 |     - 'User Prompt: '
 64 | # Sana schedule Flow
 65 | scheduler:
 66 |   predict_flow_v: true
 67 |   noise_schedule: linear_flow
 68 |   pred_sigma: false
 69 |   flow_shift: 3.0
 70 |   # logit-normal timestep
 71 |   weighting_scheme: logit_normal
 72 |   logit_mean: 0.0
 73 |   logit_std: 1.0
 74 |   vis_sampler: flow_dpm-solver
 75 | # training setting
 76 | train:
 77 |   num_workers: 10
 78 |   seed: 1
 79 |   train_batch_size: 128
 80 |   num_epochs: 100
 81 |   gradient_accumulation_steps: 1
 82 |   grad_checkpointing: true
 83 |   gradient_clip: 0.1
 84 |   optimizer:
 85 |     betas:
 86 |       - 0.9
 87 |       - 0.999
 88 |       - 0.9999
 89 |     eps:
 90 |       - 1.0e-30
 91 |       - 1.0e-16
 92 |     lr: 0.0001
 93 |     type: CAMEWrapper
 94 |     weight_decay: 0.0
 95 |   lr_schedule: constant
 96 |   lr_schedule_args:
 97 |     num_warmup_steps: 2000
 98 |   local_save_vis: true # if save log image locally
 99 |   visualize: true
100 |   eval_sampling_steps: 500
101 |   log_interval: 20
102 |   save_model_epochs: 5
103 |   save_model_steps: 500
104 |   work_dir: output/debug
105 |   online_metric: false
106 |   eval_metric_step: 2000
107 |   online_metric_dir: metric_helper
108 | 


--------------------------------------------------------------------------------
/configs/sana_config/512ms/ci_Sana_600M_img512.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: [data/data_public/vaef32c32_v2_512/dir1]
  3 |   image_size: 512
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
  7 |   external_clipscore_suffixes:
  8 |     - _InternVL2-26B_clip_score
  9 |     - _VILA1-5-13B_clip_score
 10 |     - _prompt_clip_score
 11 |   clip_thr_temperature: 0.1
 12 |   clip_thr: 25.0
 13 |   load_text_feat: false
 14 |   load_vae_feat: false
 15 |   transform: default_train
 16 |   type: SanaWebDatasetMS
 17 |   aspect_ratio_type: ASPECT_RATIO_512
 18 |   sort_dataset: false
 19 | # model config
 20 | model:
 21 |   model: SanaMS_600M_P1_D28
 22 |   image_size: 512
 23 |   mixed_precision: fp16
 24 |   fp32_attention: true
 25 |   load_from:
 26 |   resume_from:
 27 |   multi_scale: true
 28 |   #pe_interpolation: 1.
 29 |   attn_type: linear
 30 |   linear_head_dim: 32
 31 |   ffn_type: glumbconv
 32 |   mlp_acts:
 33 |     - silu
 34 |     - silu
 35 |     - null
 36 |   mlp_ratio: 2.5
 37 |   use_pe: false
 38 |   qk_norm: false
 39 |   class_dropout_prob: 0.1
 40 | # VAE setting
 41 | vae:
 42 |   vae_type: AutoencoderDC
 43 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 44 |   scale_factor: 0.41407
 45 |   vae_latent_dim: 32
 46 |   vae_downsample_rate: 32
 47 |   sample_posterior: true
 48 | # text encoder
 49 | text_encoder:
 50 |   text_encoder_name: gemma-2-2b-it
 51 |   y_norm: true
 52 |   y_norm_scale_factor: 0.01
 53 |   model_max_length: 300
 54 |   # CHI
 55 |   chi_prompt:
 56 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 57 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 58 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 59 |     - 'Here are examples of how to transform or refine prompts:'
 60 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 61 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 62 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 63 |     - 'User Prompt: '
 64 | # Sana schedule Flow
 65 | scheduler:
 66 |   predict_flow_v: true
 67 |   noise_schedule: linear_flow
 68 |   pred_sigma: false
 69 |   flow_shift: 1.0
 70 |   # logit-normal timestep
 71 |   weighting_scheme: logit_normal
 72 |   logit_mean: 0.0
 73 |   logit_std: 1.0
 74 |   vis_sampler: flow_dpm-solver
 75 | # training setting
 76 | train:
 77 |   num_workers: 10
 78 |   seed: 1
 79 |   train_batch_size: 64
 80 |   num_epochs: 1
 81 |   gradient_accumulation_steps: 1
 82 |   grad_checkpointing: true
 83 |   gradient_clip: 0.1
 84 |   optimizer:
 85 |     betas:
 86 |       - 0.9
 87 |       - 0.999
 88 |       - 0.9999
 89 |     eps:
 90 |       - 1.0e-30
 91 |       - 1.0e-16
 92 |     lr: 0.0001
 93 |     type: CAMEWrapper
 94 |     weight_decay: 0.0
 95 |   lr_schedule: constant
 96 |   lr_schedule_args:
 97 |     num_warmup_steps: 2000
 98 |   local_save_vis: true # if save log image locally
 99 |   visualize: true
100 |   eval_sampling_steps: 500
101 |   log_interval: 20
102 |   save_model_epochs: 5
103 |   save_model_steps: 500
104 |   work_dir: output/debug
105 |   online_metric: false
106 |   eval_metric_step: 2000
107 |   online_metric_dir: metric_helper
108 | 


--------------------------------------------------------------------------------
/configs/sana_config/512ms/sample_dataset.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: [asset/example_data]
  3 |   image_size: 512
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] # json fils
  7 |   external_clipscore_suffixes: # json files
  8 |     - _InternVL2-26B_clip_score
  9 |     - _VILA1-5-13B_clip_score
 10 |     - _prompt_clip_score
 11 |   clip_thr_temperature: 0.1
 12 |   clip_thr: 25.0
 13 |   load_text_feat: false
 14 |   load_vae_feat: false
 15 |   transform: default_train
 16 |   type: SanaImgDataset
 17 |   sort_dataset: false
 18 | # model config
 19 | model:
 20 |   model: SanaMS_600M_P1_D28
 21 |   image_size: 512
 22 |   mixed_precision: fp16
 23 |   fp32_attention: true
 24 |   load_from:
 25 |   resume_from:
 26 |   aspect_ratio_type: ASPECT_RATIO_512
 27 |   multi_scale: false
 28 |   #pe_interpolation: 1.
 29 |   attn_type: linear
 30 |   linear_head_dim: 32
 31 |   ffn_type: glumbconv
 32 |   mlp_acts:
 33 |     - silu
 34 |     - silu
 35 |     - null
 36 |   mlp_ratio: 2.5
 37 |   use_pe: false
 38 |   qk_norm: false
 39 |   class_dropout_prob: 0.1
 40 | # VAE setting
 41 | vae:
 42 |   vae_type: AutoencoderDC
 43 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 44 |   scale_factor: 0.41407
 45 |   vae_latent_dim: 32
 46 |   vae_downsample_rate: 32
 47 |   sample_posterior: true
 48 | # text encoder
 49 | text_encoder:
 50 |   text_encoder_name: gemma-2-2b-it
 51 |   y_norm: true
 52 |   y_norm_scale_factor: 0.01
 53 |   model_max_length: 300
 54 |   # CHI
 55 |   chi_prompt:
 56 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 57 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 58 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 59 |     - 'Here are examples of how to transform or refine prompts:'
 60 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 61 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 62 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 63 |     - 'User Prompt: '
 64 | # Sana schedule Flow
 65 | scheduler:
 66 |   predict_flow_v: true
 67 |   noise_schedule: linear_flow
 68 |   pred_sigma: false
 69 |   flow_shift: 1.0
 70 |   # logit-normal timestep
 71 |   weighting_scheme: logit_normal
 72 |   logit_mean: 0.0
 73 |   logit_std: 1.0
 74 |   vis_sampler: flow_dpm-solver
 75 | # training setting
 76 | train:
 77 |   num_workers: 10
 78 |   seed: 1
 79 |   train_batch_size: 128
 80 |   num_epochs: 100
 81 |   gradient_accumulation_steps: 1
 82 |   grad_checkpointing: true
 83 |   gradient_clip: 0.1
 84 |   optimizer:
 85 |     betas:
 86 |       - 0.9
 87 |       - 0.999
 88 |       - 0.9999
 89 |     eps:
 90 |       - 1.0e-30
 91 |       - 1.0e-16
 92 |     lr: 0.0001
 93 |     type: CAMEWrapper
 94 |     weight_decay: 0.0
 95 |   lr_schedule: constant
 96 |   lr_schedule_args:
 97 |     num_warmup_steps: 2000
 98 |   local_save_vis: true # if save log image locally
 99 |   visualize: true
100 |   eval_sampling_steps: 500
101 |   log_interval: 20
102 |   save_model_epochs: 5
103 |   save_model_steps: 500
104 |   work_dir: output/debug
105 |   online_metric: false
106 |   eval_metric_step: 2000
107 |   online_metric_dir: metric_helper
108 | 


--------------------------------------------------------------------------------
/configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: [data/data_public/controlnet_data]
  3 |   image_size: 1024
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: []
  7 |   external_clipscore_suffixes: []
  8 |   clip_thr_temperature: 0.1
  9 |   clip_thr: 25.0
 10 |   load_text_feat: false
 11 |   load_vae_feat: false
 12 |   transform: default_train
 13 |   type: SanaWebDatasetMSControl
 14 |   sort_dataset: false
 15 | # model config
 16 | model:
 17 |   model: SanaMSControlNet_1600M_P1_D20
 18 |   image_size: 1024
 19 |   mixed_precision: bf16
 20 |   fp32_attention: true
 21 |   load_from: hf://Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoint/Sana_1600M_1024px_BF16.pth
 22 |   resume_from:
 23 |   aspect_ratio_type: ASPECT_RATIO_1024
 24 |   multi_scale: true
 25 |   attn_type: linear
 26 |   ffn_type: glumbconv
 27 |   mlp_acts:
 28 |     - silu
 29 |     - silu
 30 |     -
 31 |   mlp_ratio: 2.5
 32 |   use_pe: false
 33 |   qk_norm: false
 34 |   class_dropout_prob: 0.1
 35 | # VAE setting
 36 | vae:
 37 |   vae_type: AutoencoderDC
 38 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 39 |   scale_factor: 0.41407
 40 |   vae_latent_dim: 32
 41 |   vae_downsample_rate: 32
 42 |   sample_posterior: true
 43 |   weight_dtype: bf16
 44 | # text encoder
 45 | text_encoder:
 46 |   text_encoder_name: gemma-2-2b-it
 47 |   y_norm: true
 48 |   y_norm_scale_factor: 0.01
 49 |   model_max_length: 300
 50 |   # CHI
 51 |   chi_prompt:
 52 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 53 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 54 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 55 |     - 'Here are examples of how to transform or refine prompts:'
 56 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 57 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 58 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 59 |     - 'User Prompt: '
 60 | # Sana schedule Flow
 61 | scheduler:
 62 |   predict_flow_v: true
 63 |   noise_schedule: linear_flow
 64 |   pred_sigma: false
 65 |   flow_shift: 3.0
 66 |   # logit-normal timestep
 67 |   weighting_scheme: logit_normal
 68 |   logit_mean: 0.0
 69 |   logit_std: 1.0
 70 |   vis_sampler: flow_dpm-solver
 71 | # training setting
 72 | train:
 73 |   num_workers: 10
 74 |   seed: 1
 75 |   train_batch_size: 16
 76 |   num_epochs: 100
 77 |   gradient_accumulation_steps: 1
 78 |   grad_checkpointing: true
 79 |   gradient_clip: 0.1
 80 |   optimizer:
 81 |     betas:
 82 |       - 0.9
 83 |       - 0.999
 84 |       - 0.9999
 85 |     eps:
 86 |       - 1.0e-30
 87 |       - 1.0e-16
 88 |     lr: 0.0001
 89 |     type: CAMEWrapper
 90 |     weight_decay: 0.0
 91 |   lr_schedule: constant
 92 |   lr_schedule_args:
 93 |     num_warmup_steps: 30
 94 |   local_save_vis: true # if save log image locally
 95 |   visualize: true
 96 |   eval_sampling_steps: 50
 97 |   log_interval: 20
 98 |   save_model_epochs: 5
 99 |   save_model_steps: 500
100 |   work_dir: output/debug
101 |   online_metric: false
102 |   eval_metric_step: 2000
103 |   online_metric_dir: metric_helper
104 | controlnet:
105 |   control_signal_type: "scribble"
106 | 


--------------------------------------------------------------------------------
/configs/sana_controlnet_config/Sana_600M_img1024_controlnet.yaml:
--------------------------------------------------------------------------------
  1 | data:
  2 |   data_dir: [data/data_public/controlnet_data]
  3 |   image_size: 1024
  4 |   caption_proportion:
  5 |     prompt: 1
  6 |   external_caption_suffixes: []
  7 |   external_clipscore_suffixes: []
  8 |   clip_thr_temperature: 0.1
  9 |   clip_thr: 25.0
 10 |   load_text_feat: false
 11 |   load_vae_feat: false
 12 |   transform: default_train
 13 |   type: SanaWebDatasetMSControl
 14 |   sort_dataset: false
 15 | # model config
 16 | model:
 17 |   model: SanaMSControlNet_600M_P1_D28
 18 |   image_size: 1024
 19 |   mixed_precision: fp16
 20 |   fp32_attention: true
 21 |   load_from: hf://Efficient-Large-Model/Sana_600M_1024px/checkpoint/Sana_600M_1024px.pth
 22 |   resume_from:
 23 |   aspect_ratio_type: ASPECT_RATIO_1024
 24 |   multi_scale: true
 25 |   attn_type: linear
 26 |   ffn_type: glumbconv
 27 |   mlp_acts:
 28 |     - silu
 29 |     - silu
 30 |     -
 31 |   mlp_ratio: 2.5
 32 |   use_pe: false
 33 |   qk_norm: false
 34 |   class_dropout_prob: 0.1
 35 | # VAE setting
 36 | vae:
 37 |   vae_type: AutoencoderDC
 38 |   vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
 39 |   scale_factor: 0.41407
 40 |   vae_latent_dim: 32
 41 |   vae_downsample_rate: 32
 42 |   sample_posterior: true
 43 | # text encoder
 44 | text_encoder:
 45 |   text_encoder_name: gemma-2-2b-it
 46 |   y_norm: true
 47 |   y_norm_scale_factor: 0.01
 48 |   model_max_length: 300
 49 |   # CHI
 50 |   chi_prompt:
 51 |     - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
 52 |     - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
 53 |     - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
 54 |     - 'Here are examples of how to transform or refine prompts:'
 55 |     - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
 56 |     - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
 57 |     - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
 58 |     - 'User Prompt: '
 59 | # Sana schedule Flow
 60 | scheduler:
 61 |   predict_flow_v: true
 62 |   noise_schedule: linear_flow
 63 |   pred_sigma: false
 64 |   flow_shift: 4.0
 65 |   # logit-normal timestep
 66 |   weighting_scheme: logit_normal
 67 |   logit_mean: 0.0
 68 |   logit_std: 1.0
 69 |   vis_sampler: flow_dpm-solver
 70 | # training setting
 71 | train:
 72 |   num_workers: 10
 73 |   seed: 1
 74 |   train_batch_size: 16
 75 |   num_epochs: 100
 76 |   gradient_accumulation_steps: 1
 77 |   grad_checkpointing: true
 78 |   gradient_clip: 0.1
 79 |   optimizer:
 80 |     betas:
 81 |       - 0.9
 82 |       - 0.999
 83 |       - 0.9999
 84 |     eps:
 85 |       - 1.0e-30
 86 |       - 1.0e-16
 87 |     lr: 0.0001
 88 |     type: CAMEWrapper
 89 |     weight_decay: 0.0
 90 |   lr_schedule: constant
 91 |   lr_schedule_args:
 92 |     num_warmup_steps: 30
 93 |   local_save_vis: true # if save log image locally
 94 |   visualize: true
 95 |   eval_sampling_steps: 500
 96 |   log_interval: 20
 97 |   save_model_epochs: 5
 98 |   save_model_steps: 500
 99 |   work_dir: output/debug
100 |   online_metric: false
101 |   eval_metric_step: 2000
102 |   online_metric_dir: metric_helper
103 | controlnet:
104 |   control_signal_type: "scribble"
105 | 


--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
 1 | __version__ = "0.2.1.dev0"
 2 | 
 3 | from diffusion.scheduler.dpm_solver import DPMS
 4 | from diffusion.scheduler.flow_euler_sampler import FlowEuler, LTXFlowEuler
 5 | from diffusion.scheduler.iddpm import Scheduler
 6 | from diffusion.scheduler.longlive_flow_euler_sampler import LongLiveFlowEuler
 7 | from diffusion.scheduler.sa_sampler import SASolverSampler
 8 | from diffusion.scheduler.scm_scheduler import SCMScheduler
 9 | from diffusion.scheduler.trigflow_scheduler import TrigFlowScheduler
10 | 


--------------------------------------------------------------------------------
/diffusion/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
2 | from .transforms import get_transform
3 | 


--------------------------------------------------------------------------------
/diffusion/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .sana_data import SanaImgDataset, SanaWebDataset
2 | from .sana_data_multi_scale import DummyDatasetMS, SanaWebDatasetMS
3 | from .utils import *
4 | from .video.sana_video_data import DistributePromptsDataset, SanaZipDataset
5 | 


--------------------------------------------------------------------------------
/diffusion/data/wids/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
 2 | # This file is part of the WebDataset library.
 3 | # See the LICENSE file for licensing terms (BSD-style).
 4 | #
 5 | # flake8: noqa
 6 | 
 7 | from .wids import (
 8 |     ChunkedSampler,
 9 |     DistributedChunkedSampler,
10 |     DistributedLocalSampler,
11 |     DistributedRangedSampler,
12 |     ShardedSampler,
13 |     ShardListDataset,
14 |     ShardListDatasetMulti,
15 |     lru_json_load,
16 | )
17 | 


--------------------------------------------------------------------------------
/diffusion/data/wids/wids_lru.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
18 | from collections import OrderedDict
19 | 
20 | 
21 | class LRUCache:
22 |     def __init__(self, capacity: int, release_handler=None):
23 |         """Initialize a new LRU cache with the given capacity."""
24 |         self.capacity = capacity
25 |         self.cache = OrderedDict()
26 |         self.release_handler = release_handler
27 | 
28 |     def __getitem__(self, key):
29 |         """Return the value associated with the given key, or None."""
30 |         if key not in self.cache:
31 |             return None
32 |         self.cache.move_to_end(key)
33 |         return self.cache[key]
34 | 
35 |     def __setitem__(self, key, value):
36 |         """Associate the given value with the given key."""
37 |         if key in self.cache:
38 |             self.cache.move_to_end(key)
39 |         self.cache[key] = value
40 |         if len(self.cache) > self.capacity:
41 |             key, value = self.cache.popitem(last=False)
42 |             if self.release_handler is not None:
43 |                 self.release_handler(key, value)
44 | 
45 |     def __delitem__(self, key):
46 |         """Remove the given key from the cache."""
47 |         if key in self.cache:
48 |             if self.release_handler is not None:
49 |                 value = self.cache[key]
50 |                 self.release_handler(key, value)
51 |             del self.cache[key]
52 | 
53 |     def __len__(self):
54 |         """Return the number of entries in the cache."""
55 |         return len(self.cache)
56 | 
57 |     def __contains__(self, key):
58 |         """Return whether the cache contains the given key."""
59 |         return key in self.cache
60 | 
61 |     def items(self):
62 |         """Return an iterator over the keys of the cache."""
63 |         return self.cache.items()
64 | 
65 |     def keys(self):
66 |         """Return an iterator over the keys of the cache."""
67 |         return self.cache.keys()
68 | 
69 |     def values(self):
70 |         """Return an iterator over the values of the cache."""
71 |         return self.cache.values()
72 | 
73 |     def clear(self):
74 |         for key in list(self.keys()):
75 |             value = self.cache[key]
76 |             if self.release_handler is not None:
77 |                 self.release_handler(key, value)
78 |             del self[key]
79 | 
80 |     def __del__(self):
81 |         self.clear()
82 | 


--------------------------------------------------------------------------------
/diffusion/data/wids/wids_tar.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
18 | import io
19 | import os
20 | import os.path
21 | import pickle
22 | import re
23 | import tarfile
24 | 
25 | import numpy as np
26 | 
27 | 
28 | def find_index_file(file):
29 |     prefix, last_ext = os.path.splitext(file)
30 |     if re.match("._[0-9]+_
quot;, last_ext):
31 |         return prefix + ".index"
32 |     else:
33 |         return file + ".index"
34 | 
35 | 
36 | class TarFileReader:
37 |     def __init__(self, file, index_file=find_index_file, verbose=True):
38 |         self.verbose = verbose
39 |         if callable(index_file):
40 |             index_file = index_file(file)
41 |         self.index_file = index_file
42 | 
43 |         # Open the tar file and keep it open
44 |         if isinstance(file, str):
45 |             self.tar_file = tarfile.open(file, "r")
46 |         else:
47 |             self.tar_file = tarfile.open(fileobj=file, mode="r")
48 | 
49 |         # Create the index
50 |         self._create_tar_index()
51 | 
52 |     def _create_tar_index(self):
53 |         if self.index_file is not None and os.path.exists(self.index_file):
54 |             if self.verbose:
55 |                 print("Loading tar index from", self.index_file)
56 |             with open(self.index_file, "rb") as stream:
57 |                 self.fnames, self.index = pickle.load(stream)
58 |             return
59 |         # Create an empty list for the index
60 |         self.fnames = []
61 |         self.index = []
62 | 
63 |         if self.verbose:
64 |             print("Creating tar index for", self.tar_file.name, "at", self.index_file)
65 |         # Iterate over the members of the tar file
66 |         for member in self.tar_file:
67 |             # If the member is a file, add it to the index
68 |             if member.isfile():
69 |                 # Get the file's offset
70 |                 offset = self.tar_file.fileobj.tell()
71 |                 self.fnames.append(member.name)
72 |                 self.index.append([offset, member.size])
73 |         if self.verbose:
74 |             print("Done creating tar index for", self.tar_file.name, "at", self.index_file)
75 |         self.index = np.array(self.index)
76 |         if self.index_file is not None:
77 |             if os.path.exists(self.index_file + ".temp"):
78 |                 os.unlink(self.index_file + ".temp")
79 |             with open(self.index_file + ".temp", "wb") as stream:
80 |                 pickle.dump((self.fnames, self.index), stream)
81 |             os.rename(self.index_file + ".temp", self.index_file)
82 | 
83 |     def names(self):
84 |         return self.fnames
85 | 
86 |     def __len__(self):
87 |         return len(self.index)
88 | 
89 |     def get_file(self, i):
90 |         name = self.fnames[i]
91 |         offset, size = self.index[i]
92 |         self.tar_file.fileobj.seek(offset)
93 |         file_bytes = self.tar_file.fileobj.read(size)
94 |         return name, io.BytesIO(file_bytes)
95 | 
96 |     def close(self):
97 |         # Close the tar file
98 |         self.tar_file.close()
99 | 


--------------------------------------------------------------------------------
/diffusion/guiders/__init__.py:
--------------------------------------------------------------------------------
1 | from .adaptive_projected_guidance import AdaptiveProjectedGuidance
2 | 
3 | __all__ = ["AdaptiveProjectedGuidance"]
4 | 


--------------------------------------------------------------------------------
/diffusion/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/model/__init__.py


--------------------------------------------------------------------------------
/diffusion/model/act.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import copy
18 | 
19 | import torch.nn as nn
20 | 
21 | __all__ = ["build_act", "get_act_name"]
22 | 
23 | # register activation function here
24 | #   name: module, kwargs with default values
25 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = {
26 |     "relu": (nn.ReLU, {"inplace": True}),
27 |     "relu6": (nn.ReLU6, {"inplace": True}),
28 |     "hswish": (nn.Hardswish, {"inplace": True}),
29 |     "hsigmoid": (nn.Hardsigmoid, {"inplace": True}),
30 |     "swish": (nn.SiLU, {"inplace": True}),
31 |     "silu": (nn.SiLU, {"inplace": True}),
32 |     "tanh": (nn.Tanh, {}),
33 |     "sigmoid": (nn.Sigmoid, {}),
34 |     "gelu": (nn.GELU, {"approximate": "tanh"}),
35 |     "mish": (nn.Mish, {"inplace": True}),
36 |     "identity": (nn.Identity, {}),
37 | }
38 | 
39 | 
40 | def build_act(name: str or None, **kwargs) -> nn.Module or None:
41 |     if name in REGISTERED_ACT_DICT:
42 |         act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name])
43 |         for key in default_args:
44 |             if key in kwargs:
45 |                 default_args[key] = kwargs[key]
46 |         return act_cls(**default_args)
47 |     elif name is None or name.lower() == "none":
48 |         return None
49 |     else:
50 |         raise ValueError(f"do not support: {name}")
51 | 
52 | 
53 | def get_act_name(act: nn.Module or None) -> str or None:
54 |     if act is None:
55 |         return None
56 |     module2name = {}
57 |     for key, config in REGISTERED_ACT_DICT.items():
58 |         module2name[config[0].__name__] = key
59 |     return module2name.get(type(act).__name__, "unknown")
60 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/model/dc_ae/efficientvit/__init__.py


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/model/dc_ae/efficientvit/apps/__init__.py


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/setup.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import time
  3 | from copy import deepcopy
  4 | from typing import Optional
  5 | 
  6 | import torch.backends.cudnn
  7 | import torch.distributed
  8 | import torch.nn as nn
  9 | 
 10 | from ..apps.utils import (
 11 |     dist_init,
 12 |     dump_config,
 13 |     get_dist_local_rank,
 14 |     get_dist_rank,
 15 |     get_dist_size,
 16 |     init_modules,
 17 |     is_master,
 18 |     load_config,
 19 |     partial_update_config,
 20 |     zero_last_gamma,
 21 | )
 22 | from ..models.utils import build_kwargs_from_config, load_state_dict_from_file
 23 | 
 24 | __all__ = [
 25 |     "save_exp_config",
 26 |     "setup_dist_env",
 27 |     "setup_seed",
 28 |     "setup_exp_config",
 29 |     "init_model",
 30 | ]
 31 | 
 32 | 
 33 | def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
 34 |     if not is_master():
 35 |         return
 36 |     dump_config(exp_config, os.path.join(path, name))
 37 | 
 38 | 
 39 | def setup_dist_env(gpu: Optional[str] = None) -> None:
 40 |     if gpu is not None:
 41 |         os.environ["CUDA_VISIBLE_DEVICES"] = gpu
 42 |     if not torch.distributed.is_initialized():
 43 |         dist_init()
 44 |     torch.backends.cudnn.benchmark = True
 45 |     torch.cuda.set_device(get_dist_local_rank())
 46 | 
 47 | 
 48 | def setup_seed(manual_seed: int, resume: bool) -> None:
 49 |     if resume:
 50 |         manual_seed = int(time.time())
 51 |     manual_seed = get_dist_rank() + manual_seed
 52 |     torch.manual_seed(manual_seed)
 53 |     torch.cuda.manual_seed_all(manual_seed)
 54 | 
 55 | 
 56 | def setup_exp_config(config_path: str, recursive=True, opt_args: Optional[dict] = None) -> dict:
 57 |     # load config
 58 |     if not os.path.isfile(config_path):
 59 |         raise ValueError(config_path)
 60 | 
 61 |     fpaths = [config_path]
 62 |     if recursive:
 63 |         extension = os.path.splitext(config_path)[1]
 64 |         while os.path.dirname(config_path) != config_path:
 65 |             config_path = os.path.dirname(config_path)
 66 |             fpath = os.path.join(config_path, "default" + extension)
 67 |             if os.path.isfile(fpath):
 68 |                 fpaths.append(fpath)
 69 |         fpaths = fpaths[::-1]
 70 | 
 71 |     default_config = load_config(fpaths[0])
 72 |     exp_config = deepcopy(default_config)
 73 |     for fpath in fpaths[1:]:
 74 |         partial_update_config(exp_config, load_config(fpath))
 75 |     # update config via args
 76 |     if opt_args is not None:
 77 |         partial_update_config(exp_config, opt_args)
 78 | 
 79 |     return exp_config
 80 | 
 81 | 
 82 | def init_model(
 83 |     network: nn.Module,
 84 |     init_from: Optional[str] = None,
 85 |     backbone_init_from: Optional[str] = None,
 86 |     rand_init="trunc_normal",
 87 |     last_gamma=None,
 88 | ) -> None:
 89 |     # initialization
 90 |     init_modules(network, init_type=rand_init)
 91 |     # zero gamma of last bn in each block
 92 |     if last_gamma is not None:
 93 |         zero_last_gamma(network, last_gamma)
 94 | 
 95 |     # load weight
 96 |     if init_from is not None and os.path.isfile(init_from):
 97 |         network.load_state_dict(load_state_dict_from_file(init_from))
 98 |         print(f"Loaded init from {init_from}")
 99 |     elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
100 |         network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
101 |         print(f"Loaded backbone init from {backbone_init_from}")
102 |     else:
103 |         print(f"Random init ({rand_init}) with last gamma {last_gamma}")
104 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .run_config import *
2 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/__init__.py:
--------------------------------------------------------------------------------
 1 | from .dist import *
 2 | from .ema import *
 3 | 
 4 | # from .export import *
 5 | from .image import *
 6 | from .init import *
 7 | from .lr import *
 8 | from .metric import *
 9 | from .misc import *
10 | from .opt import *
11 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/dist.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import os
18 | from typing import Union
19 | 
20 | import torch
21 | import torch.distributed
22 | 
23 | from ...models.utils.list import list_mean, list_sum
24 | 
25 | __all__ = [
26 |     "dist_init",
27 |     "is_dist_initialized",
28 |     "get_dist_rank",
29 |     "get_dist_size",
30 |     "is_master",
31 |     "dist_barrier",
32 |     "get_dist_local_rank",
33 |     "sync_tensor",
34 | ]
35 | 
36 | 
37 | def dist_init() -> None:
38 |     if is_dist_initialized():
39 |         return
40 |     try:
41 |         torch.distributed.init_process_group(backend="nccl")
42 |         assert torch.distributed.is_initialized()
43 |     except Exception:
44 |         os.environ["RANK"] = "0"
45 |         os.environ["WORLD_SIZE"] = "1"
46 |         os.environ["LOCAL_RANK"] = "0"
47 |         print("warning: dist not init")
48 | 
49 | 
50 | def is_dist_initialized() -> bool:
51 |     return torch.distributed.is_initialized()
52 | 
53 | 
54 | def get_dist_rank() -> int:
55 |     return int(os.environ["RANK"])
56 | 
57 | 
58 | def get_dist_size() -> int:
59 |     return int(os.environ["WORLD_SIZE"])
60 | 
61 | 
62 | def is_master() -> bool:
63 |     return get_dist_rank() == 0
64 | 
65 | 
66 | def dist_barrier() -> None:
67 |     if is_dist_initialized():
68 |         torch.distributed.barrier()
69 | 
70 | 
71 | def get_dist_local_rank() -> int:
72 |     return int(os.environ["LOCAL_RANK"])
73 | 
74 | 
75 | def sync_tensor(tensor: Union[torch.Tensor, float], reduce="mean") -> Union[torch.Tensor, list[torch.Tensor]]:
76 |     if not is_dist_initialized():
77 |         return tensor
78 |     if not isinstance(tensor, torch.Tensor):
79 |         tensor = torch.Tensor(1).fill_(tensor).cuda()
80 |     tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
81 |     torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
82 |     if reduce == "mean":
83 |         return list_mean(tensor_list)
84 |     elif reduce == "sum":
85 |         return list_sum(tensor_list)
86 |     elif reduce == "cat":
87 |         return torch.cat(tensor_list, dim=0)
88 |     elif reduce == "root":
89 |         return tensor_list[0]
90 |     else:
91 |         return tensor_list
92 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/ema.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import copy
18 | import math
19 | 
20 | import torch
21 | import torch.nn as nn
22 | 
23 | from ...models.utils import is_parallel
24 | 
25 | __all__ = ["EMA"]
26 | 
27 | 
28 | def update_ema(ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float) -> None:
29 |     for k, v in ema.state_dict().items():
30 |         if v.dtype.is_floating_point:
31 |             v -= (1.0 - decay) * (v - new_state_dict[k].detach())
32 | 
33 | 
34 | class EMA:
35 |     def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
36 |         self.shadows = copy.deepcopy(model.module if is_parallel(model) else model).eval()
37 |         self.decay = decay
38 |         self.warmup_steps = warmup_steps
39 | 
40 |         for p in self.shadows.parameters():
41 |             p.requires_grad = False
42 | 
43 |     def step(self, model: nn.Module, global_step: int) -> None:
44 |         with torch.no_grad():
45 |             msd = (model.module if is_parallel(model) else model).state_dict()
46 |             update_ema(self.shadows, msd, self.decay * (1 - math.exp(-global_step / self.warmup_steps)))
47 | 
48 |     def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
49 |         return {self.decay: self.shadows.state_dict()}
50 | 
51 |     def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
52 |         for decay in state_dict:
53 |             if decay == self.decay:
54 |                 self.shadows.load_state_dict(state_dict[decay])
55 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/export.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import io
18 | import os
19 | from typing import Any
20 | 
21 | import onnx
22 | import torch
23 | import torch.nn as nn
24 | from onnxsim import simplify as simplify_func
25 | 
26 | __all__ = ["export_onnx"]
27 | 
28 | 
29 | def export_onnx(model: nn.Module, export_path: str, sample_inputs: Any, simplify=True, opset=11) -> None:
30 |     """Export a model to a platform-specific onnx format.
31 | 
32 |     Args:
33 |         model: a torch.nn.Module object.
34 |         export_path: export location.
35 |         sample_inputs: Any.
36 |         simplify: a flag to turn on onnx-simplifier
37 |         opset: int
38 |     """
39 |     model.eval()
40 | 
41 |     buffer = io.BytesIO()
42 |     with torch.no_grad():
43 |         torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
44 |         buffer.seek(0, 0)
45 |         if simplify:
46 |             onnx_model = onnx.load_model(buffer)
47 |             onnx_model, success = simplify_func(onnx_model)
48 |             assert success
49 |             new_buffer = io.BytesIO()
50 |             onnx.save(onnx_model, new_buffer)
51 |             buffer = new_buffer
52 |             buffer.seek(0, 0)
53 | 
54 |     if buffer.getbuffer().nbytes > 0:
55 |         save_dir = os.path.dirname(export_path)
56 |         os.makedirs(save_dir, exist_ok=True)
57 |         with open(export_path, "wb") as f:
58 |             f.write(buffer.read())
59 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/init.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | from typing import Union
18 | 
19 | import torch
20 | import torch.nn as nn
21 | from torch.nn.modules.batchnorm import _BatchNorm
22 | 
23 | __all__ = ["init_modules", "zero_last_gamma"]
24 | 
25 | 
26 | def init_modules(model: Union[nn.Module, list[nn.Module]], init_type="trunc_normal") -> None:
27 |     _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
28 | 
29 |     if isinstance(model, list):
30 |         for sub_module in model:
31 |             init_modules(sub_module, init_type)
32 |     else:
33 |         init_params = init_type.split("@")
34 |         init_params = float(init_params[1]) if len(init_params) > 1 else None
35 | 
36 |         if init_type.startswith("trunc_normal"):
37 |             init_func = lambda param: nn.init.trunc_normal_(
38 |                 param, std=(_DEFAULT_INIT_PARAM["trunc_normal"] if init_params is None else init_params)
39 |             )
40 |         else:
41 |             raise NotImplementedError
42 | 
43 |         for m in model.modules():
44 |             if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
45 |                 init_func(m.weight)
46 |                 if m.bias is not None:
47 |                     m.bias.data.zero_()
48 |             elif isinstance(m, nn.Embedding):
49 |                 init_func(m.weight)
50 |             elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
51 |                 m.weight.data.fill_(1)
52 |                 m.bias.data.zero_()
53 |             else:
54 |                 weight = getattr(m, "weight", None)
55 |                 bias = getattr(m, "bias", None)
56 |                 if isinstance(weight, torch.nn.Parameter):
57 |                     init_func(weight)
58 |                 if isinstance(bias, torch.nn.Parameter):
59 |                     bias.data.zero_()
60 | 
61 | 
62 | def zero_last_gamma(model: nn.Module, init_val=0) -> None:
63 |     import efficientvit.models.nn.ops as ops
64 | 
65 |     for m in model.modules():
66 |         if isinstance(m, ops.ResidualBlock) and isinstance(m.shortcut, ops.IdentityLayer):
67 |             if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
68 |                 parent_module = m.main.point_conv
69 |             elif isinstance(m.main, ops.ResBlock):
70 |                 parent_module = m.main.conv2
71 |             elif isinstance(m.main, ops.ConvLayer):
72 |                 parent_module = m.main
73 |             elif isinstance(m.main, (ops.LiteMLA)):
74 |                 parent_module = m.main.proj
75 |             else:
76 |                 parent_module = None
77 |             if parent_module is not None:
78 |                 norm = getattr(parent_module, "norm", None)
79 |                 if norm is not None:
80 |                     nn.init.constant_(norm.weight, init_val)
81 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/lr.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import math
18 | from typing import Union
19 | 
20 | import torch
21 | 
22 | from ...models.utils.list import val2list
23 | 
24 | __all__ = ["CosineLRwithWarmup", "ConstantLRwithWarmup"]
25 | 
26 | 
27 | class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
28 |     def __init__(
29 |         self,
30 |         optimizer: torch.optim.Optimizer,
31 |         warmup_steps: int,
32 |         warmup_lr: float,
33 |         decay_steps: Union[int, list[int]],
34 |         last_epoch: int = -1,
35 |     ) -> None:
36 |         self.warmup_steps = warmup_steps
37 |         self.warmup_lr = warmup_lr
38 |         self.decay_steps = val2list(decay_steps)
39 |         super().__init__(optimizer, last_epoch)
40 | 
41 |     def get_lr(self) -> list[float]:
42 |         if self.last_epoch < self.warmup_steps:
43 |             return [
44 |                 (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr
45 |                 for base_lr in self.base_lrs
46 |             ]
47 |         else:
48 |             current_steps = self.last_epoch - self.warmup_steps
49 |             decay_steps = [0] + self.decay_steps
50 |             idx = len(decay_steps) - 2
51 |             for i, decay_step in enumerate(decay_steps[:-1]):
52 |                 if decay_step <= current_steps < decay_steps[i + 1]:
53 |                     idx = i
54 |                     break
55 |             current_steps -= decay_steps[idx]
56 |             decay_step = decay_steps[idx + 1] - decay_steps[idx]
57 |             return [0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step)) for base_lr in self.base_lrs]
58 | 
59 | 
60 | class ConstantLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
61 |     def __init__(
62 |         self,
63 |         optimizer: torch.optim.Optimizer,
64 |         warmup_steps: int,
65 |         warmup_lr: float,
66 |         last_epoch: int = -1,
67 |     ) -> None:
68 |         self.warmup_steps = warmup_steps
69 |         self.warmup_lr = warmup_lr
70 |         super().__init__(optimizer, last_epoch)
71 | 
72 |     def get_lr(self) -> list[float]:
73 |         if self.last_epoch < self.warmup_steps:
74 |             return [
75 |                 (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr
76 |                 for base_lr in self.base_lrs
77 |             ]
78 |         else:
79 |             return self.base_lrs
80 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/metric.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | from typing import Union
18 | 
19 | import torch
20 | 
21 | from ...apps.utils.dist import sync_tensor
22 | 
23 | __all__ = ["AverageMeter"]
24 | 
25 | 
26 | class AverageMeter:
27 |     """Computes and stores the average and current value."""
28 | 
29 |     def __init__(self, is_distributed=True):
30 |         self.is_distributed = is_distributed
31 |         self.sum = 0
32 |         self.count = 0
33 | 
34 |     def _sync(self, val: Union[torch.Tensor, int, float]) -> Union[torch.Tensor, int, float]:
35 |         return sync_tensor(val, reduce="sum") if self.is_distributed else val
36 | 
37 |     def update(self, val: Union[torch.Tensor, int, float], delta_n=1):
38 |         self.count += self._sync(delta_n)
39 |         self.sum += self._sync(val * delta_n)
40 | 
41 |     def get_count(self) -> Union[torch.Tensor, int, float]:
42 |         return self.count.item() if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 else self.count
43 | 
44 |     @property
45 |     def avg(self):
46 |         avg = -1 if self.count == 0 else self.sum / self.count
47 |         return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
48 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/opt.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | from typing import Any, Optional
18 | 
19 | import torch
20 | 
21 | __all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]
22 | 
23 | # register optimizer here
24 | #   name: optimizer, kwargs with default values
25 | REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, Any]]] = {
26 |     "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
27 |     "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
28 |     "adamw": (torch.optim.AdamW, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
29 | }
30 | 
31 | 
32 | def build_optimizer(
33 |     net_params, optimizer_name: str, optimizer_params: Optional[dict], init_lr: float
34 | ) -> torch.optim.Optimizer:
35 |     optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
36 |     optimizer_params = {} if optimizer_params is None else optimizer_params
37 | 
38 |     for key in default_params:
39 |         if key in optimizer_params:
40 |             default_params[key] = optimizer_params[key]
41 |     optimizer = optimizer_class(net_params, init_lr, **default_params)
42 |     return optimizer
43 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/model/dc_ae/efficientvit/models/__init__.py


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/efficientvit/__init__.py:
--------------------------------------------------------------------------------
1 | from .dc_ae import *
2 | from .dc_ae_with_temporal import *
3 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .act import *
2 | from .drop import *
3 | from .norm import *
4 | from .ops import *
5 | from .triton_rms_norm import *
6 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/nn/act.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | from functools import partial
18 | from typing import Optional
19 | 
20 | import torch.nn as nn
21 | 
22 | from ...models.utils import build_kwargs_from_config
23 | 
24 | __all__ = ["build_act"]
25 | 
26 | 
27 | # register activation function here
28 | REGISTERED_ACT_DICT: dict[str, type] = {
29 |     "relu": nn.ReLU,
30 |     "relu6": nn.ReLU6,
31 |     "hswish": nn.Hardswish,
32 |     "silu": nn.SiLU,
33 |     "gelu": partial(nn.GELU, approximate="tanh"),
34 | }
35 | 
36 | 
37 | def build_act(name: str, **kwargs) -> Optional[nn.Module]:
38 |     if name in REGISTERED_ACT_DICT:
39 |         act_cls = REGISTERED_ACT_DICT[name]
40 |         args = build_kwargs_from_config(kwargs, act_cls)
41 |         return act_cls(**args)
42 |     else:
43 |         return None
44 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .list import *
2 | from .network import *
3 | from .random import *
4 | from .video import *
5 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/utils/list.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | from typing import Any, Optional, Union
18 | 
19 | __all__ = [
20 |     "list_sum",
21 |     "list_mean",
22 |     "weighted_list_sum",
23 |     "list_join",
24 |     "val2list",
25 |     "val2tuple",
26 |     "squeeze_list",
27 | ]
28 | 
29 | 
30 | def list_sum(x: list) -> Any:
31 |     return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
32 | 
33 | 
34 | def list_mean(x: list) -> Any:
35 |     return list_sum(x) / len(x)
36 | 
37 | 
38 | def weighted_list_sum(x: list, weights: list) -> Any:
39 |     assert len(x) == len(weights)
40 |     return x[0] * weights[0] if len(x) == 1 else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:])
41 | 
42 | 
43 | def list_join(x: list, sep="\t", format_str="%s") -> str:
44 |     return sep.join([format_str % val for val in x])
45 | 
46 | 
47 | def val2list(x: Union[list, tuple, Any], repeat_time=1) -> list:
48 |     if isinstance(x, (list, tuple)):
49 |         return list(x)
50 |     return [x for _ in range(repeat_time)]
51 | 
52 | 
53 | def val2tuple(x: Union[list, tuple, Any], min_len: int = 1, idx_repeat: int = -1) -> tuple:
54 |     x = val2list(x)
55 | 
56 |     # repeat elements if necessary
57 |     if len(x) > 0:
58 |         x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
59 | 
60 |     return tuple(x)
61 | 
62 | 
63 | def squeeze_list(x: Optional[list]) -> Union[list, Any]:
64 |     if x is not None and len(x) == 1:
65 |         return x[0]
66 |     else:
67 |         return x
68 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/utils/random.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | from typing import Any, Optional, Union
18 | 
19 | import numpy as np
20 | import torch
21 | 
22 | __all__ = [
23 |     "torch_randint",
24 |     "torch_random",
25 |     "torch_shuffle",
26 |     "torch_uniform",
27 |     "torch_random_choices",
28 | ]
29 | 
30 | 
31 | def torch_randint(low: int, high: int, generator: Optional[torch.Generator] = None) -> int:
32 |     """uniform: [low, high)"""
33 |     if low == high:
34 |         return low
35 |     else:
36 |         assert low < high
37 |         return int(torch.randint(low=low, high=high, generator=generator, size=(1,)))
38 | 
39 | 
40 | def torch_random(generator: Optional[torch.Generator] = None) -> float:
41 |     """uniform distribution on the interval [0, 1)"""
42 |     return float(torch.rand(1, generator=generator))
43 | 
44 | 
45 | def torch_shuffle(src_list: list[Any], generator: Optional[torch.Generator] = None) -> list[Any]:
46 |     rand_indexes = torch.randperm(len(src_list), generator=generator).tolist()
47 |     return [src_list[i] for i in rand_indexes]
48 | 
49 | 
50 | def torch_uniform(low: float, high: float, generator: Optional[torch.Generator] = None) -> float:
51 |     """uniform distribution on the interval [low, high)"""
52 |     rand_val = torch_random(generator)
53 |     return (high - low) * rand_val + low
54 | 
55 | 
56 | def torch_random_choices(
57 |     src_list: list[Any],
58 |     generator: Optional[torch.Generator] = None,
59 |     k=1,
60 |     weight_list: Optional[list[float]] = None,
61 | ) -> Union[Any, list]:
62 |     if weight_list is None:
63 |         rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,))
64 |         out_list = [src_list[i] for i in rand_idx]
65 |     else:
66 |         assert len(weight_list) == len(src_list)
67 |         accumulate_weight_list = np.cumsum(weight_list)
68 | 
69 |         out_list = []
70 |         for _ in range(k):
71 |             val = torch_uniform(0, accumulate_weight_list[-1], generator)
72 |             active_id = 0
73 |             for i, weight_val in enumerate(accumulate_weight_list):
74 |                 active_id = i
75 |                 if weight_val > val:
76 |                     break
77 |             out_list.append(src_list[active_id])
78 | 
79 |     return out_list[0] if k == 1 else out_list
80 | 


--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/utils/video.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | 
  3 | import torch
  4 | import torch.nn.functional as F
  5 | 
  6 | 
  7 | def chunked_interpolate(x, scale_factor, mode="nearest"):
  8 |     """
  9 |     Interpolate large tensors by chunking along the channel dimension. https://discuss.pytorch.org/t/error-using-f-interpolate-for-large-3d-input/207859
 10 |     Only supports 'nearest' interpolation mode.
 11 | 
 12 |     Args:
 13 |         x (torch.Tensor): Input tensor (B, C, D, H, W)
 14 |         scale_factor: Tuple of scaling factors (d, h, w)
 15 | 
 16 |     Returns:
 17 |         torch.Tensor: Interpolated tensor
 18 |     """
 19 |     assert (
 20 |         mode == "nearest"
 21 |     ), "Only the nearest mode is supported"  # actually other modes are theoretically supported but not tested
 22 |     if len(x.shape) != 5:
 23 |         raise ValueError("Expected 5D input tensor (B, C, D, H, W)")
 24 | 
 25 |     # Calculate max chunk size to avoid int32 overflow. num_elements < max_int32
 26 |     # Max int32 is 2^31 - 1
 27 |     max_elements_per_chunk = 2**31 - 1
 28 | 
 29 |     # Calculate output spatial dimensions
 30 |     out_d = math.ceil(x.shape[2] * scale_factor[0])
 31 |     out_h = math.ceil(x.shape[3] * scale_factor[1])
 32 |     out_w = math.ceil(x.shape[4] * scale_factor[2])
 33 | 
 34 |     # Calculate max channels per chunk to stay under limit
 35 |     elements_per_channel = out_d * out_h * out_w
 36 |     max_channels = max_elements_per_chunk // (x.shape[0] * elements_per_channel)
 37 | 
 38 |     # Use smaller of max channels or input channels
 39 |     chunk_size = min(max_channels, x.shape[1])
 40 | 
 41 |     # Ensure at least 1 channel per chunk
 42 |     chunk_size = max(1, chunk_size)
 43 | 
 44 |     chunks = []
 45 |     for i in range(0, x.shape[1], chunk_size):
 46 |         start_idx = i
 47 |         end_idx = min(i + chunk_size, x.shape[1])
 48 | 
 49 |         chunk = x[:, start_idx:end_idx, :, :, :]
 50 | 
 51 |         interpolated_chunk = F.interpolate(chunk, scale_factor=scale_factor, mode="nearest")
 52 | 
 53 |         chunks.append(interpolated_chunk)
 54 | 
 55 |     if not chunks:
 56 |         raise ValueError(f"No chunks were generated. Input shape: {x.shape}")
 57 | 
 58 |     # Concatenate chunks along channel dimension
 59 |     return torch.cat(chunks, dim=1)
 60 | 
 61 | 
 62 | def pixel_shuffle_3d(x, upscale_factor):
 63 |     """
 64 |     3D pixelshuffle operation.
 65 |     """
 66 |     B, C, T, H, W = x.shape
 67 |     r = upscale_factor
 68 |     assert C % (r * r * r) == 0, "channel number must be a multiple of the cube of the upsampling factor"
 69 | 
 70 |     C_new = C // (r * r * r)
 71 |     x = x.view(B, C_new, r, r, r, T, H, W)
 72 | 
 73 |     x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
 74 | 
 75 |     y = x.reshape(B, C_new, T * r, H * r, W * r)
 76 |     return y
 77 | 
 78 | 
 79 | def pixel_unshuffle_3d(x, downsample_factor):
 80 |     """
 81 |     3D pixel unshuffle operation.
 82 |     """
 83 |     B, C, T, H, W = x.shape
 84 | 
 85 |     r = downsample_factor
 86 |     assert T % r == 0, f"time dimension must be a multiple of the downsampling factor, got shape {x.shape}"
 87 |     assert H % r == 0, f"height dimension must be a multiple of the downsampling factor, got shape {x.shape}"
 88 |     assert W % r == 0, f"width dimension must be a multiple of the downsampling factor, got shape {x.shape}"
 89 |     T_new = T // r
 90 |     H_new = H // r
 91 |     W_new = W // r
 92 |     C_new = C * (r * r * r)
 93 | 
 94 |     x = x.view(B, C, T_new, r, H_new, r, W_new, r)
 95 |     x = x.permute(0, 1, 3, 5, 7, 2, 4, 6)
 96 |     y = x.reshape(B, C_new, T_new, H_new, W_new)
 97 |     return y
 98 | 
 99 | 
100 | def ceil_to_divisible(n: int, dividend: int) -> int:
101 |     return math.ceil(dividend / (dividend // n))
102 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/__init__.py:
--------------------------------------------------------------------------------
 1 | from .sana import (
 2 |     Sana,
 3 |     SanaBlock,
 4 |     get_1d_sincos_pos_embed_from_grid,
 5 |     get_2d_sincos_pos_embed,
 6 |     get_2d_sincos_pos_embed_from_grid,
 7 | )
 8 | from .sana_multi_scale import (
 9 |     SanaMS,
10 |     SanaMS_600M_P1_D28,
11 |     SanaMS_600M_P2_D28,
12 |     SanaMS_600M_P4_D28,
13 |     SanaMS_1600M_P1_D20,
14 |     SanaMS_1600M_P2_D20,
15 |     SanaMSBlock,
16 | )
17 | from .sana_multi_scale_adaln import (
18 |     SanaMSAdaLN,
19 |     SanaMSAdaLN_600M_P1_D28,
20 |     SanaMSAdaLN_600M_P2_D28,
21 |     SanaMSAdaLN_600M_P4_D28,
22 |     SanaMSAdaLN_1600M_P1_D20,
23 |     SanaMSAdaLN_1600M_P2_D20,
24 |     SanaMSAdaLNBlock,
25 | )
26 | from .sana_multi_scale_controlnet import SanaMSControlNet_600M_P1_D28
27 | from .sana_multi_scale_video import (
28 |     SanaMSVideo,
29 |     SanaMSVideo_600M_P1_D28,
30 |     SanaMSVideo_600M_P2_D28,
31 |     SanaMSVideo_2000M_P1_D20,
32 |     SanaMSVideo_2000M_P2_D20,
33 | )
34 | from .sana_U_shape import (
35 |     SanaU,
36 |     SanaU_600M_P1_D28,
37 |     SanaU_600M_P2_D28,
38 |     SanaU_600M_P4_D28,
39 |     SanaU_1600M_P1_D20,
40 |     SanaU_1600M_P2_D20,
41 |     SanaUBlock,
42 | )
43 | from .sana_U_shape_multi_scale import (
44 |     SanaUMS,
45 |     SanaUMS_600M_P1_D28,
46 |     SanaUMS_600M_P2_D28,
47 |     SanaUMS_600M_P4_D28,
48 |     SanaUMS_1600M_P1_D20,
49 |     SanaUMS_1600M_P2_D20,
50 |     SanaUMSBlock,
51 | )
52 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | from .triton_lite_mla import *
18 | from .triton_lite_mla_fwd import *
19 | from .triton_mb_conv_pre_glu import *
20 | 
21 | # from .flash_attn import *
22 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/flash_attn.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import torch
18 | from flash_attn import flash_attn_func
19 | from torch import nn
20 | from torch.nn import functional as F
21 | 
22 | 
23 | class FlashAttention(nn.Module):
24 |     def __init__(self, dim: int, num_heads: int):
25 |         super().__init__()
26 |         self.dim = dim
27 |         assert dim % num_heads == 0
28 |         self.num_heads = num_heads
29 |         self.head_dim = dim // num_heads
30 | 
31 |         self.qkv = nn.Linear(dim, dim * 3, bias=False)
32 |         self.proj_out = torch.nn.Linear(dim, dim)
33 | 
34 |     def forward(self, x):
35 |         B, N, C = x.shape
36 |         qkv = self.qkv(x).view(B, N, 3, C)  # B, N, 3, C
37 |         q, k, v = qkv.unbind(2)  # B, N, C
38 |         k = k.reshape(B, N, self.num_heads, self.head_dim)
39 |         v = v.reshape(B, N, self.num_heads, self.head_dim)
40 |         q = q.reshape(B, N, self.num_heads, self.head_dim)
41 |         out = flash_attn_func(q, k, v)  # B, N, H, c
42 |         out = self.proj_out(out.view(B, N, C))  # B, N, C
43 |         return out
44 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/nn/act.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import copy
18 | 
19 | import torch.nn as nn
20 | 
21 | __all__ = ["build_act", "get_act_name"]
22 | 
23 | # register activation function here
24 | #   name: module, kwargs with default values
25 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = {
26 |     "relu": (nn.ReLU, {"inplace": True}),
27 |     "relu6": (nn.ReLU6, {"inplace": True}),
28 |     "hswish": (nn.Hardswish, {"inplace": True}),
29 |     "hsigmoid": (nn.Hardsigmoid, {"inplace": True}),
30 |     "swish": (nn.SiLU, {"inplace": True}),
31 |     "silu": (nn.SiLU, {"inplace": True}),
32 |     "tanh": (nn.Tanh, {}),
33 |     "sigmoid": (nn.Sigmoid, {}),
34 |     "gelu": (nn.GELU, {"approximate": "tanh"}),
35 |     "mish": (nn.Mish, {"inplace": True}),
36 |     "identity": (nn.Identity, {}),
37 | }
38 | 
39 | 
40 | def build_act(name: str or None, **kwargs) -> nn.Module or None:
41 |     if name in REGISTERED_ACT_DICT:
42 |         act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name])
43 |         for key in default_args:
44 |             if key in kwargs:
45 |                 default_args[key] = kwargs[key]
46 |         return act_cls(**default_args)
47 |     elif name is None or name.lower() == "none":
48 |         return None
49 |     else:
50 |         raise ValueError(f"do not support: {name}")
51 | 
52 | 
53 | def get_act_name(act: nn.Module or None) -> str or None:
54 |     if act is None:
55 |         return None
56 |     module2name = {}
57 |     for key, config in REGISTERED_ACT_DICT.items():
58 |         module2name[config[0].__name__] = key
59 |     return module2name.get(type(act).__name__, "unknown")
60 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/nn/conv.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import torch
18 | from torch import nn
19 | 
20 | from ..utils.model import get_same_padding
21 | from .act import build_act, get_act_name
22 | from .norm import build_norm, get_norm_name
23 | 
24 | 
25 | class ConvLayer(nn.Module):
26 |     def __init__(
27 |         self,
28 |         in_dim: int,
29 |         out_dim: int,
30 |         kernel_size=3,
31 |         stride=1,
32 |         dilation=1,
33 |         groups=1,
34 |         padding: int or None = None,
35 |         use_bias=False,
36 |         dropout=0.0,
37 |         norm="bn2d",
38 |         act="relu",
39 |     ):
40 |         super().__init__()
41 |         if padding is None:
42 |             padding = get_same_padding(kernel_size)
43 |             padding *= dilation
44 | 
45 |         self.in_dim = in_dim
46 |         self.out_dim = out_dim
47 |         self.kernel_size = kernel_size
48 |         self.stride = stride
49 |         self.dilation = dilation
50 |         self.groups = groups
51 |         self.padding = padding
52 |         self.use_bias = use_bias
53 | 
54 |         self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
55 |         self.conv = nn.Conv2d(
56 |             in_dim,
57 |             out_dim,
58 |             kernel_size=(kernel_size, kernel_size),
59 |             stride=(stride, stride),
60 |             padding=padding,
61 |             dilation=(dilation, dilation),
62 |             groups=groups,
63 |             bias=use_bias,
64 |         )
65 |         self.norm = build_norm(norm, num_features=out_dim)
66 |         self.act = build_act(act)
67 | 
68 |     def forward(self, x: torch.Tensor) -> torch.Tensor:
69 |         if self.dropout is not None:
70 |             x = self.dropout(x)
71 |         x = self.conv(x)
72 |         if self.norm:
73 |             x = self.norm(x)
74 |         if self.act:
75 |             x = self.act(x)
76 |         return x
77 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/utils/compare_results.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import torch
18 | 
19 | 
20 | def compare_results(name: str, result: torch.Tensor, ref_result: torch.Tensor):
21 |     print(f"comparing {name}")
22 |     diff = (result - ref_result).abs().view(-1)
23 |     max_error_pos = diff.argmax()
24 |     print(f"max error: {diff.max()}, mean error: {diff.mean()}")
25 |     print(f"max error pos: {result.view(-1)[max_error_pos]} {ref_result.view(-1)[max_error_pos]}")
26 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/utils/dtype.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import torch
18 | import triton
19 | import triton.language as tl
20 | 
21 | 
22 | def get_dtype_from_str(dtype: str) -> torch.dtype:
23 |     if dtype == "fp32":
24 |         return torch.float32
25 |     if dtype == "fp16":
26 |         return torch.float16
27 |     if dtype == "bf16":
28 |         return torch.bfloat16
29 |     raise NotImplementedError(f"dtype {dtype} is not supported")
30 | 
31 | 
32 | def get_tl_dtype_from_torch_dtype(dtype: torch.dtype) -> tl.dtype:
33 |     if dtype == torch.float32:
34 |         return tl.float32
35 |     if dtype == torch.float16:
36 |         return tl.float16
37 |     if dtype == torch.bfloat16:
38 |         return tl.bfloat16
39 |     raise NotImplementedError(f"dtype {dtype} is not supported")
40 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/utils/export_onnx.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import os
18 | import warnings
19 | from typing import Any, Tuple
20 | 
21 | import torch
22 | 
23 | 
24 | def export_onnx(
25 |     model: torch.nn.Module,
26 |     input_shape: Tuple[int],
27 |     export_path: str,
28 |     opset: int,
29 |     export_dtype: torch.dtype,
30 |     export_device: torch.device,
31 | ) -> None:
32 |     model.eval()
33 | 
34 |     dummy_input = {"x": torch.randn(input_shape, dtype=export_dtype, device=export_device)}
35 |     dynamic_axes = {
36 |         "x": {0: "batch_size"},
37 |     }
38 | 
39 |     # _ = model(**dummy_input)
40 | 
41 |     output_names = ["image_embeddings"]
42 | 
43 |     export_dir = os.path.dirname(export_path)
44 |     if not os.path.exists(export_dir):
45 |         os.makedirs(export_dir)
46 | 
47 |     with warnings.catch_warnings():
48 |         warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
49 |         warnings.filterwarnings("ignore", category=UserWarning)
50 |         print(f"Exporting onnx model to {export_path}...")
51 |         with open(export_path, "wb") as f:
52 |             torch.onnx.export(
53 |                 model,
54 |                 tuple(dummy_input.values()),
55 |                 f,
56 |                 export_params=True,
57 |                 verbose=False,
58 |                 opset_version=opset,
59 |                 do_constant_folding=True,
60 |                 input_names=list(dummy_input.keys()),
61 |                 output_names=output_names,
62 |                 dynamic_axes=dynamic_axes,
63 |             )
64 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/utils/model.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 MIT Han Lab
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | 
18 | def val2list(x: list or tuple or any, repeat_time=1) -> list:  # type: ignore
19 |     """Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
20 |     if isinstance(x, (list, tuple)):
21 |         return list(x)
22 |     return [x for _ in range(repeat_time)]
23 | 
24 | 
25 | def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:  # type: ignore
26 |     """Return tuple with min_len by repeating element at idx_repeat."""
27 |     # convert to list first
28 |     x = val2list(x)
29 | 
30 |     # repeat elements if necessary
31 |     if len(x) > 0:
32 |         x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
33 | 
34 |     return tuple(x)
35 | 
36 | 
37 | def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
38 |     if isinstance(kernel_size, tuple):
39 |         return tuple([get_same_padding(ks) for ks in kernel_size])
40 |     else:
41 |         assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
42 |         return kernel_size // 2
43 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/readme.md:
--------------------------------------------------------------------------------
 1 | # a fast implementation of linear attention
 2 | 
 3 | ## 64x64, fp16
 4 | 
 5 | ```bash
 6 | # validate correctness
 7 | ## fp16 vs fp32
 8 | python -m develop_triton_litemla attn_type=LiteMLA test_correctness=True
 9 | ## triton fp16 vs fp32
10 | python -m develop_triton_litemla attn_type=TritonLiteMLA test_correctness=True
11 | 
12 | # test performance
13 | ## fp16, forward
14 | python -m develop_triton_litemla attn_type=LiteMLA
15 | each step takes 10.81 ms
16 | max memory allocated: 2.2984 GB
17 | 
18 | ## triton fp16, forward
19 | python -m develop_triton_litemla attn_type=TritonLiteMLA
20 | each step takes 4.70 ms
21 | max memory allocated: 1.6480 GB
22 | 
23 | ## fp16, backward
24 | python -m develop_triton_litemla attn_type=LiteMLA backward=True
25 | each step takes 35.34 ms
26 | max memory allocated: 3.4412 GB
27 | 
28 | ## triton fp16, backward
29 | python -m develop_triton_litemla attn_type=TritonLiteMLA backward=True
30 | each step takes 14.25 ms
31 | max memory allocated: 2.4704 GB
32 | ```
33 | 


--------------------------------------------------------------------------------
/diffusion/model/nets/sana_ladd.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import torch
18 | import torch.nn as nn
19 | 
20 | from diffusion.model.builder import MODELS
21 | 
22 | from .ladd_blocks import DiscHead
23 | from .sana_multi_scale import SanaMSCM
24 | 
25 | 
26 | @MODELS.register_module()
27 | class SanaMSCMDiscriminator(nn.Module):
28 |     def __init__(self, pretrained_model: SanaMSCM, is_multiscale=False, head_block_ids=None):
29 |         super().__init__()
30 |         self.transformer = pretrained_model
31 |         self.transformer.requires_grad_(False)
32 | 
33 |         if head_block_ids is None or len(head_block_ids) == 0:
34 |             self.block_hooks = {2, 8, 14, 20, 27} if is_multiscale else {self.transformer.depth - 1}
35 |         else:
36 |             self.block_hooks = head_block_ids
37 | 
38 |         heads = []
39 |         for i in range(len(self.block_hooks)):
40 |             heads.append(DiscHead(self.transformer.hidden_size, 0, 0))
41 |         self.heads = nn.ModuleList(heads)
42 | 
43 |     def get_head_inputs(self):
44 |         return self.head_inputs
45 | 
46 |     def forward(self, x, timestep, y=None, data_info=None, mask=None, **kwargs):
47 |         feat_list = []
48 |         self.head_inputs = []
49 | 
50 |         def get_features(module, input, output):
51 |             feat_list.append(output)
52 |             return output
53 | 
54 |         hooks = []
55 |         for i, block in enumerate(self.transformer.blocks):
56 |             if i in self.block_hooks:
57 |                 hooks.append(block.register_forward_hook(get_features))
58 | 
59 |         self.transformer(x, timestep, y=y, mask=mask, data_info=data_info, return_logvar=False, **kwargs)
60 | 
61 |         for hook in hooks:
62 |             hook.remove()
63 | 
64 |         res_list = []
65 |         for feat, head in zip(feat_list, self.heads):
66 |             B, N, C = feat.shape
67 |             feat = feat.transpose(1, 2)  # [B, C, N]
68 |             self.head_inputs.append(feat)
69 |             res_list.append(head(feat, None).reshape(feat.shape[0], -1))
70 | 
71 |         concat_res = torch.cat(res_list, dim=1)
72 | 
73 |         return concat_res
74 | 
75 |     @property
76 |     def model(self):
77 |         return self.transformer
78 | 
79 |     def save_pretrained(self, path):
80 |         torch.save(self.state_dict(), path)
81 | 
82 | 
83 | class DiscHeadModel:
84 |     def __init__(self, disc):
85 |         self.disc = disc
86 | 
87 |     def state_dict(self):
88 |         return {name: param for name, param in self.disc.state_dict().items() if not name.startswith("transformer.")}
89 | 
90 |     def __getattr__(self, name):
91 |         return getattr(self.disc, name)
92 | 


--------------------------------------------------------------------------------
/diffusion/model/wan/__init__.py:
--------------------------------------------------------------------------------
 1 | from .attention import flash_attention
 2 | from .model import WanLinearAttentionModel, WanModel, init_model_configs
 3 | from .model_wrapper import SanaVideoMSBlock, SanaWanLinearAttentionModel, SanaWanModel
 4 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
 5 | from .tokenizers import HuggingfaceTokenizer
 6 | from .vae import WanVAE
 7 | 
 8 | __all__ = [
 9 |     "WanVAE",
10 |     "WanModel",
11 |     "WanLinearAttentionModel",
12 |     "init_model_configs",
13 |     "T5Model",
14 |     "T5Encoder",
15 |     "T5Decoder",
16 |     "T5EncoderModel",
17 |     "HuggingfaceTokenizer",
18 |     "flash_attention",
19 |     "SanaWanLinearAttentionModel",
20 |     "SanaWanModel",
21 |     "SanaVideoMSBlock",
22 | ]
23 | 


--------------------------------------------------------------------------------
/diffusion/model/wan/fsdp_utils.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 2 | import gc
 3 | from functools import partial
 4 | 
 5 | import torch
 6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 7 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
 8 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
 9 | from torch.distributed.utils import _free_storage
10 | 
11 | 
12 | def shard_model(
13 |     model,
14 |     device_id,
15 |     param_dtype=torch.bfloat16,
16 |     reduce_dtype=torch.float32,
17 |     buffer_dtype=torch.float32,
18 |     process_group=None,
19 |     sharding_strategy=ShardingStrategy.FULL_SHARD,
20 |     sync_module_states=True,
21 | ):
22 |     """
23 |     This is the shard function for T5 model.
24 |     """
25 |     model = FSDP(
26 |         module=model,
27 |         process_group=process_group,
28 |         sharding_strategy=sharding_strategy,
29 |         auto_wrap_policy=partial(lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
30 |         mixed_precision=MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype),
31 |         device_id=device_id,
32 |         sync_module_states=sync_module_states,
33 |     )
34 |     return model
35 | 
36 | 
37 | def free_model(model):
38 |     for m in model.modules():
39 |         if isinstance(m, FSDP):
40 |             _free_storage(m._handle.flat_param.data)
41 |     del model
42 |     gc.collect()
43 |     torch.cuda.empty_cache()
44 | 


--------------------------------------------------------------------------------
/diffusion/model/wan/model_wrapper.py:
--------------------------------------------------------------------------------
 1 | import torch.nn as nn
 2 | 
 3 | from .model import WanAttentionBlock, WanLinearAttentionModel, WanModel
 4 | 
 5 | 
 6 | class SanaVideoMSBlock(WanAttentionBlock):
 7 |     pass
 8 | 
 9 | 
10 | class SanaWanModel(WanModel):
11 |     def __init__(self, *args, **kwargs):
12 |         super().__init__(*args, **kwargs)
13 |         cross_attn_type = "t2v_cross_attn" if self.model_type == "t2v" else "i2v_cross_attn"
14 |         self.blocks = nn.ModuleList(
15 |             [
16 |                 SanaVideoMSBlock(
17 |                     cross_attn_type,
18 |                     self.dim,
19 |                     self.ffn_dim,
20 |                     self.num_heads,
21 |                     self.window_size,
22 |                     self.qk_norm,
23 |                     self.cross_attn_norm,
24 |                     self.eps,
25 |                 )
26 |                 for _ in range(self.num_layers)
27 |             ]
28 |         )
29 | 
30 | 
31 | class SanaWanLinearAttentionModel(WanLinearAttentionModel):
32 |     def __init__(self, *args, **kwargs):
33 |         super().__init__(*args, **kwargs)
34 |         cross_attn_type = "t2v_cross_attn" if self.model_type == "t2v" else "i2v_cross_attn"
35 |         self_attn_types = ["flash"] * self.num_layers
36 |         ffn_types = ["mlp"] * self.num_layers
37 | 
38 |         self.blocks = nn.ModuleList(
39 |             [
40 |                 SanaVideoMSBlock(
41 |                     cross_attn_type,
42 |                     self.dim,
43 |                     self.ffn_dim,
44 |                     self.num_heads,
45 |                     self.window_size,
46 |                     self.qk_norm,
47 |                     self.cross_attn_norm,
48 |                     self.eps,
49 |                     self_attn_types[i],
50 |                     self.rope_after,
51 |                     self.power,
52 |                     ffn_types[i],
53 |                 )
54 |                 for i in range(self.num_layers)
55 |             ]
56 |         )
57 | 


--------------------------------------------------------------------------------
/diffusion/model/wan/tokenizers.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 2 | import html
 3 | import string
 4 | 
 5 | import ftfy
 6 | import regex as re
 7 | from transformers import AutoTokenizer
 8 | 
 9 | __all__ = ["HuggingfaceTokenizer"]
10 | 
11 | 
12 | def basic_clean(text):
13 |     text = ftfy.fix_text(text)
14 |     text = html.unescape(html.unescape(text))
15 |     return text.strip()
16 | 
17 | 
18 | def whitespace_clean(text):
19 |     text = re.sub(r"\s+", " ", text)
20 |     text = text.strip()
21 |     return text
22 | 
23 | 
24 | def canonicalize(text, keep_punctuation_exact_string=None):
25 |     text = text.replace("_", " ")
26 |     if keep_punctuation_exact_string:
27 |         text = keep_punctuation_exact_string.join(
28 |             part.translate(str.maketrans("", "", string.punctuation))
29 |             for part in text.split(keep_punctuation_exact_string)
30 |         )
31 |     else:
32 |         text = text.translate(str.maketrans("", "", string.punctuation))
33 |     text = text.lower()
34 |     text = re.sub(r"\s+", " ", text)
35 |     return text.strip()
36 | 
37 | 
38 | class HuggingfaceTokenizer:
39 |     def __init__(self, name, seq_len=None, clean=None, **kwargs):
40 |         assert clean in (None, "whitespace", "lower", "canonicalize")
41 |         self.name = name
42 |         self.seq_len = seq_len
43 |         self.clean = clean
44 | 
45 |         # init tokenizer
46 |         self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47 |         self.vocab_size = self.tokenizer.vocab_size
48 | 
49 |     def __call__(self, sequence, **kwargs):
50 |         return_mask = kwargs.pop("return_mask", False)
51 | 
52 |         # arguments
53 |         _kwargs = {"return_tensors": "pt"}
54 |         if self.seq_len is not None:
55 |             _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
56 |         _kwargs.update(**kwargs)
57 | 
58 |         # tokenization
59 |         if isinstance(sequence, str):
60 |             sequence = [sequence]
61 |         if self.clean:
62 |             sequence = [self._clean(u) for u in sequence]
63 |         ids = self.tokenizer(sequence, **_kwargs)
64 | 
65 |         # output
66 |         if return_mask:
67 |             return ids.input_ids, ids.attention_mask
68 |         else:
69 |             return ids.input_ids
70 | 
71 |     def _clean(self, text):
72 |         if self.clean == "whitespace":
73 |             text = whitespace_clean(basic_clean(text))
74 |         elif self.clean == "lower":
75 |             text = whitespace_clean(basic_clean(text)).lower()
76 |         elif self.clean == "canonicalize":
77 |             text = canonicalize(basic_clean(text))
78 |         return text
79 | 


--------------------------------------------------------------------------------
/diffusion/scheduler/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/scheduler/__init__.py


--------------------------------------------------------------------------------
/diffusion/scheduler/dpm_solver.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import torch
18 | 
19 | from diffusion.model import gaussian_diffusion as gd
20 | from diffusion.model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper
21 | 
22 | 
23 | def DPMS(
24 |     model,
25 |     condition,
26 |     uncondition,
27 |     cfg_scale,
28 |     pag_scale=1.0,
29 |     pag_applied_layers=None,
30 |     model_type="noise",  # or "x_start" or "v" or "score", "flow"
31 |     noise_schedule="linear",
32 |     guidance_type="classifier-free",
33 |     model_kwargs=None,
34 |     diffusion_steps=1000,
35 |     schedule="VP",
36 |     interval_guidance=None,
37 |     condition_as_list=False,  # for wan text encoder, set to true
38 |     apg=None,
39 |     **kwargs,
40 | ):
41 |     if pag_applied_layers is None:
42 |         pag_applied_layers = []
43 |     if model_kwargs is None:
44 |         model_kwargs = {}
45 |     if interval_guidance is None:
46 |         interval_guidance = [0, 1.0]
47 |     betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
48 | 
49 |     ## 1. Define the noise schedule.
50 |     if schedule == "VP":
51 |         noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas)
52 |     elif schedule == "FLOW":
53 |         noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
54 | 
55 |     ## 2. Convert your discrete-time `model` to the continuous-time
56 |     ## noise prediction model. Here is an example for a diffusion model
57 |     ## `model` with the noise prediction type ("noise") .
58 |     model_fn = model_wrapper(
59 |         model,
60 |         noise_schedule,
61 |         model_type=model_type,
62 |         model_kwargs=model_kwargs,
63 |         guidance_type=guidance_type,
64 |         pag_scale=pag_scale,
65 |         pag_applied_layers=pag_applied_layers,
66 |         condition=condition,
67 |         unconditional_condition=uncondition,
68 |         guidance_scale=cfg_scale,
69 |         interval_guidance=interval_guidance,
70 |         condition_as_list=condition_as_list,
71 |         apg=apg,
72 |         **kwargs,
73 |     )
74 |     ## 3. Define dpm-solver and sample by multistep DPM-Solver.
75 |     return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
76 | 


--------------------------------------------------------------------------------
/diffusion/scheduler/iddpm.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | # Modified from OpenAI's diffusion repos
18 | #     GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
19 | #     ADM:   https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
20 | #     IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
21 | 
22 | from diffusion.model import gaussian_diffusion as gd
23 | from diffusion.model.respace import SpacedDiffusion, space_timesteps
24 | 
25 | 
26 | def Scheduler(
27 |     timestep_respacing,
28 |     noise_schedule="linear",
29 |     use_kl=False,
30 |     sigma_small=False,
31 |     predict_xstart=False,
32 |     predict_flow_v=False,
33 |     learn_sigma=True,
34 |     pred_sigma=True,
35 |     rescale_learned_sigmas=False,
36 |     diffusion_steps=1000,
37 |     snr=False,
38 |     return_startx=False,
39 |     flow_shift=1.0,
40 | ):
41 |     betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
42 |     if use_kl:
43 |         loss_type = gd.LossType.RESCALED_KL
44 |     elif rescale_learned_sigmas:
45 |         loss_type = gd.LossType.RESCALED_MSE
46 |     else:
47 |         loss_type = gd.LossType.MSE
48 |     if timestep_respacing is None or timestep_respacing == "":
49 |         timestep_respacing = [diffusion_steps]
50 |     if predict_xstart:
51 |         model_mean_type = gd.ModelMeanType.START_X
52 |     elif predict_flow_v:
53 |         model_mean_type = gd.ModelMeanType.FLOW_VELOCITY
54 |     else:
55 |         model_mean_type = gd.ModelMeanType.EPSILON
56 |     return SpacedDiffusion(
57 |         use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
58 |         betas=betas,
59 |         model_mean_type=model_mean_type,
60 |         model_var_type=(
61 |             (
62 |                 (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
63 |                 if not learn_sigma
64 |                 else gd.ModelVarType.LEARNED_RANGE
65 |             )
66 |             if pred_sigma
67 |             else None
68 |         ),
69 |         loss_type=loss_type,
70 |         snr=snr,
71 |         return_startx=return_startx,
72 |         # rescale_timesteps=rescale_timesteps,
73 |         flow="flow" in noise_schedule,
74 |         flow_shift=flow_shift,
75 |         diffusion_steps=diffusion_steps,
76 |     )
77 | 


--------------------------------------------------------------------------------
/diffusion/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/utils/__init__.py


--------------------------------------------------------------------------------
/environment_setup.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | set -e
 3 | 
 4 | # Check if we should skip environment setup entirely
 5 | if [ "${SKIP_ENV_SETUP}" = "true" ]; then
 6 |     echo "SKIP_ENV_SETUP is set to true. Skipping all environment setup steps."
 7 |     echo "Using default conda environment. Make sure it has all required packages installed."
 8 |     exit 0
 9 | fi
10 | 
11 | CONDA_ENV=${1:-""}
12 | if [ -n "$CONDA_ENV" ]; then
13 |     # This is required to activate conda environment
14 |     eval "$(conda shell.bash hook)"
15 | 
16 |     conda create -n $CONDA_ENV python=3.10.0 -y
17 |     conda activate $CONDA_ENV
18 |     # This is optional if you prefer to use built-in nvcc
19 |     conda install -c nvidia cuda-toolkit=12.8 -y
20 | else
21 |     echo "Skipping conda environment creation. Make sure you have the correct environment activated."
22 | fi
23 | 
24 | # init a raw torch to avoid installation errors.
25 | # pip install torch
26 | 
27 | # update pip to latest version for pyproject.toml setup.
28 | pip install -U pip
29 | 
30 | # for fast attn
31 | pip install -U xformers==0.0.32.post2 --index-url https://download.pytorch.org/whl/cu128
32 | 
33 | # install sana
34 | pip install -e .
35 | 
36 | pip install flash-attn==2.8.2 --no-build-isolation
37 | 
38 | # install torchprofile
39 | # pip install git+https://github.com/zhijian-liu/torchprofile
40 | 


--------------------------------------------------------------------------------
/inference_video_scripts/inference_sana_video.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | set -e
 3 | 
 4 | inference_script=inference_video_scripts/inference_sana_video.py
 5 | config=""
 6 | model_path=""
 7 | np=8
 8 | negative_prompt=""
 9 | 
10 | while [[ $# -gt 0 ]]; do
11 |   case $1 in
12 |     --config=*)
13 |       config="${1#*=}"
14 |       shift
15 |       ;;
16 |     --config)
17 |       config="$2"
18 |       shift 2
19 |       ;;
20 |     --model_path=*)
21 |       model_path="${1#*=}"
22 |       shift
23 |       ;;
24 |     --model_path)
25 |       model_path="$2"
26 |       shift 2
27 |       ;;
28 |     --inference_script=*)
29 |       inference_script="${1#*=}"
30 |       shift
31 |       ;;
32 |     --inference_script)
33 |       inference_script="$2"
34 |       shift 2
35 |       ;;
36 |     --np=*)
37 |       np="${1#*=}"
38 |       shift
39 |       ;;
40 |     --np)
41 |       np="$2"
42 |       shift 2
43 |       ;;
44 |     --negative_prompt=*)
45 |       negative_prompt="${1#*=}"
46 |       shift
47 |       ;;
48 |     --negative_prompt)
49 |       negative_prompt="$2"
50 |       shift 2
51 |       ;;
52 |     *)
53 |       other_args+=("$1")
54 |       shift
55 |       ;;
56 |   esac
57 | done
58 | 
59 | cmd=(
60 |   accelerate launch --num_processes="$np" --num_machines=1 --mixed_precision=bf16 --main_process_port="$RANDOM"
61 |   "$inference_script"
62 |   --config="$config"
63 |   --model_path="$model_path"
64 |   --txt_file=asset/samples/video_prompts_samples.txt
65 |   --dataset=video_samples
66 | )
67 | 
68 | if [[ -n "$negative_prompt" ]]; then
69 |   cmd+=(--negative_prompt="$negative_prompt")
70 | fi
71 | 
72 | if [[ ${#other_args[@]} -gt 0 ]]; then
73 |   cmd+=("${other_args[@]}")
74 | fi
75 | 
76 | printf -v cmd_str '%q ' "${cmd[@]}"
77 | echo "$cmd_str"
78 | 
79 | "${cmd[@]}"
80 | 


--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
 1 | [build-system]
 2 | requires = ["setuptools>=61.0"]
 3 | build-backend = "setuptools.build_meta"
 4 | 
 5 | [project]
 6 | name = "sana"
 7 | version = "0.2.0"
 8 | description = "SANA"
 9 | readme = "README.md"
10 | requires-python = ">=3.10"
11 | classifiers = [
12 |     "Programming Language :: Python :: 3",
13 |     "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 |     "pre-commit",
17 |     "huggingface-hub==0.36.0",
18 |     "accelerate==1.0.1",
19 |     "beautifulsoup4",
20 |     "bs4",
21 |     "came-pytorch",
22 |     "einops",
23 |     "ftfy",
24 |     "wheel",
25 |     "psutil",
26 |     "ninja",
27 |     "diffusers==0.35.0",
28 |     "clip@git+https://github.com/openai/CLIP.git",
29 |     "gradio",
30 |     "image-reward",
31 |     "ipdb",
32 |     "mmcv==1.7.2",
33 |     "omegaconf",
34 |     "opencv-python",
35 |     "optimum",
36 |     "patch_conv",
37 |     "peft==0.17.0",
38 |     "protobuf",
39 |     "pytorch-fid",
40 |     "regex",
41 |     "sentencepiece",
42 |     "tensorboard",
43 |     "tensorboardX",
44 |     "timm==0.6.13",
45 |     "torchaudio==2.8.0",
46 |     "torchvision==0.23.0",
47 |     "transformers==4.57.0",
48 |     "triton==3.4.0",
49 |     "wandb",
50 |     "webdataset",
51 |     "xformers==0.0.32.post2",
52 |     "yapf",
53 |     "spaces",
54 |     "matplotlib",
55 |     "termcolor",
56 |     "pyrallis",
57 |     "bitsandbytes",
58 |     "fire",
59 |     "moviepy",
60 |     "imageio[pyav,ffmpeg]",
61 |     "qwen-vl-utils",
62 | ]
63 | 
64 | 
65 | [project.scripts]
66 | sana-run = "sana.cli.run:main"
67 | sana-upload = "sana.cli.upload2hf:main"
68 | 
69 | [project.optional-dependencies]
70 | 
71 | [project.urls]
72 | 
73 | [tool.pip]
74 | extra-index-url = ["https://download.pytorch.org/whl/cu128"]
75 | 
76 | [tool.black]
77 | line-length = 120
78 | 
79 | [tool.isort]
80 | profile = "black"
81 | multi_line_output = 3
82 | include_trailing_comma = true
83 | force_grid_wrap = 0
84 | use_parentheses = true
85 | ensure_newline_before_comments = true
86 | line_length = 120
87 | 
88 | [tool.setuptools.packages.find]
89 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*", "data*", "output*"]
90 | 
91 | [tool.wheel]
92 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*", "data*", "output*"]
93 | 


--------------------------------------------------------------------------------
/sana/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .download import download_model
2 | from .hf_utils import hf_download_or_fpath
3 | 


--------------------------------------------------------------------------------
/sana/tools/download.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | # All rights reserved.
 3 | 
 4 | # This source code is licensed under the license found in the
 5 | # LICENSE file in the root directory of this source tree.
 6 | 
 7 | """
 8 | Functions for downloading pre-trained Sana models
 9 | """
10 | import argparse
11 | import os
12 | 
13 | import torch
14 | from torchvision.datasets.utils import download_url
15 | 
16 | pretrained_models = {}
17 | 
18 | 
19 | def find_model(model_name):
20 |     """
21 |     Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path.
22 |     """
23 |     if model_name in pretrained_models:  # Find/download our pre-trained G.pt checkpoints
24 |         return download_model(model_name)
25 |     else:  # Load a custom Sana checkpoint:
26 |         assert os.path.isfile(model_name), f"Could not find Sana checkpoint at {model_name}"
27 |         return torch.load(model_name, map_location=lambda storage, loc: storage)
28 | 
29 | 
30 | def download_model(model_name):
31 |     """
32 |     Downloads a pre-trained Sana model from the web.
33 |     """
34 |     assert model_name in pretrained_models
35 |     local_path = f"output/pretrained_models/{model_name}"
36 |     if not os.path.isfile(local_path):
37 |         hf_endpoint = os.environ.get("HF_ENDPOINT")
38 |         if hf_endpoint is None:
39 |             hf_endpoint = "https://huggingface.co"
40 |         os.makedirs("output/pretrained_models", exist_ok=True)
41 |         web_path = f"{hf_endpoint}/xxx/resolve/main/{model_name}"
42 |         download_url(web_path, "output/pretrained_models/")
43 |     model = torch.load(local_path, map_location=lambda storage, loc: storage)
44 |     return model
45 | 
46 | 
47 | if __name__ == "__main__":
48 |     parser = argparse.ArgumentParser()
49 |     parser.add_argument("--model_names", nargs="+", type=str, default=pretrained_models)
50 |     args = parser.parse_args()
51 |     model_names = args.model_names
52 |     model_names = set(model_names)
53 | 
54 |     # Download Sana checkpoints
55 |     for model in model_names:
56 |         download_model(model)
57 |     print("Done.")
58 | 


--------------------------------------------------------------------------------
/sana/tools/hf_utils.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import os
18 | import os.path as osp
19 | import sys
20 | 
21 | from huggingface_hub import hf_hub_download, snapshot_download
22 | 
23 | 
24 | def hf_download_or_fpath(path):
25 |     if osp.exists(path):
26 |         return path
27 | 
28 |     if path.startswith("hf://"):
29 |         segs = path.replace("hf://", "").split("/")
30 |         repo_id = "/".join(segs[:2])
31 |         filename = "/".join(segs[2:])
32 |         return hf_download_data(repo_id, filename, repo_type="model", download_full_repo=True)
33 | 
34 | 
35 | def hf_download_data(
36 |     repo_id="Efficient-Large-Model/Sana_1600M_1024px",
37 |     filename="checkpoints/Sana_1600M_1024px.pth",
38 |     cache_dir=None,
39 |     repo_type="model",
40 |     download_full_repo=False,
41 | ):
42 |     """
43 |     Download dummy data from a Hugging Face repository.
44 | 
45 |     Args:
46 |     repo_id (str): The ID of the Hugging Face repository.
47 |     filename (str): The name of the file to download.
48 |     cache_dir (str, optional): The directory to cache the downloaded file.
49 | 
50 |     Returns:
51 |     str: The path to the downloaded file.
52 |     """
53 |     try:
54 |         if download_full_repo:
55 |             # download full repos to fit dc-ae
56 |             snapshot_download(
57 |                 repo_id=repo_id,
58 |                 cache_dir=cache_dir,
59 |                 repo_type=repo_type,
60 |             )
61 |         file_path = hf_hub_download(
62 |             repo_id=repo_id,
63 |             filename=filename,
64 |             cache_dir=cache_dir,
65 |             repo_type=repo_type,
66 |         )
67 |         return file_path
68 |     except Exception as e:
69 |         print(f"Error downloading file: {e}")
70 |         return None
71 | 


--------------------------------------------------------------------------------
/scripts/infer_run_inference_geneval_diffusers.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | # ================= sampler & data =================
 3 | np=8    # number of GPU to use
 4 | default_step=20   # 14
 5 | default_sample_nums=553
 6 | default_sampling_algo="dpm-solver"
 7 | default_add_label=''
 8 | 
 9 | # parser
10 | config_file=$1
11 | model_paths=$2
12 | 
13 | for arg in "$@"
14 | do
15 |     case $arg in
16 |         --step=*)
17 |         step="${arg#*=}"
18 |         shift
19 |         ;;
20 |         --sampling_algo=*)
21 |         sampling_algo="${arg#*=}"
22 |         shift
23 |         ;;
24 |         --add_label=*)
25 |         add_label="${arg#*=}"
26 |         shift
27 |         ;;
28 |         --model_path=*)
29 |         model_paths="${arg#*=}"
30 |         shift
31 |         ;;
32 |         --exist_time_prefix=*)
33 |         exist_time_prefix="${arg#*=}"
34 |         shift
35 |         ;;
36 |         --if_save_dirname=*)
37 |         if_save_dirname="${arg#*=}"
38 |         shift
39 |         ;;
40 |         *)
41 |         ;;
42 |     esac
43 | done
44 | 
45 | sample_nums=$default_sample_nums
46 | samples_per_gpu=$((sample_nums / np))
47 | add_label=${add_label:-$default_add_label}
48 | echo "Sample numbers: $sample_nums"
49 | echo "Add label: $add_label"
50 | echo "Exist time prefix: $exist_time_prefix"
51 | 
52 | cmd_template="DPM_TQDM=True python scripts/inference_geneval_diffusers.py \
53 |     --model_path=$model_paths \
54 |     --gpu_id {gpu_id} --start_index {start_index} --end_index {end_index}"
55 | if [ -n "${add_label}" ]; then
56 |     cmd_template="${cmd_template} --add_label ${add_label}"
57 | fi
58 | 
59 | echo "==================== inferencing ===================="
60 | for gpu_id in $(seq 0 $((np - 1))); do
61 |   start_index=$((gpu_id * samples_per_gpu))
62 |   end_index=$((start_index + samples_per_gpu))
63 |   if [ $gpu_id -eq $((np - 1)) ]; then
64 |     end_index=$sample_nums
65 |   fi
66 | 
67 |   cmd="${cmd_template//\{config_file\}/$config_file}"
68 |   cmd="${cmd//\{model_path\}/$model_paths}"
69 |   cmd="${cmd//\{gpu_id\}/$gpu_id}"
70 |   cmd="${cmd//\{start_index\}/$start_index}"
71 |   cmd="${cmd//\{end_index\}/$end_index}"
72 | 
73 |   echo "Running on GPU $gpu_id: samples $start_index to $end_index"
74 |   eval CUDA_VISIBLE_DEVICES=$gpu_id $cmd &
75 | done
76 | wait
77 | 
78 | echo infer finally done
79 | 


--------------------------------------------------------------------------------
/scripts/style.css:
--------------------------------------------------------------------------------
 1 | /*.gradio-container{width:680px!important}*/
 2 | /* style.css */
 3 | .gradio_group, .gradio_row, .gradio_column {
 4 |     display: flex;
 5 |     flex-direction: row;
 6 |     justify-content: flex-start;
 7 |     align-items: flex-start;
 8 |     flex-wrap: wrap;
 9 | }
10 | 


--------------------------------------------------------------------------------
/tests/bash/entry.sh:
--------------------------------------------------------------------------------
 1 | #/bin/bash
 2 | set -e
 3 | 
 4 | 
 5 | echo "Testing inference"
 6 | bash tests/bash/inference/test_inference.sh
 7 | 
 8 | echo "Testing training"
 9 | bash tests/bash/training/test_training_all.sh
10 | 


--------------------------------------------------------------------------------
/tests/bash/inference/test_inference.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | set -e
 3 | 
 4 | python scripts/inference.py \
 5 |     --config=configs/sana_config/1024ms/Sana_600M_img1024.yaml \
 6 |     --model_path=hf://Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth
 7 | 
 8 | python scripts/inference.py \
 9 |     --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
10 |     --model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
11 | 
12 | python tools/controlnet/inference_controlnet.py \
13 |     --config=configs/sana_controlnet_config/Sana_600M_img1024_controlnet.yaml \
14 |     --model_path=hf://Efficient-Large-Model/Sana_600M_1024px_ControlNet_HED/checkpoints/Sana_600M_1024px_ControlNet_HED.pth \
15 |     --json_file=asset/controlnet/samples_controlnet.json
16 | 
17 | python scripts/inference_sana_sprint.py \
18 |     --config=configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml \
19 |     --model_path=hf://Lawrence-cj/Sana_Sprint_1600M_1024px/Sana_Sprint_1600M_1024px_36K.pth \
20 |     --txt_file=asset/samples/samples_mini.txt
21 | 
22 | python inference_video_scripts/inference_sana_video.py \
23 |     --config=configs/sana_video_config/Sana_2000M_256px_AdamW_fsdp.yaml \
24 |     --model_path=hf://Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth \
25 |     --debug=true
26 | 


--------------------------------------------------------------------------------
/tests/bash/setup_test_data.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | set -e
 3 | 
 4 | pip install --upgrade "huggingface-hub<1.0"
 5 | 
 6 | # download test data
 7 | mkdir -p data/data_public
 8 | hf download Efficient-Large-Model/sana_data_public --repo-type dataset --local-dir ./data/data_public
 9 | hf download Efficient-Large-Model/toy_data --repo-type dataset --local-dir ./data/toy_data
10 | hf download Efficient-Large-Model/video_toy_data --repo-type dataset --local-dir ./data/video_toy_data
11 | 
12 | mkdir -p output/pretrained_models
13 | hf download Wan-AI/Wan2.1-T2V-1.3B --repo-type model --local-dir ./output/pretrained_models/Wan2.1-T2V-1.3B
14 | 


--------------------------------------------------------------------------------
/tests/bash/training/test_training_all.sh:
--------------------------------------------------------------------------------
 1 | #/bin/bash
 2 | set -e
 3 | 
 4 | mkdir -p data/data_public
 5 | hf download Efficient-Large-Model/sana_data_public --repo-type dataset --local-dir ./data/data_public
 6 | hf download Efficient-Large-Model/toy_data --repo-type dataset --local-dir ./data/toy_data
 7 | hf download Efficient-Large-Model/video_toy_data --repo-type dataset --local-dir ./data/video_toy_data
 8 | 
 9 | mkdir -p output/pretrained_models
10 | hf download Efficient-Large-Model/Wan2.1-T2V-1.3B --repo-type model --local-dir ./output/pretrained_models/Wan2.1-T2V-1.3B
11 | 
12 | # test offline vae feature
13 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --data.load_vae_feat=true
14 | 
15 | # test online vae feature
16 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --data.data_dir="[asset/example_data]" --data.type=SanaImgDataset --model.multi_scale=false
17 | 
18 | # test FSDP training
19 | bash train_scripts/train.sh configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1
20 | 
21 | # test SANA-Sprint(sCM + LADD) training
22 | bash train_scripts/train_scm_ladd.sh configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml  --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1
23 | 
24 | # test FSDP video training
25 | bash train_video_scripts/train_video_ivjoint.sh configs/sana_video_config/Sana_2000M_256px_AdamW_fsdp.yaml --np=2 --train.num_epochs=1 --train.log_interval=1 --train.train_batch_size=1
26 | 


--------------------------------------------------------------------------------
/tests/bash/training/test_training_fsdp.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | set -e
 3 | 
 4 | echo "Setting up test data..."
 5 | 
 6 | mkdir -p data/data_public
 7 | hf download Efficient-Large-Model/toy_data --repo-type dataset --local-dir ./data/toy_data
 8 | 
 9 | echo "Testing SANA-Sprint(sCM + LADD) training"
10 | bash train_scripts/train_scm_ladd.sh configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml --np=4 --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1
11 | 
12 | echo "Testing FSDP training"
13 | bash train_scripts/train.sh configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml --np=2 --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1
14 | 


--------------------------------------------------------------------------------
/tests/bash/training/test_training_vae.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | set -e
 3 | 
 4 | echo "Setting up test data..."
 5 | # bash tests/bash/setup_test_data.sh
 6 | mkdir -p data/data_public
 7 | hf download Efficient-Large-Model/sana_data_public --repo-type dataset --local-dir ./data/data_public
 8 | 
 9 | echo "Testing offline VAE feature"
10 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --np=4 --data.load_vae_feat=true
11 | 
12 | echo "Testing online VAE feature"
13 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --np=4 --data.data_dir="[asset/example_data]" --data.type=SanaImgDataset --model.multi_scale=false
14 | 


--------------------------------------------------------------------------------
/tests/bash/training/test_training_video.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | set -e
 3 | 
 4 | echo "Setting up test data..."
 5 | # bash tests/bash/setup_test_data.sh
 6 | hf download Efficient-Large-Model/video_toy_data --repo-type dataset --local-dir ./data/video_toy_data
 7 | 
 8 | mkdir -p output/pretrained_models
 9 | hf download Wan-AI/Wan2.1-T2V-1.3B --repo-type model --local-dir ./output/pretrained_models/Wan2.1-T2V-1.3B
10 | 
11 | echo "Testing FSDP video training"
12 | bash train_video_scripts/train_video_ivjoint.sh configs/sana_video_config/Sana_2000M_256px_AdamW_fsdp.yaml --np=2 --train.num_epochs=1 --train.log_interval=1 --train.train_batch_size=1 --train.joint_training_interval=0
13 | 


--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/tools/__init__.py


--------------------------------------------------------------------------------
/tools/controlnet/annotator/ckpts/ckpts.txt:
--------------------------------------------------------------------------------
1 | Weights here.
2 | 


--------------------------------------------------------------------------------
/tools/controlnet/annotator/util.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import random
 3 | 
 4 | import cv2
 5 | import numpy as np
 6 | 
 7 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), "ckpts")
 8 | 
 9 | 
10 | def HWC3(x):
11 |     assert x.dtype == np.uint8
12 |     if x.ndim == 2:
13 |         x = x[:, :, None]
14 |     assert x.ndim == 3
15 |     H, W, C = x.shape
16 |     assert C == 1 or C == 3 or C == 4
17 |     if C == 3:
18 |         return x
19 |     if C == 1:
20 |         return np.concatenate([x, x, x], axis=2)
21 |     if C == 4:
22 |         color = x[:, :, 0:3].astype(np.float32)
23 |         alpha = x[:, :, 3:4].astype(np.float32) / 255.0
24 |         y = color * alpha + 255.0 * (1.0 - alpha)
25 |         y = y.clip(0, 255).astype(np.uint8)
26 |         return y
27 | 
28 | 
29 | def resize_image(input_image, resolution):
30 |     H, W, C = input_image.shape
31 |     H = float(H)
32 |     W = float(W)
33 |     k = float(resolution) / min(H, W)
34 |     H *= k
35 |     W *= k
36 |     H = int(np.round(H / 64.0)) * 64
37 |     W = int(np.round(W / 64.0)) * 64
38 |     img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
39 |     return img
40 | 
41 | 
42 | def nms(x, t, s):
43 |     x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
44 | 
45 |     f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
46 |     f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
47 |     f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
48 |     f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
49 | 
50 |     y = np.zeros_like(x)
51 | 
52 |     for f in [f1, f2, f3, f4]:
53 |         np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
54 | 
55 |     z = np.zeros_like(y, dtype=np.uint8)
56 |     z[y > t] = 255
57 |     return z
58 | 
59 | 
60 | def make_noise_disk(H, W, C, F):
61 |     noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
62 |     noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
63 |     noise = noise[F : F + H, F : F + W]
64 |     noise -= np.min(noise)
65 |     noise /= np.max(noise)
66 |     if C == 1:
67 |         noise = noise[:, :, None]
68 |     return noise
69 | 
70 | 
71 | def min_max_norm(x):
72 |     x -= np.min(x)
73 |     x /= np.maximum(np.max(x), 1e-5)
74 |     return x
75 | 
76 | 
77 | def safe_step(x, step=2):
78 |     y = x.astype(np.float32) * float(step + 1)
79 |     y = y.astype(np.int32).astype(np.float32) / float(step)
80 |     return y
81 | 
82 | 
83 | def img2mask(img, H, W, low=10, high=90):
84 |     assert img.ndim == 3 or img.ndim == 2
85 |     assert img.dtype == np.uint8
86 | 
87 |     if img.ndim == 3:
88 |         y = img[:, :, random.randrange(0, img.shape[2])]
89 |     else:
90 |         y = img
91 | 
92 |     y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
93 | 
94 |     if random.uniform(0, 1) < 0.5:
95 |         y = 255 - y
96 | 
97 |     return y < np.percentile(y, random.randrange(low, high))
98 | 


--------------------------------------------------------------------------------
/tools/controlnet/utils.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | 
 3 | import cv2
 4 | import numpy as np
 5 | from PIL import Image
 6 | from torchvision import transforms as T
 7 | from torchvision.transforms.functional import InterpolationMode
 8 | 
 9 | from tools.controlnet.annotator.hed import HEDdetector
10 | from tools.controlnet.annotator.util import HWC3, nms, resize_image
11 | 
12 | preprocessor = None
13 | 
14 | 
15 | def transform_control_signal(control_signal, hw):
16 |     if isinstance(control_signal, str):
17 |         control_signal = Image.open(control_signal)
18 |     elif isinstance(control_signal, Image.Image):
19 |         control_signal = control_signal
20 |     elif isinstance(control_signal, np.ndarray):
21 |         control_signal = Image.fromarray(control_signal)
22 |     else:
23 |         raise ValueError("control_signal must be a path or a PIL.Image.Image or a numpy array")
24 | 
25 |     transform = T.Compose(
26 |         [
27 |             T.Lambda(lambda img: img.convert("RGB")),
28 |             T.Resize((int(hw[0, 0]), int(hw[0, 1])), interpolation=InterpolationMode.BICUBIC),  # Image.BICUBIC
29 |             T.CenterCrop((int(hw[0, 0]), int(hw[0, 1]))),
30 |             T.ToTensor(),
31 |             T.Normalize([0.5], [0.5]),
32 |         ]
33 |     )
34 |     return transform(control_signal).unsqueeze(0)
35 | 
36 | 
37 | def get_scribble_map(input_image, det, detect_resolution=512, thickness=None):
38 |     """
39 |     Generate scribble map from input image
40 | 
41 |     Args:
42 |         input_image: Input image (numpy array, HWC format)
43 |         det: Detector type ('Scribble_HED', 'Scribble_PIDI', 'None')
44 |         detect_resolution: Processing resolution
45 |         thickness: Line thickness (between 0-24, None for random)
46 | 
47 |     Returns:
48 |         Processed scribble map
49 |     """
50 |     global preprocessor
51 | 
52 |     # Initialize detector
53 |     if "HED" in det and not isinstance(preprocessor, HEDdetector):
54 |         preprocessor = HEDdetector()
55 | 
56 |     input_image = HWC3(input_image)
57 | 
58 |     if det == "None":
59 |         detected_map = input_image.copy()
60 |     else:
61 |         # Generate scribble map
62 |         detected_map = preprocessor(resize_image(input_image, detect_resolution))
63 |         detected_map = HWC3(detected_map)
64 | 
65 |         # Post-processing
66 |         detected_map = nms(detected_map, 127, 3.0)
67 |         detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
68 |         detected_map[detected_map > 4] = 255
69 |         detected_map[detected_map < 255] = 0
70 | 
71 |         # Control line thickness
72 |         if thickness is None:
73 |             thickness = random.randint(0, 24)  # Random thickness, including 0
74 |         if thickness == 0:
75 |             # Use erosion operation to get thinner lines
76 |             kernel = np.ones((4, 4), np.uint8)
77 |             detected_map = cv2.erode(detected_map, kernel, iterations=1)
78 |         elif thickness > 1:
79 |             kernel_size = thickness // 2
80 |             kernel = np.ones((kernel_size, kernel_size), np.uint8)
81 |             detected_map = cv2.dilate(detected_map, kernel, iterations=1)
82 | 
83 |     return detected_map
84 | 


--------------------------------------------------------------------------------
/tools/convert_ImgDataset_to_WebDatasetMS_format.py:
--------------------------------------------------------------------------------
 1 | # @Author:  Pevernow (wzy3450354617@gmail.com)
 2 | # @Date:    2025/1/5
 3 | # @License: (Follow the main project)
 4 | import json
 5 | import os
 6 | import tarfile
 7 | 
 8 | from PIL import Image, PngImagePlugin
 9 | 
10 | PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024  # Increase maximum size for text chunks
11 | 
12 | 
13 | def process_data(input_dir, output_tar_name="output.tar"):
14 |     """
15 |     Processes a directory containing PNG files, generates corresponding JSON files,
16 |     and packages all files into a TAR file. It also counts the number of processed PNG images,
17 |     and saves the height and width of each PNG file to the JSON.
18 | 
19 |     Args:
20 |         input_dir (str): The input directory containing PNG files.
21 |         output_tar_name (str): The name of the output TAR file (default is "output.tar").
22 |     """
23 |     png_count = 0
24 |     json_files_created = []
25 | 
26 |     for filename in os.listdir(input_dir):
27 |         if filename.lower().endswith(".png"):
28 |             png_count += 1
29 |             base_name = filename[:-4]  # Remove the ".png" extension
30 |             txt_filename = os.path.join(input_dir, base_name + ".txt")
31 |             json_filename = base_name + ".json"
32 |             json_filepath = os.path.join(input_dir, json_filename)
33 |             png_filepath = os.path.join(input_dir, filename)
34 | 
35 |             if os.path.exists(txt_filename):
36 |                 try:
37 |                     # Get the dimensions of the PNG image
38 |                     with Image.open(png_filepath) as img:
39 |                         width, height = img.size
40 | 
41 |                     with open(txt_filename, encoding="utf-8") as f:
42 |                         caption_content = f.read().strip()
43 | 
44 |                     data = {"file_name": filename, "prompt": caption_content, "width": width, "height": height}
45 | 
46 |                     with open(json_filepath, "w", encoding="utf-8") as outfile:
47 |                         json.dump(data, outfile, indent=4, ensure_ascii=False)
48 | 
49 |                     print(f"Generated: {json_filename}")
50 |                     json_files_created.append(json_filepath)
51 | 
52 |                 except Exception as e:
53 |                     print(f"Error processing file {filename}: {e}")
54 |             else:
55 |                 print(f"Warning: No corresponding TXT file found for {filename}.")
56 | 
57 |     # Create a TAR file and include all files
58 |     with tarfile.open(output_tar_name, "w") as tar:
59 |         for item in os.listdir(input_dir):
60 |             item_path = os.path.join(input_dir, item)
61 |             tar.add(item_path, arcname=item)  # arcname maintains the relative path of the file in the tar
62 | 
63 |     print(f"\nAll files have been packaged into: {output_tar_name}")
64 |     print(f"Number of PNG images processed: {png_count}")
65 | 
66 | 
67 | if __name__ == "__main__":
68 |     input_directory = input("Please enter the directory path containing PNG and TXT files: ")
69 |     output_tar_filename = (
70 |         input("Please enter the name of the output TAR file (default is output.tar): ") or "output.tar"
71 |     )
72 |     process_data(input_directory, output_tar_filename)
73 | 


--------------------------------------------------------------------------------
/tools/convert_py_to_yaml.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | 
 3 | import yaml
 4 | 
 5 | 
 6 | def convert_py_to_yaml(py_file_path):
 7 |     with open(py_file_path, encoding="utf-8") as py_file:
 8 |         py_content = py_file.read()
 9 | 
10 |     local_vars = {}
11 |     exec(py_content, {}, local_vars)
12 | 
13 |     yaml_file_path = os.path.splitext(py_file_path)[0] + ".yaml"
14 | 
15 |     with open(yaml_file_path, "w", encoding="utf-8") as yaml_file:
16 |         yaml.dump(local_vars, yaml_file, default_flow_style=False, allow_unicode=True)
17 | 
18 | 
19 | def process_directory(path):
20 |     for root, dirs, files in os.walk(path):
21 |         for filename in files:
22 |             if filename.endswith(".py"):
23 |                 py_file_path = os.path.join(root, filename)
24 |                 convert_py_to_yaml(py_file_path)
25 |                 print(f"convert {py_file_path} to YAML format")
26 | 
27 | 
28 | if __name__ == "__main__":
29 |     process_directory("../configs/")
30 | 


--------------------------------------------------------------------------------
/tools/create_wids_metadata.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #     http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | 
17 | import json
18 | import os
19 | import sys
20 | import tarfile
21 | from glob import glob
22 | 
23 | from tqdm.contrib.concurrent import process_map
24 | 
25 | """
26 | python tools/create_wids_metadata.py /path/to/tar/dir > /path/to/wids-meta.json
27 | """
28 | 
29 | d = sys.argv[1]
30 | 
31 | 
32 | def process(t):
33 |     d = {}
34 |     with tarfile.open(t, "r") as tar:
35 |         for f in tar:
36 |             n, e = os.path.splitext(f.name)
37 |             if e == ".jpg" or e == ".jpeg" or e == ".png" or e == ".json" or e == ".npy":
38 |                 if n in d:
39 |                     d[n] = 1
40 |                 else:
41 |                     d[n] = 0
42 |     s = os.path.getsize(t)
43 |     i = sum(d.values())
44 |     t = os.path.basename(t)
45 |     return {"url": t, "nsamples": i, "filesize": s}
46 | 
47 | 
48 | print(
49 |     json.dumps(
50 |         {
51 |             "name": "sana-dev",
52 |             "__kind__": "SANA-WebDataset",
53 |             "wids_version": 1,
54 |             "shardlist": sorted(
55 |                 process_map(process, glob(f"{d}/*.tar"), chunksize=1, max_workers=os.cpu_count()),
56 |                 key=lambda x: x["url"],
57 |             ),
58 |         },
59 |         indent=4,
60 |     ),
61 |     end="",
62 | )
63 | 


--------------------------------------------------------------------------------
/tools/download.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #      http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | """
17 | Functions for downloading pre-trained Sana models
18 | """
19 | import argparse
20 | import os
21 | 
22 | import torch
23 | from termcolor import colored
24 | from torchvision.datasets.utils import download_url
25 | 
26 | from sana.tools import hf_download_or_fpath
27 | 
28 | pretrained_models = {}
29 | 
30 | 
31 | def find_model(model_name):
32 |     """
33 |     Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path.
34 |     """
35 |     if model_name in pretrained_models:  # Find/download our pre-trained G.pt checkpoints
36 |         return download_model(model_name)
37 | 
38 |     # Load a custom Sana checkpoint:
39 |     print(colored(f"[Sana] Loading model from {model_name}", attrs=["bold"]))
40 |     model_name = hf_download_or_fpath(model_name)
41 |     assert os.path.isfile(model_name), f"Could not find Sana checkpoint at {model_name}"
42 |     print(colored(f"[Sana] Loaded model from {model_name}", attrs=["bold"]))
43 |     if model_name.endswith(".safetensors"):
44 |         import safetensors
45 | 
46 |         return {"state_dict": safetensors.torch.load_file(model_name, device="cpu")}
47 |     elif model_name.endswith(".safetensors.index.json"):
48 |         import json
49 | 
50 |         import safetensors
51 | 
52 |         index = json.load(open(model_name))["weight_map"]
53 |         safetensors_list = set(index.values())
54 |         state_dict = {}
55 |         for safetensors_path in safetensors_list:
56 |             state_dict.update(
57 |                 safetensors.torch.load_file(os.path.join(os.path.dirname(model_name), safetensors_path), device="cpu")
58 |             )
59 |         return {"state_dict": state_dict}
60 |     else:
61 |         return torch.load(model_name, map_location=lambda storage, loc: storage)
62 | 
63 | 
64 | def download_model(model_name):
65 |     """
66 |     Downloads a pre-trained Sana model from the web.
67 |     """
68 |     assert model_name in pretrained_models
69 |     local_path = f"output/pretrained_models/{model_name}"
70 |     if not os.path.isfile(local_path):
71 |         hf_endpoint = os.environ.get("HF_ENDPOINT")
72 |         if hf_endpoint is None:
73 |             hf_endpoint = "https://huggingface.co"
74 |         os.makedirs("output/pretrained_models", exist_ok=True)
75 |         web_path = f""
76 |         download_url(web_path, "output/pretrained_models/")
77 |     model = torch.load(local_path, map_location=lambda storage, loc: storage)
78 |     return model
79 | 
80 | 
81 | if __name__ == "__main__":
82 |     parser = argparse.ArgumentParser()
83 |     parser.add_argument("--model_names", nargs="+", type=str, default=pretrained_models)
84 |     args = parser.parse_args()
85 |     model_names = args.model_names
86 |     model_names = set(model_names)
87 | 
88 |     # Download Sana checkpoints
89 |     for model in model_names:
90 |         download_model(model)
91 |     print("Done.")
92 | 


--------------------------------------------------------------------------------
/tools/inference_scaling/nvila_sana_pick.sh:
--------------------------------------------------------------------------------
 1 | #! /bin/bash
 2 | set -e
 3 | 
 4 | sana_dir=$1
 5 | number_of_files=$2
 6 | pick_number=$3
 7 | # calculate number of GPU to use in this machine
 8 | num_gpu=$(nvidia-smi -L | wc -l)
 9 | echo "sana_dir: $sana_dir, number_of_files: $number_of_files, pick_number: $pick_number, num_gpu: $num_gpu"
10 | # start idx iterate from 0 * (552//8), 1 * (552//8), 2 * (552//8), 3 * (552//8), 4 * (552//8), 5 * (552//8), 6 * (552//8), 7 * (552//8)
11 | # end idx iterate from 1 * (552//8), 2 * (552//8), 3 * (552//8), 4 * (552//8), 5 * (552//8), 6 * (552//8), 7 * (552//8), 552
12 | for idx in $(seq 0 $((num_gpu - 1))); do
13 |     start_idx=$((idx * (552 / num_gpu)))
14 |     end_idx=$((start_idx + 552 / num_gpu))
15 |     if [ $idx -eq $((num_gpu - 1)) ]; then
16 |         end_idx=552
17 |     fi
18 | 
19 |     echo "CUDA_VISIBLE_DEVICES=$idx python tools/inference_scaling/nvila_sana_pick.py --start_idx $start_idx --end_idx $end_idx --base_dir $sana_dir --number_of_files $number_of_files --pick_number $pick_number &"
20 |     CUDA_VISIBLE_DEVICES=$idx python tools/inference_scaling/nvila_sana_pick.py --start_idx $start_idx --end_idx $end_idx --base_dir $sana_dir --number_of_files $number_of_files --pick_number $pick_number &
21 | done
22 | wait
23 | 


--------------------------------------------------------------------------------
/tools/metrics/clip-score/.gitignore:
--------------------------------------------------------------------------------
  1 | # Byte-compiled / optimized / DLL files
  2 | __pycache__/
  3 | *.py[cod]
  4 | *$py.class
  5 | 
  6 | # C extensions
  7 | *.so
  8 | 
  9 | # Distribution / packaging
 10 | .Python
 11 | build/
 12 | develop-eggs/
 13 | dist/
 14 | downloads/
 15 | eggs/
 16 | .eggs/
 17 | lib/
 18 | lib64/
 19 | parts/
 20 | sdist/
 21 | var/
 22 | wheels/
 23 | pip-wheel-metadata/
 24 | share/python-wheels/
 25 | *.egg-info/
 26 | .installed.cfg
 27 | *.egg
 28 | MANIFEST
 29 | 
 30 | # PyInstaller
 31 | #  Usually these files are written by a python script from a template
 32 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 33 | *.manifest
 34 | *.spec
 35 | 
 36 | # Installer logs
 37 | pip-log.txt
 38 | pip-delete-this-directory.txt
 39 | 
 40 | # Unit test / coverage reports
 41 | htmlcov/
 42 | .tox/
 43 | .nox/
 44 | .coverage
 45 | .coverage.*
 46 | .cache
 47 | nosetests.xml
 48 | coverage.xml
 49 | *.cover
 50 | .hypothesis/
 51 | .pytest_cache/
 52 | 
 53 | # Translations
 54 | *.mo
 55 | *.pot
 56 | 
 57 | # Django stuff:
 58 | *.log
 59 | local_settings.py
 60 | db.sqlite3
 61 | 
 62 | # Flask stuff:
 63 | instance/
 64 | .webassets-cache
 65 | 
 66 | # Scrapy stuff:
 67 | .scrapy
 68 | 
 69 | # Sphinx documentation
 70 | docs/_build/
 71 | 
 72 | # PyBuilder
 73 | target/
 74 | 
 75 | # Jupyter Notebook
 76 | .ipynb_checkpoints
 77 | 
 78 | # IPython
 79 | profile_default/
 80 | ipython_config.py
 81 | 
 82 | # pyenv
 83 | .python-version
 84 | 
 85 | # celery beat schedule file
 86 | celerybeat-schedule
 87 | 
 88 | # SageMath parsed files
 89 | *.sage.py
 90 | 
 91 | # Environments
 92 | .env
 93 | .venv
 94 | env/
 95 | venv/
 96 | ENV/
 97 | env.bak/
 98 | venv.bak/
 99 | 
100 | # Spyder project settings
101 | .spyderproject
102 | .spyproject
103 | 
104 | # Rope project settings
105 | .ropeproject
106 | 
107 | # mkdocs documentation
108 | /site
109 | 
110 | # mypy
111 | .mypy_cache/
112 | .dmypy.json
113 | dmypy.json
114 | 
115 | # Pyre type checker
116 | .pyre/
117 | 


--------------------------------------------------------------------------------
/tools/metrics/clip-score/setup.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | 
 3 | import setuptools
 4 | 
 5 | 
 6 | def read(rel_path):
 7 |     base_path = os.path.abspath(os.path.dirname(__file__))
 8 |     with open(os.path.join(base_path, rel_path)) as f:
 9 |         return f.read()
10 | 
11 | 
12 | def get_version(rel_path):
13 |     for line in read(rel_path).splitlines():
14 |         if line.startswith("__version__"):
15 |             delim = '"' if '"' in line else "'"
16 |             return line.split(delim)[1]
17 | 
18 |     raise RuntimeError("Unable to find version string.")
19 | 
20 | 
21 | if __name__ == "__main__":
22 |     setuptools.setup(
23 |         name="clip-score",
24 |         version=get_version(os.path.join("src", "clip_score", "__init__.py")),
25 |         author="Taited",
26 |         author_email="taited9160@gmail.com",
27 |         description=("Package for calculating CLIP-Score" " using PyTorch"),
28 |         long_description=read("README.md"),
29 |         long_description_content_type="text/markdown",
30 |         url="https://github.com/taited/clip-score",
31 |         package_dir={"": "src"},
32 |         packages=setuptools.find_packages(where="src"),
33 |         classifiers=[
34 |             "Programming Language :: Python :: 3",
35 |             "License :: OSI Approved :: Apache Software License",
36 |         ],
37 |         python_requires=">=3.5",
38 |         entry_points={
39 |             "console_scripts": [
40 |                 "clip-score = clip_score.clip_score:main",
41 |             ],
42 |         },
43 |         install_requires=[
44 |             "numpy",
45 |             "pillow",
46 |             "torch>=1.7.1",
47 |             "torchvision>=0.8.2",
48 |             "ftfy",
49 |             "regex",
50 |             "tqdm",
51 |         ],
52 |         extras_require={"dev": ["flake8", "flake8-bugbear", "flake8-isort", "nox"]},
53 |     )
54 | 


--------------------------------------------------------------------------------
/tools/metrics/clip-score/src/clip_score/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.1"
2 | 


--------------------------------------------------------------------------------
/tools/metrics/clip-score/src/clip_score/__main__.py:
--------------------------------------------------------------------------------
1 | import clip_score.clip_score
2 | 
3 | clip_score.clip_score.main()
4 | 


--------------------------------------------------------------------------------
/tools/metrics/dpg_bench/requirements.txt:
--------------------------------------------------------------------------------
 1 | accelerate
 2 | addict
 3 | 
 4 | # for modelscope
 5 | cloudpickle
 6 | datasets==2.21.0
 7 | decord>=0.6.0
 8 | diffusers
 9 | ftfy>=6.0.3
10 | librosa==0.10.1
11 | modelscope[multi-modal]
12 | numpy
13 | opencv-python
14 | oss2
15 | pandas
16 | pillow
17 | # compatible with taming-transformers-rom1504
18 | rapidfuzz
19 | # rough-score was just recently updated from 0.0.4 to 0.0.7
20 | # which introduced compatability issues that are being investigated
21 | rouge_score<=0.0.4
22 | safetensors
23 | simplejson
24 | sortedcontainers
25 | # scikit-video
26 | soundfile
27 | taming-transformers-rom1504
28 | tiktoken
29 | timm
30 | tokenizers
31 | torchvision
32 | tqdm
33 | transformers
34 | transformers_stream_generator
35 | unicodedata2
36 | wandb
37 | zhconv
38 | # fairseq   need to be build from source code: https://github.com/facebookresearch/fairseq
39 | 


--------------------------------------------------------------------------------
/tools/metrics/geneval/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2023 Dhruba Ghosh
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 


--------------------------------------------------------------------------------
/tools/metrics/geneval/evaluation/download_models.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | 
 3 | # Download Mask2Former object detection config and weights
 4 | 
 5 | if [ ! -z "$1" ]
 6 | then
 7 |     mkdir -p "$1"
 8 |     echo "Downloading mask2former for GenEval"
 9 |     wget https://download.openmmlab.com/mmdetection/v2.0/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco_20220504_001756-743b7d99.pth -O "$1/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.pth"
10 | fi
11 | 


--------------------------------------------------------------------------------
/tools/metrics/geneval/evaluation/object_names.txt:
--------------------------------------------------------------------------------
 1 | person
 2 | bicycle
 3 | car
 4 | motorcycle
 5 | airplane
 6 | bus
 7 | train
 8 | truck
 9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | couch
59 | potted plant
60 | bed
61 | dining table
62 | toilet
63 | tv
64 | laptop
65 | computer mouse
66 | tv remote
67 | computer keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
81 | 


--------------------------------------------------------------------------------
/tools/metrics/geneval/evaluation/summary_scores.py:
--------------------------------------------------------------------------------
 1 | # Get results of evaluation
 2 | 
 3 | import argparse
 4 | import os
 5 | 
 6 | import numpy as np
 7 | import pandas as pd
 8 | 
 9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("filename", type=str)
11 | args = parser.parse_args()
12 | 
13 | # Load classnames
14 | 
15 | with open(os.path.join(os.path.dirname(__file__), "object_names.txt")) as cls_file:
16 |     classnames = [line.strip() for line in cls_file]
17 |     cls_to_idx = {"_".join(cls.split()): idx for idx, cls in enumerate(classnames)}
18 | 
19 | # Load results
20 | 
21 | df = pd.read_json(args.filename, orient="records", lines=True)
22 | 
23 | # Measure overall success
24 | 
25 | print("Summary")
26 | print("=======")
27 | print(f"Total images: {len(df)}")
28 | print(f"Total prompts: {len(df.groupby('metadata'))}")
29 | print(f"% correct images: {df['correct'].mean():.2%}")
30 | print(f"% correct prompts: {df.groupby('metadata')['correct'].any().mean():.2%}")
31 | print()
32 | 
33 | # By group
34 | 
35 | task_scores = []
36 | 
37 | print("Task breakdown")
38 | print("==============")
39 | for tag, task_df in df.groupby("tag", sort=False):
40 |     task_scores.append(task_df["correct"].mean())
41 |     print(f"{tag:<16} = {task_df['correct'].mean():.2%} ({task_df['correct'].sum()} / {len(task_df)})")
42 | print()
43 | 
44 | print(f"Overall score (avg. over tasks): {np.mean(task_scores):.5f}")
45 | 


--------------------------------------------------------------------------------
/tools/metrics/geneval/geneval_env.md:
--------------------------------------------------------------------------------
 1 | The installation process refers to https://github.com/djghosh13/geneval/issues/12 Thanks for the community!
 2 | 
 3 | # Cloning the repository
 4 | 
 5 | ```bash
 6 | git clone https://github.com/djghosh13/geneval.git
 7 | 
 8 | cd geneval
 9 | conda create -n geneval python=3.8.10 -y
10 | conda activate geneval
11 | ```
12 | 
13 | # Installing dependencies
14 | 
15 | ```bash
16 | pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
17 | pip install open-clip-torch==2.26.1
18 | pip install clip-benchmark
19 | pip install -U openmim
20 | pip install einops
21 | python -m pip install lightning
22 | pip install diffusers["torch"] transformers
23 | pip install tomli
24 | pip install platformdirs
25 | pip install --upgrade setuptools
26 | ```
27 | 
28 | # mmengine and mmcv dependency installation
29 | 
30 | ```bash
31 | mim install mmengine mmcv-full==1.7.2
32 | ```
33 | 
34 | # mmdet installation
35 | 
36 | ```bash
37 | git clone https://github.com/open-mmlab/mmdetection.git
38 | cd mmdetection; git checkout 2.x
39 | pip install -v -e .
40 | ```
41 | 
42 | <!-- conda create -n geneval python==3.9 -->
43 | 
44 | <!--
45 | conda install -c nvidia cuda-toolkit -y
46 | ./evaluation/download_models.sh output/
47 | 
48 | # install mmdetection
49 | 
50 | git clone https://github.com/open-mmlab/mmdetection.git
51 | cd mmdetection; git checkout 2.x
52 | pip install -v -e .
53 | 
54 | pip install einops
55 | pip install accelerate==0.15.0
56 | pip install torchmetrics==0.6.0
57 | pip install transformers==4.48.0
58 | pip install pandas
59 | pip install open-clip-torch
60 | pip install clip_benchmark
61 | pip install huggingface_hub==0.24.5
62 | pip install pytorch-lightning==1.4.2
63 | 
64 | pip install openmim
65 | mim install mmcv-full==1.7.1 -->
66 | 


--------------------------------------------------------------------------------
/tools/metrics/geneval/images/geneval_figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/tools/metrics/geneval/images/geneval_figure_1.png


--------------------------------------------------------------------------------
/tools/metrics/geneval/prompts/object_names.txt:
--------------------------------------------------------------------------------
 1 | person
 2 | bicycle
 3 | car
 4 | motorcycle
 5 | airplane
 6 | bus
 7 | train
 8 | truck
 9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | couch
59 | potted plant
60 | bed
61 | dining table
62 | toilet
63 | tv
64 | laptop
65 | computer mouse
66 | tv remote
67 | computer keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
81 | 


--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/.gitignore:
--------------------------------------------------------------------------------
  1 | # Byte-compiled / optimized / DLL files
  2 | __pycache__/
  3 | *.py[cod]
  4 | *$py.class
  5 | 
  6 | # C extensions
  7 | *.so
  8 | 
  9 | # Distribution / packaging
 10 | .Python
 11 | build/
 12 | develop-eggs/
 13 | dist/
 14 | downloads/
 15 | eggs/
 16 | .eggs/
 17 | lib/
 18 | lib64/
 19 | parts/
 20 | sdist/
 21 | var/
 22 | wheels/
 23 | pip-wheel-metadata/
 24 | share/python-wheels/
 25 | *.egg-info/
 26 | .installed.cfg
 27 | *.egg
 28 | MANIFEST
 29 | 
 30 | # PyInstaller
 31 | #  Usually these files are written by a python script from a template
 32 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 33 | *.manifest
 34 | *.spec
 35 | 
 36 | # Installer logs
 37 | pip-log.txt
 38 | pip-delete-this-directory.txt
 39 | 
 40 | # Unit test / coverage reports
 41 | htmlcov/
 42 | .tox/
 43 | .nox/
 44 | .coverage
 45 | .coverage.*
 46 | .cache
 47 | nosetests.xml
 48 | coverage.xml
 49 | *.cover
 50 | .hypothesis/
 51 | .pytest_cache/
 52 | 
 53 | # Translations
 54 | *.mo
 55 | *.pot
 56 | 
 57 | # Django stuff:
 58 | *.log
 59 | local_settings.py
 60 | db.sqlite3
 61 | 
 62 | # Flask stuff:
 63 | instance/
 64 | .webassets-cache
 65 | 
 66 | # Scrapy stuff:
 67 | .scrapy
 68 | 
 69 | # Sphinx documentation
 70 | docs/_build/
 71 | 
 72 | # PyBuilder
 73 | target/
 74 | 
 75 | # Jupyter Notebook
 76 | .ipynb_checkpoints
 77 | 
 78 | # IPython
 79 | profile_default/
 80 | ipython_config.py
 81 | 
 82 | # pyenv
 83 | .python-version
 84 | 
 85 | # celery beat schedule file
 86 | celerybeat-schedule
 87 | 
 88 | # SageMath parsed files
 89 | *.sage.py
 90 | 
 91 | # Environments
 92 | .env
 93 | .venv
 94 | env/
 95 | venv/
 96 | ENV/
 97 | env.bak/
 98 | venv.bak/
 99 | 
100 | # Spyder project settings
101 | .spyderproject
102 | .spyproject
103 | 
104 | # Rope project settings
105 | .ropeproject
106 | 
107 | # mkdocs documentation
108 | /site
109 | 
110 | # mypy
111 | .mypy_cache/
112 | .dmypy.json
113 | dmypy.json
114 | 
115 | # Pyre type checker
116 | .pyre/
117 | 


--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/CHANGELOG.md:
--------------------------------------------------------------------------------
 1 | # Changelog
 2 | 
 3 | ## \[0.3.0\] - 2023-01-05
 4 | 
 5 | ### Added
 6 | 
 7 | - Add argument `--save-stats` allowing to compute dataset statistics and save them as an `.npz` file ([#80](https://github.com/mseitzer/pytorch-fid/pull/80)). The `.npz` file can be used in subsequent FID computations instead of recomputing the dataset statistics. This option can be used in the following way: `python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile`.
 8 | 
 9 | ### Fixed
10 | 
11 | - Do not use `os.sched_getaffinity` to get number of available CPUs on Windows, as it is not available there ([232b3b14](https://github.com/mseitzer/pytorch-fid/commit/232b3b1468800102fcceaf6f2bb8977811fc991a), [#84](https://github.com/mseitzer/pytorch-fid/issues/84)).
12 | - Do not use Inception model argument `pretrained`, as it was deprecated in torchvision 0.13 ([#88](https://github.com/mseitzer/pytorch-fid/pull/88)).
13 | 
14 | ## \[0.2.1\] - 2021-10-10
15 | 
16 | ### Added
17 | 
18 | - Add argument `--num-workers` to select number of dataloader processes ([#66](https://github.com/mseitzer/pytorch-fid/pull/66)). Defaults to 8 or the number of available CPUs if less than 8 CPUs are available.
19 | 
20 | ### Fixed
21 | 
22 | - Fixed package setup to work under Windows ([#55](https://github.com/mseitzer/pytorch-fid/pull/55), [#72](https://github.com/mseitzer/pytorch-fid/issues/72))
23 | 
24 | ## \[0.2.0\] - 2020-11-30
25 | 
26 | ### Added
27 | 
28 | - Load images using a Pytorch dataloader, which should result in a speed-up. ([#47](https://github.com/mseitzer/pytorch-fid/pull/47))
29 | - Support more image extensions ([#53](https://github.com/mseitzer/pytorch-fid/pull/53))
30 | - Improve tooling by setting up Nox, add linting and test support ([#52](https://github.com/mseitzer/pytorch-fid/pull/52))
31 | - Add some unit tests
32 | 
33 | ## \[0.1.1\] - 2020-08-16
34 | 
35 | ### Fixed
36 | 
37 | - Fixed software license string in `setup.py`
38 | 
39 | ## \[0.1.0\] - 2020-08-16
40 | 
41 | Initial release as a pypi package. Use `pip install pytorch-fid` to install.
42 | 


--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/noxfile.py:
--------------------------------------------------------------------------------
 1 | import nox
 2 | 
 3 | LOCATIONS = ("src/", "tests/", "noxfile.py", "setup.py")
 4 | 
 5 | 
 6 | @nox.session
 7 | def lint(session):
 8 |     session.install("flake8")
 9 |     session.install("flake8-bugbear")
10 |     session.install("flake8-isort")
11 | 
12 |     args = session.posargs or LOCATIONS
13 |     session.run("flake8", *args)
14 | 
15 | 
16 | @nox.session
17 | def tests(session):
18 |     session.install(".")
19 |     session.install("pytest")
20 |     session.install("pytest-mock")
21 |     session.run("pytest", *session.posargs)
22 | 


--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | select=F,W,E,I,B,B9
3 | ignore=W503,B950
4 | max-line-length=79
5 | 
6 | [isort]
7 | multi_line_output=1
8 | line_length=79
9 | 


--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/setup.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | 
 3 | import setuptools
 4 | 
 5 | 
 6 | def read(rel_path):
 7 |     base_path = os.path.abspath(os.path.dirname(__file__))
 8 |     with open(os.path.join(base_path, rel_path)) as f:
 9 |         return f.read()
10 | 
11 | 
12 | def get_version(rel_path):
13 |     for line in read(rel_path).splitlines():
14 |         if line.startswith("__version__"):
15 |             # __version__ = "0.9"
16 |             delim = '"' if '"' in line else "'"
17 |             return line.split(delim)[1]
18 | 
19 |     raise RuntimeError("Unable to find version string.")
20 | 
21 | 
22 | if __name__ == "__main__":
23 |     setuptools.setup(
24 |         name="pytorch-fid",
25 |         version=get_version(os.path.join("src", "pytorch_fid", "__init__.py")),
26 |         author="Max Seitzer",
27 |         description=("Package for calculating Frechet Inception Distance (FID)" " using PyTorch"),
28 |         long_description=read("README.md"),
29 |         long_description_content_type="text/markdown",
30 |         url="https://github.com/mseitzer/pytorch-fid",
31 |         package_dir={"": "src"},
32 |         packages=setuptools.find_packages(where="src"),
33 |         classifiers=[
34 |             "Programming Language :: Python :: 3",
35 |             "License :: OSI Approved :: Apache Software License",
36 |         ],
37 |         python_requires=">=3.5",
38 |         entry_points={
39 |             "console_scripts": [
40 |                 "pytorch-fid = pytorch_fid.fid_score:main",
41 |             ],
42 |         },
43 |         install_requires=["numpy", "pillow", "scipy", "torch>=1.0.1", "torchvision>=0.2.2"],
44 |         extras_require={"dev": ["flake8", "flake8-bugbear", "flake8-isort", "nox"]},
45 |     )
46 | 


--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/src/pytorch_fid/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.3.0"
2 | 


--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/src/pytorch_fid/__main__.py:
--------------------------------------------------------------------------------
1 | import pytorch_fid.fid_score
2 | 
3 | pytorch_fid.fid_score.main()
4 | 


--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/tests/test_fid_score.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | import pytest
 3 | import torch
 4 | from PIL import Image
 5 | from pytorch_fid import fid_score, inception
 6 | 
 7 | 
 8 | @pytest.fixture
 9 | def device():
10 |     return torch.device("cpu")
11 | 
12 | 
13 | def test_calculate_fid_given_statistics(mocker, tmp_path, device):
14 |     dim = 2048
15 |     m1, m2 = np.zeros((dim,)), np.ones((dim,))
16 |     sigma = np.eye(dim)
17 | 
18 |     def dummy_statistics(path, model, batch_size, dims, device, num_workers):
19 |         if path.endswith("1"):
20 |             return m1, sigma
21 |         elif path.endswith("2"):
22 |             return m2, sigma
23 |         else:
24 |             raise ValueError
25 | 
26 |     mocker.patch("pytorch_fid.fid_score.compute_statistics_of_path", side_effect=dummy_statistics)
27 | 
28 |     dir_names = ["1", "2"]
29 |     paths = []
30 |     for name in dir_names:
31 |         path = tmp_path / name
32 |         path.mkdir()
33 |         paths.append(str(path))
34 | 
35 |     fid_value = fid_score.calculate_fid_given_paths(paths, batch_size=dim, device=device, dims=dim, num_workers=0)
36 | 
37 |     # Given equal covariance, FID is just the squared norm of difference
38 |     assert fid_value == np.sum((m1 - m2) ** 2)
39 | 
40 | 
41 | def test_compute_statistics_of_path(mocker, tmp_path, device):
42 |     model = mocker.MagicMock(inception.InceptionV3)()
43 |     model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)]
44 | 
45 |     size = (4, 4, 3)
46 |     arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]
47 |     images = [(arr * 255).astype(np.uint8) for arr in arrays]
48 | 
49 |     paths = []
50 |     for idx, image in enumerate(images):
51 |         paths.append(str(tmp_path / f"{idx}.png"))
52 |         Image.fromarray(image, mode="RGB").save(paths[-1])
53 | 
54 |     stats = fid_score.compute_statistics_of_path(
55 |         str(tmp_path), model, batch_size=len(images), dims=3, device=device, num_workers=0
56 |     )
57 | 
58 |     assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3)
59 |     assert np.allclose(stats[1], np.ones((3, 3)) * 0.25)
60 | 
61 | 
62 | def test_compute_statistics_of_path_from_file(mocker, tmp_path, device):
63 |     model = mocker.MagicMock(inception.InceptionV3)()
64 | 
65 |     mu = np.random.randn(5)
66 |     sigma = np.random.randn(5, 5)
67 | 
68 |     path = tmp_path / "stats.npz"
69 |     with path.open("wb") as f:
70 |         np.savez(f, mu=mu, sigma=sigma)
71 | 
72 |     stats = fid_score.compute_statistics_of_path(str(path), model, batch_size=1, dims=5, device=device, num_workers=0)
73 | 
74 |     assert np.allclose(stats[0], mu)
75 |     assert np.allclose(stats[1], sigma)
76 | 
77 | 
78 | def test_image_types(tmp_path):
79 |     in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255
80 |     in_image = Image.fromarray(in_arr, mode="RGB")
81 | 
82 |     paths = []
83 |     for ext in fid_score.IMAGE_EXTENSIONS:
84 |         paths.append(str(tmp_path / f"img.{ext}"))
85 |         in_image.save(paths[-1])
86 | 
87 |     dataset = fid_score.ImagePathDataset(paths)
88 | 
89 |     for img in dataset:
90 |         assert np.allclose(np.array(img), in_arr)
91 | 


--------------------------------------------------------------------------------
/tools/metrics/utils.py:
--------------------------------------------------------------------------------
 1 | import re
 2 | 
 3 | 
 4 | def tracker(args, result_dict, label="", pattern="epoch_step", metric="FID"):
 5 |     if args.report_to == "wandb":
 6 |         import wandb
 7 | 
 8 |         wandb_name = f"[{args.log_metric}]_{args.name}"
 9 |         wandb.init(project=args.tracker_project_name, name=wandb_name, resume="allow", id=wandb_name, tags="metrics")
10 |         run = wandb.run
11 |         if pattern == "step":
12 |             pattern = "sample_steps"
13 |         elif pattern == "epoch_step":
14 |             pattern = "step"
15 |         custom_name = f"custom_{pattern}"
16 |         run.define_metric(custom_name)
17 |         # define which metrics will be plotted against it
18 |         run.define_metric(f"{metric}_{label}", step_metric=custom_name)
19 | 
20 |         steps = []
21 |         results = []
22 | 
23 |         def extract_value(regex, exp_name):
24 |             match = re.search(regex, exp_name)
25 |             if match:
26 |                 return match.group(1)
27 |             else:
28 |                 return "unknown"
29 | 
30 |         for exp_name, result_value in result_dict.items():
31 |             if pattern == "step":
32 |                 regex = r".*step(\d+)_scale.*"
33 |                 custom_x = extract_value(regex, exp_name)
34 |             elif pattern == "sample_steps":
35 |                 regex = r".*step(\d+)_size.*"
36 |                 custom_x = extract_value(regex, exp_name)
37 |             else:
38 |                 regex = rf"{pattern}(\d+(\.\d+)?)"
39 |                 custom_x = extract_value(regex, exp_name)
40 |                 custom_x = 1 if custom_x == "unknown" else custom_x
41 | 
42 |             assert custom_x != "unknown"
43 |             steps.append(float(custom_x))
44 |             results.append(result_value)
45 | 
46 |         sorted_data = sorted(zip(steps, results))
47 |         steps, results = zip(*sorted_data)
48 | 
49 |         for step, result in sorted(zip(steps, results)):
50 |             run.log({f"{metric}_{label}": result, custom_name: step})
51 |     else:
52 |         print(f"{args.report_to} is not supported")
53 | 


--------------------------------------------------------------------------------
/train_scripts/train.sh:
--------------------------------------------------------------------------------
 1 | #/bin/bash
 2 | set -e
 3 | 
 4 | work_dir=output/debug
 5 | np=2
 6 | 
 7 | while [[ $# -gt 0 ]]; do
 8 |     case $1 in
 9 |         --np=*)
10 |             np="${1#*=}"
11 |             shift
12 |             ;;
13 |         *.yaml)
14 |             config=$1
15 |             shift
16 |             ;;
17 |         *)
18 |             other_args+=("$1")
19 |             shift
20 |             ;;
21 |     esac
22 | done
23 | 
24 | if [[ -z "$config" ]]; then
25 |     config="configs/sana1-5_config/1024ms/Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml"
26 |     echo "No yaml file specified. Set to --config_path=$config"
27 | fi
28 | 
29 | cmd="TRITON_PRINT_AUTOTUNING=1 \
30 |     torchrun --nproc_per_node=$np --master_port=$((RANDOM % 10000 + 20000))  \
31 |         train_scripts/train.py \
32 |         --config_path=$config \
33 |         --work_dir=$work_dir \
34 |         --name=tmp \
35 |         --resume_from=latest \
36 |         --report_to=tensorboard \
37 |         --debug=true \
38 |         ${other_args[@]}"
39 | 
40 | echo $cmd
41 | eval $cmd
42 | 


--------------------------------------------------------------------------------
/train_scripts/train_lora.sh:
--------------------------------------------------------------------------------
 1 | #! /bin/bash
 2 | 
 3 | export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
 4 | export INSTANCE_DIR="data/dreambooth/dog"
 5 | export OUTPUT_DIR="trained-sana-lora"
 6 | 
 7 | accelerate launch --num_processes 4 --main_process_port 29500 --gpu_ids 0,1,2,3 \
 8 |   train_scripts/train_dreambooth_lora_sana.py \
 9 |   --pretrained_model_name_or_path=$MODEL_NAME  \
10 |   --instance_data_dir=$INSTANCE_DIR \
11 |   --output_dir=$OUTPUT_DIR \
12 |   --mixed_precision="bf16" \
13 |   --instance_prompt="a photo of sks dog" \
14 |   --resolution=1024 \
15 |   --train_batch_size=1 \
16 |   --gradient_accumulation_steps=4 \
17 |   --use_8bit_adam \
18 |   --learning_rate=1e-4 \
19 |   --report_to="wandb" \
20 |   --lr_scheduler="constant" \
21 |   --lr_warmup_steps=0 \
22 |   --max_train_steps=500 \
23 |   --validation_prompt="A photo of sks dog in a pond, yarn art style" \
24 |   --validation_epochs=25 \
25 |   --seed="0" \
26 |   --push_to_hub
27 | 


--------------------------------------------------------------------------------
/train_scripts/train_scm_ladd.sh:
--------------------------------------------------------------------------------
 1 | #/bin/bash
 2 | set -e
 3 | 
 4 | work_dir=output/debug_sCM_ladd
 5 | np=2
 6 | 
 7 | while [[ $# -gt 0 ]]; do
 8 |     case $1 in
 9 |         --np=*)
10 |             np="${1#*=}"
11 |             shift
12 |             ;;
13 |         *.yaml)
14 |             config=$1
15 |             shift
16 |             ;;
17 |         *)
18 |             other_args+=("$1")
19 |             shift
20 |             ;;
21 |     esac
22 | done
23 | 
24 | if [[ -z "$config" ]]; then
25 |     config="configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml"
26 |     echo "Only support .yaml files, but get $1. Set to --config_path=$config"
27 | fi
28 | 
29 | cmd="TRITON_PRINT_AUTOTUNING=1 \
30 |     torchrun --nproc_per_node=$np --master_port=$((RANDOM % 10000 + 20000)) \
31 |         train_scripts/train_scm_ladd.py \
32 |         --config_path=$config \
33 |         --work_dir=$work_dir \
34 |         --name=tmp \
35 |         --resume_from=latest \
36 |         --report_to=tensorboard \
37 |         --debug=true \
38 |         ${other_args[@]}"
39 | 
40 | echo $cmd
41 | eval $cmd
42 | 


--------------------------------------------------------------------------------
/train_video_scripts/train_video_ivjoint.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | set -e
 3 | 
 4 | work_dir=output/debug_video
 5 | np=1
 6 | 
 7 | 
 8 | while [[ $# -gt 0 ]]; do
 9 |     case $1 in
10 |         --np=*)
11 |             np="${1#*=}"
12 |             shift
13 |             ;;
14 |         *.yaml)
15 |             config=$1
16 |             shift
17 |             ;;
18 |         *)
19 |             other_args+=("$1")
20 |             shift
21 |             ;;
22 |     esac
23 | done
24 | 
25 | if [[ -z "$config" ]]; then
26 |     config="configs/sana_video_config/Sana_2000M_480px_AdamW_fsdp.yaml"
27 |     echo "No yaml file specified. Set to --config_path=$config"
28 | fi
29 | 
30 | export DISABLE_XFORMERS=1
31 | export DEBUG_MODE=1
32 | 
33 | cmd="TRITON_PRINT_AUTOTUNING=1 \
34 |     torchrun --nproc_per_node=$np --master_port=$((RANDOM % 10000 + 20000))  \
35 |         train_video_scripts/train_video_ivjoint.py \
36 |         --config_path=$config \
37 |         --work_dir=$work_dir \
38 |         --train.log_interval=1 \
39 |         --name=tmp \
40 |         --resume_from=latest \
41 |         --report_to=tensorboard \
42 |         --train.num_workers=0 \
43 |         --train.visualize=False \
44 |         --debug=true \
45 |         ${other_args[@]}"
46 | 
47 | echo $cmd
48 | eval $cmd
49 | 


--------------------------------------------------------------------------------
/train_video_scripts/train_video_ivjoint_chunk.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | set -e
 3 | 
 4 | work_dir=output/debug_video
 5 | np=1
 6 | 
 7 | 
 8 | if [[ $1 == *.yaml ]]; then
 9 |     config=$1
10 |     shift
11 | else
12 |     config="configs/sana_video_config/480ms/Sana_1600M_480px_adamW_fsdp_chunk.yaml"
13 |     echo "Only support .yaml files, but get $1. Set to --config_path=$config"
14 | fi
15 | 
16 | export DISABLE_XFORMERS=1
17 | export DEBUG_MODE=1
18 | 
19 | cmd="TRITON_PRINT_AUTOTUNING=1 \
20 |     torchrun --nproc_per_node=$np --master_port=$((RANDOM % 10000 + 20000))  \
21 |         configs/sana_video_config/480ms/Sana_1600M_480px_adamW_fsdp_chunk.yaml \
22 |         --config_path=$config \
23 |         --work_dir=$work_dir \
24 |         --train.log_interval=1 \
25 |         --name=tmp \
26 |         --resume_from=latest \
27 |         --report_to=tensorboard \
28 |         --train.num_workers=0 \
29 |         --train.visualize=False \
30 |         --debug=true \
31 |         $@"
32 | 
33 | echo $cmd
34 | eval $cmd
35 | 


--------------------------------------------------------------------------------