├── .github └── workflows │ ├── bot-autolint.yaml │ └── ci.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CIs └── add_license_all.sh ├── Dockerfile ├── LICENSE ├── README.md ├── app ├── app_sana.py ├── app_sana_4bit.py ├── app_sana_4bit_compare_bf16.py ├── app_sana_controlnet_hed.py ├── app_sana_multithread.py ├── app_sana_sprint.py ├── safety_check.py ├── sana_controlnet_pipeline.py ├── sana_pipeline.py └── sana_sprint_pipeline.py ├── asset ├── Sana.jpg ├── app_styles │ └── controlnet_app_style.css ├── apple.webp ├── controlnet │ ├── ref_images │ │ ├── A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg │ │ ├── a house.png │ │ ├── a living room.png │ │ └── nvidia.png │ └── samples_controlnet.json ├── docs │ ├── ComfyUI │ │ ├── SANA-1.5_FlowEuler.json │ │ ├── Sana_CogVideoX.json │ │ ├── Sana_FlowEuler.json │ │ ├── Sana_FlowEuler_2K.json │ │ ├── Sana_FlowEuler_4K.json │ │ └── comfyui.md │ ├── inference_scaling │ │ ├── inference_scaling.md │ │ ├── results.jpg │ │ └── scaling_curve.jpg │ ├── metrics_toolkit.md │ ├── model_zoo.md │ ├── quantize │ │ ├── 4bit_sana.md │ │ └── 8bit_sana.md │ ├── sana_controlnet.md │ ├── sana_lora_dreambooth.md │ └── sana_sprint.md ├── example_data │ ├── 00000000.jpg │ ├── 00000000.png │ ├── 00000000.txt │ ├── 00000000_InternVL2-26B.json │ ├── 00000000_InternVL2-26B_clip_score.json │ ├── 00000000_VILA1-5-13B.json │ ├── 00000000_VILA1-5-13B_clip_score.json │ ├── 00000000_prompt_clip_score.json │ └── meta_data.json ├── examples.py ├── logo.png ├── mit-logo.jpg ├── model-incremental.jpg ├── model_paths.txt └── samples │ ├── samples.txt │ └── samples_mini.txt ├── configs ├── sana1-5_config │ └── 1024ms │ │ ├── Sana_1600M_1024px_AdamW_fsdp.yaml │ │ ├── Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml │ │ ├── Sana_3200M_1024px_came8bit_grow_constant_allqknorm_bf16_lr2e5.yaml │ │ └── Sana_4800M_1024px_came8bit_grow_constant_allqknorm_bf16_lr2e5.yaml ├── sana_app_config │ ├── Sana_1600M_app.yaml │ └── Sana_600M_app.yaml ├── sana_base.yaml ├── sana_config │ ├── 1024ms │ │ ├── Sana_1600M_img1024.yaml │ │ ├── Sana_1600M_img1024_AdamW.yaml │ │ ├── Sana_1600M_img1024_CAME8bit.yaml │ │ └── Sana_600M_img1024.yaml │ ├── 2048ms │ │ └── Sana_1600M_img2048_bf16.yaml │ ├── 4096ms │ │ └── Sana_1600M_img4096_bf16.yaml │ └── 512ms │ │ ├── Sana_1600M_img512.yaml │ │ ├── Sana_600M_img512.yaml │ │ ├── ci_Sana_600M_img512.yaml │ │ └── sample_dataset.yaml ├── sana_controlnet_config │ ├── Sana_1600M_1024px_controlnet_bf16.yaml │ └── Sana_600M_img1024_controlnet.yaml └── sana_sprint_config │ └── 1024ms │ ├── SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml │ ├── SanaSprint_1600M_img1024_bf16_normT_allqknorm_teacher_ft.yaml │ └── SanaSprint_600M_1024px_allqknorm_bf16_scm_ladd.yaml ├── diffusion ├── __init__.py ├── data │ ├── __init__.py │ ├── builder.py │ ├── datasets │ │ ├── __init__.py │ │ ├── sana_data.py │ │ ├── sana_data_multi_scale.py │ │ └── utils.py │ ├── transforms.py │ └── wids │ │ ├── __init__.py │ │ ├── wids.py │ │ ├── wids_dl.py │ │ ├── wids_lru.py │ │ ├── wids_mmtar.py │ │ ├── wids_specs.py │ │ └── wids_tar.py ├── model │ ├── __init__.py │ ├── act.py │ ├── builder.py │ ├── dc_ae │ │ └── efficientvit │ │ │ ├── __init__.py │ │ │ ├── ae_model_zoo.py │ │ │ ├── apps │ │ │ ├── __init__.py │ │ │ ├── setup.py │ │ │ ├── trainer │ │ │ │ ├── __init__.py │ │ │ │ └── run_config.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── dist.py │ │ │ │ ├── ema.py │ │ │ │ ├── export.py │ │ │ │ ├── image.py │ │ │ │ ├── init.py │ │ │ │ ├── lr.py │ │ │ │ ├── metric.py │ │ │ │ ├── misc.py │ │ │ │ └── opt.py │ │ │ └── models │ │ │ ├── __init__.py │ │ │ ├── efficientvit │ │ │ ├── __init__.py │ │ │ └── dc_ae.py │ │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── act.py │ │ │ ├── drop.py │ │ │ ├── norm.py │ │ │ ├── ops.py │ │ │ └── triton_rms_norm.py │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── list.py │ │ │ ├── network.py │ │ │ └── random.py │ ├── diffusion_utils.py │ ├── dpm_solver.py │ ├── edm_sample.py │ ├── gaussian_diffusion.py │ ├── model_growth_utils.py │ ├── nets │ │ ├── __init__.py │ │ ├── basic_modules.py │ │ ├── fastlinear │ │ │ ├── develop_triton_ffn.py │ │ │ ├── develop_triton_litemla.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── flash_attn.py │ │ │ │ ├── lite_mla.py │ │ │ │ ├── mb_conv_pre_glu.py │ │ │ │ ├── nn │ │ │ │ │ ├── act.py │ │ │ │ │ ├── conv.py │ │ │ │ │ └── norm.py │ │ │ │ ├── triton_lite_mla.py │ │ │ │ ├── triton_lite_mla_fwd.py │ │ │ │ ├── triton_lite_mla_kernels │ │ │ │ │ ├── custom_autotune.py │ │ │ │ │ ├── linear_relu_fwd.py │ │ │ │ │ ├── mm.py │ │ │ │ │ ├── pad_vk_mm_fwd.py │ │ │ │ │ ├── proj_divide_bwd.py │ │ │ │ │ ├── vk_mm_relu_bwd.py │ │ │ │ │ ├── vk_q_mm_divide_fwd.py │ │ │ │ │ └── vk_q_mm_relu_bwd.py │ │ │ │ ├── triton_mb_conv_pre_glu.py │ │ │ │ ├── triton_mb_conv_pre_glu_kernels │ │ │ │ │ ├── depthwise_conv_fwd.py │ │ │ │ │ └── linear_glu_fwd.py │ │ │ │ └── utils │ │ │ │ │ ├── compare_results.py │ │ │ │ │ ├── custom_autotune.py │ │ │ │ │ ├── dtype.py │ │ │ │ │ ├── export_onnx.py │ │ │ │ │ └── model.py │ │ │ └── readme.md │ │ ├── ladd_blocks.py │ │ ├── sana.py │ │ ├── sana_U_shape.py │ │ ├── sana_U_shape_multi_scale.py │ │ ├── sana_blocks.py │ │ ├── sana_ladd.py │ │ ├── sana_multi_scale.py │ │ ├── sana_multi_scale_adaln.py │ │ ├── sana_multi_scale_controlnet.py │ │ └── sana_others.py │ ├── norms.py │ ├── respace.py │ ├── sa_solver.py │ ├── timestep_sampler.py │ └── utils.py ├── scheduler │ ├── __init__.py │ ├── dpm_solver.py │ ├── flow_euler_sampler.py │ ├── iddpm.py │ ├── lcm_scheduler.py │ ├── sa_sampler.py │ ├── sa_solver_diffusers.py │ ├── scm_scheduler.py │ └── trigflow_scheduler.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── config.py │ ├── data_sampler.py │ ├── dist_utils.py │ ├── import_utils.py │ ├── logger.py │ ├── lr_scheduler.py │ ├── misc.py │ └── optimizer.py ├── environment_setup.sh ├── pyproject.toml ├── sana ├── cli │ ├── run.py │ └── upload2hf.py └── tools │ ├── __init__.py │ ├── download.py │ └── hf_utils.py ├── scripts ├── bash_run_inference_metric.sh ├── bash_run_inference_metric_dpg.sh ├── bash_run_inference_metric_geneval.sh ├── bash_run_inference_metric_imagereward.sh ├── infer_metric_run_inference_metric.sh ├── infer_metric_run_inference_metric_geneval.sh ├── infer_run_inference.sh ├── infer_run_inference_geneval.sh ├── infer_run_inference_geneval_diffusers.sh ├── inference.py ├── inference_dpg.py ├── inference_geneval.py ├── inference_geneval_diffusers.py ├── inference_image_reward.py ├── inference_sana_sprint.py ├── inference_sana_sprint_geneval.py ├── interface.py └── style.css ├── tests └── bash │ ├── entry.sh │ ├── test_inference.sh │ └── test_training_1epoch.sh ├── tools ├── __init__.py ├── controlnet │ ├── annotator │ │ ├── ckpts │ │ │ └── ckpts.txt │ │ ├── hed │ │ │ └── __init__.py │ │ └── util.py │ ├── inference_controlnet.py │ └── utils.py ├── convert_ImgDataset_to_WebDatasetMS_format.py ├── convert_py_to_yaml.py ├── convert_sana_to_diffusers.py ├── convert_sana_to_svdquant.py ├── create_wids_metadata.py ├── download.py ├── inference_scaling │ ├── nvila_sana_pick.py │ └── nvila_sana_pick.sh └── metrics │ ├── clip-score │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── clip_score.py │ ├── setup.py │ └── src │ │ └── clip_score │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── clip_score.py │ ├── compute_clipscore.sh │ ├── compute_dpg.sh │ ├── compute_fid_embedding.sh │ ├── compute_geneval.sh │ ├── compute_imagereward.sh │ ├── dpg_bench │ ├── compute_dpg_bench.py │ ├── dpg_bench.csv │ ├── metadata.json │ └── requirements.txt │ ├── geneval │ ├── LICENSE │ ├── README.md │ ├── annotations │ │ ├── annotations_clip.csv │ │ ├── annotations_if-xl.csv │ │ ├── annotations_sdv2.csv │ │ └── mturk_hit_template.html │ ├── environment.yml │ ├── evaluation │ │ ├── download_models.sh │ │ ├── evaluate_images.py │ │ ├── object_names.txt │ │ └── summary_scores.py │ ├── generation │ │ └── diffusers_generate.py │ ├── geneval_env.md │ ├── images │ │ └── geneval_figure_1.png │ └── prompts │ │ ├── create_prompts.py │ │ ├── evaluation_metadata.jsonl │ │ ├── generation_prompts.txt │ │ └── object_names.txt │ ├── image_reward │ ├── benchmark-prompts-dict.json │ └── compute_image_reward.py │ ├── pytorch-fid │ ├── .gitignore │ ├── CHANGELOG.md │ ├── LICENSE │ ├── README.md │ ├── compute_fid.py │ ├── noxfile.py │ ├── setup.cfg │ ├── setup.py │ ├── src │ │ └── pytorch_fid │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ ├── fid_score.py │ │ │ └── inception.py │ └── tests │ │ └── test_fid_score.py │ └── utils.py └── train_scripts ├── train.py ├── train.sh ├── train_dreambooth_lora_sana.py ├── train_lora.sh ├── train_scm_ladd.py └── train_scm_ladd.sh /.github/workflows/bot-autolint.yaml: -------------------------------------------------------------------------------- 1 | name: Auto Lint (triggered by "auto lint" label) 2 | on: 3 | pull_request: 4 | types: 5 | - opened 6 | - edited 7 | - closed 8 | - reopened 9 | - synchronize 10 | - labeled 11 | - unlabeled 12 | # run only one unit test for a branch / tag. 13 | concurrency: 14 | group: ci-lint-${{ github.head_ref || github.ref }} 15 | cancel-in-progress: true 16 | jobs: 17 | lint-by-label: 18 | if: contains(github.event.pull_request.labels.*.name, 'lint wanted') 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: Check out Git repository 22 | uses: actions/checkout@v4 23 | with: 24 | token: ${{ secrets.GITHUB_TOKEN }} 25 | ref: ${{ github.event.pull_request.head.ref }} 26 | - name: Set up Python 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: '3.10' 30 | - name: Test pre-commit hooks 31 | continue-on-error: true 32 | uses: pre-commit/action@v3.0.0 # sync with https://github.com/Efficient-Large-Model/VILA-Internal/blob/main/.github/workflows/pre-commit.yaml 33 | with: 34 | extra_args: --all-files 35 | - name: Check if there are any changes 36 | id: verify_diff 37 | run: | 38 | git diff --quiet . || echo "changed=true" >> $GITHUB_OUTPUT 39 | - name: Commit files 40 | if: steps.verify_diff.outputs.changed == 'true' 41 | run: | 42 | git config --local user.email "action@github.com" 43 | git config --local user.name "GitHub Action" 44 | git add . 45 | git commit -m "[CI-Lint] Fix code style issues with pre-commit ${{ github.sha }}" -a 46 | git push 47 | - name: Remove label(s) after lint 48 | uses: actions-ecosystem/action-remove-labels@v1 49 | with: 50 | labels: lint wanted 51 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | pull_request: 4 | push: 5 | branches: [main, feat/Sana-public, feat/Sana-public-for-NVLab] 6 | concurrency: 7 | group: ci-${{ github.workflow }}-${{ github.ref }} 8 | cancel-in-progress: true 9 | # if: ${{ github.repository == 'Efficient-Large-Model/Sana' }} 10 | jobs: 11 | pre-commit: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Check out Git repository 15 | uses: actions/checkout@v4 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: 3.10.10 20 | - name: Test pre-commit hooks 21 | uses: pre-commit/action@v3.0.1 22 | tests-bash: 23 | # needs: pre-commit 24 | runs-on: self-hosted 25 | steps: 26 | - name: Check out Git repository 27 | uses: actions/checkout@v4 28 | - name: Set up Python 29 | uses: actions/setup-python@v5 30 | with: 31 | python-version: 3.10.10 32 | - name: Set up the environment 33 | run: | 34 | bash environment_setup.sh 35 | - name: Run tests with Slurm 36 | run: | 37 | sana-run --pty -m ci -J tests-bash bash tests/bash/entry.sh 38 | 39 | # tests-python: 40 | # needs: pre-commit 41 | # runs-on: self-hosted 42 | # steps: 43 | # - name: Check out Git repository 44 | # uses: actions/checkout@v4 45 | # - name: Set up Python 46 | # uses: actions/setup-python@v5 47 | # with: 48 | # python-version: 3.10.10 49 | # - name: Set up the environment 50 | # run: | 51 | # ./environment_setup.sh 52 | # - name: Run tests with Slurm 53 | # run: | 54 | # sana-run --pty -m ci -J tests-python pytest tests/python 55 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: trailing-whitespace 6 | name: (Common) Remove trailing whitespaces 7 | - id: mixed-line-ending 8 | name: (Common) Fix mixed line ending 9 | args: [--fix=lf] 10 | - id: end-of-file-fixer 11 | name: (Common) Remove extra EOF newlines 12 | - id: check-merge-conflict 13 | name: (Common) Check for merge conflicts 14 | - id: requirements-txt-fixer 15 | name: (Common) Sort "requirements.txt" 16 | - id: fix-encoding-pragma 17 | name: (Python) Remove encoding pragmas 18 | args: [--remove] 19 | # - id: debug-statements 20 | # name: (Python) Check for debugger imports 21 | - id: check-json 22 | name: (JSON) Check syntax 23 | - id: check-yaml 24 | name: (YAML) Check syntax 25 | - id: check-toml 26 | name: (TOML) Check syntax 27 | # - repo: https://github.com/shellcheck-py/shellcheck-py 28 | # rev: v0.10.0.1 29 | # hooks: 30 | # - id: shellcheck 31 | - repo: https://github.com/google/yamlfmt 32 | rev: v0.13.0 33 | hooks: 34 | - id: yamlfmt 35 | - repo: https://github.com/executablebooks/mdformat 36 | rev: 0.7.16 37 | hooks: 38 | - id: mdformat 39 | name: (Markdown) Format docs with mdformat 40 | - repo: https://github.com/asottile/pyupgrade 41 | rev: v3.2.2 42 | hooks: 43 | - id: pyupgrade 44 | name: (Python) Update syntax for newer versions 45 | args: [--py37-plus] 46 | - repo: https://github.com/psf/black 47 | rev: 22.10.0 48 | hooks: 49 | - id: black 50 | name: (Python) Format code with black 51 | - repo: https://github.com/pycqa/isort 52 | rev: 5.12.0 53 | hooks: 54 | - id: isort 55 | name: (Python) Sort imports with isort 56 | - repo: https://github.com/pre-commit/mirrors-clang-format 57 | rev: v15.0.4 58 | hooks: 59 | - id: clang-format 60 | name: (C/C++/CUDA) Format code with clang-format 61 | args: [-style=google, -i] 62 | types_or: [c, c++, cuda] 63 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: 'SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer' 3 | message: >- 4 | If you use this software or research, please cite it using the 5 | metadata from this file. 6 | type: misc 7 | authors: 8 | - given-names: Enze 9 | family-names: Xie 10 | - given-names: Junsong 11 | family-names: Chen 12 | - given-names: Junyu 13 | family-names: Chen 14 | - given-names: Han 15 | family-names: Cai 16 | - given-names: Haotian 17 | family-names: Tang 18 | - given-names: Yujun 19 | family-names: Lin 20 | - given-names: Zhekai 21 | family-names: Zhang 22 | - given-names: Muyang 23 | family-names: Li 24 | - given-names: Ligeng 25 | family-names: Zhu 26 | - given-names: Yao 27 | family-names: Lu 28 | - given-names: Song 29 | family-names: Han 30 | repository-code: 'https://github.com/NVlabs/Sana' 31 | abstract: >- 32 | SANA proposes an efficient linear Diffusion Transformer (DiT) for high-resolution 33 | image synthesis, featuring a depth-growth paradigm, model pruning techniques, 34 | and inference-time scaling strategies to reduce training costs while maintaining 35 | generation quality. SANA-Sprint also achieves one-step generation of high-resolution images 36 | keywords: 37 | - deep-learning 38 | - diffusion-models 39 | - transformer 40 | - image-generation 41 | - text-to-image 42 | - efficient-training 43 | - distillation 44 | license: Apache-2.0 45 | version: 1.5.0 46 | doi: 10.48550/arXiv.2410.10629 47 | date-released: 2024-10-16 48 | -------------------------------------------------------------------------------- /CIs/add_license_all.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | addlicense -s -c 'NVIDIA CORPORATION & AFFILIATES' -ignore "**/*__init__.py" **/*.py 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.06-py3 2 | 3 | ENV PATH=/opt/conda/bin:$PATH 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | libgl1-mesa-glx \ 7 | libglib2.0-0 \ 8 | && rm -rf /var/lib/apt/lists/* 9 | 10 | WORKDIR /app 11 | 12 | RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \ 13 | && sh ~/miniconda.sh -b -p /opt/conda \ 14 | && rm ~/miniconda.sh 15 | 16 | COPY pyproject.toml pyproject.toml 17 | COPY diffusion diffusion 18 | COPY configs configs 19 | COPY sana sana 20 | COPY app app 21 | COPY tools tools 22 | 23 | COPY environment_setup.sh environment_setup.sh 24 | RUN ./environment_setup.sh 25 | 26 | CMD ["python", "-u", "-W", "ignore", "app/app_sana.py", "--share", "--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml", "--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth"] 27 | -------------------------------------------------------------------------------- /app/safety_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py 17 | import json 18 | from functools import lru_cache 19 | from typing import Dict, Sequence 20 | 21 | import torch 22 | import transformers 23 | from torch.nn.functional import softmax 24 | from transformers import AutoModelForCausalLM, AutoTokenizer 25 | 26 | __all__ = ["is_dangerous"] 27 | 28 | __rank, __local_rank, __world_size, __device = ( 29 | 0, 30 | 0, 31 | 1, 32 | "cuda" if torch.cuda.is_available() else "cpu", 33 | ) 34 | 35 | 36 | safety_policy = """ 37 | The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups. 38 | """ 39 | 40 | 41 | def is_dangerous(tokenizer, model, user_prompt, threshold=0.5): 42 | prompt = f"""You are a policy expert trying to help determine whether a user 43 | prompt is in violation of the defined safety policies. 44 | 45 | 46 | Human Question: {user_prompt.strip()} 47 | 48 | 49 | Our safety principle is defined in the below: 50 | 51 | {safety_policy.strip()} 52 | 53 | Does the human question violate the above principle? Your answer must start 54 | with 'Yes' or 'No'. And then walk through step by step to be sure we answer 55 | correctly. 56 | """ 57 | 58 | inputs = tokenizer(prompt, return_tensors="pt").to("cuda") 59 | with torch.no_grad(): 60 | logits = model(**inputs).logits 61 | 62 | # Extract the logits for the Yes and No tokens 63 | vocab = tokenizer.get_vocab() 64 | selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]] 65 | 66 | # Convert these logits to a probability with softmax 67 | probabilities = softmax(selected_logits, dim=0) 68 | 69 | # Return probability of 'Yes' 70 | score = probabilities[0].item() 71 | 72 | return score > threshold 73 | -------------------------------------------------------------------------------- /asset/Sana.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/Sana.jpg -------------------------------------------------------------------------------- /asset/app_styles/controlnet_app_style.css: -------------------------------------------------------------------------------- 1 | @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css'); 2 | 3 | body{align-items: center;} 4 | .gradio-container{max-width: 1200px !important} 5 | h1{text-align:center} 6 | 7 | .wrap.svelte-p4aq0j.svelte-p4aq0j { 8 | display: none; 9 | } 10 | 11 | #column_input, #column_output { 12 | width: 500px; 13 | display: flex; 14 | align-items: center; 15 | } 16 | 17 | #input_header, #output_header { 18 | display: flex; 19 | justify-content: center; 20 | align-items: center; 21 | width: 400px; 22 | } 23 | 24 | #accessibility { 25 | text-align: center; /* Center-aligns the text */ 26 | margin: auto; /* Centers the element horizontally */ 27 | } 28 | 29 | #random_seed {height: 71px;} 30 | #run_button {height: 87px;} 31 | -------------------------------------------------------------------------------- /asset/apple.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/apple.webp -------------------------------------------------------------------------------- /asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg -------------------------------------------------------------------------------- /asset/controlnet/ref_images/a house.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/controlnet/ref_images/a house.png -------------------------------------------------------------------------------- /asset/controlnet/ref_images/a living room.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/controlnet/ref_images/a living room.png -------------------------------------------------------------------------------- /asset/controlnet/ref_images/nvidia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/controlnet/ref_images/nvidia.png -------------------------------------------------------------------------------- /asset/controlnet/samples_controlnet.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "prompt": "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.", 4 | "ref_image_path": "asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg" 5 | }, 6 | { 7 | "prompt": "an architecture in INDIA,15th-18th style, with a lot of details", 8 | "ref_image_path": "asset/controlnet/ref_images/a house.png" 9 | }, 10 | { 11 | "prompt": "An IKEA modern style living room with sofa, coffee table, stairs, etc., a brand new theme.", 12 | "ref_image_path": "asset/controlnet/ref_images/a living room.png" 13 | }, 14 | { 15 | "prompt": "A modern new living room with sofa, coffee table, carpet, stairs, etc., high quality high detail, high resolution.", 16 | "ref_image_path": "asset/controlnet/ref_images/a living room.png" 17 | }, 18 | { 19 | "prompt": "big eye, vibrant colors, intricate details, captivating gaze, surreal, dreamlike, fantasy, enchanting, mysterious, magical, moonlit, mystical, ethereal, enchanting {macro lens, high aperture, low ISO}", 20 | "ref_image_path": "asset/controlnet/ref_images/nvidia.png" 21 | }, 22 | { 23 | "prompt": "shining eye, bright and vivid colors, radiant glow, sparkling reflections, joyful, uplifting, optimistic, hopeful, magical, luminous, celestial, dreamy {zoom lens, high aperture, natural light, vibrant color film}", 24 | "ref_image_path": "asset/controlnet/ref_images/nvidia.png" 25 | } 26 | ] 27 | -------------------------------------------------------------------------------- /asset/docs/ComfyUI/comfyui.md: -------------------------------------------------------------------------------- 1 | ## 🖌️ Sana-ComfyUI 2 | 3 | [Original Repo](https://github.com/city96/ComfyUI_ExtraModels) 4 | 5 | ### Model info / implementation 6 | 7 | - Uses Gemma2 2B as the text encoder 8 | - Multiple resolutions and models available 9 | - Compressed latent space (32 channels, /32 compression) - needs custom VAE 10 | 11 | ### Usage 12 | 13 | 1. All the checkpoints will be downloaded automatically. 14 | 1. KSampler(Flow Euler) is available for now; Flow DPM-Solver will be available soon. 15 | 16 | ```bash 17 | git clone https://github.com/comfyanonymous/ComfyUI.git 18 | cd ComfyUI 19 | git clone https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels.git custom_nodes/ComfyUI_ExtraModels 20 | 21 | python main.py 22 | ``` 23 | 24 | ### A sample workflow for Sana 25 | 26 | [Sana workflow](Sana_FlowEuler.json) 27 | 28 | ![Sana](https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/page/asset/content/comfyui/sana.jpg) 29 | 30 | ### A sample for T2I(Sana) + I2V(CogVideoX) 31 | 32 | [Sana + CogVideoX workflow](Sana_CogVideoX.json) 33 | 34 | [![Sample T2I + I2V](https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/page/asset/content/comfyui/sana-cogvideox.jpg)](https://nvlabs.github.io/Sana/asset/content/comfyui/Sana_CogVideoX_Fun.mp4) 35 | 36 | ### A sample workflow for Sana 4096x4096 image (18GB GPU is needed) 37 | 38 | [Sana workflow](Sana_FlowEuler_4K.json) 39 | 40 | ![Sana](https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/page/asset/content/comfyui/Sana_4K_workflow.jpg) 41 | -------------------------------------------------------------------------------- /asset/docs/inference_scaling/inference_scaling.md: -------------------------------------------------------------------------------- 1 | ## Inference Time Scaling for SANA-1.5 2 | 3 | ![results](results.jpg) 4 | 5 | We trained a specialized [NVILA-2B](https://huggingface.co/Efficient-Large-Model/NVILA-Lite-2B-Verifier) model to score images, which we named VISA (VIla as SAna verifier). By selecting the top 4 images from 2,048 candidates, we enhanced the GenEval performance of SD1.5 and SANA-1.5-4.8B v2, increasing their scores from 42 to 87 and 81 to 96, respectively. 6 | 7 | ![curve](scaling_curve.jpg) 8 | 9 | Even for smaller number of candidates, like 32, we can also push the performance over 90% for SANA-1.5-4.8B v2 in the GenEval. 10 | 11 | ### Environment Requirement 12 | 13 | Dependency setups: 14 | 15 | ```bash 16 | # other transformers version may also work, but we have not tested 17 | pip install transformers==4.46 18 | pip install git+https://github.com/bfshi/scaling_on_scales.git 19 | ``` 20 | 21 | ### 1. Generate N images with a .pth file for the following selection 22 | 23 | ```bash 24 | # download the checkpoint for the following generation 25 | huggingface-cli download Efficient-Large-Model/Sana_600M_512px --repo-type model --local-dir output/Sana_600M_512px --local-dir-use-symlinks False 26 | # 32 is a relatively small number for test but can already push the geneval>90% when we verify the SANA-1.5-4.8B v2 model. Set it to larger number like 2048 for the limit of sky. 27 | n_samples=32 28 | pick_number=4 29 | 30 | output_dir=output/geneval_generated_path 31 | # example 32 | bash scripts/infer_run_inference_geneval.sh \ 33 | configs/sana_config/512ms/Sana_600M_img512.yaml \ 34 | output/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth \ 35 | --img_nums_per_sample=$n_samples \ 36 | --output_dir=$output_dir 37 | ``` 38 | 39 | ### 2. Use NVILA-Verifier to select from the generated images 40 | 41 | ```bash 42 | bash tools/inference_scaling/nvila_sana_pick.sh \ 43 | $output_dir \ 44 | $n_samples \ 45 | $pick_number 46 | ``` 47 | 48 | ### 3. Calculate the GenEval metric 49 | 50 | You need to use the GenEval environment for the final evaluation. The document about installation can be found [here](../../../tools/metrics/geneval/geneval_env.md). 51 | 52 | ```bash 53 | # activate geneval env 54 | conda activate geneval 55 | 56 | DIR_AFTER_PICK="output/nvila_pick/best_${pick_number}_of_${n_samples}/${output_dir}" 57 | 58 | bash tools/metrics/compute_geneval.sh $(dirname "$DIR_AFTER_PICK") $(basename "$DIR_AFTER_PICK") 59 | ``` 60 | -------------------------------------------------------------------------------- /asset/docs/inference_scaling/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/docs/inference_scaling/results.jpg -------------------------------------------------------------------------------- /asset/docs/inference_scaling/scaling_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/docs/inference_scaling/scaling_curve.jpg -------------------------------------------------------------------------------- /asset/docs/quantize/4bit_sana.md: -------------------------------------------------------------------------------- 1 | 15 | 16 | # 4bit SanaPipeline 17 | 18 | ### 1. Environment setup 19 | 20 | Follow the official [SVDQuant-Nunchaku](https://github.com/mit-han-lab/nunchaku) repository to set up the environment. The guidance can be found [here](https://github.com/mit-han-lab/nunchaku?tab=readme-ov-file#installation). 21 | 22 | ### 1-1. Quantize Sana with SVDQuant-4bit (Optional) 23 | 24 | 1. Convert pth to SVDQuant required safetensor 25 | 26 | ``` 27 | python tools/convert_sana_to_svdquant.py \ 28 | --orig_ckpt_path Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth \ 29 | --model_type SanaMS1.5_1600M_P1_D20 \ 30 | --dtype bf16 \ 31 | --dump_path output/SANA1.5_1.6B_1024px_svdquant_diffusers \ 32 | --save_full_pipeline 33 | ``` 34 | 35 | 2. follow the guidance to compress model 36 | [Quantization guidance](https://github.com/mit-han-lab/deepcompressor/tree/main/examples/diffusion) 37 | 38 | ### 2. Code snap for inference 39 | 40 | Here we show the code snippet for SanaPipeline. For SanaPAGPipeline, please refer to the [SanaPAGPipeline](https://github.com/mit-han-lab/nunchaku/blob/main/examples/sana_1600m_pag.py) section. 41 | 42 | ```python 43 | import torch 44 | from diffusers import SanaPipeline 45 | 46 | from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel 47 | 48 | transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") 49 | pipe = SanaPipeline.from_pretrained( 50 | "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", 51 | transformer=transformer, 52 | variant="bf16", 53 | torch_dtype=torch.bfloat16, 54 | ).to("cuda") 55 | 56 | pipe.text_encoder.to(torch.bfloat16) 57 | pipe.vae.to(torch.bfloat16) 58 | 59 | image = pipe( 60 | prompt="A cute 🐼 eating 🎋, ink drawing style", 61 | height=1024, 62 | width=1024, 63 | guidance_scale=4.5, 64 | num_inference_steps=20, 65 | generator=torch.Generator().manual_seed(42), 66 | ).images[0] 67 | image.save("sana_1600m.png") 68 | ``` 69 | 70 | ### 3. Online demo 71 | 72 | 1). Launch the 4bit Sana. 73 | 74 | ```bash 75 | python app/app_sana_4bit.py 76 | ``` 77 | 78 | 2). Compare with BF16 version 79 | 80 | Refer to the original [Nunchaku-Sana.](https://github.com/mit-han-lab/nunchaku/tree/main/app/sana/t2i) guidance for SanaPAGPipeline 81 | 82 | ```bash 83 | python app/app_sana_4bit_compare_bf16.py 84 | ``` 85 | -------------------------------------------------------------------------------- /asset/docs/sana_controlnet.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | ## 🔥 ControlNet 18 | 19 | We incorporate a ControlNet-like(https://github.com/lllyasviel/ControlNet) module enables fine-grained control over text-to-image diffusion models. We implement a ControlNet-Transformer architecture, specifically tailored for Transformers, achieving explicit controllability alongside high-quality image generation. 20 | 21 |

22 | 23 |

24 | 25 | ## Inference of `Sana + ControlNet` 26 | 27 | ### 1). Gradio Interface 28 | 29 | ```bash 30 | python app/app_sana_controlnet_hed.py \ 31 | --config configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml \ 32 | --model_path hf://Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED/checkpoints/Sana_1600M_1024px_BF16_ControlNet_HED.pth 33 | ``` 34 | 35 |

36 | teaser_page2 37 |

38 | 39 | ### 2). Inference with JSON file 40 | 41 | ```bash 42 | python tools/controlnet/inference_controlnet.py \ 43 | --config configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml \ 44 | --model_path hf://Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED/checkpoints/Sana_1600M_1024px_BF16_ControlNet_HED.pth \ 45 | --json_file asset/controlnet/samples_controlnet.json 46 | ``` 47 | 48 | ### 3). Inference code snap 49 | 50 | ```python 51 | import torch 52 | from PIL import Image 53 | from app.sana_controlnet_pipeline import SanaControlNetPipeline 54 | 55 | device = "cuda" if torch.cuda.is_available() else "cpu" 56 | 57 | pipe = SanaControlNetPipeline("configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml") 58 | pipe.from_pretrained("hf://Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED/checkpoints/Sana_1600M_1024px_BF16_ControlNet_HED.pth") 59 | 60 | ref_image = Image.open("asset/controlnet/ref_images/A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a la.jpg") 61 | prompt = "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape." 62 | 63 | images = pipe( 64 | prompt=prompt, 65 | ref_image=ref_image, 66 | guidance_scale=4.5, 67 | num_inference_steps=10, 68 | sketch_thickness=2, 69 | generator=torch.Generator(device=device).manual_seed(0), 70 | ) 71 | ``` 72 | 73 | ## Training of `Sana + ControlNet` 74 | 75 | ### Coming soon 76 | -------------------------------------------------------------------------------- /asset/example_data/00000000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/example_data/00000000.jpg -------------------------------------------------------------------------------- /asset/example_data/00000000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/example_data/00000000.png -------------------------------------------------------------------------------- /asset/example_data/00000000.txt: -------------------------------------------------------------------------------- 1 | a cyberpunk cat with a neon sign that says "Sana". 2 | -------------------------------------------------------------------------------- /asset/example_data/00000000_InternVL2-26B.json: -------------------------------------------------------------------------------- 1 | { 2 | "00000000": { 3 | "InternVL2-26B": "a cyberpunk cat with a neon sign that says 'Sana'" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /asset/example_data/00000000_InternVL2-26B_clip_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "00000000": { 3 | "InternVL2-26B": "27.1037" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /asset/example_data/00000000_VILA1-5-13B.json: -------------------------------------------------------------------------------- 1 | { 2 | "00000000": { 3 | "VILA1-5-13B": "a cyberpunk cat with a neon sign that says 'Sana'" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /asset/example_data/00000000_VILA1-5-13B_clip_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "00000000": { 3 | "VILA1-5-13B": "27.2321" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /asset/example_data/00000000_prompt_clip_score.json: -------------------------------------------------------------------------------- 1 | { 2 | "00000000": { 3 | "prompt": "26.7331" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /asset/example_data/meta_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sana-dev", 3 | "__kind__": "Sana-ImgDataset", 4 | "img_names": [ 5 | "00000000", "00000000", "00000000.png", "00000000.jpg" 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /asset/examples.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | examples = [ 18 | [ 19 | "A small cactus with a happy face in the Sahara desert.", 20 | "flow_dpm-solver", 21 | 20, 22 | 5.0, 23 | 2.5, 24 | ], 25 | [ 26 | "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history" 27 | "of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits " 28 | "mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret " 29 | "and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile " 30 | "as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and " 31 | "the Parisian streets and city in the background, depth of field, cinematic 35mm film.", 32 | "flow_dpm-solver", 33 | 20, 34 | 5.0, 35 | 2.5, 36 | ], 37 | [ 38 | "An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. " 39 | "Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. " 40 | "The quote 'Find the universe within you' is etched in bold letters across the horizon." 41 | "blue and pink, brilliantly illuminated in the background.", 42 | "flow_dpm-solver", 43 | 20, 44 | 5.0, 45 | 2.5, 46 | ], 47 | [ 48 | "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.", 49 | "flow_dpm-solver", 50 | 20, 51 | 5.0, 52 | 2.5, 53 | ], 54 | [ 55 | "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.", 56 | "flow_dpm-solver", 57 | 20, 58 | 5.0, 59 | 2.5, 60 | ], 61 | [ 62 | "a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, " 63 | "national geographic photo, 8k resolution, crayon art, interactive artwork", 64 | "flow_dpm-solver", 65 | 20, 66 | 5.0, 67 | 2.5, 68 | ], 69 | ] 70 | -------------------------------------------------------------------------------- /asset/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/logo.png -------------------------------------------------------------------------------- /asset/mit-logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/mit-logo.jpg -------------------------------------------------------------------------------- /asset/model-incremental.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/asset/model-incremental.jpg -------------------------------------------------------------------------------- /asset/model_paths.txt: -------------------------------------------------------------------------------- 1 | output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth 2 | output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth 3 | -------------------------------------------------------------------------------- /asset/samples/samples_mini.txt: -------------------------------------------------------------------------------- 1 | A cyberpunk cat with a neon sign that says 'Sana'. 2 | A small cactus with a happy face in the Sahara desert. 3 | The towel was on top of the hard counter. 4 | A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds. 5 | I want to supplement vitamin c, please help me paint related food. 6 | A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape. 7 | an old rusted robot wearing pants and a jacket riding skis in a supermarket. 8 | professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest. 9 | Astronaut in a jungle, cold color palette, muted colors, detailed 10 | a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests. 11 | -------------------------------------------------------------------------------- /configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [data/toy_data] 3 | image_size: 1024 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: [] 7 | external_clipscore_suffixes: [] 8 | clip_thr_temperature: 0.1 9 | clip_thr: 25.0 10 | del_img_clip_thr: 22.0 11 | load_text_feat: false 12 | load_vae_feat: true 13 | transform: default_train 14 | type: SanaWebDatasetMS 15 | sort_dataset: false 16 | # model config 17 | model: 18 | model: SanaMS_1600M_P1_D20 19 | image_size: 1024 20 | mixed_precision: bf16 21 | fp32_attention: true 22 | load_from: hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth 23 | aspect_ratio_type: ASPECT_RATIO_1024 24 | multi_scale: true 25 | attn_type: linear 26 | ffn_type: glumbconv 27 | mlp_acts: 28 | - silu 29 | - silu 30 | - 31 | mlp_ratio: 2.5 32 | use_pe: false 33 | qk_norm: true 34 | cross_norm: true 35 | class_dropout_prob: 0.1 36 | # VAE setting 37 | vae: 38 | vae_type: AutoencoderDC 39 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 40 | scale_factor: 0.41407 41 | vae_latent_dim: 32 42 | vae_downsample_rate: 32 43 | sample_posterior: true 44 | # text encoder 45 | text_encoder: 46 | text_encoder_name: gemma-2-2b-it 47 | y_norm: true 48 | y_norm_scale_factor: 0.01 49 | model_max_length: 300 50 | # CHI 51 | chi_prompt: 52 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 53 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 54 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 55 | - 'Here are examples of how to transform or refine prompts:' 56 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 57 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 58 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 59 | - 'User Prompt: ' 60 | # Sana schedule Flow 61 | scheduler: 62 | predict_flow_v: true 63 | noise_schedule: linear_flow 64 | pred_sigma: false 65 | flow_shift: 3.0 66 | # logit-normal timestep 67 | weighting_scheme: logit_normal 68 | logit_mean: 0.0 69 | logit_std: 1.0 70 | vis_sampler: flow_dpm-solver 71 | # training setting 72 | train: 73 | use_fsdp: true 74 | num_workers: 10 75 | seed: 1 76 | train_batch_size: 32 77 | num_epochs: 100 78 | gradient_accumulation_steps: 1 79 | grad_checkpointing: true 80 | gradient_clip: 0.1 81 | optimizer: 82 | betas: 83 | - 0.9 84 | - 0.999 85 | eps: 1.0e-10 86 | lr: 2.0e-5 87 | type: AdamW 88 | weight_decay: 0.0 89 | lr_schedule: constant 90 | lr_schedule_args: 91 | num_warmup_steps: 1000 92 | local_save_vis: true # if save log image locally 93 | visualize: true 94 | eval_sampling_steps: 500 95 | log_interval: 20 96 | save_model_epochs: 5 97 | save_model_steps: 500 98 | work_dir: output/debug 99 | online_metric: false 100 | eval_metric_step: 2000 101 | online_metric_dir: metric_helper 102 | -------------------------------------------------------------------------------- /configs/sana_app_config/Sana_1600M_app.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [] 3 | image_size: 1024 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: [] 7 | external_clipscore_suffixes: [] 8 | clip_thr_temperature: 0.1 9 | clip_thr: 25.0 10 | load_text_feat: false 11 | load_vae_feat: false 12 | transform: default_train 13 | type: SanaWebDatasetMS 14 | data: 15 | sort_dataset: false 16 | # model config 17 | model: 18 | model: SanaMS_1600M_P1_D20 19 | image_size: 1024 20 | mixed_precision: fp16 # ['fp16', 'fp32', 'bf16'] 21 | fp32_attention: true 22 | load_from: 23 | resume_from: 24 | aspect_ratio_type: ASPECT_RATIO_1024 25 | multi_scale: true 26 | #pe_interpolation: 1. 27 | attn_type: linear 28 | ffn_type: glumbconv 29 | mlp_acts: 30 | - silu 31 | - silu 32 | - 33 | mlp_ratio: 2.5 34 | use_pe: false 35 | qk_norm: false 36 | class_dropout_prob: 0.1 37 | # CFG & PAG settings 38 | pag_applied_layers: 39 | - 8 40 | # VAE setting 41 | vae: 42 | vae_type: AutoencoderDC 43 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 44 | scale_factor: 0.41407 45 | vae_latent_dim: 32 46 | vae_downsample_rate: 32 47 | sample_posterior: true 48 | # text encoder 49 | text_encoder: 50 | text_encoder_name: gemma-2-2b-it 51 | y_norm: true 52 | y_norm_scale_factor: 0.01 53 | model_max_length: 300 54 | # CHI 55 | chi_prompt: 56 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 57 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 58 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 59 | - 'Here are examples of how to transform or refine prompts:' 60 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 61 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 62 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 63 | - 'User Prompt: ' 64 | # Sana schedule Flow 65 | scheduler: 66 | predict_flow_v: true 67 | noise_schedule: linear_flow 68 | pred_sigma: false 69 | flow_shift: 3.0 70 | # logit-normal timestep 71 | weighting_scheme: logit_normal 72 | logit_mean: 0.0 73 | logit_std: 1.0 74 | vis_sampler: flow_dpm-solver 75 | # training setting 76 | train: 77 | num_workers: 10 78 | seed: 1 79 | train_batch_size: 64 80 | num_epochs: 100 81 | gradient_accumulation_steps: 1 82 | grad_checkpointing: true 83 | gradient_clip: 0.1 84 | optimizer: 85 | betas: 86 | - 0.9 87 | - 0.999 88 | - 0.9999 89 | eps: 90 | - 1.0e-30 91 | - 1.0e-16 92 | lr: 0.0001 93 | type: CAMEWrapper 94 | weight_decay: 0.0 95 | lr_schedule: constant 96 | lr_schedule_args: 97 | num_warmup_steps: 2000 98 | local_save_vis: true # if save log image locally 99 | visualize: true 100 | eval_sampling_steps: 500 101 | log_interval: 20 102 | save_model_epochs: 5 103 | save_model_steps: 500 104 | work_dir: output/debug 105 | online_metric: false 106 | eval_metric_step: 2000 107 | online_metric_dir: metric_helper 108 | -------------------------------------------------------------------------------- /configs/sana_app_config/Sana_600M_app.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [] 3 | image_size: 1024 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: [] 7 | external_clipscore_suffixes: [] 8 | clip_thr_temperature: 0.1 9 | clip_thr: 25.0 10 | load_text_feat: false 11 | load_vae_feat: true 12 | transform: default_train 13 | type: SanaWebDatasetMS 14 | sort_dataset: false 15 | # model config 16 | model: 17 | model: SanaMS_600M_P1_D28 18 | image_size: 1024 19 | mixed_precision: fp16 # ['fp16', 'fp32', 'bf16'] 20 | fp32_attention: true 21 | load_from: 22 | resume_from: 23 | aspect_ratio_type: ASPECT_RATIO_1024 24 | multi_scale: true 25 | attn_type: linear 26 | ffn_type: glumbconv 27 | mlp_acts: 28 | - silu 29 | - silu 30 | - 31 | mlp_ratio: 2.5 32 | use_pe: false 33 | qk_norm: false 34 | class_dropout_prob: 0.1 35 | # CFG & PAG settings 36 | pag_applied_layers: 37 | - 14 38 | # VAE setting 39 | vae: 40 | vae_type: AutoencoderDC 41 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 42 | scale_factor: 0.41407 43 | vae_latent_dim: 32 44 | vae_downsample_rate: 32 45 | sample_posterior: true 46 | # text encoder 47 | text_encoder: 48 | text_encoder_name: gemma-2-2b-it 49 | y_norm: true 50 | y_norm_scale_factor: 0.01 51 | model_max_length: 300 52 | # CHI 53 | chi_prompt: 54 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 55 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 56 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 57 | - 'Here are examples of how to transform or refine prompts:' 58 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 59 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 60 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 61 | - 'User Prompt: ' 62 | # Sana schedule Flow 63 | scheduler: 64 | predict_flow_v: true 65 | noise_schedule: linear_flow 66 | pred_sigma: false 67 | flow_shift: 4.0 68 | # logit-normal timestep 69 | weighting_scheme: logit_normal 70 | logit_mean: 0.0 71 | logit_std: 1.0 72 | vis_sampler: flow_dpm-solver 73 | # training setting 74 | train: 75 | num_workers: 10 76 | seed: 1 77 | train_batch_size: 64 78 | num_epochs: 100 79 | gradient_accumulation_steps: 1 80 | grad_checkpointing: true 81 | gradient_clip: 0.1 82 | optimizer: 83 | betas: 84 | - 0.9 85 | - 0.999 86 | - 0.9999 87 | eps: 88 | - 1.0e-30 89 | - 1.0e-16 90 | lr: 0.0001 91 | type: CAMEWrapper 92 | weight_decay: 0.0 93 | lr_schedule: constant 94 | lr_schedule_args: 95 | num_warmup_steps: 2000 96 | local_save_vis: true # if save log image locally 97 | visualize: true 98 | eval_sampling_steps: 500 99 | log_interval: 20 100 | save_model_epochs: 5 101 | save_model_steps: 500 102 | work_dir: output/debug 103 | online_metric: false 104 | eval_metric_step: 2000 105 | online_metric_dir: metric_helper 106 | -------------------------------------------------------------------------------- /configs/sana_config/1024ms/Sana_1600M_img1024.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [data/toy_data] 3 | image_size: 1024 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] 7 | external_clipscore_suffixes: 8 | - _InternVL2-26B_clip_score 9 | - _VILA1-5-13B_clip_score 10 | - _prompt_clip_score 11 | clip_thr_temperature: 0.1 12 | clip_thr: 25.0 13 | load_text_feat: false 14 | load_vae_feat: false 15 | transform: default_train 16 | type: SanaWebDatasetMS 17 | sort_dataset: false 18 | # model config 19 | model: 20 | model: SanaMS_1600M_P1_D20 21 | image_size: 1024 22 | mixed_precision: bf16 # ['fp16', 'fp32', 'bf16'] 23 | fp32_attention: true 24 | load_from: 25 | resume_from: 26 | aspect_ratio_type: ASPECT_RATIO_1024 27 | multi_scale: true 28 | #pe_interpolation: 1. 29 | attn_type: linear 30 | ffn_type: glumbconv 31 | mlp_acts: 32 | - silu 33 | - silu 34 | - 35 | mlp_ratio: 2.5 36 | use_pe: false 37 | qk_norm: false 38 | class_dropout_prob: 0.1 39 | # PAG 40 | pag_applied_layers: 41 | - 8 42 | # VAE setting 43 | vae: 44 | vae_type: AutoencoderDC 45 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 46 | scale_factor: 0.41407 47 | vae_latent_dim: 32 48 | vae_downsample_rate: 32 49 | sample_posterior: true 50 | # text encoder 51 | text_encoder: 52 | text_encoder_name: gemma-2-2b-it 53 | y_norm: true 54 | y_norm_scale_factor: 0.01 55 | model_max_length: 300 56 | # CHI 57 | chi_prompt: 58 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 59 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 60 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 61 | - 'Here are examples of how to transform or refine prompts:' 62 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 63 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 64 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 65 | - 'User Prompt: ' 66 | # Sana schedule Flow 67 | scheduler: 68 | predict_flow_v: true 69 | noise_schedule: linear_flow 70 | pred_sigma: false 71 | flow_shift: 3.0 72 | # logit-normal timestep 73 | weighting_scheme: logit_normal 74 | logit_mean: 0.0 75 | logit_std: 1.0 76 | vis_sampler: flow_dpm-solver 77 | # training setting 78 | train: 79 | num_workers: 10 80 | seed: 1 81 | train_batch_size: 64 82 | num_epochs: 100 83 | gradient_accumulation_steps: 1 84 | grad_checkpointing: true 85 | gradient_clip: 0.1 86 | optimizer: 87 | betas: 88 | - 0.9 89 | - 0.999 90 | - 0.9999 91 | eps: 92 | - 1.0e-30 93 | - 1.0e-16 94 | lr: 0.0001 95 | type: CAMEWrapper 96 | weight_decay: 0.0 97 | lr_schedule: constant 98 | lr_schedule_args: 99 | num_warmup_steps: 2000 100 | local_save_vis: true # if save log image locally 101 | visualize: true 102 | eval_sampling_steps: 500 103 | log_interval: 20 104 | save_model_epochs: 5 105 | save_model_steps: 500 106 | work_dir: output/debug 107 | online_metric: false 108 | eval_metric_step: 2000 109 | online_metric_dir: metric_helper 110 | -------------------------------------------------------------------------------- /configs/sana_config/1024ms/Sana_1600M_img1024_AdamW.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [data/toy_data] 3 | image_size: 1024 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] 7 | external_clipscore_suffixes: 8 | - _InternVL2-26B_clip_score 9 | - _VILA1-5-13B_clip_score 10 | - _prompt_clip_score 11 | clip_thr_temperature: 0.1 12 | clip_thr: 25.0 13 | load_text_feat: false 14 | load_vae_feat: false 15 | transform: default_train 16 | type: SanaWebDatasetMS 17 | sort_dataset: false 18 | # model config 19 | model: 20 | model: SanaMS_1600M_P1_D20 21 | image_size: 1024 22 | mixed_precision: fp16 # ['fp16', 'fp32', 'bf16'] 23 | fp32_attention: true 24 | load_from: 25 | resume_from: 26 | aspect_ratio_type: ASPECT_RATIO_1024 27 | multi_scale: true 28 | #pe_interpolation: 1. 29 | attn_type: linear 30 | ffn_type: glumbconv 31 | mlp_acts: 32 | - silu 33 | - silu 34 | - 35 | mlp_ratio: 2.5 36 | use_pe: false 37 | qk_norm: false 38 | class_dropout_prob: 0.1 39 | # PAG 40 | pag_applied_layers: 41 | - 8 42 | # VAE setting 43 | vae: 44 | vae_type: AutoencoderDC 45 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 46 | scale_factor: 0.41407 47 | vae_latent_dim: 32 48 | vae_downsample_rate: 32 49 | sample_posterior: true 50 | # text encoder 51 | text_encoder: 52 | text_encoder_name: gemma-2-2b-it 53 | y_norm: true 54 | y_norm_scale_factor: 0.01 55 | model_max_length: 300 56 | # CHI 57 | chi_prompt: 58 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 59 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 60 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 61 | - 'Here are examples of how to transform or refine prompts:' 62 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 63 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 64 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 65 | - 'User Prompt: ' 66 | # Sana schedule Flow 67 | scheduler: 68 | predict_flow_v: true 69 | noise_schedule: linear_flow 70 | pred_sigma: false 71 | flow_shift: 3.0 72 | # logit-normal timestep 73 | weighting_scheme: logit_normal 74 | logit_mean: 0.0 75 | logit_std: 1.0 76 | vis_sampler: flow_dpm-solver 77 | # training setting 78 | train: 79 | num_workers: 10 80 | seed: 1 81 | train_batch_size: 64 82 | num_epochs: 100 83 | gradient_accumulation_steps: 1 84 | grad_checkpointing: true 85 | gradient_clip: 0.1 86 | optimizer: 87 | lr: 1.0e-4 88 | type: AdamW 89 | weight_decay: 0.01 90 | eps: 1.0e-8 91 | betas: [0.9, 0.999] 92 | lr_schedule: constant 93 | lr_schedule_args: 94 | num_warmup_steps: 2000 95 | local_save_vis: true # if save log image locally 96 | visualize: true 97 | eval_sampling_steps: 500 98 | log_interval: 20 99 | save_model_epochs: 5 100 | save_model_steps: 500 101 | work_dir: output/debug 102 | online_metric: false 103 | eval_metric_step: 2000 104 | online_metric_dir: metric_helper 105 | -------------------------------------------------------------------------------- /configs/sana_config/1024ms/Sana_600M_img1024.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [data/toy_data] 3 | image_size: 1024 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] 7 | external_clipscore_suffixes: 8 | - _InternVL2-26B_clip_score 9 | - _VILA1-5-13B_clip_score 10 | - _prompt_clip_score 11 | clip_thr_temperature: 0.1 12 | clip_thr: 25.0 13 | load_text_feat: false 14 | load_vae_feat: false 15 | transform: default_train 16 | type: SanaWebDatasetMS 17 | sort_dataset: false 18 | # model config 19 | model: 20 | model: SanaMS_600M_P1_D28 21 | image_size: 1024 22 | mixed_precision: fp16 23 | fp32_attention: true 24 | load_from: 25 | resume_from: 26 | aspect_ratio_type: ASPECT_RATIO_1024 27 | multi_scale: true 28 | attn_type: linear 29 | ffn_type: glumbconv 30 | mlp_acts: 31 | - silu 32 | - silu 33 | - 34 | mlp_ratio: 2.5 35 | use_pe: false 36 | qk_norm: false 37 | class_dropout_prob: 0.1 38 | # VAE setting 39 | vae: 40 | vae_type: AutoencoderDC 41 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 42 | scale_factor: 0.41407 43 | vae_latent_dim: 32 44 | vae_downsample_rate: 32 45 | sample_posterior: true 46 | # text encoder 47 | text_encoder: 48 | text_encoder_name: gemma-2-2b-it 49 | y_norm: true 50 | y_norm_scale_factor: 0.01 51 | model_max_length: 300 52 | # CHI 53 | chi_prompt: 54 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 55 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 56 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 57 | - 'Here are examples of how to transform or refine prompts:' 58 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 59 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 60 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 61 | - 'User Prompt: ' 62 | # Sana schedule Flow 63 | scheduler: 64 | predict_flow_v: true 65 | noise_schedule: linear_flow 66 | pred_sigma: false 67 | flow_shift: 4.0 68 | # logit-normal timestep 69 | weighting_scheme: logit_normal 70 | logit_mean: 0.0 71 | logit_std: 1.0 72 | vis_sampler: flow_dpm-solver 73 | # training setting 74 | train: 75 | num_workers: 10 76 | seed: 1 77 | train_batch_size: 64 78 | num_epochs: 100 79 | gradient_accumulation_steps: 1 80 | grad_checkpointing: true 81 | gradient_clip: 0.1 82 | optimizer: 83 | betas: 84 | - 0.9 85 | - 0.999 86 | - 0.9999 87 | eps: 88 | - 1.0e-30 89 | - 1.0e-16 90 | lr: 0.0001 91 | type: CAMEWrapper 92 | weight_decay: 0.0 93 | lr_schedule: constant 94 | lr_schedule_args: 95 | num_warmup_steps: 2000 96 | local_save_vis: true # if save log image locally 97 | visualize: true 98 | eval_sampling_steps: 500 99 | log_interval: 20 100 | save_model_epochs: 5 101 | save_model_steps: 500 102 | work_dir: output/debug 103 | online_metric: false 104 | eval_metric_step: 2000 105 | online_metric_dir: metric_helper 106 | -------------------------------------------------------------------------------- /configs/sana_config/512ms/Sana_1600M_img512.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [data/data_public/dir1] 3 | image_size: 512 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] 7 | external_clipscore_suffixes: 8 | - _InternVL2-26B_clip_score 9 | - _VILA1-5-13B_clip_score 10 | - _prompt_clip_score 11 | clip_thr_temperature: 0.1 12 | clip_thr: 25.0 13 | load_text_feat: false 14 | load_vae_feat: false 15 | transform: default_train 16 | type: SanaWebDatasetMS 17 | sort_dataset: false 18 | # model config 19 | model: 20 | model: SanaMS_1600M_P1_D20 21 | image_size: 512 22 | mixed_precision: fp16 # ['fp16', 'fp32', 'bf16'] 23 | fp32_attention: true 24 | load_from: 25 | resume_from: 26 | aspect_ratio_type: ASPECT_RATIO_512 27 | multi_scale: true 28 | attn_type: linear 29 | ffn_type: glumbconv 30 | mlp_acts: 31 | - silu 32 | - silu 33 | - 34 | mlp_ratio: 2.5 35 | use_pe: false 36 | qk_norm: false 37 | class_dropout_prob: 0.1 38 | # PAG 39 | pag_applied_layers: 40 | - 8 41 | # VAE setting 42 | vae: 43 | vae_type: AutoencoderDC 44 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 45 | scale_factor: 0.41407 46 | vae_latent_dim: 32 47 | vae_downsample_rate: 32 48 | sample_posterior: true 49 | # text encoder 50 | text_encoder: 51 | text_encoder_name: gemma-2-2b-it 52 | y_norm: true 53 | y_norm_scale_factor: 0.01 54 | model_max_length: 300 55 | # CHI 56 | chi_prompt: 57 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 58 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 59 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 60 | - 'Here are examples of how to transform or refine prompts:' 61 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 62 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 63 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 64 | - 'User Prompt: ' 65 | # Sana schedule Flow 66 | scheduler: 67 | predict_flow_v: true 68 | noise_schedule: linear_flow 69 | pred_sigma: false 70 | flow_shift: 3.0 71 | # logit-normal timestep 72 | weighting_scheme: logit_normal 73 | logit_mean: 0.0 74 | logit_std: 1.0 75 | vis_sampler: flow_dpm-solver 76 | # training setting 77 | train: 78 | num_workers: 10 79 | seed: 1 80 | train_batch_size: 64 81 | num_epochs: 100 82 | gradient_accumulation_steps: 1 83 | grad_checkpointing: true 84 | gradient_clip: 0.1 85 | optimizer: 86 | betas: 87 | - 0.9 88 | - 0.999 89 | - 0.9999 90 | eps: 91 | - 1.0e-30 92 | - 1.0e-16 93 | lr: 0.0001 94 | type: CAMEWrapper 95 | weight_decay: 0.0 96 | lr_schedule: constant 97 | lr_schedule_args: 98 | num_warmup_steps: 2000 99 | local_save_vis: true # if save log image locally 100 | visualize: true 101 | eval_sampling_steps: 500 102 | log_interval: 20 103 | save_model_epochs: 5 104 | save_model_steps: 500 105 | work_dir: output/debug 106 | online_metric: false 107 | eval_metric_step: 2000 108 | online_metric_dir: metric_helper 109 | -------------------------------------------------------------------------------- /configs/sana_config/512ms/Sana_600M_img512.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [data/data_public/dir1] 3 | image_size: 512 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] 7 | external_clipscore_suffixes: 8 | - _InternVL2-26B_clip_score 9 | - _VILA1-5-13B_clip_score 10 | - _prompt_clip_score 11 | clip_thr_temperature: 0.1 12 | clip_thr: 25.0 13 | load_text_feat: false 14 | load_vae_feat: false 15 | transform: default_train 16 | type: SanaWebDatasetMS 17 | sort_dataset: false 18 | # model config 19 | model: 20 | model: SanaMS_600M_P1_D28 21 | image_size: 512 22 | mixed_precision: fp16 23 | fp32_attention: true 24 | load_from: 25 | resume_from: 26 | aspect_ratio_type: ASPECT_RATIO_512 27 | multi_scale: true 28 | #pe_interpolation: 1. 29 | attn_type: linear 30 | linear_head_dim: 32 31 | ffn_type: glumbconv 32 | mlp_acts: 33 | - silu 34 | - silu 35 | - null 36 | mlp_ratio: 2.5 37 | use_pe: false 38 | qk_norm: false 39 | class_dropout_prob: 0.1 40 | # VAE setting 41 | vae: 42 | vae_type: AutoencoderDC 43 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 44 | scale_factor: 0.41407 45 | vae_latent_dim: 32 46 | vae_downsample_rate: 32 47 | sample_posterior: true 48 | # text encoder 49 | text_encoder: 50 | text_encoder_name: gemma-2-2b-it 51 | y_norm: true 52 | y_norm_scale_factor: 0.01 53 | model_max_length: 300 54 | # CHI 55 | chi_prompt: 56 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 57 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 58 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 59 | - 'Here are examples of how to transform or refine prompts:' 60 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 61 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 62 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 63 | - 'User Prompt: ' 64 | # Sana schedule Flow 65 | scheduler: 66 | predict_flow_v: true 67 | noise_schedule: linear_flow 68 | pred_sigma: false 69 | flow_shift: 3.0 70 | # logit-normal timestep 71 | weighting_scheme: logit_normal 72 | logit_mean: 0.0 73 | logit_std: 1.0 74 | vis_sampler: flow_dpm-solver 75 | # training setting 76 | train: 77 | num_workers: 10 78 | seed: 1 79 | train_batch_size: 128 80 | num_epochs: 100 81 | gradient_accumulation_steps: 1 82 | grad_checkpointing: true 83 | gradient_clip: 0.1 84 | optimizer: 85 | betas: 86 | - 0.9 87 | - 0.999 88 | - 0.9999 89 | eps: 90 | - 1.0e-30 91 | - 1.0e-16 92 | lr: 0.0001 93 | type: CAMEWrapper 94 | weight_decay: 0.0 95 | lr_schedule: constant 96 | lr_schedule_args: 97 | num_warmup_steps: 2000 98 | local_save_vis: true # if save log image locally 99 | visualize: true 100 | eval_sampling_steps: 500 101 | log_interval: 20 102 | save_model_epochs: 5 103 | save_model_steps: 500 104 | work_dir: output/debug 105 | online_metric: false 106 | eval_metric_step: 2000 107 | online_metric_dir: metric_helper 108 | -------------------------------------------------------------------------------- /configs/sana_config/512ms/ci_Sana_600M_img512.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [data/data_public/vaef32c32_v2_512/dir1] 3 | image_size: 512 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] 7 | external_clipscore_suffixes: 8 | - _InternVL2-26B_clip_score 9 | - _VILA1-5-13B_clip_score 10 | - _prompt_clip_score 11 | clip_thr_temperature: 0.1 12 | clip_thr: 25.0 13 | load_text_feat: false 14 | load_vae_feat: false 15 | transform: default_train 16 | type: SanaWebDatasetMS 17 | sort_dataset: false 18 | # model config 19 | model: 20 | model: SanaMS_600M_P1_D28 21 | image_size: 512 22 | mixed_precision: fp16 23 | fp32_attention: true 24 | load_from: 25 | resume_from: 26 | aspect_ratio_type: ASPECT_RATIO_512 27 | multi_scale: true 28 | #pe_interpolation: 1. 29 | attn_type: linear 30 | linear_head_dim: 32 31 | ffn_type: glumbconv 32 | mlp_acts: 33 | - silu 34 | - silu 35 | - null 36 | mlp_ratio: 2.5 37 | use_pe: false 38 | qk_norm: false 39 | class_dropout_prob: 0.1 40 | # VAE setting 41 | vae: 42 | vae_type: AutoencoderDC 43 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 44 | scale_factor: 0.41407 45 | vae_latent_dim: 32 46 | vae_downsample_rate: 32 47 | sample_posterior: true 48 | # text encoder 49 | text_encoder: 50 | text_encoder_name: gemma-2-2b-it 51 | y_norm: true 52 | y_norm_scale_factor: 0.01 53 | model_max_length: 300 54 | # CHI 55 | chi_prompt: 56 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 57 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 58 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 59 | - 'Here are examples of how to transform or refine prompts:' 60 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 61 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 62 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 63 | - 'User Prompt: ' 64 | # Sana schedule Flow 65 | scheduler: 66 | predict_flow_v: true 67 | noise_schedule: linear_flow 68 | pred_sigma: false 69 | flow_shift: 1.0 70 | # logit-normal timestep 71 | weighting_scheme: logit_normal 72 | logit_mean: 0.0 73 | logit_std: 1.0 74 | vis_sampler: flow_dpm-solver 75 | # training setting 76 | train: 77 | num_workers: 10 78 | seed: 1 79 | train_batch_size: 64 80 | num_epochs: 1 81 | gradient_accumulation_steps: 1 82 | grad_checkpointing: true 83 | gradient_clip: 0.1 84 | optimizer: 85 | betas: 86 | - 0.9 87 | - 0.999 88 | - 0.9999 89 | eps: 90 | - 1.0e-30 91 | - 1.0e-16 92 | lr: 0.0001 93 | type: CAMEWrapper 94 | weight_decay: 0.0 95 | lr_schedule: constant 96 | lr_schedule_args: 97 | num_warmup_steps: 2000 98 | local_save_vis: true # if save log image locally 99 | visualize: true 100 | eval_sampling_steps: 500 101 | log_interval: 20 102 | save_model_epochs: 5 103 | save_model_steps: 500 104 | work_dir: output/debug 105 | online_metric: false 106 | eval_metric_step: 2000 107 | online_metric_dir: metric_helper 108 | -------------------------------------------------------------------------------- /configs/sana_config/512ms/sample_dataset.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [asset/example_data] 3 | image_size: 512 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] # json fils 7 | external_clipscore_suffixes: # json files 8 | - _InternVL2-26B_clip_score 9 | - _VILA1-5-13B_clip_score 10 | - _prompt_clip_score 11 | clip_thr_temperature: 0.1 12 | clip_thr: 25.0 13 | load_text_feat: false 14 | load_vae_feat: false 15 | transform: default_train 16 | type: SanaImgDataset 17 | sort_dataset: false 18 | # model config 19 | model: 20 | model: SanaMS_600M_P1_D28 21 | image_size: 512 22 | mixed_precision: fp16 23 | fp32_attention: true 24 | load_from: 25 | resume_from: 26 | aspect_ratio_type: ASPECT_RATIO_512 27 | multi_scale: false 28 | #pe_interpolation: 1. 29 | attn_type: linear 30 | linear_head_dim: 32 31 | ffn_type: glumbconv 32 | mlp_acts: 33 | - silu 34 | - silu 35 | - null 36 | mlp_ratio: 2.5 37 | use_pe: false 38 | qk_norm: false 39 | class_dropout_prob: 0.1 40 | # VAE setting 41 | vae: 42 | vae_type: AutoencoderDC 43 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 44 | scale_factor: 0.41407 45 | vae_latent_dim: 32 46 | vae_downsample_rate: 32 47 | sample_posterior: true 48 | # text encoder 49 | text_encoder: 50 | text_encoder_name: gemma-2-2b-it 51 | y_norm: true 52 | y_norm_scale_factor: 0.01 53 | model_max_length: 300 54 | # CHI 55 | chi_prompt: 56 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 57 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 58 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 59 | - 'Here are examples of how to transform or refine prompts:' 60 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 61 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 62 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 63 | - 'User Prompt: ' 64 | # Sana schedule Flow 65 | scheduler: 66 | predict_flow_v: true 67 | noise_schedule: linear_flow 68 | pred_sigma: false 69 | flow_shift: 1.0 70 | # logit-normal timestep 71 | weighting_scheme: logit_normal 72 | logit_mean: 0.0 73 | logit_std: 1.0 74 | vis_sampler: flow_dpm-solver 75 | # training setting 76 | train: 77 | num_workers: 10 78 | seed: 1 79 | train_batch_size: 128 80 | num_epochs: 100 81 | gradient_accumulation_steps: 1 82 | grad_checkpointing: true 83 | gradient_clip: 0.1 84 | optimizer: 85 | betas: 86 | - 0.9 87 | - 0.999 88 | - 0.9999 89 | eps: 90 | - 1.0e-30 91 | - 1.0e-16 92 | lr: 0.0001 93 | type: CAMEWrapper 94 | weight_decay: 0.0 95 | lr_schedule: constant 96 | lr_schedule_args: 97 | num_warmup_steps: 2000 98 | local_save_vis: true # if save log image locally 99 | visualize: true 100 | eval_sampling_steps: 500 101 | log_interval: 20 102 | save_model_epochs: 5 103 | save_model_steps: 500 104 | work_dir: output/debug 105 | online_metric: false 106 | eval_metric_step: 2000 107 | online_metric_dir: metric_helper 108 | -------------------------------------------------------------------------------- /configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [data/data_public/controlnet_data] 3 | image_size: 1024 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: [] 7 | external_clipscore_suffixes: [] 8 | clip_thr_temperature: 0.1 9 | clip_thr: 25.0 10 | load_text_feat: false 11 | load_vae_feat: false 12 | transform: default_train 13 | type: SanaWebDatasetMSControl 14 | sort_dataset: false 15 | # model config 16 | model: 17 | model: SanaMSControlNet_1600M_P1_D20 18 | image_size: 1024 19 | mixed_precision: bf16 20 | fp32_attention: true 21 | load_from: hf://Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoint/Sana_1600M_1024px_BF16.pth 22 | resume_from: 23 | aspect_ratio_type: ASPECT_RATIO_1024 24 | multi_scale: true 25 | attn_type: linear 26 | ffn_type: glumbconv 27 | mlp_acts: 28 | - silu 29 | - silu 30 | - 31 | mlp_ratio: 2.5 32 | use_pe: false 33 | qk_norm: false 34 | class_dropout_prob: 0.1 35 | # VAE setting 36 | vae: 37 | vae_type: AutoencoderDC 38 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 39 | scale_factor: 0.41407 40 | vae_latent_dim: 32 41 | vae_downsample_rate: 32 42 | sample_posterior: true 43 | weight_dtype: bf16 44 | # text encoder 45 | text_encoder: 46 | text_encoder_name: gemma-2-2b-it 47 | y_norm: true 48 | y_norm_scale_factor: 0.01 49 | model_max_length: 300 50 | # CHI 51 | chi_prompt: 52 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 53 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 54 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 55 | - 'Here are examples of how to transform or refine prompts:' 56 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 57 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 58 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 59 | - 'User Prompt: ' 60 | # Sana schedule Flow 61 | scheduler: 62 | predict_flow_v: true 63 | noise_schedule: linear_flow 64 | pred_sigma: false 65 | flow_shift: 3.0 66 | # logit-normal timestep 67 | weighting_scheme: logit_normal 68 | logit_mean: 0.0 69 | logit_std: 1.0 70 | vis_sampler: flow_dpm-solver 71 | # training setting 72 | train: 73 | num_workers: 10 74 | seed: 1 75 | train_batch_size: 16 76 | num_epochs: 100 77 | gradient_accumulation_steps: 1 78 | grad_checkpointing: true 79 | gradient_clip: 0.1 80 | optimizer: 81 | betas: 82 | - 0.9 83 | - 0.999 84 | - 0.9999 85 | eps: 86 | - 1.0e-30 87 | - 1.0e-16 88 | lr: 0.0001 89 | type: CAMEWrapper 90 | weight_decay: 0.0 91 | lr_schedule: constant 92 | lr_schedule_args: 93 | num_warmup_steps: 30 94 | local_save_vis: true # if save log image locally 95 | visualize: true 96 | eval_sampling_steps: 50 97 | log_interval: 20 98 | save_model_epochs: 5 99 | save_model_steps: 500 100 | work_dir: output/debug 101 | online_metric: false 102 | eval_metric_step: 2000 103 | online_metric_dir: metric_helper 104 | controlnet: 105 | control_signal_type: "scribble" 106 | -------------------------------------------------------------------------------- /configs/sana_controlnet_config/Sana_600M_img1024_controlnet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: [data/data_public/controlnet_data] 3 | image_size: 1024 4 | caption_proportion: 5 | prompt: 1 6 | external_caption_suffixes: [] 7 | external_clipscore_suffixes: [] 8 | clip_thr_temperature: 0.1 9 | clip_thr: 25.0 10 | load_text_feat: false 11 | load_vae_feat: false 12 | transform: default_train 13 | type: SanaWebDatasetMSControl 14 | sort_dataset: false 15 | # model config 16 | model: 17 | model: SanaMSControlNet_600M_P1_D28 18 | image_size: 1024 19 | mixed_precision: fp16 20 | fp32_attention: true 21 | load_from: hf://Efficient-Large-Model/Sana_600M_1024px/checkpoint/Sana_600M_1024px.pth 22 | resume_from: 23 | aspect_ratio_type: ASPECT_RATIO_1024 24 | multi_scale: true 25 | attn_type: linear 26 | ffn_type: glumbconv 27 | mlp_acts: 28 | - silu 29 | - silu 30 | - 31 | mlp_ratio: 2.5 32 | use_pe: false 33 | qk_norm: false 34 | class_dropout_prob: 0.1 35 | # VAE setting 36 | vae: 37 | vae_type: AutoencoderDC 38 | vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers 39 | scale_factor: 0.41407 40 | vae_latent_dim: 32 41 | vae_downsample_rate: 32 42 | sample_posterior: true 43 | # text encoder 44 | text_encoder: 45 | text_encoder_name: gemma-2-2b-it 46 | y_norm: true 47 | y_norm_scale_factor: 0.01 48 | model_max_length: 300 49 | # CHI 50 | chi_prompt: 51 | - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:' 52 | - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.' 53 | - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.' 54 | - 'Here are examples of how to transform or refine prompts:' 55 | - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.' 56 | - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.' 57 | - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:' 58 | - 'User Prompt: ' 59 | # Sana schedule Flow 60 | scheduler: 61 | predict_flow_v: true 62 | noise_schedule: linear_flow 63 | pred_sigma: false 64 | flow_shift: 4.0 65 | # logit-normal timestep 66 | weighting_scheme: logit_normal 67 | logit_mean: 0.0 68 | logit_std: 1.0 69 | vis_sampler: flow_dpm-solver 70 | # training setting 71 | train: 72 | num_workers: 10 73 | seed: 1 74 | train_batch_size: 16 75 | num_epochs: 100 76 | gradient_accumulation_steps: 1 77 | grad_checkpointing: true 78 | gradient_clip: 0.1 79 | optimizer: 80 | betas: 81 | - 0.9 82 | - 0.999 83 | - 0.9999 84 | eps: 85 | - 1.0e-30 86 | - 1.0e-16 87 | lr: 0.0001 88 | type: CAMEWrapper 89 | weight_decay: 0.0 90 | lr_schedule: constant 91 | lr_schedule_args: 92 | num_warmup_steps: 30 93 | local_save_vis: true # if save log image locally 94 | visualize: true 95 | eval_sampling_steps: 500 96 | log_interval: 20 97 | save_model_epochs: 5 98 | save_model_steps: 500 99 | work_dir: output/debug 100 | online_metric: false 101 | eval_metric_step: 2000 102 | online_metric_dir: metric_helper 103 | controlnet: 104 | control_signal_type: "scribble" 105 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | 7 | from .scheduler.dpm_solver import DPMS 8 | from .scheduler.flow_euler_sampler import FlowEuler 9 | from .scheduler.iddpm import Scheduler 10 | from .scheduler.sa_sampler import SASolverSampler 11 | from .scheduler.scm_scheduler import SCMScheduler 12 | from .scheduler.trigflow_scheduler import TrigFlowScheduler 13 | -------------------------------------------------------------------------------- /diffusion/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .transforms import get_transform 3 | -------------------------------------------------------------------------------- /diffusion/data/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import os 18 | import time 19 | 20 | from mmcv import Registry, build_from_cfg 21 | from termcolor import colored 22 | from torch.utils.data import DataLoader 23 | 24 | from diffusion.data.transforms import get_transform 25 | from diffusion.utils.logger import get_root_logger 26 | 27 | DATASETS = Registry("datasets") 28 | 29 | DATA_ROOT = "data" 30 | 31 | 32 | def set_data_root(data_root): 33 | global DATA_ROOT 34 | DATA_ROOT = data_root 35 | 36 | 37 | def get_data_path(data_dir): 38 | if os.path.isabs(data_dir): 39 | return data_dir 40 | global DATA_ROOT 41 | return os.path.join(DATA_ROOT, data_dir) 42 | 43 | 44 | def get_data_root_and_path(data_dir): 45 | if os.path.isabs(data_dir): 46 | return data_dir 47 | global DATA_ROOT 48 | return DATA_ROOT, os.path.join(DATA_ROOT, data_dir) 49 | 50 | 51 | def build_dataset(cfg, resolution=224, **kwargs): 52 | logger = get_root_logger() 53 | 54 | dataset_type = cfg.get("type") 55 | logger.info(f"Constructing dataset {dataset_type}...") 56 | t = time.time() 57 | transform = cfg.pop("transform", "default_train") 58 | transform = get_transform(transform, resolution) 59 | dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs)) 60 | logger.info( 61 | f"{colored(f'Dataset {dataset_type} constructed: ', 'green', attrs=['bold'])}" 62 | f"time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}" 63 | ) 64 | return dataset 65 | 66 | 67 | def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs): 68 | if "batch_sampler" in kwargs: 69 | dataloader = DataLoader( 70 | dataset, batch_sampler=kwargs["batch_sampler"], num_workers=num_workers, pin_memory=True 71 | ) 72 | else: 73 | dataloader = DataLoader( 74 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, **kwargs 75 | ) 76 | return dataloader 77 | -------------------------------------------------------------------------------- /diffusion/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .sana_data import SanaImgDataset, SanaWebDataset 2 | from .sana_data_multi_scale import DummyDatasetMS, SanaWebDatasetMS 3 | from .utils import * 4 | -------------------------------------------------------------------------------- /diffusion/data/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torchvision.transforms as T 18 | 19 | TRANSFORMS = dict() 20 | 21 | 22 | def register_transform(transform): 23 | name = transform.__name__ 24 | if name in TRANSFORMS: 25 | raise RuntimeError(f"Transform {name} has already registered.") 26 | TRANSFORMS.update({name: transform}) 27 | 28 | 29 | def get_transform(type, resolution): 30 | transform = TRANSFORMS[type](resolution) 31 | transform = T.Compose(transform) 32 | transform.image_size = resolution 33 | return transform 34 | 35 | 36 | @register_transform 37 | def default_train(n_px): 38 | transform = [ 39 | T.Lambda(lambda img: img.convert("RGB")), 40 | T.Resize(n_px), # Image.BICUBIC 41 | T.CenterCrop(n_px), 42 | # T.RandomHorizontalFlip(), 43 | T.ToTensor(), 44 | T.Normalize([0.5], [0.5]), 45 | ] 46 | return transform 47 | -------------------------------------------------------------------------------- /diffusion/data/wids/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. 2 | # This file is part of the WebDataset library. 3 | # See the LICENSE file for licensing terms (BSD-style). 4 | # 5 | # flake8: noqa 6 | 7 | from .wids import ( 8 | ChunkedSampler, 9 | DistributedChunkedSampler, 10 | DistributedLocalSampler, 11 | DistributedRangedSampler, 12 | ShardedSampler, 13 | ShardListDataset, 14 | ShardListDatasetMulti, 15 | lru_json_load, 16 | ) 17 | -------------------------------------------------------------------------------- /diffusion/data/wids/wids_lru.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids 18 | from collections import OrderedDict 19 | 20 | 21 | class LRUCache: 22 | def __init__(self, capacity: int, release_handler=None): 23 | """Initialize a new LRU cache with the given capacity.""" 24 | self.capacity = capacity 25 | self.cache = OrderedDict() 26 | self.release_handler = release_handler 27 | 28 | def __getitem__(self, key): 29 | """Return the value associated with the given key, or None.""" 30 | if key not in self.cache: 31 | return None 32 | self.cache.move_to_end(key) 33 | return self.cache[key] 34 | 35 | def __setitem__(self, key, value): 36 | """Associate the given value with the given key.""" 37 | if key in self.cache: 38 | self.cache.move_to_end(key) 39 | self.cache[key] = value 40 | if len(self.cache) > self.capacity: 41 | key, value = self.cache.popitem(last=False) 42 | if self.release_handler is not None: 43 | self.release_handler(key, value) 44 | 45 | def __delitem__(self, key): 46 | """Remove the given key from the cache.""" 47 | if key in self.cache: 48 | if self.release_handler is not None: 49 | value = self.cache[key] 50 | self.release_handler(key, value) 51 | del self.cache[key] 52 | 53 | def __len__(self): 54 | """Return the number of entries in the cache.""" 55 | return len(self.cache) 56 | 57 | def __contains__(self, key): 58 | """Return whether the cache contains the given key.""" 59 | return key in self.cache 60 | 61 | def items(self): 62 | """Return an iterator over the keys of the cache.""" 63 | return self.cache.items() 64 | 65 | def keys(self): 66 | """Return an iterator over the keys of the cache.""" 67 | return self.cache.keys() 68 | 69 | def values(self): 70 | """Return an iterator over the values of the cache.""" 71 | return self.cache.values() 72 | 73 | def clear(self): 74 | for key in list(self.keys()): 75 | value = self.cache[key] 76 | if self.release_handler is not None: 77 | self.release_handler(key, value) 78 | del self[key] 79 | 80 | def __del__(self): 81 | self.clear() 82 | -------------------------------------------------------------------------------- /diffusion/data/wids/wids_tar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids 18 | import io 19 | import os 20 | import os.path 21 | import pickle 22 | import re 23 | import tarfile 24 | 25 | import numpy as np 26 | 27 | 28 | def find_index_file(file): 29 | prefix, last_ext = os.path.splitext(file) 30 | if re.match("._[0-9]+_$", last_ext): 31 | return prefix + ".index" 32 | else: 33 | return file + ".index" 34 | 35 | 36 | class TarFileReader: 37 | def __init__(self, file, index_file=find_index_file, verbose=True): 38 | self.verbose = verbose 39 | if callable(index_file): 40 | index_file = index_file(file) 41 | self.index_file = index_file 42 | 43 | # Open the tar file and keep it open 44 | if isinstance(file, str): 45 | self.tar_file = tarfile.open(file, "r") 46 | else: 47 | self.tar_file = tarfile.open(fileobj=file, mode="r") 48 | 49 | # Create the index 50 | self._create_tar_index() 51 | 52 | def _create_tar_index(self): 53 | if self.index_file is not None and os.path.exists(self.index_file): 54 | if self.verbose: 55 | print("Loading tar index from", self.index_file) 56 | with open(self.index_file, "rb") as stream: 57 | self.fnames, self.index = pickle.load(stream) 58 | return 59 | # Create an empty list for the index 60 | self.fnames = [] 61 | self.index = [] 62 | 63 | if self.verbose: 64 | print("Creating tar index for", self.tar_file.name, "at", self.index_file) 65 | # Iterate over the members of the tar file 66 | for member in self.tar_file: 67 | # If the member is a file, add it to the index 68 | if member.isfile(): 69 | # Get the file's offset 70 | offset = self.tar_file.fileobj.tell() 71 | self.fnames.append(member.name) 72 | self.index.append([offset, member.size]) 73 | if self.verbose: 74 | print("Done creating tar index for", self.tar_file.name, "at", self.index_file) 75 | self.index = np.array(self.index) 76 | if self.index_file is not None: 77 | if os.path.exists(self.index_file + ".temp"): 78 | os.unlink(self.index_file + ".temp") 79 | with open(self.index_file + ".temp", "wb") as stream: 80 | pickle.dump((self.fnames, self.index), stream) 81 | os.rename(self.index_file + ".temp", self.index_file) 82 | 83 | def names(self): 84 | return self.fnames 85 | 86 | def __len__(self): 87 | return len(self.index) 88 | 89 | def get_file(self, i): 90 | name = self.fnames[i] 91 | offset, size = self.index[i] 92 | self.tar_file.fileobj.seek(offset) 93 | file_bytes = self.tar_file.fileobj.read(size) 94 | return name, io.BytesIO(file_bytes) 95 | 96 | def close(self): 97 | # Close the tar file 98 | self.tar_file.close() 99 | -------------------------------------------------------------------------------- /diffusion/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/diffusion/model/__init__.py -------------------------------------------------------------------------------- /diffusion/model/act.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import copy 18 | 19 | import torch.nn as nn 20 | 21 | __all__ = ["build_act", "get_act_name"] 22 | 23 | # register activation function here 24 | # name: module, kwargs with default values 25 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = { 26 | "relu": (nn.ReLU, {"inplace": True}), 27 | "relu6": (nn.ReLU6, {"inplace": True}), 28 | "hswish": (nn.Hardswish, {"inplace": True}), 29 | "hsigmoid": (nn.Hardsigmoid, {"inplace": True}), 30 | "swish": (nn.SiLU, {"inplace": True}), 31 | "silu": (nn.SiLU, {"inplace": True}), 32 | "tanh": (nn.Tanh, {}), 33 | "sigmoid": (nn.Sigmoid, {}), 34 | "gelu": (nn.GELU, {"approximate": "tanh"}), 35 | "mish": (nn.Mish, {"inplace": True}), 36 | "identity": (nn.Identity, {}), 37 | } 38 | 39 | 40 | def build_act(name: str or None, **kwargs) -> nn.Module or None: 41 | if name in REGISTERED_ACT_DICT: 42 | act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name]) 43 | for key in default_args: 44 | if key in kwargs: 45 | default_args[key] = kwargs[key] 46 | return act_cls(**default_args) 47 | elif name is None or name.lower() == "none": 48 | return None 49 | else: 50 | raise ValueError(f"do not support: {name}") 51 | 52 | 53 | def get_act_name(act: nn.Module or None) -> str or None: 54 | if act is None: 55 | return None 56 | module2name = {} 57 | for key, config in REGISTERED_ACT_DICT.items(): 58 | module2name[config[0].__name__] = key 59 | return module2name.get(type(act).__name__, "unknown") 60 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/diffusion/model/dc_ae/efficientvit/__init__.py -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/ae_model_zoo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import Callable, Optional 18 | 19 | import diffusers 20 | import torch 21 | from huggingface_hub import PyTorchModelHubMixin 22 | from torch import nn 23 | 24 | from ..efficientvit.models.efficientvit.dc_ae import DCAE, DCAEConfig, dc_ae_f32c32, dc_ae_f64c128, dc_ae_f128c512 25 | 26 | __all__ = ["create_dc_ae_model_cfg", "DCAE_HF", "AutoencoderKL"] 27 | 28 | 29 | REGISTERED_DCAE_MODEL: dict[str, tuple[Callable, Optional[str]]] = { 30 | "dc-ae-f32c32-in-1.0": (dc_ae_f32c32, None), 31 | "dc-ae-f64c128-in-1.0": (dc_ae_f64c128, None), 32 | "dc-ae-f128c512-in-1.0": (dc_ae_f128c512, None), 33 | ################################################################################################# 34 | "dc-ae-f32c32-mix-1.0": (dc_ae_f32c32, None), 35 | "dc-ae-f64c128-mix-1.0": (dc_ae_f64c128, None), 36 | "dc-ae-f128c512-mix-1.0": (dc_ae_f128c512, None), 37 | ################################################################################################# 38 | "dc-ae-f32c32-sana-1.0": (dc_ae_f32c32, None), 39 | "dc-ae-f32c32-sana-1.1": (dc_ae_f32c32, None), 40 | } 41 | 42 | 43 | def create_dc_ae_model_cfg(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig: 44 | assert name in REGISTERED_DCAE_MODEL, f"{name} is not supported" 45 | dc_ae_cls, default_pt_path = REGISTERED_DCAE_MODEL[name] 46 | pretrained_path = default_pt_path if pretrained_path is None else pretrained_path 47 | model_cfg = dc_ae_cls(name, pretrained_path) 48 | return model_cfg 49 | 50 | 51 | class DCAE_HF(DCAE, PyTorchModelHubMixin): 52 | def __init__(self, model_name: str): 53 | cfg = create_dc_ae_model_cfg(model_name) 54 | DCAE.__init__(self, cfg) 55 | 56 | 57 | class AutoencoderKL(nn.Module): 58 | def __init__(self, model_name: str): 59 | super().__init__() 60 | self.model_name = model_name 61 | if self.model_name in ["stabilityai/sd-vae-ft-ema"]: 62 | self.model = diffusers.models.AutoencoderKL.from_pretrained(self.model_name) 63 | self.spatial_compression_ratio = 8 64 | elif self.model_name == "flux-vae": 65 | from diffusers import FluxPipeline 66 | 67 | pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) 68 | self.model = diffusers.models.AutoencoderKL.from_pretrained(pipe.vae.config._name_or_path) 69 | self.spatial_compression_ratio = 8 70 | else: 71 | raise ValueError(f"{self.model_name} is not supported for AutoencoderKL") 72 | 73 | def encode(self, x: torch.Tensor) -> torch.Tensor: 74 | if self.model_name in ["stabilityai/sd-vae-ft-ema", "flux-vae"]: 75 | return self.model.encode(x).latent_dist.sample() 76 | else: 77 | raise ValueError(f"{self.model_name} is not supported for AutoencoderKL") 78 | 79 | def decode(self, latent: torch.Tensor) -> torch.Tensor: 80 | if self.model_name in ["stabilityai/sd-vae-ft-ema", "flux-vae"]: 81 | return self.model.decode(latent).sample 82 | else: 83 | raise ValueError(f"{self.model_name} is not supported for AutoencoderKL") 84 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/diffusion/model/dc_ae/efficientvit/apps/__init__.py -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from copy import deepcopy 4 | from typing import Optional 5 | 6 | import torch.backends.cudnn 7 | import torch.distributed 8 | import torch.nn as nn 9 | 10 | from ..apps.utils import ( 11 | dist_init, 12 | dump_config, 13 | get_dist_local_rank, 14 | get_dist_rank, 15 | get_dist_size, 16 | init_modules, 17 | is_master, 18 | load_config, 19 | partial_update_config, 20 | zero_last_gamma, 21 | ) 22 | from ..models.utils import build_kwargs_from_config, load_state_dict_from_file 23 | 24 | __all__ = [ 25 | "save_exp_config", 26 | "setup_dist_env", 27 | "setup_seed", 28 | "setup_exp_config", 29 | "init_model", 30 | ] 31 | 32 | 33 | def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None: 34 | if not is_master(): 35 | return 36 | dump_config(exp_config, os.path.join(path, name)) 37 | 38 | 39 | def setup_dist_env(gpu: Optional[str] = None) -> None: 40 | if gpu is not None: 41 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu 42 | if not torch.distributed.is_initialized(): 43 | dist_init() 44 | torch.backends.cudnn.benchmark = True 45 | torch.cuda.set_device(get_dist_local_rank()) 46 | 47 | 48 | def setup_seed(manual_seed: int, resume: bool) -> None: 49 | if resume: 50 | manual_seed = int(time.time()) 51 | manual_seed = get_dist_rank() + manual_seed 52 | torch.manual_seed(manual_seed) 53 | torch.cuda.manual_seed_all(manual_seed) 54 | 55 | 56 | def setup_exp_config(config_path: str, recursive=True, opt_args: Optional[dict] = None) -> dict: 57 | # load config 58 | if not os.path.isfile(config_path): 59 | raise ValueError(config_path) 60 | 61 | fpaths = [config_path] 62 | if recursive: 63 | extension = os.path.splitext(config_path)[1] 64 | while os.path.dirname(config_path) != config_path: 65 | config_path = os.path.dirname(config_path) 66 | fpath = os.path.join(config_path, "default" + extension) 67 | if os.path.isfile(fpath): 68 | fpaths.append(fpath) 69 | fpaths = fpaths[::-1] 70 | 71 | default_config = load_config(fpaths[0]) 72 | exp_config = deepcopy(default_config) 73 | for fpath in fpaths[1:]: 74 | partial_update_config(exp_config, load_config(fpath)) 75 | # update config via args 76 | if opt_args is not None: 77 | partial_update_config(exp_config, opt_args) 78 | 79 | return exp_config 80 | 81 | 82 | def init_model( 83 | network: nn.Module, 84 | init_from: Optional[str] = None, 85 | backbone_init_from: Optional[str] = None, 86 | rand_init="trunc_normal", 87 | last_gamma=None, 88 | ) -> None: 89 | # initialization 90 | init_modules(network, init_type=rand_init) 91 | # zero gamma of last bn in each block 92 | if last_gamma is not None: 93 | zero_last_gamma(network, last_gamma) 94 | 95 | # load weight 96 | if init_from is not None and os.path.isfile(init_from): 97 | network.load_state_dict(load_state_dict_from_file(init_from)) 98 | print(f"Loaded init from {init_from}") 99 | elif backbone_init_from is not None and os.path.isfile(backbone_init_from): 100 | network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from)) 101 | print(f"Loaded backbone init from {backbone_init_from}") 102 | else: 103 | print(f"Random init ({rand_init}) with last gamma {last_gamma}") 104 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .run_config import * 2 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist import * 2 | from .ema import * 3 | 4 | # from .export import * 5 | from .image import * 6 | from .init import * 7 | from .lr import * 8 | from .metric import * 9 | from .misc import * 10 | from .opt import * 11 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/utils/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import os 18 | from typing import Union 19 | 20 | import torch 21 | import torch.distributed 22 | 23 | from ...models.utils.list import list_mean, list_sum 24 | 25 | __all__ = [ 26 | "dist_init", 27 | "is_dist_initialized", 28 | "get_dist_rank", 29 | "get_dist_size", 30 | "is_master", 31 | "dist_barrier", 32 | "get_dist_local_rank", 33 | "sync_tensor", 34 | ] 35 | 36 | 37 | def dist_init() -> None: 38 | if is_dist_initialized(): 39 | return 40 | try: 41 | torch.distributed.init_process_group(backend="nccl") 42 | assert torch.distributed.is_initialized() 43 | except Exception: 44 | os.environ["RANK"] = "0" 45 | os.environ["WORLD_SIZE"] = "1" 46 | os.environ["LOCAL_RANK"] = "0" 47 | print("warning: dist not init") 48 | 49 | 50 | def is_dist_initialized() -> bool: 51 | return torch.distributed.is_initialized() 52 | 53 | 54 | def get_dist_rank() -> int: 55 | return int(os.environ["RANK"]) 56 | 57 | 58 | def get_dist_size() -> int: 59 | return int(os.environ["WORLD_SIZE"]) 60 | 61 | 62 | def is_master() -> bool: 63 | return get_dist_rank() == 0 64 | 65 | 66 | def dist_barrier() -> None: 67 | if is_dist_initialized(): 68 | torch.distributed.barrier() 69 | 70 | 71 | def get_dist_local_rank() -> int: 72 | return int(os.environ["LOCAL_RANK"]) 73 | 74 | 75 | def sync_tensor(tensor: Union[torch.Tensor, float], reduce="mean") -> Union[torch.Tensor, list[torch.Tensor]]: 76 | if not is_dist_initialized(): 77 | return tensor 78 | if not isinstance(tensor, torch.Tensor): 79 | tensor = torch.Tensor(1).fill_(tensor).cuda() 80 | tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())] 81 | torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False) 82 | if reduce == "mean": 83 | return list_mean(tensor_list) 84 | elif reduce == "sum": 85 | return list_sum(tensor_list) 86 | elif reduce == "cat": 87 | return torch.cat(tensor_list, dim=0) 88 | elif reduce == "root": 89 | return tensor_list[0] 90 | else: 91 | return tensor_list 92 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/utils/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import copy 18 | import math 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | from ...models.utils import is_parallel 24 | 25 | __all__ = ["EMA"] 26 | 27 | 28 | def update_ema(ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float) -> None: 29 | for k, v in ema.state_dict().items(): 30 | if v.dtype.is_floating_point: 31 | v -= (1.0 - decay) * (v - new_state_dict[k].detach()) 32 | 33 | 34 | class EMA: 35 | def __init__(self, model: nn.Module, decay: float, warmup_steps=2000): 36 | self.shadows = copy.deepcopy(model.module if is_parallel(model) else model).eval() 37 | self.decay = decay 38 | self.warmup_steps = warmup_steps 39 | 40 | for p in self.shadows.parameters(): 41 | p.requires_grad = False 42 | 43 | def step(self, model: nn.Module, global_step: int) -> None: 44 | with torch.no_grad(): 45 | msd = (model.module if is_parallel(model) else model).state_dict() 46 | update_ema(self.shadows, msd, self.decay * (1 - math.exp(-global_step / self.warmup_steps))) 47 | 48 | def state_dict(self) -> dict[float, dict[str, torch.Tensor]]: 49 | return {self.decay: self.shadows.state_dict()} 50 | 51 | def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None: 52 | for decay in state_dict: 53 | if decay == self.decay: 54 | self.shadows.load_state_dict(state_dict[decay]) 55 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/utils/export.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import io 18 | import os 19 | from typing import Any 20 | 21 | import onnx 22 | import torch 23 | import torch.nn as nn 24 | from onnxsim import simplify as simplify_func 25 | 26 | __all__ = ["export_onnx"] 27 | 28 | 29 | def export_onnx(model: nn.Module, export_path: str, sample_inputs: Any, simplify=True, opset=11) -> None: 30 | """Export a model to a platform-specific onnx format. 31 | 32 | Args: 33 | model: a torch.nn.Module object. 34 | export_path: export location. 35 | sample_inputs: Any. 36 | simplify: a flag to turn on onnx-simplifier 37 | opset: int 38 | """ 39 | model.eval() 40 | 41 | buffer = io.BytesIO() 42 | with torch.no_grad(): 43 | torch.onnx.export(model, sample_inputs, buffer, opset_version=opset) 44 | buffer.seek(0, 0) 45 | if simplify: 46 | onnx_model = onnx.load_model(buffer) 47 | onnx_model, success = simplify_func(onnx_model) 48 | assert success 49 | new_buffer = io.BytesIO() 50 | onnx.save(onnx_model, new_buffer) 51 | buffer = new_buffer 52 | buffer.seek(0, 0) 53 | 54 | if buffer.getbuffer().nbytes > 0: 55 | save_dir = os.path.dirname(export_path) 56 | os.makedirs(save_dir, exist_ok=True) 57 | with open(export_path, "wb") as f: 58 | f.write(buffer.read()) 59 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/utils/init.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import Union 18 | 19 | import torch 20 | import torch.nn as nn 21 | from torch.nn.modules.batchnorm import _BatchNorm 22 | 23 | __all__ = ["init_modules", "zero_last_gamma"] 24 | 25 | 26 | def init_modules(model: Union[nn.Module, list[nn.Module]], init_type="trunc_normal") -> None: 27 | _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02} 28 | 29 | if isinstance(model, list): 30 | for sub_module in model: 31 | init_modules(sub_module, init_type) 32 | else: 33 | init_params = init_type.split("@") 34 | init_params = float(init_params[1]) if len(init_params) > 1 else None 35 | 36 | if init_type.startswith("trunc_normal"): 37 | init_func = lambda param: nn.init.trunc_normal_( 38 | param, std=(_DEFAULT_INIT_PARAM["trunc_normal"] if init_params is None else init_params) 39 | ) 40 | else: 41 | raise NotImplementedError 42 | 43 | for m in model.modules(): 44 | if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): 45 | init_func(m.weight) 46 | if m.bias is not None: 47 | m.bias.data.zero_() 48 | elif isinstance(m, nn.Embedding): 49 | init_func(m.weight) 50 | elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): 51 | m.weight.data.fill_(1) 52 | m.bias.data.zero_() 53 | else: 54 | weight = getattr(m, "weight", None) 55 | bias = getattr(m, "bias", None) 56 | if isinstance(weight, torch.nn.Parameter): 57 | init_func(weight) 58 | if isinstance(bias, torch.nn.Parameter): 59 | bias.data.zero_() 60 | 61 | 62 | def zero_last_gamma(model: nn.Module, init_val=0) -> None: 63 | import efficientvit.models.nn.ops as ops 64 | 65 | for m in model.modules(): 66 | if isinstance(m, ops.ResidualBlock) and isinstance(m.shortcut, ops.IdentityLayer): 67 | if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)): 68 | parent_module = m.main.point_conv 69 | elif isinstance(m.main, ops.ResBlock): 70 | parent_module = m.main.conv2 71 | elif isinstance(m.main, ops.ConvLayer): 72 | parent_module = m.main 73 | elif isinstance(m.main, (ops.LiteMLA)): 74 | parent_module = m.main.proj 75 | else: 76 | parent_module = None 77 | if parent_module is not None: 78 | norm = getattr(parent_module, "norm", None) 79 | if norm is not None: 80 | nn.init.constant_(norm.weight, init_val) 81 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/utils/lr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import math 18 | from typing import Union 19 | 20 | import torch 21 | 22 | from ...models.utils.list import val2list 23 | 24 | __all__ = ["CosineLRwithWarmup", "ConstantLRwithWarmup"] 25 | 26 | 27 | class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler): 28 | def __init__( 29 | self, 30 | optimizer: torch.optim.Optimizer, 31 | warmup_steps: int, 32 | warmup_lr: float, 33 | decay_steps: Union[int, list[int]], 34 | last_epoch: int = -1, 35 | ) -> None: 36 | self.warmup_steps = warmup_steps 37 | self.warmup_lr = warmup_lr 38 | self.decay_steps = val2list(decay_steps) 39 | super().__init__(optimizer, last_epoch) 40 | 41 | def get_lr(self) -> list[float]: 42 | if self.last_epoch < self.warmup_steps: 43 | return [ 44 | (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr 45 | for base_lr in self.base_lrs 46 | ] 47 | else: 48 | current_steps = self.last_epoch - self.warmup_steps 49 | decay_steps = [0] + self.decay_steps 50 | idx = len(decay_steps) - 2 51 | for i, decay_step in enumerate(decay_steps[:-1]): 52 | if decay_step <= current_steps < decay_steps[i + 1]: 53 | idx = i 54 | break 55 | current_steps -= decay_steps[idx] 56 | decay_step = decay_steps[idx + 1] - decay_steps[idx] 57 | return [0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step)) for base_lr in self.base_lrs] 58 | 59 | 60 | class ConstantLRwithWarmup(torch.optim.lr_scheduler._LRScheduler): 61 | def __init__( 62 | self, 63 | optimizer: torch.optim.Optimizer, 64 | warmup_steps: int, 65 | warmup_lr: float, 66 | last_epoch: int = -1, 67 | ) -> None: 68 | self.warmup_steps = warmup_steps 69 | self.warmup_lr = warmup_lr 70 | super().__init__(optimizer, last_epoch) 71 | 72 | def get_lr(self) -> list[float]: 73 | if self.last_epoch < self.warmup_steps: 74 | return [ 75 | (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr 76 | for base_lr in self.base_lrs 77 | ] 78 | else: 79 | return self.base_lrs 80 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/utils/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import Union 18 | 19 | import torch 20 | 21 | from ...apps.utils.dist import sync_tensor 22 | 23 | __all__ = ["AverageMeter"] 24 | 25 | 26 | class AverageMeter: 27 | """Computes and stores the average and current value.""" 28 | 29 | def __init__(self, is_distributed=True): 30 | self.is_distributed = is_distributed 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def _sync(self, val: Union[torch.Tensor, int, float]) -> Union[torch.Tensor, int, float]: 35 | return sync_tensor(val, reduce="sum") if self.is_distributed else val 36 | 37 | def update(self, val: Union[torch.Tensor, int, float], delta_n=1): 38 | self.count += self._sync(delta_n) 39 | self.sum += self._sync(val * delta_n) 40 | 41 | def get_count(self) -> Union[torch.Tensor, int, float]: 42 | return self.count.item() if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 else self.count 43 | 44 | @property 45 | def avg(self): 46 | avg = -1 if self.count == 0 else self.sum / self.count 47 | return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg 48 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/apps/utils/opt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import Any, Optional 18 | 19 | import torch 20 | 21 | __all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"] 22 | 23 | # register optimizer here 24 | # name: optimizer, kwargs with default values 25 | REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, Any]]] = { 26 | "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}), 27 | "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}), 28 | "adamw": (torch.optim.AdamW, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}), 29 | } 30 | 31 | 32 | def build_optimizer( 33 | net_params, optimizer_name: str, optimizer_params: Optional[dict], init_lr: float 34 | ) -> torch.optim.Optimizer: 35 | optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name] 36 | optimizer_params = {} if optimizer_params is None else optimizer_params 37 | 38 | for key in default_params: 39 | if key in optimizer_params: 40 | default_params[key] = optimizer_params[key] 41 | optimizer = optimizer_class(net_params, init_lr, **default_params) 42 | return optimizer 43 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/diffusion/model/dc_ae/efficientvit/models/__init__.py -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/models/efficientvit/__init__.py: -------------------------------------------------------------------------------- 1 | from .dc_ae import * 2 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .act import * 2 | from .drop import * 3 | from .norm import * 4 | from .ops import * 5 | from .triton_rms_norm import * 6 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/models/nn/act.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from functools import partial 18 | from typing import Optional 19 | 20 | import torch.nn as nn 21 | 22 | from ...models.utils import build_kwargs_from_config 23 | 24 | __all__ = ["build_act"] 25 | 26 | 27 | # register activation function here 28 | REGISTERED_ACT_DICT: dict[str, type] = { 29 | "relu": nn.ReLU, 30 | "relu6": nn.ReLU6, 31 | "hswish": nn.Hardswish, 32 | "silu": nn.SiLU, 33 | "gelu": partial(nn.GELU, approximate="tanh"), 34 | } 35 | 36 | 37 | def build_act(name: str, **kwargs) -> Optional[nn.Module]: 38 | if name in REGISTERED_ACT_DICT: 39 | act_cls = REGISTERED_ACT_DICT[name] 40 | args = build_kwargs_from_config(kwargs, act_cls) 41 | return act_cls(**args) 42 | else: 43 | return None 44 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .list import * 2 | from .network import * 3 | from .random import * 4 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/models/utils/list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import Any, Optional, Union 18 | 19 | __all__ = [ 20 | "list_sum", 21 | "list_mean", 22 | "weighted_list_sum", 23 | "list_join", 24 | "val2list", 25 | "val2tuple", 26 | "squeeze_list", 27 | ] 28 | 29 | 30 | def list_sum(x: list) -> Any: 31 | return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) 32 | 33 | 34 | def list_mean(x: list) -> Any: 35 | return list_sum(x) / len(x) 36 | 37 | 38 | def weighted_list_sum(x: list, weights: list) -> Any: 39 | assert len(x) == len(weights) 40 | return x[0] * weights[0] if len(x) == 1 else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:]) 41 | 42 | 43 | def list_join(x: list, sep="\t", format_str="%s") -> str: 44 | return sep.join([format_str % val for val in x]) 45 | 46 | 47 | def val2list(x: Union[list, tuple, Any], repeat_time=1) -> list: 48 | if isinstance(x, (list, tuple)): 49 | return list(x) 50 | return [x for _ in range(repeat_time)] 51 | 52 | 53 | def val2tuple(x: Union[list, tuple, Any], min_len: int = 1, idx_repeat: int = -1) -> tuple: 54 | x = val2list(x) 55 | 56 | # repeat elements if necessary 57 | if len(x) > 0: 58 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] 59 | 60 | return tuple(x) 61 | 62 | 63 | def squeeze_list(x: Optional[list]) -> Union[list, Any]: 64 | if x is not None and len(x) == 1: 65 | return x[0] 66 | else: 67 | return x 68 | -------------------------------------------------------------------------------- /diffusion/model/dc_ae/efficientvit/models/utils/random.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import Any, Optional, Union 18 | 19 | import numpy as np 20 | import torch 21 | 22 | __all__ = [ 23 | "torch_randint", 24 | "torch_random", 25 | "torch_shuffle", 26 | "torch_uniform", 27 | "torch_random_choices", 28 | ] 29 | 30 | 31 | def torch_randint(low: int, high: int, generator: Optional[torch.Generator] = None) -> int: 32 | """uniform: [low, high)""" 33 | if low == high: 34 | return low 35 | else: 36 | assert low < high 37 | return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) 38 | 39 | 40 | def torch_random(generator: Optional[torch.Generator] = None) -> float: 41 | """uniform distribution on the interval [0, 1)""" 42 | return float(torch.rand(1, generator=generator)) 43 | 44 | 45 | def torch_shuffle(src_list: list[Any], generator: Optional[torch.Generator] = None) -> list[Any]: 46 | rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() 47 | return [src_list[i] for i in rand_indexes] 48 | 49 | 50 | def torch_uniform(low: float, high: float, generator: Optional[torch.Generator] = None) -> float: 51 | """uniform distribution on the interval [low, high)""" 52 | rand_val = torch_random(generator) 53 | return (high - low) * rand_val + low 54 | 55 | 56 | def torch_random_choices( 57 | src_list: list[Any], 58 | generator: Optional[torch.Generator] = None, 59 | k=1, 60 | weight_list: Optional[list[float]] = None, 61 | ) -> Union[Any, list]: 62 | if weight_list is None: 63 | rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,)) 64 | out_list = [src_list[i] for i in rand_idx] 65 | else: 66 | assert len(weight_list) == len(src_list) 67 | accumulate_weight_list = np.cumsum(weight_list) 68 | 69 | out_list = [] 70 | for _ in range(k): 71 | val = torch_uniform(0, accumulate_weight_list[-1], generator) 72 | active_id = 0 73 | for i, weight_val in enumerate(accumulate_weight_list): 74 | active_id = i 75 | if weight_val > val: 76 | break 77 | out_list.append(src_list[active_id]) 78 | 79 | return out_list[0] if k == 1 else out_list 80 | -------------------------------------------------------------------------------- /diffusion/model/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .sana import ( 2 | Sana, 3 | SanaBlock, 4 | get_1d_sincos_pos_embed_from_grid, 5 | get_2d_sincos_pos_embed, 6 | get_2d_sincos_pos_embed_from_grid, 7 | ) 8 | from .sana_multi_scale import ( 9 | SanaMS, 10 | SanaMS_600M_P1_D28, 11 | SanaMS_600M_P2_D28, 12 | SanaMS_600M_P4_D28, 13 | SanaMS_1600M_P1_D20, 14 | SanaMS_1600M_P2_D20, 15 | SanaMSBlock, 16 | ) 17 | from .sana_multi_scale_adaln import ( 18 | SanaMSAdaLN, 19 | SanaMSAdaLN_600M_P1_D28, 20 | SanaMSAdaLN_600M_P2_D28, 21 | SanaMSAdaLN_600M_P4_D28, 22 | SanaMSAdaLN_1600M_P1_D20, 23 | SanaMSAdaLN_1600M_P2_D20, 24 | SanaMSAdaLNBlock, 25 | ) 26 | from .sana_multi_scale_controlnet import SanaMSControlNet_600M_P1_D28 27 | from .sana_U_shape import ( 28 | SanaU, 29 | SanaU_600M_P1_D28, 30 | SanaU_600M_P2_D28, 31 | SanaU_600M_P4_D28, 32 | SanaU_1600M_P1_D20, 33 | SanaU_1600M_P2_D20, 34 | SanaUBlock, 35 | ) 36 | from .sana_U_shape_multi_scale import ( 37 | SanaUMS, 38 | SanaUMS_600M_P1_D28, 39 | SanaUMS_600M_P2_D28, 40 | SanaUMS_600M_P4_D28, 41 | SanaUMS_1600M_P1_D20, 42 | SanaUMS_1600M_P2_D20, 43 | SanaUMSBlock, 44 | ) 45 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from .triton_lite_mla import * 18 | from .triton_lite_mla_fwd import * 19 | from .triton_mb_conv_pre_glu import * 20 | 21 | # from .flash_attn import * 22 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/modules/flash_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | from flash_attn import flash_attn_func 19 | from torch import nn 20 | from torch.nn import functional as F 21 | 22 | 23 | class FlashAttention(nn.Module): 24 | def __init__(self, dim: int, num_heads: int): 25 | super().__init__() 26 | self.dim = dim 27 | assert dim % num_heads == 0 28 | self.num_heads = num_heads 29 | self.head_dim = dim // num_heads 30 | 31 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 32 | self.proj_out = torch.nn.Linear(dim, dim) 33 | 34 | def forward(self, x): 35 | B, N, C = x.shape 36 | qkv = self.qkv(x).view(B, N, 3, C) # B, N, 3, C 37 | q, k, v = qkv.unbind(2) # B, N, C 38 | k = k.reshape(B, N, self.num_heads, self.head_dim) 39 | v = v.reshape(B, N, self.num_heads, self.head_dim) 40 | q = q.reshape(B, N, self.num_heads, self.head_dim) 41 | out = flash_attn_func(q, k, v) # B, N, H, c 42 | out = self.proj_out(out.view(B, N, C)) # B, N, C 43 | return out 44 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/modules/mb_conv_pre_glu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | from torch import nn 19 | 20 | from .nn.act import build_act, get_act_name 21 | from .nn.conv import ConvLayer 22 | from .nn.norm import build_norm, get_norm_name 23 | from .utils.model import get_same_padding, val2tuple 24 | 25 | 26 | class MBConvPreGLU(nn.Module): 27 | def __init__( 28 | self, 29 | in_dim: int, 30 | out_dim: int, 31 | kernel_size=3, 32 | stride=1, 33 | mid_dim=None, 34 | expand=6, 35 | padding: int or None = None, 36 | use_bias=False, 37 | norm=(None, None, "ln2d"), 38 | act=("silu", "silu", None), 39 | ): 40 | super().__init__() 41 | use_bias = val2tuple(use_bias, 3) 42 | norm = val2tuple(norm, 3) 43 | act = val2tuple(act, 3) 44 | 45 | mid_dim = mid_dim or round(in_dim * expand) 46 | 47 | self.inverted_conv = ConvLayer( 48 | in_dim, 49 | mid_dim * 2, 50 | 1, 51 | use_bias=use_bias[0], 52 | norm=norm[0], 53 | act=None, 54 | ) 55 | self.glu_act = build_act(act[0], inplace=False) 56 | self.depth_conv = ConvLayer( 57 | mid_dim, 58 | mid_dim, 59 | kernel_size, 60 | stride=stride, 61 | groups=mid_dim, 62 | padding=padding, 63 | use_bias=use_bias[1], 64 | norm=norm[1], 65 | act=act[1], 66 | ) 67 | self.point_conv = ConvLayer( 68 | mid_dim, 69 | out_dim, 70 | 1, 71 | use_bias=use_bias[2], 72 | norm=norm[2], 73 | act=act[2], 74 | ) 75 | 76 | def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor: 77 | B, N, C = x.shape 78 | if HW is None: 79 | H = W = int(N**0.5) 80 | else: 81 | H, W = HW 82 | 83 | x = x.reshape(B, H, W, C).permute(0, 3, 1, 2) 84 | 85 | x = self.inverted_conv(x) 86 | x, gate = torch.chunk(x, 2, dim=1) 87 | gate = self.glu_act(gate) 88 | x = x * gate 89 | 90 | x = self.depth_conv(x) 91 | x = self.point_conv(x) 92 | 93 | x = x.reshape(B, C, N).permute(0, 2, 1) 94 | return x 95 | 96 | @property 97 | def module_str(self) -> str: 98 | _str = f"{self.depth_conv.kernel_size}{type(self).__name__}(" 99 | _str += f"in={self.inverted_conv.in_dim},mid={self.depth_conv.in_dim},out={self.point_conv.out_dim},s={self.depth_conv.stride}" 100 | _str += ( 101 | f",norm={get_norm_name(self.inverted_conv.norm)}" 102 | f"+{get_norm_name(self.depth_conv.norm)}" 103 | f"+{get_norm_name(self.point_conv.norm)}" 104 | ) 105 | _str += ( 106 | f",act={get_act_name(self.inverted_conv.act)}" 107 | f"+{get_act_name(self.depth_conv.act)}" 108 | f"+{get_act_name(self.point_conv.act)}" 109 | ) 110 | _str += f",glu_act={get_act_name(self.glu_act)})" 111 | return _str 112 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/modules/nn/act.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import copy 18 | 19 | import torch.nn as nn 20 | 21 | __all__ = ["build_act", "get_act_name"] 22 | 23 | # register activation function here 24 | # name: module, kwargs with default values 25 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = { 26 | "relu": (nn.ReLU, {"inplace": True}), 27 | "relu6": (nn.ReLU6, {"inplace": True}), 28 | "hswish": (nn.Hardswish, {"inplace": True}), 29 | "hsigmoid": (nn.Hardsigmoid, {"inplace": True}), 30 | "swish": (nn.SiLU, {"inplace": True}), 31 | "silu": (nn.SiLU, {"inplace": True}), 32 | "tanh": (nn.Tanh, {}), 33 | "sigmoid": (nn.Sigmoid, {}), 34 | "gelu": (nn.GELU, {"approximate": "tanh"}), 35 | "mish": (nn.Mish, {"inplace": True}), 36 | "identity": (nn.Identity, {}), 37 | } 38 | 39 | 40 | def build_act(name: str or None, **kwargs) -> nn.Module or None: 41 | if name in REGISTERED_ACT_DICT: 42 | act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name]) 43 | for key in default_args: 44 | if key in kwargs: 45 | default_args[key] = kwargs[key] 46 | return act_cls(**default_args) 47 | elif name is None or name.lower() == "none": 48 | return None 49 | else: 50 | raise ValueError(f"do not support: {name}") 51 | 52 | 53 | def get_act_name(act: nn.Module or None) -> str or None: 54 | if act is None: 55 | return None 56 | module2name = {} 57 | for key, config in REGISTERED_ACT_DICT.items(): 58 | module2name[config[0].__name__] = key 59 | return module2name.get(type(act).__name__, "unknown") 60 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/modules/nn/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | from torch import nn 19 | 20 | from ..utils.model import get_same_padding 21 | from .act import build_act, get_act_name 22 | from .norm import build_norm, get_norm_name 23 | 24 | 25 | class ConvLayer(nn.Module): 26 | def __init__( 27 | self, 28 | in_dim: int, 29 | out_dim: int, 30 | kernel_size=3, 31 | stride=1, 32 | dilation=1, 33 | groups=1, 34 | padding: int or None = None, 35 | use_bias=False, 36 | dropout=0.0, 37 | norm="bn2d", 38 | act="relu", 39 | ): 40 | super().__init__() 41 | if padding is None: 42 | padding = get_same_padding(kernel_size) 43 | padding *= dilation 44 | 45 | self.in_dim = in_dim 46 | self.out_dim = out_dim 47 | self.kernel_size = kernel_size 48 | self.stride = stride 49 | self.dilation = dilation 50 | self.groups = groups 51 | self.padding = padding 52 | self.use_bias = use_bias 53 | 54 | self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None 55 | self.conv = nn.Conv2d( 56 | in_dim, 57 | out_dim, 58 | kernel_size=(kernel_size, kernel_size), 59 | stride=(stride, stride), 60 | padding=padding, 61 | dilation=(dilation, dilation), 62 | groups=groups, 63 | bias=use_bias, 64 | ) 65 | self.norm = build_norm(norm, num_features=out_dim) 66 | self.act = build_act(act) 67 | 68 | def forward(self, x: torch.Tensor) -> torch.Tensor: 69 | if self.dropout is not None: 70 | x = self.dropout(x) 71 | x = self.conv(x) 72 | if self.norm: 73 | x = self.norm(x) 74 | if self.act: 75 | x = self.act(x) 76 | return x 77 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/modules/utils/compare_results.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | 19 | 20 | def compare_results(name: str, result: torch.Tensor, ref_result: torch.Tensor): 21 | print(f"comparing {name}") 22 | diff = (result - ref_result).abs().view(-1) 23 | max_error_pos = diff.argmax() 24 | print(f"max error: {diff.max()}, mean error: {diff.mean()}") 25 | print(f"max error pos: {result.view(-1)[max_error_pos]} {ref_result.view(-1)[max_error_pos]}") 26 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/modules/utils/dtype.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | import triton 19 | import triton.language as tl 20 | 21 | 22 | def get_dtype_from_str(dtype: str) -> torch.dtype: 23 | if dtype == "fp32": 24 | return torch.float32 25 | if dtype == "fp16": 26 | return torch.float16 27 | if dtype == "bf16": 28 | return torch.bfloat16 29 | raise NotImplementedError(f"dtype {dtype} is not supported") 30 | 31 | 32 | def get_tl_dtype_from_torch_dtype(dtype: torch.dtype) -> tl.dtype: 33 | if dtype == torch.float32: 34 | return tl.float32 35 | if dtype == torch.float16: 36 | return tl.float16 37 | if dtype == torch.bfloat16: 38 | return tl.bfloat16 39 | raise NotImplementedError(f"dtype {dtype} is not supported") 40 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/modules/utils/export_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import os 18 | import warnings 19 | from typing import Any, Tuple 20 | 21 | import torch 22 | 23 | 24 | def export_onnx( 25 | model: torch.nn.Module, 26 | input_shape: Tuple[int], 27 | export_path: str, 28 | opset: int, 29 | export_dtype: torch.dtype, 30 | export_device: torch.device, 31 | ) -> None: 32 | model.eval() 33 | 34 | dummy_input = {"x": torch.randn(input_shape, dtype=export_dtype, device=export_device)} 35 | dynamic_axes = { 36 | "x": {0: "batch_size"}, 37 | } 38 | 39 | # _ = model(**dummy_input) 40 | 41 | output_names = ["image_embeddings"] 42 | 43 | export_dir = os.path.dirname(export_path) 44 | if not os.path.exists(export_dir): 45 | os.makedirs(export_dir) 46 | 47 | with warnings.catch_warnings(): 48 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 49 | warnings.filterwarnings("ignore", category=UserWarning) 50 | print(f"Exporting onnx model to {export_path}...") 51 | with open(export_path, "wb") as f: 52 | torch.onnx.export( 53 | model, 54 | tuple(dummy_input.values()), 55 | f, 56 | export_params=True, 57 | verbose=False, 58 | opset_version=opset, 59 | do_constant_folding=True, 60 | input_names=list(dummy_input.keys()), 61 | output_names=output_names, 62 | dynamic_axes=dynamic_axes, 63 | ) 64 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/modules/utils/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 MIT Han Lab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | 18 | def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore 19 | """Repeat `val` for `repeat_time` times and return the list or val if list/tuple.""" 20 | if isinstance(x, (list, tuple)): 21 | return list(x) 22 | return [x for _ in range(repeat_time)] 23 | 24 | 25 | def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore 26 | """Return tuple with min_len by repeating element at idx_repeat.""" 27 | # convert to list first 28 | x = val2list(x) 29 | 30 | # repeat elements if necessary 31 | if len(x) > 0: 32 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] 33 | 34 | return tuple(x) 35 | 36 | 37 | def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]: 38 | if isinstance(kernel_size, tuple): 39 | return tuple([get_same_padding(ks) for ks in kernel_size]) 40 | else: 41 | assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number" 42 | return kernel_size // 2 43 | -------------------------------------------------------------------------------- /diffusion/model/nets/fastlinear/readme.md: -------------------------------------------------------------------------------- 1 | # a fast implementation of linear attention 2 | 3 | ## 64x64, fp16 4 | 5 | ```bash 6 | # validate correctness 7 | ## fp16 vs fp32 8 | python -m develop_triton_litemla attn_type=LiteMLA test_correctness=True 9 | ## triton fp16 vs fp32 10 | python -m develop_triton_litemla attn_type=TritonLiteMLA test_correctness=True 11 | 12 | # test performance 13 | ## fp16, forward 14 | python -m develop_triton_litemla attn_type=LiteMLA 15 | each step takes 10.81 ms 16 | max memory allocated: 2.2984 GB 17 | 18 | ## triton fp16, forward 19 | python -m develop_triton_litemla attn_type=TritonLiteMLA 20 | each step takes 4.70 ms 21 | max memory allocated: 1.6480 GB 22 | 23 | ## fp16, backward 24 | python -m develop_triton_litemla attn_type=LiteMLA backward=True 25 | each step takes 35.34 ms 26 | max memory allocated: 3.4412 GB 27 | 28 | ## triton fp16, backward 29 | python -m develop_triton_litemla attn_type=TritonLiteMLA backward=True 30 | each step takes 14.25 ms 31 | max memory allocated: 2.4704 GB 32 | ``` 33 | -------------------------------------------------------------------------------- /diffusion/model/nets/sana_ladd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from diffusion.model.builder import MODELS 21 | 22 | from .ladd_blocks import DiscHead 23 | from .sana_multi_scale import SanaMSCM 24 | 25 | 26 | @MODELS.register_module() 27 | class SanaMSCMDiscriminator(nn.Module): 28 | def __init__(self, pretrained_model: SanaMSCM, is_multiscale=False, head_block_ids=None): 29 | super().__init__() 30 | self.transformer = pretrained_model 31 | self.transformer.requires_grad_(False) 32 | 33 | if head_block_ids is None or len(head_block_ids) == 0: 34 | self.block_hooks = {2, 8, 14, 20, 27} if is_multiscale else {self.transformer.depth - 1} 35 | else: 36 | self.block_hooks = head_block_ids 37 | 38 | heads = [] 39 | for i in range(len(self.block_hooks)): 40 | heads.append(DiscHead(self.transformer.hidden_size, 0, 0)) 41 | self.heads = nn.ModuleList(heads) 42 | 43 | def get_head_inputs(self): 44 | return self.head_inputs 45 | 46 | def forward(self, x, timestep, y=None, data_info=None, mask=None, **kwargs): 47 | feat_list = [] 48 | self.head_inputs = [] 49 | 50 | def get_features(module, input, output): 51 | feat_list.append(output) 52 | return output 53 | 54 | hooks = [] 55 | for i, block in enumerate(self.transformer.blocks): 56 | if i in self.block_hooks: 57 | hooks.append(block.register_forward_hook(get_features)) 58 | 59 | self.transformer(x, timestep, y=y, mask=mask, data_info=data_info, return_logvar=False, **kwargs) 60 | 61 | for hook in hooks: 62 | hook.remove() 63 | 64 | res_list = [] 65 | for feat, head in zip(feat_list, self.heads): 66 | B, N, C = feat.shape 67 | feat = feat.transpose(1, 2) # [B, C, N] 68 | self.head_inputs.append(feat) 69 | res_list.append(head(feat, None).reshape(feat.shape[0], -1)) 70 | 71 | concat_res = torch.cat(res_list, dim=1) 72 | 73 | return concat_res 74 | 75 | @property 76 | def model(self): 77 | return self.transformer 78 | 79 | def save_pretrained(self, path): 80 | torch.save(self.state_dict(), path) 81 | 82 | 83 | class DiscHeadModel: 84 | def __init__(self, disc): 85 | self.disc = disc 86 | 87 | def state_dict(self): 88 | return {name: param for name, param in self.disc.state_dict().items() if not name.startswith("transformer.")} 89 | 90 | def __getattr__(self, name): 91 | return getattr(self.disc, name) 92 | -------------------------------------------------------------------------------- /diffusion/scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/diffusion/scheduler/__init__.py -------------------------------------------------------------------------------- /diffusion/scheduler/dpm_solver.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | 19 | from diffusion.model import gaussian_diffusion as gd 20 | from diffusion.model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper 21 | 22 | 23 | def DPMS( 24 | model, 25 | condition, 26 | uncondition, 27 | cfg_scale, 28 | pag_scale=1.0, 29 | pag_applied_layers=None, 30 | model_type="noise", # or "x_start" or "v" or "score", "flow" 31 | noise_schedule="linear", 32 | guidance_type="classifier-free", 33 | model_kwargs=None, 34 | diffusion_steps=1000, 35 | schedule="VP", 36 | interval_guidance=None, 37 | ): 38 | if pag_applied_layers is None: 39 | pag_applied_layers = [] 40 | if model_kwargs is None: 41 | model_kwargs = {} 42 | if interval_guidance is None: 43 | interval_guidance = [0, 1.0] 44 | betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) 45 | 46 | ## 1. Define the noise schedule. 47 | if schedule == "VP": 48 | noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas) 49 | elif schedule == "FLOW": 50 | noise_schedule = NoiseScheduleFlow(schedule="discrete_flow") 51 | 52 | ## 2. Convert your discrete-time `model` to the continuous-time 53 | ## noise prediction model. Here is an example for a diffusion model 54 | ## `model` with the noise prediction type ("noise") . 55 | model_fn = model_wrapper( 56 | model, 57 | noise_schedule, 58 | model_type=model_type, 59 | model_kwargs=model_kwargs, 60 | guidance_type=guidance_type, 61 | pag_scale=pag_scale, 62 | pag_applied_layers=pag_applied_layers, 63 | condition=condition, 64 | unconditional_condition=uncondition, 65 | guidance_scale=cfg_scale, 66 | interval_guidance=interval_guidance, 67 | ) 68 | ## 3. Define dpm-solver and sample by multistep DPM-Solver. 69 | return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") 70 | -------------------------------------------------------------------------------- /diffusion/scheduler/flow_euler_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import os 18 | 19 | import torch 20 | from diffusers import FlowMatchEulerDiscreteScheduler 21 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 22 | from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps 23 | from tqdm import tqdm 24 | 25 | 26 | class FlowEuler: 27 | def __init__(self, model_fn, condition, uncondition, cfg_scale, model_kwargs): 28 | self.model = model_fn 29 | self.condition = condition 30 | self.uncondition = uncondition 31 | self.cfg_scale = cfg_scale 32 | self.model_kwargs = model_kwargs 33 | # repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" 34 | self.scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) 35 | 36 | def sample(self, latents, steps=28): 37 | device = self.condition.device 38 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, steps, device, None) 39 | do_classifier_free_guidance = True 40 | 41 | prompt_embeds = self.condition 42 | if do_classifier_free_guidance: 43 | prompt_embeds = torch.cat([self.uncondition, self.condition], dim=0) 44 | 45 | for i, t in tqdm(list(enumerate(timesteps)), disable=os.getenv("DPM_TQDM", "False") == "True"): 46 | 47 | # expand the latents if we are doing classifier free guidance 48 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 49 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 50 | timestep = t.expand(latent_model_input.shape[0]) 51 | 52 | noise_pred = self.model( 53 | latent_model_input, 54 | timestep, 55 | prompt_embeds, 56 | **self.model_kwargs, 57 | ) 58 | 59 | if isinstance(noise_pred, Transformer2DModelOutput): 60 | noise_pred = noise_pred[0] 61 | 62 | # perform guidance 63 | if do_classifier_free_guidance: 64 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 65 | noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) 66 | 67 | # compute the previous noisy sample x_t -> x_t-1 68 | latents_dtype = latents.dtype 69 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 70 | 71 | if latents.dtype != latents_dtype: 72 | latents = latents.to(latents_dtype) 73 | 74 | return latents 75 | -------------------------------------------------------------------------------- /diffusion/scheduler/iddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # Modified from OpenAI's diffusion repos 18 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 19 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 20 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 21 | 22 | from diffusion.model import gaussian_diffusion as gd 23 | from diffusion.model.respace import SpacedDiffusion, space_timesteps 24 | 25 | 26 | def Scheduler( 27 | timestep_respacing, 28 | noise_schedule="linear", 29 | use_kl=False, 30 | sigma_small=False, 31 | predict_xstart=False, 32 | predict_flow_v=False, 33 | learn_sigma=True, 34 | pred_sigma=True, 35 | rescale_learned_sigmas=False, 36 | diffusion_steps=1000, 37 | snr=False, 38 | return_startx=False, 39 | flow_shift=1.0, 40 | ): 41 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 42 | if use_kl: 43 | loss_type = gd.LossType.RESCALED_KL 44 | elif rescale_learned_sigmas: 45 | loss_type = gd.LossType.RESCALED_MSE 46 | else: 47 | loss_type = gd.LossType.MSE 48 | if timestep_respacing is None or timestep_respacing == "": 49 | timestep_respacing = [diffusion_steps] 50 | if predict_xstart: 51 | model_mean_type = gd.ModelMeanType.START_X 52 | elif predict_flow_v: 53 | model_mean_type = gd.ModelMeanType.FLOW_VELOCITY 54 | else: 55 | model_mean_type = gd.ModelMeanType.EPSILON 56 | return SpacedDiffusion( 57 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 58 | betas=betas, 59 | model_mean_type=model_mean_type, 60 | model_var_type=( 61 | ( 62 | (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) 63 | if not learn_sigma 64 | else gd.ModelVarType.LEARNED_RANGE 65 | ) 66 | if pred_sigma 67 | else None 68 | ), 69 | loss_type=loss_type, 70 | snr=snr, 71 | return_startx=return_startx, 72 | # rescale_timesteps=rescale_timesteps, 73 | flow="flow" in noise_schedule, 74 | flow_shift=flow_shift, 75 | diffusion_steps=diffusion_steps, 76 | ) 77 | -------------------------------------------------------------------------------- /diffusion/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/diffusion/utils/__init__.py -------------------------------------------------------------------------------- /diffusion/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import logging 3 | import warnings 4 | 5 | import importlib_metadata 6 | from packaging import version 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | _xformers_available = importlib.util.find_spec("xformers") is not None 11 | try: 12 | if _xformers_available: 13 | _xformers_version = importlib_metadata.version("xformers") 14 | _torch_version = importlib_metadata.version("torch") 15 | if version.Version(_torch_version) < version.Version("1.12"): 16 | raise ValueError("xformers is installed but requires PyTorch >= 1.12") 17 | logger.debug(f"Successfully imported xformers version {_xformers_version}") 18 | except importlib_metadata.PackageNotFoundError: 19 | _xformers_available = False 20 | 21 | _triton_modules_available = importlib.util.find_spec("triton") is not None 22 | try: 23 | if _triton_modules_available: 24 | _triton_version = importlib_metadata.version("triton") 25 | if version.Version(_triton_version) < version.Version("3.0.0"): 26 | raise ValueError("triton is installed but requires Triton >= 3.0.0") 27 | logger.debug(f"Successfully imported triton version {_triton_version}") 28 | except ImportError: 29 | _triton_modules_available = False 30 | warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") 31 | 32 | 33 | def is_xformers_available(): 34 | return _xformers_available 35 | 36 | 37 | def is_triton_module_available(): 38 | return _triton_modules_available 39 | 40 | 41 | import inspect 42 | import warnings 43 | from typing import Any, Dict, Optional, Union 44 | 45 | from packaging import version 46 | 47 | 48 | def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): 49 | from .. import __version__ 50 | 51 | deprecated_kwargs = take_from 52 | values = () 53 | if not isinstance(args[0], tuple): 54 | args = (args,) 55 | 56 | for attribute, version_name, message in args: 57 | if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): 58 | raise ValueError( 59 | f"The deprecation tuple {(attribute, version_name, message)} should be removed since sana's" 60 | f" version {__version__} is >= {version_name}" 61 | ) 62 | 63 | warning = None 64 | if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: 65 | values += (deprecated_kwargs.pop(attribute),) 66 | warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." 67 | elif hasattr(deprecated_kwargs, attribute): 68 | values += (getattr(deprecated_kwargs, attribute),) 69 | warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." 70 | elif deprecated_kwargs is None: 71 | warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." 72 | 73 | if warning is not None: 74 | warning = warning + " " if standard_warn else "" 75 | warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) 76 | 77 | if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: 78 | call_frame = inspect.getouterframes(inspect.currentframe())[1] 79 | filename = call_frame.filename 80 | line_number = call_frame.lineno 81 | function = call_frame.function 82 | key, value = next(iter(deprecated_kwargs.items())) 83 | raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") 84 | 85 | if len(values) == 0: 86 | return 87 | elif len(values) == 1: 88 | return values[0] 89 | return values 90 | -------------------------------------------------------------------------------- /environment_setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | CONDA_ENV=${1:-""} 5 | if [ -n "$CONDA_ENV" ]; then 6 | # This is required to activate conda environment 7 | eval "$(conda shell.bash hook)" 8 | 9 | conda create -n $CONDA_ENV python=3.10.0 -y 10 | conda activate $CONDA_ENV 11 | # This is optional if you prefer to use built-in nvcc 12 | conda install -c nvidia cuda-toolkit=12.4 -y 13 | else 14 | echo "Skipping conda environment creation. Make sure you have the correct environment activated." 15 | fi 16 | 17 | # init a raw torch to avoid installation errors. 18 | # pip install torch 19 | 20 | # update pip to latest version for pyproject.toml setup. 21 | pip install -U pip 22 | 23 | # for fast attn 24 | pip install -U xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121 25 | 26 | # install sana 27 | pip install -e . 28 | 29 | # install torchprofile 30 | # pip install git+https://github.com/zhijian-liu/torchprofile 31 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "sana" 7 | version = "0.0.1" 8 | description = "SANA" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "pre-commit", 17 | "accelerate", 18 | "beautifulsoup4", 19 | "bs4", 20 | "came-pytorch", 21 | "einops", 22 | "ftfy", 23 | "diffusers@git+https://github.com/huggingface/diffusers", 24 | "clip@git+https://github.com/openai/CLIP.git", 25 | "gradio", 26 | "image-reward", 27 | "ipdb", 28 | "mmcv==1.7.2", 29 | "omegaconf", 30 | "opencv-python", 31 | "optimum", 32 | "patch_conv", 33 | "peft", 34 | "protobuf", 35 | "pytorch-fid", 36 | "regex", 37 | "sentencepiece", 38 | "tensorboard", 39 | "tensorboardX", 40 | "timm", 41 | "torchaudio==2.4.0", 42 | "torchvision==0.19", 43 | "transformers", 44 | "triton==3.0.0", 45 | "wandb", 46 | "webdataset", 47 | "xformers==0.0.27.post2", 48 | "yapf", 49 | "spaces", 50 | "matplotlib", 51 | "termcolor", 52 | "pyrallis", 53 | "bitsandbytes", 54 | ] 55 | 56 | 57 | [project.scripts] 58 | sana-run = "sana.cli.run:main" 59 | sana-upload = "sana.cli.upload2hf:main" 60 | 61 | [project.optional-dependencies] 62 | 63 | [project.urls] 64 | 65 | [tool.pip] 66 | extra-index-url = ["https://download.pytorch.org/whl/cu121"] 67 | 68 | [tool.black] 69 | line-length = 120 70 | 71 | [tool.isort] 72 | profile = "black" 73 | multi_line_output = 3 74 | include_trailing_comma = true 75 | force_grid_wrap = 0 76 | use_parentheses = true 77 | ensure_newline_before_comments = true 78 | line_length = 120 79 | 80 | [tool.setuptools.packages.find] 81 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 82 | 83 | [tool.wheel] 84 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 85 | -------------------------------------------------------------------------------- /sana/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .download import download_model 2 | from .hf_utils import hf_download_or_fpath 3 | -------------------------------------------------------------------------------- /sana/tools/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Functions for downloading pre-trained Sana models 9 | """ 10 | import argparse 11 | import os 12 | 13 | import torch 14 | from torchvision.datasets.utils import download_url 15 | 16 | pretrained_models = {} 17 | 18 | 19 | def find_model(model_name): 20 | """ 21 | Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path. 22 | """ 23 | if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints 24 | return download_model(model_name) 25 | else: # Load a custom Sana checkpoint: 26 | assert os.path.isfile(model_name), f"Could not find Sana checkpoint at {model_name}" 27 | return torch.load(model_name, map_location=lambda storage, loc: storage) 28 | 29 | 30 | def download_model(model_name): 31 | """ 32 | Downloads a pre-trained Sana model from the web. 33 | """ 34 | assert model_name in pretrained_models 35 | local_path = f"output/pretrained_models/{model_name}" 36 | if not os.path.isfile(local_path): 37 | hf_endpoint = os.environ.get("HF_ENDPOINT") 38 | if hf_endpoint is None: 39 | hf_endpoint = "https://huggingface.co" 40 | os.makedirs("output/pretrained_models", exist_ok=True) 41 | web_path = f"{hf_endpoint}/xxx/resolve/main/{model_name}" 42 | download_url(web_path, "output/pretrained_models/") 43 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 44 | return model 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--model_names", nargs="+", type=str, default=pretrained_models) 50 | args = parser.parse_args() 51 | model_names = args.model_names 52 | model_names = set(model_names) 53 | 54 | # Download Sana checkpoints 55 | for model in model_names: 56 | download_model(model) 57 | print("Done.") 58 | -------------------------------------------------------------------------------- /sana/tools/hf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import os 18 | import os.path as osp 19 | import sys 20 | 21 | from huggingface_hub import hf_hub_download, snapshot_download 22 | 23 | 24 | def hf_download_or_fpath(path): 25 | if osp.exists(path): 26 | return path 27 | 28 | if path.startswith("hf://"): 29 | segs = path.replace("hf://", "").split("/") 30 | repo_id = "/".join(segs[:2]) 31 | filename = "/".join(segs[2:]) 32 | return hf_download_data(repo_id, filename, repo_type="model", download_full_repo=True) 33 | 34 | 35 | def hf_download_data( 36 | repo_id="Efficient-Large-Model/Sana_1600M_1024px", 37 | filename="checkpoints/Sana_1600M_1024px.pth", 38 | cache_dir=None, 39 | repo_type="model", 40 | download_full_repo=False, 41 | ): 42 | """ 43 | Download dummy data from a Hugging Face repository. 44 | 45 | Args: 46 | repo_id (str): The ID of the Hugging Face repository. 47 | filename (str): The name of the file to download. 48 | cache_dir (str, optional): The directory to cache the downloaded file. 49 | 50 | Returns: 51 | str: The path to the downloaded file. 52 | """ 53 | try: 54 | if download_full_repo: 55 | # download full repos to fit dc-ae 56 | snapshot_download( 57 | repo_id=repo_id, 58 | cache_dir=cache_dir, 59 | repo_type=repo_type, 60 | ) 61 | file_path = hf_hub_download( 62 | repo_id=repo_id, 63 | filename=filename, 64 | cache_dir=cache_dir, 65 | repo_type=repo_type, 66 | ) 67 | return file_path 68 | except Exception as e: 69 | print(f"Error downloading file: {e}") 70 | return None 71 | -------------------------------------------------------------------------------- /scripts/infer_run_inference_geneval_diffusers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ================= sampler & data ================= 3 | np=8 # number of GPU to use 4 | default_step=20 # 14 5 | default_sample_nums=553 6 | default_sampling_algo="dpm-solver" 7 | default_add_label='' 8 | 9 | # parser 10 | config_file=$1 11 | model_paths=$2 12 | 13 | for arg in "$@" 14 | do 15 | case $arg in 16 | --step=*) 17 | step="${arg#*=}" 18 | shift 19 | ;; 20 | --sampling_algo=*) 21 | sampling_algo="${arg#*=}" 22 | shift 23 | ;; 24 | --add_label=*) 25 | add_label="${arg#*=}" 26 | shift 27 | ;; 28 | --model_path=*) 29 | model_paths="${arg#*=}" 30 | shift 31 | ;; 32 | --exist_time_prefix=*) 33 | exist_time_prefix="${arg#*=}" 34 | shift 35 | ;; 36 | --if_save_dirname=*) 37 | if_save_dirname="${arg#*=}" 38 | shift 39 | ;; 40 | *) 41 | ;; 42 | esac 43 | done 44 | 45 | sample_nums=$default_sample_nums 46 | samples_per_gpu=$((sample_nums / np)) 47 | add_label=${add_label:-$default_add_label} 48 | echo "Sample numbers: $sample_nums" 49 | echo "Add label: $add_label" 50 | echo "Exist time prefix: $exist_time_prefix" 51 | 52 | cmd_template="DPM_TQDM=True python scripts/inference_geneval_diffusers.py \ 53 | --model_path=$model_paths \ 54 | --gpu_id {gpu_id} --start_index {start_index} --end_index {end_index}" 55 | if [ -n "${add_label}" ]; then 56 | cmd_template="${cmd_template} --add_label ${add_label}" 57 | fi 58 | 59 | echo "==================== inferencing ====================" 60 | for gpu_id in $(seq 0 $((np - 1))); do 61 | start_index=$((gpu_id * samples_per_gpu)) 62 | end_index=$((start_index + samples_per_gpu)) 63 | if [ $gpu_id -eq $((np - 1)) ]; then 64 | end_index=$sample_nums 65 | fi 66 | 67 | cmd="${cmd_template//\{config_file\}/$config_file}" 68 | cmd="${cmd//\{model_path\}/$model_paths}" 69 | cmd="${cmd//\{gpu_id\}/$gpu_id}" 70 | cmd="${cmd//\{start_index\}/$start_index}" 71 | cmd="${cmd//\{end_index\}/$end_index}" 72 | 73 | echo "Running on GPU $gpu_id: samples $start_index to $end_index" 74 | eval CUDA_VISIBLE_DEVICES=$gpu_id $cmd & 75 | done 76 | wait 77 | 78 | echo infer finally done 79 | -------------------------------------------------------------------------------- /scripts/style.css: -------------------------------------------------------------------------------- 1 | /*.gradio-container{width:680px!important}*/ 2 | /* style.css */ 3 | .gradio_group, .gradio_row, .gradio_column { 4 | display: flex; 5 | flex-direction: row; 6 | justify-content: flex-start; 7 | align-items: flex-start; 8 | flex-wrap: wrap; 9 | } 10 | -------------------------------------------------------------------------------- /tests/bash/entry.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | set -e 3 | 4 | 5 | echo "Testing inference" 6 | bash tests/bash/test_inference.sh 7 | 8 | echo "Testing training" 9 | bash tests/bash/test_training_1epoch.sh 10 | -------------------------------------------------------------------------------- /tests/bash/test_inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python scripts/inference.py \ 5 | --config=configs/sana_config/1024ms/Sana_600M_img1024.yaml \ 6 | --model_path=hf://Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth 7 | 8 | 9 | python scripts/inference.py \ 10 | --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \ 11 | --model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth 12 | 13 | python tools/controlnet/inference_controlnet.py \ 14 | --config=configs/sana_controlnet_config/Sana_600M_img1024_controlnet.yaml \ 15 | --model_path=hf://Efficient-Large-Model/Sana_600M_1024px_ControlNet_HED/checkpoints/Sana_600M_1024px_ControlNet_HED.pth \ 16 | --json_file=asset/controlnet/samples_controlnet.json 17 | 18 | python scripts/inference_sana_sprint.py \ 19 | --config=configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml \ 20 | --model_path=hf://Lawrence-cj/Sana_Sprint_1600M_1024px/Sana_Sprint_1600M_1024px_36K.pth \ 21 | --txt_file=asset/samples/samples_mini.txt 22 | -------------------------------------------------------------------------------- /tests/bash/test_training_1epoch.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | set -e 3 | 4 | mkdir -p data/data_public 5 | huggingface-cli download Efficient-Large-Model/sana_data_public --repo-type dataset --local-dir ./data/data_public --local-dir-use-symlinks False 6 | huggingface-cli download Efficient-Large-Model/toy_data --repo-type dataset --local-dir ./data/toy_data --local-dir-use-symlinks False 7 | 8 | # test offline vae feature 9 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --data.load_vae_feat=true 10 | 11 | # test online vae feature 12 | bash train_scripts/train.sh configs/sana_config/512ms/ci_Sana_600M_img512.yaml --data.data_dir="[asset/example_data]" --data.type=SanaImgDataset --model.multi_scale=false 13 | 14 | # test FSDP training 15 | bash train_scripts/train.sh configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1 16 | 17 | # test SANA-Sprint(sCM + LADD) training 18 | bash train_scripts/train_scm_ladd.sh configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml --data.data_dir="[data/toy_data]" --data.load_vae_feat=true --train.num_epochs=1 --train.log_interval=1 19 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/tools/__init__.py -------------------------------------------------------------------------------- /tools/controlnet/annotator/ckpts/ckpts.txt: -------------------------------------------------------------------------------- 1 | Weights here. 2 | -------------------------------------------------------------------------------- /tools/controlnet/annotator/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), "ckpts") 8 | 9 | 10 | def HWC3(x): 11 | assert x.dtype == np.uint8 12 | if x.ndim == 2: 13 | x = x[:, :, None] 14 | assert x.ndim == 3 15 | H, W, C = x.shape 16 | assert C == 1 or C == 3 or C == 4 17 | if C == 3: 18 | return x 19 | if C == 1: 20 | return np.concatenate([x, x, x], axis=2) 21 | if C == 4: 22 | color = x[:, :, 0:3].astype(np.float32) 23 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 24 | y = color * alpha + 255.0 * (1.0 - alpha) 25 | y = y.clip(0, 255).astype(np.uint8) 26 | return y 27 | 28 | 29 | def resize_image(input_image, resolution): 30 | H, W, C = input_image.shape 31 | H = float(H) 32 | W = float(W) 33 | k = float(resolution) / min(H, W) 34 | H *= k 35 | W *= k 36 | H = int(np.round(H / 64.0)) * 64 37 | W = int(np.round(W / 64.0)) * 64 38 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 39 | return img 40 | 41 | 42 | def nms(x, t, s): 43 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 44 | 45 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 46 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 47 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 48 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 49 | 50 | y = np.zeros_like(x) 51 | 52 | for f in [f1, f2, f3, f4]: 53 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 54 | 55 | z = np.zeros_like(y, dtype=np.uint8) 56 | z[y > t] = 255 57 | return z 58 | 59 | 60 | def make_noise_disk(H, W, C, F): 61 | noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) 62 | noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) 63 | noise = noise[F : F + H, F : F + W] 64 | noise -= np.min(noise) 65 | noise /= np.max(noise) 66 | if C == 1: 67 | noise = noise[:, :, None] 68 | return noise 69 | 70 | 71 | def min_max_norm(x): 72 | x -= np.min(x) 73 | x /= np.maximum(np.max(x), 1e-5) 74 | return x 75 | 76 | 77 | def safe_step(x, step=2): 78 | y = x.astype(np.float32) * float(step + 1) 79 | y = y.astype(np.int32).astype(np.float32) / float(step) 80 | return y 81 | 82 | 83 | def img2mask(img, H, W, low=10, high=90): 84 | assert img.ndim == 3 or img.ndim == 2 85 | assert img.dtype == np.uint8 86 | 87 | if img.ndim == 3: 88 | y = img[:, :, random.randrange(0, img.shape[2])] 89 | else: 90 | y = img 91 | 92 | y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) 93 | 94 | if random.uniform(0, 1) < 0.5: 95 | y = 255 - y 96 | 97 | return y < np.percentile(y, random.randrange(low, high)) 98 | -------------------------------------------------------------------------------- /tools/controlnet/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | from torchvision.transforms.functional import InterpolationMode 8 | 9 | from tools.controlnet.annotator.hed import HEDdetector 10 | from tools.controlnet.annotator.util import HWC3, nms, resize_image 11 | 12 | preprocessor = None 13 | 14 | 15 | def transform_control_signal(control_signal, hw): 16 | if isinstance(control_signal, str): 17 | control_signal = Image.open(control_signal) 18 | elif isinstance(control_signal, Image.Image): 19 | control_signal = control_signal 20 | elif isinstance(control_signal, np.ndarray): 21 | control_signal = Image.fromarray(control_signal) 22 | else: 23 | raise ValueError("control_signal must be a path or a PIL.Image.Image or a numpy array") 24 | 25 | transform = T.Compose( 26 | [ 27 | T.Lambda(lambda img: img.convert("RGB")), 28 | T.Resize((int(hw[0, 0]), int(hw[0, 1])), interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC 29 | T.CenterCrop((int(hw[0, 0]), int(hw[0, 1]))), 30 | T.ToTensor(), 31 | T.Normalize([0.5], [0.5]), 32 | ] 33 | ) 34 | return transform(control_signal).unsqueeze(0) 35 | 36 | 37 | def get_scribble_map(input_image, det, detect_resolution=512, thickness=None): 38 | """ 39 | Generate scribble map from input image 40 | 41 | Args: 42 | input_image: Input image (numpy array, HWC format) 43 | det: Detector type ('Scribble_HED', 'Scribble_PIDI', 'None') 44 | detect_resolution: Processing resolution 45 | thickness: Line thickness (between 0-24, None for random) 46 | 47 | Returns: 48 | Processed scribble map 49 | """ 50 | global preprocessor 51 | 52 | # Initialize detector 53 | if "HED" in det and not isinstance(preprocessor, HEDdetector): 54 | preprocessor = HEDdetector() 55 | 56 | input_image = HWC3(input_image) 57 | 58 | if det == "None": 59 | detected_map = input_image.copy() 60 | else: 61 | # Generate scribble map 62 | detected_map = preprocessor(resize_image(input_image, detect_resolution)) 63 | detected_map = HWC3(detected_map) 64 | 65 | # Post-processing 66 | detected_map = nms(detected_map, 127, 3.0) 67 | detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) 68 | detected_map[detected_map > 4] = 255 69 | detected_map[detected_map < 255] = 0 70 | 71 | # Control line thickness 72 | if thickness is None: 73 | thickness = random.randint(0, 24) # Random thickness, including 0 74 | if thickness == 0: 75 | # Use erosion operation to get thinner lines 76 | kernel = np.ones((4, 4), np.uint8) 77 | detected_map = cv2.erode(detected_map, kernel, iterations=1) 78 | elif thickness > 1: 79 | kernel_size = thickness // 2 80 | kernel = np.ones((kernel_size, kernel_size), np.uint8) 81 | detected_map = cv2.dilate(detected_map, kernel, iterations=1) 82 | 83 | return detected_map 84 | -------------------------------------------------------------------------------- /tools/convert_ImgDataset_to_WebDatasetMS_format.py: -------------------------------------------------------------------------------- 1 | # @Author: Pevernow (wzy3450354617@gmail.com) 2 | # @Date: 2025/1/5 3 | # @License: (Follow the main project) 4 | import json 5 | import os 6 | import tarfile 7 | 8 | from PIL import Image, PngImagePlugin 9 | 10 | PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024 # Increase maximum size for text chunks 11 | 12 | 13 | def process_data(input_dir, output_tar_name="output.tar"): 14 | """ 15 | Processes a directory containing PNG files, generates corresponding JSON files, 16 | and packages all files into a TAR file. It also counts the number of processed PNG images, 17 | and saves the height and width of each PNG file to the JSON. 18 | 19 | Args: 20 | input_dir (str): The input directory containing PNG files. 21 | output_tar_name (str): The name of the output TAR file (default is "output.tar"). 22 | """ 23 | png_count = 0 24 | json_files_created = [] 25 | 26 | for filename in os.listdir(input_dir): 27 | if filename.lower().endswith(".png"): 28 | png_count += 1 29 | base_name = filename[:-4] # Remove the ".png" extension 30 | txt_filename = os.path.join(input_dir, base_name + ".txt") 31 | json_filename = base_name + ".json" 32 | json_filepath = os.path.join(input_dir, json_filename) 33 | png_filepath = os.path.join(input_dir, filename) 34 | 35 | if os.path.exists(txt_filename): 36 | try: 37 | # Get the dimensions of the PNG image 38 | with Image.open(png_filepath) as img: 39 | width, height = img.size 40 | 41 | with open(txt_filename, encoding="utf-8") as f: 42 | caption_content = f.read().strip() 43 | 44 | data = {"file_name": filename, "prompt": caption_content, "width": width, "height": height} 45 | 46 | with open(json_filepath, "w", encoding="utf-8") as outfile: 47 | json.dump(data, outfile, indent=4, ensure_ascii=False) 48 | 49 | print(f"Generated: {json_filename}") 50 | json_files_created.append(json_filepath) 51 | 52 | except Exception as e: 53 | print(f"Error processing file {filename}: {e}") 54 | else: 55 | print(f"Warning: No corresponding TXT file found for {filename}.") 56 | 57 | # Create a TAR file and include all files 58 | with tarfile.open(output_tar_name, "w") as tar: 59 | for item in os.listdir(input_dir): 60 | item_path = os.path.join(input_dir, item) 61 | tar.add(item_path, arcname=item) # arcname maintains the relative path of the file in the tar 62 | 63 | print(f"\nAll files have been packaged into: {output_tar_name}") 64 | print(f"Number of PNG images processed: {png_count}") 65 | 66 | 67 | if __name__ == "__main__": 68 | input_directory = input("Please enter the directory path containing PNG and TXT files: ") 69 | output_tar_filename = ( 70 | input("Please enter the name of the output TAR file (default is output.tar): ") or "output.tar" 71 | ) 72 | process_data(input_directory, output_tar_filename) 73 | -------------------------------------------------------------------------------- /tools/convert_py_to_yaml.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | 5 | 6 | def convert_py_to_yaml(py_file_path): 7 | with open(py_file_path, encoding="utf-8") as py_file: 8 | py_content = py_file.read() 9 | 10 | local_vars = {} 11 | exec(py_content, {}, local_vars) 12 | 13 | yaml_file_path = os.path.splitext(py_file_path)[0] + ".yaml" 14 | 15 | with open(yaml_file_path, "w", encoding="utf-8") as yaml_file: 16 | yaml.dump(local_vars, yaml_file, default_flow_style=False, allow_unicode=True) 17 | 18 | 19 | def process_directory(path): 20 | for root, dirs, files in os.walk(path): 21 | for filename in files: 22 | if filename.endswith(".py"): 23 | py_file_path = os.path.join(root, filename) 24 | convert_py_to_yaml(py_file_path) 25 | print(f"convert {py_file_path} to YAML format") 26 | 27 | 28 | if __name__ == "__main__": 29 | process_directory("../configs/") 30 | -------------------------------------------------------------------------------- /tools/create_wids_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import json 18 | import os 19 | import sys 20 | import tarfile 21 | from glob import glob 22 | 23 | from tqdm.contrib.concurrent import process_map 24 | 25 | """ 26 | python tools/create_wids_metadata.py /path/to/tar/dir > /path/to/wids-meta.json 27 | """ 28 | 29 | d = sys.argv[1] 30 | 31 | 32 | def process(t): 33 | d = {} 34 | with tarfile.open(t, "r") as tar: 35 | for f in tar: 36 | n, e = os.path.splitext(f.name) 37 | if e == ".jpg" or e == ".jpeg" or e == ".png" or e == ".json" or e == ".npy": 38 | if n in d: 39 | d[n] = 1 40 | else: 41 | d[n] = 0 42 | s = os.path.getsize(t) 43 | i = sum(d.values()) 44 | t = os.path.basename(t) 45 | return {"url": t, "nsamples": i, "filesize": s} 46 | 47 | 48 | print( 49 | json.dumps( 50 | { 51 | "name": "sana-dev", 52 | "__kind__": "SANA-WebDataset", 53 | "wids_version": 1, 54 | "shardlist": sorted( 55 | process_map(process, glob(f"{d}/*.tar"), chunksize=1, max_workers=os.cpu_count()), 56 | key=lambda x: x["url"], 57 | ), 58 | }, 59 | indent=4, 60 | ), 61 | end="", 62 | ) 63 | -------------------------------------------------------------------------------- /tools/download.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | """ 17 | Functions for downloading pre-trained Sana models 18 | """ 19 | import argparse 20 | import os 21 | 22 | import torch 23 | from termcolor import colored 24 | from torchvision.datasets.utils import download_url 25 | 26 | from sana.tools import hf_download_or_fpath 27 | 28 | pretrained_models = {} 29 | 30 | 31 | def find_model(model_name): 32 | """ 33 | Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path. 34 | """ 35 | if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints 36 | return download_model(model_name) 37 | 38 | # Load a custom Sana checkpoint: 39 | model_name = hf_download_or_fpath(model_name) 40 | assert os.path.isfile(model_name), f"Could not find Sana checkpoint at {model_name}" 41 | print(colored(f"[Sana] Loading model from {model_name}", attrs=["bold"])) 42 | return torch.load(model_name, map_location=lambda storage, loc: storage) 43 | 44 | 45 | def download_model(model_name): 46 | """ 47 | Downloads a pre-trained Sana model from the web. 48 | """ 49 | assert model_name in pretrained_models 50 | local_path = f"output/pretrained_models/{model_name}" 51 | if not os.path.isfile(local_path): 52 | hf_endpoint = os.environ.get("HF_ENDPOINT") 53 | if hf_endpoint is None: 54 | hf_endpoint = "https://huggingface.co" 55 | os.makedirs("output/pretrained_models", exist_ok=True) 56 | web_path = f"" 57 | download_url(web_path, "output/pretrained_models/") 58 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 59 | return model 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("--model_names", nargs="+", type=str, default=pretrained_models) 65 | args = parser.parse_args() 66 | model_names = args.model_names 67 | model_names = set(model_names) 68 | 69 | # Download Sana checkpoints 70 | for model in model_names: 71 | download_model(model) 72 | print("Done.") 73 | -------------------------------------------------------------------------------- /tools/inference_scaling/nvila_sana_pick.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | set -e 3 | 4 | sana_dir=$1 5 | number_of_files=$2 6 | pick_number=$3 7 | # calculate number of GPU to use in this machine 8 | num_gpu=$(nvidia-smi -L | wc -l) 9 | echo "sana_dir: $sana_dir, number_of_files: $number_of_files, pick_number: $pick_number, num_gpu: $num_gpu" 10 | # start idx iterate from 0 * (552//8), 1 * (552//8), 2 * (552//8), 3 * (552//8), 4 * (552//8), 5 * (552//8), 6 * (552//8), 7 * (552//8) 11 | # end idx iterate from 1 * (552//8), 2 * (552//8), 3 * (552//8), 4 * (552//8), 5 * (552//8), 6 * (552//8), 7 * (552//8), 552 12 | for idx in $(seq 0 $((num_gpu - 1))); do 13 | start_idx=$((idx * (552 / num_gpu))) 14 | end_idx=$((start_idx + 552 / num_gpu)) 15 | if [ $idx -eq $((num_gpu - 1)) ]; then 16 | end_idx=552 17 | fi 18 | 19 | echo "CUDA_VISIBLE_DEVICES=$idx python tools/inference_scaling/nvila_sana_pick.py --start_idx $start_idx --end_idx $end_idx --base_dir $sana_dir --number_of_files $number_of_files --pick_number $pick_number &" 20 | CUDA_VISIBLE_DEVICES=$idx python tools/inference_scaling/nvila_sana_pick.py --start_idx $start_idx --end_idx $end_idx --base_dir $sana_dir --number_of_files $number_of_files --pick_number $pick_number & 21 | done 22 | wait 23 | -------------------------------------------------------------------------------- /tools/metrics/clip-score/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | -------------------------------------------------------------------------------- /tools/metrics/clip-score/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import setuptools 4 | 5 | 6 | def read(rel_path): 7 | base_path = os.path.abspath(os.path.dirname(__file__)) 8 | with open(os.path.join(base_path, rel_path)) as f: 9 | return f.read() 10 | 11 | 12 | def get_version(rel_path): 13 | for line in read(rel_path).splitlines(): 14 | if line.startswith("__version__"): 15 | delim = '"' if '"' in line else "'" 16 | return line.split(delim)[1] 17 | 18 | raise RuntimeError("Unable to find version string.") 19 | 20 | 21 | if __name__ == "__main__": 22 | setuptools.setup( 23 | name="clip-score", 24 | version=get_version(os.path.join("src", "clip_score", "__init__.py")), 25 | author="Taited", 26 | author_email="taited9160@gmail.com", 27 | description=("Package for calculating CLIP-Score" " using PyTorch"), 28 | long_description=read("README.md"), 29 | long_description_content_type="text/markdown", 30 | url="https://github.com/taited/clip-score", 31 | package_dir={"": "src"}, 32 | packages=setuptools.find_packages(where="src"), 33 | classifiers=[ 34 | "Programming Language :: Python :: 3", 35 | "License :: OSI Approved :: Apache Software License", 36 | ], 37 | python_requires=">=3.5", 38 | entry_points={ 39 | "console_scripts": [ 40 | "clip-score = clip_score.clip_score:main", 41 | ], 42 | }, 43 | install_requires=[ 44 | "numpy", 45 | "pillow", 46 | "torch>=1.7.1", 47 | "torchvision>=0.8.2", 48 | "ftfy", 49 | "regex", 50 | "tqdm", 51 | ], 52 | extras_require={"dev": ["flake8", "flake8-bugbear", "flake8-isort", "nox"]}, 53 | ) 54 | -------------------------------------------------------------------------------- /tools/metrics/clip-score/src/clip_score/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.1" 2 | -------------------------------------------------------------------------------- /tools/metrics/clip-score/src/clip_score/__main__.py: -------------------------------------------------------------------------------- 1 | import clip_score.clip_score 2 | 3 | clip_score.clip_score.main() 4 | -------------------------------------------------------------------------------- /tools/metrics/dpg_bench/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | addict 3 | 4 | # for modelscope 5 | cloudpickle 6 | datasets==2.21.0 7 | decord>=0.6.0 8 | diffusers 9 | ftfy>=6.0.3 10 | librosa==0.10.1 11 | modelscope[multi-modal] 12 | numpy 13 | opencv-python 14 | oss2 15 | pandas 16 | pillow 17 | # compatible with taming-transformers-rom1504 18 | rapidfuzz 19 | # rough-score was just recently updated from 0.0.4 to 0.0.7 20 | # which introduced compatability issues that are being investigated 21 | rouge_score<=0.0.4 22 | safetensors 23 | simplejson 24 | sortedcontainers 25 | # scikit-video 26 | soundfile 27 | taming-transformers-rom1504 28 | tiktoken 29 | timm 30 | tokenizers 31 | torchvision 32 | tqdm 33 | transformers 34 | transformers_stream_generator 35 | unicodedata2 36 | wandb 37 | zhconv 38 | # fairseq need to be build from source code: https://github.com/facebookresearch/fairseq 39 | -------------------------------------------------------------------------------- /tools/metrics/geneval/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Dhruba Ghosh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tools/metrics/geneval/evaluation/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download Mask2Former object detection config and weights 4 | 5 | if [ ! -z "$1" ] 6 | then 7 | mkdir -p "$1" 8 | echo "Downloading mask2former for GenEval" 9 | wget https://download.openmmlab.com/mmdetection/v2.0/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco_20220504_001756-743b7d99.pth -O "$1/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.pth" 10 | fi 11 | -------------------------------------------------------------------------------- /tools/metrics/geneval/evaluation/object_names.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | potted plant 60 | bed 61 | dining table 62 | toilet 63 | tv 64 | laptop 65 | computer mouse 66 | tv remote 67 | computer keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /tools/metrics/geneval/evaluation/summary_scores.py: -------------------------------------------------------------------------------- 1 | # Get results of evaluation 2 | 3 | import argparse 4 | import os 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("filename", type=str) 11 | args = parser.parse_args() 12 | 13 | # Load classnames 14 | 15 | with open(os.path.join(os.path.dirname(__file__), "object_names.txt")) as cls_file: 16 | classnames = [line.strip() for line in cls_file] 17 | cls_to_idx = {"_".join(cls.split()): idx for idx, cls in enumerate(classnames)} 18 | 19 | # Load results 20 | 21 | df = pd.read_json(args.filename, orient="records", lines=True) 22 | 23 | # Measure overall success 24 | 25 | print("Summary") 26 | print("=======") 27 | print(f"Total images: {len(df)}") 28 | print(f"Total prompts: {len(df.groupby('metadata'))}") 29 | print(f"% correct images: {df['correct'].mean():.2%}") 30 | print(f"% correct prompts: {df.groupby('metadata')['correct'].any().mean():.2%}") 31 | print() 32 | 33 | # By group 34 | 35 | task_scores = [] 36 | 37 | print("Task breakdown") 38 | print("==============") 39 | for tag, task_df in df.groupby("tag", sort=False): 40 | task_scores.append(task_df["correct"].mean()) 41 | print(f"{tag:<16} = {task_df['correct'].mean():.2%} ({task_df['correct'].sum()} / {len(task_df)})") 42 | print() 43 | 44 | print(f"Overall score (avg. over tasks): {np.mean(task_scores):.5f}") 45 | -------------------------------------------------------------------------------- /tools/metrics/geneval/geneval_env.md: -------------------------------------------------------------------------------- 1 | The installation process refers to https://github.com/djghosh13/geneval/issues/12 Thanks for the community! 2 | 3 | # Cloning the repository 4 | 5 | ```bash 6 | git clone https://github.com/djghosh13/geneval.git 7 | 8 | cd geneval 9 | conda create -n geneval python=3.8.10 -y 10 | conda activate geneval 11 | ``` 12 | 13 | # Installing dependencies 14 | 15 | ```bash 16 | pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121 17 | pip install open-clip-torch==2.26.1 18 | pip install clip-benchmark 19 | pip install -U openmim 20 | pip install einops 21 | python -m pip install lightning 22 | pip install diffusers["torch"] transformers 23 | pip install tomli 24 | pip install platformdirs 25 | pip install --upgrade setuptools 26 | ``` 27 | 28 | # mmengine and mmcv dependency installation 29 | 30 | ```bash 31 | mim install mmengine mmcv-full==1.7.2 32 | ``` 33 | 34 | # mmdet installation 35 | 36 | ```bash 37 | git clone https://github.com/open-mmlab/mmdetection.git 38 | cd mmdetection; git checkout 2.x 39 | pip install -v -e . 40 | ``` 41 | 42 | 43 | 44 | 66 | -------------------------------------------------------------------------------- /tools/metrics/geneval/images/geneval_figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Sana/70459f414474c10c509e8b58f3f9442738f85577/tools/metrics/geneval/images/geneval_figure_1.png -------------------------------------------------------------------------------- /tools/metrics/geneval/prompts/object_names.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | potted plant 60 | bed 61 | dining table 62 | toilet 63 | tv 64 | laptop 65 | computer mouse 66 | tv remote 67 | computer keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /tools/metrics/pytorch-fid/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | -------------------------------------------------------------------------------- /tools/metrics/pytorch-fid/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## \[0.3.0\] - 2023-01-05 4 | 5 | ### Added 6 | 7 | - Add argument `--save-stats` allowing to compute dataset statistics and save them as an `.npz` file ([#80](https://github.com/mseitzer/pytorch-fid/pull/80)). The `.npz` file can be used in subsequent FID computations instead of recomputing the dataset statistics. This option can be used in the following way: `python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile`. 8 | 9 | ### Fixed 10 | 11 | - Do not use `os.sched_getaffinity` to get number of available CPUs on Windows, as it is not available there ([232b3b14](https://github.com/mseitzer/pytorch-fid/commit/232b3b1468800102fcceaf6f2bb8977811fc991a), [#84](https://github.com/mseitzer/pytorch-fid/issues/84)). 12 | - Do not use Inception model argument `pretrained`, as it was deprecated in torchvision 0.13 ([#88](https://github.com/mseitzer/pytorch-fid/pull/88)). 13 | 14 | ## \[0.2.1\] - 2021-10-10 15 | 16 | ### Added 17 | 18 | - Add argument `--num-workers` to select number of dataloader processes ([#66](https://github.com/mseitzer/pytorch-fid/pull/66)). Defaults to 8 or the number of available CPUs if less than 8 CPUs are available. 19 | 20 | ### Fixed 21 | 22 | - Fixed package setup to work under Windows ([#55](https://github.com/mseitzer/pytorch-fid/pull/55), [#72](https://github.com/mseitzer/pytorch-fid/issues/72)) 23 | 24 | ## \[0.2.0\] - 2020-11-30 25 | 26 | ### Added 27 | 28 | - Load images using a Pytorch dataloader, which should result in a speed-up. ([#47](https://github.com/mseitzer/pytorch-fid/pull/47)) 29 | - Support more image extensions ([#53](https://github.com/mseitzer/pytorch-fid/pull/53)) 30 | - Improve tooling by setting up Nox, add linting and test support ([#52](https://github.com/mseitzer/pytorch-fid/pull/52)) 31 | - Add some unit tests 32 | 33 | ## \[0.1.1\] - 2020-08-16 34 | 35 | ### Fixed 36 | 37 | - Fixed software license string in `setup.py` 38 | 39 | ## \[0.1.0\] - 2020-08-16 40 | 41 | Initial release as a pypi package. Use `pip install pytorch-fid` to install. 42 | -------------------------------------------------------------------------------- /tools/metrics/pytorch-fid/noxfile.py: -------------------------------------------------------------------------------- 1 | import nox 2 | 3 | LOCATIONS = ("src/", "tests/", "noxfile.py", "setup.py") 4 | 5 | 6 | @nox.session 7 | def lint(session): 8 | session.install("flake8") 9 | session.install("flake8-bugbear") 10 | session.install("flake8-isort") 11 | 12 | args = session.posargs or LOCATIONS 13 | session.run("flake8", *args) 14 | 15 | 16 | @nox.session 17 | def tests(session): 18 | session.install(".") 19 | session.install("pytest") 20 | session.install("pytest-mock") 21 | session.run("pytest", *session.posargs) 22 | -------------------------------------------------------------------------------- /tools/metrics/pytorch-fid/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select=F,W,E,I,B,B9 3 | ignore=W503,B950 4 | max-line-length=79 5 | 6 | [isort] 7 | multi_line_output=1 8 | line_length=79 9 | -------------------------------------------------------------------------------- /tools/metrics/pytorch-fid/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import setuptools 4 | 5 | 6 | def read(rel_path): 7 | base_path = os.path.abspath(os.path.dirname(__file__)) 8 | with open(os.path.join(base_path, rel_path)) as f: 9 | return f.read() 10 | 11 | 12 | def get_version(rel_path): 13 | for line in read(rel_path).splitlines(): 14 | if line.startswith("__version__"): 15 | # __version__ = "0.9" 16 | delim = '"' if '"' in line else "'" 17 | return line.split(delim)[1] 18 | 19 | raise RuntimeError("Unable to find version string.") 20 | 21 | 22 | if __name__ == "__main__": 23 | setuptools.setup( 24 | name="pytorch-fid", 25 | version=get_version(os.path.join("src", "pytorch_fid", "__init__.py")), 26 | author="Max Seitzer", 27 | description=("Package for calculating Frechet Inception Distance (FID)" " using PyTorch"), 28 | long_description=read("README.md"), 29 | long_description_content_type="text/markdown", 30 | url="https://github.com/mseitzer/pytorch-fid", 31 | package_dir={"": "src"}, 32 | packages=setuptools.find_packages(where="src"), 33 | classifiers=[ 34 | "Programming Language :: Python :: 3", 35 | "License :: OSI Approved :: Apache Software License", 36 | ], 37 | python_requires=">=3.5", 38 | entry_points={ 39 | "console_scripts": [ 40 | "pytorch-fid = pytorch_fid.fid_score:main", 41 | ], 42 | }, 43 | install_requires=["numpy", "pillow", "scipy", "torch>=1.0.1", "torchvision>=0.2.2"], 44 | extras_require={"dev": ["flake8", "flake8-bugbear", "flake8-isort", "nox"]}, 45 | ) 46 | -------------------------------------------------------------------------------- /tools/metrics/pytorch-fid/src/pytorch_fid/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.0" 2 | -------------------------------------------------------------------------------- /tools/metrics/pytorch-fid/src/pytorch_fid/__main__.py: -------------------------------------------------------------------------------- 1 | import pytorch_fid.fid_score 2 | 3 | pytorch_fid.fid_score.main() 4 | -------------------------------------------------------------------------------- /tools/metrics/pytorch-fid/tests/test_fid_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | from pytorch_fid import fid_score, inception 6 | 7 | 8 | @pytest.fixture 9 | def device(): 10 | return torch.device("cpu") 11 | 12 | 13 | def test_calculate_fid_given_statistics(mocker, tmp_path, device): 14 | dim = 2048 15 | m1, m2 = np.zeros((dim,)), np.ones((dim,)) 16 | sigma = np.eye(dim) 17 | 18 | def dummy_statistics(path, model, batch_size, dims, device, num_workers): 19 | if path.endswith("1"): 20 | return m1, sigma 21 | elif path.endswith("2"): 22 | return m2, sigma 23 | else: 24 | raise ValueError 25 | 26 | mocker.patch("pytorch_fid.fid_score.compute_statistics_of_path", side_effect=dummy_statistics) 27 | 28 | dir_names = ["1", "2"] 29 | paths = [] 30 | for name in dir_names: 31 | path = tmp_path / name 32 | path.mkdir() 33 | paths.append(str(path)) 34 | 35 | fid_value = fid_score.calculate_fid_given_paths(paths, batch_size=dim, device=device, dims=dim, num_workers=0) 36 | 37 | # Given equal covariance, FID is just the squared norm of difference 38 | assert fid_value == np.sum((m1 - m2) ** 2) 39 | 40 | 41 | def test_compute_statistics_of_path(mocker, tmp_path, device): 42 | model = mocker.MagicMock(inception.InceptionV3)() 43 | model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)] 44 | 45 | size = (4, 4, 3) 46 | arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)] 47 | images = [(arr * 255).astype(np.uint8) for arr in arrays] 48 | 49 | paths = [] 50 | for idx, image in enumerate(images): 51 | paths.append(str(tmp_path / f"{idx}.png")) 52 | Image.fromarray(image, mode="RGB").save(paths[-1]) 53 | 54 | stats = fid_score.compute_statistics_of_path( 55 | str(tmp_path), model, batch_size=len(images), dims=3, device=device, num_workers=0 56 | ) 57 | 58 | assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3) 59 | assert np.allclose(stats[1], np.ones((3, 3)) * 0.25) 60 | 61 | 62 | def test_compute_statistics_of_path_from_file(mocker, tmp_path, device): 63 | model = mocker.MagicMock(inception.InceptionV3)() 64 | 65 | mu = np.random.randn(5) 66 | sigma = np.random.randn(5, 5) 67 | 68 | path = tmp_path / "stats.npz" 69 | with path.open("wb") as f: 70 | np.savez(f, mu=mu, sigma=sigma) 71 | 72 | stats = fid_score.compute_statistics_of_path(str(path), model, batch_size=1, dims=5, device=device, num_workers=0) 73 | 74 | assert np.allclose(stats[0], mu) 75 | assert np.allclose(stats[1], sigma) 76 | 77 | 78 | def test_image_types(tmp_path): 79 | in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255 80 | in_image = Image.fromarray(in_arr, mode="RGB") 81 | 82 | paths = [] 83 | for ext in fid_score.IMAGE_EXTENSIONS: 84 | paths.append(str(tmp_path / f"img.{ext}")) 85 | in_image.save(paths[-1]) 86 | 87 | dataset = fid_score.ImagePathDataset(paths) 88 | 89 | for img in dataset: 90 | assert np.allclose(np.array(img), in_arr) 91 | -------------------------------------------------------------------------------- /tools/metrics/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def tracker(args, result_dict, label="", pattern="epoch_step", metric="FID"): 5 | if args.report_to == "wandb": 6 | import wandb 7 | 8 | wandb_name = f"[{args.log_metric}]_{args.name}" 9 | wandb.init(project=args.tracker_project_name, name=wandb_name, resume="allow", id=wandb_name, tags="metrics") 10 | run = wandb.run 11 | if pattern == "step": 12 | pattern = "sample_steps" 13 | elif pattern == "epoch_step": 14 | pattern = "step" 15 | custom_name = f"custom_{pattern}" 16 | run.define_metric(custom_name) 17 | # define which metrics will be plotted against it 18 | run.define_metric(f"{metric}_{label}", step_metric=custom_name) 19 | 20 | steps = [] 21 | results = [] 22 | 23 | def extract_value(regex, exp_name): 24 | match = re.search(regex, exp_name) 25 | if match: 26 | return match.group(1) 27 | else: 28 | return "unknown" 29 | 30 | for exp_name, result_value in result_dict.items(): 31 | if pattern == "step": 32 | regex = r".*step(\d+)_scale.*" 33 | custom_x = extract_value(regex, exp_name) 34 | elif pattern == "sample_steps": 35 | regex = r".*step(\d+)_size.*" 36 | custom_x = extract_value(regex, exp_name) 37 | else: 38 | regex = rf"{pattern}(\d+(\.\d+)?)" 39 | custom_x = extract_value(regex, exp_name) 40 | custom_x = 1 if custom_x == "unknown" else custom_x 41 | 42 | assert custom_x != "unknown" 43 | steps.append(float(custom_x)) 44 | results.append(result_value) 45 | 46 | sorted_data = sorted(zip(steps, results)) 47 | steps, results = zip(*sorted_data) 48 | 49 | for step, result in sorted(zip(steps, results)): 50 | run.log({f"{metric}_{label}": result, custom_name: step}) 51 | else: 52 | print(f"{args.report_to} is not supported") 53 | -------------------------------------------------------------------------------- /train_scripts/train.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | set -e 3 | 4 | work_dir=output/debug 5 | np=2 6 | 7 | 8 | if [[ $1 == *.yaml ]]; then 9 | config=$1 10 | shift 11 | else 12 | config="configs/sana1-5_config/1024ms/Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml" 13 | # config="configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml" FSDP config file 14 | echo "Only support .yaml files, but get $1. Set to --config_path=$config" 15 | fi 16 | 17 | TRITON_PRINT_AUTOTUNING=1 \ 18 | torchrun --nproc_per_node=$np --master_port=15432 \ 19 | train_scripts/train.py \ 20 | --config_path=$config \ 21 | --work_dir=$work_dir \ 22 | --name=tmp \ 23 | --resume_from=latest \ 24 | --report_to=tensorboard \ 25 | --debug=true \ 26 | "$@" 27 | -------------------------------------------------------------------------------- /train_scripts/train_lora.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers" 4 | export INSTANCE_DIR="data/dreambooth/dog" 5 | export OUTPUT_DIR="trained-sana-lora" 6 | 7 | accelerate launch --num_processes 4 --main_process_port 29500 --gpu_ids 0,1,2,3 \ 8 | train_scripts/train_dreambooth_lora_sana.py \ 9 | --pretrained_model_name_or_path=$MODEL_NAME \ 10 | --instance_data_dir=$INSTANCE_DIR \ 11 | --output_dir=$OUTPUT_DIR \ 12 | --mixed_precision="bf16" \ 13 | --instance_prompt="a photo of sks dog" \ 14 | --resolution=1024 \ 15 | --train_batch_size=1 \ 16 | --gradient_accumulation_steps=4 \ 17 | --use_8bit_adam \ 18 | --learning_rate=1e-4 \ 19 | --report_to="wandb" \ 20 | --lr_scheduler="constant" \ 21 | --lr_warmup_steps=0 \ 22 | --max_train_steps=500 \ 23 | --validation_prompt="A photo of sks dog in a pond, yarn art style" \ 24 | --validation_epochs=25 \ 25 | --seed="0" \ 26 | --push_to_hub 27 | -------------------------------------------------------------------------------- /train_scripts/train_scm_ladd.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | set -e 3 | 4 | work_dir=output/debug_sCM_ladd 5 | np=2 6 | 7 | 8 | if [[ $1 == *.yaml ]]; then 9 | config=$1 10 | shift 11 | else 12 | config="configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml" 13 | echo "Only support .yaml files, but get $1. Set to --config_path=$config" 14 | fi 15 | 16 | cmd="TRITON_PRINT_AUTOTUNING=1 \ 17 | torchrun --nproc_per_node=$np --master_port=$((RANDOM % 10000 + 20000)) \ 18 | train_scripts/train_scm_ladd.py \ 19 | --config_path=$config \ 20 | --work_dir=$work_dir \ 21 | --name=tmp \ 22 | --resume_from=latest \ 23 | --report_to=tensorboard \ 24 | --debug=true \ 25 | $@" 26 | 27 | echo $cmd 28 | eval $cmd 29 | --------------------------------------------------------------------------------