├── .github
└── workflows
│ ├── bot-autolint.yaml
│ └── ci.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── CITATION.cff
├── CIs
└── add_license_all.sh
├── Dockerfile
├── LICENSE
├── README.md
├── app
├── app_sana.py
├── app_sana_4bit.py
├── app_sana_4bit_compare_bf16.py
├── app_sana_controlnet_hed.py
├── app_sana_inpaint.py
├── app_sana_multithread.py
├── app_sana_sprint.py
├── safety_check.py
├── sana_controlnet_pipeline.py
├── sana_pipeline.py
├── sana_pipeline_inpaint.py
└── sana_sprint_pipeline.py
├── asset
├── Sana.jpg
├── app_styles
│ └── controlnet_app_style.css
├── apple.webp
├── controlnet
│ ├── ref_images
│ │ ├── A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg
│ │ ├── a house.png
│ │ ├── a living room.png
│ │ └── nvidia.png
│ └── samples_controlnet.json
├── cover.png
├── docs
│ ├── ComfyUI
│ │ ├── SANA-1.5_FlowEuler.json
│ │ ├── SANA-Sprint.json
│ │ ├── Sana_CogVideoX.json
│ │ ├── Sana_FlowEuler.json
│ │ ├── Sana_FlowEuler_2K.json
│ │ ├── Sana_FlowEuler_4K.json
│ │ └── comfyui.md
│ ├── inference_scaling
│ │ ├── inference_scaling.md
│ │ ├── results.jpg
│ │ └── scaling_curve.jpg
│ ├── metrics_toolkit.md
│ ├── model_zoo.md
│ ├── quantize
│ │ ├── 4bit_sana.md
│ │ └── 8bit_sana.md
│ ├── sana_controlnet.md
│ ├── sana_lora_dreambooth.md
│ ├── sana_sprint.md
│ └── sana_video.md
├── example_data
│ ├── 00000000.jpg
│ ├── 00000000.png
│ ├── 00000000.txt
│ ├── 00000000_InternVL2-26B.json
│ ├── 00000000_InternVL2-26B_clip_score.json
│ ├── 00000000_VILA1-5-13B.json
│ ├── 00000000_VILA1-5-13B_clip_score.json
│ ├── 00000000_prompt_clip_score.json
│ └── meta_data.json
├── examples.py
├── logo.png
├── mit-logo.jpg
├── model-incremental.jpg
├── model_paths.txt
├── paper2video.jpg
└── samples
│ ├── sample_i2v.txt
│ ├── samples.txt
│ ├── samples_mini.txt
│ └── video_prompts_samples.txt
├── configs
├── sana1-5_config
│ └── 1024ms
│ │ ├── Sana_1600M_1024px_AdamW_fsdp.yaml
│ │ ├── Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml
│ │ ├── Sana_3200M_1024px_came8bit_grow_constant_allqknorm_bf16_lr2e5.yaml
│ │ └── Sana_4800M_1024px_came8bit_grow_constant_allqknorm_bf16_lr2e5.yaml
├── sana_app_config
│ ├── Sana_1600M_app.yaml
│ └── Sana_600M_app.yaml
├── sana_base.yaml
├── sana_config
│ ├── 1024ms
│ │ ├── Sana_1600M_img1024.yaml
│ │ ├── Sana_1600M_img1024_AdamW.yaml
│ │ ├── Sana_1600M_img1024_CAME8bit.yaml
│ │ └── Sana_600M_img1024.yaml
│ ├── 2048ms
│ │ └── Sana_1600M_img2048_bf16.yaml
│ ├── 4096ms
│ │ └── Sana_1600M_img4096_bf16.yaml
│ └── 512ms
│ │ ├── Sana_1600M_img512.yaml
│ │ ├── Sana_600M_img512.yaml
│ │ ├── ci_Sana_600M_img512.yaml
│ │ └── sample_dataset.yaml
├── sana_controlnet_config
│ ├── Sana_1600M_1024px_controlnet_bf16.yaml
│ └── Sana_600M_img1024_controlnet.yaml
├── sana_sprint_config
│ └── 1024ms
│ │ ├── SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml
│ │ ├── SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd_dc_ae_lite.yaml
│ │ ├── SanaSprint_1600M_img1024_bf16_normT_allqknorm_teacher_ft.yaml
│ │ ├── SanaSprint_600M_1024px_allqknorm_bf16_scm_ladd.yaml
│ │ └── SanaSprint_600M_1024px_allqknorm_bf16_scm_ladd_dc_ae_lite.yaml
└── sana_video_config
│ ├── Sana_2000M_256px_AdamW_fsdp.yaml
│ └── Sana_2000M_480px_AdamW_fsdp.yaml
├── diffusion
├── __init__.py
├── data
│ ├── __init__.py
│ ├── builder.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── sana_data.py
│ │ ├── sana_data_multi_scale.py
│ │ ├── utils.py
│ │ └── video
│ │ │ └── sana_video_data.py
│ ├── transforms.py
│ └── wids
│ │ ├── __init__.py
│ │ ├── wids.py
│ │ ├── wids_dl.py
│ │ ├── wids_lru.py
│ │ ├── wids_mmtar.py
│ │ ├── wids_specs.py
│ │ └── wids_tar.py
├── guiders
│ ├── __init__.py
│ └── adaptive_projected_guidance.py
├── model
│ ├── __init__.py
│ ├── act.py
│ ├── builder.py
│ ├── dc_ae
│ │ └── efficientvit
│ │ │ ├── __init__.py
│ │ │ ├── ae_model_zoo.py
│ │ │ ├── apps
│ │ │ ├── __init__.py
│ │ │ ├── setup.py
│ │ │ ├── trainer
│ │ │ │ ├── __init__.py
│ │ │ │ └── run_config.py
│ │ │ └── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dist.py
│ │ │ │ ├── ema.py
│ │ │ │ ├── export.py
│ │ │ │ ├── image.py
│ │ │ │ ├── init.py
│ │ │ │ ├── lr.py
│ │ │ │ ├── metric.py
│ │ │ │ ├── misc.py
│ │ │ │ └── opt.py
│ │ │ └── models
│ │ │ ├── __init__.py
│ │ │ ├── efficientvit
│ │ │ ├── __init__.py
│ │ │ ├── dc_ae.py
│ │ │ └── dc_ae_with_temporal.py
│ │ │ ├── nn
│ │ │ ├── __init__.py
│ │ │ ├── act.py
│ │ │ ├── drop.py
│ │ │ ├── norm.py
│ │ │ ├── ops.py
│ │ │ ├── ops_3d.py
│ │ │ └── triton_rms_norm.py
│ │ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── list.py
│ │ │ ├── network.py
│ │ │ ├── random.py
│ │ │ └── video.py
│ ├── diffusion_utils.py
│ ├── dpm_solver.py
│ ├── edm_sample.py
│ ├── gaussian_diffusion.py
│ ├── model_growth_utils.py
│ ├── nets
│ │ ├── __init__.py
│ │ ├── basic_modules.py
│ │ ├── fastlinear
│ │ │ ├── develop_triton_ffn.py
│ │ │ ├── develop_triton_litemla.py
│ │ │ ├── modules
│ │ │ │ ├── __init__.py
│ │ │ │ ├── flash_attn.py
│ │ │ │ ├── lite_mla.py
│ │ │ │ ├── mb_conv_pre_glu.py
│ │ │ │ ├── nn
│ │ │ │ │ ├── act.py
│ │ │ │ │ ├── conv.py
│ │ │ │ │ └── norm.py
│ │ │ │ ├── triton_lite_mla.py
│ │ │ │ ├── triton_lite_mla_fwd.py
│ │ │ │ ├── triton_lite_mla_kernels
│ │ │ │ │ ├── custom_autotune.py
│ │ │ │ │ ├── linear_relu_fwd.py
│ │ │ │ │ ├── mm.py
│ │ │ │ │ ├── pad_vk_mm_fwd.py
│ │ │ │ │ ├── proj_divide_bwd.py
│ │ │ │ │ ├── vk_mm_relu_bwd.py
│ │ │ │ │ ├── vk_q_mm_divide_fwd.py
│ │ │ │ │ └── vk_q_mm_relu_bwd.py
│ │ │ │ ├── triton_mb_conv_pre_glu.py
│ │ │ │ ├── triton_mb_conv_pre_glu_kernels
│ │ │ │ │ ├── depthwise_conv_fwd.py
│ │ │ │ │ └── linear_glu_fwd.py
│ │ │ │ └── utils
│ │ │ │ │ ├── compare_results.py
│ │ │ │ │ ├── custom_autotune.py
│ │ │ │ │ ├── dtype.py
│ │ │ │ │ ├── export_onnx.py
│ │ │ │ │ └── model.py
│ │ │ └── readme.md
│ │ ├── ladd_blocks.py
│ │ ├── sana.py
│ │ ├── sana_U_shape.py
│ │ ├── sana_U_shape_multi_scale.py
│ │ ├── sana_blocks.py
│ │ ├── sana_ladd.py
│ │ ├── sana_multi_scale.py
│ │ ├── sana_multi_scale_adaln.py
│ │ ├── sana_multi_scale_controlnet.py
│ │ ├── sana_multi_scale_video.py
│ │ └── sana_others.py
│ ├── norms.py
│ ├── qwen
│ │ └── qwen_vl.py
│ ├── respace.py
│ ├── sa_solver.py
│ ├── timestep_sampler.py
│ ├── utils.py
│ ├── wan
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── clip.py
│ │ ├── fsdp_utils.py
│ │ ├── model.py
│ │ ├── model_wrapper.py
│ │ ├── t5.py
│ │ ├── tokenizers.py
│ │ ├── vae.py
│ │ └── xlm_roberta.py
│ └── wan2_2
│ │ └── vae.py
├── scheduler
│ ├── __init__.py
│ ├── dpm_solver.py
│ ├── flow_euler_sampler.py
│ ├── iddpm.py
│ ├── lcm_scheduler.py
│ ├── sa_sampler.py
│ ├── sa_solver_diffusers.py
│ ├── scm_scheduler.py
│ └── trigflow_scheduler.py
└── utils
│ ├── __init__.py
│ ├── checkpoint.py
│ ├── config.py
│ ├── config_wan.py
│ ├── data_sampler.py
│ ├── dist_utils.py
│ ├── git.py
│ ├── import_utils.py
│ ├── logger.py
│ ├── lr_scheduler.py
│ ├── misc.py
│ └── optimizer.py
├── environment_setup.sh
├── inference_video_scripts
├── inference_sana_video.py
└── inference_sana_video.sh
├── pyproject.toml
├── sana
├── cli
│ ├── run.py
│ └── upload2hf.py
└── tools
│ ├── __init__.py
│ ├── download.py
│ └── hf_utils.py
├── scripts
├── bash_run_inference_metric.sh
├── bash_run_inference_metric_dpg.sh
├── bash_run_inference_metric_geneval.sh
├── bash_run_inference_metric_imagereward.sh
├── infer_metric_run_inference_metric.sh
├── infer_metric_run_inference_metric_geneval.sh
├── infer_run_inference.sh
├── infer_run_inference_geneval.sh
├── infer_run_inference_geneval_diffusers.sh
├── inference.py
├── inference_dpg.py
├── inference_geneval.py
├── inference_geneval_diffusers.py
├── inference_image_reward.py
├── inference_sana_sprint.py
├── inference_sana_sprint_geneval.py
├── interface.py
└── style.css
├── tests
└── bash
│ ├── entry.sh
│ ├── inference
│ └── test_inference.sh
│ ├── setup_test_data.sh
│ └── training
│ ├── test_training_all.sh
│ ├── test_training_fsdp.sh
│ ├── test_training_vae.sh
│ └── test_training_video.sh
├── tools
├── __init__.py
├── controlnet
│ ├── annotator
│ │ ├── ckpts
│ │ │ └── ckpts.txt
│ │ ├── hed
│ │ │ └── __init__.py
│ │ └── util.py
│ ├── inference_controlnet.py
│ └── utils.py
├── convert_ImgDataset_to_WebDatasetMS_format.py
├── convert_py_to_yaml.py
├── convert_sana_to_diffusers.py
├── convert_sana_to_svdquant.py
├── convert_sana_video_to_diffusers.py
├── create_wids_metadata.py
├── download.py
├── inference_scaling
│ ├── nvila_sana_pick.py
│ └── nvila_sana_pick.sh
└── metrics
│ ├── clip-score
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── clip_score.py
│ ├── setup.py
│ └── src
│ │ └── clip_score
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ └── clip_score.py
│ ├── compute_clipscore.sh
│ ├── compute_dpg.sh
│ ├── compute_fid_embedding.sh
│ ├── compute_geneval.sh
│ ├── compute_imagereward.sh
│ ├── dpg_bench
│ ├── compute_dpg_bench.py
│ ├── dpg_bench.csv
│ ├── metadata.json
│ └── requirements.txt
│ ├── geneval
│ ├── LICENSE
│ ├── README.md
│ ├── annotations
│ │ ├── annotations_clip.csv
│ │ ├── annotations_if-xl.csv
│ │ ├── annotations_sdv2.csv
│ │ └── mturk_hit_template.html
│ ├── environment.yml
│ ├── evaluation
│ │ ├── download_models.sh
│ │ ├── evaluate_images.py
│ │ ├── object_names.txt
│ │ └── summary_scores.py
│ ├── generation
│ │ └── diffusers_generate.py
│ ├── geneval_env.md
│ ├── images
│ │ └── geneval_figure_1.png
│ └── prompts
│ │ ├── create_prompts.py
│ │ ├── evaluation_metadata.jsonl
│ │ ├── generation_prompts.txt
│ │ └── object_names.txt
│ ├── image_reward
│ ├── benchmark-prompts-dict.json
│ └── compute_image_reward.py
│ ├── pytorch-fid
│ ├── .gitignore
│ ├── CHANGELOG.md
│ ├── LICENSE
│ ├── README.md
│ ├── compute_fid.py
│ ├── noxfile.py
│ ├── setup.cfg
│ ├── setup.py
│ ├── src
│ │ └── pytorch_fid
│ │ │ ├── __init__.py
│ │ │ ├── __main__.py
│ │ │ ├── fid_score.py
│ │ │ └── inception.py
│ └── tests
│ │ └── test_fid_score.py
│ └── utils.py
├── train_scripts
├── train.py
├── train.sh
├── train_dreambooth_lora_sana.py
├── train_lora.sh
├── train_scm_ladd.py
└── train_scm_ladd.sh
└── train_video_scripts
├── train_video_ivjoint.py
├── train_video_ivjoint.sh
├── train_video_ivjoint_chunk.py
└── train_video_ivjoint_chunk.sh
/.github/workflows/bot-autolint.yaml:
--------------------------------------------------------------------------------
1 | name: Auto Lint (triggered by "auto lint" label)
2 | on:
3 | pull_request:
4 | types:
5 | - opened
6 | - edited
7 | - closed
8 | - reopened
9 | - synchronize
10 | - labeled
11 | - unlabeled
12 | # run only one unit test for a branch / tag.
13 | concurrency:
14 | group: ci-lint-${{ github.head_ref || github.ref }}
15 | cancel-in-progress: true
16 | jobs:
17 | lint-by-label:
18 | if: contains(github.event.pull_request.labels.*.name, 'lint wanted')
19 | runs-on: ubuntu-latest
20 | steps:
21 | - name: Check out Git repository
22 | uses: actions/checkout@v4
23 | with:
24 | token: ${{ secrets.GITHUB_TOKEN }}
25 | ref: ${{ github.event.pull_request.head.ref }}
26 | - name: Set up Python
27 | uses: actions/setup-python@v5
28 | with:
29 | python-version: '3.10'
30 | - name: Test pre-commit hooks
31 | continue-on-error: true
32 | uses: pre-commit/action@v3.0.0 # sync with https://github.com/Efficient-Large-Model/VILA-Internal/blob/main/.github/workflows/pre-commit.yaml
33 | with:
34 | extra_args: --all-files
35 | - name: Check if there are any changes
36 | id: verify_diff
37 | run: |
38 | git diff --quiet . || echo "changed=true" >> $GITHUB_OUTPUT
39 | - name: Commit files
40 | if: steps.verify_diff.outputs.changed == 'true'
41 | run: |
42 | git config --local user.email "action@github.com"
43 | git config --local user.name "GitHub Action"
44 | git add .
45 | git commit -m "[CI-Lint] Fix code style issues with pre-commit ${{ github.sha }}" -a
46 | git push
47 | - name: Remove label(s) after lint
48 | uses: actions-ecosystem/action-remove-labels@v1
49 | with:
50 | labels: lint wanted
51 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: trailing-whitespace
6 | name: (Common) Remove trailing whitespaces
7 | - id: mixed-line-ending
8 | name: (Common) Fix mixed line ending
9 | args: [--fix=lf]
10 | - id: end-of-file-fixer
11 | name: (Common) Remove extra EOF newlines
12 | - id: check-merge-conflict
13 | name: (Common) Check for merge conflicts
14 | - id: requirements-txt-fixer
15 | name: (Common) Sort "requirements.txt"
16 | - id: fix-encoding-pragma
17 | name: (Python) Remove encoding pragmas
18 | args: [--remove]
19 | # - id: debug-statements
20 | # name: (Python) Check for debugger imports
21 | - id: check-json
22 | name: (JSON) Check syntax
23 | - id: check-yaml
24 | name: (YAML) Check syntax
25 | - id: check-toml
26 | name: (TOML) Check syntax
27 | # - repo: https://github.com/shellcheck-py/shellcheck-py
28 | # rev: v0.10.0.1
29 | # hooks:
30 | # - id: shellcheck
31 | - repo: https://github.com/google/yamlfmt
32 | rev: v0.13.0
33 | hooks:
34 | - id: yamlfmt
35 | - repo: https://github.com/executablebooks/mdformat
36 | rev: 0.7.16
37 | hooks:
38 | - id: mdformat
39 | name: (Markdown) Format docs with mdformat
40 | - repo: https://github.com/asottile/pyupgrade
41 | rev: v3.2.2
42 | hooks:
43 | - id: pyupgrade
44 | name: (Python) Update syntax for newer versions
45 | args: [--py37-plus]
46 | - repo: https://github.com/psf/black
47 | rev: 22.10.0
48 | hooks:
49 | - id: black
50 | name: (Python) Format code with black
51 | - repo: https://github.com/pycqa/isort
52 | rev: 5.12.0
53 | hooks:
54 | - id: isort
55 | name: (Python) Sort imports with isort
56 | - repo: https://github.com/pre-commit/mirrors-clang-format
57 | rev: v15.0.4
58 | hooks:
59 | - id: clang-format
60 | name: (C/C++/CUDA) Format code with clang-format
61 | args: [-style=google, -i]
62 | types_or: [c, c++, cuda]
63 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | title: 'SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer'
3 | message: >-
4 | If you use this software or research, please cite it using the
5 | metadata from this file.
6 | type: misc
7 | authors:
8 | - given-names: Enze
9 | family-names: Xie
10 | - given-names: Junsong
11 | family-names: Chen
12 | - given-names: Junyu
13 | family-names: Chen
14 | - given-names: Han
15 | family-names: Cai
16 | - given-names: Haotian
17 | family-names: Tang
18 | - given-names: Yujun
19 | family-names: Lin
20 | - given-names: Zhekai
21 | family-names: Zhang
22 | - given-names: Muyang
23 | family-names: Li
24 | - given-names: Ligeng
25 | family-names: Zhu
26 | - given-names: Yao
27 | family-names: Lu
28 | - given-names: Song
29 | family-names: Han
30 | repository-code: 'https://github.com/NVlabs/Sana'
31 | abstract: >-
32 | SANA proposes an efficient linear Diffusion Transformer (DiT) for high-resolution
33 | image synthesis, featuring a depth-growth paradigm, model pruning techniques,
34 | and inference-time scaling strategies to reduce training costs while maintaining
35 | generation quality. SANA-Sprint also achieves one-step generation of high-resolution images
36 | keywords:
37 | - deep-learning
38 | - diffusion-models
39 | - transformer
40 | - image-generation
41 | - text-to-image
42 | - efficient-training
43 | - distillation
44 | license: Apache-2.0
45 | version: 2.0.0
46 | doi: 10.48550/arXiv.2410.10629
47 | date-released: 2024-10-16
48 |
--------------------------------------------------------------------------------
/CIs/add_license_all.sh:
--------------------------------------------------------------------------------
1 | #/bin/bash
2 | addlicense -s -c 'NVIDIA CORPORATION & AFFILIATES' -ignore "**/*__init__.py" **/*.py
3 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:24.06-py3
2 |
3 | ENV PATH=/opt/conda/bin:$PATH
4 |
5 | RUN apt-get update && apt-get install -y \
6 | libgl1-mesa-glx \
7 | libglib2.0-0 \
8 | && rm -rf /var/lib/apt/lists/*
9 |
10 | WORKDIR /app
11 |
12 | RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \
13 | && sh ~/miniconda.sh -b -p /opt/conda \
14 | && rm ~/miniconda.sh
15 |
16 | COPY pyproject.toml pyproject.toml
17 | COPY diffusion diffusion
18 | COPY configs configs
19 | COPY sana sana
20 | COPY app app
21 | COPY tools tools
22 |
23 | COPY environment_setup.sh environment_setup.sh
24 | RUN ./environment_setup.sh
25 |
26 | CMD ["python", "-u", "-W", "ignore", "app/app_sana.py", "--share", "--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml", "--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth"]
27 |
--------------------------------------------------------------------------------
/app/safety_check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | # This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
17 | import json
18 | from functools import lru_cache
19 | from typing import Dict, Sequence
20 |
21 | import torch
22 | import transformers
23 | from torch.nn.functional import softmax
24 | from transformers import AutoModelForCausalLM, AutoTokenizer
25 |
26 | __all__ = ["is_dangerous"]
27 |
28 | __rank, __local_rank, __world_size, __device = (
29 | 0,
30 | 0,
31 | 1,
32 | "cuda" if torch.cuda.is_available() else "cpu",
33 | )
34 |
35 |
36 | safety_policy = """
37 | The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
38 | """
39 |
40 |
41 | def is_dangerous(tokenizer, model, user_prompt, threshold=0.5):
42 | prompt = f"""You are a policy expert trying to help determine whether a user
43 | prompt is in violation of the defined safety policies.
44 |
45 | <start_of_turn>
46 | Human Question: {user_prompt.strip()}
47 | <end_of_turn>
48 |
49 | Our safety principle is defined in the below:
50 |
51 | {safety_policy.strip()}
52 |
53 | Does the human question violate the above principle? Your answer must start
54 | with 'Yes' or 'No'. And then walk through step by step to be sure we answer
55 | correctly.
56 | """
57 |
58 | inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
59 | with torch.no_grad():
60 | logits = model(**inputs).logits
61 |
62 | # Extract the logits for the Yes and No tokens
63 | vocab = tokenizer.get_vocab()
64 | selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
65 |
66 | # Convert these logits to a probability with softmax
67 | probabilities = softmax(selected_logits, dim=0)
68 |
69 | # Return probability of 'Yes'
70 | score = probabilities[0].item()
71 |
72 | return score > threshold
73 |
--------------------------------------------------------------------------------
/asset/Sana.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/Sana.jpg
--------------------------------------------------------------------------------
/asset/app_styles/controlnet_app_style.css:
--------------------------------------------------------------------------------
1 | @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
2 |
3 | body{align-items: center;}
4 | .gradio-container{max-width: 1200px !important}
5 | h1{text-align:center}
6 |
7 | .wrap.svelte-p4aq0j.svelte-p4aq0j {
8 | display: none;
9 | }
10 |
11 | #column_input, #column_output {
12 | width: 500px;
13 | display: flex;
14 | align-items: center;
15 | }
16 |
17 | #input_header, #output_header {
18 | display: flex;
19 | justify-content: center;
20 | align-items: center;
21 | width: 400px;
22 | }
23 |
24 | #accessibility {
25 | text-align: center; /* Center-aligns the text */
26 | margin: auto; /* Centers the element horizontally */
27 | }
28 |
29 | #random_seed {height: 71px;}
30 | #run_button {height: 87px;}
31 |
--------------------------------------------------------------------------------
/asset/apple.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/apple.webp
--------------------------------------------------------------------------------
/asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg
--------------------------------------------------------------------------------
/asset/controlnet/ref_images/a house.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/controlnet/ref_images/a house.png
--------------------------------------------------------------------------------
/asset/controlnet/ref_images/a living room.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/controlnet/ref_images/a living room.png
--------------------------------------------------------------------------------
/asset/controlnet/ref_images/nvidia.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/controlnet/ref_images/nvidia.png
--------------------------------------------------------------------------------
/asset/controlnet/samples_controlnet.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "prompt": "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
4 | "ref_image_path": "asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg"
5 | },
6 | {
7 | "prompt": "an architecture in INDIA,15th-18th style, with a lot of details",
8 | "ref_image_path": "asset/controlnet/ref_images/a house.png"
9 | },
10 | {
11 | "prompt": "An IKEA modern style living room with sofa, coffee table, stairs, etc., a brand new theme.",
12 | "ref_image_path": "asset/controlnet/ref_images/a living room.png"
13 | },
14 | {
15 | "prompt": "A modern new living room with sofa, coffee table, carpet, stairs, etc., high quality high detail, high resolution.",
16 | "ref_image_path": "asset/controlnet/ref_images/a living room.png"
17 | },
18 | {
19 | "prompt": "big eye, vibrant colors, intricate details, captivating gaze, surreal, dreamlike, fantasy, enchanting, mysterious, magical, moonlit, mystical, ethereal, enchanting {macro lens, high aperture, low ISO}",
20 | "ref_image_path": "asset/controlnet/ref_images/nvidia.png"
21 | },
22 | {
23 | "prompt": "shining eye, bright and vivid colors, radiant glow, sparkling reflections, joyful, uplifting, optimistic, hopeful, magical, luminous, celestial, dreamy {zoom lens, high aperture, natural light, vibrant color film}",
24 | "ref_image_path": "asset/controlnet/ref_images/nvidia.png"
25 | }
26 | ]
27 |
--------------------------------------------------------------------------------
/asset/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/cover.png
--------------------------------------------------------------------------------
/asset/docs/ComfyUI/comfyui.md:
--------------------------------------------------------------------------------
1 | ## 🖌️ Sana-ComfyUI
2 |
3 | [Original Repo](https://github.com/city96/ComfyUI_ExtraModels)
4 |
5 | ### Model info / implementation
6 |
7 | - Uses Gemma2 2B as the text encoder
8 | - Multiple resolutions and models available
9 | - Compressed latent space (32 channels, /32 compression) - needs custom VAE
10 |
11 | ### Usage
12 |
13 | 1. All the checkpoints will be downloaded automatically.
14 | 1. KSampler(Flow Euler) is available for now; Flow DPM-Solver will be available soon.
15 |
16 | ```bash
17 | git clone https://github.com/comfyanonymous/ComfyUI.git
18 | cd ComfyUI
19 | git clone https://github.com/lawrence-cj/ComfyUI_ExtraModels.git custom_nodes/ComfyUI_ExtraModels
20 |
21 | python main.py
22 | ```
23 |
24 | ### A sample workflow for Sana
25 |
26 | [Sana workflow](Sana_FlowEuler.json)
27 |
28 | 
29 |
30 | ### A sample for T2I(Sana) + I2V(CogVideoX)
31 |
32 | [Sana + CogVideoX workflow](Sana_CogVideoX.json)
33 |
34 | [](https://nvlabs.github.io/Sana/asset/content/comfyui/Sana_CogVideoX_Fun.mp4)
35 |
36 | ### A sample workflow for Sana 4096x4096 image (18GB GPU is needed)
37 |
38 | [Sana workflow](Sana_FlowEuler_4K.json)
39 |
40 | 
41 |
--------------------------------------------------------------------------------
/asset/docs/inference_scaling/inference_scaling.md:
--------------------------------------------------------------------------------
1 | ## Inference Time Scaling for SANA-1.5
2 |
3 | 
4 |
5 | We trained a specialized [NVILA-2B](https://huggingface.co/Efficient-Large-Model/NVILA-Lite-2B-Verifier) model to score images, which we named VISA (VIla as SAna verifier). By selecting the top 4 images from 2,048 candidates, we enhanced the GenEval performance of SD1.5 and SANA-1.5-4.8B v2, increasing their scores from 42 to 87 and 81 to 96, respectively.
6 |
7 | 
8 |
9 | Even for smaller number of candidates, like 32, we can also push the performance over 90% for SANA-1.5-4.8B v2 in the GenEval.
10 |
11 | ### Environment Requirement
12 |
13 | Dependency setups:
14 |
15 | ```bash
16 | # other transformers version may also work, but we have not tested
17 | pip install transformers==4.46
18 | pip install git+https://github.com/bfshi/scaling_on_scales.git
19 | ```
20 |
21 | ### 1. Generate N images with a .pth file for the following selection
22 |
23 | ```bash
24 | # download the checkpoint for the following generation
25 | huggingface-cli download Efficient-Large-Model/Sana_600M_512px --repo-type model --local-dir output/Sana_600M_512px --local-dir-use-symlinks False
26 | # 32 is a relatively small number for test but can already push the geneval>90% when we verify the SANA-1.5-4.8B v2 model. Set it to larger number like 2048 for the limit of sky.
27 | n_samples=32
28 | pick_number=4
29 |
30 | output_dir=output/geneval_generated_path
31 | # example
32 | bash scripts/infer_run_inference_geneval.sh \
33 | configs/sana_config/512ms/Sana_600M_img512.yaml \
34 | output/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth \
35 | --img_nums_per_sample=$n_samples \
36 | --output_dir=$output_dir
37 | ```
38 |
39 | ### 2. Use NVILA-Verifier to select from the generated images
40 |
41 | ```bash
42 | bash tools/inference_scaling/nvila_sana_pick.sh \
43 | $output_dir \
44 | $n_samples \
45 | $pick_number
46 | ```
47 |
48 | ### 3. Calculate the GenEval metric
49 |
50 | You need to use the GenEval environment for the final evaluation. The document about installation can be found [here](../../../tools/metrics/geneval/geneval_env.md).
51 |
52 | ```bash
53 | # activate geneval env
54 | conda activate geneval
55 |
56 | DIR_AFTER_PICK="output/nvila_pick/best_${pick_number}_of_${n_samples}/${output_dir}"
57 |
58 | bash tools/metrics/compute_geneval.sh $(dirname "$DIR_AFTER_PICK") $(basename "$DIR_AFTER_PICK")
59 | ```
60 |
--------------------------------------------------------------------------------
/asset/docs/inference_scaling/results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/docs/inference_scaling/results.jpg
--------------------------------------------------------------------------------
/asset/docs/inference_scaling/scaling_curve.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/docs/inference_scaling/scaling_curve.jpg
--------------------------------------------------------------------------------
/asset/docs/quantize/4bit_sana.md:
--------------------------------------------------------------------------------
1 | <!--Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License. -->
15 |
16 | # 4bit SanaPipeline
17 |
18 | ### 1. Environment setup
19 |
20 | Follow the official [SVDQuant-Nunchaku](https://github.com/mit-han-lab/nunchaku) repository to set up the environment. The guidance can be found [here](https://github.com/mit-han-lab/nunchaku?tab=readme-ov-file#installation).
21 |
22 | ### 1-1. Quantize Sana with SVDQuant-4bit (Optional)
23 |
24 | 1. Convert pth to SVDQuant required safetensor
25 |
26 | ```
27 | python tools/convert_sana_to_svdquant.py \
28 | --orig_ckpt_path Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth \
29 | --model_type SanaMS1.5_1600M_P1_D20 \
30 | --dtype bf16 \
31 | --dump_path output/SANA1.5_1.6B_1024px_svdquant_diffusers \
32 | --save_full_pipeline
33 | ```
34 |
35 | 2. follow the guidance to compress model
36 | [Quantization guidance](https://github.com/mit-han-lab/deepcompressor/tree/main/examples/diffusion)
37 |
38 | ### 2. Code snap for inference
39 |
40 | Here we show the code snippet for SanaPipeline. For SanaPAGPipeline, please refer to the [SanaPAGPipeline](https://github.com/mit-han-lab/nunchaku/blob/main/examples/sana_1600m_pag.py) section.
41 |
42 | ```python
43 | import torch
44 | from diffusers import SanaPipeline
45 |
46 | from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
47 |
48 | transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
49 | pipe = SanaPipeline.from_pretrained(
50 | "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
51 | transformer=transformer,
52 | variant="bf16",
53 | torch_dtype=torch.bfloat16,
54 | ).to("cuda")
55 |
56 | pipe.text_encoder.to(torch.bfloat16)
57 | pipe.vae.to(torch.bfloat16)
58 |
59 | image = pipe(
60 | prompt="A cute 🐼 eating 🎋, ink drawing style",
61 | height=1024,
62 | width=1024,
63 | guidance_scale=4.5,
64 | num_inference_steps=20,
65 | generator=torch.Generator().manual_seed(42),
66 | ).images[0]
67 | image.save("sana_1600m.png")
68 | ```
69 |
70 | ### 3. Online demo
71 |
72 | 1). Launch the 4bit Sana.
73 |
74 | ```bash
75 | python app/app_sana_4bit.py
76 | ```
77 |
78 | 2). Compare with BF16 version
79 |
80 | Refer to the original [Nunchaku-Sana.](https://github.com/mit-han-lab/nunchaku/tree/main/app/sana/t2i) guidance for SanaPAGPipeline
81 |
82 | ```bash
83 | python app/app_sana_4bit_compare_bf16.py
84 | ```
85 |
--------------------------------------------------------------------------------
/asset/docs/sana_controlnet.md:
--------------------------------------------------------------------------------
1 | <!-- Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
15 | SPDX-License-Identifier: Apache-2.0 -->
16 |
17 | ## 🔥 ControlNet
18 |
19 | We incorporate a ControlNet-like(https://github.com/lllyasviel/ControlNet) module enables fine-grained control over text-to-image diffusion models. We implement a ControlNet-Transformer architecture, specifically tailored for Transformers, achieving explicit controllability alongside high-quality image generation.
20 |
21 | <p align="center">
22 | <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/page/asset/content/controlnet/sana_controlnet.jpg" height=480>
23 | </p>
24 |
25 | ## Inference of `Sana + ControlNet`
26 |
27 | ### 1). Gradio Interface
28 |
29 | ```bash
30 | python app/app_sana_controlnet_hed.py \
31 | --config configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml \
32 | --model_path hf://Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED/checkpoints/Sana_1600M_1024px_BF16_ControlNet_HED.pth
33 | ```
34 |
35 | <p align="center" border-raduis="10px">
36 | <img src="https://nvlabs.github.io/Sana/asset/content/controlnet/controlnet_app.jpg" width="90%" alt="teaser_page2"/>
37 | </p>
38 |
39 | ### 2). Inference with JSON file
40 |
41 | ```bash
42 | python tools/controlnet/inference_controlnet.py \
43 | --config configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml \
44 | --model_path hf://Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED/checkpoints/Sana_1600M_1024px_BF16_ControlNet_HED.pth \
45 | --json_file asset/controlnet/samples_controlnet.json
46 | ```
47 |
48 | ### 3). Inference code snap
49 |
50 | ```python
51 | import torch
52 | from PIL import Image
53 | from app.sana_controlnet_pipeline import SanaControlNetPipeline
54 |
55 | device = "cuda" if torch.cuda.is_available() else "cpu"
56 |
57 | pipe = SanaControlNetPipeline("configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml")
58 | pipe.from_pretrained("hf://Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED/checkpoints/Sana_1600M_1024px_BF16_ControlNet_HED.pth")
59 |
60 | ref_image = Image.open("asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg")
61 | prompt = "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape."
62 |
63 | images = pipe(
64 | prompt=prompt,
65 | ref_image=ref_image,
66 | guidance_scale=4.5,
67 | num_inference_steps=10,
68 | sketch_thickness=2,
69 | generator=torch.Generator(device=device).manual_seed(0),
70 | )
71 | ```
72 |
73 | ## Training of `Sana + ControlNet`
74 |
75 | ### Coming soon
76 |
--------------------------------------------------------------------------------
/asset/example_data/00000000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/example_data/00000000.jpg
--------------------------------------------------------------------------------
/asset/example_data/00000000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/example_data/00000000.png
--------------------------------------------------------------------------------
/asset/example_data/00000000.txt:
--------------------------------------------------------------------------------
1 | a cyberpunk cat with a neon sign that says "Sana".
2 |
--------------------------------------------------------------------------------
/asset/example_data/00000000_InternVL2-26B.json:
--------------------------------------------------------------------------------
1 | {
2 | "00000000": {
3 | "InternVL2-26B": "a cyberpunk cat with a neon sign that says 'Sana'"
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/asset/example_data/00000000_InternVL2-26B_clip_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "00000000": {
3 | "InternVL2-26B": "27.1037"
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/asset/example_data/00000000_VILA1-5-13B.json:
--------------------------------------------------------------------------------
1 | {
2 | "00000000": {
3 | "VILA1-5-13B": "a cyberpunk cat with a neon sign that says 'Sana'"
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/asset/example_data/00000000_VILA1-5-13B_clip_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "00000000": {
3 | "VILA1-5-13B": "27.2321"
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/asset/example_data/00000000_prompt_clip_score.json:
--------------------------------------------------------------------------------
1 | {
2 | "00000000": {
3 | "prompt": "26.7331"
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/asset/example_data/meta_data.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "sana-dev",
3 | "__kind__": "Sana-ImgDataset",
4 | "img_names": [
5 | "00000000", "00000000", "00000000.png", "00000000.jpg"
6 | ]
7 | }
8 |
--------------------------------------------------------------------------------
/asset/examples.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | examples = [
18 | [
19 | "A small cactus with a happy face in the Sahara desert.",
20 | "flow_dpm-solver",
21 | 20,
22 | 5.0,
23 | 2.5,
24 | ],
25 | [
26 | "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history"
27 | "of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits "
28 | "mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret "
29 | "and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile "
30 | "as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and "
31 | "the Parisian streets and city in the background, depth of field, cinematic 35mm film.",
32 | "flow_dpm-solver",
33 | 20,
34 | 5.0,
35 | 2.5,
36 | ],
37 | [
38 | "An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. "
39 | "Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. "
40 | "The quote 'Find the universe within you' is etched in bold letters across the horizon."
41 | "blue and pink, brilliantly illuminated in the background.",
42 | "flow_dpm-solver",
43 | 20,
44 | 5.0,
45 | 2.5,
46 | ],
47 | [
48 | "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
49 | "flow_dpm-solver",
50 | 20,
51 | 5.0,
52 | 2.5,
53 | ],
54 | [
55 | "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
56 | "flow_dpm-solver",
57 | 20,
58 | 5.0,
59 | 2.5,
60 | ],
61 | [
62 | "a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, "
63 | "national geographic photo, 8k resolution, crayon art, interactive artwork",
64 | "flow_dpm-solver",
65 | 20,
66 | 5.0,
67 | 2.5,
68 | ],
69 | ]
70 |
--------------------------------------------------------------------------------
/asset/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/logo.png
--------------------------------------------------------------------------------
/asset/mit-logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/mit-logo.jpg
--------------------------------------------------------------------------------
/asset/model-incremental.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/model-incremental.jpg
--------------------------------------------------------------------------------
/asset/model_paths.txt:
--------------------------------------------------------------------------------
1 | output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
2 | output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
3 |
--------------------------------------------------------------------------------
/asset/paper2video.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/asset/paper2video.jpg
--------------------------------------------------------------------------------
/asset/samples/sample_i2v.txt:
--------------------------------------------------------------------------------
1 | A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle.<image>asset/samples/i2v-1.png
2 | A majestic brown cow with large curved horns gallops across a dusty field under a clear blue sky. The cow's powerful strides kick up clouds of dust, creating a dynamic sense of motion. The camera captures the animal from a low angle, emphasizing its imposing presence against the vast, open landscape. The sunlight highlights the cow's glossy coat, adding depth and vibrancy to the scene. A slow-motion effect enhances the fluidity of the cow's movement, making it appear almost airborne.<image>asset/samples/i2v-2.png
3 |
--------------------------------------------------------------------------------
/asset/samples/samples_mini.txt:
--------------------------------------------------------------------------------
1 | A cyberpunk cat with a neon sign that says 'Sana'.
2 | A small cactus with a happy face in the Sahara desert.
3 | The towel was on top of the hard counter.
4 | A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
5 | I want to supplement vitamin c, please help me paint related food.
6 | A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
7 | an old rusted robot wearing pants and a jacket riding skis in a supermarket.
8 | professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
9 | Astronaut in a jungle, cold color palette, muted colors, detailed
10 | a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests.
11 |
--------------------------------------------------------------------------------
/asset/samples/video_prompts_samples.txt:
--------------------------------------------------------------------------------
1 | Extreme close-up of a thoughtful, gray-haired professor in his 60s, sitting motionless in a Paris café, dressed in a wool coat and beret, pondering the universe. His subtle closed-mouth smile reveals an answer. Golden light, cinematic depth of field, Paris streets blurred in the background. Cinematic 35mm film.
2 | A woman surrounded by swirling smoke of vibrant colors, warm light bathing her figure. Medium shot, soft focus.
3 | "Minecraft with the most gorgeous high-res 8K texture pack ever, showcasing detailed landscapes and characters. Smooth camera pans over lush forests, towering mountains, and serene villages. Epic vistas and intricate textures. Wide and close-up shots."
4 | Japanese animated film style, a young woman standing on a ship's deck, looking back at the camera with a serene expression. The background shows the ocean and sky stretching out behind her. Medium shot focusing on her profile.
5 | A hyper-speed train's internal window view passing through an old European city, showing fast-moving buildings and landscapes. Medium shot from inside the train.
6 | A large orange octopus rests on the ocean floor, blending with sand and rocks, tentacles spread, eyes closed. A brown, spiky king crab creeps closer, claws raised. Wide angle captures the vast, clear, sunlit blue sea, focusing on the octopus and crab with a depth of field blur.
7 | Wildlife along the Kinabatangan River in Borneo, focusing on diverse animals like orangutans, proboscis monkeys, and crocodiles in their natural habitat. Aerial and ground shots showcasing lush rainforest surroundings.
8 | A lively pink pig running swiftly towards the camera in a bustling Tokyo alleyway, surrounded by neon lights and signs. Close-up, dynamic shot.
9 | In the daytime, an anime-style white car drives towards the camera, splashing water from a pond as it passes by, medium shot.
10 | A toy robot in blue jeans and a white t-shirt leisurely walking in Antarctica as the sun sets beautifully. The robot has a friendly expression, moving with smooth steps. Wide shot capturing the vast icy landscape and colorful sky.
11 |
--------------------------------------------------------------------------------
/configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: [data/toy_data]
3 | image_size: 1024
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: []
7 | external_clipscore_suffixes: []
8 | clip_thr_temperature: 0.1
9 | clip_thr: 25.0
10 | del_img_clip_thr: 22.0
11 | load_text_feat: false
12 | load_vae_feat: true
13 | transform: default_train
14 | type: SanaWebDatasetMS
15 | sort_dataset: false
16 | # model config
17 | model:
18 | model: SanaMS_1600M_P1_D20
19 | image_size: 1024
20 | mixed_precision: bf16
21 | fp32_attention: true
22 | load_from: hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth
23 | aspect_ratio_type: ASPECT_RATIO_1024
24 | multi_scale: true
25 | attn_type: linear
26 | ffn_type: glumbconv
27 | mlp_acts:
28 | - silu
29 | - silu
30 | -
31 | mlp_ratio: 2.5
32 | use_pe: false
33 | qk_norm: true
34 | cross_norm: true
35 | class_dropout_prob: 0.1
36 | # VAE setting
37 | vae:
38 | vae_type: AutoencoderDC
39 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
40 | scale_factor: 0.41407
41 | vae_latent_dim: 32
42 | vae_downsample_rate: 32
43 | sample_posterior: true
44 | # text encoder
45 | text_encoder:
46 | text_encoder_name: gemma-2-2b-it
47 | y_norm: true
48 | y_norm_scale_factor: 0.01
49 | model_max_length: 300
50 | # CHI
51 | chi_prompt:
52 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
53 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
54 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
55 | - 'Here are examples of how to transform or refine prompts:'
56 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
57 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
58 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
59 | - 'User Prompt: '
60 | # Sana schedule Flow
61 | scheduler:
62 | predict_flow_v: true
63 | noise_schedule: linear_flow
64 | pred_sigma: false
65 | flow_shift: 3.0
66 | # logit-normal timestep
67 | weighting_scheme: logit_normal
68 | logit_mean: 0.0
69 | logit_std: 1.0
70 | vis_sampler: flow_dpm-solver
71 | # training setting
72 | train:
73 | use_fsdp: true
74 | num_workers: 10
75 | seed: 1
76 | train_batch_size: 32
77 | num_epochs: 100
78 | gradient_accumulation_steps: 1
79 | grad_checkpointing: true
80 | gradient_clip: 0.1
81 | optimizer:
82 | betas:
83 | - 0.9
84 | - 0.999
85 | eps: 1.0e-10
86 | lr: 2.0e-5
87 | type: AdamW
88 | weight_decay: 0.0
89 | lr_schedule: constant
90 | lr_schedule_args:
91 | num_warmup_steps: 1000
92 | local_save_vis: true # if save log image locally
93 | visualize: true
94 | eval_sampling_steps: 500
95 | log_interval: 20
96 | save_model_epochs: 5
97 | save_model_steps: 500
98 | work_dir: output/debug
99 | online_metric: false
100 | eval_metric_step: 2000
101 | online_metric_dir: metric_helper
102 |
--------------------------------------------------------------------------------
/configs/sana_app_config/Sana_1600M_app.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: []
3 | image_size: 1024
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: []
7 | external_clipscore_suffixes: []
8 | clip_thr_temperature: 0.1
9 | clip_thr: 25.0
10 | load_text_feat: false
11 | load_vae_feat: false
12 | transform: default_train
13 | type: SanaWebDatasetMS
14 | data:
15 | sort_dataset: false
16 | # model config
17 | model:
18 | model: SanaMS_1600M_P1_D20
19 | image_size: 1024
20 | mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
21 | fp32_attention: true
22 | load_from:
23 | resume_from:
24 | aspect_ratio_type: ASPECT_RATIO_1024
25 | multi_scale: true
26 | #pe_interpolation: 1.
27 | attn_type: linear
28 | ffn_type: glumbconv
29 | mlp_acts:
30 | - silu
31 | - silu
32 | -
33 | mlp_ratio: 2.5
34 | use_pe: false
35 | qk_norm: false
36 | class_dropout_prob: 0.1
37 | # CFG & PAG settings
38 | pag_applied_layers:
39 | - 8
40 | # VAE setting
41 | vae:
42 | vae_type: AutoencoderDC
43 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
44 | scale_factor: 0.41407
45 | vae_latent_dim: 32
46 | vae_downsample_rate: 32
47 | sample_posterior: true
48 | # text encoder
49 | text_encoder:
50 | text_encoder_name: gemma-2-2b-it
51 | y_norm: true
52 | y_norm_scale_factor: 0.01
53 | model_max_length: 300
54 | # CHI
55 | chi_prompt:
56 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
57 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
58 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
59 | - 'Here are examples of how to transform or refine prompts:'
60 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
61 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
62 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
63 | - 'User Prompt: '
64 | # Sana schedule Flow
65 | scheduler:
66 | predict_flow_v: true
67 | noise_schedule: linear_flow
68 | pred_sigma: false
69 | flow_shift: 3.0
70 | # logit-normal timestep
71 | weighting_scheme: logit_normal
72 | logit_mean: 0.0
73 | logit_std: 1.0
74 | vis_sampler: flow_dpm-solver
75 | # training setting
76 | train:
77 | num_workers: 10
78 | seed: 1
79 | train_batch_size: 64
80 | num_epochs: 100
81 | gradient_accumulation_steps: 1
82 | grad_checkpointing: true
83 | gradient_clip: 0.1
84 | optimizer:
85 | betas:
86 | - 0.9
87 | - 0.999
88 | - 0.9999
89 | eps:
90 | - 1.0e-30
91 | - 1.0e-16
92 | lr: 0.0001
93 | type: CAMEWrapper
94 | weight_decay: 0.0
95 | lr_schedule: constant
96 | lr_schedule_args:
97 | num_warmup_steps: 2000
98 | local_save_vis: true # if save log image locally
99 | visualize: true
100 | eval_sampling_steps: 500
101 | log_interval: 20
102 | save_model_epochs: 5
103 | save_model_steps: 500
104 | work_dir: output/debug
105 | online_metric: false
106 | eval_metric_step: 2000
107 | online_metric_dir: metric_helper
108 |
--------------------------------------------------------------------------------
/configs/sana_app_config/Sana_600M_app.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: []
3 | image_size: 1024
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: []
7 | external_clipscore_suffixes: []
8 | clip_thr_temperature: 0.1
9 | clip_thr: 25.0
10 | load_text_feat: false
11 | load_vae_feat: true
12 | transform: default_train
13 | type: SanaWebDatasetMS
14 | sort_dataset: false
15 | # model config
16 | model:
17 | model: SanaMS_600M_P1_D28
18 | image_size: 1024
19 | mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
20 | fp32_attention: true
21 | load_from:
22 | resume_from:
23 | aspect_ratio_type: ASPECT_RATIO_1024
24 | multi_scale: true
25 | attn_type: linear
26 | ffn_type: glumbconv
27 | mlp_acts:
28 | - silu
29 | - silu
30 | -
31 | mlp_ratio: 2.5
32 | use_pe: false
33 | qk_norm: false
34 | class_dropout_prob: 0.1
35 | # CFG & PAG settings
36 | pag_applied_layers:
37 | - 14
38 | # VAE setting
39 | vae:
40 | vae_type: AutoencoderDC
41 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
42 | scale_factor: 0.41407
43 | vae_latent_dim: 32
44 | vae_downsample_rate: 32
45 | sample_posterior: true
46 | # text encoder
47 | text_encoder:
48 | text_encoder_name: gemma-2-2b-it
49 | y_norm: true
50 | y_norm_scale_factor: 0.01
51 | model_max_length: 300
52 | # CHI
53 | chi_prompt:
54 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
55 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
56 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
57 | - 'Here are examples of how to transform or refine prompts:'
58 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
59 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
60 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
61 | - 'User Prompt: '
62 | # Sana schedule Flow
63 | scheduler:
64 | predict_flow_v: true
65 | noise_schedule: linear_flow
66 | pred_sigma: false
67 | flow_shift: 4.0
68 | # logit-normal timestep
69 | weighting_scheme: logit_normal
70 | logit_mean: 0.0
71 | logit_std: 1.0
72 | vis_sampler: flow_dpm-solver
73 | # training setting
74 | train:
75 | num_workers: 10
76 | seed: 1
77 | train_batch_size: 64
78 | num_epochs: 100
79 | gradient_accumulation_steps: 1
80 | grad_checkpointing: true
81 | gradient_clip: 0.1
82 | optimizer:
83 | betas:
84 | - 0.9
85 | - 0.999
86 | - 0.9999
87 | eps:
88 | - 1.0e-30
89 | - 1.0e-16
90 | lr: 0.0001
91 | type: CAMEWrapper
92 | weight_decay: 0.0
93 | lr_schedule: constant
94 | lr_schedule_args:
95 | num_warmup_steps: 2000
96 | local_save_vis: true # if save log image locally
97 | visualize: true
98 | eval_sampling_steps: 500
99 | log_interval: 20
100 | save_model_epochs: 5
101 | save_model_steps: 500
102 | work_dir: output/debug
103 | online_metric: false
104 | eval_metric_step: 2000
105 | online_metric_dir: metric_helper
106 |
--------------------------------------------------------------------------------
/configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: [data/toy_data]
3 | image_size: 1024
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7 | external_clipscore_suffixes:
8 | - _InternVL2-26B_clip_score
9 | - _VILA1-5-13B_clip_score
10 | - _prompt_clip_score
11 | clip_thr_temperature: 0.1
12 | clip_thr: 25.0
13 | load_text_feat: false
14 | load_vae_feat: false
15 | transform: default_train
16 | type: SanaWebDatasetMS
17 | sort_dataset: false
18 | # model config
19 | model:
20 | model: SanaMS_1600M_P1_D20
21 | image_size: 1024
22 | mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
23 | fp32_attention: true
24 | load_from:
25 | resume_from:
26 | aspect_ratio_type: ASPECT_RATIO_1024
27 | multi_scale: true
28 | #pe_interpolation: 1.
29 | attn_type: linear
30 | ffn_type: glumbconv
31 | mlp_acts:
32 | - silu
33 | - silu
34 | -
35 | mlp_ratio: 2.5
36 | use_pe: false
37 | qk_norm: false
38 | class_dropout_prob: 0.1
39 | # PAG
40 | pag_applied_layers:
41 | - 8
42 | # VAE setting
43 | vae:
44 | vae_type: AutoencoderDC
45 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
46 | scale_factor: 0.41407
47 | vae_latent_dim: 32
48 | vae_downsample_rate: 32
49 | sample_posterior: true
50 | # text encoder
51 | text_encoder:
52 | text_encoder_name: gemma-2-2b-it
53 | y_norm: true
54 | y_norm_scale_factor: 0.01
55 | model_max_length: 300
56 | # CHI
57 | chi_prompt:
58 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
59 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
60 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
61 | - 'Here are examples of how to transform or refine prompts:'
62 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
63 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
64 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
65 | - 'User Prompt: '
66 | # Sana schedule Flow
67 | scheduler:
68 | predict_flow_v: true
69 | noise_schedule: linear_flow
70 | pred_sigma: false
71 | flow_shift: 3.0
72 | # logit-normal timestep
73 | weighting_scheme: logit_normal
74 | logit_mean: 0.0
75 | logit_std: 1.0
76 | vis_sampler: flow_dpm-solver
77 | # training setting
78 | train:
79 | num_workers: 10
80 | seed: 1
81 | train_batch_size: 64
82 | num_epochs: 100
83 | gradient_accumulation_steps: 1
84 | grad_checkpointing: true
85 | gradient_clip: 0.1
86 | optimizer:
87 | lr: 1.0e-4
88 | type: AdamW
89 | weight_decay: 0.01
90 | eps: 1.0e-8
91 | betas: [0.9, 0.999]
92 | lr_schedule: constant
93 | lr_schedule_args:
94 | num_warmup_steps: 2000
95 | local_save_vis: true # if save log image locally
96 | visualize: true
97 | eval_sampling_steps: 500
98 | log_interval: 20
99 | save_model_epochs: 5
100 | save_model_steps: 500
101 | work_dir: output/debug
102 | online_metric: false
103 | eval_metric_step: 2000
104 | online_metric_dir: metric_helper
105 |
--------------------------------------------------------------------------------
/configs/sana_config/1024ms/Sana_600M_img1024.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: [data/toy_data]
3 | image_size: 1024
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7 | external_clipscore_suffixes:
8 | - _InternVL2-26B_clip_score
9 | - _VILA1-5-13B_clip_score
10 | - _prompt_clip_score
11 | clip_thr_temperature: 0.1
12 | clip_thr: 25.0
13 | load_text_feat: false
14 | load_vae_feat: false
15 | transform: default_train
16 | type: SanaWebDatasetMS
17 | sort_dataset: false
18 | # model config
19 | model:
20 | model: SanaMS_600M_P1_D28
21 | image_size: 1024
22 | mixed_precision: fp16
23 | fp32_attention: true
24 | load_from:
25 | resume_from:
26 | aspect_ratio_type: ASPECT_RATIO_1024
27 | multi_scale: true
28 | attn_type: linear
29 | ffn_type: glumbconv
30 | mlp_acts:
31 | - silu
32 | - silu
33 | -
34 | mlp_ratio: 2.5
35 | use_pe: false
36 | qk_norm: false
37 | class_dropout_prob: 0.1
38 | # VAE setting
39 | vae:
40 | vae_type: AutoencoderDC
41 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
42 | scale_factor: 0.41407
43 | vae_latent_dim: 32
44 | vae_downsample_rate: 32
45 | sample_posterior: true
46 | # text encoder
47 | text_encoder:
48 | text_encoder_name: gemma-2-2b-it
49 | y_norm: true
50 | y_norm_scale_factor: 0.01
51 | model_max_length: 300
52 | # CHI
53 | chi_prompt:
54 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
55 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
56 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
57 | - 'Here are examples of how to transform or refine prompts:'
58 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
59 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
60 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
61 | - 'User Prompt: '
62 | # Sana schedule Flow
63 | scheduler:
64 | predict_flow_v: true
65 | noise_schedule: linear_flow
66 | pred_sigma: false
67 | flow_shift: 4.0
68 | # logit-normal timestep
69 | weighting_scheme: logit_normal
70 | logit_mean: 0.0
71 | logit_std: 1.0
72 | vis_sampler: flow_dpm-solver
73 | # training setting
74 | train:
75 | num_workers: 10
76 | seed: 1
77 | train_batch_size: 64
78 | num_epochs: 100
79 | gradient_accumulation_steps: 1
80 | grad_checkpointing: true
81 | gradient_clip: 0.1
82 | optimizer:
83 | betas:
84 | - 0.9
85 | - 0.999
86 | - 0.9999
87 | eps:
88 | - 1.0e-30
89 | - 1.0e-16
90 | lr: 0.0001
91 | type: CAMEWrapper
92 | weight_decay: 0.0
93 | lr_schedule: constant
94 | lr_schedule_args:
95 | num_warmup_steps: 2000
96 | local_save_vis: true # if save log image locally
97 | visualize: true
98 | eval_sampling_steps: 500
99 | log_interval: 20
100 | save_model_epochs: 5
101 | save_model_steps: 500
102 | work_dir: output/debug
103 | online_metric: false
104 | eval_metric_step: 2000
105 | online_metric_dir: metric_helper
106 |
--------------------------------------------------------------------------------
/configs/sana_config/512ms/Sana_1600M_img512.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: [data/data_public/dir1]
3 | image_size: 512
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7 | external_clipscore_suffixes:
8 | - _InternVL2-26B_clip_score
9 | - _VILA1-5-13B_clip_score
10 | - _prompt_clip_score
11 | clip_thr_temperature: 0.1
12 | clip_thr: 25.0
13 | load_text_feat: false
14 | load_vae_feat: false
15 | transform: default_train
16 | type: SanaWebDatasetMS
17 | sort_dataset: false
18 | # model config
19 | model:
20 | model: SanaMS_1600M_P1_D20
21 | image_size: 512
22 | mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
23 | fp32_attention: true
24 | load_from:
25 | resume_from:
26 | aspect_ratio_type: ASPECT_RATIO_512
27 | multi_scale: true
28 | attn_type: linear
29 | ffn_type: glumbconv
30 | mlp_acts:
31 | - silu
32 | - silu
33 | -
34 | mlp_ratio: 2.5
35 | use_pe: false
36 | qk_norm: false
37 | class_dropout_prob: 0.1
38 | # PAG
39 | pag_applied_layers:
40 | - 8
41 | # VAE setting
42 | vae:
43 | vae_type: AutoencoderDC
44 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
45 | scale_factor: 0.41407
46 | vae_latent_dim: 32
47 | vae_downsample_rate: 32
48 | sample_posterior: true
49 | # text encoder
50 | text_encoder:
51 | text_encoder_name: gemma-2-2b-it
52 | y_norm: true
53 | y_norm_scale_factor: 0.01
54 | model_max_length: 300
55 | # CHI
56 | chi_prompt:
57 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
58 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
59 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
60 | - 'Here are examples of how to transform or refine prompts:'
61 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
62 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
63 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
64 | - 'User Prompt: '
65 | # Sana schedule Flow
66 | scheduler:
67 | predict_flow_v: true
68 | noise_schedule: linear_flow
69 | pred_sigma: false
70 | flow_shift: 3.0
71 | # logit-normal timestep
72 | weighting_scheme: logit_normal
73 | logit_mean: 0.0
74 | logit_std: 1.0
75 | vis_sampler: flow_dpm-solver
76 | # training setting
77 | train:
78 | num_workers: 10
79 | seed: 1
80 | train_batch_size: 64
81 | num_epochs: 100
82 | gradient_accumulation_steps: 1
83 | grad_checkpointing: true
84 | gradient_clip: 0.1
85 | optimizer:
86 | betas:
87 | - 0.9
88 | - 0.999
89 | - 0.9999
90 | eps:
91 | - 1.0e-30
92 | - 1.0e-16
93 | lr: 0.0001
94 | type: CAMEWrapper
95 | weight_decay: 0.0
96 | lr_schedule: constant
97 | lr_schedule_args:
98 | num_warmup_steps: 2000
99 | local_save_vis: true # if save log image locally
100 | visualize: true
101 | eval_sampling_steps: 500
102 | log_interval: 20
103 | save_model_epochs: 5
104 | save_model_steps: 500
105 | work_dir: output/debug
106 | online_metric: false
107 | eval_metric_step: 2000
108 | online_metric_dir: metric_helper
109 |
--------------------------------------------------------------------------------
/configs/sana_config/512ms/Sana_600M_img512.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: [data/data_public/dir1]
3 | image_size: 512
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7 | external_clipscore_suffixes:
8 | - _InternVL2-26B_clip_score
9 | - _VILA1-5-13B_clip_score
10 | - _prompt_clip_score
11 | clip_thr_temperature: 0.1
12 | clip_thr: 25.0
13 | load_text_feat: false
14 | load_vae_feat: false
15 | transform: default_train
16 | type: SanaWebDatasetMS
17 | sort_dataset: false
18 | # model config
19 | model:
20 | model: SanaMS_600M_P1_D28
21 | image_size: 512
22 | mixed_precision: fp16
23 | fp32_attention: true
24 | load_from:
25 | resume_from:
26 | aspect_ratio_type: ASPECT_RATIO_512
27 | multi_scale: true
28 | #pe_interpolation: 1.
29 | attn_type: linear
30 | linear_head_dim: 32
31 | ffn_type: glumbconv
32 | mlp_acts:
33 | - silu
34 | - silu
35 | - null
36 | mlp_ratio: 2.5
37 | use_pe: false
38 | qk_norm: false
39 | class_dropout_prob: 0.1
40 | # VAE setting
41 | vae:
42 | vae_type: AutoencoderDC
43 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
44 | scale_factor: 0.41407
45 | vae_latent_dim: 32
46 | vae_downsample_rate: 32
47 | sample_posterior: true
48 | # text encoder
49 | text_encoder:
50 | text_encoder_name: gemma-2-2b-it
51 | y_norm: true
52 | y_norm_scale_factor: 0.01
53 | model_max_length: 300
54 | # CHI
55 | chi_prompt:
56 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
57 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
58 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
59 | - 'Here are examples of how to transform or refine prompts:'
60 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
61 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
62 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
63 | - 'User Prompt: '
64 | # Sana schedule Flow
65 | scheduler:
66 | predict_flow_v: true
67 | noise_schedule: linear_flow
68 | pred_sigma: false
69 | flow_shift: 3.0
70 | # logit-normal timestep
71 | weighting_scheme: logit_normal
72 | logit_mean: 0.0
73 | logit_std: 1.0
74 | vis_sampler: flow_dpm-solver
75 | # training setting
76 | train:
77 | num_workers: 10
78 | seed: 1
79 | train_batch_size: 128
80 | num_epochs: 100
81 | gradient_accumulation_steps: 1
82 | grad_checkpointing: true
83 | gradient_clip: 0.1
84 | optimizer:
85 | betas:
86 | - 0.9
87 | - 0.999
88 | - 0.9999
89 | eps:
90 | - 1.0e-30
91 | - 1.0e-16
92 | lr: 0.0001
93 | type: CAMEWrapper
94 | weight_decay: 0.0
95 | lr_schedule: constant
96 | lr_schedule_args:
97 | num_warmup_steps: 2000
98 | local_save_vis: true # if save log image locally
99 | visualize: true
100 | eval_sampling_steps: 500
101 | log_interval: 20
102 | save_model_epochs: 5
103 | save_model_steps: 500
104 | work_dir: output/debug
105 | online_metric: false
106 | eval_metric_step: 2000
107 | online_metric_dir: metric_helper
108 |
--------------------------------------------------------------------------------
/configs/sana_config/512ms/ci_Sana_600M_img512.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: [data/data_public/vaef32c32_v2_512/dir1]
3 | image_size: 512
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7 | external_clipscore_suffixes:
8 | - _InternVL2-26B_clip_score
9 | - _VILA1-5-13B_clip_score
10 | - _prompt_clip_score
11 | clip_thr_temperature: 0.1
12 | clip_thr: 25.0
13 | load_text_feat: false
14 | load_vae_feat: false
15 | transform: default_train
16 | type: SanaWebDatasetMS
17 | aspect_ratio_type: ASPECT_RATIO_512
18 | sort_dataset: false
19 | # model config
20 | model:
21 | model: SanaMS_600M_P1_D28
22 | image_size: 512
23 | mixed_precision: fp16
24 | fp32_attention: true
25 | load_from:
26 | resume_from:
27 | multi_scale: true
28 | #pe_interpolation: 1.
29 | attn_type: linear
30 | linear_head_dim: 32
31 | ffn_type: glumbconv
32 | mlp_acts:
33 | - silu
34 | - silu
35 | - null
36 | mlp_ratio: 2.5
37 | use_pe: false
38 | qk_norm: false
39 | class_dropout_prob: 0.1
40 | # VAE setting
41 | vae:
42 | vae_type: AutoencoderDC
43 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
44 | scale_factor: 0.41407
45 | vae_latent_dim: 32
46 | vae_downsample_rate: 32
47 | sample_posterior: true
48 | # text encoder
49 | text_encoder:
50 | text_encoder_name: gemma-2-2b-it
51 | y_norm: true
52 | y_norm_scale_factor: 0.01
53 | model_max_length: 300
54 | # CHI
55 | chi_prompt:
56 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
57 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
58 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
59 | - 'Here are examples of how to transform or refine prompts:'
60 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
61 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
62 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
63 | - 'User Prompt: '
64 | # Sana schedule Flow
65 | scheduler:
66 | predict_flow_v: true
67 | noise_schedule: linear_flow
68 | pred_sigma: false
69 | flow_shift: 1.0
70 | # logit-normal timestep
71 | weighting_scheme: logit_normal
72 | logit_mean: 0.0
73 | logit_std: 1.0
74 | vis_sampler: flow_dpm-solver
75 | # training setting
76 | train:
77 | num_workers: 10
78 | seed: 1
79 | train_batch_size: 64
80 | num_epochs: 1
81 | gradient_accumulation_steps: 1
82 | grad_checkpointing: true
83 | gradient_clip: 0.1
84 | optimizer:
85 | betas:
86 | - 0.9
87 | - 0.999
88 | - 0.9999
89 | eps:
90 | - 1.0e-30
91 | - 1.0e-16
92 | lr: 0.0001
93 | type: CAMEWrapper
94 | weight_decay: 0.0
95 | lr_schedule: constant
96 | lr_schedule_args:
97 | num_warmup_steps: 2000
98 | local_save_vis: true # if save log image locally
99 | visualize: true
100 | eval_sampling_steps: 500
101 | log_interval: 20
102 | save_model_epochs: 5
103 | save_model_steps: 500
104 | work_dir: output/debug
105 | online_metric: false
106 | eval_metric_step: 2000
107 | online_metric_dir: metric_helper
108 |
--------------------------------------------------------------------------------
/configs/sana_config/512ms/sample_dataset.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: [asset/example_data]
3 | image_size: 512
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] # json fils
7 | external_clipscore_suffixes: # json files
8 | - _InternVL2-26B_clip_score
9 | - _VILA1-5-13B_clip_score
10 | - _prompt_clip_score
11 | clip_thr_temperature: 0.1
12 | clip_thr: 25.0
13 | load_text_feat: false
14 | load_vae_feat: false
15 | transform: default_train
16 | type: SanaImgDataset
17 | sort_dataset: false
18 | # model config
19 | model:
20 | model: SanaMS_600M_P1_D28
21 | image_size: 512
22 | mixed_precision: fp16
23 | fp32_attention: true
24 | load_from:
25 | resume_from:
26 | aspect_ratio_type: ASPECT_RATIO_512
27 | multi_scale: false
28 | #pe_interpolation: 1.
29 | attn_type: linear
30 | linear_head_dim: 32
31 | ffn_type: glumbconv
32 | mlp_acts:
33 | - silu
34 | - silu
35 | - null
36 | mlp_ratio: 2.5
37 | use_pe: false
38 | qk_norm: false
39 | class_dropout_prob: 0.1
40 | # VAE setting
41 | vae:
42 | vae_type: AutoencoderDC
43 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
44 | scale_factor: 0.41407
45 | vae_latent_dim: 32
46 | vae_downsample_rate: 32
47 | sample_posterior: true
48 | # text encoder
49 | text_encoder:
50 | text_encoder_name: gemma-2-2b-it
51 | y_norm: true
52 | y_norm_scale_factor: 0.01
53 | model_max_length: 300
54 | # CHI
55 | chi_prompt:
56 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
57 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
58 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
59 | - 'Here are examples of how to transform or refine prompts:'
60 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
61 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
62 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
63 | - 'User Prompt: '
64 | # Sana schedule Flow
65 | scheduler:
66 | predict_flow_v: true
67 | noise_schedule: linear_flow
68 | pred_sigma: false
69 | flow_shift: 1.0
70 | # logit-normal timestep
71 | weighting_scheme: logit_normal
72 | logit_mean: 0.0
73 | logit_std: 1.0
74 | vis_sampler: flow_dpm-solver
75 | # training setting
76 | train:
77 | num_workers: 10
78 | seed: 1
79 | train_batch_size: 128
80 | num_epochs: 100
81 | gradient_accumulation_steps: 1
82 | grad_checkpointing: true
83 | gradient_clip: 0.1
84 | optimizer:
85 | betas:
86 | - 0.9
87 | - 0.999
88 | - 0.9999
89 | eps:
90 | - 1.0e-30
91 | - 1.0e-16
92 | lr: 0.0001
93 | type: CAMEWrapper
94 | weight_decay: 0.0
95 | lr_schedule: constant
96 | lr_schedule_args:
97 | num_warmup_steps: 2000
98 | local_save_vis: true # if save log image locally
99 | visualize: true
100 | eval_sampling_steps: 500
101 | log_interval: 20
102 | save_model_epochs: 5
103 | save_model_steps: 500
104 | work_dir: output/debug
105 | online_metric: false
106 | eval_metric_step: 2000
107 | online_metric_dir: metric_helper
108 |
--------------------------------------------------------------------------------
/configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: [data/data_public/controlnet_data]
3 | image_size: 1024
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: []
7 | external_clipscore_suffixes: []
8 | clip_thr_temperature: 0.1
9 | clip_thr: 25.0
10 | load_text_feat: false
11 | load_vae_feat: false
12 | transform: default_train
13 | type: SanaWebDatasetMSControl
14 | sort_dataset: false
15 | # model config
16 | model:
17 | model: SanaMSControlNet_1600M_P1_D20
18 | image_size: 1024
19 | mixed_precision: bf16
20 | fp32_attention: true
21 | load_from: hf://Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoint/Sana_1600M_1024px_BF16.pth
22 | resume_from:
23 | aspect_ratio_type: ASPECT_RATIO_1024
24 | multi_scale: true
25 | attn_type: linear
26 | ffn_type: glumbconv
27 | mlp_acts:
28 | - silu
29 | - silu
30 | -
31 | mlp_ratio: 2.5
32 | use_pe: false
33 | qk_norm: false
34 | class_dropout_prob: 0.1
35 | # VAE setting
36 | vae:
37 | vae_type: AutoencoderDC
38 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
39 | scale_factor: 0.41407
40 | vae_latent_dim: 32
41 | vae_downsample_rate: 32
42 | sample_posterior: true
43 | weight_dtype: bf16
44 | # text encoder
45 | text_encoder:
46 | text_encoder_name: gemma-2-2b-it
47 | y_norm: true
48 | y_norm_scale_factor: 0.01
49 | model_max_length: 300
50 | # CHI
51 | chi_prompt:
52 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
53 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
54 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
55 | - 'Here are examples of how to transform or refine prompts:'
56 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
57 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
58 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
59 | - 'User Prompt: '
60 | # Sana schedule Flow
61 | scheduler:
62 | predict_flow_v: true
63 | noise_schedule: linear_flow
64 | pred_sigma: false
65 | flow_shift: 3.0
66 | # logit-normal timestep
67 | weighting_scheme: logit_normal
68 | logit_mean: 0.0
69 | logit_std: 1.0
70 | vis_sampler: flow_dpm-solver
71 | # training setting
72 | train:
73 | num_workers: 10
74 | seed: 1
75 | train_batch_size: 16
76 | num_epochs: 100
77 | gradient_accumulation_steps: 1
78 | grad_checkpointing: true
79 | gradient_clip: 0.1
80 | optimizer:
81 | betas:
82 | - 0.9
83 | - 0.999
84 | - 0.9999
85 | eps:
86 | - 1.0e-30
87 | - 1.0e-16
88 | lr: 0.0001
89 | type: CAMEWrapper
90 | weight_decay: 0.0
91 | lr_schedule: constant
92 | lr_schedule_args:
93 | num_warmup_steps: 30
94 | local_save_vis: true # if save log image locally
95 | visualize: true
96 | eval_sampling_steps: 50
97 | log_interval: 20
98 | save_model_epochs: 5
99 | save_model_steps: 500
100 | work_dir: output/debug
101 | online_metric: false
102 | eval_metric_step: 2000
103 | online_metric_dir: metric_helper
104 | controlnet:
105 | control_signal_type: "scribble"
106 |
--------------------------------------------------------------------------------
/configs/sana_controlnet_config/Sana_600M_img1024_controlnet.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: [data/data_public/controlnet_data]
3 | image_size: 1024
4 | caption_proportion:
5 | prompt: 1
6 | external_caption_suffixes: []
7 | external_clipscore_suffixes: []
8 | clip_thr_temperature: 0.1
9 | clip_thr: 25.0
10 | load_text_feat: false
11 | load_vae_feat: false
12 | transform: default_train
13 | type: SanaWebDatasetMSControl
14 | sort_dataset: false
15 | # model config
16 | model:
17 | model: SanaMSControlNet_600M_P1_D28
18 | image_size: 1024
19 | mixed_precision: fp16
20 | fp32_attention: true
21 | load_from: hf://Efficient-Large-Model/Sana_600M_1024px/checkpoint/Sana_600M_1024px.pth
22 | resume_from:
23 | aspect_ratio_type: ASPECT_RATIO_1024
24 | multi_scale: true
25 | attn_type: linear
26 | ffn_type: glumbconv
27 | mlp_acts:
28 | - silu
29 | - silu
30 | -
31 | mlp_ratio: 2.5
32 | use_pe: false
33 | qk_norm: false
34 | class_dropout_prob: 0.1
35 | # VAE setting
36 | vae:
37 | vae_type: AutoencoderDC
38 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
39 | scale_factor: 0.41407
40 | vae_latent_dim: 32
41 | vae_downsample_rate: 32
42 | sample_posterior: true
43 | # text encoder
44 | text_encoder:
45 | text_encoder_name: gemma-2-2b-it
46 | y_norm: true
47 | y_norm_scale_factor: 0.01
48 | model_max_length: 300
49 | # CHI
50 | chi_prompt:
51 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
52 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
53 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
54 | - 'Here are examples of how to transform or refine prompts:'
55 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
56 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
57 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
58 | - 'User Prompt: '
59 | # Sana schedule Flow
60 | scheduler:
61 | predict_flow_v: true
62 | noise_schedule: linear_flow
63 | pred_sigma: false
64 | flow_shift: 4.0
65 | # logit-normal timestep
66 | weighting_scheme: logit_normal
67 | logit_mean: 0.0
68 | logit_std: 1.0
69 | vis_sampler: flow_dpm-solver
70 | # training setting
71 | train:
72 | num_workers: 10
73 | seed: 1
74 | train_batch_size: 16
75 | num_epochs: 100
76 | gradient_accumulation_steps: 1
77 | grad_checkpointing: true
78 | gradient_clip: 0.1
79 | optimizer:
80 | betas:
81 | - 0.9
82 | - 0.999
83 | - 0.9999
84 | eps:
85 | - 1.0e-30
86 | - 1.0e-16
87 | lr: 0.0001
88 | type: CAMEWrapper
89 | weight_decay: 0.0
90 | lr_schedule: constant
91 | lr_schedule_args:
92 | num_warmup_steps: 30
93 | local_save_vis: true # if save log image locally
94 | visualize: true
95 | eval_sampling_steps: 500
96 | log_interval: 20
97 | save_model_epochs: 5
98 | save_model_steps: 500
99 | work_dir: output/debug
100 | online_metric: false
101 | eval_metric_step: 2000
102 | online_metric_dir: metric_helper
103 | controlnet:
104 | control_signal_type: "scribble"
105 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.1.dev0"
2 |
3 | from diffusion.scheduler.dpm_solver import DPMS
4 | from diffusion.scheduler.flow_euler_sampler import FlowEuler, LTXFlowEuler
5 | from diffusion.scheduler.iddpm import Scheduler
6 | from diffusion.scheduler.longlive_flow_euler_sampler import LongLiveFlowEuler
7 | from diffusion.scheduler.sa_sampler import SASolverSampler
8 | from diffusion.scheduler.scm_scheduler import SCMScheduler
9 | from diffusion.scheduler.trigflow_scheduler import TrigFlowScheduler
10 |
--------------------------------------------------------------------------------
/diffusion/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
2 | from .transforms import get_transform
3 |
--------------------------------------------------------------------------------
/diffusion/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .sana_data import SanaImgDataset, SanaWebDataset
2 | from .sana_data_multi_scale import DummyDatasetMS, SanaWebDatasetMS
3 | from .utils import *
4 | from .video.sana_video_data import DistributePromptsDataset, SanaZipDataset
5 |
--------------------------------------------------------------------------------
/diffusion/data/wids/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
2 | # This file is part of the WebDataset library.
3 | # See the LICENSE file for licensing terms (BSD-style).
4 | #
5 | # flake8: noqa
6 |
7 | from .wids import (
8 | ChunkedSampler,
9 | DistributedChunkedSampler,
10 | DistributedLocalSampler,
11 | DistributedRangedSampler,
12 | ShardedSampler,
13 | ShardListDataset,
14 | ShardListDatasetMulti,
15 | lru_json_load,
16 | )
17 |
--------------------------------------------------------------------------------
/diffusion/data/wids/wids_lru.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
18 | from collections import OrderedDict
19 |
20 |
21 | class LRUCache:
22 | def __init__(self, capacity: int, release_handler=None):
23 | """Initialize a new LRU cache with the given capacity."""
24 | self.capacity = capacity
25 | self.cache = OrderedDict()
26 | self.release_handler = release_handler
27 |
28 | def __getitem__(self, key):
29 | """Return the value associated with the given key, or None."""
30 | if key not in self.cache:
31 | return None
32 | self.cache.move_to_end(key)
33 | return self.cache[key]
34 |
35 | def __setitem__(self, key, value):
36 | """Associate the given value with the given key."""
37 | if key in self.cache:
38 | self.cache.move_to_end(key)
39 | self.cache[key] = value
40 | if len(self.cache) > self.capacity:
41 | key, value = self.cache.popitem(last=False)
42 | if self.release_handler is not None:
43 | self.release_handler(key, value)
44 |
45 | def __delitem__(self, key):
46 | """Remove the given key from the cache."""
47 | if key in self.cache:
48 | if self.release_handler is not None:
49 | value = self.cache[key]
50 | self.release_handler(key, value)
51 | del self.cache[key]
52 |
53 | def __len__(self):
54 | """Return the number of entries in the cache."""
55 | return len(self.cache)
56 |
57 | def __contains__(self, key):
58 | """Return whether the cache contains the given key."""
59 | return key in self.cache
60 |
61 | def items(self):
62 | """Return an iterator over the keys of the cache."""
63 | return self.cache.items()
64 |
65 | def keys(self):
66 | """Return an iterator over the keys of the cache."""
67 | return self.cache.keys()
68 |
69 | def values(self):
70 | """Return an iterator over the values of the cache."""
71 | return self.cache.values()
72 |
73 | def clear(self):
74 | for key in list(self.keys()):
75 | value = self.cache[key]
76 | if self.release_handler is not None:
77 | self.release_handler(key, value)
78 | del self[key]
79 |
80 | def __del__(self):
81 | self.clear()
82 |
--------------------------------------------------------------------------------
/diffusion/data/wids/wids_tar.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
18 | import io
19 | import os
20 | import os.path
21 | import pickle
22 | import re
23 | import tarfile
24 |
25 | import numpy as np
26 |
27 |
28 | def find_index_file(file):
29 | prefix, last_ext = os.path.splitext(file)
30 | if re.match("._[0-9]+_quot;, last_ext):
31 | return prefix + ".index"
32 | else:
33 | return file + ".index"
34 |
35 |
36 | class TarFileReader:
37 | def __init__(self, file, index_file=find_index_file, verbose=True):
38 | self.verbose = verbose
39 | if callable(index_file):
40 | index_file = index_file(file)
41 | self.index_file = index_file
42 |
43 | # Open the tar file and keep it open
44 | if isinstance(file, str):
45 | self.tar_file = tarfile.open(file, "r")
46 | else:
47 | self.tar_file = tarfile.open(fileobj=file, mode="r")
48 |
49 | # Create the index
50 | self._create_tar_index()
51 |
52 | def _create_tar_index(self):
53 | if self.index_file is not None and os.path.exists(self.index_file):
54 | if self.verbose:
55 | print("Loading tar index from", self.index_file)
56 | with open(self.index_file, "rb") as stream:
57 | self.fnames, self.index = pickle.load(stream)
58 | return
59 | # Create an empty list for the index
60 | self.fnames = []
61 | self.index = []
62 |
63 | if self.verbose:
64 | print("Creating tar index for", self.tar_file.name, "at", self.index_file)
65 | # Iterate over the members of the tar file
66 | for member in self.tar_file:
67 | # If the member is a file, add it to the index
68 | if member.isfile():
69 | # Get the file's offset
70 | offset = self.tar_file.fileobj.tell()
71 | self.fnames.append(member.name)
72 | self.index.append([offset, member.size])
73 | if self.verbose:
74 | print("Done creating tar index for", self.tar_file.name, "at", self.index_file)
75 | self.index = np.array(self.index)
76 | if self.index_file is not None:
77 | if os.path.exists(self.index_file + ".temp"):
78 | os.unlink(self.index_file + ".temp")
79 | with open(self.index_file + ".temp", "wb") as stream:
80 | pickle.dump((self.fnames, self.index), stream)
81 | os.rename(self.index_file + ".temp", self.index_file)
82 |
83 | def names(self):
84 | return self.fnames
85 |
86 | def __len__(self):
87 | return len(self.index)
88 |
89 | def get_file(self, i):
90 | name = self.fnames[i]
91 | offset, size = self.index[i]
92 | self.tar_file.fileobj.seek(offset)
93 | file_bytes = self.tar_file.fileobj.read(size)
94 | return name, io.BytesIO(file_bytes)
95 |
96 | def close(self):
97 | # Close the tar file
98 | self.tar_file.close()
99 |
--------------------------------------------------------------------------------
/diffusion/guiders/__init__.py:
--------------------------------------------------------------------------------
1 | from .adaptive_projected_guidance import AdaptiveProjectedGuidance
2 |
3 | __all__ = ["AdaptiveProjectedGuidance"]
4 |
--------------------------------------------------------------------------------
/diffusion/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/model/__init__.py
--------------------------------------------------------------------------------
/diffusion/model/act.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import copy
18 |
19 | import torch.nn as nn
20 |
21 | __all__ = ["build_act", "get_act_name"]
22 |
23 | # register activation function here
24 | # name: module, kwargs with default values
25 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = {
26 | "relu": (nn.ReLU, {"inplace": True}),
27 | "relu6": (nn.ReLU6, {"inplace": True}),
28 | "hswish": (nn.Hardswish, {"inplace": True}),
29 | "hsigmoid": (nn.Hardsigmoid, {"inplace": True}),
30 | "swish": (nn.SiLU, {"inplace": True}),
31 | "silu": (nn.SiLU, {"inplace": True}),
32 | "tanh": (nn.Tanh, {}),
33 | "sigmoid": (nn.Sigmoid, {}),
34 | "gelu": (nn.GELU, {"approximate": "tanh"}),
35 | "mish": (nn.Mish, {"inplace": True}),
36 | "identity": (nn.Identity, {}),
37 | }
38 |
39 |
40 | def build_act(name: str or None, **kwargs) -> nn.Module or None:
41 | if name in REGISTERED_ACT_DICT:
42 | act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name])
43 | for key in default_args:
44 | if key in kwargs:
45 | default_args[key] = kwargs[key]
46 | return act_cls(**default_args)
47 | elif name is None or name.lower() == "none":
48 | return None
49 | else:
50 | raise ValueError(f"do not support: {name}")
51 |
52 |
53 | def get_act_name(act: nn.Module or None) -> str or None:
54 | if act is None:
55 | return None
56 | module2name = {}
57 | for key, config in REGISTERED_ACT_DICT.items():
58 | module2name[config[0].__name__] = key
59 | return module2name.get(type(act).__name__, "unknown")
60 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/model/dc_ae/efficientvit/__init__.py
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/model/dc_ae/efficientvit/apps/__init__.py
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from copy import deepcopy
4 | from typing import Optional
5 |
6 | import torch.backends.cudnn
7 | import torch.distributed
8 | import torch.nn as nn
9 |
10 | from ..apps.utils import (
11 | dist_init,
12 | dump_config,
13 | get_dist_local_rank,
14 | get_dist_rank,
15 | get_dist_size,
16 | init_modules,
17 | is_master,
18 | load_config,
19 | partial_update_config,
20 | zero_last_gamma,
21 | )
22 | from ..models.utils import build_kwargs_from_config, load_state_dict_from_file
23 |
24 | __all__ = [
25 | "save_exp_config",
26 | "setup_dist_env",
27 | "setup_seed",
28 | "setup_exp_config",
29 | "init_model",
30 | ]
31 |
32 |
33 | def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
34 | if not is_master():
35 | return
36 | dump_config(exp_config, os.path.join(path, name))
37 |
38 |
39 | def setup_dist_env(gpu: Optional[str] = None) -> None:
40 | if gpu is not None:
41 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu
42 | if not torch.distributed.is_initialized():
43 | dist_init()
44 | torch.backends.cudnn.benchmark = True
45 | torch.cuda.set_device(get_dist_local_rank())
46 |
47 |
48 | def setup_seed(manual_seed: int, resume: bool) -> None:
49 | if resume:
50 | manual_seed = int(time.time())
51 | manual_seed = get_dist_rank() + manual_seed
52 | torch.manual_seed(manual_seed)
53 | torch.cuda.manual_seed_all(manual_seed)
54 |
55 |
56 | def setup_exp_config(config_path: str, recursive=True, opt_args: Optional[dict] = None) -> dict:
57 | # load config
58 | if not os.path.isfile(config_path):
59 | raise ValueError(config_path)
60 |
61 | fpaths = [config_path]
62 | if recursive:
63 | extension = os.path.splitext(config_path)[1]
64 | while os.path.dirname(config_path) != config_path:
65 | config_path = os.path.dirname(config_path)
66 | fpath = os.path.join(config_path, "default" + extension)
67 | if os.path.isfile(fpath):
68 | fpaths.append(fpath)
69 | fpaths = fpaths[::-1]
70 |
71 | default_config = load_config(fpaths[0])
72 | exp_config = deepcopy(default_config)
73 | for fpath in fpaths[1:]:
74 | partial_update_config(exp_config, load_config(fpath))
75 | # update config via args
76 | if opt_args is not None:
77 | partial_update_config(exp_config, opt_args)
78 |
79 | return exp_config
80 |
81 |
82 | def init_model(
83 | network: nn.Module,
84 | init_from: Optional[str] = None,
85 | backbone_init_from: Optional[str] = None,
86 | rand_init="trunc_normal",
87 | last_gamma=None,
88 | ) -> None:
89 | # initialization
90 | init_modules(network, init_type=rand_init)
91 | # zero gamma of last bn in each block
92 | if last_gamma is not None:
93 | zero_last_gamma(network, last_gamma)
94 |
95 | # load weight
96 | if init_from is not None and os.path.isfile(init_from):
97 | network.load_state_dict(load_state_dict_from_file(init_from))
98 | print(f"Loaded init from {init_from}")
99 | elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
100 | network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
101 | print(f"Loaded backbone init from {backbone_init_from}")
102 | else:
103 | print(f"Random init ({rand_init}) with last gamma {last_gamma}")
104 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .run_config import *
2 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dist import *
2 | from .ema import *
3 |
4 | # from .export import *
5 | from .image import *
6 | from .init import *
7 | from .lr import *
8 | from .metric import *
9 | from .misc import *
10 | from .opt import *
11 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/dist.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import os
18 | from typing import Union
19 |
20 | import torch
21 | import torch.distributed
22 |
23 | from ...models.utils.list import list_mean, list_sum
24 |
25 | __all__ = [
26 | "dist_init",
27 | "is_dist_initialized",
28 | "get_dist_rank",
29 | "get_dist_size",
30 | "is_master",
31 | "dist_barrier",
32 | "get_dist_local_rank",
33 | "sync_tensor",
34 | ]
35 |
36 |
37 | def dist_init() -> None:
38 | if is_dist_initialized():
39 | return
40 | try:
41 | torch.distributed.init_process_group(backend="nccl")
42 | assert torch.distributed.is_initialized()
43 | except Exception:
44 | os.environ["RANK"] = "0"
45 | os.environ["WORLD_SIZE"] = "1"
46 | os.environ["LOCAL_RANK"] = "0"
47 | print("warning: dist not init")
48 |
49 |
50 | def is_dist_initialized() -> bool:
51 | return torch.distributed.is_initialized()
52 |
53 |
54 | def get_dist_rank() -> int:
55 | return int(os.environ["RANK"])
56 |
57 |
58 | def get_dist_size() -> int:
59 | return int(os.environ["WORLD_SIZE"])
60 |
61 |
62 | def is_master() -> bool:
63 | return get_dist_rank() == 0
64 |
65 |
66 | def dist_barrier() -> None:
67 | if is_dist_initialized():
68 | torch.distributed.barrier()
69 |
70 |
71 | def get_dist_local_rank() -> int:
72 | return int(os.environ["LOCAL_RANK"])
73 |
74 |
75 | def sync_tensor(tensor: Union[torch.Tensor, float], reduce="mean") -> Union[torch.Tensor, list[torch.Tensor]]:
76 | if not is_dist_initialized():
77 | return tensor
78 | if not isinstance(tensor, torch.Tensor):
79 | tensor = torch.Tensor(1).fill_(tensor).cuda()
80 | tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
81 | torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
82 | if reduce == "mean":
83 | return list_mean(tensor_list)
84 | elif reduce == "sum":
85 | return list_sum(tensor_list)
86 | elif reduce == "cat":
87 | return torch.cat(tensor_list, dim=0)
88 | elif reduce == "root":
89 | return tensor_list[0]
90 | else:
91 | return tensor_list
92 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/ema.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import copy
18 | import math
19 |
20 | import torch
21 | import torch.nn as nn
22 |
23 | from ...models.utils import is_parallel
24 |
25 | __all__ = ["EMA"]
26 |
27 |
28 | def update_ema(ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float) -> None:
29 | for k, v in ema.state_dict().items():
30 | if v.dtype.is_floating_point:
31 | v -= (1.0 - decay) * (v - new_state_dict[k].detach())
32 |
33 |
34 | class EMA:
35 | def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
36 | self.shadows = copy.deepcopy(model.module if is_parallel(model) else model).eval()
37 | self.decay = decay
38 | self.warmup_steps = warmup_steps
39 |
40 | for p in self.shadows.parameters():
41 | p.requires_grad = False
42 |
43 | def step(self, model: nn.Module, global_step: int) -> None:
44 | with torch.no_grad():
45 | msd = (model.module if is_parallel(model) else model).state_dict()
46 | update_ema(self.shadows, msd, self.decay * (1 - math.exp(-global_step / self.warmup_steps)))
47 |
48 | def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
49 | return {self.decay: self.shadows.state_dict()}
50 |
51 | def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
52 | for decay in state_dict:
53 | if decay == self.decay:
54 | self.shadows.load_state_dict(state_dict[decay])
55 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/export.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import io
18 | import os
19 | from typing import Any
20 |
21 | import onnx
22 | import torch
23 | import torch.nn as nn
24 | from onnxsim import simplify as simplify_func
25 |
26 | __all__ = ["export_onnx"]
27 |
28 |
29 | def export_onnx(model: nn.Module, export_path: str, sample_inputs: Any, simplify=True, opset=11) -> None:
30 | """Export a model to a platform-specific onnx format.
31 |
32 | Args:
33 | model: a torch.nn.Module object.
34 | export_path: export location.
35 | sample_inputs: Any.
36 | simplify: a flag to turn on onnx-simplifier
37 | opset: int
38 | """
39 | model.eval()
40 |
41 | buffer = io.BytesIO()
42 | with torch.no_grad():
43 | torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
44 | buffer.seek(0, 0)
45 | if simplify:
46 | onnx_model = onnx.load_model(buffer)
47 | onnx_model, success = simplify_func(onnx_model)
48 | assert success
49 | new_buffer = io.BytesIO()
50 | onnx.save(onnx_model, new_buffer)
51 | buffer = new_buffer
52 | buffer.seek(0, 0)
53 |
54 | if buffer.getbuffer().nbytes > 0:
55 | save_dir = os.path.dirname(export_path)
56 | os.makedirs(save_dir, exist_ok=True)
57 | with open(export_path, "wb") as f:
58 | f.write(buffer.read())
59 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/init.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | from typing import Union
18 |
19 | import torch
20 | import torch.nn as nn
21 | from torch.nn.modules.batchnorm import _BatchNorm
22 |
23 | __all__ = ["init_modules", "zero_last_gamma"]
24 |
25 |
26 | def init_modules(model: Union[nn.Module, list[nn.Module]], init_type="trunc_normal") -> None:
27 | _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
28 |
29 | if isinstance(model, list):
30 | for sub_module in model:
31 | init_modules(sub_module, init_type)
32 | else:
33 | init_params = init_type.split("@")
34 | init_params = float(init_params[1]) if len(init_params) > 1 else None
35 |
36 | if init_type.startswith("trunc_normal"):
37 | init_func = lambda param: nn.init.trunc_normal_(
38 | param, std=(_DEFAULT_INIT_PARAM["trunc_normal"] if init_params is None else init_params)
39 | )
40 | else:
41 | raise NotImplementedError
42 |
43 | for m in model.modules():
44 | if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
45 | init_func(m.weight)
46 | if m.bias is not None:
47 | m.bias.data.zero_()
48 | elif isinstance(m, nn.Embedding):
49 | init_func(m.weight)
50 | elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
51 | m.weight.data.fill_(1)
52 | m.bias.data.zero_()
53 | else:
54 | weight = getattr(m, "weight", None)
55 | bias = getattr(m, "bias", None)
56 | if isinstance(weight, torch.nn.Parameter):
57 | init_func(weight)
58 | if isinstance(bias, torch.nn.Parameter):
59 | bias.data.zero_()
60 |
61 |
62 | def zero_last_gamma(model: nn.Module, init_val=0) -> None:
63 | import efficientvit.models.nn.ops as ops
64 |
65 | for m in model.modules():
66 | if isinstance(m, ops.ResidualBlock) and isinstance(m.shortcut, ops.IdentityLayer):
67 | if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
68 | parent_module = m.main.point_conv
69 | elif isinstance(m.main, ops.ResBlock):
70 | parent_module = m.main.conv2
71 | elif isinstance(m.main, ops.ConvLayer):
72 | parent_module = m.main
73 | elif isinstance(m.main, (ops.LiteMLA)):
74 | parent_module = m.main.proj
75 | else:
76 | parent_module = None
77 | if parent_module is not None:
78 | norm = getattr(parent_module, "norm", None)
79 | if norm is not None:
80 | nn.init.constant_(norm.weight, init_val)
81 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/lr.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import math
18 | from typing import Union
19 |
20 | import torch
21 |
22 | from ...models.utils.list import val2list
23 |
24 | __all__ = ["CosineLRwithWarmup", "ConstantLRwithWarmup"]
25 |
26 |
27 | class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
28 | def __init__(
29 | self,
30 | optimizer: torch.optim.Optimizer,
31 | warmup_steps: int,
32 | warmup_lr: float,
33 | decay_steps: Union[int, list[int]],
34 | last_epoch: int = -1,
35 | ) -> None:
36 | self.warmup_steps = warmup_steps
37 | self.warmup_lr = warmup_lr
38 | self.decay_steps = val2list(decay_steps)
39 | super().__init__(optimizer, last_epoch)
40 |
41 | def get_lr(self) -> list[float]:
42 | if self.last_epoch < self.warmup_steps:
43 | return [
44 | (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr
45 | for base_lr in self.base_lrs
46 | ]
47 | else:
48 | current_steps = self.last_epoch - self.warmup_steps
49 | decay_steps = [0] + self.decay_steps
50 | idx = len(decay_steps) - 2
51 | for i, decay_step in enumerate(decay_steps[:-1]):
52 | if decay_step <= current_steps < decay_steps[i + 1]:
53 | idx = i
54 | break
55 | current_steps -= decay_steps[idx]
56 | decay_step = decay_steps[idx + 1] - decay_steps[idx]
57 | return [0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step)) for base_lr in self.base_lrs]
58 |
59 |
60 | class ConstantLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
61 | def __init__(
62 | self,
63 | optimizer: torch.optim.Optimizer,
64 | warmup_steps: int,
65 | warmup_lr: float,
66 | last_epoch: int = -1,
67 | ) -> None:
68 | self.warmup_steps = warmup_steps
69 | self.warmup_lr = warmup_lr
70 | super().__init__(optimizer, last_epoch)
71 |
72 | def get_lr(self) -> list[float]:
73 | if self.last_epoch < self.warmup_steps:
74 | return [
75 | (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr
76 | for base_lr in self.base_lrs
77 | ]
78 | else:
79 | return self.base_lrs
80 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/metric.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | from typing import Union
18 |
19 | import torch
20 |
21 | from ...apps.utils.dist import sync_tensor
22 |
23 | __all__ = ["AverageMeter"]
24 |
25 |
26 | class AverageMeter:
27 | """Computes and stores the average and current value."""
28 |
29 | def __init__(self, is_distributed=True):
30 | self.is_distributed = is_distributed
31 | self.sum = 0
32 | self.count = 0
33 |
34 | def _sync(self, val: Union[torch.Tensor, int, float]) -> Union[torch.Tensor, int, float]:
35 | return sync_tensor(val, reduce="sum") if self.is_distributed else val
36 |
37 | def update(self, val: Union[torch.Tensor, int, float], delta_n=1):
38 | self.count += self._sync(delta_n)
39 | self.sum += self._sync(val * delta_n)
40 |
41 | def get_count(self) -> Union[torch.Tensor, int, float]:
42 | return self.count.item() if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 else self.count
43 |
44 | @property
45 | def avg(self):
46 | avg = -1 if self.count == 0 else self.sum / self.count
47 | return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
48 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/apps/utils/opt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | from typing import Any, Optional
18 |
19 | import torch
20 |
21 | __all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]
22 |
23 | # register optimizer here
24 | # name: optimizer, kwargs with default values
25 | REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, Any]]] = {
26 | "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
27 | "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
28 | "adamw": (torch.optim.AdamW, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
29 | }
30 |
31 |
32 | def build_optimizer(
33 | net_params, optimizer_name: str, optimizer_params: Optional[dict], init_lr: float
34 | ) -> torch.optim.Optimizer:
35 | optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
36 | optimizer_params = {} if optimizer_params is None else optimizer_params
37 |
38 | for key in default_params:
39 | if key in optimizer_params:
40 | default_params[key] = optimizer_params[key]
41 | optimizer = optimizer_class(net_params, init_lr, **default_params)
42 | return optimizer
43 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/model/dc_ae/efficientvit/models/__init__.py
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/efficientvit/__init__.py:
--------------------------------------------------------------------------------
1 | from .dc_ae import *
2 | from .dc_ae_with_temporal import *
3 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .act import *
2 | from .drop import *
3 | from .norm import *
4 | from .ops import *
5 | from .triton_rms_norm import *
6 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/nn/act.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | from functools import partial
18 | from typing import Optional
19 |
20 | import torch.nn as nn
21 |
22 | from ...models.utils import build_kwargs_from_config
23 |
24 | __all__ = ["build_act"]
25 |
26 |
27 | # register activation function here
28 | REGISTERED_ACT_DICT: dict[str, type] = {
29 | "relu": nn.ReLU,
30 | "relu6": nn.ReLU6,
31 | "hswish": nn.Hardswish,
32 | "silu": nn.SiLU,
33 | "gelu": partial(nn.GELU, approximate="tanh"),
34 | }
35 |
36 |
37 | def build_act(name: str, **kwargs) -> Optional[nn.Module]:
38 | if name in REGISTERED_ACT_DICT:
39 | act_cls = REGISTERED_ACT_DICT[name]
40 | args = build_kwargs_from_config(kwargs, act_cls)
41 | return act_cls(**args)
42 | else:
43 | return None
44 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .list import *
2 | from .network import *
3 | from .random import *
4 | from .video import *
5 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/utils/list.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | from typing import Any, Optional, Union
18 |
19 | __all__ = [
20 | "list_sum",
21 | "list_mean",
22 | "weighted_list_sum",
23 | "list_join",
24 | "val2list",
25 | "val2tuple",
26 | "squeeze_list",
27 | ]
28 |
29 |
30 | def list_sum(x: list) -> Any:
31 | return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
32 |
33 |
34 | def list_mean(x: list) -> Any:
35 | return list_sum(x) / len(x)
36 |
37 |
38 | def weighted_list_sum(x: list, weights: list) -> Any:
39 | assert len(x) == len(weights)
40 | return x[0] * weights[0] if len(x) == 1 else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:])
41 |
42 |
43 | def list_join(x: list, sep="\t", format_str="%s") -> str:
44 | return sep.join([format_str % val for val in x])
45 |
46 |
47 | def val2list(x: Union[list, tuple, Any], repeat_time=1) -> list:
48 | if isinstance(x, (list, tuple)):
49 | return list(x)
50 | return [x for _ in range(repeat_time)]
51 |
52 |
53 | def val2tuple(x: Union[list, tuple, Any], min_len: int = 1, idx_repeat: int = -1) -> tuple:
54 | x = val2list(x)
55 |
56 | # repeat elements if necessary
57 | if len(x) > 0:
58 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
59 |
60 | return tuple(x)
61 |
62 |
63 | def squeeze_list(x: Optional[list]) -> Union[list, Any]:
64 | if x is not None and len(x) == 1:
65 | return x[0]
66 | else:
67 | return x
68 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/utils/random.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | from typing import Any, Optional, Union
18 |
19 | import numpy as np
20 | import torch
21 |
22 | __all__ = [
23 | "torch_randint",
24 | "torch_random",
25 | "torch_shuffle",
26 | "torch_uniform",
27 | "torch_random_choices",
28 | ]
29 |
30 |
31 | def torch_randint(low: int, high: int, generator: Optional[torch.Generator] = None) -> int:
32 | """uniform: [low, high)"""
33 | if low == high:
34 | return low
35 | else:
36 | assert low < high
37 | return int(torch.randint(low=low, high=high, generator=generator, size=(1,)))
38 |
39 |
40 | def torch_random(generator: Optional[torch.Generator] = None) -> float:
41 | """uniform distribution on the interval [0, 1)"""
42 | return float(torch.rand(1, generator=generator))
43 |
44 |
45 | def torch_shuffle(src_list: list[Any], generator: Optional[torch.Generator] = None) -> list[Any]:
46 | rand_indexes = torch.randperm(len(src_list), generator=generator).tolist()
47 | return [src_list[i] for i in rand_indexes]
48 |
49 |
50 | def torch_uniform(low: float, high: float, generator: Optional[torch.Generator] = None) -> float:
51 | """uniform distribution on the interval [low, high)"""
52 | rand_val = torch_random(generator)
53 | return (high - low) * rand_val + low
54 |
55 |
56 | def torch_random_choices(
57 | src_list: list[Any],
58 | generator: Optional[torch.Generator] = None,
59 | k=1,
60 | weight_list: Optional[list[float]] = None,
61 | ) -> Union[Any, list]:
62 | if weight_list is None:
63 | rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,))
64 | out_list = [src_list[i] for i in rand_idx]
65 | else:
66 | assert len(weight_list) == len(src_list)
67 | accumulate_weight_list = np.cumsum(weight_list)
68 |
69 | out_list = []
70 | for _ in range(k):
71 | val = torch_uniform(0, accumulate_weight_list[-1], generator)
72 | active_id = 0
73 | for i, weight_val in enumerate(accumulate_weight_list):
74 | active_id = i
75 | if weight_val > val:
76 | break
77 | out_list.append(src_list[active_id])
78 |
79 | return out_list[0] if k == 1 else out_list
80 |
--------------------------------------------------------------------------------
/diffusion/model/dc_ae/efficientvit/models/utils/video.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def chunked_interpolate(x, scale_factor, mode="nearest"):
8 | """
9 | Interpolate large tensors by chunking along the channel dimension. https://discuss.pytorch.org/t/error-using-f-interpolate-for-large-3d-input/207859
10 | Only supports 'nearest' interpolation mode.
11 |
12 | Args:
13 | x (torch.Tensor): Input tensor (B, C, D, H, W)
14 | scale_factor: Tuple of scaling factors (d, h, w)
15 |
16 | Returns:
17 | torch.Tensor: Interpolated tensor
18 | """
19 | assert (
20 | mode == "nearest"
21 | ), "Only the nearest mode is supported" # actually other modes are theoretically supported but not tested
22 | if len(x.shape) != 5:
23 | raise ValueError("Expected 5D input tensor (B, C, D, H, W)")
24 |
25 | # Calculate max chunk size to avoid int32 overflow. num_elements < max_int32
26 | # Max int32 is 2^31 - 1
27 | max_elements_per_chunk = 2**31 - 1
28 |
29 | # Calculate output spatial dimensions
30 | out_d = math.ceil(x.shape[2] * scale_factor[0])
31 | out_h = math.ceil(x.shape[3] * scale_factor[1])
32 | out_w = math.ceil(x.shape[4] * scale_factor[2])
33 |
34 | # Calculate max channels per chunk to stay under limit
35 | elements_per_channel = out_d * out_h * out_w
36 | max_channels = max_elements_per_chunk // (x.shape[0] * elements_per_channel)
37 |
38 | # Use smaller of max channels or input channels
39 | chunk_size = min(max_channels, x.shape[1])
40 |
41 | # Ensure at least 1 channel per chunk
42 | chunk_size = max(1, chunk_size)
43 |
44 | chunks = []
45 | for i in range(0, x.shape[1], chunk_size):
46 | start_idx = i
47 | end_idx = min(i + chunk_size, x.shape[1])
48 |
49 | chunk = x[:, start_idx:end_idx, :, :, :]
50 |
51 | interpolated_chunk = F.interpolate(chunk, scale_factor=scale_factor, mode="nearest")
52 |
53 | chunks.append(interpolated_chunk)
54 |
55 | if not chunks:
56 | raise ValueError(f"No chunks were generated. Input shape: {x.shape}")
57 |
58 | # Concatenate chunks along channel dimension
59 | return torch.cat(chunks, dim=1)
60 |
61 |
62 | def pixel_shuffle_3d(x, upscale_factor):
63 | """
64 | 3D pixelshuffle operation.
65 | """
66 | B, C, T, H, W = x.shape
67 | r = upscale_factor
68 | assert C % (r * r * r) == 0, "channel number must be a multiple of the cube of the upsampling factor"
69 |
70 | C_new = C // (r * r * r)
71 | x = x.view(B, C_new, r, r, r, T, H, W)
72 |
73 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
74 |
75 | y = x.reshape(B, C_new, T * r, H * r, W * r)
76 | return y
77 |
78 |
79 | def pixel_unshuffle_3d(x, downsample_factor):
80 | """
81 | 3D pixel unshuffle operation.
82 | """
83 | B, C, T, H, W = x.shape
84 |
85 | r = downsample_factor
86 | assert T % r == 0, f"time dimension must be a multiple of the downsampling factor, got shape {x.shape}"
87 | assert H % r == 0, f"height dimension must be a multiple of the downsampling factor, got shape {x.shape}"
88 | assert W % r == 0, f"width dimension must be a multiple of the downsampling factor, got shape {x.shape}"
89 | T_new = T // r
90 | H_new = H // r
91 | W_new = W // r
92 | C_new = C * (r * r * r)
93 |
94 | x = x.view(B, C, T_new, r, H_new, r, W_new, r)
95 | x = x.permute(0, 1, 3, 5, 7, 2, 4, 6)
96 | y = x.reshape(B, C_new, T_new, H_new, W_new)
97 | return y
98 |
99 |
100 | def ceil_to_divisible(n: int, dividend: int) -> int:
101 | return math.ceil(dividend / (dividend // n))
102 |
--------------------------------------------------------------------------------
/diffusion/model/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from .sana import (
2 | Sana,
3 | SanaBlock,
4 | get_1d_sincos_pos_embed_from_grid,
5 | get_2d_sincos_pos_embed,
6 | get_2d_sincos_pos_embed_from_grid,
7 | )
8 | from .sana_multi_scale import (
9 | SanaMS,
10 | SanaMS_600M_P1_D28,
11 | SanaMS_600M_P2_D28,
12 | SanaMS_600M_P4_D28,
13 | SanaMS_1600M_P1_D20,
14 | SanaMS_1600M_P2_D20,
15 | SanaMSBlock,
16 | )
17 | from .sana_multi_scale_adaln import (
18 | SanaMSAdaLN,
19 | SanaMSAdaLN_600M_P1_D28,
20 | SanaMSAdaLN_600M_P2_D28,
21 | SanaMSAdaLN_600M_P4_D28,
22 | SanaMSAdaLN_1600M_P1_D20,
23 | SanaMSAdaLN_1600M_P2_D20,
24 | SanaMSAdaLNBlock,
25 | )
26 | from .sana_multi_scale_controlnet import SanaMSControlNet_600M_P1_D28
27 | from .sana_multi_scale_video import (
28 | SanaMSVideo,
29 | SanaMSVideo_600M_P1_D28,
30 | SanaMSVideo_600M_P2_D28,
31 | SanaMSVideo_2000M_P1_D20,
32 | SanaMSVideo_2000M_P2_D20,
33 | )
34 | from .sana_U_shape import (
35 | SanaU,
36 | SanaU_600M_P1_D28,
37 | SanaU_600M_P2_D28,
38 | SanaU_600M_P4_D28,
39 | SanaU_1600M_P1_D20,
40 | SanaU_1600M_P2_D20,
41 | SanaUBlock,
42 | )
43 | from .sana_U_shape_multi_scale import (
44 | SanaUMS,
45 | SanaUMS_600M_P1_D28,
46 | SanaUMS_600M_P2_D28,
47 | SanaUMS_600M_P4_D28,
48 | SanaUMS_1600M_P1_D20,
49 | SanaUMS_1600M_P2_D20,
50 | SanaUMSBlock,
51 | )
52 |
--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | from .triton_lite_mla import *
18 | from .triton_lite_mla_fwd import *
19 | from .triton_mb_conv_pre_glu import *
20 |
21 | # from .flash_attn import *
22 |
--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/flash_attn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import torch
18 | from flash_attn import flash_attn_func
19 | from torch import nn
20 | from torch.nn import functional as F
21 |
22 |
23 | class FlashAttention(nn.Module):
24 | def __init__(self, dim: int, num_heads: int):
25 | super().__init__()
26 | self.dim = dim
27 | assert dim % num_heads == 0
28 | self.num_heads = num_heads
29 | self.head_dim = dim // num_heads
30 |
31 | self.qkv = nn.Linear(dim, dim * 3, bias=False)
32 | self.proj_out = torch.nn.Linear(dim, dim)
33 |
34 | def forward(self, x):
35 | B, N, C = x.shape
36 | qkv = self.qkv(x).view(B, N, 3, C) # B, N, 3, C
37 | q, k, v = qkv.unbind(2) # B, N, C
38 | k = k.reshape(B, N, self.num_heads, self.head_dim)
39 | v = v.reshape(B, N, self.num_heads, self.head_dim)
40 | q = q.reshape(B, N, self.num_heads, self.head_dim)
41 | out = flash_attn_func(q, k, v) # B, N, H, c
42 | out = self.proj_out(out.view(B, N, C)) # B, N, C
43 | return out
44 |
--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/nn/act.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import copy
18 |
19 | import torch.nn as nn
20 |
21 | __all__ = ["build_act", "get_act_name"]
22 |
23 | # register activation function here
24 | # name: module, kwargs with default values
25 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = {
26 | "relu": (nn.ReLU, {"inplace": True}),
27 | "relu6": (nn.ReLU6, {"inplace": True}),
28 | "hswish": (nn.Hardswish, {"inplace": True}),
29 | "hsigmoid": (nn.Hardsigmoid, {"inplace": True}),
30 | "swish": (nn.SiLU, {"inplace": True}),
31 | "silu": (nn.SiLU, {"inplace": True}),
32 | "tanh": (nn.Tanh, {}),
33 | "sigmoid": (nn.Sigmoid, {}),
34 | "gelu": (nn.GELU, {"approximate": "tanh"}),
35 | "mish": (nn.Mish, {"inplace": True}),
36 | "identity": (nn.Identity, {}),
37 | }
38 |
39 |
40 | def build_act(name: str or None, **kwargs) -> nn.Module or None:
41 | if name in REGISTERED_ACT_DICT:
42 | act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name])
43 | for key in default_args:
44 | if key in kwargs:
45 | default_args[key] = kwargs[key]
46 | return act_cls(**default_args)
47 | elif name is None or name.lower() == "none":
48 | return None
49 | else:
50 | raise ValueError(f"do not support: {name}")
51 |
52 |
53 | def get_act_name(act: nn.Module or None) -> str or None:
54 | if act is None:
55 | return None
56 | module2name = {}
57 | for key, config in REGISTERED_ACT_DICT.items():
58 | module2name[config[0].__name__] = key
59 | return module2name.get(type(act).__name__, "unknown")
60 |
--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/nn/conv.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import torch
18 | from torch import nn
19 |
20 | from ..utils.model import get_same_padding
21 | from .act import build_act, get_act_name
22 | from .norm import build_norm, get_norm_name
23 |
24 |
25 | class ConvLayer(nn.Module):
26 | def __init__(
27 | self,
28 | in_dim: int,
29 | out_dim: int,
30 | kernel_size=3,
31 | stride=1,
32 | dilation=1,
33 | groups=1,
34 | padding: int or None = None,
35 | use_bias=False,
36 | dropout=0.0,
37 | norm="bn2d",
38 | act="relu",
39 | ):
40 | super().__init__()
41 | if padding is None:
42 | padding = get_same_padding(kernel_size)
43 | padding *= dilation
44 |
45 | self.in_dim = in_dim
46 | self.out_dim = out_dim
47 | self.kernel_size = kernel_size
48 | self.stride = stride
49 | self.dilation = dilation
50 | self.groups = groups
51 | self.padding = padding
52 | self.use_bias = use_bias
53 |
54 | self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
55 | self.conv = nn.Conv2d(
56 | in_dim,
57 | out_dim,
58 | kernel_size=(kernel_size, kernel_size),
59 | stride=(stride, stride),
60 | padding=padding,
61 | dilation=(dilation, dilation),
62 | groups=groups,
63 | bias=use_bias,
64 | )
65 | self.norm = build_norm(norm, num_features=out_dim)
66 | self.act = build_act(act)
67 |
68 | def forward(self, x: torch.Tensor) -> torch.Tensor:
69 | if self.dropout is not None:
70 | x = self.dropout(x)
71 | x = self.conv(x)
72 | if self.norm:
73 | x = self.norm(x)
74 | if self.act:
75 | x = self.act(x)
76 | return x
77 |
--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/utils/compare_results.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import torch
18 |
19 |
20 | def compare_results(name: str, result: torch.Tensor, ref_result: torch.Tensor):
21 | print(f"comparing {name}")
22 | diff = (result - ref_result).abs().view(-1)
23 | max_error_pos = diff.argmax()
24 | print(f"max error: {diff.max()}, mean error: {diff.mean()}")
25 | print(f"max error pos: {result.view(-1)[max_error_pos]} {ref_result.view(-1)[max_error_pos]}")
26 |
--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/utils/dtype.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import torch
18 | import triton
19 | import triton.language as tl
20 |
21 |
22 | def get_dtype_from_str(dtype: str) -> torch.dtype:
23 | if dtype == "fp32":
24 | return torch.float32
25 | if dtype == "fp16":
26 | return torch.float16
27 | if dtype == "bf16":
28 | return torch.bfloat16
29 | raise NotImplementedError(f"dtype {dtype} is not supported")
30 |
31 |
32 | def get_tl_dtype_from_torch_dtype(dtype: torch.dtype) -> tl.dtype:
33 | if dtype == torch.float32:
34 | return tl.float32
35 | if dtype == torch.float16:
36 | return tl.float16
37 | if dtype == torch.bfloat16:
38 | return tl.bfloat16
39 | raise NotImplementedError(f"dtype {dtype} is not supported")
40 |
--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/utils/export_onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import os
18 | import warnings
19 | from typing import Any, Tuple
20 |
21 | import torch
22 |
23 |
24 | def export_onnx(
25 | model: torch.nn.Module,
26 | input_shape: Tuple[int],
27 | export_path: str,
28 | opset: int,
29 | export_dtype: torch.dtype,
30 | export_device: torch.device,
31 | ) -> None:
32 | model.eval()
33 |
34 | dummy_input = {"x": torch.randn(input_shape, dtype=export_dtype, device=export_device)}
35 | dynamic_axes = {
36 | "x": {0: "batch_size"},
37 | }
38 |
39 | # _ = model(**dummy_input)
40 |
41 | output_names = ["image_embeddings"]
42 |
43 | export_dir = os.path.dirname(export_path)
44 | if not os.path.exists(export_dir):
45 | os.makedirs(export_dir)
46 |
47 | with warnings.catch_warnings():
48 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
49 | warnings.filterwarnings("ignore", category=UserWarning)
50 | print(f"Exporting onnx model to {export_path}...")
51 | with open(export_path, "wb") as f:
52 | torch.onnx.export(
53 | model,
54 | tuple(dummy_input.values()),
55 | f,
56 | export_params=True,
57 | verbose=False,
58 | opset_version=opset,
59 | do_constant_folding=True,
60 | input_names=list(dummy_input.keys()),
61 | output_names=output_names,
62 | dynamic_axes=dynamic_axes,
63 | )
64 |
--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/modules/utils/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 MIT Han Lab
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 |
18 | def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
19 | """Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
20 | if isinstance(x, (list, tuple)):
21 | return list(x)
22 | return [x for _ in range(repeat_time)]
23 |
24 |
25 | def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
26 | """Return tuple with min_len by repeating element at idx_repeat."""
27 | # convert to list first
28 | x = val2list(x)
29 |
30 | # repeat elements if necessary
31 | if len(x) > 0:
32 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
33 |
34 | return tuple(x)
35 |
36 |
37 | def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
38 | if isinstance(kernel_size, tuple):
39 | return tuple([get_same_padding(ks) for ks in kernel_size])
40 | else:
41 | assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
42 | return kernel_size // 2
43 |
--------------------------------------------------------------------------------
/diffusion/model/nets/fastlinear/readme.md:
--------------------------------------------------------------------------------
1 | # a fast implementation of linear attention
2 |
3 | ## 64x64, fp16
4 |
5 | ```bash
6 | # validate correctness
7 | ## fp16 vs fp32
8 | python -m develop_triton_litemla attn_type=LiteMLA test_correctness=True
9 | ## triton fp16 vs fp32
10 | python -m develop_triton_litemla attn_type=TritonLiteMLA test_correctness=True
11 |
12 | # test performance
13 | ## fp16, forward
14 | python -m develop_triton_litemla attn_type=LiteMLA
15 | each step takes 10.81 ms
16 | max memory allocated: 2.2984 GB
17 |
18 | ## triton fp16, forward
19 | python -m develop_triton_litemla attn_type=TritonLiteMLA
20 | each step takes 4.70 ms
21 | max memory allocated: 1.6480 GB
22 |
23 | ## fp16, backward
24 | python -m develop_triton_litemla attn_type=LiteMLA backward=True
25 | each step takes 35.34 ms
26 | max memory allocated: 3.4412 GB
27 |
28 | ## triton fp16, backward
29 | python -m develop_triton_litemla attn_type=TritonLiteMLA backward=True
30 | each step takes 14.25 ms
31 | max memory allocated: 2.4704 GB
32 | ```
33 |
--------------------------------------------------------------------------------
/diffusion/model/nets/sana_ladd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 | from diffusion.model.builder import MODELS
21 |
22 | from .ladd_blocks import DiscHead
23 | from .sana_multi_scale import SanaMSCM
24 |
25 |
26 | @MODELS.register_module()
27 | class SanaMSCMDiscriminator(nn.Module):
28 | def __init__(self, pretrained_model: SanaMSCM, is_multiscale=False, head_block_ids=None):
29 | super().__init__()
30 | self.transformer = pretrained_model
31 | self.transformer.requires_grad_(False)
32 |
33 | if head_block_ids is None or len(head_block_ids) == 0:
34 | self.block_hooks = {2, 8, 14, 20, 27} if is_multiscale else {self.transformer.depth - 1}
35 | else:
36 | self.block_hooks = head_block_ids
37 |
38 | heads = []
39 | for i in range(len(self.block_hooks)):
40 | heads.append(DiscHead(self.transformer.hidden_size, 0, 0))
41 | self.heads = nn.ModuleList(heads)
42 |
43 | def get_head_inputs(self):
44 | return self.head_inputs
45 |
46 | def forward(self, x, timestep, y=None, data_info=None, mask=None, **kwargs):
47 | feat_list = []
48 | self.head_inputs = []
49 |
50 | def get_features(module, input, output):
51 | feat_list.append(output)
52 | return output
53 |
54 | hooks = []
55 | for i, block in enumerate(self.transformer.blocks):
56 | if i in self.block_hooks:
57 | hooks.append(block.register_forward_hook(get_features))
58 |
59 | self.transformer(x, timestep, y=y, mask=mask, data_info=data_info, return_logvar=False, **kwargs)
60 |
61 | for hook in hooks:
62 | hook.remove()
63 |
64 | res_list = []
65 | for feat, head in zip(feat_list, self.heads):
66 | B, N, C = feat.shape
67 | feat = feat.transpose(1, 2) # [B, C, N]
68 | self.head_inputs.append(feat)
69 | res_list.append(head(feat, None).reshape(feat.shape[0], -1))
70 |
71 | concat_res = torch.cat(res_list, dim=1)
72 |
73 | return concat_res
74 |
75 | @property
76 | def model(self):
77 | return self.transformer
78 |
79 | def save_pretrained(self, path):
80 | torch.save(self.state_dict(), path)
81 |
82 |
83 | class DiscHeadModel:
84 | def __init__(self, disc):
85 | self.disc = disc
86 |
87 | def state_dict(self):
88 | return {name: param for name, param in self.disc.state_dict().items() if not name.startswith("transformer.")}
89 |
90 | def __getattr__(self, name):
91 | return getattr(self.disc, name)
92 |
--------------------------------------------------------------------------------
/diffusion/model/wan/__init__.py:
--------------------------------------------------------------------------------
1 | from .attention import flash_attention
2 | from .model import WanLinearAttentionModel, WanModel, init_model_configs
3 | from .model_wrapper import SanaVideoMSBlock, SanaWanLinearAttentionModel, SanaWanModel
4 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
5 | from .tokenizers import HuggingfaceTokenizer
6 | from .vae import WanVAE
7 |
8 | __all__ = [
9 | "WanVAE",
10 | "WanModel",
11 | "WanLinearAttentionModel",
12 | "init_model_configs",
13 | "T5Model",
14 | "T5Encoder",
15 | "T5Decoder",
16 | "T5EncoderModel",
17 | "HuggingfaceTokenizer",
18 | "flash_attention",
19 | "SanaWanLinearAttentionModel",
20 | "SanaWanModel",
21 | "SanaVideoMSBlock",
22 | ]
23 |
--------------------------------------------------------------------------------
/diffusion/model/wan/fsdp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import gc
3 | from functools import partial
4 |
5 | import torch
6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9 | from torch.distributed.utils import _free_storage
10 |
11 |
12 | def shard_model(
13 | model,
14 | device_id,
15 | param_dtype=torch.bfloat16,
16 | reduce_dtype=torch.float32,
17 | buffer_dtype=torch.float32,
18 | process_group=None,
19 | sharding_strategy=ShardingStrategy.FULL_SHARD,
20 | sync_module_states=True,
21 | ):
22 | """
23 | This is the shard function for T5 model.
24 | """
25 | model = FSDP(
26 | module=model,
27 | process_group=process_group,
28 | sharding_strategy=sharding_strategy,
29 | auto_wrap_policy=partial(lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
30 | mixed_precision=MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype),
31 | device_id=device_id,
32 | sync_module_states=sync_module_states,
33 | )
34 | return model
35 |
36 |
37 | def free_model(model):
38 | for m in model.modules():
39 | if isinstance(m, FSDP):
40 | _free_storage(m._handle.flat_param.data)
41 | del model
42 | gc.collect()
43 | torch.cuda.empty_cache()
44 |
--------------------------------------------------------------------------------
/diffusion/model/wan/model_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .model import WanAttentionBlock, WanLinearAttentionModel, WanModel
4 |
5 |
6 | class SanaVideoMSBlock(WanAttentionBlock):
7 | pass
8 |
9 |
10 | class SanaWanModel(WanModel):
11 | def __init__(self, *args, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | cross_attn_type = "t2v_cross_attn" if self.model_type == "t2v" else "i2v_cross_attn"
14 | self.blocks = nn.ModuleList(
15 | [
16 | SanaVideoMSBlock(
17 | cross_attn_type,
18 | self.dim,
19 | self.ffn_dim,
20 | self.num_heads,
21 | self.window_size,
22 | self.qk_norm,
23 | self.cross_attn_norm,
24 | self.eps,
25 | )
26 | for _ in range(self.num_layers)
27 | ]
28 | )
29 |
30 |
31 | class SanaWanLinearAttentionModel(WanLinearAttentionModel):
32 | def __init__(self, *args, **kwargs):
33 | super().__init__(*args, **kwargs)
34 | cross_attn_type = "t2v_cross_attn" if self.model_type == "t2v" else "i2v_cross_attn"
35 | self_attn_types = ["flash"] * self.num_layers
36 | ffn_types = ["mlp"] * self.num_layers
37 |
38 | self.blocks = nn.ModuleList(
39 | [
40 | SanaVideoMSBlock(
41 | cross_attn_type,
42 | self.dim,
43 | self.ffn_dim,
44 | self.num_heads,
45 | self.window_size,
46 | self.qk_norm,
47 | self.cross_attn_norm,
48 | self.eps,
49 | self_attn_types[i],
50 | self.rope_after,
51 | self.power,
52 | ffn_types[i],
53 | )
54 | for i in range(self.num_layers)
55 | ]
56 | )
57 |
--------------------------------------------------------------------------------
/diffusion/model/wan/tokenizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import html
3 | import string
4 |
5 | import ftfy
6 | import regex as re
7 | from transformers import AutoTokenizer
8 |
9 | __all__ = ["HuggingfaceTokenizer"]
10 |
11 |
12 | def basic_clean(text):
13 | text = ftfy.fix_text(text)
14 | text = html.unescape(html.unescape(text))
15 | return text.strip()
16 |
17 |
18 | def whitespace_clean(text):
19 | text = re.sub(r"\s+", " ", text)
20 | text = text.strip()
21 | return text
22 |
23 |
24 | def canonicalize(text, keep_punctuation_exact_string=None):
25 | text = text.replace("_", " ")
26 | if keep_punctuation_exact_string:
27 | text = keep_punctuation_exact_string.join(
28 | part.translate(str.maketrans("", "", string.punctuation))
29 | for part in text.split(keep_punctuation_exact_string)
30 | )
31 | else:
32 | text = text.translate(str.maketrans("", "", string.punctuation))
33 | text = text.lower()
34 | text = re.sub(r"\s+", " ", text)
35 | return text.strip()
36 |
37 |
38 | class HuggingfaceTokenizer:
39 | def __init__(self, name, seq_len=None, clean=None, **kwargs):
40 | assert clean in (None, "whitespace", "lower", "canonicalize")
41 | self.name = name
42 | self.seq_len = seq_len
43 | self.clean = clean
44 |
45 | # init tokenizer
46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47 | self.vocab_size = self.tokenizer.vocab_size
48 |
49 | def __call__(self, sequence, **kwargs):
50 | return_mask = kwargs.pop("return_mask", False)
51 |
52 | # arguments
53 | _kwargs = {"return_tensors": "pt"}
54 | if self.seq_len is not None:
55 | _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
56 | _kwargs.update(**kwargs)
57 |
58 | # tokenization
59 | if isinstance(sequence, str):
60 | sequence = [sequence]
61 | if self.clean:
62 | sequence = [self._clean(u) for u in sequence]
63 | ids = self.tokenizer(sequence, **_kwargs)
64 |
65 | # output
66 | if return_mask:
67 | return ids.input_ids, ids.attention_mask
68 | else:
69 | return ids.input_ids
70 |
71 | def _clean(self, text):
72 | if self.clean == "whitespace":
73 | text = whitespace_clean(basic_clean(text))
74 | elif self.clean == "lower":
75 | text = whitespace_clean(basic_clean(text)).lower()
76 | elif self.clean == "canonicalize":
77 | text = canonicalize(basic_clean(text))
78 | return text
79 |
--------------------------------------------------------------------------------
/diffusion/scheduler/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/scheduler/__init__.py
--------------------------------------------------------------------------------
/diffusion/scheduler/dpm_solver.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import torch
18 |
19 | from diffusion.model import gaussian_diffusion as gd
20 | from diffusion.model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper
21 |
22 |
23 | def DPMS(
24 | model,
25 | condition,
26 | uncondition,
27 | cfg_scale,
28 | pag_scale=1.0,
29 | pag_applied_layers=None,
30 | model_type="noise", # or "x_start" or "v" or "score", "flow"
31 | noise_schedule="linear",
32 | guidance_type="classifier-free",
33 | model_kwargs=None,
34 | diffusion_steps=1000,
35 | schedule="VP",
36 | interval_guidance=None,
37 | condition_as_list=False, # for wan text encoder, set to true
38 | apg=None,
39 | **kwargs,
40 | ):
41 | if pag_applied_layers is None:
42 | pag_applied_layers = []
43 | if model_kwargs is None:
44 | model_kwargs = {}
45 | if interval_guidance is None:
46 | interval_guidance = [0, 1.0]
47 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
48 |
49 | ## 1. Define the noise schedule.
50 | if schedule == "VP":
51 | noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas)
52 | elif schedule == "FLOW":
53 | noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
54 |
55 | ## 2. Convert your discrete-time `model` to the continuous-time
56 | ## noise prediction model. Here is an example for a diffusion model
57 | ## `model` with the noise prediction type ("noise") .
58 | model_fn = model_wrapper(
59 | model,
60 | noise_schedule,
61 | model_type=model_type,
62 | model_kwargs=model_kwargs,
63 | guidance_type=guidance_type,
64 | pag_scale=pag_scale,
65 | pag_applied_layers=pag_applied_layers,
66 | condition=condition,
67 | unconditional_condition=uncondition,
68 | guidance_scale=cfg_scale,
69 | interval_guidance=interval_guidance,
70 | condition_as_list=condition_as_list,
71 | apg=apg,
72 | **kwargs,
73 | )
74 | ## 3. Define dpm-solver and sample by multistep DPM-Solver.
75 | return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
76 |
--------------------------------------------------------------------------------
/diffusion/scheduler/iddpm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | # Modified from OpenAI's diffusion repos
18 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
19 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
20 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
21 |
22 | from diffusion.model import gaussian_diffusion as gd
23 | from diffusion.model.respace import SpacedDiffusion, space_timesteps
24 |
25 |
26 | def Scheduler(
27 | timestep_respacing,
28 | noise_schedule="linear",
29 | use_kl=False,
30 | sigma_small=False,
31 | predict_xstart=False,
32 | predict_flow_v=False,
33 | learn_sigma=True,
34 | pred_sigma=True,
35 | rescale_learned_sigmas=False,
36 | diffusion_steps=1000,
37 | snr=False,
38 | return_startx=False,
39 | flow_shift=1.0,
40 | ):
41 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
42 | if use_kl:
43 | loss_type = gd.LossType.RESCALED_KL
44 | elif rescale_learned_sigmas:
45 | loss_type = gd.LossType.RESCALED_MSE
46 | else:
47 | loss_type = gd.LossType.MSE
48 | if timestep_respacing is None or timestep_respacing == "":
49 | timestep_respacing = [diffusion_steps]
50 | if predict_xstart:
51 | model_mean_type = gd.ModelMeanType.START_X
52 | elif predict_flow_v:
53 | model_mean_type = gd.ModelMeanType.FLOW_VELOCITY
54 | else:
55 | model_mean_type = gd.ModelMeanType.EPSILON
56 | return SpacedDiffusion(
57 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
58 | betas=betas,
59 | model_mean_type=model_mean_type,
60 | model_var_type=(
61 | (
62 | (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
63 | if not learn_sigma
64 | else gd.ModelVarType.LEARNED_RANGE
65 | )
66 | if pred_sigma
67 | else None
68 | ),
69 | loss_type=loss_type,
70 | snr=snr,
71 | return_startx=return_startx,
72 | # rescale_timesteps=rescale_timesteps,
73 | flow="flow" in noise_schedule,
74 | flow_shift=flow_shift,
75 | diffusion_steps=diffusion_steps,
76 | )
77 |
--------------------------------------------------------------------------------
/diffusion/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/diffusion/utils/__init__.py
--------------------------------------------------------------------------------
/environment_setup.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | # Check if we should skip environment setup entirely
5 | if [ "${SKIP_ENV_SETUP}" = "true" ]; then
6 | echo "SKIP_ENV_SETUP is set to true. Skipping all environment setup steps."
7 | echo "Using default conda environment. Make sure it has all required packages installed."
8 | exit 0
9 | fi
10 |
11 | CONDA_ENV=${1:-""}
12 | if [ -n "$CONDA_ENV" ]; then
13 | # This is required to activate conda environment
14 | eval "$(conda shell.bash hook)"
15 |
16 | conda create -n $CONDA_ENV python=3.10.0 -y
17 | conda activate $CONDA_ENV
18 | # This is optional if you prefer to use built-in nvcc
19 | conda install -c nvidia cuda-toolkit=12.8 -y
20 | else
21 | echo "Skipping conda environment creation. Make sure you have the correct environment activated."
22 | fi
23 |
24 | # init a raw torch to avoid installation errors.
25 | # pip install torch
26 |
27 | # update pip to latest version for pyproject.toml setup.
28 | pip install -U pip
29 |
30 | # for fast attn
31 | pip install -U xformers==0.0.32.post2 --index-url https://download.pytorch.org/whl/cu128
32 |
33 | # install sana
34 | pip install -e .
35 |
36 | pip install flash-attn==2.8.2 --no-build-isolation
37 |
38 | # install torchprofile
39 | # pip install git+https://github.com/zhijian-liu/torchprofile
40 |
--------------------------------------------------------------------------------
/inference_video_scripts/inference_sana_video.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | inference_script=inference_video_scripts/inference_sana_video.py
5 | config=""
6 | model_path=""
7 | np=8
8 | negative_prompt=""
9 |
10 | while [[ $# -gt 0 ]]; do
11 | case $1 in
12 | --config=*)
13 | config="${1#*=}"
14 | shift
15 | ;;
16 | --config)
17 | config="$2"
18 | shift 2
19 | ;;
20 | --model_path=*)
21 | model_path="${1#*=}"
22 | shift
23 | ;;
24 | --model_path)
25 | model_path="$2"
26 | shift 2
27 | ;;
28 | --inference_script=*)
29 | inference_script="${1#*=}"
30 | shift
31 | ;;
32 | --inference_script)
33 | inference_script="$2"
34 | shift 2
35 | ;;
36 | --np=*)
37 | np="${1#*=}"
38 | shift
39 | ;;
40 | --np)
41 | np="$2"
42 | shift 2
43 | ;;
44 | --negative_prompt=*)
45 | negative_prompt="${1#*=}"
46 | shift
47 | ;;
48 | --negative_prompt)
49 | negative_prompt="$2"
50 | shift 2
51 | ;;
52 | *)
53 | other_args+=("$1")
54 | shift
55 | ;;
56 | esac
57 | done
58 |
59 | cmd=(
60 | accelerate launch --num_processes="$np" --num_machines=1 --mixed_precision=bf16 --main_process_port="$RANDOM"
61 | "$inference_script"
62 | --config="$config"
63 | --model_path="$model_path"
64 | --txt_file=asset/samples/video_prompts_samples.txt
65 | --dataset=video_samples
66 | )
67 |
68 | if [[ -n "$negative_prompt" ]]; then
69 | cmd+=(--negative_prompt="$negative_prompt")
70 | fi
71 |
72 | if [[ ${#other_args[@]} -gt 0 ]]; then
73 | cmd+=("${other_args[@]}")
74 | fi
75 |
76 | printf -v cmd_str '%q ' "${cmd[@]}"
77 | echo "$cmd_str"
78 |
79 | "${cmd[@]}"
80 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "sana"
7 | version = "0.2.0"
8 | description = "SANA"
9 | readme = "README.md"
10 | requires-python = ">=3.10"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "pre-commit",
17 | "huggingface-hub==0.36.0",
18 | "accelerate==1.0.1",
19 | "beautifulsoup4",
20 | "bs4",
21 | "came-pytorch",
22 | "einops",
23 | "ftfy",
24 | "wheel",
25 | "psutil",
26 | "ninja",
27 | "diffusers==0.35.0",
28 | "clip@git+https://github.com/openai/CLIP.git",
29 | "gradio",
30 | "image-reward",
31 | "ipdb",
32 | "mmcv==1.7.2",
33 | "omegaconf",
34 | "opencv-python",
35 | "optimum",
36 | "patch_conv",
37 | "peft==0.17.0",
38 | "protobuf",
39 | "pytorch-fid",
40 | "regex",
41 | "sentencepiece",
42 | "tensorboard",
43 | "tensorboardX",
44 | "timm==0.6.13",
45 | "torchaudio==2.8.0",
46 | "torchvision==0.23.0",
47 | "transformers==4.57.0",
48 | "triton==3.4.0",
49 | "wandb",
50 | "webdataset",
51 | "xformers==0.0.32.post2",
52 | "yapf",
53 | "spaces",
54 | "matplotlib",
55 | "termcolor",
56 | "pyrallis",
57 | "bitsandbytes",
58 | "fire",
59 | "moviepy",
60 | "imageio[pyav,ffmpeg]",
61 | "qwen-vl-utils",
62 | ]
63 |
64 |
65 | [project.scripts]
66 | sana-run = "sana.cli.run:main"
67 | sana-upload = "sana.cli.upload2hf:main"
68 |
69 | [project.optional-dependencies]
70 |
71 | [project.urls]
72 |
73 | [tool.pip]
74 | extra-index-url = ["https://download.pytorch.org/whl/cu128"]
75 |
76 | [tool.black]
77 | line-length = 120
78 |
79 | [tool.isort]
80 | profile = "black"
81 | multi_line_output = 3
82 | include_trailing_comma = true
83 | force_grid_wrap = 0
84 | use_parentheses = true
85 | ensure_newline_before_comments = true
86 | line_length = 120
87 |
88 | [tool.setuptools.packages.find]
89 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*", "data*", "output*"]
90 |
91 | [tool.wheel]
92 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*", "data*", "output*"]
93 |
--------------------------------------------------------------------------------
/sana/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .download import download_model
2 | from .hf_utils import hf_download_or_fpath
3 |
--------------------------------------------------------------------------------
/sana/tools/download.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Functions for downloading pre-trained Sana models
9 | """
10 | import argparse
11 | import os
12 |
13 | import torch
14 | from torchvision.datasets.utils import download_url
15 |
16 | pretrained_models = {}
17 |
18 |
19 | def find_model(model_name):
20 | """
21 | Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path.
22 | """
23 | if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints
24 | return download_model(model_name)
25 | else: # Load a custom Sana checkpoint:
26 | assert os.path.isfile(model_name), f"Could not find Sana checkpoint at {model_name}"
27 | return torch.load(model_name, map_location=lambda storage, loc: storage)
28 |
29 |
30 | def download_model(model_name):
31 | """
32 | Downloads a pre-trained Sana model from the web.
33 | """
34 | assert model_name in pretrained_models
35 | local_path = f"output/pretrained_models/{model_name}"
36 | if not os.path.isfile(local_path):
37 | hf_endpoint = os.environ.get("HF_ENDPOINT")
38 | if hf_endpoint is None:
39 | hf_endpoint = "https://huggingface.co"
40 | os.makedirs("output/pretrained_models", exist_ok=True)
41 | web_path = f"{hf_endpoint}/xxx/resolve/main/{model_name}"
42 | download_url(web_path, "output/pretrained_models/")
43 | model = torch.load(local_path, map_location=lambda storage, loc: storage)
44 | return model
45 |
46 |
47 | if __name__ == "__main__":
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument("--model_names", nargs="+", type=str, default=pretrained_models)
50 | args = parser.parse_args()
51 | model_names = args.model_names
52 | model_names = set(model_names)
53 |
54 | # Download Sana checkpoints
55 | for model in model_names:
56 | download_model(model)
57 | print("Done.")
58 |
--------------------------------------------------------------------------------
/sana/tools/hf_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import os
18 | import os.path as osp
19 | import sys
20 |
21 | from huggingface_hub import hf_hub_download, snapshot_download
22 |
23 |
24 | def hf_download_or_fpath(path):
25 | if osp.exists(path):
26 | return path
27 |
28 | if path.startswith("hf://"):
29 | segs = path.replace("hf://", "").split("/")
30 | repo_id = "/".join(segs[:2])
31 | filename = "/".join(segs[2:])
32 | return hf_download_data(repo_id, filename, repo_type="model", download_full_repo=True)
33 |
34 |
35 | def hf_download_data(
36 | repo_id="Efficient-Large-Model/Sana_1600M_1024px",
37 | filename="checkpoints/Sana_1600M_1024px.pth",
38 | cache_dir=None,
39 | repo_type="model",
40 | download_full_repo=False,
41 | ):
42 | """
43 | Download dummy data from a Hugging Face repository.
44 |
45 | Args:
46 | repo_id (str): The ID of the Hugging Face repository.
47 | filename (str): The name of the file to download.
48 | cache_dir (str, optional): The directory to cache the downloaded file.
49 |
50 | Returns:
51 | str: The path to the downloaded file.
52 | """
53 | try:
54 | if download_full_repo:
55 | # download full repos to fit dc-ae
56 | snapshot_download(
57 | repo_id=repo_id,
58 | cache_dir=cache_dir,
59 | repo_type=repo_type,
60 | )
61 | file_path = hf_hub_download(
62 | repo_id=repo_id,
63 | filename=filename,
64 | cache_dir=cache_dir,
65 | repo_type=repo_type,
66 | )
67 | return file_path
68 | except Exception as e:
69 | print(f"Error downloading file: {e}")
70 | return None
71 |
--------------------------------------------------------------------------------
/scripts/infer_run_inference_geneval_diffusers.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # ================= sampler & data =================
3 | np=8 # number of GPU to use
4 | default_step=20 # 14
5 | default_sample_nums=553
6 | default_sampling_algo="dpm-solver"
7 | default_add_label=''
8 |
9 | # parser
10 | config_file=$1
11 | model_paths=$2
12 |
13 | for arg in "$@"
14 | do
15 | case $arg in
16 | --step=*)
17 | step="${arg#*=}"
18 | shift
19 | ;;
20 | --sampling_algo=*)
21 | sampling_algo="${arg#*=}"
22 | shift
23 | ;;
24 | --add_label=*)
25 | add_label="${arg#*=}"
26 | shift
27 | ;;
28 | --model_path=*)
29 | model_paths="${arg#*=}"
30 | shift
31 | ;;
32 | --exist_time_prefix=*)
33 | exist_time_prefix="${arg#*=}"
34 | shift
35 | ;;
36 | --if_save_dirname=*)
37 | if_save_dirname="${arg#*=}"
38 | shift
39 | ;;
40 | *)
41 | ;;
42 | esac
43 | done
44 |
45 | sample_nums=$default_sample_nums
46 | samples_per_gpu=$((sample_nums / np))
47 | add_label=${add_label:-$default_add_label}
48 | echo "Sample numbers: $sample_nums"
49 | echo "Add label: $add_label"
50 | echo "Exist time prefix: $exist_time_prefix"
51 |
52 | cmd_template="DPM_TQDM=True python scripts/inference_geneval_diffusers.py \
53 | --model_path=$model_paths \
54 | --gpu_id {gpu_id} --start_index {start_index} --end_index {end_index}"
55 | if [ -n "${add_label}" ]; then
56 | cmd_template="${cmd_template} --add_label ${add_label}"
57 | fi
58 |
59 | echo "==================== inferencing ===================="
60 | for gpu_id in $(seq 0 $((np - 1))); do
61 | start_index=$((gpu_id * samples_per_gpu))
62 | end_index=$((start_index + samples_per_gpu))
63 | if [ $gpu_id -eq $((np - 1)) ]; then
64 | end_index=$sample_nums
65 | fi
66 |
67 | cmd="${cmd_template//\{config_file\}/$config_file}"
68 | cmd="${cmd//\{model_path\}/$model_paths}"
69 | cmd="${cmd//\{gpu_id\}/$gpu_id}"
70 | cmd="${cmd//\{start_index\}/$start_index}"
71 | cmd="${cmd//\{end_index\}/$end_index}"
72 |
73 | echo "Running on GPU $gpu_id: samples $start_index to $end_index"
74 | eval CUDA_VISIBLE_DEVICES=$gpu_id $cmd &
75 | done
76 | wait
77 |
78 | echo infer finally done
79 |
--------------------------------------------------------------------------------
/scripts/style.css:
--------------------------------------------------------------------------------
1 | /*.gradio-container{width:680px!important}*/
2 | /* style.css */
3 | .gradio_group, .gradio_row, .gradio_column {
4 | display: flex;
5 | flex-direction: row;
6 | justify-content: flex-start;
7 | align-items: flex-start;
8 | flex-wrap: wrap;
9 | }
10 |
--------------------------------------------------------------------------------
/tests/bash/entry.sh:
--------------------------------------------------------------------------------
1 | #/bin/bash
2 | set -e
3 |
4 |
5 | echo "Testing inference"
6 | bash tests/bash/inference/test_inference.sh
7 |
8 | echo "Testing training"
9 | bash tests/bash/training/test_training_all.sh
10 |
--------------------------------------------------------------------------------
/tests/bash/inference/test_inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | python scripts/inference.py \
5 | --config=configs/sana_config/1024ms/Sana_600M_img1024.yaml \
6 | --model_path=hf://Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth
7 |
8 | python scripts/inference.py \
9 | --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
10 | --model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
11 |
12 | python tools/controlnet/inference_controlnet.py \
13 | --config=configs/sana_controlnet_config/Sana_600M_img1024_controlnet.yaml \
14 | --model_path=hf://Efficient-Large-Model/Sana_600M_1024px_ControlNet_HED/checkpoints/Sana_600M_1024px_ControlNet_HED.pth \
15 | --json_file=asset/controlnet/samples_controlnet.json
16 |
17 | python scripts/inference_sana_sprint.py \
18 | --config=configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml \
19 | --model_path=hf://Lawrence-cj/Sana_Sprint_1600M_1024px/Sana_Sprint_1600M_1024px_36K.pth \
20 | --txt_file=asset/samples/samples_mini.txt
21 |
22 | python inference_video_scripts/inference_sana_video.py \
23 | --config=configs/sana_video_config/Sana_2000M_256px_AdamW_fsdp.yaml \
24 | --model_path=hf://Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth \
25 | --debug=true
26 |
--------------------------------------------------------------------------------
/tests/bash/setup_test_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | pip install --upgrade "huggingface-hub<1.0"
5 |
6 | # download test data
7 | mkdir -p data/data_public
8 | hf download Efficient-Large-Model/sana_data_public --repo-type dataset --local-dir ./data/data_public
9 | hf download Efficient-Large-Model/toy_data --repo-type dataset --local-dir ./data/toy_data
10 | hf download Efficient-Large-Model/video_toy_data --repo-type dataset --local-dir ./data/video_toy_data
11 |
12 | mkdir -p output/pretrained_models
13 | hf download Wan-AI/Wan2.1-T2V-1.3B --repo-type model --local-dir ./output/pretrained_models/Wan2.1-T2V-1.3B
14 |
--------------------------------------------------------------------------------
/tests/bash/training/test_training_all.sh:
--------------------------------------------------------------------------------
1 | #/bin/bash
2 | set -e
3 |
4 | mkdir -p data/data_public
5 | hf download Efficient-Large-Model/sana_data_public --repo-type dataset --local-dir ./data/data_public
6 | hf download Efficient-Large-Model/toy_data --repo-type dataset --local-dir ./data/toy_data
7 | hf download Efficient-Large-Model/video_toy_data --repo-type dataset --local-dir ./data/video_toy_data
8 |
9 | mkdir -p output/pretrained_models
10 | hf download Efficient-Large-Model/Wan2.1-T2V-1.3B --repo-type model --local-dir ./output/pretrained_models/Wan2.1-T2V-1.3B
11 |
12 | # test offline vae feature
13 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --data.load_vae_feat=true
14 |
15 | # test online vae feature
16 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --data.data_dir="[asset/example_data]" --data.type=SanaImgDataset --model.multi_scale=false
17 |
18 | # test FSDP training
19 | bash train_scripts/train.sh configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1
20 |
21 | # test SANA-Sprint(sCM + LADD) training
22 | bash train_scripts/train_scm_ladd.sh configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1
23 |
24 | # test FSDP video training
25 | bash train_video_scripts/train_video_ivjoint.sh configs/sana_video_config/Sana_2000M_256px_AdamW_fsdp.yaml --np=2 --train.num_epochs=1 --train.log_interval=1 --train.train_batch_size=1
26 |
--------------------------------------------------------------------------------
/tests/bash/training/test_training_fsdp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | echo "Setting up test data..."
5 |
6 | mkdir -p data/data_public
7 | hf download Efficient-Large-Model/toy_data --repo-type dataset --local-dir ./data/toy_data
8 |
9 | echo "Testing SANA-Sprint(sCM + LADD) training"
10 | bash train_scripts/train_scm_ladd.sh configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml --np=4 --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1
11 |
12 | echo "Testing FSDP training"
13 | bash train_scripts/train.sh configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml --np=2 --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1
14 |
--------------------------------------------------------------------------------
/tests/bash/training/test_training_vae.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | echo "Setting up test data..."
5 | # bash tests/bash/setup_test_data.sh
6 | mkdir -p data/data_public
7 | hf download Efficient-Large-Model/sana_data_public --repo-type dataset --local-dir ./data/data_public
8 |
9 | echo "Testing offline VAE feature"
10 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --np=4 --data.load_vae_feat=true
11 |
12 | echo "Testing online VAE feature"
13 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --np=4 --data.data_dir="[asset/example_data]" --data.type=SanaImgDataset --model.multi_scale=false
14 |
--------------------------------------------------------------------------------
/tests/bash/training/test_training_video.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | echo "Setting up test data..."
5 | # bash tests/bash/setup_test_data.sh
6 | hf download Efficient-Large-Model/video_toy_data --repo-type dataset --local-dir ./data/video_toy_data
7 |
8 | mkdir -p output/pretrained_models
9 | hf download Wan-AI/Wan2.1-T2V-1.3B --repo-type model --local-dir ./output/pretrained_models/Wan2.1-T2V-1.3B
10 |
11 | echo "Testing FSDP video training"
12 | bash train_video_scripts/train_video_ivjoint.sh configs/sana_video_config/Sana_2000M_256px_AdamW_fsdp.yaml --np=2 --train.num_epochs=1 --train.log_interval=1 --train.train_batch_size=1 --train.joint_training_interval=0
13 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/tools/__init__.py
--------------------------------------------------------------------------------
/tools/controlnet/annotator/ckpts/ckpts.txt:
--------------------------------------------------------------------------------
1 | Weights here.
2 |
--------------------------------------------------------------------------------
/tools/controlnet/annotator/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), "ckpts")
8 |
9 |
10 | def HWC3(x):
11 | assert x.dtype == np.uint8
12 | if x.ndim == 2:
13 | x = x[:, :, None]
14 | assert x.ndim == 3
15 | H, W, C = x.shape
16 | assert C == 1 or C == 3 or C == 4
17 | if C == 3:
18 | return x
19 | if C == 1:
20 | return np.concatenate([x, x, x], axis=2)
21 | if C == 4:
22 | color = x[:, :, 0:3].astype(np.float32)
23 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
24 | y = color * alpha + 255.0 * (1.0 - alpha)
25 | y = y.clip(0, 255).astype(np.uint8)
26 | return y
27 |
28 |
29 | def resize_image(input_image, resolution):
30 | H, W, C = input_image.shape
31 | H = float(H)
32 | W = float(W)
33 | k = float(resolution) / min(H, W)
34 | H *= k
35 | W *= k
36 | H = int(np.round(H / 64.0)) * 64
37 | W = int(np.round(W / 64.0)) * 64
38 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
39 | return img
40 |
41 |
42 | def nms(x, t, s):
43 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
44 |
45 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
46 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
47 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
48 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
49 |
50 | y = np.zeros_like(x)
51 |
52 | for f in [f1, f2, f3, f4]:
53 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
54 |
55 | z = np.zeros_like(y, dtype=np.uint8)
56 | z[y > t] = 255
57 | return z
58 |
59 |
60 | def make_noise_disk(H, W, C, F):
61 | noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
62 | noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
63 | noise = noise[F : F + H, F : F + W]
64 | noise -= np.min(noise)
65 | noise /= np.max(noise)
66 | if C == 1:
67 | noise = noise[:, :, None]
68 | return noise
69 |
70 |
71 | def min_max_norm(x):
72 | x -= np.min(x)
73 | x /= np.maximum(np.max(x), 1e-5)
74 | return x
75 |
76 |
77 | def safe_step(x, step=2):
78 | y = x.astype(np.float32) * float(step + 1)
79 | y = y.astype(np.int32).astype(np.float32) / float(step)
80 | return y
81 |
82 |
83 | def img2mask(img, H, W, low=10, high=90):
84 | assert img.ndim == 3 or img.ndim == 2
85 | assert img.dtype == np.uint8
86 |
87 | if img.ndim == 3:
88 | y = img[:, :, random.randrange(0, img.shape[2])]
89 | else:
90 | y = img
91 |
92 | y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
93 |
94 | if random.uniform(0, 1) < 0.5:
95 | y = 255 - y
96 |
97 | return y < np.percentile(y, random.randrange(low, high))
98 |
--------------------------------------------------------------------------------
/tools/controlnet/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import cv2
4 | import numpy as np
5 | from PIL import Image
6 | from torchvision import transforms as T
7 | from torchvision.transforms.functional import InterpolationMode
8 |
9 | from tools.controlnet.annotator.hed import HEDdetector
10 | from tools.controlnet.annotator.util import HWC3, nms, resize_image
11 |
12 | preprocessor = None
13 |
14 |
15 | def transform_control_signal(control_signal, hw):
16 | if isinstance(control_signal, str):
17 | control_signal = Image.open(control_signal)
18 | elif isinstance(control_signal, Image.Image):
19 | control_signal = control_signal
20 | elif isinstance(control_signal, np.ndarray):
21 | control_signal = Image.fromarray(control_signal)
22 | else:
23 | raise ValueError("control_signal must be a path or a PIL.Image.Image or a numpy array")
24 |
25 | transform = T.Compose(
26 | [
27 | T.Lambda(lambda img: img.convert("RGB")),
28 | T.Resize((int(hw[0, 0]), int(hw[0, 1])), interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC
29 | T.CenterCrop((int(hw[0, 0]), int(hw[0, 1]))),
30 | T.ToTensor(),
31 | T.Normalize([0.5], [0.5]),
32 | ]
33 | )
34 | return transform(control_signal).unsqueeze(0)
35 |
36 |
37 | def get_scribble_map(input_image, det, detect_resolution=512, thickness=None):
38 | """
39 | Generate scribble map from input image
40 |
41 | Args:
42 | input_image: Input image (numpy array, HWC format)
43 | det: Detector type ('Scribble_HED', 'Scribble_PIDI', 'None')
44 | detect_resolution: Processing resolution
45 | thickness: Line thickness (between 0-24, None for random)
46 |
47 | Returns:
48 | Processed scribble map
49 | """
50 | global preprocessor
51 |
52 | # Initialize detector
53 | if "HED" in det and not isinstance(preprocessor, HEDdetector):
54 | preprocessor = HEDdetector()
55 |
56 | input_image = HWC3(input_image)
57 |
58 | if det == "None":
59 | detected_map = input_image.copy()
60 | else:
61 | # Generate scribble map
62 | detected_map = preprocessor(resize_image(input_image, detect_resolution))
63 | detected_map = HWC3(detected_map)
64 |
65 | # Post-processing
66 | detected_map = nms(detected_map, 127, 3.0)
67 | detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
68 | detected_map[detected_map > 4] = 255
69 | detected_map[detected_map < 255] = 0
70 |
71 | # Control line thickness
72 | if thickness is None:
73 | thickness = random.randint(0, 24) # Random thickness, including 0
74 | if thickness == 0:
75 | # Use erosion operation to get thinner lines
76 | kernel = np.ones((4, 4), np.uint8)
77 | detected_map = cv2.erode(detected_map, kernel, iterations=1)
78 | elif thickness > 1:
79 | kernel_size = thickness // 2
80 | kernel = np.ones((kernel_size, kernel_size), np.uint8)
81 | detected_map = cv2.dilate(detected_map, kernel, iterations=1)
82 |
83 | return detected_map
84 |
--------------------------------------------------------------------------------
/tools/convert_ImgDataset_to_WebDatasetMS_format.py:
--------------------------------------------------------------------------------
1 | # @Author: Pevernow (wzy3450354617@gmail.com)
2 | # @Date: 2025/1/5
3 | # @License: (Follow the main project)
4 | import json
5 | import os
6 | import tarfile
7 |
8 | from PIL import Image, PngImagePlugin
9 |
10 | PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024 # Increase maximum size for text chunks
11 |
12 |
13 | def process_data(input_dir, output_tar_name="output.tar"):
14 | """
15 | Processes a directory containing PNG files, generates corresponding JSON files,
16 | and packages all files into a TAR file. It also counts the number of processed PNG images,
17 | and saves the height and width of each PNG file to the JSON.
18 |
19 | Args:
20 | input_dir (str): The input directory containing PNG files.
21 | output_tar_name (str): The name of the output TAR file (default is "output.tar").
22 | """
23 | png_count = 0
24 | json_files_created = []
25 |
26 | for filename in os.listdir(input_dir):
27 | if filename.lower().endswith(".png"):
28 | png_count += 1
29 | base_name = filename[:-4] # Remove the ".png" extension
30 | txt_filename = os.path.join(input_dir, base_name + ".txt")
31 | json_filename = base_name + ".json"
32 | json_filepath = os.path.join(input_dir, json_filename)
33 | png_filepath = os.path.join(input_dir, filename)
34 |
35 | if os.path.exists(txt_filename):
36 | try:
37 | # Get the dimensions of the PNG image
38 | with Image.open(png_filepath) as img:
39 | width, height = img.size
40 |
41 | with open(txt_filename, encoding="utf-8") as f:
42 | caption_content = f.read().strip()
43 |
44 | data = {"file_name": filename, "prompt": caption_content, "width": width, "height": height}
45 |
46 | with open(json_filepath, "w", encoding="utf-8") as outfile:
47 | json.dump(data, outfile, indent=4, ensure_ascii=False)
48 |
49 | print(f"Generated: {json_filename}")
50 | json_files_created.append(json_filepath)
51 |
52 | except Exception as e:
53 | print(f"Error processing file {filename}: {e}")
54 | else:
55 | print(f"Warning: No corresponding TXT file found for {filename}.")
56 |
57 | # Create a TAR file and include all files
58 | with tarfile.open(output_tar_name, "w") as tar:
59 | for item in os.listdir(input_dir):
60 | item_path = os.path.join(input_dir, item)
61 | tar.add(item_path, arcname=item) # arcname maintains the relative path of the file in the tar
62 |
63 | print(f"\nAll files have been packaged into: {output_tar_name}")
64 | print(f"Number of PNG images processed: {png_count}")
65 |
66 |
67 | if __name__ == "__main__":
68 | input_directory = input("Please enter the directory path containing PNG and TXT files: ")
69 | output_tar_filename = (
70 | input("Please enter the name of the output TAR file (default is output.tar): ") or "output.tar"
71 | )
72 | process_data(input_directory, output_tar_filename)
73 |
--------------------------------------------------------------------------------
/tools/convert_py_to_yaml.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import yaml
4 |
5 |
6 | def convert_py_to_yaml(py_file_path):
7 | with open(py_file_path, encoding="utf-8") as py_file:
8 | py_content = py_file.read()
9 |
10 | local_vars = {}
11 | exec(py_content, {}, local_vars)
12 |
13 | yaml_file_path = os.path.splitext(py_file_path)[0] + ".yaml"
14 |
15 | with open(yaml_file_path, "w", encoding="utf-8") as yaml_file:
16 | yaml.dump(local_vars, yaml_file, default_flow_style=False, allow_unicode=True)
17 |
18 |
19 | def process_directory(path):
20 | for root, dirs, files in os.walk(path):
21 | for filename in files:
22 | if filename.endswith(".py"):
23 | py_file_path = os.path.join(root, filename)
24 | convert_py_to_yaml(py_file_path)
25 | print(f"convert {py_file_path} to YAML format")
26 |
27 |
28 | if __name__ == "__main__":
29 | process_directory("../configs/")
30 |
--------------------------------------------------------------------------------
/tools/create_wids_metadata.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | import json
18 | import os
19 | import sys
20 | import tarfile
21 | from glob import glob
22 |
23 | from tqdm.contrib.concurrent import process_map
24 |
25 | """
26 | python tools/create_wids_metadata.py /path/to/tar/dir > /path/to/wids-meta.json
27 | """
28 |
29 | d = sys.argv[1]
30 |
31 |
32 | def process(t):
33 | d = {}
34 | with tarfile.open(t, "r") as tar:
35 | for f in tar:
36 | n, e = os.path.splitext(f.name)
37 | if e == ".jpg" or e == ".jpeg" or e == ".png" or e == ".json" or e == ".npy":
38 | if n in d:
39 | d[n] = 1
40 | else:
41 | d[n] = 0
42 | s = os.path.getsize(t)
43 | i = sum(d.values())
44 | t = os.path.basename(t)
45 | return {"url": t, "nsamples": i, "filesize": s}
46 |
47 |
48 | print(
49 | json.dumps(
50 | {
51 | "name": "sana-dev",
52 | "__kind__": "SANA-WebDataset",
53 | "wids_version": 1,
54 | "shardlist": sorted(
55 | process_map(process, glob(f"{d}/*.tar"), chunksize=1, max_workers=os.cpu_count()),
56 | key=lambda x: x["url"],
57 | ),
58 | },
59 | indent=4,
60 | ),
61 | end="",
62 | )
63 |
--------------------------------------------------------------------------------
/tools/download.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 | """
17 | Functions for downloading pre-trained Sana models
18 | """
19 | import argparse
20 | import os
21 |
22 | import torch
23 | from termcolor import colored
24 | from torchvision.datasets.utils import download_url
25 |
26 | from sana.tools import hf_download_or_fpath
27 |
28 | pretrained_models = {}
29 |
30 |
31 | def find_model(model_name):
32 | """
33 | Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path.
34 | """
35 | if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints
36 | return download_model(model_name)
37 |
38 | # Load a custom Sana checkpoint:
39 | print(colored(f"[Sana] Loading model from {model_name}", attrs=["bold"]))
40 | model_name = hf_download_or_fpath(model_name)
41 | assert os.path.isfile(model_name), f"Could not find Sana checkpoint at {model_name}"
42 | print(colored(f"[Sana] Loaded model from {model_name}", attrs=["bold"]))
43 | if model_name.endswith(".safetensors"):
44 | import safetensors
45 |
46 | return {"state_dict": safetensors.torch.load_file(model_name, device="cpu")}
47 | elif model_name.endswith(".safetensors.index.json"):
48 | import json
49 |
50 | import safetensors
51 |
52 | index = json.load(open(model_name))["weight_map"]
53 | safetensors_list = set(index.values())
54 | state_dict = {}
55 | for safetensors_path in safetensors_list:
56 | state_dict.update(
57 | safetensors.torch.load_file(os.path.join(os.path.dirname(model_name), safetensors_path), device="cpu")
58 | )
59 | return {"state_dict": state_dict}
60 | else:
61 | return torch.load(model_name, map_location=lambda storage, loc: storage)
62 |
63 |
64 | def download_model(model_name):
65 | """
66 | Downloads a pre-trained Sana model from the web.
67 | """
68 | assert model_name in pretrained_models
69 | local_path = f"output/pretrained_models/{model_name}"
70 | if not os.path.isfile(local_path):
71 | hf_endpoint = os.environ.get("HF_ENDPOINT")
72 | if hf_endpoint is None:
73 | hf_endpoint = "https://huggingface.co"
74 | os.makedirs("output/pretrained_models", exist_ok=True)
75 | web_path = f""
76 | download_url(web_path, "output/pretrained_models/")
77 | model = torch.load(local_path, map_location=lambda storage, loc: storage)
78 | return model
79 |
80 |
81 | if __name__ == "__main__":
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument("--model_names", nargs="+", type=str, default=pretrained_models)
84 | args = parser.parse_args()
85 | model_names = args.model_names
86 | model_names = set(model_names)
87 |
88 | # Download Sana checkpoints
89 | for model in model_names:
90 | download_model(model)
91 | print("Done.")
92 |
--------------------------------------------------------------------------------
/tools/inference_scaling/nvila_sana_pick.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | set -e
3 |
4 | sana_dir=$1
5 | number_of_files=$2
6 | pick_number=$3
7 | # calculate number of GPU to use in this machine
8 | num_gpu=$(nvidia-smi -L | wc -l)
9 | echo "sana_dir: $sana_dir, number_of_files: $number_of_files, pick_number: $pick_number, num_gpu: $num_gpu"
10 | # start idx iterate from 0 * (552//8), 1 * (552//8), 2 * (552//8), 3 * (552//8), 4 * (552//8), 5 * (552//8), 6 * (552//8), 7 * (552//8)
11 | # end idx iterate from 1 * (552//8), 2 * (552//8), 3 * (552//8), 4 * (552//8), 5 * (552//8), 6 * (552//8), 7 * (552//8), 552
12 | for idx in $(seq 0 $((num_gpu - 1))); do
13 | start_idx=$((idx * (552 / num_gpu)))
14 | end_idx=$((start_idx + 552 / num_gpu))
15 | if [ $idx -eq $((num_gpu - 1)) ]; then
16 | end_idx=552
17 | fi
18 |
19 | echo "CUDA_VISIBLE_DEVICES=$idx python tools/inference_scaling/nvila_sana_pick.py --start_idx $start_idx --end_idx $end_idx --base_dir $sana_dir --number_of_files $number_of_files --pick_number $pick_number &"
20 | CUDA_VISIBLE_DEVICES=$idx python tools/inference_scaling/nvila_sana_pick.py --start_idx $start_idx --end_idx $end_idx --base_dir $sana_dir --number_of_files $number_of_files --pick_number $pick_number &
21 | done
22 | wait
23 |
--------------------------------------------------------------------------------
/tools/metrics/clip-score/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # celery beat schedule file
86 | celerybeat-schedule
87 |
88 | # SageMath parsed files
89 | *.sage.py
90 |
91 | # Environments
92 | .env
93 | .venv
94 | env/
95 | venv/
96 | ENV/
97 | env.bak/
98 | venv.bak/
99 |
100 | # Spyder project settings
101 | .spyderproject
102 | .spyproject
103 |
104 | # Rope project settings
105 | .ropeproject
106 |
107 | # mkdocs documentation
108 | /site
109 |
110 | # mypy
111 | .mypy_cache/
112 | .dmypy.json
113 | dmypy.json
114 |
115 | # Pyre type checker
116 | .pyre/
117 |
--------------------------------------------------------------------------------
/tools/metrics/clip-score/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import setuptools
4 |
5 |
6 | def read(rel_path):
7 | base_path = os.path.abspath(os.path.dirname(__file__))
8 | with open(os.path.join(base_path, rel_path)) as f:
9 | return f.read()
10 |
11 |
12 | def get_version(rel_path):
13 | for line in read(rel_path).splitlines():
14 | if line.startswith("__version__"):
15 | delim = '"' if '"' in line else "'"
16 | return line.split(delim)[1]
17 |
18 | raise RuntimeError("Unable to find version string.")
19 |
20 |
21 | if __name__ == "__main__":
22 | setuptools.setup(
23 | name="clip-score",
24 | version=get_version(os.path.join("src", "clip_score", "__init__.py")),
25 | author="Taited",
26 | author_email="taited9160@gmail.com",
27 | description=("Package for calculating CLIP-Score" " using PyTorch"),
28 | long_description=read("README.md"),
29 | long_description_content_type="text/markdown",
30 | url="https://github.com/taited/clip-score",
31 | package_dir={"": "src"},
32 | packages=setuptools.find_packages(where="src"),
33 | classifiers=[
34 | "Programming Language :: Python :: 3",
35 | "License :: OSI Approved :: Apache Software License",
36 | ],
37 | python_requires=">=3.5",
38 | entry_points={
39 | "console_scripts": [
40 | "clip-score = clip_score.clip_score:main",
41 | ],
42 | },
43 | install_requires=[
44 | "numpy",
45 | "pillow",
46 | "torch>=1.7.1",
47 | "torchvision>=0.8.2",
48 | "ftfy",
49 | "regex",
50 | "tqdm",
51 | ],
52 | extras_require={"dev": ["flake8", "flake8-bugbear", "flake8-isort", "nox"]},
53 | )
54 |
--------------------------------------------------------------------------------
/tools/metrics/clip-score/src/clip_score/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.1"
2 |
--------------------------------------------------------------------------------
/tools/metrics/clip-score/src/clip_score/__main__.py:
--------------------------------------------------------------------------------
1 | import clip_score.clip_score
2 |
3 | clip_score.clip_score.main()
4 |
--------------------------------------------------------------------------------
/tools/metrics/dpg_bench/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | addict
3 |
4 | # for modelscope
5 | cloudpickle
6 | datasets==2.21.0
7 | decord>=0.6.0
8 | diffusers
9 | ftfy>=6.0.3
10 | librosa==0.10.1
11 | modelscope[multi-modal]
12 | numpy
13 | opencv-python
14 | oss2
15 | pandas
16 | pillow
17 | # compatible with taming-transformers-rom1504
18 | rapidfuzz
19 | # rough-score was just recently updated from 0.0.4 to 0.0.7
20 | # which introduced compatability issues that are being investigated
21 | rouge_score<=0.0.4
22 | safetensors
23 | simplejson
24 | sortedcontainers
25 | # scikit-video
26 | soundfile
27 | taming-transformers-rom1504
28 | tiktoken
29 | timm
30 | tokenizers
31 | torchvision
32 | tqdm
33 | transformers
34 | transformers_stream_generator
35 | unicodedata2
36 | wandb
37 | zhconv
38 | # fairseq need to be build from source code: https://github.com/facebookresearch/fairseq
39 |
--------------------------------------------------------------------------------
/tools/metrics/geneval/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Dhruba Ghosh
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tools/metrics/geneval/evaluation/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Download Mask2Former object detection config and weights
4 |
5 | if [ ! -z "$1" ]
6 | then
7 | mkdir -p "$1"
8 | echo "Downloading mask2former for GenEval"
9 | wget https://download.openmmlab.com/mmdetection/v2.0/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco_20220504_001756-743b7d99.pth -O "$1/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.pth"
10 | fi
11 |
--------------------------------------------------------------------------------
/tools/metrics/geneval/evaluation/object_names.txt:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorcycle
5 | airplane
6 | bus
7 | train
8 | truck
9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | couch
59 | potted plant
60 | bed
61 | dining table
62 | toilet
63 | tv
64 | laptop
65 | computer mouse
66 | tv remote
67 | computer keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
81 |
--------------------------------------------------------------------------------
/tools/metrics/geneval/evaluation/summary_scores.py:
--------------------------------------------------------------------------------
1 | # Get results of evaluation
2 |
3 | import argparse
4 | import os
5 |
6 | import numpy as np
7 | import pandas as pd
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("filename", type=str)
11 | args = parser.parse_args()
12 |
13 | # Load classnames
14 |
15 | with open(os.path.join(os.path.dirname(__file__), "object_names.txt")) as cls_file:
16 | classnames = [line.strip() for line in cls_file]
17 | cls_to_idx = {"_".join(cls.split()): idx for idx, cls in enumerate(classnames)}
18 |
19 | # Load results
20 |
21 | df = pd.read_json(args.filename, orient="records", lines=True)
22 |
23 | # Measure overall success
24 |
25 | print("Summary")
26 | print("=======")
27 | print(f"Total images: {len(df)}")
28 | print(f"Total prompts: {len(df.groupby('metadata'))}")
29 | print(f"% correct images: {df['correct'].mean():.2%}")
30 | print(f"% correct prompts: {df.groupby('metadata')['correct'].any().mean():.2%}")
31 | print()
32 |
33 | # By group
34 |
35 | task_scores = []
36 |
37 | print("Task breakdown")
38 | print("==============")
39 | for tag, task_df in df.groupby("tag", sort=False):
40 | task_scores.append(task_df["correct"].mean())
41 | print(f"{tag:<16} = {task_df['correct'].mean():.2%} ({task_df['correct'].sum()} / {len(task_df)})")
42 | print()
43 |
44 | print(f"Overall score (avg. over tasks): {np.mean(task_scores):.5f}")
45 |
--------------------------------------------------------------------------------
/tools/metrics/geneval/geneval_env.md:
--------------------------------------------------------------------------------
1 | The installation process refers to https://github.com/djghosh13/geneval/issues/12 Thanks for the community!
2 |
3 | # Cloning the repository
4 |
5 | ```bash
6 | git clone https://github.com/djghosh13/geneval.git
7 |
8 | cd geneval
9 | conda create -n geneval python=3.8.10 -y
10 | conda activate geneval
11 | ```
12 |
13 | # Installing dependencies
14 |
15 | ```bash
16 | pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
17 | pip install open-clip-torch==2.26.1
18 | pip install clip-benchmark
19 | pip install -U openmim
20 | pip install einops
21 | python -m pip install lightning
22 | pip install diffusers["torch"] transformers
23 | pip install tomli
24 | pip install platformdirs
25 | pip install --upgrade setuptools
26 | ```
27 |
28 | # mmengine and mmcv dependency installation
29 |
30 | ```bash
31 | mim install mmengine mmcv-full==1.7.2
32 | ```
33 |
34 | # mmdet installation
35 |
36 | ```bash
37 | git clone https://github.com/open-mmlab/mmdetection.git
38 | cd mmdetection; git checkout 2.x
39 | pip install -v -e .
40 | ```
41 |
42 | <!-- conda create -n geneval python==3.9 -->
43 |
44 | <!--
45 | conda install -c nvidia cuda-toolkit -y
46 | ./evaluation/download_models.sh output/
47 |
48 | # install mmdetection
49 |
50 | git clone https://github.com/open-mmlab/mmdetection.git
51 | cd mmdetection; git checkout 2.x
52 | pip install -v -e .
53 |
54 | pip install einops
55 | pip install accelerate==0.15.0
56 | pip install torchmetrics==0.6.0
57 | pip install transformers==4.48.0
58 | pip install pandas
59 | pip install open-clip-torch
60 | pip install clip_benchmark
61 | pip install huggingface_hub==0.24.5
62 | pip install pytorch-lightning==1.4.2
63 |
64 | pip install openmim
65 | mim install mmcv-full==1.7.1 -->
66 |
--------------------------------------------------------------------------------
/tools/metrics/geneval/images/geneval_figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Sana/9a13b9ccaa3671c654ae9ab3f462811bb04c03ac/tools/metrics/geneval/images/geneval_figure_1.png
--------------------------------------------------------------------------------
/tools/metrics/geneval/prompts/object_names.txt:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorcycle
5 | airplane
6 | bus
7 | train
8 | truck
9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | couch
59 | potted plant
60 | bed
61 | dining table
62 | toilet
63 | tv
64 | laptop
65 | computer mouse
66 | tv remote
67 | computer keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
81 |
--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # celery beat schedule file
86 | celerybeat-schedule
87 |
88 | # SageMath parsed files
89 | *.sage.py
90 |
91 | # Environments
92 | .env
93 | .venv
94 | env/
95 | venv/
96 | ENV/
97 | env.bak/
98 | venv.bak/
99 |
100 | # Spyder project settings
101 | .spyderproject
102 | .spyproject
103 |
104 | # Rope project settings
105 | .ropeproject
106 |
107 | # mkdocs documentation
108 | /site
109 |
110 | # mypy
111 | .mypy_cache/
112 | .dmypy.json
113 | dmypy.json
114 |
115 | # Pyre type checker
116 | .pyre/
117 |
--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ## \[0.3.0\] - 2023-01-05
4 |
5 | ### Added
6 |
7 | - Add argument `--save-stats` allowing to compute dataset statistics and save them as an `.npz` file ([#80](https://github.com/mseitzer/pytorch-fid/pull/80)). The `.npz` file can be used in subsequent FID computations instead of recomputing the dataset statistics. This option can be used in the following way: `python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile`.
8 |
9 | ### Fixed
10 |
11 | - Do not use `os.sched_getaffinity` to get number of available CPUs on Windows, as it is not available there ([232b3b14](https://github.com/mseitzer/pytorch-fid/commit/232b3b1468800102fcceaf6f2bb8977811fc991a), [#84](https://github.com/mseitzer/pytorch-fid/issues/84)).
12 | - Do not use Inception model argument `pretrained`, as it was deprecated in torchvision 0.13 ([#88](https://github.com/mseitzer/pytorch-fid/pull/88)).
13 |
14 | ## \[0.2.1\] - 2021-10-10
15 |
16 | ### Added
17 |
18 | - Add argument `--num-workers` to select number of dataloader processes ([#66](https://github.com/mseitzer/pytorch-fid/pull/66)). Defaults to 8 or the number of available CPUs if less than 8 CPUs are available.
19 |
20 | ### Fixed
21 |
22 | - Fixed package setup to work under Windows ([#55](https://github.com/mseitzer/pytorch-fid/pull/55), [#72](https://github.com/mseitzer/pytorch-fid/issues/72))
23 |
24 | ## \[0.2.0\] - 2020-11-30
25 |
26 | ### Added
27 |
28 | - Load images using a Pytorch dataloader, which should result in a speed-up. ([#47](https://github.com/mseitzer/pytorch-fid/pull/47))
29 | - Support more image extensions ([#53](https://github.com/mseitzer/pytorch-fid/pull/53))
30 | - Improve tooling by setting up Nox, add linting and test support ([#52](https://github.com/mseitzer/pytorch-fid/pull/52))
31 | - Add some unit tests
32 |
33 | ## \[0.1.1\] - 2020-08-16
34 |
35 | ### Fixed
36 |
37 | - Fixed software license string in `setup.py`
38 |
39 | ## \[0.1.0\] - 2020-08-16
40 |
41 | Initial release as a pypi package. Use `pip install pytorch-fid` to install.
42 |
--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/noxfile.py:
--------------------------------------------------------------------------------
1 | import nox
2 |
3 | LOCATIONS = ("src/", "tests/", "noxfile.py", "setup.py")
4 |
5 |
6 | @nox.session
7 | def lint(session):
8 | session.install("flake8")
9 | session.install("flake8-bugbear")
10 | session.install("flake8-isort")
11 |
12 | args = session.posargs or LOCATIONS
13 | session.run("flake8", *args)
14 |
15 |
16 | @nox.session
17 | def tests(session):
18 | session.install(".")
19 | session.install("pytest")
20 | session.install("pytest-mock")
21 | session.run("pytest", *session.posargs)
22 |
--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | select=F,W,E,I,B,B9
3 | ignore=W503,B950
4 | max-line-length=79
5 |
6 | [isort]
7 | multi_line_output=1
8 | line_length=79
9 |
--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import setuptools
4 |
5 |
6 | def read(rel_path):
7 | base_path = os.path.abspath(os.path.dirname(__file__))
8 | with open(os.path.join(base_path, rel_path)) as f:
9 | return f.read()
10 |
11 |
12 | def get_version(rel_path):
13 | for line in read(rel_path).splitlines():
14 | if line.startswith("__version__"):
15 | # __version__ = "0.9"
16 | delim = '"' if '"' in line else "'"
17 | return line.split(delim)[1]
18 |
19 | raise RuntimeError("Unable to find version string.")
20 |
21 |
22 | if __name__ == "__main__":
23 | setuptools.setup(
24 | name="pytorch-fid",
25 | version=get_version(os.path.join("src", "pytorch_fid", "__init__.py")),
26 | author="Max Seitzer",
27 | description=("Package for calculating Frechet Inception Distance (FID)" " using PyTorch"),
28 | long_description=read("README.md"),
29 | long_description_content_type="text/markdown",
30 | url="https://github.com/mseitzer/pytorch-fid",
31 | package_dir={"": "src"},
32 | packages=setuptools.find_packages(where="src"),
33 | classifiers=[
34 | "Programming Language :: Python :: 3",
35 | "License :: OSI Approved :: Apache Software License",
36 | ],
37 | python_requires=">=3.5",
38 | entry_points={
39 | "console_scripts": [
40 | "pytorch-fid = pytorch_fid.fid_score:main",
41 | ],
42 | },
43 | install_requires=["numpy", "pillow", "scipy", "torch>=1.0.1", "torchvision>=0.2.2"],
44 | extras_require={"dev": ["flake8", "flake8-bugbear", "flake8-isort", "nox"]},
45 | )
46 |
--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/src/pytorch_fid/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.3.0"
2 |
--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/src/pytorch_fid/__main__.py:
--------------------------------------------------------------------------------
1 | import pytorch_fid.fid_score
2 |
3 | pytorch_fid.fid_score.main()
4 |
--------------------------------------------------------------------------------
/tools/metrics/pytorch-fid/tests/test_fid_score.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from PIL import Image
5 | from pytorch_fid import fid_score, inception
6 |
7 |
8 | @pytest.fixture
9 | def device():
10 | return torch.device("cpu")
11 |
12 |
13 | def test_calculate_fid_given_statistics(mocker, tmp_path, device):
14 | dim = 2048
15 | m1, m2 = np.zeros((dim,)), np.ones((dim,))
16 | sigma = np.eye(dim)
17 |
18 | def dummy_statistics(path, model, batch_size, dims, device, num_workers):
19 | if path.endswith("1"):
20 | return m1, sigma
21 | elif path.endswith("2"):
22 | return m2, sigma
23 | else:
24 | raise ValueError
25 |
26 | mocker.patch("pytorch_fid.fid_score.compute_statistics_of_path", side_effect=dummy_statistics)
27 |
28 | dir_names = ["1", "2"]
29 | paths = []
30 | for name in dir_names:
31 | path = tmp_path / name
32 | path.mkdir()
33 | paths.append(str(path))
34 |
35 | fid_value = fid_score.calculate_fid_given_paths(paths, batch_size=dim, device=device, dims=dim, num_workers=0)
36 |
37 | # Given equal covariance, FID is just the squared norm of difference
38 | assert fid_value == np.sum((m1 - m2) ** 2)
39 |
40 |
41 | def test_compute_statistics_of_path(mocker, tmp_path, device):
42 | model = mocker.MagicMock(inception.InceptionV3)()
43 | model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)]
44 |
45 | size = (4, 4, 3)
46 | arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]
47 | images = [(arr * 255).astype(np.uint8) for arr in arrays]
48 |
49 | paths = []
50 | for idx, image in enumerate(images):
51 | paths.append(str(tmp_path / f"{idx}.png"))
52 | Image.fromarray(image, mode="RGB").save(paths[-1])
53 |
54 | stats = fid_score.compute_statistics_of_path(
55 | str(tmp_path), model, batch_size=len(images), dims=3, device=device, num_workers=0
56 | )
57 |
58 | assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3)
59 | assert np.allclose(stats[1], np.ones((3, 3)) * 0.25)
60 |
61 |
62 | def test_compute_statistics_of_path_from_file(mocker, tmp_path, device):
63 | model = mocker.MagicMock(inception.InceptionV3)()
64 |
65 | mu = np.random.randn(5)
66 | sigma = np.random.randn(5, 5)
67 |
68 | path = tmp_path / "stats.npz"
69 | with path.open("wb") as f:
70 | np.savez(f, mu=mu, sigma=sigma)
71 |
72 | stats = fid_score.compute_statistics_of_path(str(path), model, batch_size=1, dims=5, device=device, num_workers=0)
73 |
74 | assert np.allclose(stats[0], mu)
75 | assert np.allclose(stats[1], sigma)
76 |
77 |
78 | def test_image_types(tmp_path):
79 | in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255
80 | in_image = Image.fromarray(in_arr, mode="RGB")
81 |
82 | paths = []
83 | for ext in fid_score.IMAGE_EXTENSIONS:
84 | paths.append(str(tmp_path / f"img.{ext}"))
85 | in_image.save(paths[-1])
86 |
87 | dataset = fid_score.ImagePathDataset(paths)
88 |
89 | for img in dataset:
90 | assert np.allclose(np.array(img), in_arr)
91 |
--------------------------------------------------------------------------------
/tools/metrics/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def tracker(args, result_dict, label="", pattern="epoch_step", metric="FID"):
5 | if args.report_to == "wandb":
6 | import wandb
7 |
8 | wandb_name = f"[{args.log_metric}]_{args.name}"
9 | wandb.init(project=args.tracker_project_name, name=wandb_name, resume="allow", id=wandb_name, tags="metrics")
10 | run = wandb.run
11 | if pattern == "step":
12 | pattern = "sample_steps"
13 | elif pattern == "epoch_step":
14 | pattern = "step"
15 | custom_name = f"custom_{pattern}"
16 | run.define_metric(custom_name)
17 | # define which metrics will be plotted against it
18 | run.define_metric(f"{metric}_{label}", step_metric=custom_name)
19 |
20 | steps = []
21 | results = []
22 |
23 | def extract_value(regex, exp_name):
24 | match = re.search(regex, exp_name)
25 | if match:
26 | return match.group(1)
27 | else:
28 | return "unknown"
29 |
30 | for exp_name, result_value in result_dict.items():
31 | if pattern == "step":
32 | regex = r".*step(\d+)_scale.*"
33 | custom_x = extract_value(regex, exp_name)
34 | elif pattern == "sample_steps":
35 | regex = r".*step(\d+)_size.*"
36 | custom_x = extract_value(regex, exp_name)
37 | else:
38 | regex = rf"{pattern}(\d+(\.\d+)?)"
39 | custom_x = extract_value(regex, exp_name)
40 | custom_x = 1 if custom_x == "unknown" else custom_x
41 |
42 | assert custom_x != "unknown"
43 | steps.append(float(custom_x))
44 | results.append(result_value)
45 |
46 | sorted_data = sorted(zip(steps, results))
47 | steps, results = zip(*sorted_data)
48 |
49 | for step, result in sorted(zip(steps, results)):
50 | run.log({f"{metric}_{label}": result, custom_name: step})
51 | else:
52 | print(f"{args.report_to} is not supported")
53 |
--------------------------------------------------------------------------------
/train_scripts/train.sh:
--------------------------------------------------------------------------------
1 | #/bin/bash
2 | set -e
3 |
4 | work_dir=output/debug
5 | np=2
6 |
7 | while [[ $# -gt 0 ]]; do
8 | case $1 in
9 | --np=*)
10 | np="${1#*=}"
11 | shift
12 | ;;
13 | *.yaml)
14 | config=$1
15 | shift
16 | ;;
17 | *)
18 | other_args+=("$1")
19 | shift
20 | ;;
21 | esac
22 | done
23 |
24 | if [[ -z "$config" ]]; then
25 | config="configs/sana1-5_config/1024ms/Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml"
26 | echo "No yaml file specified. Set to --config_path=$config"
27 | fi
28 |
29 | cmd="TRITON_PRINT_AUTOTUNING=1 \
30 | torchrun --nproc_per_node=$np --master_port=$((RANDOM % 10000 + 20000)) \
31 | train_scripts/train.py \
32 | --config_path=$config \
33 | --work_dir=$work_dir \
34 | --name=tmp \
35 | --resume_from=latest \
36 | --report_to=tensorboard \
37 | --debug=true \
38 | ${other_args[@]}"
39 |
40 | echo $cmd
41 | eval $cmd
42 |
--------------------------------------------------------------------------------
/train_scripts/train_lora.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
4 | export INSTANCE_DIR="data/dreambooth/dog"
5 | export OUTPUT_DIR="trained-sana-lora"
6 |
7 | accelerate launch --num_processes 4 --main_process_port 29500 --gpu_ids 0,1,2,3 \
8 | train_scripts/train_dreambooth_lora_sana.py \
9 | --pretrained_model_name_or_path=$MODEL_NAME \
10 | --instance_data_dir=$INSTANCE_DIR \
11 | --output_dir=$OUTPUT_DIR \
12 | --mixed_precision="bf16" \
13 | --instance_prompt="a photo of sks dog" \
14 | --resolution=1024 \
15 | --train_batch_size=1 \
16 | --gradient_accumulation_steps=4 \
17 | --use_8bit_adam \
18 | --learning_rate=1e-4 \
19 | --report_to="wandb" \
20 | --lr_scheduler="constant" \
21 | --lr_warmup_steps=0 \
22 | --max_train_steps=500 \
23 | --validation_prompt="A photo of sks dog in a pond, yarn art style" \
24 | --validation_epochs=25 \
25 | --seed="0" \
26 | --push_to_hub
27 |
--------------------------------------------------------------------------------
/train_scripts/train_scm_ladd.sh:
--------------------------------------------------------------------------------
1 | #/bin/bash
2 | set -e
3 |
4 | work_dir=output/debug_sCM_ladd
5 | np=2
6 |
7 | while [[ $# -gt 0 ]]; do
8 | case $1 in
9 | --np=*)
10 | np="${1#*=}"
11 | shift
12 | ;;
13 | *.yaml)
14 | config=$1
15 | shift
16 | ;;
17 | *)
18 | other_args+=("$1")
19 | shift
20 | ;;
21 | esac
22 | done
23 |
24 | if [[ -z "$config" ]]; then
25 | config="configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml"
26 | echo "Only support .yaml files, but get $1. Set to --config_path=$config"
27 | fi
28 |
29 | cmd="TRITON_PRINT_AUTOTUNING=1 \
30 | torchrun --nproc_per_node=$np --master_port=$((RANDOM % 10000 + 20000)) \
31 | train_scripts/train_scm_ladd.py \
32 | --config_path=$config \
33 | --work_dir=$work_dir \
34 | --name=tmp \
35 | --resume_from=latest \
36 | --report_to=tensorboard \
37 | --debug=true \
38 | ${other_args[@]}"
39 |
40 | echo $cmd
41 | eval $cmd
42 |
--------------------------------------------------------------------------------
/train_video_scripts/train_video_ivjoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | work_dir=output/debug_video
5 | np=1
6 |
7 |
8 | while [[ $# -gt 0 ]]; do
9 | case $1 in
10 | --np=*)
11 | np="${1#*=}"
12 | shift
13 | ;;
14 | *.yaml)
15 | config=$1
16 | shift
17 | ;;
18 | *)
19 | other_args+=("$1")
20 | shift
21 | ;;
22 | esac
23 | done
24 |
25 | if [[ -z "$config" ]]; then
26 | config="configs/sana_video_config/Sana_2000M_480px_AdamW_fsdp.yaml"
27 | echo "No yaml file specified. Set to --config_path=$config"
28 | fi
29 |
30 | export DISABLE_XFORMERS=1
31 | export DEBUG_MODE=1
32 |
33 | cmd="TRITON_PRINT_AUTOTUNING=1 \
34 | torchrun --nproc_per_node=$np --master_port=$((RANDOM % 10000 + 20000)) \
35 | train_video_scripts/train_video_ivjoint.py \
36 | --config_path=$config \
37 | --work_dir=$work_dir \
38 | --train.log_interval=1 \
39 | --name=tmp \
40 | --resume_from=latest \
41 | --report_to=tensorboard \
42 | --train.num_workers=0 \
43 | --train.visualize=False \
44 | --debug=true \
45 | ${other_args[@]}"
46 |
47 | echo $cmd
48 | eval $cmd
49 |
--------------------------------------------------------------------------------
/train_video_scripts/train_video_ivjoint_chunk.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | work_dir=output/debug_video
5 | np=1
6 |
7 |
8 | if [[ $1 == *.yaml ]]; then
9 | config=$1
10 | shift
11 | else
12 | config="configs/sana_video_config/480ms/Sana_1600M_480px_adamW_fsdp_chunk.yaml"
13 | echo "Only support .yaml files, but get $1. Set to --config_path=$config"
14 | fi
15 |
16 | export DISABLE_XFORMERS=1
17 | export DEBUG_MODE=1
18 |
19 | cmd="TRITON_PRINT_AUTOTUNING=1 \
20 | torchrun --nproc_per_node=$np --master_port=$((RANDOM % 10000 + 20000)) \
21 | configs/sana_video_config/480ms/Sana_1600M_480px_adamW_fsdp_chunk.yaml \
22 | --config_path=$config \
23 | --work_dir=$work_dir \
24 | --train.log_interval=1 \
25 | --name=tmp \
26 | --resume_from=latest \
27 | --report_to=tensorboard \
28 | --train.num_workers=0 \
29 | --train.visualize=False \
30 | --debug=true \
31 | $@"
32 |
33 | echo $cmd
34 | eval $cmd
35 |
--------------------------------------------------------------------------------