├── .github ├── dependabot.yml └── workflows │ ├── lint.yml │ ├── test_models.yml │ ├── test_sdxl.yml │ └── test_shark.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── ci-tools └── latest-pkgci.py ├── models ├── README.md ├── pyproject.toml ├── requirements.txt ├── setup.py └── turbine_models │ ├── custom_models │ ├── README.md │ ├── __init__.py │ ├── llama_argmax_td_spec.mlir │ ├── llama_benchmark │ │ ├── README.md │ │ ├── benchmark.mlir │ │ ├── benchmark_forward.mlir │ │ ├── benchmark_module.py │ │ └── stateless_llama_benchmark.py │ ├── llm_optimizations │ │ ├── __init__.py │ │ └── streaming_llm │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ └── modify_llama.py │ ├── llm_runner.py │ ├── pipeline_base.py │ ├── remap_gguf.py │ ├── resnet_18.py │ ├── sd3_inference │ │ ├── sd3_cmd_opts.py │ │ ├── sd3_full.py │ │ ├── sd3_mmdit.py │ │ ├── sd3_mmdit_runner.py │ │ ├── sd3_pipeline.py │ │ ├── sd3_schedulers.py │ │ ├── sd3_text_encoders.py │ │ ├── sd3_text_encoders_runner.py │ │ ├── sd3_vae.py │ │ ├── sd3_vae_runner.py │ │ └── text_encoder_impls.py │ ├── sd_inference │ │ ├── clip.py │ │ ├── clip_runner.py │ │ ├── default_mfma_attn_spec.mlir │ │ ├── schedulers.py │ │ ├── schedulers_runner.py │ │ ├── sd_cmd_opts.py │ │ ├── sd_pipeline.py │ │ ├── tokenization.py │ │ ├── unet.py │ │ ├── unet_runner.py │ │ ├── utils.py │ │ ├── vae.py │ │ └── vae_runner.py │ ├── sdxl_inference │ │ ├── COMMANDS.md │ │ ├── README.md │ │ ├── clip.py │ │ ├── clip_runner.py │ │ ├── pipeline_ir.py │ │ ├── sdxl_benchmark.py │ │ ├── sdxl_cmd_opts.py │ │ ├── sdxl_compiled_pipeline.py │ │ ├── sdxl_prompt_encoder.py │ │ ├── sdxl_prompt_encoder_runner.py │ │ ├── sdxl_scheduled_unet.py │ │ ├── sdxl_scheduled_unet_runner.py │ │ ├── unet.py │ │ ├── unet_runner.py │ │ ├── vae.py │ │ └── vae_runner.py │ └── stateless_llama.py │ ├── gen_external_params │ ├── __init__.py │ └── gen_external_params.py │ ├── model_builder.py │ ├── model_runner.py │ ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── gen_external_params_test.py │ ├── pipeline_test.py │ ├── resnet_test.py │ ├── sd3_test.py │ ├── sd_test.py │ ├── sdxl_test.py │ ├── stateless_llama_test.py │ └── vmfb_comparison_cached_torch_output_f32_unquantized.txt │ ├── turbine_tank │ ├── __init__.py │ └── turbine_tank.py │ └── utils │ ├── benchmark.py │ └── sdxl_benchmark.py ├── mypy-requirements.txt └── version_info.json /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | 4 | # Check for updates to GitHub Actions 5 | - package-ecosystem: "github-actions" 6 | directory: "/" 7 | schedule: 8 | interval: "weekly" 9 | groups: 10 | github-actions: 11 | patterns: 12 | - "*" 13 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | black: 12 | name: Python Formatting With Black 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checking out repository 16 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 17 | - name: Setting up python 18 | uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 19 | - name: Fetching Base Branch 20 | # We have to explicitly fetch the base branch as well 21 | run: git fetch --no-tags --prune --depth=1 origin "${GITHUB_BASE_REF?}:${GITHUB_BASE_REF?}" 22 | - name: Install black 23 | run: | 24 | python3 -m pip install black 25 | - name: Check if modified files are formatted 26 | run: | 27 | # The filter lowercase `d` means to exclude deleted files. 28 | git diff "${GITHUB_BASE_REF?}" --name-only --diff-filter=d \ 29 | -- '*.py' \ 30 | | xargs --no-run-if-empty black --check --diff --verbose 31 | - name: Instructions for fixing the above linting errors 32 | if: failure() 33 | run: | 34 | printf "You can fix formatting by running 'black' on the modified python files:\n" 35 | printf " git diff ${GITHUB_BASE_REF?} --name-only -- '*.py' ':!third_party' | xargs black\n" 36 | -------------------------------------------------------------------------------- /.github/workflows/test_models.yml: -------------------------------------------------------------------------------- 1 | name: Test Turbine Models 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | 10 | concurrency: 11 | # A PR number if a pull request and otherwise the commit hash. This cancels 12 | # queued and in-progress runs for the same PR (presubmit) or commit 13 | # (postsubmit). The workflow name is prepended to avoid conflicts between 14 | # different workflows. 15 | group: ${{ github.workflow }}-${{ github.event.number || github.sha }} 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | test-turbine-models: 20 | strategy: 21 | matrix: 22 | version: [3.11] 23 | os: [nodai-amdgpu-mi250-x86-64] 24 | 25 | runs-on: ${{matrix.os}} 26 | env: 27 | E2E_VENV_DIR: ${{ github.workspace }}/test-suite_venv 28 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 29 | AZ_PRIVATE_CONNECTION: ${{ secrets.ONNXPRIVATESTORAGE_AZ_PRIVATE_CONNECTION }} 30 | ROCR_VISIBLE_DEVICES: ${{ matrix.visible-device }} 31 | TORCH_HOME: /groups/aig_sharks/test-suite-ci-cache 32 | HF_HOME: /groups/aig_sharks/test-suite-ci-cache 33 | TURBINE_TANK_CACHE_DIR: /groups/aig_sharks/test-suite-ci-cache 34 | steps: 35 | # We are using a persistent Gentoo runner here, and this python action is not supported for the arch 36 | # - name: "Setting up Python" 37 | # uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3 38 | # with: 39 | # python-version: ${{matrix.version}} 40 | 41 | - name: "Checkout This Repo" 42 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 43 | 44 | - name: "Checkout iree-turbine" 45 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 46 | with: 47 | repository: iree-org/iree-turbine 48 | # TODO: Let the ref be passed as a parameter to run integration tests. 49 | path: iree-turbine 50 | 51 | - name: Sync source deps 52 | # build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile 53 | run: | 54 | python3.11 -m venv turbine_venv 55 | source turbine_venv/bin/activate 56 | python3.11 -m pip install --upgrade pip 57 | # Note: We install in three steps in order to satisfy requirements 58 | # from non default locations first. Installing the PyTorch CPU 59 | # wheels saves multiple minutes and a lot of bandwidth on runner setup. 60 | pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt 61 | pip install --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt 62 | pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing] 63 | pip install --upgrade --pre --no-cache-dir iree-compiler iree-runtime -f https://iree.dev/pip-release-links.html 64 | pip install --no-compile --pre --upgrade -e models -r models/requirements.txt 65 | 66 | - name: Show current free memory 67 | run: | 68 | free -mh 69 | 70 | - name: Run stateless_llama tests 71 | run: | 72 | source turbine_venv/bin/activate 73 | pytest -v models/turbine_models/tests/stateless_llama_test.py 74 | 75 | - name: Run sd tests 76 | run: | 77 | source turbine_venv/bin/activate 78 | 79 | pytest -v models/turbine_models/tests/sd_test.py 80 | pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 81 | pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux 82 | pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default 83 | pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2 84 | pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 85 | -------------------------------------------------------------------------------- /.github/workflows/test_sdxl.yml: -------------------------------------------------------------------------------- 1 | name: SDXL E2E Pipeline CI 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | schedule: 7 | - cron: "*/50 * * * *" 8 | 9 | concurrency: 10 | # A PR number if a pull request and otherwise the commit hash. This cancels 11 | # queued and in-progress runs for the same PR (presubmit) or commit 12 | # (postsubmit). The workflow name is prepended to avoid conflicts between 13 | # different workflows. 14 | group: ${{ github.workflow }}-${{ github.event.number || github.sha }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | test-sdxl-models: 19 | strategy: 20 | matrix: 21 | version: [3.11] 22 | os: [nodai-amdgpu-mi300-x86-64] 23 | 24 | runs-on: ${{matrix.os}} 25 | env: 26 | IREE_TOKEN: ${{ secrets.IREE_TOKEN }} 27 | steps: 28 | - name: "Setting up Python" 29 | uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 30 | with: 31 | python-version: ${{matrix.version}} 32 | 33 | - name: "Checkout SHARK-ModelDev" 34 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 35 | with: 36 | ref: bump-punet-tom 37 | 38 | - name: "Checkout iree-turbine" 39 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 40 | with: 41 | repository: iree-org/iree-turbine 42 | path: iree-turbine 43 | 44 | - name: "Checkout iree" 45 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 46 | with: 47 | repository: iree-org/iree 48 | path: iree 49 | 50 | - name: Python deps 51 | run: | 52 | python3.11 -m venv sdxl_venv 53 | source sdxl_venv/bin/activate 54 | python -m pip install --upgrade pip 55 | pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt 56 | pip install --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt 57 | pip install --no-compile --pre --upgrade -e models -r models/requirements.txt 58 | pip uninstall torch torchvision torchaudio -y 59 | pip install https://download.pytorch.org/whl/nightly/pytorch_triton_rocm-3.0.0%2B21eae954ef-cp311-cp311-linux_x86_64.whl 60 | pip install https://download.pytorch.org/whl/nightly/rocm6.1/torch-2.5.0.dev20240710%2Brocm6.1-cp311-cp311-linux_x86_64.whl 61 | pip install https://download.pytorch.org/whl/nightly/rocm6.1/torchvision-0.20.0.dev20240711%2Brocm6.1-cp311-cp311-linux_x86_64.whl 62 | pip install https://download.pytorch.org/whl/nightly/rocm6.1/torchaudio-2.4.0.dev20240711%2Brocm6.1-cp311-cp311-linux_x86_64.whl 63 | pip uninstall iree-compiler iree-runtime iree-base-compiler iree-base-runtime -y 64 | python ci-tools/latest-pkgci.py 65 | cd wheels 66 | unzip *.zip 67 | pip install *.whl 68 | cd .. 69 | rm -rf wheels 70 | 71 | - name: Show current free memory 72 | run: | 73 | free -mh 74 | 75 | - name: Run sdxl tests 76 | run: | 77 | source sdxl_venv/bin/activate 78 | python3 models/turbine_models/custom_models/sd_inference/sd_pipeline.py --device=hip --precision=fp16 --iree_target_triple=gfx942 --external_weights=safetensors --hf_model_name=stabilityai/stable-diffusion-xl-base-1.0 --width=1024 --height=1024 --batch_size=1 --use_i8_punet --attn_spec=punet --vae_decomp_attn --external_weights=safetensors --num_inference_steps=20 --benchmark=all --verbose 79 | -------------------------------------------------------------------------------- /.github/workflows/test_shark.yml: -------------------------------------------------------------------------------- 1 | name: Test SHARK 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | 10 | concurrency: 11 | # A PR number if a pull request and otherwise the commit hash. This cancels 12 | # queued and in-progress runs for the same PR (presubmit) or commit 13 | # (postsubmit). The workflow name is prepended to avoid conflicts between 14 | # different workflows. 15 | group: ${{ github.workflow }}-${{ github.event.number || github.sha }} 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | test-shark: 20 | strategy: 21 | matrix: 22 | version: [3.11] 23 | os: [nodai-amdgpu-mi250-x86-64] 24 | 25 | runs-on: ${{matrix.os}} 26 | steps: 27 | - name: "Setting up Python" 28 | uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 29 | with: 30 | python-version: ${{matrix.version}} 31 | 32 | - name: "Checkout SHARK" 33 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 34 | with: 35 | repository: "nod-ai/SHARK.git" 36 | path: SHARK 37 | ref: "main" 38 | 39 | - name: "Checkout iree-turbine" 40 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 41 | with: 42 | repository: iree-org/iree-turbine 43 | # TODO: Let the ref be passed as a parameter to run integration tests. 44 | path: iree-turbine 45 | 46 | # TODO: Replace with a sh script from shark repo 47 | - name: "Install SHARK" 48 | run: | 49 | cd $GITHUB_WORKSPACE/SHARK 50 | python${{ matrix.version }} -m venv shark.venv 51 | source shark.venv/bin/activate 52 | pip install -r requirements.txt --no-cache-dir 53 | pip install -e . 54 | python apps/shark_studio/tests/api_test.py 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Visual Studio files 2 | .env 3 | .vs/ 4 | .vscode/ 5 | *.sdf 6 | *.opensdf 7 | *.VC.opendb 8 | *.suo 9 | *.user 10 | 11 | # macOS files 12 | .DS_Store 13 | 14 | # CMake artifacts 15 | build/ 16 | build-*/ 17 | 18 | # Python 19 | __pycache__ 20 | _python_build/ 21 | dist/ 22 | wheelhouse 23 | *.egg-info 24 | *.whl 25 | 26 | #Model artifacts 27 | *.pt 28 | *.safetensors 29 | *.gguf 30 | *.vmfb 31 | *.mlir 32 | *.npy 33 | *.png 34 | *tmp* 35 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt 3 | include pytorch-cpu-requirements.txt 4 | include version_info.json 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SHARK Turbine 2 | 3 | This repo is Nod-AI's integration repository for various model bringup 4 | activities and CI. In 2023 and early 2024, it played a different role 5 | by being the place where FX/Dynamo based torch-mlir and IREE toolsets 6 | were developed, including: 7 | 8 | * [Torch-MLIR FxImporter](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py) 9 | * [Torch-MLIR ONNX Importer](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/onnx_importer.py) 10 | * [Torch-MLIR's ONNX C Importer](https://github.com/llvm/torch-mlir/tree/main/projects/onnx_c_importer) 11 | * [IREE Turbine](https://github.com/iree-org/iree-turbine) 12 | * [Sharktank and Shortfin](https://github.com/nod-ai/sharktank) 13 | 14 | As these have all found upstream homes, this repo is a bit bare. We will 15 | continue to use it as a staging ground for things that don't have a 16 | more defined spot and as a way to drive certain kinds of upstreaming 17 | activities. 18 | 19 | 20 | ## Current Projects 21 | 22 | ### turbine-models 23 | 24 | The `turbine-models` project (under models/) contains ports and adaptations 25 | of various (mostly HF) models that we use in various ways. 26 | 27 | ### CI 28 | 29 | Integration CI for a variety of projects is rooted in this repo. 30 | 31 | -------------------------------------------------------------------------------- /ci-tools/latest-pkgci.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import os 4 | 5 | GITHUB_TOKEN = os.getenv("IREE_TOKEN") 6 | 7 | OWNER = "iree-org" 8 | REPO = "iree" 9 | 10 | API_URL = ( 11 | f"https://api.github.com/repos/{OWNER}/{REPO}/actions/workflows/pkgci.yml/runs" 12 | ) 13 | 14 | 15 | # Get the latest workflow run ID for pkgci.yml 16 | def get_latest_pkgci_workflow_run(): 17 | headers = { 18 | "Accept": "application/vnd.github+json", 19 | "Authorization": f"Bearer {GITHUB_TOKEN}", 20 | "X-GitHub-Api-Version": "2022-11-28", 21 | } 22 | params = { 23 | "per_page": 1, 24 | "event": "push", 25 | "branch": "main", 26 | } 27 | response = requests.get(API_URL, headers=headers, params=params) 28 | 29 | if response.status_code == 200: 30 | data = response.json() 31 | if data["total_count"] > 0: 32 | latest_run = data["workflow_runs"][0] 33 | return latest_run["id"], latest_run["artifacts_url"] 34 | else: 35 | print("No workflow runs found for pkgci.yml.") 36 | return None 37 | else: 38 | print(f"Error fetching workflow runs: {response.status_code}") 39 | return None 40 | 41 | 42 | # Get the artifacts of a specific workflow run 43 | def get_artifacts(workflow_run_id, artifacts_url): 44 | headers = { 45 | "Accept": "application/vnd.github+json", 46 | "Authorization": f"Bearer {GITHUB_TOKEN}", 47 | "X-GitHub-Api-Version": "2022-11-28", 48 | } 49 | response = requests.get(artifacts_url, headers=headers) 50 | 51 | if response.status_code == 200: 52 | artifacts = response.json()["artifacts"] 53 | if artifacts: 54 | print(f"Artifacts for pkgci.yml workflow run {workflow_run_id}:") 55 | for artifact in artifacts: 56 | print(f"- {artifact['name']} (Size: {artifact['size_in_bytes']} bytes)") 57 | download_artifact(artifact["archive_download_url"], artifact["name"]) 58 | else: 59 | print("No artifacts found for the pkgci.yml workflow run.") 60 | else: 61 | print(f"Error fetching artifacts: {response.status_code}") 62 | 63 | 64 | # Download an artifact 65 | def download_artifact(download_url, artifact_name): 66 | headers = { 67 | "Accept": "application/vnd.github+json", 68 | "Authorization": f"Bearer {GITHUB_TOKEN}", 69 | "X-GitHub-Api-Version": "2022-11-28", 70 | } 71 | response = requests.get(download_url, headers=headers, stream=True) 72 | 73 | if response.status_code == 200: 74 | file_name = f"wheels/{artifact_name}.zip" 75 | os.mkdir("wheels") 76 | with open(file_name, "wb") as f: 77 | for chunk in response.iter_content(chunk_size=8192): 78 | if chunk: 79 | f.write(chunk) 80 | print(f"Artifact '{artifact_name}' downloaded successfully as '{file_name}'.") 81 | else: 82 | print(f"Error downloading artifact '{artifact_name}': {response.status_code}") 83 | 84 | 85 | if __name__ == "__main__": 86 | workflow_run_id, artifact_url = get_latest_pkgci_workflow_run() 87 | if workflow_run_id: 88 | get_artifacts(workflow_run_id, artifact_url) 89 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # LLAMA 2 Inference 2 | 3 | This example require some extra dependencies. Here's an easy way to get it running on a fresh server. 4 | 5 | Don't forget to put in your huggingface token from https://huggingface.co/settings/tokens 6 | 7 | ```bash 8 | #!/bin/bash 9 | 10 | 11 | # if you don't insert it, you will be prompted to log in later; 12 | # you may need to rerun this script after logging in 13 | YOUR_HF_TOKEN="insert token for headless" 14 | 15 | # clone and install dependencies 16 | sudo apt install -y git 17 | git clone https://github.com/nod-ai/SHARK-Turbine.git 18 | cd SHARK-Turbine 19 | pip install -r core/requirements.txt 20 | pip install -r models/requirements.txt 21 | 22 | # do an editable install from the cloned SHARK-Turbine 23 | pip install --editable core models 24 | 25 | # Log in with Hugging Face CLI if token setup is required 26 | if [[ $YOUR_HF_TOKEN == hf_* ]]; then 27 | huggingface login --token $YOUR_HF_TOKEN 28 | echo "Logged in with YOUR_HF_TOKEN." 29 | elif [ -f ~/.cache/huggingface/token ]; then 30 | # Read token from the file 31 | TOKEN_CONTENT=$(cat ~/.cache/huggingface/token) 32 | 33 | # Check if the token starts with "hf_" 34 | if [[ $TOKEN_CONTENT == hf_* ]]; then 35 | echo "Already logged in with a Hugging Face token." 36 | else 37 | echo "Token in file does not start with 'hf_'. Please log into huggingface to download models." 38 | huggingface-cli login 39 | fi 40 | else 41 | echo "Please log into huggingface to download models." 42 | huggingface-cli login 43 | fi 44 | 45 | # Step 7: Run the Python script 46 | python .\python\turbine_models\custom_models\stateless_llama.py --compile_to=torch --external_weights=safetensors --external_weight_file=llama_f32.safetensors 47 | ``` 48 | -------------------------------------------------------------------------------- /models/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /models/requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf 2 | gguf 3 | transformers==4.50.0 4 | torchsde 5 | accelerate 6 | peft 7 | diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark 8 | brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b 9 | # turbine tank downloading/uploading 10 | azure-storage-blob 11 | # microsoft/phi model 12 | einops 13 | pytest 14 | scipy 15 | iree-turbine @ git+https://github.com/iree-org/iree-turbine.git@main 16 | -e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank 17 | -------------------------------------------------------------------------------- /models/setup.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | from setuptools import find_namespace_packages, setup 6 | 7 | 8 | #### TURBINE MODELS SETUP #### 9 | 10 | 11 | TURBINE_MODELS_DIR = os.path.realpath(os.path.dirname(__file__)) 12 | TURBINE_ROOT_DIR = Path(TURBINE_MODELS_DIR).parent 13 | print(TURBINE_ROOT_DIR) 14 | VERSION_INFO_FILE = os.path.join(TURBINE_ROOT_DIR, "version_info.json") 15 | 16 | 17 | with open( 18 | os.path.join( 19 | TURBINE_MODELS_DIR, 20 | "README.md", 21 | ), 22 | "rt", 23 | ) as f: 24 | README = f.read() 25 | 26 | 27 | def load_version_info(): 28 | with open(VERSION_INFO_FILE, "rt") as f: 29 | return json.load(f) 30 | 31 | 32 | version_info = load_version_info() 33 | PACKAGE_VERSION = version_info["package-version"] 34 | 35 | setup( 36 | name=f"turbine-models", 37 | version=f"{PACKAGE_VERSION}", 38 | author="SHARK Authors", 39 | author_email="dan@nod.ai", 40 | description="SHARK Turbine Machine Learning Model Zoo", 41 | long_description=README, 42 | long_description_content_type="text/markdown", 43 | url="https://github.com/nod-ai/SHARK-Turbine", 44 | license="Apache-2.0", 45 | classifiers=[ 46 | "Development Status :: 3 - Alpha", 47 | "License :: OSI Approved :: Apache Software License", 48 | "Programming Language :: Python :: 3", 49 | ], 50 | packages=find_namespace_packages( 51 | include=[ 52 | "turbine_models", 53 | "turbine_models.*", 54 | ], 55 | ), 56 | install_requires=[ 57 | "Shark-Turbine", 58 | "protobuf", 59 | "sentencepiece", 60 | "transformers>=4.37.1", 61 | "accelerate", 62 | "diffusers==0.29.0.dev0", 63 | "azure-storage-blob", 64 | "einops", 65 | ], 66 | ) 67 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | Clone and install SHARK-Turbine 4 | ``` 5 | git clone https://github.com/nod-ai/SHARK-Turbine.git 6 | cd SHARK-Turbine 7 | python -m venv turbine_venv && source turbine_venv/bin/activate 8 | 9 | pip install --index-url https://download.pytorch.org/whl/cpu \ 10 | -r core/pytorch-cpu-requirements.txt 11 | pip install --upgrade -r core/requirements.txt 12 | pip install -e core 13 | pip install -e models 14 | ``` 15 | 16 | ## Compiling LLMs 17 | Note: Make sure to replace "your_token" with your actual hf_auth_token for all the commands. 18 | 19 | Now, you can generate the quantized weight file with 20 | ``` 21 | python models/turbine_models/gen_external_params/gen_external_params.py --hf_auth_token=your_token 22 | ``` 23 | The model weights will then be saved in the current directory as `Llama_2_7b_chat_hf_f16_int4.safetensors`. 24 | 25 | To compile to vmfb for llama 26 | ``` 27 | python models/turbine_models/custom_models/stateless_llama.py --compile_to=vmfb --hf_auth_token=your_token --external_weights="safetensors" --quantization="int4" --precision="f16" 28 | ``` 29 | By default the vmfb will be saved as `Llama_2_7b_chat_hf.vmfb`. 30 | 31 | ## Running LLMs 32 | There are two ways of running LLMs: 33 | 34 | 1) Single run with predefined prompt to validate correctness. 35 | ``` 36 | python models/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token 37 | ``` 38 | 2) Interactive CLI chat mode. (just add a --chat_mode flag) 39 | ``` 40 | python models/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode 41 | ``` 42 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nod-ai/SHARK-ModelDev/bdf46351281f47843bb0d8c77bcd8ecde7271b60/models/turbine_models/custom_models/__init__.py -------------------------------------------------------------------------------- /models/turbine_models/custom_models/llama_benchmark/README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | Clone and install SHARK-Turbine 4 | ``` 5 | git clone https://github.com/nod-ai/SHARK-Turbine.git 6 | cd SHARK-Turbine 7 | python -m venv turbine_venv && source turbine_venv/bin/activate 8 | 9 | pip install --upgrade -r requirements.txt 10 | pip install --upgrade -e .[torch-cpu-nightly,testing] 11 | pip install --upgrade -r turbine-models-requirements.txt 12 | ``` 13 | 14 | Note: Make sure to replace "your_token" with your actual hf_auth_token for all the commands. 15 | 16 | Now, you can generate the quantized weight file with 17 | ``` 18 | python python/turbine_models/gen_external_params/gen_external_params.py --hf_auth_token=your_token 19 | ``` 20 | The model weights will then be saved in the current directory as `Llama_2_7b_chat_hf_f16_int4.safetensors`. 21 | 22 | To compile to vmfb for llama 23 | ``` 24 | python python/turbine_models/custom_models/stateless_llama.py --compile_to=vmfb --hf_auth_token=your_token --external_weights="safetensors" --quantization="int4" --precision="f16" 25 | ``` 26 | By default the vmfb will be saved as `Llama_2_7b_chat_hf.vmfb`. 27 | 28 | There are two options provided for benchmarking: 29 | 30 | 1) Benchmarking the first and second vicuna (run_initialize and run_forward) 31 | 2) Only benchmarking the second vicuna (run_forward) for more accurate tok/s 32 | 33 | To compile to vmfb for benchmark option 1: 34 | ``` 35 | python python/turbine_models/custom_models/llama-benchmark/benchmark_module.py --benchmark_mlir_path=./python/turbine_models/custom_models/llama-benchmark/benchmark.mlir 36 | ``` 37 | By default the vmfb will be saved as `benchmark.vmfb`. 38 | 39 | To compile to vmfb for benchmark option 2: 40 | ``` 41 | python python/turbine_models/custom_models/llama-benchmark/benchmark_module.py --benchmark_mlir_path=./python/turbine_models/custom_models/llama-benchmark/benchmark_forward.mlir 42 | ``` 43 | By default the vmfb will be saved as `benchmark.vmfb`. 44 | 45 | 46 | # Benchmarking 47 | 48 | Set the number of times second vicuna is run (# of tokens to benchmark) using the steps argument in following command. 49 | 50 | To run the benchmark, use this command for option 1 (first and second vicuna): 51 | 52 | ``` 53 | python python/turbine_models/custom_models/llama-benchmark/stateless_llama_benchmark.py --hf_auth_token=your_token --benchmark_vmfb_path=benchmark.vmfb --llama_vmfb_path=Llama_2_7b_chat_hf.vmfb --external_weight_file=Llama_2_7b_chat_hf_f16_int4.safetensors --steps=10 54 | ``` 55 | 56 | To run the benchmark, use this command for option 2 (only run_forward): 57 | 58 | ``` 59 | python python/turbine_models/custom_models/llama-benchmark/stateless_llama_benchmark.py --run_forward_only_benchmark --hf_auth_token=your_token --benchmark_vmfb_path=benchmark.vmfb --llama_vmfb_path=Llama_2_7b_chat_hf.vmfb --external_weight_file=Llama_2_7b_chat_hf_f16_int4.safetensors --steps=10 60 | ``` 61 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/llama_benchmark/benchmark.mlir: -------------------------------------------------------------------------------- 1 | module { 2 | func.func private @state_update.run_initialize(%arg0: tensor<1x?xi64>) -> tensor<1x1xi64> attributes { 3 | torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]" 4 | } 5 | func.func private @state_update.run_forward(%arg0: tensor<1x1xi64>) -> tensor<1x1xi64> attributes { 6 | torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]" 7 | } 8 | 9 | func.func @run(%input: tensor<1x?xi64>, %steps: tensor) -> tensor<1x1xi64> { 10 | %init = func.call @state_update.run_initialize(%input) : (tensor<1x?xi64>) -> tensor<1x1xi64> 11 | %c0 = arith.constant 0 : index 12 | %c1 = arith.constant 1 : index 13 | %steps_i64 = tensor.extract %steps[] : tensor 14 | %steps_index = arith.index_cast %steps_i64 : i64 to index 15 | %res = scf.for %arg0 = %c0 to %steps_index step %c1 iter_args(%arg = %init) -> (tensor<1x1xi64>) { 16 | %next = func.call @state_update.run_forward(%arg) : (tensor<1x1xi64>) -> tensor<1x1xi64> 17 | scf.yield %next : tensor<1x1xi64> 18 | } 19 | 20 | return %res : tensor<1x1xi64> 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/llama_benchmark/benchmark_forward.mlir: -------------------------------------------------------------------------------- 1 | module { 2 | func.func private @state_update.run_forward(%arg0: tensor<1x1xi64>) -> tensor<1x1xi64> attributes { 3 | torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]" 4 | } 5 | 6 | func.func @run(%input: tensor<1x1xi64>, %steps: tensor) -> tensor<1x1xi64> { 7 | %c0 = arith.constant 0 : index 8 | %c1 = arith.constant 1 : index 9 | %steps_i64 = tensor.extract %steps[] : tensor 10 | %steps_index = arith.index_cast %steps_i64 : i64 to index 11 | %res = scf.for %arg0 = %c0 to %steps_index step %c1 iter_args(%arg = %input) -> (tensor<1x1xi64>) { 12 | %next = func.call @state_update.run_forward(%arg) : (tensor<1x1xi64>) -> tensor<1x1xi64> 13 | scf.yield %next : tensor<1x1xi64> 14 | } 15 | 16 | return %res : tensor<1x1xi64> 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/llama_benchmark/benchmark_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import argparse 8 | import sys 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--benchmark_mlir_path", type=str, default="", help="Path to benchmark mlir module" 13 | ) 14 | parser.add_argument( 15 | "--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm" 16 | ) 17 | parser.add_argument( 18 | "--iree_target_triple", 19 | type=str, 20 | default="host", 21 | help="Specify vulkan target triple or rocm/cuda target device.", 22 | ) 23 | parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") 24 | 25 | 26 | def create_benchmark_vmfb(args): 27 | if not args.benchmark_mlir_path: 28 | sys.exit("no benchmark_vmfb_path provided, required for run_benchmark") 29 | 30 | flags = [ 31 | "--iree-input-type=torch", 32 | "--mlir-print-debuginfo", 33 | "--mlir-print-op-on-diagnostic=false", 34 | "--iree-llvmcpu-target-cpu-features=host", 35 | "--iree-llvmcpu-target-triple=x86_64-linux-gnu", 36 | "--iree-stream-resource-index-bits=64", 37 | "--iree-vm-target-index-bits=64", 38 | ] 39 | device = args.device 40 | if device == "cpu" or device == "llvm-cpu": 41 | flags.append("--iree-llvmcpu-enable-ukernels=all") 42 | device = "llvm-cpu" 43 | elif device == "vulkan": 44 | flags.extend( 45 | [ 46 | "--iree-vulkan-target-triple=" + args.iree_target_triple, 47 | "--iree-stream-resource-max-allocation-size=" 48 | + args.vulkan_max_allocation, 49 | ] 50 | ) 51 | elif device == "rocm": 52 | flags.extend( 53 | [ 54 | "--iree-rocm-target-chip=" + args.iree_target_triple, 55 | "--iree-rocm-link-bc=true", 56 | "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", 57 | "--iree-vm-bytecode-module-strip-source-map=true", 58 | "--iree-opt-strip-assertions=true", 59 | "--iree-vm-target-truncate-unsupported-floats", 60 | ] 61 | ) 62 | elif device == "cuda": 63 | flags.extend( 64 | [ 65 | "--iree-hal-cuda-llvm-target-arch=" + args.iree_target_triple, 66 | "--iree-vm-bytecode-module-strip-source-map=true", 67 | "--iree-vm-target-truncate-unsupported-floats", 68 | ] 69 | ) 70 | else: 71 | print("Unknown device kind: ", device) 72 | 73 | import iree.compiler as ireec 74 | 75 | flatbuffer_blob = ireec.compile_file( 76 | input_file=f"{args.benchmark_mlir_path}", 77 | target_backends=[device], 78 | extra_args=flags, 79 | ) 80 | with open(f"benchmark.vmfb", "wb+") as f: 81 | f.write(flatbuffer_blob) 82 | print("saved to benchmark.vmfb") 83 | exit() 84 | 85 | 86 | if __name__ == "__main__": 87 | args = parser.parse_args() 88 | create_benchmark_vmfb(args) 89 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/llama_benchmark/stateless_llama_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import argparse 8 | import numpy as np 9 | import os 10 | import re 11 | import sys 12 | 13 | from transformers import AutoTokenizer 14 | from iree import runtime as ireert 15 | from turbine_models.utils.benchmark import benchmark_module 16 | import turbine_models.custom_models.stateless_llama as llama 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--hf_auth_token", type=str, help="The Hugging Face auth token, required" 22 | ) 23 | parser.add_argument( 24 | "--hf_model_name", 25 | type=str, 26 | help="HF model name", 27 | default="meta-llama/Llama-2-7b-chat-hf", 28 | ) 29 | parser.add_argument("--external_weight_file", type=str, default="") 30 | parser.add_argument("--benchmark_vmfb_path", type=str, default="") 31 | parser.add_argument("--llama_vmfb_path", type=str, default="") 32 | parser.add_argument( 33 | "--steps", 34 | type=int, 35 | default=10, 36 | help="number of times second vicuna is run (# of tokens to benchmark)", 37 | ) 38 | parser.add_argument( 39 | "--run_forward_only_benchmark", 40 | action="store_true", 41 | help="do not include inititalization in benchmark for accurate tok/s", 42 | ) 43 | 44 | 45 | def run_benchmark(args): 46 | config = ireert.Config("local-task") 47 | 48 | if args.external_weight_file: 49 | index = ireert.ParameterIndex() 50 | index.load(args.external_weight_file) 51 | 52 | if not args.benchmark_vmfb_path: 53 | sys.exit("no benchmark_vmfb_path provided, required for run_benchmark") 54 | benchmark_mod = ireert.VmModule.mmap(config.vm_instance, args.benchmark_vmfb_path) 55 | 56 | if not args.llama_vmfb_path: 57 | sys.exit("no llama_vmfb_path provided, required for run_benchmark") 58 | 59 | tokenizer = AutoTokenizer.from_pretrained( 60 | args.hf_model_name, 61 | use_fast=False, 62 | use_auth_token=args.hf_auth_token, 63 | ) 64 | 65 | initial_input = tokenizer(llama.prompt, return_tensors="pt") 66 | example_input_id = initial_input.input_ids 67 | input = [] 68 | temp = np.asarray(example_input_id, dtype=None, order="C") 69 | input.append(temp) 70 | input.append(np.array(args.steps)) 71 | 72 | vmfbs = [] 73 | vmfbs.append(args.llama_vmfb_path) 74 | vmfbs.append(args.benchmark_vmfb_path) 75 | 76 | if args.external_weight_file: 77 | results = benchmark_module( 78 | benchmark_mod, 79 | "run", 80 | vmfbs, 81 | input, 82 | parameters=f"model={args.external_weight_file}", 83 | ) 84 | else: 85 | results = benchmark_module(benchmark_mod, "run", vmfbs, input) 86 | 87 | for benchmark_result in results: 88 | print( 89 | f"benchmark_name: {benchmark_result.benchmark_name}, time: {benchmark_result.time}, cpu_time: {benchmark_result.cpu_time}, iterations: {benchmark_result.iterations}, user_counters: {benchmark_result.user_counters}" 90 | ) 91 | 92 | 93 | def run_forward_benchmark(args): 94 | print("HERE") 95 | # Create the config for the IREE runtime environment 96 | config = ireert.Config("local-task") 97 | 98 | if not args.benchmark_vmfb_path: 99 | sys.exit("no benchmark_vmfb_path provided, required for run_benchmark") 100 | benchmark_mod = ireert.VmModule.mmap(config.vm_instance, args.benchmark_vmfb_path) 101 | 102 | # Load the external weight file if provided 103 | if args.external_weight_file: 104 | index = ireert.ParameterIndex() 105 | index.load(args.external_weight_file) 106 | 107 | # Ensure model name is in a safe format 108 | safe_name = args.hf_model_name.split("/")[-1].strip() 109 | safe_name = re.sub("-", "_", safe_name) 110 | 111 | # Load the .vmfb model file 112 | if args.llama_vmfb_path: 113 | mod = ireert.VmModule.mmap(config.vm_instance, args.llama_vmfb_path) 114 | elif os.path.exists(f"{safe_name}.vmfb"): 115 | mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") 116 | else: 117 | raise FileNotFoundError("No vmfb_path provided, required for run_vmfb") 118 | 119 | # Prepare the modules for the IREE runtime context 120 | vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)] 121 | 122 | # Include parameter module if external weight file is used 123 | if args.external_weight_file: 124 | param_module = ireert.create_io_parameters_module( 125 | config.vm_instance, index.create_provider(scope="model") 126 | ) 127 | vm_modules.insert(0, param_module) 128 | 129 | # Create the system context with the given configuration and modules 130 | ctx = ireert.SystemContext(vm_modules=vm_modules, config=config) 131 | 132 | # Initialize the tokenizer 133 | tokenizer = AutoTokenizer.from_pretrained( 134 | args.hf_model_name, use_fast=False, use_auth_token=args.hf_auth_token 135 | ) 136 | 137 | # Convert the prompt to input tensor 138 | initial_input = tokenizer(llama.prompt, return_tensors="pt") 139 | example_input_id = initial_input.input_ids 140 | device_inputs = [ireert.asdevicearray(config.device, example_input_id)] 141 | 142 | # Get the compiled module 143 | ModuleCompiled = ctx.modules.state_update 144 | init_val = ModuleCompiled["run_initialize"](*device_inputs) 145 | 146 | input = [] 147 | temp = np.asarray(init_val, dtype=None, order="C") 148 | input.append(temp) 149 | input.append(np.array(args.steps)) 150 | 151 | vmfbs = [] 152 | vmfbs.append(args.llama_vmfb_path) 153 | vmfbs.append(args.benchmark_vmfb_path) 154 | 155 | if args.external_weight_file: 156 | results = benchmark_module( 157 | benchmark_mod, 158 | "run", 159 | vmfbs, 160 | input, 161 | parameters=f"model={args.external_weight_file}", 162 | ) 163 | else: 164 | results = benchmark_module(benchmark_mod, "run", vmfbs, input) 165 | 166 | for benchmark_result in results: 167 | print( 168 | f"benchmark_name: {benchmark_result.benchmark_name}, time: {benchmark_result.time}, cpu_time: {benchmark_result.cpu_time}, iterations: {benchmark_result.iterations}, user_counters: {benchmark_result.user_counters}" 169 | ) 170 | print( 171 | f"estimate: avg. {args.steps/(float(benchmark_result.time[:-3])/1000)} tok/s" 172 | ) 173 | 174 | 175 | # Python Benchmarking Support for multiple modules 176 | 177 | DTYPE_TO_ABI_TYPE = { 178 | np.dtype(np.float32): "f32", 179 | np.dtype(np.int32): "i32", 180 | np.dtype(np.int64): "i64", 181 | np.dtype(np.float64): "f64", 182 | np.dtype(np.int16): "i16", 183 | np.dtype(np.int8): "i8", 184 | np.dtype(np.bool_): "i1", 185 | } 186 | 187 | 188 | if __name__ == "__main__": 189 | args = parser.parse_args() 190 | if args.run_forward_only_benchmark: 191 | run_forward_benchmark(args) 192 | else: 193 | run_benchmark(args) 194 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/llm_optimizations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nod-ai/SHARK-ModelDev/bdf46351281f47843bb0d8c77bcd8ecde7271b60/models/turbine_models/custom_models/llm_optimizations/__init__.py -------------------------------------------------------------------------------- /models/turbine_models/custom_models/llm_optimizations/streaming_llm/README.md: -------------------------------------------------------------------------------- 1 | # StreamingLLM 2 | 3 | StreamingLLM is based on the paper *"Efficient Streaming Language Models with Attention Sinks"* by Xiao et al from the MIT Han Lab. Here is the original [[paper](http://arxiv.org/abs/2309.17453)] and [[code](https://github.com/mit-han-lab/streaming-llm)]. 4 | 5 | The modify_llama.py code is highly inspired by the modify_llama.py code in the original repo, but tweaked to work with ToM HuggingFace and compilable through Turbine. 6 | 7 | The work introduces sink attention which in short is a combination of a fixed starting few sequence attention along with a sliding window attention. This is beneficial for these reasons: 8 | 9 | 1) Generate infinitely long context. 10 | 2) Maintain memory under certain threshold (controlled by window_length) 11 | 12 | 13 | ## Compiling LLMs with StreamingLLM 14 | 15 | Just need to add an extra `--streaming_llm` flag when you call stateless_llama when generating your vmfb. For example: 16 | ``` 17 | python python/turbine_models/custom_models/stateless_llama.py --compile_to=vmfb --hf_auth_token=your_token --external_weights="safetensors" --quantization="int4" --precision="f16" --streaming_llm 18 | ``` 19 | 20 | By default the vmfb will still be saved as `Llama_2_7b_chat_hf.vmfb`. 21 | 22 | ## Running LLMs with StreamingLLM 23 | 24 | Similar to compiling, just need to add an extra `--streaming_llm` flag when you call llm_runner.py. For example: 25 | ``` 26 | python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode --streaming_llm=true 27 | ``` 28 | 29 | ## Future Work: 30 | - [ ] Make window size configurable through python, everything is there but we'd need to initialize with a default value which would only be possible after we let `_create_initial_value` to take in initial value from GlobalAttribute somewhere [here](https://github.com/nod-ai/SHARK-Turbine/blob/18e8a4100b61adfd9425dd32f780dc5f90017813/python/shark_turbine/aot/support/ir_utils.py#L284-L316) . 31 | - [ ] Get flow.move to enable overlap of sliding window and src of data. (Currently need to evict when it's at least 2x size of window) For example by default our streamingLLM window_size is 256, so we evict at ~600(slightly more than 2x for safety) token. 32 | - [ ] Introduce Rerotation of RoPE to as seen [here](https://github.com/huggingface/transformers/blob/c2d283a64a7f33547952e3eb0fa6533fc375bcdd/src/transformers/cache_utils.py#L213-L218) to remove invasive modification of LlamaAttention module for streamingLLM. -------------------------------------------------------------------------------- /models/turbine_models/custom_models/llm_optimizations/streaming_llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nod-ai/SHARK-ModelDev/bdf46351281f47843bb0d8c77bcd8ecde7271b60/models/turbine_models/custom_models/llm_optimizations/streaming_llm/__init__.py -------------------------------------------------------------------------------- /models/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import torch.utils.checkpoint 7 | 8 | import torch.nn.functional as F 9 | 10 | from transformers.models.llama.modeling_llama import ( 11 | LlamaAttention, 12 | rotate_half, 13 | apply_rotary_pos_emb, 14 | repeat_kv, 15 | ) 16 | import types 17 | 18 | __all__ = ["enable_llama_pos_shift_attention"] 19 | 20 | 21 | def apply_rotary_pos_emb_single(x, cos, sin, position_ids): 22 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 23 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 24 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 25 | x_embed = (x * cos) + (rotate_half(x) * sin) 26 | return x_embed 27 | 28 | 29 | def llama_pos_shift_attention_forward( 30 | self, 31 | hidden_states: torch.Tensor, 32 | attention_mask: Optional[torch.Tensor] = None, 33 | position_ids: Optional[torch.LongTensor] = None, 34 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 35 | output_attentions: bool = False, 36 | use_cache: bool = False, 37 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 38 | bsz, q_len, _ = hidden_states.size() 39 | 40 | query_states = self.q_proj(hidden_states) 41 | key_states = self.k_proj(hidden_states) 42 | value_states = self.v_proj(hidden_states) 43 | 44 | query_states = query_states.view( 45 | bsz, q_len, self.num_heads, self.head_dim 46 | ).transpose(1, 2) 47 | key_states = key_states.view( 48 | bsz, q_len, self.num_key_value_heads, self.head_dim 49 | ).transpose(1, 2) 50 | value_states = value_states.view( 51 | bsz, q_len, self.num_key_value_heads, self.head_dim 52 | ).transpose(1, 2) 53 | 54 | kv_seq_len = key_states.shape[-2] 55 | if past_key_value is not None: 56 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 57 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 58 | ### Shift Pos: query pos is min(cache_size, idx) 59 | # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 60 | query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids) 61 | ### 62 | 63 | if past_key_value is not None: 64 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 65 | key_states, value_states = past_key_value.update( 66 | key_states, value_states, self.layer_idx, cache_kwargs 67 | ) 68 | 69 | ### Shift Pos: key pos is the pos in cache 70 | key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0) 71 | key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids) 72 | ### 73 | 74 | # repeat k/v heads if n_kv_heads < n_heads 75 | key_states = repeat_kv(key_states, self.num_key_value_groups) 76 | value_states = repeat_kv(value_states, self.num_key_value_groups) 77 | softmax_scale = 1.0 / math.sqrt(self.head_dim) 78 | attn_weights = ( 79 | torch.matmul(query_states, key_states.transpose(2, 3)) * softmax_scale 80 | ) 81 | 82 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 83 | raise ValueError( 84 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 85 | f" {attn_weights.size()}" 86 | ) 87 | 88 | # For causal mode, we use to get input mask, but now causal mode does not expect a mask 89 | # and we need to generate the causal mask ourselves. 90 | current_is_causal = False 91 | if self.is_causal and attention_mask is None and q_len > 1: 92 | current_is_causal = True 93 | if current_is_causal and attention_mask is None: 94 | bool_attention_mask = torch.ones( 95 | [query_states.shape[-2], key_states.shape[-2]], 96 | device=query_states.device, 97 | dtype=torch.bool, 98 | ).tril() 99 | additive_attention_mask = torch.zeros_like( 100 | bool_attention_mask, dtype=attn_weights.dtype 101 | ).masked_fill(bool_attention_mask.logical_not(), -10000) 102 | attn_weights = attn_weights + additive_attention_mask 103 | 104 | # Legacy support to take in mask for non-causal mode. 105 | if attention_mask is not None: 106 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 107 | raise ValueError( 108 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 109 | ) 110 | attn_weights = attn_weights + attention_mask 111 | 112 | # upcast attention to fp32 113 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( 114 | query_states.dtype 115 | ) 116 | attn_output = torch.matmul(attn_weights, value_states) 117 | 118 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 119 | raise ValueError( 120 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 121 | f" {attn_output.size()}" 122 | ) 123 | attn_output = attn_output.transpose(1, 2).contiguous() 124 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 125 | attn_output = self.o_proj(attn_output) 126 | return attn_output, None, past_key_value 127 | 128 | 129 | def enable_llama_pos_shift_attention(model): 130 | for name, module in reversed(model._modules.items()): 131 | if len(list(module.children())) > 0: 132 | enable_llama_pos_shift_attention( 133 | module, 134 | ) 135 | 136 | if isinstance(module, LlamaAttention): 137 | model._modules[name].forward = types.MethodType( 138 | llama_pos_shift_attention_forward, model._modules[name] 139 | ) 140 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/resnet_18.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | 5 | from transformers import AutoFeatureExtractor, AutoModelForImageClassification 6 | import torch 7 | from iree.turbine.aot import * 8 | from iree.compiler.ir import Context 9 | import iree.runtime as rt 10 | from turbine_models.custom_models.sd_inference import utils 11 | import iree.turbine.ops.iree as ops 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument( 17 | "--hf_model_name", 18 | type=str, 19 | help="HF model name", 20 | default="microsoft/resnet-18", 21 | ) 22 | parser.add_argument("--run_vmfb", action="store_true") 23 | parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") 24 | parser.add_argument("--vmfb_path", type=str, default="") 25 | parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") 26 | parser.add_argument( 27 | "--iree_target_triple", 28 | type=str, 29 | default="", 30 | help="Specify vulkan target triple or rocm/cuda target device.", 31 | ) 32 | parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") 33 | 34 | # TODO: Add other resnet models 35 | torch.random.manual_seed(0) 36 | 37 | 38 | class Resnet18Model(torch.nn.Module): 39 | def __init__(self): 40 | super().__init__() 41 | self.model = AutoModelForImageClassification.from_pretrained( 42 | "microsoft/resnet-18" 43 | ) 44 | # self.extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18") 45 | 46 | def forward(self, pixel_values_tensor: torch.Tensor): 47 | logits = self.model.forward(pixel_values_tensor).logits 48 | predicted_id = torch.argmax(logits, -1) 49 | return predicted_id 50 | 51 | 52 | def export_resnet_18_model( 53 | resnet_model, compile_to="torch", device=None, target_triple=None, max_alloc=None 54 | ): 55 | class CompiledResnet18Model(CompiledModule): 56 | params = export_parameters(resnet_model.model) 57 | 58 | def main(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)): 59 | dynamic_shapes = {"arg0_1": {0: torch.export.Dim("dim", max=15)}} 60 | return jittable(resnet_model.forward)(x, dynamic_shapes=dynamic_shapes) 61 | 62 | import_to = "INPUT" if compile_to == "linalg" else "IMPORT" 63 | inst = CompiledResnet18Model(context=Context(), import_to=import_to) 64 | 65 | module_str = str(CompiledModule.get_mlir_module(inst)) 66 | if compile_to != "vmfb": 67 | return module_str 68 | else: 69 | utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, "resnet_18") 70 | 71 | 72 | def export_static_resnet_18_model( 73 | resnet_model, compile_to="torch", device=None, target_triple=None, max_alloc=None 74 | ): 75 | resnet_model = resnet_model.half() 76 | input_args = (torch.empty((5, 3, 224, 224), dtype=torch.float16),) 77 | exported = export(resnet_model, args=input_args) 78 | 79 | module_str = str(exported.mlir_module) 80 | if compile_to != "vmfb": 81 | return module_str 82 | else: 83 | utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, "resnet_18") 84 | 85 | 86 | def run_resnet_18_vmfb_comparison(resnet_model, args): 87 | import numpy as np 88 | 89 | torch_dtype = torch.float32 if args.precision == "fp32" else torch.float16 90 | config = rt.Config(args.device) 91 | 92 | if args.vmfb_path: 93 | mod = rt.VmModule.mmap(config.vm_instance, args.vmfb_path) 94 | elif os.path.exists("resnet_18.vmfb"): 95 | mod = rt.VmModule.mmap(config.vm_instance, "resnet_18.vmfb") 96 | else: 97 | sys.exit("no vmfb_path provided, required for run_vmfb") 98 | 99 | vm_modules = [ 100 | mod, 101 | rt.create_hal_module(config.vm_instance, config.device), 102 | ] 103 | ctx = rt.SystemContext( 104 | vm_modules=vm_modules, 105 | config=config, 106 | ) 107 | inp = torch.rand(5, 3, 224, 224, dtype=torch_dtype) 108 | np.save(f"test_input_{args.precision}.npy", inp.numpy()) 109 | device_inputs = [rt.asdevicearray(config.device, inp)] 110 | 111 | # Turbine output 112 | CompModule = ctx.modules.compiled_resnet18_model 113 | turbine_output = CompModule["main"](*device_inputs) 114 | print( 115 | "TURBINE OUTPUT:", 116 | turbine_output.to_host(), 117 | turbine_output.to_host().shape, 118 | turbine_output.to_host().dtype, 119 | ) 120 | 121 | # Torch output 122 | torch_output = resnet_model.forward(inp) 123 | torch_output = torch_output.detach().cpu().numpy() 124 | print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) 125 | np.save(f"resnet18_golden_out.npy", torch_output) 126 | 127 | err = utils.largest_error(torch_output, turbine_output) 128 | print("LARGEST ERROR:", err) 129 | del CompModule 130 | return err 131 | 132 | 133 | if __name__ == "__main__": 134 | args = parser.parse_args() 135 | resnet_model = Resnet18Model() 136 | if args.run_vmfb: 137 | run_resnet_18_vmfb_comparison(resnet_model, args) 138 | else: 139 | mod_str = export_resnet_18_model( 140 | resnet_model, 141 | args.compile_to, 142 | args.device, 143 | args.iree_target_triple, 144 | args.vulkan_max_allocation, 145 | ) 146 | safe_name = "resnet_18" 147 | with open(f"{safe_name}.mlir", "w+") as f: 148 | f.write(mod_str) 149 | print("Saved to", safe_name + ".mlir") 150 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from turbine_models.model_runner import vmfbRunner 3 | from turbine_models.custom_models.sd_inference import utils, schedulers 4 | from iree import runtime as ireert 5 | import torch 6 | import numpy as np 7 | from tqdm.auto import tqdm 8 | from iree.turbine.ops.iree import trace_tensor 9 | 10 | torch.random.manual_seed(0) 11 | 12 | 13 | def run_mmdit_turbine( 14 | hidden_states, 15 | encoder_hidden_states, 16 | pooled_projections, 17 | timestep, 18 | args, 19 | ): 20 | mmdit_runner = vmfbRunner( 21 | args.device, 22 | args.vmfb_path, 23 | args.external_weight_path, 24 | ) 25 | iree_inputs = [ 26 | ireert.asdevicearray(mmdit_runner.config.device, hidden_states), 27 | ireert.asdevicearray(mmdit_runner.config.device, encoder_hidden_states), 28 | ireert.asdevicearray(mmdit_runner.config.device, pooled_projections), 29 | ireert.asdevicearray(mmdit_runner.config.device, timestep), 30 | ] 31 | noise_pred = mmdit_runner.ctx.modules.compiled_mmdit["run_forward"]( 32 | *iree_inputs 33 | ).to_host() 34 | return noise_pred 35 | 36 | 37 | @torch.no_grad() 38 | def run_diffusers_mmdit( 39 | hidden_states, 40 | encoder_hidden_states, 41 | pooled_projections, 42 | timestep, 43 | args, 44 | ): 45 | from turbine_models.custom_models.sd3_inference.sd3_mmdit import MMDiTModel 46 | 47 | mmdit_model = MMDiTModel( 48 | args.hf_model_name, 49 | dtype=torch.float32, 50 | ) 51 | noise_pred = mmdit_model.forward( 52 | hidden_states.float(), 53 | encoder_hidden_states.float(), 54 | pooled_projections.float(), 55 | timestep.float(), 56 | ) 57 | 58 | return noise_pred.numpy() 59 | 60 | 61 | def run_attn_turbine(q, k, v, args): 62 | attn_runner = vmfbRunner( 63 | args.device, 64 | args.vmfb_path, 65 | None, 66 | ) 67 | iree_inputs = [ 68 | ireert.asdevicearray(attn_runner.config.device, q), 69 | ireert.asdevicearray(attn_runner.config.device, k), 70 | ireert.asdevicearray(attn_runner.config.device, v), 71 | ] 72 | attn_output = attn_runner.ctx.modules.compiled_attn["run_forward"]( 73 | *iree_inputs 74 | ).to_host() 75 | return attn_output 76 | 77 | 78 | @torch.no_grad() 79 | def run_attn_torch(q, k, v, args): 80 | from turbine_models.custom_models.sd3_inference.sd3_mmdit import MMDiTAttention 81 | 82 | mmdit_attn = MMDiTAttention() 83 | attn_output = mmdit_attn.forward( 84 | torch.tensor(q, dtype=torch.float32), 85 | torch.tensor(k, dtype=torch.float32), 86 | torch.tensor(v, dtype=torch.float32), 87 | ) 88 | 89 | return attn_output.numpy() 90 | 91 | 92 | def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]): 93 | if not np.allclose(turbine_output, torch_output, rtol=4e-2, atol=4e-2): 94 | if turbine_output.ndim > 0: 95 | orig_dim = dim 96 | for idx, i in enumerate(torch_output): 97 | dim = [*orig_dim, idx] 98 | try: 99 | np.testing.assert_allclose( 100 | turbine_output[idx], torch_output[idx], rtol=4e-2, atol=4e-2 101 | ) 102 | except Exception as e: 103 | err = np.abs(turbine_output[idx] - torch_output[idx]) 104 | failed_dims.append(dim) 105 | errs.append([err, turbine_output[idx], torch_output[idx]]) 106 | failed_dims, errs = find_errs( 107 | turbine_output[idx], torch_output[idx], dim, failed_dims, errs 108 | ) 109 | return (failed_dims, errs) 110 | 111 | 112 | if __name__ == "__main__": 113 | from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args 114 | import numpy as np 115 | import os 116 | 117 | torch.random.manual_seed(0) 118 | 119 | if args.precision == "fp16": 120 | dtype = torch.float16 121 | np_dtype = np.float16 122 | else: 123 | dtype = torch.float32 124 | np_dtype = np.float32 125 | 126 | if args.attn_repro: 127 | qkv_shape = (2, 24, 4250, 64) 128 | example_qkv = [ 129 | np.load("q.npy").astype(np_dtype), 130 | np.load("k.npy").astype(np_dtype), 131 | np.load("v.npy").astype(np_dtype), 132 | ] 133 | turbine_output = run_attn_turbine( 134 | *example_qkv, 135 | args, 136 | ) 137 | torch_output = run_attn_torch(*example_qkv, args).astype(np.float16) 138 | np.save("turbine_attn_output.npy", turbine_output) 139 | np.save("torch_attn_output.npy", torch_output) 140 | failed_dims, errs = find_errs(turbine_output, torch_output) 141 | for idx, dim in enumerate(failed_dims): 142 | if len(dim) == len(torch_output.shape): 143 | print("Failed dimension: ", dim, " with error: ", errs[idx][0]) 144 | print("Turbine output: ", errs[idx][1]) 145 | print("Torch output: ", errs[idx][2]) 146 | print(torch_output.shape) 147 | exit() 148 | 149 | batch_size = args.batch_size * 2 # do classifier free guidance 150 | hidden_states = torch.randn( 151 | (batch_size, 16, args.height // 8, args.width // 8), dtype=dtype 152 | ) 153 | encoder_hidden_states = torch.randn( 154 | (batch_size, args.max_length * 2, 4096), dtype=dtype 155 | ) 156 | pooled_projections = torch.randn((batch_size, 2048), dtype=dtype) 157 | timestep = torch.tensor([0, 0], dtype=dtype) 158 | 159 | turbine_output = run_mmdit_turbine( 160 | hidden_states, 161 | encoder_hidden_states, 162 | pooled_projections, 163 | timestep, 164 | args, 165 | ) 166 | print( 167 | "TURBINE SPLIT OUTPUT:", 168 | turbine_output, 169 | turbine_output.shape, 170 | turbine_output.dtype, 171 | ) 172 | turbine_output = turbine_output 173 | 174 | if args.compare_vs_torch: 175 | print("generating torch output: ") 176 | torch_output = run_diffusers_mmdit( 177 | hidden_states, 178 | encoder_hidden_states, 179 | pooled_projections, 180 | timestep, 181 | args, 182 | ) 183 | np.save("torch_mmdit_output.npy", torch_output.astype(np.float16)) 184 | print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) 185 | 186 | print("\n(torch (comfy) image latents to iree image latents): ") 187 | 188 | np.testing.assert_allclose(turbine_output, torch_output, rtol=4e-2, atol=4e-2) 189 | print("passed!") 190 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import os 8 | import sys 9 | 10 | import safetensors 11 | from iree import runtime as ireert 12 | import iree.compiler as ireec 13 | from iree.compiler.ir import Context 14 | import numpy as np 15 | from iree.turbine.aot import * 16 | from iree.turbine.transforms.general.add_metadata import AddMetadataPass 17 | from turbine_models.custom_models.sd_inference import utils 18 | import torch 19 | from turbine_models.custom_models.sd3_inference.text_encoder_impls import ( 20 | SDClipModel, 21 | SDXLClipG, 22 | T5XXLModel, 23 | load_into, 24 | ) 25 | from huggingface_hub import hf_hub_download 26 | from safetensors import safe_open 27 | 28 | CLIPG_CONFIG = { 29 | "hidden_act": "gelu", 30 | "hidden_size": 1280, 31 | "intermediate_size": 5120, 32 | "num_attention_heads": 20, 33 | "num_hidden_layers": 32, 34 | } 35 | 36 | CLIPL_CONFIG = { 37 | "hidden_act": "quick_gelu", 38 | "hidden_size": 768, 39 | "intermediate_size": 3072, 40 | "num_attention_heads": 12, 41 | "num_hidden_layers": 12, 42 | } 43 | 44 | T5_CONFIG = { 45 | "d_ff": 10240, 46 | "d_model": 4096, 47 | "num_heads": 64, 48 | "num_layers": 24, 49 | "vocab_size": 32128, 50 | } 51 | 52 | 53 | class TextEncoderModule(torch.nn.Module): 54 | @torch.no_grad() 55 | def __init__( 56 | self, 57 | batch_size=1, 58 | ): 59 | super().__init__() 60 | self.dtype = torch.float16 61 | self.clip_l = SDClipModel( 62 | layer="hidden", 63 | layer_idx=-2, 64 | device="cpu", 65 | dtype=self.dtype, 66 | layer_norm_hidden_state=False, 67 | return_projected_pooled=False, 68 | textmodel_json_config=CLIPL_CONFIG, 69 | ).half() 70 | clip_l_weights = hf_hub_download( 71 | repo_id="stabilityai/stable-diffusion-3-medium", 72 | filename="text_encoders/clip_l.safetensors", 73 | ) 74 | with safe_open(clip_l_weights, framework="pt", device="cpu") as f: 75 | load_into(f, self.clip_l.transformer, "", "cpu", self.dtype) 76 | self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype).half() 77 | clip_g_weights = hf_hub_download( 78 | repo_id="stabilityai/stable-diffusion-3-medium", 79 | filename="text_encoders/clip_g.safetensors", 80 | ) 81 | with safe_open(clip_g_weights, framework="pt", device="cpu") as f: 82 | load_into(f, self.clip_g.transformer, "", "cpu", self.dtype) 83 | self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=self.dtype).half() 84 | t5_weights = hf_hub_download( 85 | repo_id="stabilityai/stable-diffusion-3-medium", 86 | filename="text_encoders/t5xxl_fp16.safetensors", 87 | ) 88 | with safe_open(t5_weights, framework="pt", device="cpu") as f: 89 | load_into(f, self.t5xxl.transformer, "", "cpu", self.dtype) 90 | 91 | self.do_classifier_free_guidance = True 92 | self.batch_size = batch_size 93 | 94 | def get_cond(self, tokens_l, tokens_g, tokens_t5xxl): 95 | l_out, l_pooled = self.clip_l.forward(tokens_l) 96 | g_out, g_pooled = self.clip_g.forward(tokens_g) 97 | t5_out, _ = self.t5xxl.forward(tokens_t5xxl) 98 | lg_out = torch.cat([l_out, g_out], dim=-1) 99 | lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) 100 | return torch.cat([lg_out, t5_out], dim=-2), torch.cat( 101 | (l_pooled, g_pooled), dim=-1 102 | ) 103 | 104 | def forward(self, tokens_g, tokens_l, tokens_t5xxl, neg_g, neg_l, neg_t5): 105 | conditioning, cond_pool = self.get_cond(tokens_l, tokens_g, tokens_t5xxl) 106 | neg_cond, neg_cond_pool = self.get_cond(neg_l, neg_g, neg_t5) 107 | 108 | prompt_embeds = torch.cat([neg_cond, conditioning], dim=0) 109 | pooled_prompt_embeds = torch.cat([neg_cond_pool, cond_pool], dim=0) 110 | 111 | return prompt_embeds, pooled_prompt_embeds 112 | 113 | 114 | @torch.no_grad() 115 | def export_text_encoders( 116 | hf_model_name, 117 | max_length=64, 118 | batch_size=1, 119 | precision="fp16", 120 | compile_to="torch", 121 | external_weights=None, 122 | external_weight_path=None, 123 | device=None, 124 | target_triple=None, 125 | ireec_flags=None, 126 | exit_on_vmfb=False, 127 | pipeline_dir=None, 128 | input_mlir=None, 129 | attn_spec=None, 130 | decomp_attn=True, 131 | ): 132 | 133 | safe_name = utils.create_safe_name( 134 | hf_model_name, 135 | f"_bs{batch_size}_{str(max_length)}_{precision}_text_encoders", 136 | ) 137 | if pipeline_dir: 138 | safe_name = os.path.join(pipeline_dir, safe_name) 139 | 140 | if input_mlir: 141 | vmfb_path = utils.compile_to_vmfb( 142 | input_mlir, 143 | device, 144 | target_triple, 145 | ireec_flags, 146 | safe_name, 147 | mlir_source="file", 148 | return_path=not exit_on_vmfb, 149 | const_expr_hoisting=True, 150 | attn_spec=attn_spec, 151 | ) 152 | return vmfb_path 153 | model = TextEncoderModule( 154 | batch_size=batch_size, 155 | ) 156 | mapper = {} 157 | 158 | assert ( 159 | ".safetensors" not in external_weight_path 160 | ), "Original parameters format incompatible with IREE safetensors parser. Use '.irpa' instead." 161 | 162 | input_args = [torch.empty([batch_size, 77, 2], dtype=torch.int64) for x in range(6)] 163 | 164 | decomp_list = [] 165 | if decomp_attn == True: 166 | decomp_list = [ 167 | torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, 168 | torch.ops.aten._scaled_dot_product_flash_attention.default, 169 | torch.ops.aten.scaled_dot_product_attention, 170 | ] 171 | with decompositions.extend_aot_decompositions( 172 | from_current=True, 173 | add_ops=decomp_list, 174 | ): 175 | fxb = FxProgramsBuilder(model) 176 | 177 | @fxb.export_program( 178 | args=(input_args,), 179 | ) 180 | def _forward( 181 | module, 182 | inputs, 183 | ): 184 | return module.forward(*inputs) 185 | 186 | class CompiledTextEncoder(CompiledModule): 187 | encode_tokens = _forward 188 | 189 | if external_weights: 190 | externalize_module_parameters(model) 191 | save_module_parameters(external_weight_path, model) 192 | 193 | inst = CompiledTextEncoder(context=Context(), import_to="IMPORT") 194 | 195 | module = CompiledModule.get_mlir_module(inst) 196 | 197 | model_metadata_forward = { 198 | "model_name": "sd3_clip_t5xxl_text_encoders", 199 | "input_shapes": [(batch_size, max_length, 2) for x in range(6)], 200 | "input_dtypes": ["int64" for x in range(6)], 201 | "output_shapes": [ 202 | (2 * batch_size, max_length * 2, 4096), 203 | (2 * batch_size, 2048), 204 | ], 205 | "output_dtypes": ["float32"], 206 | } 207 | module = AddMetadataPass(module, model_metadata_forward, "forward").run() 208 | module_str = str(module) 209 | if compile_to != "vmfb": 210 | return module_str 211 | else: 212 | vmfb_path = utils.compile_to_vmfb( 213 | module_str, 214 | device, 215 | target_triple, 216 | ireec_flags, 217 | safe_name, 218 | return_path=not exit_on_vmfb, 219 | const_expr_hoisting=True, 220 | attn_spec=attn_spec, 221 | ) 222 | return vmfb_path 223 | 224 | 225 | if __name__ == "__main__": 226 | from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args 227 | 228 | mod_str, _ = export_text_encoders( 229 | args.hf_model_name, 230 | args.max_length, 231 | args.batch_size, 232 | args.precision, 233 | args.compile_to, 234 | args.external_weights, 235 | args.external_weight_path, 236 | args.device, 237 | args.iree_target_triple, 238 | args.ireec_flags + args.clip_flags, 239 | exit_on_vmfb=True, 240 | pipeline_dir=args.pipeline_dir, 241 | input_mlir=args.input_mlir, 242 | attn_spec=args.attn_spec, 243 | ) 244 | if args.input_mlir or args.weights_only or args.compile_to == "vmfb": 245 | exit() 246 | safe_name = utils.create_safe_name( 247 | args.hf_model_name, f"_{str(args.max_length)}_{args.precision}_text_encoders" 248 | ) 249 | with open(f"{safe_name}.mlir", "w+") as f: 250 | f.write(mod_str) 251 | print("Saved to", safe_name + ".mlir") 252 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd3_inference/sd3_text_encoders_runner.py: -------------------------------------------------------------------------------- 1 | from turbine_models.model_runner import vmfbRunner 2 | from turbine_models.custom_models.sd3_inference.text_encoder_impls import ( 3 | SD3Tokenizer, 4 | T5XXLTokenizer, 5 | SDXLClipGTokenizer, 6 | ) 7 | from iree import runtime as ireert 8 | import torch 9 | import numpy as np 10 | 11 | 12 | def run_prompt_encoder( 13 | vmfb_path, 14 | device, 15 | external_weight_path, 16 | input_ids, 17 | uncond_input_ids, 18 | ): 19 | prompt_encoder_runner = vmfbRunner(device, vmfb_path, external_weight_path) 20 | # np.save("input0.npy", input_ids[0].numpy()) 21 | # np.save("input1.npy", input_ids[1].numpy()) 22 | # np.save("input2.npy", input_ids[2].numpy()) 23 | # np.save("input3.npy", uncond_input_ids[0].numpy()) 24 | # np.save("input4.npy", uncond_input_ids[1].numpy()) 25 | # np.save("input5.npy", uncond_input_ids[2].numpy()) 26 | prompt_encoder_inputs = [ 27 | ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[0]), 28 | ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[1]), 29 | ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[2]), 30 | ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[0]), 31 | ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[1]), 32 | ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[2]), 33 | ] 34 | encoded_outputs = prompt_encoder_runner.ctx.modules.compiled_text_encoder[ 35 | "encode_tokens" 36 | ](*prompt_encoder_inputs) 37 | for i in encoded_outputs: 38 | i = i.to_host() 39 | del prompt_encoder_inputs 40 | return encoded_outputs 41 | 42 | 43 | def run_tokenize( 44 | tokenizer, 45 | prompt, 46 | negative_prompt, 47 | ): 48 | prompt_tokens_dict = tokenizer.tokenize_with_weights(prompt) 49 | neg_prompt_tokens_dict = tokenizer.tokenize_with_weights(negative_prompt) 50 | text_input_ids_list = list(prompt_tokens_dict.values()) 51 | uncond_input_ids_list = list(neg_prompt_tokens_dict.values()) 52 | return text_input_ids_list, uncond_input_ids_list 53 | 54 | 55 | if __name__ == "__main__": 56 | from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args 57 | 58 | tokenizer = SD3Tokenizer() 59 | 60 | text_input_ids_list, uncond_input_ids_list = run_tokenize( 61 | tokenizer, 62 | args.prompt, 63 | args.negative_prompt, 64 | ) 65 | turbine_output1, turbine_output2 = run_prompt_encoder( 66 | args.vmfb_path, 67 | args.rt_device, 68 | args.external_weight_path, 69 | text_input_ids_list, 70 | uncond_input_ids_list, 71 | ) 72 | print( 73 | "TURBINE OUTPUT 1:", 74 | turbine_output1.to_host(), 75 | turbine_output1.shape, 76 | turbine_output1.dtype, 77 | ) 78 | 79 | print( 80 | "TURBINE OUTPUT 2:", 81 | turbine_output2.to_host(), 82 | turbine_output2.shape, 83 | turbine_output2.dtype, 84 | ) 85 | 86 | if args.compare_vs_torch: 87 | print("generating torch output: ") 88 | from turbine_models.custom_models.sd_inference import utils 89 | from turbine_models.custom_models.sd3_inference.sd3_text_encoders import ( 90 | TextEncoderModule, 91 | ) 92 | 93 | torch_encoder_model = TextEncoderModule( 94 | args.batch_size, 95 | ) 96 | torch_output1, torch_output2 = torch_encoder_model.forward( 97 | *text_input_ids_list, *uncond_input_ids_list 98 | ) 99 | np.save("torch_output1.npy", torch_output1) 100 | np.save("torch_output2.npy", torch_output2) 101 | print( 102 | "TORCH OUTPUT 1:", torch_output1, torch_output1.shape, torch_output1.dtype 103 | ) 104 | 105 | print( 106 | "TORCH OUTPUT 2:", torch_output2, torch_output2.shape, torch_output2.dtype 107 | ) 108 | rtol = 4e-2 109 | atol = 4e-2 110 | 111 | np.testing.assert_allclose( 112 | torch_output1, turbine_output1.to_host(), rtol, atol, verbose=True 113 | ) 114 | np.testing.assert_allclose( 115 | torch_output2, turbine_output2.to_host(), rtol, atol, verbose=True 116 | ) 117 | print("Passed!") 118 | # TODO: Figure out why we occasionally segfault without unlinking output variables 119 | turbine_output1, turbine_output2 = (None, None) 120 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd3_inference/sd3_vae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import copy 8 | import os 9 | import sys 10 | 11 | from iree import runtime as ireert 12 | from iree.compiler.ir import Context 13 | import numpy as np 14 | from iree.turbine.aot import * 15 | from iree.turbine.dynamo.passes import ( 16 | DEFAULT_DECOMPOSITIONS, 17 | ) 18 | from turbine_models.custom_models.sd_inference import utils 19 | import torch 20 | import torch._dynamo as dynamo 21 | from diffusers import AutoencoderKL 22 | 23 | 24 | class VaeModel(torch.nn.Module): 25 | def __init__( 26 | self, 27 | hf_model_name, 28 | ): 29 | super().__init__() 30 | self.vae = AutoencoderKL.from_pretrained( 31 | hf_model_name, 32 | subfolder="vae", 33 | ) 34 | 35 | def decode(self, inp): 36 | inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor 37 | image = self.vae.decode(inp, return_dict=False)[0] 38 | image = image.float() 39 | image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] 40 | return image 41 | 42 | def encode(self, inp): 43 | image_np = inp / 255.0 44 | image_np = np.moveaxis(image_np, 2, 0) 45 | batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0) 46 | image_torch = torch.from_numpy(batch_images) 47 | image_torch = 2.0 * image_torch - 1.0 48 | image_torch = image_torch 49 | latent = self.vae.encode(image_torch) 50 | return latent 51 | 52 | 53 | def export_vae_model( 54 | vae_model, 55 | hf_model_name, 56 | batch_size, 57 | height, 58 | width, 59 | precision, 60 | compile_to="torch", 61 | external_weights=None, 62 | external_weight_path=None, 63 | device=None, 64 | target_triple=None, 65 | ireec_flags=None, 66 | decomp_attn=False, 67 | exit_on_vmfb=False, 68 | pipeline_dir=None, 69 | attn_spec=None, 70 | input_mlir=None, 71 | weights_only=False, 72 | ): 73 | dtype = torch.float16 if precision == "fp16" else torch.float32 74 | safe_name = utils.create_safe_name( 75 | hf_model_name, 76 | f"_bs{batch_size}_{height}x{width}_{precision}_vae", 77 | ) 78 | if pipeline_dir: 79 | safe_name = os.path.join(pipeline_dir, safe_name) 80 | 81 | if input_mlir: 82 | vmfb_path = utils.compile_to_vmfb( 83 | input_mlir, 84 | device, 85 | target_triple, 86 | ireec_flags, 87 | safe_name, 88 | mlir_source="file", 89 | return_path=not exit_on_vmfb, 90 | attn_spec=attn_spec, 91 | ) 92 | return vmfb_path 93 | 94 | if device == "cpu": 95 | decomp_attn = True 96 | 97 | if dtype == torch.float16: 98 | vae_model = vae_model.half() 99 | mapper = {} 100 | utils.save_external_weights( 101 | mapper, vae_model, external_weights, external_weight_path 102 | ) 103 | if weights_only: 104 | return external_weight_path 105 | 106 | input_image_shape = (height, width, 3) 107 | input_latents_shape = (batch_size, 16, height // 8, width // 8) 108 | encode_args = [ 109 | torch.empty( 110 | input_image_shape, 111 | dtype=torch.float32, 112 | ) 113 | ] 114 | decode_args = [ 115 | torch.empty( 116 | input_latents_shape, 117 | dtype=dtype, 118 | ) 119 | ] 120 | decomp_list = [] 121 | if decomp_attn == True: 122 | decomp_list = [ 123 | torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, 124 | torch.ops.aten._scaled_dot_product_flash_attention.default, 125 | torch.ops.aten.scaled_dot_product_attention, 126 | ] 127 | with decompositions.extend_aot_decompositions( 128 | from_current=True, 129 | add_ops=decomp_list, 130 | ): 131 | fxb = FxProgramsBuilder(vae_model) 132 | 133 | # @fxb.export_program(args=(encode_args,)) 134 | # def _encode(module, inputs,): 135 | # return module.encode(*inputs) 136 | 137 | @fxb.export_program(args=(decode_args,)) 138 | def _decode(module, inputs): 139 | return module.decode(*inputs) 140 | 141 | class CompiledVae(CompiledModule): 142 | decode = _decode 143 | 144 | if external_weights: 145 | externalize_module_parameters(vae_model) 146 | 147 | inst = CompiledVae(context=Context(), import_to="IMPORT") 148 | 149 | module_str = str(CompiledModule.get_mlir_module(inst)) 150 | 151 | if compile_to != "vmfb": 152 | return module_str 153 | else: 154 | vmfb_path = utils.compile_to_vmfb( 155 | module_str, 156 | device, 157 | target_triple, 158 | ireec_flags, 159 | safe_name, 160 | return_path=not exit_on_vmfb, 161 | attn_spec=attn_spec, 162 | ) 163 | return vmfb_path 164 | 165 | 166 | if __name__ == "__main__": 167 | from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args 168 | 169 | if args.input_mlir: 170 | vae_model = None 171 | else: 172 | vae_model = VaeModel( 173 | args.hf_model_name, 174 | ) 175 | mod_str = export_vae_model( 176 | vae_model, 177 | args.hf_model_name, 178 | args.batch_size, 179 | height=args.height, 180 | width=args.width, 181 | precision=args.precision, 182 | compile_to=args.compile_to, 183 | external_weights=args.external_weights, 184 | external_weight_path=args.external_weight_path, 185 | device=args.device, 186 | target_triple=args.iree_target_triple, 187 | ireec_flags=args.ireec_flags + args.attn_flags + args.vae_flags, 188 | decomp_attn=args.decomp_attn, 189 | attn_spec=args.attn_spec, 190 | input_mlir=args.input_mlir, 191 | ) 192 | if args.input_mlir or (args.compile_to == "vmfb"): 193 | exit() 194 | safe_name = utils.create_safe_name( 195 | args.hf_model_name, 196 | f"_bs{str(args.batch_size)}_{args.height}x{args.width}_{args.precision}_vae", 197 | ) 198 | with open(f"{safe_name}.mlir", "w+") as f: 199 | f.write(mod_str) 200 | print("Saved to", safe_name + ".mlir") 201 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from turbine_models.model_runner import vmfbRunner 3 | from iree import runtime as ireert 4 | import torch 5 | 6 | torch.random.manual_seed(0) 7 | 8 | 9 | def run_vae( 10 | device, 11 | example_input, 12 | vmfb_path, 13 | hf_model_name, 14 | external_weight_path, 15 | ): 16 | runner = vmfbRunner(device, vmfb_path, external_weight_path) 17 | inputs = [ireert.asdevicearray(runner.config.device, example_input)] 18 | results = runner.ctx.modules.compiled_vae["decode"](*inputs).to_host() 19 | results = imagearray_from_vae_out(results) 20 | return results 21 | 22 | 23 | def run_torch_vae(hf_model_name, variant, example_input): 24 | from turbine_models.custom_models.sd_inference.vae import SD3VaeModel 25 | 26 | vae_model = SD3VaeModel( 27 | hf_model_name, 28 | ) 29 | 30 | if variant == "decode": 31 | results = vae_model.decode(example_input) 32 | elif variant == "encode": 33 | results = vae_model.encode(example_input) 34 | np_torch_output = results.detach().cpu().numpy() 35 | np_torch_output = imagearray_from_vae_out(np_torch_output) 36 | return np_torch_output 37 | 38 | 39 | def imagearray_from_vae_out(image): 40 | if image.ndim == 4: 41 | image = image[0] 42 | image = torch.from_numpy(image).cpu().permute(1, 2, 0).float().numpy() 43 | image = (image * 255).round().astype("uint8") 44 | return image 45 | 46 | 47 | if __name__ == "__main__": 48 | from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args 49 | import numpy as np 50 | from PIL import Image 51 | 52 | dtype = torch.float16 if args.precision == "fp16" else torch.float32 53 | if args.vae_variant == "decode": 54 | example_input = torch.rand( 55 | args.batch_size, 16, args.height // 8, args.width // 8, dtype=dtype 56 | ) 57 | if args.vae_input_path: 58 | example_input = np.load(args.vae_input_path) 59 | if example_input.shape[0] == 2: 60 | example_input = np.split(example_input, 2)[0] 61 | elif args.vae_variant == "encode": 62 | example_input = torch.rand( 63 | args.batch_size, 3, args.height, args.width, dtype=dtype 64 | ) 65 | print("generating turbine output:") 66 | turbine_results = run_vae( 67 | args.device, 68 | example_input, 69 | args.vmfb_path, 70 | args.hf_model_name, 71 | args.external_weight_path, 72 | ) 73 | print( 74 | "TURBINE OUTPUT:", 75 | turbine_results, 76 | turbine_results.shape, 77 | turbine_results.dtype, 78 | ) 79 | if args.compare_vs_torch: 80 | print("generating torch output: ") 81 | from turbine_models.custom_models.sd_inference import utils 82 | 83 | torch_output = run_torch_vae( 84 | args.hf_model_name, args.vae_variant, torch.tensor(example_input).float() 85 | ) 86 | print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) 87 | if args.vae_input_path: 88 | out_image_torch = Image.fromarray(torch_output) 89 | out_image_torch.save("vae_test_output_torch.png") 90 | out_image_turbine = Image.fromarray(turbine_results) 91 | out_image_turbine.save("vae_test_output_turbine.png") 92 | # Allow a small amount of wiggle room for rounding errors (1) 93 | 94 | np.testing.assert_allclose(turbine_results, torch_output, rtol=1, atol=1) 95 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd_inference/clip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import os 8 | import re 9 | 10 | from iree.compiler.ir import Context 11 | from iree.turbine.aot import * 12 | from iree.turbine.transforms.general.add_metadata import AddMetadataPass 13 | from turbine_models.custom_models.sd_inference import utils 14 | import torch 15 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor 16 | from turbine_models.turbine_tank import turbine_tank 17 | 18 | 19 | @torch.no_grad() 20 | def export_clip_model( 21 | hf_model_name, 22 | batch_size: int = 1, 23 | max_length: int = 64, 24 | precision: str = "fp16", 25 | compile_to: str = "torch", 26 | external_weights: str = None, 27 | external_weight_path: str = None, 28 | device: str = "llvm-cpu", 29 | target: str = "x86_64-linux-gnu", 30 | ireec_flags: str = None, 31 | exit_on_vmfb: bool = False, 32 | pipeline_dir: str = None, 33 | input_mlir: str = None, 34 | attn_spec: str = None, 35 | weights_only: bool = False, 36 | upload_ir: bool = False, 37 | decomp_attn: bool = False, 38 | ): 39 | input_len = max_length 40 | safe_name = utils.create_safe_name( 41 | hf_model_name, f"_bs{batch_size}_{str(max_length)}-{precision}-clip" 42 | ) 43 | if pipeline_dir not in [None, ""]: 44 | safe_name = os.path.join(pipeline_dir, safe_name) 45 | if input_mlir: 46 | vmfb_path = utils.compile_to_vmfb( 47 | input_mlir, 48 | device, 49 | target, 50 | ireec_flags, 51 | safe_name, 52 | mlir_source="file", 53 | return_path=not exit_on_vmfb, 54 | const_expr_hoisting=True, 55 | attn_spec=attn_spec, 56 | ) 57 | return vmfb_path 58 | if "google/t5" in hf_model_name: 59 | from transformers import T5Tokenizer, T5Model 60 | 61 | tokenizer = T5Tokenizer.from_pretrained(hf_model_name) 62 | text_encoder_model = T5Model.from_pretrained(hf_model_name) 63 | input_len = 512 64 | 65 | else: 66 | # TODO: Add better filtering mechanism for things that require CLIPProcessor 67 | if "openai" in hf_model_name: 68 | tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 69 | hf_subfolder = "" # CLIPProcessor does not have a subfolder 70 | input_len = 10 71 | else: 72 | # Load the tokenizer and text encoder to tokenize and encode the text. 73 | tokenizer = CLIPTokenizer.from_pretrained( 74 | hf_model_name, 75 | subfolder="tokenizer", 76 | ) 77 | hf_subfolder = "text_encoder" 78 | 79 | text_encoder_model = CLIPTextModel.from_pretrained( 80 | hf_model_name, 81 | subfolder=hf_subfolder, 82 | ) 83 | if precision == "fp16": 84 | text_encoder_model = text_encoder_model.half() 85 | mapper = {} 86 | utils.save_external_weights( 87 | mapper, text_encoder_model, external_weights, external_weight_path 88 | ) 89 | if weights_only: 90 | return external_weight_path 91 | 92 | if "google/t5" in hf_model_name: 93 | input_shapes = [(batch_size, input_len), (batch_size, input_len)] 94 | 95 | class CompiledTextEncoder(CompiledModule): 96 | if external_weights: 97 | params = export_parameters( 98 | text_encoder_model, 99 | external=True, 100 | external_scope="", 101 | name_mapper=mapper.get, 102 | ) 103 | else: 104 | params = export_parameters(text_encoder_model) 105 | 106 | def encode_tokens( 107 | self, 108 | inp=AbstractTensor(1, input_len, dtype=torch.int64), 109 | decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64), 110 | ): 111 | return jittable(text_encoder_model.forward)( 112 | input_ids=inp, decoder_input_ids=decoder_input_ids 113 | ) 114 | 115 | else: 116 | input_shapes = [str((batch_size, input_len)), str((batch_size, input_len))] 117 | 118 | class CompiledTextEncoder(CompiledModule): 119 | if external_weights: 120 | params = export_parameters( 121 | text_encoder_model, 122 | external=True, 123 | external_scope="", 124 | name_mapper=mapper.get, 125 | ) 126 | else: 127 | params = export_parameters(text_encoder_model) 128 | 129 | def encode_tokens_attn_mask( 130 | self, 131 | inp=AbstractTensor(1, input_len, dtype=torch.int64), 132 | attn_mask=AbstractTensor(1, input_len, dtype=torch.int64), 133 | ): 134 | return jittable(text_encoder_model.forward)( 135 | input_ids=inp, attention_mask=attn_mask 136 | ) 137 | 138 | def encode_tokens( 139 | self, 140 | inp=AbstractTensor(1, input_len, dtype=torch.int64), 141 | ): 142 | return jittable(text_encoder_model.forward)(input_ids=inp) 143 | 144 | import_to = "INPUT" if compile_to == "linalg" else "IMPORT" 145 | inst = CompiledTextEncoder(context=Context(), import_to=import_to) 146 | module = CompiledModule.get_mlir_module(inst) 147 | 148 | model_metadata_attn_mask = { 149 | "model_name": hf_model_name + "_text_encoder", 150 | "input_shapes": input_shapes, 151 | "input_dtypes": ["int64", "int64"], 152 | "use_attention_mask": True, 153 | } 154 | model_metadata_encode = { 155 | "model_name": hf_model_name + "_text_encoder", 156 | "input_shapes": input_shapes[0], 157 | "input_dtypes": ["int64"], 158 | "use_attention_mask": False, 159 | } 160 | module = AddMetadataPass( 161 | module, model_metadata_attn_mask, "encode_tokens_attn_mask" 162 | ).run() 163 | module = AddMetadataPass(module, model_metadata_encode, "encode_tokens").run() 164 | 165 | module_str = str(module) 166 | if compile_to != "vmfb": 167 | return module_str 168 | else: 169 | vmfb_path = utils.compile_to_vmfb( 170 | module_str, 171 | device, 172 | target, 173 | ireec_flags, 174 | safe_name, 175 | return_path=not exit_on_vmfb, 176 | const_expr_hoisting=True, 177 | attn_spec=attn_spec, 178 | ) 179 | return vmfb_path 180 | 181 | 182 | if __name__ == "__main__": 183 | from turbine_models.custom_models.sd_inference.sd_cmd_opts import args 184 | 185 | mod_str, _ = export_clip_model( 186 | args.hf_model_name, 187 | args.max_length, 188 | args.precision, 189 | args.compile_to, 190 | args.external_weights, 191 | args.external_weight_path, 192 | args.device, 193 | args.iree_target_triple, 194 | args.ireec_flags + args.clip_flags, 195 | exit_on_vmfb=True, 196 | pipeline_dir=args.pipeline_dir, 197 | input_mlir=args.input_mlir, 198 | attn_spec=args.attn_spec, 199 | weights_only=False, 200 | upload_ir=False, 201 | ) 202 | if args.input_mlir: 203 | exit() 204 | safe_name = utils.create_safe_name( 205 | args.hf_model_name, f"{str(args.max_length)}_{args.precision}_clip" 206 | ) 207 | with open(f"{safe_name}.mlir", "w+") as f: 208 | f.write(mod_str) 209 | print("Saved to", safe_name + ".mlir") 210 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd_inference/clip_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from turbine_models.model_runner import vmfbRunner 3 | from transformers import CLIPTokenizer 4 | from iree import runtime as ireert 5 | import torch 6 | 7 | 8 | def run_clip( 9 | device, prompt, vmfb_path, hf_model_name, hf_auth_token, external_weight_path 10 | ): 11 | runner = vmfbRunner(device, vmfb_path, external_weight_path) 12 | 13 | if "google/t5" in hf_model_name: 14 | from transformers import T5Tokenizer, T5Model 15 | 16 | tokenizer = T5Tokenizer.from_pretrained(hf_model_name) 17 | text_input = tokenizer( 18 | prompt, 19 | padding="max_length", 20 | max_length=tokenizer.model_max_length, 21 | truncation=True, 22 | return_tensors="pt", 23 | ) 24 | # TODO: Integrate with HFTransformerBuilder 25 | else: 26 | if "openai" in hf_model_name: 27 | from transformers import CLIPProcessor 28 | import requests 29 | 30 | tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 31 | text_input = tokenizer( 32 | text=prompt, 33 | truncation=True, 34 | padding=True, 35 | return_tensors="pt", 36 | ) 37 | else: 38 | hf_subfolder = "tokenizer" 39 | 40 | tokenizer = CLIPTokenizer.from_pretrained( 41 | hf_model_name, 42 | subfolder=hf_subfolder, 43 | token=hf_auth_token, 44 | ) 45 | 46 | text_input = tokenizer( 47 | prompt, 48 | padding="max_length", 49 | max_length=tokenizer.model_max_length, 50 | truncation=True, 51 | return_tensors="pt", 52 | ) 53 | example_input = text_input.input_ids 54 | inp = [ireert.asdevicearray(runner.config.device, example_input)] 55 | 56 | if "google/t5" in hf_model_name: 57 | inp += [ireert.asdevicearray(runner.config.device, example_input)] 58 | results = runner.ctx.modules.compiled_text_encoder["encode_tokens"](*inp) 59 | return results 60 | 61 | 62 | def run_torch_clip(hf_model_name, hf_auth_token, prompt): 63 | if "google/t5" in hf_model_name: 64 | from transformers import T5Tokenizer, T5Model 65 | 66 | tokenizer = T5Tokenizer.from_pretrained(hf_model_name) 67 | model = T5Model.from_pretrained(hf_model_name) 68 | text_input = tokenizer( 69 | prompt, 70 | padding="max_length", 71 | max_length=tokenizer.model_max_length, 72 | truncation=True, 73 | return_tensors="pt", 74 | ) 75 | # TODO: Integrate with HFTransformerBuilder 76 | else: 77 | if hf_model_name == "openai/clip-vit-large-patch14": 78 | from transformers import CLIPProcessor 79 | 80 | tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 81 | hf_subfolder = "" # CLIPProcessor does not have a subfolder 82 | from transformers import CLIPTextModel 83 | 84 | model = CLIPTextModel.from_pretrained( 85 | hf_model_name, 86 | subfolder=hf_subfolder, 87 | token=hf_auth_token, 88 | ) 89 | text_input = tokenizer( 90 | text=prompt, 91 | truncation=True, 92 | padding=True, 93 | return_tensors="pt", 94 | ) 95 | else: 96 | hf_subfolder = "text_encoder" 97 | 98 | tokenizer = CLIPTokenizer.from_pretrained( 99 | hf_model_name, 100 | subfolder="tokenizer", 101 | token=hf_auth_token, 102 | ) 103 | 104 | from transformers import CLIPTextModel 105 | 106 | model = CLIPTextModel.from_pretrained( 107 | hf_model_name, 108 | subfolder=hf_subfolder, 109 | token=hf_auth_token, 110 | ) 111 | text_input = tokenizer( 112 | prompt, 113 | padding="max_length", 114 | max_length=tokenizer.model_max_length, 115 | truncation=True, 116 | return_tensors="pt", 117 | ) 118 | example_input = text_input.input_ids 119 | 120 | if "google/t5" in hf_model_name: 121 | results = model.forward(example_input, decoder_input_ids=example_input)[0] 122 | else: 123 | results = model.forward(example_input)[0] 124 | np_torch_output = results.detach().cpu().numpy() 125 | return np_torch_output 126 | 127 | 128 | if __name__ == "__main__": 129 | from turbine_models.custom_models.sd_inference.sd_cmd_opts import args 130 | 131 | turbine_output = run_clip( 132 | args.device, 133 | args.prompt, 134 | args.vmfb_path, 135 | args.hf_model_name, 136 | args.hf_auth_token, 137 | args.external_weight_path, 138 | ) 139 | print( 140 | "TURBINE OUTPUT:", 141 | turbine_output[0].to_host(), 142 | turbine_output[0].to_host().shape, 143 | turbine_output[0].to_host().dtype, 144 | ) 145 | if args.compare_vs_torch: 146 | print("generating torch output: ") 147 | from turbine_models.custom_models.sd_inference import utils 148 | 149 | torch_output = run_torch_clip( 150 | args.hf_model_name, args.hf_auth_token, args.prompt 151 | ) 152 | print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) 153 | err = utils.largest_error(torch_output, turbine_output[0]) 154 | print("Largest Error: ", err) 155 | assert err < 9e-5 156 | # TODO: Figure out why we occasionally segfault without unlinking output variables 157 | turbine_output = None 158 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd_inference/schedulers_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Advanced Micro Devices, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | from turbine_models.model_runner import vmfbRunner 8 | from iree import runtime as ireert 9 | import torch 10 | from diffusers import ( 11 | UNet2DConditionModel, 12 | ) 13 | 14 | 15 | def run_scheduler( 16 | device, 17 | sample, 18 | encoder_hidden_states, 19 | vmfb_path, 20 | hf_model_name, 21 | hf_auth_token, 22 | external_weight_path, 23 | ): 24 | runner = vmfbRunner(device, vmfb_path, external_weight_path) 25 | 26 | inputs = [ 27 | ireert.asdevicearray(runner.config.device, sample), 28 | ireert.asdevicearray(runner.config.device, encoder_hidden_states), 29 | ] 30 | results = runner.ctx.modules.compiled_scheduler["main"](*inputs) 31 | return results 32 | 33 | 34 | def run_sdxl_scheduler( 35 | device, 36 | sample, 37 | prompt_embeds, 38 | text_embeds, 39 | time_ids, 40 | vmfb_path, 41 | hf_model_name, 42 | hf_auth_token, 43 | external_weight_path, 44 | ): 45 | runner = vmfbRunner(device, vmfb_path, external_weight_path) 46 | 47 | inputs = [ 48 | ireert.asdevicearray(runner.config.device, sample), 49 | ireert.asdevicearray(runner.config.device, prompt_embeds), 50 | ireert.asdevicearray(runner.config.device, text_embeds), 51 | ireert.asdevicearray(runner.config.device, time_ids), 52 | ] 53 | results = runner.ctx.modules.compiled_scheduler["main"](*inputs) 54 | return results 55 | 56 | 57 | def run_torch_scheduler( 58 | hf_model_name, 59 | scheduler, 60 | num_inference_steps, 61 | sample, 62 | prompt_embeds, 63 | text_embeds, 64 | time_ids, 65 | ): 66 | class SDXLScheduler(torch.nn.Module): 67 | def __init__( 68 | self, 69 | hf_model_name, 70 | num_inference_steps, 71 | scheduler, 72 | hf_auth_token=None, 73 | precision="fp32", 74 | ): 75 | super().__init__() 76 | self.scheduler = scheduler 77 | self.scheduler.set_timesteps(num_inference_steps) 78 | self.guidance_scale = 7.5 79 | if precision == "fp16": 80 | try: 81 | self.unet = UNet2DConditionModel.from_pretrained( 82 | hf_model_name, 83 | subfolder="unet", 84 | auth_token=hf_auth_token, 85 | low_cpu_mem_usage=False, 86 | variant="fp16", 87 | ) 88 | except: 89 | self.unet = UNet2DConditionModel.from_pretrained( 90 | hf_model_name, 91 | subfolder="unet", 92 | auth_token=hf_auth_token, 93 | low_cpu_mem_usage=False, 94 | ) 95 | else: 96 | self.unet = UNet2DConditionModel.from_pretrained( 97 | hf_model_name, 98 | subfolder="unet", 99 | auth_token=hf_auth_token, 100 | low_cpu_mem_usage=False, 101 | ) 102 | 103 | def forward(self, sample, prompt_embeds, text_embeds, time_ids): 104 | sample = sample * self.scheduler.init_noise_sigma 105 | for t in self.scheduler.timesteps: 106 | with torch.no_grad(): 107 | added_cond_kwargs = { 108 | "text_embeds": text_embeds, 109 | "time_ids": time_ids, 110 | } 111 | latent_model_input = torch.cat([sample] * 2) 112 | t = t.unsqueeze(0) 113 | # print('UNSQUEEZE T:', t) 114 | latent_model_input = self.scheduler.scale_model_input( 115 | latent_model_input, timestep=t 116 | ) 117 | noise_pred = self.unet.forward( 118 | latent_model_input, 119 | t, 120 | encoder_hidden_states=prompt_embeds, 121 | cross_attention_kwargs=None, 122 | added_cond_kwargs=added_cond_kwargs, 123 | return_dict=False, 124 | )[0] 125 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 126 | noise_pred = noise_pred_uncond + self.guidance_scale * ( 127 | noise_pred_text - noise_pred_uncond 128 | ) 129 | sample = self.scheduler.step( 130 | noise_pred, t, sample, return_dict=False 131 | )[0] 132 | return sample 133 | 134 | scheduler_module = SDXLScheduler( 135 | hf_model_name, 136 | num_inference_steps, 137 | scheduler, 138 | hf_auth_token=None, 139 | precision="fp16", 140 | ) 141 | results = scheduler_module.forward(sample, prompt_embeds, text_embeds, time_ids) 142 | np_torch_output = results.detach().cpu().numpy() 143 | return np_torch_output 144 | 145 | 146 | if __name__ == "__main__": 147 | from turbine_models.custom_models.sd_inference.sd_cmd_opts import args 148 | 149 | sample = torch.rand( 150 | args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 151 | ) 152 | if args.hf_model_name == "CompVis/stable-diffusion-v1-4": 153 | encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) 154 | elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": 155 | encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) 156 | 157 | sample = torch.rand(args.batch_size, 4, args.height // 8, args.width // 8) 158 | prompt_embeds = torch.rand(2, 77, 2048) 159 | text_embeds = torch.rand(2, 1280) 160 | time_ids = torch.rand(2, 6) 161 | turbine_output = run_sdxl_scheduler( 162 | args.device, 163 | sample, 164 | prompt_embeds, 165 | text_embeds, 166 | time_ids, 167 | args.vmfb_path, 168 | args.hf_model_name, 169 | args.hf_auth_token, 170 | args.external_weight_path, 171 | ) 172 | print( 173 | "TURBINE OUTPUT:", 174 | turbine_output.to_host(), 175 | turbine_output.to_host().shape, 176 | turbine_output.to_host().dtype, 177 | ) 178 | 179 | if args.compare_vs_torch: 180 | print("generating torch output: ") 181 | from turbine_models.custom_models.sd_inference import utils 182 | 183 | schedulers = utils.get_schedulers(args.hf_model_name) 184 | scheduler = schedulers[args.scheduler_id] 185 | torch_output = run_torch_scheduler( 186 | args.hf_model_name, 187 | scheduler, 188 | args.num_inference_steps, 189 | sample, 190 | prompt_embeds, 191 | text_embeds, 192 | time_ids, 193 | ) 194 | print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) 195 | err = utils.largest_error(torch_output, turbine_output) 196 | print("Largest Error: ", err) 197 | assert err < 9e-3 198 | 199 | # TODO: Figure out why we occasionally segfault without unlinking output variables 200 | turbine_output = None 201 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd_inference/tokenization.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | from iree import runtime as ireert 3 | import re 4 | import torch 5 | import numpy as np 6 | import warnings 7 | 8 | 9 | # The following is copied from Diffusers' "encode_prompt" function in the StableDiffusion pipeline. 10 | # It has been lightly augmented to work with the SHARK-Turbine pipeline. 11 | def encode_prompt( 12 | pipe, 13 | prompt, 14 | negative_prompt=None, 15 | num_images_per_prompt=1, 16 | do_classifier_free_guidance=True, 17 | prompt_embeds: Optional[torch.Tensor] = None, 18 | negative_prompt_embeds: Optional[torch.Tensor] = None, 19 | lora_scale: Optional[float] = None, 20 | clip_skip: Optional[int] = None, 21 | ): 22 | r""" 23 | Encodes the prompt into text encoder hidden states. 24 | 25 | Args: 26 | prompt (`str` or `List[str]`, *optional*): 27 | prompt to be encoded 28 | num_images_per_prompt (`int`): 29 | number of images that should be generated per prompt 30 | do_classifier_free_guidance (`bool`): 31 | whether to use classifier free guidance or not 32 | negative_prompt (`str` or `List[str]`, *optional*): 33 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 34 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 35 | less than `1`). 36 | prompt_embeds (`torch.Tensor`, *optional*): 37 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 38 | provided, text embeddings will be generated from `prompt` input argument. 39 | negative_prompt_embeds (`torch.Tensor`, *optional*): 40 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 41 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 42 | argument. 43 | lora_scale (`float`, *optional*): 44 | A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 45 | clip_skip (`int`, *optional*): 46 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 47 | the output of the pre-final layer will be used for computing the prompt embeddings. 48 | """ 49 | # set lora scale so that monkey patched LoRA 50 | # function of text encoder can correctly access it 51 | # if lora_scale is not None and pipe.use_lora: 52 | # pipe._lora_scale = lora_scale 53 | 54 | # # dynamically adjust the LoRA scale 55 | # if not USE_PEFT_BACKEND: 56 | # adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) 57 | # else: 58 | # scale_lora_layers(pipe.text_encoder, lora_scale) 59 | 60 | if prompt is not None and isinstance(prompt, str): 61 | batch_size = 1 62 | elif prompt is not None and isinstance(prompt, list): 63 | batch_size = len(prompt) 64 | else: 65 | batch_size = prompt_embeds.shape[0] 66 | 67 | if prompt_embeds is None: 68 | # textual inversion: process multi-vector tokens if necessary 69 | # if pipe.use_textual_inversion: 70 | # prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer) 71 | 72 | text_inputs = pipe.tokenizer( 73 | prompt, 74 | padding="max_length", 75 | max_length=pipe.model_max_length, 76 | truncation=True, 77 | return_tensors="pt", 78 | ) 79 | text_input_ids = text_inputs.input_ids 80 | untruncated_ids = pipe.tokenizer( 81 | prompt, padding="longest", return_tensors="pt" 82 | ).input_ids 83 | 84 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 85 | text_input_ids, untruncated_ids 86 | ): 87 | removed_text = pipe.tokenizer.batch_decode( 88 | untruncated_ids[:, pipe.model_max_length - 1 : -1] 89 | ) 90 | warnings.warn( 91 | "The following text was removed due to truncation: " + removed_text 92 | ) 93 | if pipe.text_encoder.metadata.get("use_attention_mask"): 94 | attention_mask = text_inputs.attention_mask 95 | prompt_embeds = pipe.text_encoder( 96 | "encode_tokens_attn_mask", [text_input_ids, attention_mask] 97 | ) 98 | else: 99 | attention_mask = None 100 | prompt_embeds = pipe.text_encoder("encode_tokens", [text_input_ids]) 101 | prompt_embeds = prompt_embeds[0] 102 | bs_embed, seq_len, _ = prompt_embeds.shape 103 | # duplicate text embeddings for each generation per prompt, using mps friendly method 104 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 105 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 106 | 107 | # get unconditional embeddings for classifier free guidance 108 | if do_classifier_free_guidance and negative_prompt_embeds is None: 109 | uncond_tokens: List[str] 110 | if negative_prompt is None: 111 | uncond_tokens = [""] * batch_size 112 | elif prompt is not None and type(prompt) is not type(negative_prompt): 113 | raise TypeError( 114 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 115 | f" {type(prompt)}." 116 | ) 117 | elif isinstance(negative_prompt, str): 118 | uncond_tokens = [negative_prompt] 119 | elif batch_size != len(negative_prompt): 120 | raise ValueError( 121 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 122 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 123 | " the batch size of `prompt`." 124 | ) 125 | else: 126 | uncond_tokens = negative_prompt 127 | 128 | # textual inversion: process multi-vector tokens if necessary 129 | # if pipe.use_textual_inversion: 130 | # uncond_tokens = pipe.maybe_convert_prompt(uncond_tokens, pipe.tokenizer) 131 | 132 | max_length = prompt_embeds.shape[1] 133 | uncond_input = pipe.tokenizer( 134 | uncond_tokens, 135 | padding="max_length", 136 | max_length=max_length, 137 | truncation=True, 138 | return_tensors="pt", 139 | ) 140 | 141 | if pipe.text_encoder.metadata.get("use_attention_mask"): 142 | attention_mask = uncond_input.attention_mask 143 | negative_prompt_embeds = pipe.text_encoder( 144 | "encode_tokens_attn_mask", 145 | [ 146 | uncond_input.input_ids, 147 | attention_mask, 148 | ], 149 | ) 150 | else: 151 | attention_mask = None 152 | negative_prompt_embeds = pipe.text_encoder( 153 | "encode_tokens", 154 | [ 155 | uncond_input.input_ids, 156 | ], 157 | ) 158 | 159 | negative_prompt_embeds = negative_prompt_embeds[0] 160 | 161 | if do_classifier_free_guidance: 162 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 163 | seq_len = negative_prompt_embeds.shape[1] 164 | 165 | negative_prompt_embeds = negative_prompt_embeds.repeat( 166 | 1, num_images_per_prompt, 1 167 | ) 168 | negative_prompt_embeds = negative_prompt_embeds.view( 169 | batch_size * num_images_per_prompt, seq_len, -1 170 | ) 171 | 172 | # if pipe.use_lora: 173 | # Retrieve the original scale by scaling back the LoRA layers 174 | # unimplemented 175 | # unscale_lora_layers(pipe.text_encoder, lora_scale) 176 | 177 | return prompt_embeds, negative_prompt_embeds 178 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd_inference/unet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import os 8 | import sys 9 | import copy 10 | 11 | from iree import runtime as ireert 12 | from iree.compiler.ir import Context 13 | import numpy as np 14 | from iree.turbine.aot import * 15 | from iree.turbine.dynamo.passes import ( 16 | DEFAULT_DECOMPOSITIONS, 17 | ) 18 | from iree.turbine.transforms.general.add_metadata import AddMetadataPass 19 | from turbine_models.custom_models.sd_inference import utils 20 | import torch 21 | import torch._dynamo as dynamo 22 | from diffusers import UNet2DConditionModel 23 | 24 | import safetensors 25 | import argparse 26 | from turbine_models.turbine_tank import turbine_tank 27 | 28 | 29 | class UnetModel(torch.nn.Module): 30 | def __init__(self, hf_model_name): 31 | super().__init__() 32 | self.do_classifier_free_guidance = True 33 | self.unet = UNet2DConditionModel.from_pretrained( 34 | hf_model_name, 35 | subfolder="unet", 36 | ) 37 | 38 | def forward( 39 | self, latent_model_input, timestep, encoder_hidden_states, guidance_scale 40 | ): 41 | noise_pred = self.unet.forward( 42 | latent_model_input, timestep, encoder_hidden_states, return_dict=False 43 | )[0] 44 | if self.do_classifier_free_guidance: 45 | noise_preds = noise_pred.chunk(2) 46 | noise_pred = noise_preds[0] + guidance_scale * ( 47 | noise_preds[1] - noise_preds[0] 48 | ) 49 | return noise_pred 50 | 51 | 52 | def export_unet_model( 53 | hf_model_name, 54 | batch_size, 55 | height, 56 | width, 57 | precision="fp32", 58 | max_length=77, 59 | compile_to="torch", 60 | external_weights=None, 61 | external_weight_path=None, 62 | device=None, 63 | target=None, 64 | ireec_flags=None, 65 | decomp_attn=False, 66 | exit_on_vmfb=False, 67 | pipeline_dir=None, 68 | attn_spec=None, 69 | input_mlir=None, 70 | weights_only=False, 71 | upload_ir=False, 72 | ): 73 | if input_mlir: 74 | unet_model = None 75 | else: 76 | unet_model = UnetModel( 77 | hf_model_name, 78 | ) 79 | dtype = torch.float16 if precision == "fp16" else torch.float32 80 | np_dtype = "float16" if precision == "fp16" else "float32" 81 | safe_name = utils.create_safe_name( 82 | hf_model_name, 83 | f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_unet", 84 | ) 85 | if decomp_attn: 86 | safe_name += "_decomp_attn" 87 | if pipeline_dir: 88 | safe_name = os.path.join(pipeline_dir, safe_name) 89 | 90 | if input_mlir: 91 | vmfb_path = utils.compile_to_vmfb( 92 | input_mlir, 93 | device, 94 | target, 95 | ireec_flags, 96 | safe_name, 97 | mlir_source="file", 98 | return_path=not exit_on_vmfb, 99 | attn_spec=attn_spec, 100 | ) 101 | return vmfb_path 102 | 103 | mapper = {} 104 | 105 | if precision == "fp16": 106 | unet_model = unet_model.half() 107 | 108 | utils.save_external_weights( 109 | mapper, unet_model, external_weights, external_weight_path 110 | ) 111 | 112 | if weights_only: 113 | return external_weight_path 114 | 115 | sample = ( 116 | batch_size * 2, 117 | unet_model.unet.config.in_channels, 118 | height // 8, 119 | width // 8, 120 | ) 121 | encoder_hidden_states_sizes = ( 122 | unet_model.unet.config.layers_per_block, 123 | max_length, 124 | unet_model.unet.config.cross_attention_dim, 125 | ) 126 | example_forward_args = [ 127 | torch.empty(sample, dtype=dtype), 128 | torch.empty(1, dtype=dtype), 129 | torch.empty(encoder_hidden_states_sizes, dtype=dtype), 130 | torch.empty(1, dtype=dtype), 131 | ] 132 | decomp_list = [] 133 | if decomp_attn: 134 | decomp_list = [ 135 | torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, 136 | torch.ops.aten._scaled_dot_product_flash_attention.default, 137 | torch.ops.aten.scaled_dot_product_attention, 138 | ] 139 | with decompositions.extend_aot_decompositions( 140 | from_current=True, 141 | add_ops=decomp_list, 142 | ): 143 | fxb = FxProgramsBuilder(unet_model) 144 | 145 | @fxb.export_program( 146 | args=(example_forward_args,), 147 | ) 148 | def _forward( 149 | module, 150 | inputs, 151 | ): 152 | return module.forward(*inputs) 153 | 154 | class CompiledUnet(CompiledModule): 155 | run_forward = _forward 156 | 157 | if external_weights: 158 | externalize_module_parameters(unet_model) 159 | 160 | inst = CompiledUnet(context=Context(), import_to="IMPORT") 161 | 162 | module = CompiledModule.get_mlir_module(inst) 163 | 164 | model_metadata_run_forward = { 165 | "model_name": "sd_unet", 166 | "input_shapes": [ 167 | sample, 168 | (1,), 169 | encoder_hidden_states_sizes, 170 | (1,), 171 | ], 172 | "input_dtypes": [np_dtype for x in range(4)], 173 | "output_shapes": [sample], 174 | "output_dtypes": [np_dtype], 175 | } 176 | 177 | module = AddMetadataPass(module, model_metadata_run_forward, "run_forward").run() 178 | module_str = str(module) 179 | if compile_to != "vmfb": 180 | return module_str 181 | else: 182 | vmfb_path = utils.compile_to_vmfb( 183 | module_str, 184 | device, 185 | target, 186 | ireec_flags, 187 | safe_name, 188 | return_path=True, 189 | attn_spec=attn_spec, 190 | ) 191 | if exit_on_vmfb: 192 | exit() 193 | return vmfb_path 194 | 195 | 196 | if __name__ == "__main__": 197 | from turbine_models.custom_models.sd_inference.sd_cmd_opts import args 198 | 199 | mod_str = export_unet_model( 200 | args.hf_model_name, 201 | args.batch_size, 202 | args.height, 203 | args.width, 204 | args.precision, 205 | args.max_length, 206 | args.compile_to, 207 | args.external_weights, 208 | args.external_weight_path, 209 | args.device, 210 | args.iree_target_triple, 211 | args.ireec_flags + args.attn_flags + args.unet_flags, 212 | args.decomp_attn, 213 | attn_spec=args.attn_spec, 214 | input_mlir=args.input_mlir, 215 | ) 216 | if args.input_mlir: 217 | exit() 218 | safe_name = utils.create_safe_name( 219 | args.hf_model_name, 220 | f"_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}_unet", 221 | ) 222 | with open(f"{safe_name}.mlir", "w+") as f: 223 | f.write(mod_str) 224 | print("Saved to", safe_name + ".mlir") 225 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd_inference/unet_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from turbine_models.model_runner import vmfbRunner 3 | from transformers import CLIPTokenizer 4 | from iree import runtime as ireert 5 | import torch 6 | 7 | 8 | def run_unet( 9 | device, 10 | sample, 11 | timestep, 12 | encoder_hidden_states, 13 | guidance_scale, 14 | vmfb_path, 15 | hf_model_name, 16 | hf_auth_token, 17 | external_weight_path, 18 | iree_dtype, 19 | ): 20 | runner = vmfbRunner(device, vmfb_path, external_weight_path) 21 | inputs = [ 22 | ireert.asdevicearray(runner.config.device, sample, dtype=iree_dtype), 23 | ireert.asdevicearray(runner.config.device, timestep, dtype=iree_dtype), 24 | ireert.asdevicearray( 25 | runner.config.device, encoder_hidden_states, dtype=iree_dtype 26 | ), 27 | ireert.asdevicearray(runner.config.device, guidance_scale, dtype=iree_dtype), 28 | ] 29 | results = runner.ctx.modules.compiled_unet["run_forward"](*inputs) 30 | return results 31 | 32 | 33 | def run_torch_unet( 34 | hf_model_name, 35 | hf_auth_token, 36 | sample, 37 | timestep, 38 | encoder_hidden_states, 39 | guidance_scale, 40 | ): 41 | from turbine_models.custom_models.sd_inference.unet import UnetModel 42 | 43 | unet_model = UnetModel( 44 | hf_model_name, 45 | ) 46 | results = unet_model.forward( 47 | sample, timestep, encoder_hidden_states, guidance_scale 48 | ) 49 | np_torch_output = results.detach().cpu().numpy() 50 | return np_torch_output 51 | 52 | 53 | if __name__ == "__main__": 54 | args = parser.parse_args() 55 | iree_dtypes = { 56 | "fp16": "float16", 57 | "fp32": "float32", 58 | } 59 | sample = torch.rand( 60 | args.batch_size * 2, 4, args.height // 8, args.width // 8, dtype=torch.float32 61 | ) 62 | timestep = torch.zeros(1, dtype=torch.float32) 63 | guidance_scale = torch.Tensor([7.5], dtype=torch.float32) 64 | if args.hf_model_name == "CompVis/stable-diffusion-v1-4": 65 | encoder_hidden_states = torch.rand(2, args.max_length, 768, dtype=torch.float32) 66 | elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": 67 | encoder_hidden_states = torch.rand( 68 | 2, args.max_length, 1024, dtype=torch.float32 69 | ) 70 | 71 | turbine_output = run_unet( 72 | args.device, 73 | sample, 74 | timestep, 75 | encoder_hidden_states, 76 | guidance_scale, 77 | args.vmfb_path, 78 | args.hf_model_name, 79 | args.hf_auth_token, 80 | args.external_weight_path, 81 | iree_dtypes[args.precision], 82 | ) 83 | print( 84 | "TURBINE OUTPUT:", 85 | turbine_output.to_host(), 86 | turbine_output.to_host().shape, 87 | turbine_output.to_host().dtype, 88 | ) 89 | 90 | if args.compare_vs_torch: 91 | print("generating torch output: ") 92 | from turbine_models.custom_models.sd_inference import utils 93 | from turbine_models.custom_models.sd_inference.sd_cmd_opts import args 94 | 95 | torch_output = run_torch_unet( 96 | args.hf_model_name, 97 | args.hf_auth_token, 98 | sample, 99 | timestep, 100 | encoder_hidden_states, 101 | guidance_scale, 102 | ) 103 | print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) 104 | err = utils.largest_error(torch_output, turbine_output) 105 | print("Largest Error: ", err) 106 | assert err < 9e-5 107 | 108 | # TODO: Figure out why we occasionally segfault without unlinking output variables 109 | turbine_output = None 110 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sd_inference/vae_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from turbine_models.model_runner import vmfbRunner 3 | from transformers import CLIPTokenizer 4 | from iree import runtime as ireert 5 | import torch 6 | 7 | 8 | def run_vae_decode( 9 | device, example_input, vmfb_path, hf_model_name, external_weight_path 10 | ): 11 | runner = vmfbRunner(device, vmfb_path, external_weight_path) 12 | 13 | inputs = [ireert.asdevicearray(runner.config.device, example_input)] 14 | 15 | results = runner.ctx.modules.compiled_vae["decode"](*inputs).to_host() 16 | 17 | return results 18 | 19 | 20 | def run_torch_vae_decode(hf_model_name, variant, example_input): 21 | from diffusers import AutoencoderKL 22 | 23 | class VaeModel(torch.nn.Module): 24 | def __init__( 25 | self, 26 | hf_model_name, 27 | base_vae=False, 28 | custom_vae="", 29 | low_cpu_mem_usage=False, 30 | hf_auth_token="", 31 | ): 32 | super().__init__() 33 | self.vae = None 34 | if custom_vae == "": 35 | self.vae = AutoencoderKL.from_pretrained( 36 | hf_model_name, 37 | subfolder="vae", 38 | low_cpu_mem_usage=low_cpu_mem_usage, 39 | hf_auth_token=hf_auth_token, 40 | ) 41 | elif not isinstance(custom_vae, dict): 42 | self.vae = AutoencoderKL.from_pretrained( 43 | custom_vae, 44 | subfolder="vae", 45 | low_cpu_mem_usage=low_cpu_mem_usage, 46 | hf_auth_token=hf_auth_token, 47 | ) 48 | else: 49 | self.vae = AutoencoderKL.from_pretrained( 50 | hf_model_name, 51 | subfolder="vae", 52 | low_cpu_mem_usage=low_cpu_mem_usage, 53 | hf_auth_token=hf_auth_token, 54 | ) 55 | self.vae.load_state_dict(custom_vae) 56 | self.base_vae = base_vae 57 | 58 | def decode_inp(self, input): 59 | with torch.no_grad(): 60 | input = 1 / 0.18215 * input 61 | x = self.vae.decode(input, return_dict=False)[0] 62 | return (x / 2 + 0.5).clamp(0, 1) 63 | 64 | def encode_inp(self, inp): 65 | latents = self.vae.encode(inp).latent_dist.sample() 66 | return 0.18215 * latents 67 | 68 | vae_model = VaeModel( 69 | hf_model_name, 70 | ) 71 | 72 | if variant == "decode": 73 | results = vae_model.decode_inp(example_input) 74 | elif variant == "encode": 75 | results = vae_model.encode_inp(example_input) 76 | np_torch_output = results.detach().cpu().numpy() 77 | return np_torch_output 78 | 79 | 80 | if __name__ == "__main__": 81 | from turbine_models.custom_models.sd_inference.sd_cmd_opts import args 82 | 83 | if args.variant == "decode": 84 | example_input = torch.rand( 85 | args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 86 | ) 87 | elif args.variant == "encode": 88 | example_input = torch.rand( 89 | args.batch_size, 3, args.height, args.width, dtype=torch.float32 90 | ) 91 | print("generating turbine output:") 92 | turbine_results = run_vae_decode( 93 | args.device, 94 | example_input, 95 | args.vmfb_path, 96 | args.hf_model_name, 97 | args.external_weight_path, 98 | ) 99 | print( 100 | "TURBINE OUTPUT:", 101 | turbine_results.to_host(), 102 | turbine_results.to_host().shape, 103 | turbine_results.to_host().dtype, 104 | ) 105 | if args.compare_vs_torch: 106 | print("generating torch output: ") 107 | from turbine_models.custom_models.sd_inference import utils 108 | 109 | torch_output = run_torch_vae_decode( 110 | args.hf_model_name, args.hf_auth_token, args.variant, example_input 111 | ) 112 | print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) 113 | err = utils.largest_error(torch_output, turbine_results) 114 | print("Largest Error: ", err) 115 | assert err < 3e-3 116 | 117 | # TODO: Figure out why we occasionally segfault without unlinking output variables 118 | turbine_results = None 119 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sdxl_inference/COMMANDS.md: -------------------------------------------------------------------------------- 1 | 2 | # SHARK-Turbine SDXL CLI usage (ROCM) 3 | 4 | ## Pipeline (txt2img): 5 | 6 | Note: These commands are generally for unix, and use `$WEIGHTS_DIR`, `$PIPELINE_DIR`, and `$TARGET_TRIPLE` in place of actual values. You can set these env variables or replace them in the commands as desired. 7 | 8 | ```shell 9 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py --precision=fp16 --external_weights=irpa --device=rocm --rt_device=rocm --iree_target_triple=$TARGET_TRIPLE --scheduler_id=PNDM --num_inference_steps=30 --pipeline_dir=$PIPELINE_DIR --external_weights_dir=$WEIGHTS_DIR --attn_spec=default --compiled_pipeline 10 | 11 | iree-benchmark-module \ 12 | --module=$PWD/stable_diffusion_xl_base_1_0_64_fp16_prompt_encoder_rocm.vmfb \ 13 | --parameters=model=$WEIGHTS_DIR/prompt_encoder.irpa \ 14 | --module=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_scheduled_unet_rocm.vmfb \ 15 | --parameters=model=$WEIGHTS_DIR/unet.irpa \ 16 | --module=$PWD/stable_diffusion_xl_base_1_0_1024x1024_fp16_vae_decode_rocm.vmfb \ 17 | --parameters=model=$WEIGHTS_DIR/vae_decode.irpa \ 18 | --module=$PWD/sdxl_pipeline_fp16_$TARGET_TRIPLE.vmfb \ 19 | --function=tokens_to_image \ 20 | --input=1x4x128x128xf16 \ 21 | --input=1xf16 \ 22 | --input=1x64xi64 \ 23 | --input=1x64xi64 \ 24 | --input=1x64xi64 \ 25 | --input=1x64xi64 \ 26 | --device_allocator=caching \ 27 | --benchmark_repetitions=1 \ 28 | --device=rocm 29 | ``` 30 | Note: you can either manually compile the pipeline vmfb from the .mlir in sdxl_inference, or by running the sdxl_scheduled_unet.py script. 31 | The sdxl_compiled_pipeline script will do this for you, and you can switch between the segmented pipeline and the 'tokens->image' one-shot pipeline using `--compiled_pipeline` (if present, script will run the latter.) 32 | 33 | ## Scheduled UNet 34 | 35 | ``` 36 | # Import to MLIR: 37 | 38 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=mlir --external_weight_path=$WEIGHTS_DIR/unet.safetensors 39 | 40 | # Compile to VMFB (MLIR not necessary here but this is faster if you are compiling more than once): 41 | 42 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=vmfb --iree_target_triple=$TARGET_TRIPLE --input_mlir=$PWD/stable_diffusion_xl_base_1_0_PNDM_64_1024x1024_fp16_unet_30.mlir 43 | 44 | # Test numerics (validate against pytorch cpu): 45 | 46 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py --compare_vs_torch --precision=fp16 --device=rocm --external_weight_path=$WEIGHTS_DIR/scheduled_unet.irpa --max_length=64 --pipeline_vmfb_path=./sdxl_pipeline_fp16_$TARGET_TRIPLE.vmfb --vmfb_path=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_scheduled_unet_rocm.vmfb 47 | 48 | # Benchmark with IREE CLI: 49 | 50 | iree-benchmark-module --benchmark_repetitions=5 --device=rocm --module=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_scheduled_unet_rocm.vmfb --parameters=model=$WEIGHTS_DIR/scheduled_unet.irpa --function=run_forward --input=1x4x128x128xf16 --input=2x64x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=1xf16 --input=1xi64 --device_allocator=caching 51 | 52 | iree-benchmark-module --benchmark_repetitions=5 --device=rocm --module=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_scheduled_unet_rocm.vmfb --module=$PWD/sdxl_pipeline_fp16_$TARGET_TRIPLE.vmfb --parameters=model=$WEIGHTS_DIR/scheduled_unet.irpa --function=run_forward --input=1x4x128x128xf16 --input=2x64x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=1xf16 --input=1xi64 --device_allocator=caching 53 | ``` 54 | 55 | ## UNet 56 | 57 | ``` 58 | # Import to MLIR: 59 | 60 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/unet.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=mlir 61 | 62 | # Compile to VMFB (MLIR not necessary here but this is faster if you are compiling more than once): 63 | 64 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/unet.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=vmfb --iree_target_triple=$TARGET_TRIPLE --input_mlir=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet.mlir 65 | 66 | # Convert weights to IREE parameter archive format: 67 | 68 | iree-convert-parameters --parameters=$WEIGHTS_DIR/unet.safetensors --output=$WEIGHTS_DIR/scheduled_unet.irpa 69 | 70 | # Test numerics (validate against pytorch cpu): 71 | 72 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/unet_runner.py --compare_vs_torch --precision=fp16 --device=rocm --external_weight_path=$WEIGHTS_DIR/scheduled_unet.irpa --max_length=64 --vmfb_path=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet_rocm.vmfb 73 | 74 | # Benchmark with IREE CLI: 75 | 76 | iree-benchmark-module --benchmark_repetitions=5 --device=rocm --module=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet_rocm.vmfb --parameters=model=$WEIGHTS_DIR/scheduled_unet.irpa --function=main --input=1x4x128x128xf16 --input=1xi64 --input=2x64x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=1xf16 --device_allocator=caching 77 | ``` 78 | 79 | ## CLIP 80 | 81 | ``` 82 | # Import to MLIR: 83 | 84 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py --precision=fp16 --external_weights=safetensors --device=rocm --rt_device=rocm --compile_to=mlir --iree_target_triple=$TARGET_TRIPLE --external_weight_path=$WEIGHTS_DIR/prompt_encoder.safetensors 85 | 86 | # Compile to VMFB (MLIR not necessary here but this is faster if you are compiling more than once): 87 | 88 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py --precision=fp16 --external_weights=safetensors --device=rocm --rt_device=rocm --compile_to=vmfb --iree_target_triple=$TARGET_TRIPLE --input_mlir=$PWD/stable_diffusion_xl_base_1_0_64_fp16_prompt_encoder.mlir 89 | 90 | # Convert weights to IREE parameter archive format: 91 | 92 | iree-convert-parameters --parameters=$WEIGHTS_DIR/prompt_encoder.safetensors --output=$WEIGHTS_DIR/prompt_encoder.irpa 93 | 94 | # Test numerics (validate against pytorch cpu): 95 | 96 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py --compare_vs_torch --precision=fp16 --device=rocm --external_weight_path=$WEIGHTS_DIR/prompt_encoder.irpa --max_length=64 --vmfb_path=$PWD/stable_diffusion_xl_base_1_0_64_fp16_prompt_encoder_rocm.vmfb 97 | 98 | # Benchmark with IREE CLI: 99 | 100 | iree-benchmark-module --benchmark_repetitions=5 --device=rocm --module=$PWD/stable_diffusion_xl_base_1_0_64_fp16_prompt_encoder_rocm.vmfb --parameters=model=$WEIGHTS_DIR/prompt_encoder.irpa --function=encode_prompts --input=1x64xi64 --input=1x64xi64 --input=1x64xi64 --input=1x64xi64 --device_allocator=caching 101 | ``` 102 | 103 | 104 | ## VAE 105 | 106 | ``` 107 | # Import to MLIR: 108 | 109 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/vae.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=mlir --iree_target_triple=$TARGET_TRIPLE --external_weight_path=$WEIGHTS_DIR/vae_decode.safetensors 110 | 111 | # Compile to VMFB (MLIR not necessary here but this is faster if you are compiling more than once): 112 | 113 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/vae.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=vmfb --iree_target_triple=$TARGET_TRIPLE --input_mlir=$PWD/stable_diffusion_xl_base_1_0_1024x1024_fp16_vae_decode.mlir 114 | 115 | # Convert weights to IREE parameter archive format: 116 | 117 | iree-convert-parameters --parameters=$WEIGHTS_DIR/vae_decode.safetensors --output=$WEIGHTS_DIR/vae_decode.irpa 118 | 119 | # Test numerics (validate against pytorch cpu): 120 | 121 | python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/vae_runner.py --precision=fp16 --external_weights=irpa --device=rocm --iree_target_triple=$TARGET_TRIPLE --vmfb_path=$PWD/stable_diffusion_xl_base_1_0_1024x1024_fp16_vae_decode_rocm.vmfb --external_weight_path=$WEIGHTS_DIR/vae_decode.irpa --compare_vs_torch 122 | 123 | # Benchmark with IREE CLI: 124 | 125 | iree-benchmark-module --benchmark_repetitions=5 --module=$PWD/stable_diffusion_xl_base_1_0_1024x1024_fp16_vae_decode_rocm.vmfb --parameters=model=$WEIGHTS_DIR/vae_decode.irpa --device=rocm --input=1x4x128x128xf16 --device-allocator=caching --function=main 126 | ``` 127 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sdxl_inference/README.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion XL with SHARK-Turbine 2 | 3 | ## Support 4 | 5 | Following is a table that shows current status of turbine SDXL inference support for a few AMDGPU targets. This is not an exhaustive list of supported targets. 6 | 7 | | Target Chip | Attention Decomposed? | CLIP | UNet | VAE Decode | Txt2Img | 8 | |-------------|-----------------------|---------------|--------------------------------|--------------------------------|----------------| 9 | | gfx1100 | Yes | 💚 | 💛 (numerics with vector distribution)| 💚 | 💚 | 10 | | | No | | 💔 (Attn lowering) | 💔 (Attn lowering) | 💔 | 11 | | gfx90a | Yes | 💚 | 💚 | 💚 | 💚 | 12 | | | No | | 💛 (Numerics with mfma) | 💚 | 💛 | 13 | | gfx942 | Yes | 💚 | 💚 | 💚 | 💚 | 14 | | | No | | 💚 | 💚 | 💚 | 15 | 16 | ## Setup SHARK-Turbine for importing and running the SDXL pipeline or submodels. 17 | 18 | Linux: 19 | ```shell 20 | python -m venv turbine_venv 21 | source turbine_venv/bin/activate 22 | python -m pip install --upgrade pip 23 | cd .. 24 | git clone https://iree-org/iree-turbine 25 | cd iree-turbine 26 | pip install -r pytorch-cpu-requirements.txt 27 | pip install -e . 28 | cd ../SHARK-Turbine 29 | pip install --pre --upgrade -e models -r models/requirements.txt 30 | ``` 31 | 32 | Windows: 33 | ```shell 34 | python -m venv turbine_venv 35 | turbine_venv/Scripts/activate 36 | python -m pip install --upgrade pip 37 | cd .. 38 | git clone https://iree-org/iree-turbine 39 | cd iree-turbine 40 | pip install -r pytorch-cpu-requirements.txt 41 | pip install -e . 42 | cd ../SHARK-Turbine 43 | pip install --pre --upgrade -e models -r models/requirements.txt 44 | ``` 45 | 46 | ## Run tests 47 | ROCM: 48 | ``` 49 | pytest models/turbine_models/tests/sdxl_test.py --device=rocm --rt_device= --iree_target_triple=gfx --external_weights=safetensors 50 | ``` 51 | 52 | CPU: 53 | ``` 54 | pytest models/turbine_models/tests/sdxl_test.py --device=cpu --rt_device=local-task --iree_target_triple=x86-64_linux_gnu --external_weights=safetensors --precision=fp32 55 | ``` 56 | 57 | ## Run image generation pipeline 58 | 59 | ROCM: 60 | ``` 61 | python models\turbine_models\custom_models\sdxl_inference\sdxl_compiled_pipeline.py --iree_target_triple=gfx1100 --device=rocm --rt_device=hip --external_weights=safetensors 62 | ``` 63 | For mfma-capable hardware, use `--attn_spec=default` to lower attention ops to MFMA instructions. 64 | 65 | CPU: 66 | ``` 67 | pytest models/turbine_models/tests/sdxl_test.py --device=cpu --rt_device=local-task --iree_target_triple=x86-64_linux_gnu --external_weights=safetensors --precision=fp32 68 | ``` 69 | 70 | ## Shared CLI options 71 | - `--iree_target_triple`: use gfx1100 for 7900xt, gfx90a for MI210/MI250, gfx940 for MI300A, gfx942 for MI300X. For CPU, use x86_64-linux-gnu if you aren't sure. On Vulkan, this is something like `rdna3-7900-windows`. 72 | - `--rt_device`: if using pip install, `hip` will work correctly, but `rocm` will not. Source builds of IREE can support rocm with the `-DIREE_HAL_DRIVER_ROCM=ON -DIREE_EXTERNAL_HAL_DRIVERS="rocm"`, but that option is soon to be deprecated in favor of the HIP driver. 73 | - `--compiled_pipeline`: run one-shot SDXL in a MLIR wrapper, removing model glue from python execution layer 74 | - `--pipeline_dir`: directory in which to save or look for .vmfb files. 75 | - `--external_weights_dir`: directory in which to save or look for weights. 76 | - `--ireec_flags`: extra ireec flags to use for _all_ submodels. 77 | - `--unet_flags / --vae_flags / --clip_flags`: extra ireec flags for individual submodels. 78 | - `--precision`: fp16 or fp32. Default is fp16 and you should only use fp32 for cpu. 79 | - `--num_inference_steps`: (default 30) number of unet iterations to run. 80 | - `--batch_count`: Not compatible with `--compiled_pipeline`. Uses the same clip output to generate a set of images in a batch, with different image latents. 81 | - `--prompt / --negative_prompt`: prompts for stable diffusion inference 82 | 83 | 84 | Note: the following "prompt_encoder_f16.irpa" contains weights for both clip1 and clip2. 85 | The pipeline script will look for these filenames in the specified "external_weights_dir" under "prompt_encoder.irpa", "vae_decode.irpa", "scheduled_unet.irpa". 86 | It's not ideal in current state, but will be smoothed out now that general pipeline structure and file management needs are stable. 87 | - [prompt_encoder_f16.irpa](https://sharkpublic.blob.core.windows.net/sharkpublic/SDXL/SDXL_weights_fp16/prompt_encoder_fp16.irpa) 88 | - [scheduled_unet_f16.irpa](https://sharkpublic.blob.core.windows.net/sharkpublic/SDXL/SDXL_weights_fp16/scheduled_unet_f16.irpa) 89 | - [vae_decode_f16.irpa](https://sharkpublic.blob.core.windows.net/sharkpublic/SDXL/SDXL_weights_fp16/vae_encode_fp16.irpa) 90 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sdxl_inference/clip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import os 8 | import sys 9 | 10 | from iree import runtime as ireert 11 | import iree.compiler as ireec 12 | from iree.compiler.ir import Context 13 | import numpy as np 14 | from iree.turbine.aot import * 15 | from turbine_models.custom_models.sd_inference import utils 16 | import torch 17 | from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer 18 | 19 | 20 | class ClipModel(torch.nn.Module): 21 | def __init__(self, hf_model_name, hf_auth_token=None, index=1): 22 | super().__init__() 23 | if index == 1: 24 | self.text_encoder_model = CLIPTextModel.from_pretrained( 25 | hf_model_name, 26 | subfolder="text_encoder", 27 | token=hf_auth_token, 28 | ) 29 | if index == 2: 30 | self.text_encoder_model = CLIPTextModelWithProjection.from_pretrained( 31 | hf_model_name, 32 | subfolder="text_encoder_2", 33 | token=hf_auth_token, 34 | ) 35 | 36 | def forward(self, input): 37 | with torch.no_grad(): 38 | prompt_embeds = self.text_encoder_model( 39 | input, 40 | output_hidden_states=True, 41 | ) 42 | # We are only ALWAYS interested in the pooled output of the final text encoder 43 | pooled_prompt_embeds = prompt_embeds[0] 44 | prompt_embeds = prompt_embeds.hidden_states[-2] 45 | return prompt_embeds, pooled_prompt_embeds 46 | 47 | 48 | def export_clip_model( 49 | hf_model_name, 50 | hf_auth_token=None, 51 | max_length=77, 52 | precision="fp16", 53 | compile_to="torch", 54 | external_weights=None, 55 | external_weight_path=None, 56 | device=None, 57 | target_triple=None, 58 | ireec_flags=None, 59 | index=1, 60 | exit_on_vmfb=True, 61 | pipeline_dir=None, 62 | input_mlir=None, 63 | attn_spec=None, 64 | weights_only=False, 65 | ): 66 | if pipeline_dir not in [None, ""]: 67 | safe_name = os.path.join(pipeline_dir, "clip_" + str(index)) 68 | else: 69 | safe_name = utils.create_safe_name( 70 | hf_model_name, f"_{str(max_length)}-{precision}-clip-{index}-{device}" 71 | ) 72 | if input_mlir: 73 | vmfb_path = utils.compile_to_vmfb( 74 | input_mlir, 75 | device, 76 | target_triple, 77 | ireec_flags, 78 | safe_name, 79 | mlir_source="file", 80 | return_path=not exit_on_vmfb, 81 | const_expr_hoisting=True, 82 | attn_spec=attn_spec, 83 | ) 84 | return vmfb_path 85 | # Load the tokenizer and text encoder to tokenize and encode the text. 86 | if index == 1: 87 | tokenizer = CLIPTokenizer.from_pretrained( 88 | hf_model_name, 89 | subfolder="tokenizer", 90 | token=hf_auth_token, 91 | model_max_length=max_length, 92 | ) 93 | elif index == 2: 94 | tokenizer = CLIPTokenizer.from_pretrained( 95 | hf_model_name, 96 | subfolder="tokenizer_2", 97 | token=hf_auth_token, 98 | model_max_length=max_length, 99 | ) 100 | text_encoder_model = ClipModel(hf_model_name, hf_auth_token, index=index) 101 | if compile_to == "tokenizer_only": 102 | return None, tokenizer 103 | if precision == "fp16": 104 | text_encoder_model = text_encoder_model.half() 105 | mapper = {} 106 | if external_weight_path: 107 | weights_path = ( 108 | external_weight_path.split(f".{external_weights}")[0] 109 | + f"_{index}" 110 | + f".{external_weights}" 111 | ) 112 | else: 113 | weights_path = None 114 | 115 | utils.save_external_weights( 116 | mapper, text_encoder_model, external_weights, weights_path 117 | ) 118 | 119 | if weights_only: 120 | return weights_path 121 | 122 | class CompiledClip(CompiledModule): 123 | if external_weights: 124 | params = export_parameters( 125 | text_encoder_model, 126 | external=True, 127 | external_scope="", 128 | name_mapper=mapper.get, 129 | ) 130 | else: 131 | params = export_parameters(text_encoder_model) 132 | 133 | def main(self, inp=AbstractTensor(1, max_length, dtype=torch.int64)): 134 | return jittable(text_encoder_model.forward)(inp) 135 | 136 | import_to = "INPUT" if compile_to == "linalg" else "IMPORT" 137 | inst = CompiledClip(context=Context(), import_to=import_to) 138 | 139 | module_str = str(CompiledModule.get_mlir_module(inst)) 140 | 141 | if compile_to != "vmfb": 142 | return module_str, tokenizer 143 | else: 144 | vmfb_path = utils.compile_to_vmfb( 145 | module_str, 146 | device, 147 | target_triple, 148 | ireec_flags, 149 | safe_name, 150 | return_path=not exit_on_vmfb, 151 | const_expr_hoisting=True, 152 | attn_spec=attn_spec, 153 | ) 154 | return None, vmfb_path 155 | 156 | 157 | if __name__ == "__main__": 158 | from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args 159 | 160 | mod_1_str, _ = export_clip_model( 161 | args.hf_model_name, 162 | args.hf_auth_token, 163 | args.max_length, 164 | args.precision, 165 | args.compile_to, 166 | args.external_weights, 167 | args.external_weight_path, 168 | args.device, 169 | args.iree_target_triple, 170 | args.ireec_flags + args.clip_flags, 171 | 1, 172 | exit_on_vmfb=False, 173 | pipeline_dir=args.pipeline_dir, 174 | input_mlir=args.input_mlir, 175 | attn_spec=args.attn_spec, 176 | ) 177 | mod_2_str, _ = export_clip_model( 178 | args.hf_model_name, 179 | args.hf_auth_token, 180 | args.max_length, 181 | args.precision, 182 | args.compile_to, 183 | args.external_weights, 184 | args.external_weight_path, 185 | args.device, 186 | args.iree_target_triple, 187 | args.ireec_flags + args.clip_flags, 188 | 2, 189 | exit_on_vmfb=True, 190 | pipeline_dir=args.pipeline_dir, 191 | input_mlir=args.input_mlir, 192 | attn_spec=args.attn_spec, 193 | ) 194 | if args.input_mlir: 195 | exit() 196 | safe_name_1 = safe_name = utils.create_safe_name( 197 | args.hf_model_name, f"_{str(args.max_length)}_{args.precision}_clip_1" 198 | ) 199 | safe_name_2 = safe_name = utils.create_safe_name( 200 | args.hf_model_name, f"_{str(args.max_length)}_{args.precision}_clip_2" 201 | ) 202 | with open(f"{safe_name_1}.mlir", "w+") as f: 203 | f.write(mod_1_str) 204 | print("Saved to", safe_name_1 + ".mlir") 205 | with open(f"{safe_name_2}.mlir", "w+") as f: 206 | f.write(mod_2_str) 207 | print("Saved to", safe_name_2 + ".mlir") 208 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sdxl_inference/sdxl_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | 8 | import numpy as np 9 | import torch 10 | import sys 11 | 12 | from iree import runtime as ireert 13 | from turbine_models.utils.benchmark import benchmark_module 14 | 15 | 16 | def run_benchmark(args): 17 | config = ireert.Config(args.rt_device) 18 | 19 | if args.external_weight_file: 20 | index = ireert.ParameterIndex() 21 | index.load(args.external_weight_file) 22 | 23 | if not args.benchmark_vmfb_path: 24 | sys.exit("no --benchmark_vmfb_path provided, required for run_benchmark") 25 | benchmark_mod = ireert.VmModule.mmap(config.vm_instance, args.benchmark_vmfb_path) 26 | 27 | if not args.scheduled_unet_vmfb_path: 28 | sys.exit("no --scheduled_unet_vmfb_path provided, required for run_benchmark") 29 | 30 | dtype = np.float16 if args.precision == "fp16" else np.float32 31 | sample = np.random.randn( 32 | args.batch_size, 4, args.height // 8, args.width // 8 33 | ).astype(dtype) 34 | prompt_embeds = np.random.randn(2 * args.batch_size, args.max_length, 2048).astype( 35 | dtype 36 | ) 37 | text_embeds = np.random.randn(2 * args.batch_size, 1280).astype(dtype) 38 | guidance_scale = np.array([7.5], dtype=dtype) 39 | num_iters = np.array(args.num_inference_steps) 40 | input = [ 41 | sample, 42 | prompt_embeds, 43 | text_embeds, 44 | guidance_scale, 45 | num_iters, 46 | ] 47 | 48 | vmfbs = [] 49 | vmfbs.append(args.scheduled_unet_vmfb_path) 50 | vmfbs.append(args.benchmark_vmfb_path) 51 | 52 | if args.external_weight_file: 53 | results = benchmark_module( 54 | benchmark_mod, 55 | "produce_image_latents", 56 | vmfbs, 57 | input, 58 | parameters=f"model={args.external_weight_file}", 59 | ) 60 | else: 61 | results = benchmark_module(benchmark_mod, "produce_image_latents", vmfbs, input) 62 | for benchmark_result in results: 63 | print( 64 | f"benchmark_name: {benchmark_result.benchmark_name}, time: {benchmark_result.time}, cpu_time: {benchmark_result.cpu_time}, iterations: {benchmark_result.iterations}, user_counters: {benchmark_result.user_counters}" 65 | ) 66 | 67 | 68 | # Python Benchmarking Support for multiple modules 69 | 70 | if __name__ == "__main__": 71 | from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args 72 | 73 | run_benchmark(args) 74 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py: -------------------------------------------------------------------------------- 1 | from turbine_models.model_runner import vmfbRunner 2 | from transformers import CLIPTokenizer 3 | from iree import runtime as ireert 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def run_prompt_encoder( 9 | vmfb_path, 10 | device, 11 | external_weight_path, 12 | input_ids, 13 | uncond_input_ids, 14 | ): 15 | prompt_encoder_runner = vmfbRunner(device, vmfb_path, external_weight_path) 16 | # np.save("input0.npy", input_ids[0].numpy()) 17 | # np.save("input1.npy", input_ids[1].numpy()) 18 | # np.save("input2.npy", uncond_input_ids[0].numpy()) 19 | # np.save("input3.npy", uncond_input_ids[1].numpy()) 20 | prompt_encoder_inputs = [ 21 | ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[0]), 22 | ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[1]), 23 | ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[0]), 24 | ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[1]), 25 | ] 26 | encoded_outputs = prompt_encoder_runner.ctx.modules.compiled_clip["encode_prompts"]( 27 | *prompt_encoder_inputs 28 | ) 29 | for i in encoded_outputs: 30 | i = i.to_host() 31 | del prompt_encoder_inputs 32 | return encoded_outputs 33 | 34 | 35 | def run_tokenize( 36 | tokenizer_1, 37 | tokenizer_2, 38 | prompt, 39 | negative_prompt, 40 | max_length=64, 41 | ): 42 | text_input_ids_list = [] 43 | uncond_input_ids_list = [] 44 | 45 | # Tokenize prompt and negative prompt. 46 | tokenizers = [tokenizer_1, tokenizer_2] 47 | for tokenizer in tokenizers: 48 | text_inputs = tokenizer( 49 | prompt, 50 | padding="max_length", 51 | max_length=max_length, 52 | truncation=True, 53 | return_tensors="pt", 54 | ) 55 | uncond_input = tokenizer( 56 | negative_prompt, 57 | padding="max_length", 58 | max_length=max_length, 59 | truncation=True, 60 | return_tensors="pt", 61 | ) 62 | text_input_ids = text_inputs.input_ids 63 | uncond_input_ids = uncond_input.input_ids 64 | 65 | text_input_ids_list.extend([text_input_ids]) 66 | uncond_input_ids_list.extend([uncond_input_ids]) 67 | return text_input_ids_list, uncond_input_ids_list 68 | 69 | 70 | if __name__ == "__main__": 71 | from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args 72 | 73 | tokenizer_1 = CLIPTokenizer.from_pretrained( 74 | args.hf_model_name, 75 | subfolder="tokenizer", 76 | token=args.hf_auth_token, 77 | ) 78 | tokenizer_2 = CLIPTokenizer.from_pretrained( 79 | args.hf_model_name, 80 | subfolder="tokenizer_2", 81 | token=args.hf_auth_token, 82 | ) 83 | 84 | text_input_ids_list, uncond_input_ids_list = run_tokenize( 85 | tokenizer_1, 86 | tokenizer_2, 87 | args.prompt, 88 | args.negative_prompt, 89 | args.max_length, 90 | ) 91 | turbine_output1, turbine_output2 = run_prompt_encoder( 92 | args.vmfb_path, 93 | args.rt_device, 94 | args.external_weight_path, 95 | text_input_ids_list, 96 | uncond_input_ids_list, 97 | ) 98 | print( 99 | "TURBINE OUTPUT 1:", 100 | turbine_output1.to_host(), 101 | turbine_output1.shape, 102 | turbine_output1.dtype, 103 | ) 104 | 105 | print( 106 | "TURBINE OUTPUT 2:", 107 | turbine_output2.to_host(), 108 | turbine_output2.shape, 109 | turbine_output2.dtype, 110 | ) 111 | 112 | if args.compare_vs_torch: 113 | print("generating torch output: ") 114 | from turbine_models.custom_models.sd_inference import utils 115 | from turbine_models.custom_models.sdxl_inference.sdxl_prompt_encoder import ( 116 | PromptEncoderModule, 117 | ) 118 | 119 | torch_encoder_model = PromptEncoderModule( 120 | args.hf_model_name, args.precision, args.hf_auth_token 121 | ) 122 | torch_output1, torch_output2 = torch_encoder_model.forward( 123 | *text_input_ids_list, *uncond_input_ids_list 124 | ) 125 | np.save("torch_output1.npy", torch_output1) 126 | np.save("torch_output2.npy", torch_output2) 127 | print( 128 | "TORCH OUTPUT 1:", torch_output1, torch_output1.shape, torch_output1.dtype 129 | ) 130 | 131 | print( 132 | "TORCH OUTPUT 2:", torch_output2, torch_output2.shape, torch_output2.dtype 133 | ) 134 | rtol = 4e-2 135 | atol = 4e-2 136 | 137 | np.testing.assert_allclose( 138 | torch_output1, turbine_output1.to_host(), rtol, atol, verbose=True 139 | ) 140 | np.testing.assert_allclose( 141 | torch_output2, turbine_output2.to_host(), rtol, atol, verbose=True 142 | ) 143 | print("Passed!") 144 | # TODO: Figure out why we occasionally segfault without unlinking output variables 145 | turbine_output1, turbine_output2 = (None, None) 146 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sdxl_inference/unet_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from turbine_models.model_runner import vmfbRunner 3 | from iree import runtime as ireert 4 | import torch 5 | import numpy as np 6 | from tqdm.auto import tqdm 7 | 8 | torch.random.manual_seed(0) 9 | 10 | 11 | def run_unet( 12 | device, 13 | sample, 14 | timestep, 15 | prompt_embeds, 16 | text_embeds, 17 | time_ids, 18 | guidance_scale, 19 | vmfb_path, 20 | hf_model_name, 21 | hf_auth_token, 22 | external_weight_path, 23 | runner=None, 24 | ): 25 | if runner is None: 26 | runner = vmfbRunner(device, vmfb_path, external_weight_path) 27 | 28 | inputs = [ 29 | ireert.asdevicearray(runner.config.device, sample), 30 | ireert.asdevicearray(runner.config.device, timestep), 31 | ireert.asdevicearray(runner.config.device, prompt_embeds), 32 | ireert.asdevicearray(runner.config.device, text_embeds), 33 | ireert.asdevicearray(runner.config.device, time_ids), 34 | ireert.asdevicearray(runner.config.device, guidance_scale), 35 | ] 36 | results = runner.ctx.modules.compiled_unet["run_forward"](*inputs) 37 | 38 | return results 39 | 40 | 41 | def run_unet_steps( 42 | device, 43 | sample, 44 | scheduler, 45 | prompt_embeds, 46 | text_embeds, 47 | time_ids, 48 | guidance_scale, 49 | vmfb_path, 50 | external_weight_path, 51 | ): 52 | runner = vmfbRunner(device, vmfb_path, external_weight_path) 53 | timestep = torch.zeros(1, dtype=torch.int64) 54 | inputs = [ 55 | ireert.asdevicearray(runner.config.device, sample), 56 | ireert.asdevicearray(runner.config.device, timestep), 57 | ireert.asdevicearray(runner.config.device, prompt_embeds), 58 | ireert.asdevicearray(runner.config.device, text_embeds), 59 | ireert.asdevicearray(runner.config.device, time_ids), 60 | ireert.asdevicearray(runner.config.device, guidance_scale), 61 | ] 62 | for i, t in tqdm(enumerate(scheduler.timesteps)): 63 | timestep = t 64 | latent_model_input = scheduler.scale_model_input(sample, timestep) 65 | 66 | inputs[0] = latent_model_input = ireert.asdevicearray( 67 | runner.config.device, latent_model_input 68 | ) 69 | inputs[1] = timestep = ireert.asdevicearray( 70 | runner.config.device, (timestep,), dtype="int64" 71 | ) 72 | noise_pred = runner.ctx.modules.compiled_unet["run_forward"](*inputs).to_host() 73 | sample = scheduler.step( 74 | torch.from_numpy(noise_pred).cpu(), 75 | timestep, 76 | sample, 77 | generator=None, 78 | return_dict=False, 79 | )[0] 80 | return sample 81 | 82 | 83 | def run_torch_unet( 84 | hf_model_name, 85 | hf_auth_token, 86 | sample, 87 | timestep, 88 | prompt_embeds, 89 | text_embeds, 90 | time_ids, 91 | guidance_scale, 92 | precision="fp32", 93 | ): 94 | from turbine_models.custom_models.sdxl_inference.unet import UnetModel 95 | 96 | unet_model = UnetModel( 97 | hf_model_name, 98 | hf_auth_token, 99 | precision="fp32", 100 | ) 101 | results = unet_model.forward( 102 | sample, timestep, prompt_embeds, text_embeds, time_ids, guidance_scale 103 | ) 104 | np_torch_output = results.detach().cpu().numpy() 105 | return np_torch_output 106 | 107 | 108 | if __name__ == "__main__": 109 | from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args 110 | 111 | if args.precision == "fp16": 112 | dtype = torch.float16 113 | else: 114 | dtype = torch.float32 115 | 116 | save_inputs = True 117 | 118 | sample = torch.rand( 119 | args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype 120 | ) 121 | timestep = torch.ones(1, dtype=dtype) 122 | prompt_embeds = torch.rand(2 * args.batch_size, args.max_length, 2048, dtype=dtype) 123 | text_embeds = torch.rand(2 * args.batch_size, 1280, dtype=dtype) 124 | time_ids = torch.rand(2 * args.batch_size, 6, dtype=dtype) 125 | guidance_scale = torch.tensor([7.5], dtype=dtype) 126 | 127 | if save_inputs: 128 | import os 129 | 130 | inputs_dir = "sdxl_unet_inputs_" + args.precision 131 | if not os.path.exists(inputs_dir): 132 | os.mkdir(inputs_dir) 133 | np.save("input1.npy", sample) 134 | np.save("input2.npy", timestep) 135 | np.save("input3.npy", prompt_embeds) 136 | np.save("input4.npy", text_embeds) 137 | np.save("input5.npy", time_ids) 138 | np.save("input6.npy", guidance_scale) 139 | 140 | turbine_output = run_unet( 141 | args.device, 142 | sample, 143 | timestep, 144 | prompt_embeds, 145 | text_embeds, 146 | time_ids, 147 | guidance_scale, 148 | args.vmfb_path, 149 | args.hf_model_name, 150 | args.hf_auth_token, 151 | args.external_weight_path, 152 | ).to_host() 153 | print( 154 | "TURBINE OUTPUT:", 155 | turbine_output, 156 | turbine_output.shape, 157 | turbine_output.dtype, 158 | ) 159 | 160 | if args.compare_vs_torch: 161 | print("generating torch output: ") 162 | from turbine_models.custom_models.sd_inference import utils 163 | 164 | # comment out .float for fp16... sorry. 165 | torch_output = run_torch_unet( 166 | args.hf_model_name, 167 | args.hf_auth_token, 168 | sample.float(), 169 | timestep, 170 | prompt_embeds.float(), 171 | text_embeds.float(), 172 | time_ids.float(), 173 | guidance_scale.float(), 174 | # precision="fp16", 175 | ) 176 | print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) 177 | if save_inputs: 178 | np.save("golden_out.npy", torch_output) 179 | atol = 4e-2 180 | rtol = 4e-1 181 | np.testing.assert_allclose(turbine_output, torch_output, atol=atol, rtol=rtol) 182 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sdxl_inference/vae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import copy 8 | import os 9 | import sys 10 | 11 | from iree import runtime as ireert 12 | from iree.compiler.ir import Context 13 | import numpy as np 14 | from iree.turbine.aot import * 15 | from iree.turbine.dynamo.passes import ( 16 | DEFAULT_DECOMPOSITIONS, 17 | ) 18 | from turbine_models.custom_models.sd_inference import utils 19 | import torch 20 | import torch._dynamo as dynamo 21 | from diffusers import AutoencoderKL 22 | import safetensors 23 | 24 | 25 | class VaeModel(torch.nn.Module): 26 | def __init__( 27 | self, 28 | hf_model_name, 29 | custom_vae="", 30 | ): 31 | super().__init__() 32 | self.vae = None 33 | if custom_vae in ["", None]: 34 | self.vae = AutoencoderKL.from_pretrained( 35 | hf_model_name, 36 | subfolder="vae", 37 | ) 38 | elif "safetensors" in custom_vae: 39 | custom_vae = safetensors.torch.load_file(custom_vae) 40 | # custom vae as a HF state dict 41 | self.vae = AutoencoderKL.from_pretrained( 42 | hf_model_name, 43 | subfolder="vae", 44 | ) 45 | self.vae.load_state_dict(custom_vae) 46 | elif not isinstance(custom_vae, dict): 47 | try: 48 | # custom HF repo with no vae subfolder 49 | self.vae = AutoencoderKL.from_pretrained( 50 | custom_vae, 51 | ) 52 | except: 53 | # some larger repo with vae subfolder 54 | self.vae = AutoencoderKL.from_pretrained( 55 | custom_vae, 56 | subfolder="vae", 57 | ) 58 | 59 | def decode(self, inp): 60 | img = 1 / 0.13025 * inp 61 | x = self.vae.decode(img, return_dict=False)[0] 62 | return (x / 2 + 0.5).clamp(0, 1) 63 | 64 | def encode(self, inp): 65 | latents = self.vae.encode(inp).latent_dist.sample() 66 | return 0.13025 * latents 67 | 68 | 69 | def export_vae_model( 70 | vae_model, 71 | hf_model_name, 72 | batch_size, 73 | height, 74 | width, 75 | precision, 76 | compile_to="torch", 77 | external_weights=None, 78 | external_weight_path=None, 79 | device=None, 80 | target_triple=None, 81 | ireec_flags=None, 82 | variant="decode", 83 | decomp_attn=False, 84 | exit_on_vmfb=False, 85 | pipeline_dir=None, 86 | attn_spec=None, 87 | input_mlir=None, 88 | weights_only=False, 89 | ): 90 | safe_name = utils.create_safe_name( 91 | hf_model_name, 92 | f"_bs{batch_size}_{height}x{width}_{precision}_vae_{variant}", 93 | ) 94 | if pipeline_dir: 95 | safe_name = os.path.join(pipeline_dir, safe_name) 96 | 97 | if input_mlir: 98 | vmfb_path = utils.compile_to_vmfb( 99 | input_mlir, 100 | device, 101 | target_triple, 102 | ireec_flags, 103 | safe_name + "_" + target_triple, 104 | mlir_source="file", 105 | return_path=not exit_on_vmfb, 106 | attn_spec=attn_spec, 107 | ) 108 | return vmfb_path 109 | # if precision == "fp32" and device == "rocm": 110 | # decomp_attn = True 111 | # external_weights = None 112 | # print("Decomposing attention and inlining weights for fp32 VAE on ROCm") 113 | if device == "cpu": 114 | decomp_attn = True 115 | 116 | dtype = torch.float16 if precision == "fp16" else torch.float32 117 | if precision == "fp16": 118 | vae_model = vae_model.half() 119 | 120 | mapper = {} 121 | 122 | utils.save_external_weights( 123 | mapper, vae_model, external_weights, external_weight_path 124 | ) 125 | if weights_only: 126 | return external_weight_path 127 | 128 | input_image_shape = (height, width, 3) 129 | input_latents_shape = (batch_size, 4, height // 8, width // 8) 130 | encode_args = [ 131 | torch.empty( 132 | input_image_shape, 133 | dtype=torch.float32, 134 | ) 135 | ] 136 | decode_args = [ 137 | torch.empty( 138 | input_latents_shape, 139 | dtype=dtype, 140 | ) 141 | ] 142 | decomp_list = [] 143 | if decomp_attn == True: 144 | safe_name += "_decomp" 145 | decomp_list = [ 146 | torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, 147 | torch.ops.aten._scaled_dot_product_flash_attention.default, 148 | torch.ops.aten.scaled_dot_product_attention, 149 | ] 150 | with decompositions.extend_aot_decompositions( 151 | from_current=True, 152 | add_ops=decomp_list, 153 | ): 154 | fxb = FxProgramsBuilder(vae_model) 155 | 156 | # @fxb.export_program(args=(encode_args,)) 157 | # def _encode(module, inputs,): 158 | # return module.encode(*inputs) 159 | 160 | @fxb.export_program(args=(decode_args,)) 161 | def _decode(module, inputs): 162 | return module.decode(*inputs) 163 | 164 | class CompiledVae(CompiledModule): 165 | main = _decode 166 | 167 | if external_weights: 168 | externalize_module_parameters(vae_model) 169 | 170 | inst = CompiledVae(context=Context(), import_to="IMPORT") 171 | 172 | module_str = str(CompiledModule.get_mlir_module(inst)) 173 | 174 | if compile_to != "vmfb": 175 | return module_str 176 | else: 177 | vmfb_path = utils.compile_to_vmfb( 178 | module_str, 179 | device, 180 | target_triple, 181 | ireec_flags, 182 | safe_name + "_" + target_triple, 183 | return_path=not exit_on_vmfb, 184 | attn_spec=attn_spec, 185 | ) 186 | return vmfb_path 187 | 188 | 189 | if __name__ == "__main__": 190 | from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args 191 | 192 | if args.precision == "fp16": 193 | custom_vae = "madebyollin/sdxl-vae-fp16-fix" 194 | else: 195 | custom_vae = "" 196 | 197 | if args.input_mlir: 198 | vae_model = None 199 | else: 200 | vae_model = VaeModel( 201 | args.hf_model_name, 202 | custom_vae=custom_vae, 203 | ) 204 | mod_str = export_vae_model( 205 | vae_model, 206 | args.hf_model_name, 207 | args.batch_size, 208 | height=args.height, 209 | width=args.width, 210 | precision=args.precision, 211 | compile_to=args.compile_to, 212 | external_weights=args.external_weights, 213 | external_weight_path=args.external_weight_path, 214 | device=args.device, 215 | target_triple=args.iree_target_triple, 216 | ireec_flags=args.ireec_flags + args.attn_flags + args.vae_flags, 217 | variant=args.vae_variant, 218 | decomp_attn=args.decomp_attn, 219 | attn_spec=args.attn_spec, 220 | input_mlir=args.input_mlir, 221 | ) 222 | if args.input_mlir or (args.compile_to == "vmfb"): 223 | exit() 224 | safe_name = utils.create_safe_name( 225 | args.hf_model_name, 226 | f"_bs{str(args.batch_size)}_{args.height}x{args.width}_{args.precision}_vae_{args.vae_variant}", 227 | ) 228 | with open(f"{safe_name}.mlir", "w+") as f: 229 | f.write(mod_str) 230 | print("Saved to", safe_name + ".mlir") 231 | -------------------------------------------------------------------------------- /models/turbine_models/custom_models/sdxl_inference/vae_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from turbine_models.model_runner import vmfbRunner 3 | from iree import runtime as ireert 4 | import torch 5 | 6 | torch.random.manual_seed(0) 7 | 8 | 9 | def run_vae( 10 | device, 11 | example_input, 12 | vmfb_path, 13 | hf_model_name, 14 | external_weight_path, 15 | ): 16 | runner = vmfbRunner(device, vmfb_path, external_weight_path) 17 | inputs = [ireert.asdevicearray(runner.config.device, example_input)] 18 | results = runner.ctx.modules.compiled_vae["decode"](*inputs) 19 | 20 | return results 21 | 22 | 23 | def run_torch_vae(hf_model_name, custom_vae, variant, example_input): 24 | from turbine_models.custom_models.sd_inference.vae import VaeModel 25 | 26 | vae_model = VaeModel( 27 | hf_model_name, 28 | ) 29 | 30 | if variant == "decode": 31 | results = vae_model.decode(example_input) 32 | elif variant == "encode": 33 | results = vae_model.encode(example_input) 34 | np_torch_output = results.detach().cpu().numpy() 35 | return np_torch_output 36 | 37 | 38 | if __name__ == "__main__": 39 | from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args 40 | 41 | if args.precision == "fp16": 42 | dtype = torch.float16 43 | custom_vae = "madebyollin/sdxl-vae-fp16-fix" 44 | else: 45 | dtype = torch.float32 46 | custom_vae = "" 47 | if args.vae_variant == "decode": 48 | example_input = torch.rand( 49 | args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype 50 | ) 51 | elif args.vae_variant == "encode": 52 | example_input = torch.rand( 53 | args.batch_size, 3, args.height, args.width, dtype=dtype 54 | ) 55 | print("generating turbine output:") 56 | turbine_results = run_vae( 57 | args.device, 58 | example_input, 59 | args.vmfb_path, 60 | args.hf_model_name, 61 | args.external_weight_path, 62 | ) 63 | print( 64 | "TURBINE OUTPUT:", 65 | turbine_results.to_host(), 66 | turbine_results.to_host().shape, 67 | turbine_results.to_host().dtype, 68 | ) 69 | if args.compare_vs_torch: 70 | print("generating torch output: ") 71 | from turbine_models.custom_models.sd_inference import utils 72 | 73 | torch_output = run_torch_vae( 74 | args.hf_model_name, custom_vae, args.vae_variant, example_input.float() 75 | ) 76 | print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) 77 | err = utils.largest_error(torch_output, turbine_results) 78 | print("Largest Error: ", err) 79 | assert err < 2e-3 80 | 81 | # TODO: Figure out why we occasionally segfault without unlinking output variables 82 | turbine_results = None 83 | -------------------------------------------------------------------------------- /models/turbine_models/gen_external_params/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nod-ai/SHARK-ModelDev/bdf46351281f47843bb0d8c77bcd8ecde7271b60/models/turbine_models/gen_external_params/__init__.py -------------------------------------------------------------------------------- /models/turbine_models/gen_external_params/gen_external_params.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Literal 3 | from turbine_models.model_builder import HFTransformerBuilder 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | import torch 6 | 7 | import argparse 8 | import sys 9 | 10 | parser = argparse.ArgumentParser(description="Quantize and save Hugging Face models.") 11 | 12 | parser.add_argument( 13 | "--hf_model_name", 14 | type=str, 15 | default="meta-llama/Llama-2-7b-chat-hf", 16 | help="The Hugging Face model name ID.", 17 | ) 18 | parser.add_argument( 19 | "--quantization", 20 | type=str, 21 | default="int4", 22 | choices=["unquantized", "int4", "int8"], 23 | help="Type of quantization to apply.", 24 | ) 25 | parser.add_argument( 26 | "--weight_path", 27 | type=str, 28 | default="", 29 | help="Path to save the quantized model weights.", 30 | ) 31 | parser.add_argument( 32 | "--hf_auth_token", 33 | type=str, 34 | default=None, 35 | help="The Hugging Face auth token required for some models.", 36 | ) 37 | parser.add_argument( 38 | "--precision", 39 | type=str, 40 | default="f16", 41 | choices=["f16", "f32"], 42 | help="Data type of model.", 43 | ) 44 | 45 | 46 | def quantize(model, quantization, dtype): 47 | accumulates = dtype 48 | int_weights = {} 49 | if quantization in ["int4", "int8"]: 50 | from brevitas_examples.common.generative.quantize import quantize_model 51 | from brevitas_examples.llm.llm_quant.run_utils import get_model_impl 52 | 53 | print("Applying weight quantization...") 54 | weight_bit_width = 4 if quantization == "int4" else 8 55 | quantize_model( 56 | get_model_impl(model).layers, 57 | dtype=accumulates, 58 | weight_bit_width=weight_bit_width, 59 | weight_param_method="stats", 60 | weight_scale_precision="float_scale", 61 | weight_quant_type="asym", 62 | weight_quant_granularity="per_group", 63 | weight_group_size=128, # TODO: make adjustable 64 | quantize_weight_zero_point=False, 65 | ) 66 | from brevitas_examples.llm.llm_quant.export import LinearWeightBlockQuantHandler 67 | from brevitas.nn.quant_linear import QuantLinear 68 | 69 | class DummyLinearWeightBlockQuantHandler(LinearWeightBlockQuantHandler): 70 | def forward(self, x): 71 | raise NotImplementedError 72 | 73 | for prefix, layer in model.named_modules(): 74 | if isinstance(layer, QuantLinear): 75 | print(f"Exporting layer {prefix}") 76 | exporter = DummyLinearWeightBlockQuantHandler() 77 | exporter.prepare_for_export(layer) 78 | print( 79 | f" weight = ({exporter.int_weight.shape}, {exporter.int_weight.dtype}), " 80 | f"scale=({exporter.scale.shape}, {exporter.scale.dtype}), " 81 | f"zero=({exporter.zero_point.shape}, {exporter.zero_point.dtype})" 82 | ) 83 | int_weights[f"{prefix}.weight"] = exporter.int_weight 84 | int_weights[f"{prefix}.weight_scale"] = exporter.scale 85 | int_weights[f"{prefix}.weight_zp"] = exporter.zero_point 86 | 87 | all_weights = dict(model.named_parameters()) 88 | for k in list(all_weights.keys()): 89 | if "wrapped_scaling_impl" in k or "wrapped_zero_point_impl" in k: 90 | del all_weights[k] 91 | 92 | if len(int_weights) != 0: 93 | all_weights.update(int_weights) 94 | return all_weights 95 | 96 | 97 | def gen_external_params( 98 | hf_model_name: str = "meta-llama/Llama-2-7b-chat-hf", 99 | quantization: Literal["unquantized", "int4", "int8"] = "int4", 100 | weight_path: str = "", 101 | hf_auth_token: str = None, 102 | precision: str = "f16", 103 | ): 104 | """ 105 | Main function to run the model quantization and saving process. 106 | 107 | :param hf_model_name: The Hugging Face model name ID. 108 | :param quantization: Type of quantization to apply ('int4' or 'int8'). 109 | :param weight_path: Path to save the quantized model weights. 110 | :param hf_auth_token: The Hugging Face auth token required for some models. 111 | :param precision: Data type of model ('f16' or 'f32'). 112 | """ 113 | SUPPORTED_QUANTIZATIONS = ["unquantized", "int4", "int8"] 114 | if quantization not in SUPPORTED_QUANTIZATIONS: 115 | if ( 116 | quantization is None 117 | or quantization.lower() == "none" 118 | or quantization.lower() == "unquantized" 119 | ): 120 | quantization = "unquantized" 121 | else: 122 | raise ValueError(f"Invalid quantization, {quantization} not supported.") 123 | 124 | model_builder = HFTransformerBuilder( 125 | example_input=None, 126 | hf_id=hf_model_name, 127 | auto_model=AutoModelForCausalLM, 128 | hf_auth_token=hf_auth_token, 129 | ) 130 | 131 | if precision == "f16": 132 | model = model_builder.model.half() 133 | dtype = torch.float16 134 | elif precision == "f32": 135 | model = model_builder.model 136 | dtype = torch.float32 137 | else: 138 | sys.exit("Invalid precision, f16 or f32 supported") 139 | 140 | quant_weights = quantize(model, quantization, dtype) 141 | 142 | if weight_path == "": 143 | save_path = hf_model_name.split("/")[-1].strip() 144 | save_path = re.sub("-", "_", save_path) 145 | save_path = save_path + "_" + precision + "_" + quantization + ".safetensors" 146 | else: 147 | save_path = weight_path 148 | 149 | import safetensors 150 | 151 | safetensors.torch.save_file(quant_weights, save_path) 152 | print("Saved safetensor output to ", save_path) 153 | 154 | 155 | if __name__ == "__main__": 156 | args = parser.parse_args() 157 | try: 158 | gen_external_params( 159 | hf_model_name=args.hf_model_name, 160 | quantization=args.quantization, 161 | weight_path=args.weight_path, 162 | hf_auth_token=args.hf_auth_token, 163 | precision=args.precision, 164 | ) 165 | except Exception as e: 166 | print(f"Error: {e}", file=sys.stderr) 167 | sys.exit(1) 168 | -------------------------------------------------------------------------------- /models/turbine_models/model_builder.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer, AutoConfig 2 | import torch 3 | import iree.turbine.aot as aot 4 | from turbine_models.turbine_tank import turbine_tank 5 | import os 6 | import re 7 | 8 | 9 | class HFTransformerBuilder: 10 | """ 11 | A model builder that uses Hugging Face's transformers library to build a PyTorch model. 12 | 13 | Args: 14 | example_input (torch.Tensor): An example input tensor to the model. 15 | hf_id (str): The Hugging Face model ID. 16 | auto_model (AutoModel): The AutoModel class to use for loading the model. 17 | auto_tokenizer (AutoTokenizer): The AutoTokenizer class to use for loading the tokenizer. 18 | auto_config (AutoConfig): The AutoConfig class to use for loading the model configuration. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | example_input: torch.Tensor, 24 | hf_id: str = None, 25 | auto_model: AutoModel = AutoModel, 26 | auto_tokenizer: AutoTokenizer = None, 27 | auto_config: AutoConfig = None, 28 | hf_auth_token=None, 29 | upload_ir=False, 30 | model=None, 31 | model_type: str = None, 32 | compile_to_vmfb: bool = None, 33 | tokenizer=None, 34 | ) -> None: 35 | self.example_input = example_input 36 | self.hf_id = hf_id 37 | self.auto_model = auto_model 38 | self.auto_tokenizer = auto_tokenizer 39 | self.auto_config = auto_config 40 | self.hf_auth_token = hf_auth_token 41 | self.model = model 42 | self.tokenizer = tokenizer 43 | self.upload_ir = upload_ir 44 | self.model_type = model_type 45 | self.compile_to_vmfb = compile_to_vmfb 46 | if self.model == None: 47 | self.build_model() 48 | 49 | def build_model(self) -> None: 50 | """ 51 | Builds a PyTorch model using Hugging Face's transformers library. 52 | """ 53 | # TODO: check cloud storage for existing ir 54 | if self.hf_id: 55 | self.model = self.auto_model.from_pretrained( 56 | self.hf_id, token=self.hf_auth_token, config=self.auto_config 57 | ) 58 | if self.auto_tokenizer is not None: 59 | self.tokenizer = self.auto_tokenizer.from_pretrained( 60 | self.hf_id, token=self.hf_auth_token 61 | ) 62 | else: 63 | self.tokenizer = None 64 | 65 | def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: 66 | """ 67 | Compiles the PyTorch model into a compiled module using SHARK-Turbine's AOT compiler. 68 | 69 | Args: 70 | save_to (str): one of: input (Torch IR) or import (linalg). 71 | 72 | Returns: 73 | aot.CompiledModule: The compiled module binary. 74 | """ 75 | if self.model_type and self.model_type == "hf_seq2seq": 76 | module = aot.export(self.model, *self.example_input) 77 | else: 78 | module = aot.export(self.model, self.example_input) 79 | if self.hf_id: 80 | module_str = str(module.mlir_module) 81 | safe_name = self.hf_id.split("/")[-1].strip() 82 | safe_name = re.sub("-", "_", safe_name) 83 | if self.upload_ir: 84 | with open(f"{safe_name}.mlir", "w+") as f: 85 | f.write(module_str) 86 | model_name_upload = self.hf_id.replace("/", "_") 87 | turbine_tank.uploadToBlobStorage( 88 | str(os.path.abspath(f"{safe_name}.mlir")), 89 | f"{model_name_upload}/{model_name_upload}.mlir", 90 | ) 91 | os.remove(f"{safe_name}.mlir") 92 | if self.compile_to_vmfb and not self.compile_to_vmfb: 93 | return 94 | compiled_binary = module.compile(save_to=save_to) 95 | return compiled_binary 96 | -------------------------------------------------------------------------------- /models/turbine_models/model_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from iree import runtime as ireert 4 | from iree.runtime._binding import create_hal_driver 5 | 6 | 7 | class vmfbRunner: 8 | def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None): 9 | flags = [] 10 | 11 | # If an extra plugin is requested, add a global flag to load the plugin 12 | # and create the driver using the non-caching creation function, as 13 | # the caching creation function may ignore the flag. 14 | if extra_plugin: 15 | ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}") 16 | haldriver = create_hal_driver(device) 17 | 18 | # No plugin requested: create the driver with the caching create 19 | # function. 20 | else: 21 | haldriver = ireert.get_driver(device) 22 | if "://" in device: 23 | try: 24 | device_idx = int(device.split("://")[-1]) 25 | device_uri = None 26 | except: 27 | device_idx = None 28 | device_uri = device.split("://")[-1] 29 | else: 30 | device_idx = 0 31 | device_uri = None 32 | if device_uri: 33 | if not any(x in device for x in ["cpu", "task"]): 34 | allocators = ["caching"] 35 | haldevice = haldriver.create_device_by_uri( 36 | device_uri, allocators=allocators 37 | ) 38 | else: 39 | haldevice = haldriver.create_device_by_uri(device_uri) 40 | else: 41 | hal_device_id = haldriver.query_available_devices()[device_idx]["device_id"] 42 | if not any(x in device for x in ["cpu", "task"]): 43 | allocators = ["caching"] 44 | haldevice = haldriver.create_device( 45 | hal_device_id, allocators=allocators 46 | ) 47 | else: 48 | haldevice = haldriver.create_device(hal_device_id) 49 | 50 | self.config = ireert.Config(device=haldevice) 51 | mods = [] 52 | if not isinstance(vmfb_path, list): 53 | vmfb_path = [vmfb_path] 54 | for path in vmfb_path: 55 | mods.append(ireert.VmModule.mmap(self.config.vm_instance, path)) 56 | vm_modules = [ 57 | *mods, 58 | ireert.create_hal_module(self.config.vm_instance, self.config.device), 59 | ] 60 | 61 | # TODO: Enable multiple weight files 62 | if external_weight_path: 63 | index = ireert.ParameterIndex() 64 | if not isinstance(external_weight_path, list): 65 | external_weight_path = [external_weight_path] 66 | for i, path in enumerate(external_weight_path): 67 | if path in ["", None]: 68 | continue 69 | index.load(path) 70 | # TODO: extend scope 71 | param_module = ireert.create_io_parameters_module( 72 | self.config.vm_instance, index.create_provider(scope="model") 73 | ) 74 | vm_modules.insert(i, param_module) 75 | del param_module 76 | del index 77 | 78 | self.ctx = ireert.SystemContext( 79 | vm_modules=vm_modules, 80 | config=self.config, 81 | ) 82 | 83 | def unload(self): 84 | self.ctx = None 85 | self.config = None 86 | -------------------------------------------------------------------------------- /models/turbine_models/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nod-ai/SHARK-ModelDev/bdf46351281f47843bb0d8c77bcd8ecde7271b60/models/turbine_models/tests/__init__.py -------------------------------------------------------------------------------- /models/turbine_models/tests/conftest.py: -------------------------------------------------------------------------------- 1 | def pytest_addoption(parser): 2 | # Huggingface Options 3 | parser.addoption("--hf_auth_token", action="store", default=None) 4 | parser.addoption( 5 | "--hf_model_name", 6 | action="store", 7 | default="stabilityai/stable-diffusion-xl-base-1.0", 8 | ) 9 | parser.addoption("--scheduler_id", action="store", default="PNDM") 10 | # Inference Options 11 | parser.addoption( 12 | "--prompt", 13 | action="store", 14 | default="a photograph of an astronaut riding a horse", 15 | ) 16 | parser.addoption( 17 | "--negative_prompt", 18 | action="store", 19 | default="blurry, unsaturated, watermark, noisy, grainy, out of focus", 20 | ) 21 | parser.addoption("--num_inference_steps", type=int, action="store", default=2) 22 | parser.addoption("--guidance_scale", type=float, action="store", default=7.5) 23 | parser.addoption("--seed", type=float, action="store", default=0.0) 24 | parser.addoption("--vmfb_path", action="store", default="") 25 | parser.addoption("--external_weight_path", action="store", default="") 26 | parser.addoption("--external_weight_dir", action="store", default="") 27 | parser.addoption("--external_weight_file", action="store", default="") 28 | parser.addoption("--pipeline_dir", action="store", default=".") 29 | # Modelling Options 30 | parser.addoption("--batch_size", type=int, action="store", default=1) 31 | parser.addoption("--height", type=int, action="store", default=1024) 32 | parser.addoption("--width", type=int, action="store", default=1024) 33 | parser.addoption("--precision", action="store", default="fp32") 34 | parser.addoption("--max_length", type=int, action="store", default=64) 35 | parser.addoption("--run_vmfb", action="store", default=True) 36 | # General Options 37 | parser.addoption("--compile_to", action="store", default=None) 38 | parser.addoption("--external_weights", action="store", default="safetensors") 39 | parser.addoption("--decomp_attn", action="store", default=False) 40 | parser.addoption("--vae_decomp_attn", action="store", default=False) 41 | parser.addoption("--attn_spec", action="store", default="") 42 | # Compiler Options 43 | parser.addoption("--device", action="store", default="cpu") 44 | parser.addoption("--rt_device", action="store", default="local-task") 45 | parser.addoption( 46 | "--iree_target_triple", type=str, action="store", default="x86_64-linux-gnu" 47 | ) 48 | parser.addoption("--ireec_flags", action="store", default="") 49 | parser.addoption("--attn_flags", action="store", default="") 50 | # Test Options 51 | parser.addoption("--in_channels", type=int, action="store", default=4) 52 | parser.addoption("--benchmark", action="store_true", default=False) 53 | parser.addoption("--tracy_profile", action="store_true", default=False) 54 | parser.addoption("--compiled_pipeline", type=bool, default=False) 55 | parser.addoption("--model_path", type=str, action="store", default=None) 56 | parser.addoption("--vae_model_path", type=str, action="store", default=None) 57 | parser.addoption("--pipeline_vmfb_path", type=str, action="store", default=None) 58 | parser.addoption("--scheduler_vmfb_path", type=str, action="store", default=None) 59 | parser.addoption("--split_scheduler", action="store_true", default=True) 60 | parser.addoption("--cpu_scheduling", action="store_true", default=True) 61 | parser.addoption("--npu_delegate_path", type=str, action="store", default=None) 62 | parser.addoption("--clip_precision", type=str, action="store", default=None) 63 | parser.addoption("--mmdit_precision", type=str, action="store", default=None) 64 | parser.addoption("--unet_precision", type=str, action="store", default=None) 65 | parser.addoption("--vae_precision", type=str, action="store", default=None) 66 | parser.addoption("--shift", type=float, action="store", default=None) 67 | parser.addoption("--denoise", action="store_true", default=None) 68 | -------------------------------------------------------------------------------- /models/turbine_models/tests/gen_external_params_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Nod Labs, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import logging 8 | from turbine_models.gen_external_params.gen_external_params import quantize 9 | from turbine_models.model_builder import HFTransformerBuilder 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | import unittest 12 | import os 13 | import torch 14 | import pytest 15 | 16 | 17 | class ExternalParamsTest(unittest.TestCase): 18 | def testQuantizeF32(self): 19 | model_builder = HFTransformerBuilder( 20 | example_input=None, 21 | hf_id="facebook/opt-350m", 22 | auto_model=AutoModelForCausalLM, 23 | ) 24 | model_builder.build_model() 25 | quant_weights = quantize(model_builder.model, "", torch.float32) 26 | for weight in quant_weights: 27 | self.assertNotIn("weight_zp", weight) 28 | self.assertNotIn("weight_scale", weight) 29 | assert quant_weights[weight].dtype in [torch.float32] 30 | 31 | def testQuantizeF32I8(self): 32 | model_builder = HFTransformerBuilder( 33 | example_input=None, 34 | hf_id="facebook/opt-350m", 35 | auto_model=AutoModelForCausalLM, 36 | ) 37 | model_builder.build_model() 38 | quant_weights = quantize(model_builder.model, "int8", torch.float32) 39 | named_params = dict(model_builder.model.named_parameters()) 40 | for weight in quant_weights: 41 | if "weight_scale" not in weight and "weight_zp" not in weight: 42 | if "layers" in weight and "weight" in weight and "norm" not in weight: 43 | assert quant_weights[weight].dtype in [torch.uint8] 44 | assert named_params[weight].size(dim=1) == quant_weights[ 45 | weight 46 | ].size(dim=1) 47 | else: 48 | assert quant_weights[weight].dtype in [torch.float32] 49 | else: 50 | assert quant_weights[weight].dtype in [torch.float32] 51 | 52 | def testQuantizeF32I4(self): 53 | model_builder = HFTransformerBuilder( 54 | example_input=None, 55 | hf_id="facebook/opt-350m", 56 | auto_model=AutoModelForCausalLM, 57 | ) 58 | model_builder.build_model() 59 | quant_weights = quantize(model_builder.model, "int4", torch.float32) 60 | named_params = dict(model_builder.model.named_parameters()) 61 | for weight in quant_weights: 62 | if "weight_scale" not in weight and "weight_zp" not in weight: 63 | if "layers" in weight and "weight" in weight and "norm" not in weight: 64 | assert quant_weights[weight].dtype in [torch.uint8] 65 | assert named_params[weight].size(dim=1) == 2 * quant_weights[ 66 | weight 67 | ].size(dim=1) 68 | else: 69 | assert quant_weights[weight].dtype in [torch.float32] 70 | else: 71 | assert quant_weights[weight].dtype in [torch.float32] 72 | 73 | def testQuantizeF16(self): 74 | model_builder = HFTransformerBuilder( 75 | example_input=None, 76 | hf_id="facebook/opt-350m", 77 | auto_model=AutoModelForCausalLM, 78 | ) 79 | model_builder.build_model() 80 | quant_weights = quantize(model_builder.model.half(), "", torch.float16) 81 | for weight in quant_weights: 82 | self.assertNotIn("weight_zp", weight) 83 | self.assertNotIn("weight_scale", weight) 84 | assert quant_weights[weight].dtype in [torch.float16] 85 | 86 | @pytest.mark.xfail(reason="brevitas issue with f16 int8 quanttensor") 87 | def testQuantizeF16I8(self): 88 | model_builder = HFTransformerBuilder( 89 | example_input=None, 90 | hf_id="facebook/opt-350m", 91 | auto_model=AutoModelForCausalLM, 92 | ) 93 | model_builder.build_model() 94 | quant_weights = quantize(model_builder.model.half(), "int8", torch.float16) 95 | named_params = dict(model_builder.model.named_parameters()) 96 | for weight in quant_weights: 97 | if "weight_scale" not in weight and "weight_zp" not in weight: 98 | if "layers" in weight and "weight" in weight and "norm" not in weight: 99 | assert quant_weights[weight].dtype in [torch.uint8] 100 | assert named_params[weight].size(dim=1) == quant_weights[ 101 | weight 102 | ].size(dim=1) 103 | else: 104 | assert quant_weights[weight].dtype in [torch.float16] 105 | else: 106 | assert quant_weights[weight].dtype in [torch.float16] 107 | 108 | def testQuantizeF16I4(self): 109 | model_builder = HFTransformerBuilder( 110 | example_input=None, 111 | hf_id="facebook/opt-350m", 112 | auto_model=AutoModelForCausalLM, 113 | ) 114 | model_builder.build_model() 115 | quant_weights = quantize(model_builder.model.half(), "int4", torch.float16) 116 | named_params = dict(model_builder.model.named_parameters()) 117 | for weight in quant_weights: 118 | if "weight_scale" not in weight and "weight_zp" not in weight: 119 | if "layers" in weight and "weight" in weight and "norm" not in weight: 120 | assert quant_weights[weight].dtype in [torch.uint8] 121 | assert named_params[weight].size(dim=1) == 2 * quant_weights[ 122 | weight 123 | ].size(dim=1) 124 | else: 125 | assert quant_weights[weight].dtype in [torch.float16] 126 | else: 127 | assert quant_weights[weight].dtype in [torch.float16] 128 | 129 | 130 | if __name__ == "__main__": 131 | logging.basicConfig(level=logging.DEBUG) 132 | unittest.main() 133 | -------------------------------------------------------------------------------- /models/turbine_models/tests/pipeline_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Advanced Micro Devices, inc. 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | import logging 8 | import pytest 9 | import unittest 10 | import torch 11 | import os 12 | import numpy as np 13 | from iree.compiler.ir import Context 14 | from iree.turbine.aot import * 15 | from turbine_models.custom_models.sd_inference import utils 16 | from turbine_models.custom_models.pipeline_base import ( 17 | PipelineComponent, 18 | TurbinePipelineBase, 19 | ) 20 | from iree.turbine.transforms.general.add_metadata import AddMetadataPass 21 | 22 | model_metadata_forward = { 23 | "model_name": "TestModel2xLinear", 24 | "input_shapes": [10], 25 | "input_dtypes": ["float32"], 26 | "output_shapes": [10], 27 | "output_dtypes": ["float32"], 28 | "test_kwarg_1": "test_kwarg_1_value", 29 | "test_kwarg_2": "test_kwarg_2_value", 30 | } 31 | 32 | 33 | class TestModule(torch.nn.Module): 34 | def __init__(self): 35 | super().__init__() 36 | self.fc1 = torch.nn.Linear(10, 10) 37 | self.fc2 = torch.nn.Linear(10, 10) 38 | 39 | def forward(self, x): 40 | x = self.fc1(x) 41 | x = self.fc2(x) 42 | return x 43 | 44 | 45 | torch.no_grad() 46 | 47 | 48 | def export_dummy_model(): 49 | model = TestModule() 50 | target = "x86_64-unknown-linux-gnu" 51 | device = "llvm-cpu" 52 | 53 | dummy_input = torch.empty(10) 54 | safe_keys = [ 55 | model_metadata_forward["model_name"], 56 | "fp32", 57 | "bs1", 58 | ] 59 | safe_name = "_".join(safe_keys) 60 | vmfb_path = f"./{safe_name}.vmfb" 61 | 62 | fxb = FxProgramsBuilder(model) 63 | 64 | @fxb.export_program(args=(dummy_input,)) 65 | def _forward(module, inputs): 66 | return module.forward(inputs) 67 | 68 | class CompiledTester(CompiledModule): 69 | forward = _forward 70 | 71 | inst = CompiledTester(context=Context(), import_to="IMPORT") 72 | mlir_module = CompiledModule.get_mlir_module(inst) 73 | mlir_module = AddMetadataPass(mlir_module, model_metadata_forward, "forward").run() 74 | vmfb_path = utils.compile_to_vmfb( 75 | str(mlir_module), 76 | device, 77 | target, 78 | None, 79 | safe_name + "_" + target, 80 | return_path=True, 81 | ) 82 | return vmfb_path 83 | 84 | 85 | class TestPipeline(TurbinePipelineBase): 86 | def __init__( 87 | self, 88 | **base_args, 89 | ): 90 | super().__init__(**base_args) 91 | 92 | def run(self, inputs: list): 93 | return self.test_model_1("forward", *inputs) 94 | 95 | 96 | class PipelineTest(unittest.TestCase): 97 | def setUp(self): 98 | model_map = { 99 | "test_model_1": { 100 | "model_name": "TestModel1", 101 | "external_weights": None, 102 | "module_name": "compiled_tester", 103 | "safe_name": "TestModel2xLinear", 104 | "keywords": ["Test", "Model", "2x", "Linear"], 105 | "export_fn": export_dummy_model, 106 | } 107 | } 108 | self.pipe = TestPipeline( 109 | model_map=model_map, 110 | device="cpu", 111 | target="x86_64-unknown-linux-gnu", 112 | pipeline_dir="./", 113 | precision="fp32", 114 | ) 115 | self.pipe.prepare_all() 116 | self.pipe.load_map() 117 | self.test_input = [torch.ones(10)] 118 | 119 | def test_pipeline(self): 120 | output = self.pipe.run(self.test_input).to_host() 121 | print(output) 122 | 123 | def test_pipeline_benchmark(self): 124 | self.pipe.test_model_1.benchmark = True 125 | output = self.pipe.run(self.test_input).to_host() 126 | print(output) 127 | 128 | def test_pipeline_metadata(self): 129 | metadata = self.pipe.test_model_1.get_metadata("forward") 130 | expected = model_metadata_forward 131 | for i in expected.keys(): 132 | expected[i] = str(expected[i]) 133 | assert expected == metadata, "Metadata mismatch: expected {}, got {}".format( 134 | expected, metadata 135 | ) 136 | 137 | 138 | if __name__ == "__main__": 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /models/turbine_models/tests/resnet_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from turbine_models.custom_models import resnet_18 4 | import unittest 5 | import os 6 | 7 | resnet_model = resnet_18.Resnet18Model() 8 | 9 | 10 | class Resnet18Test(unittest.TestCase): 11 | def testExportResnet18ModelCPU(self): 12 | from turbine_models.tests.testing_cmd_opts import args 13 | 14 | arguments = { 15 | "run_vmfb": True, 16 | "compile_to": "vmfb", 17 | "vmfb_path": "", 18 | "device": "local-task", 19 | "target_triple": "x86_64-unknown-linux-gnu", 20 | "vulkan_max_allocation": "4294967296", 21 | "precision": "fp32", 22 | } 23 | resnet_18.export_resnet_18_model( 24 | resnet_model, 25 | "vmfb", 26 | "cpu", 27 | ) 28 | namespace = AttributeDict(arguments) 29 | err = resnet_18.run_resnet_18_vmfb_comparison(resnet_model, namespace) 30 | assert err < 1e-5 31 | 32 | def testExportResnet18ModelStaticGFX1100(self): 33 | arguments = { 34 | "run_vmfb": True, 35 | "compile_to": "vmfb", 36 | "vmfb_path": "", 37 | "device": "rocm", 38 | "target_triple": "gfx1100", 39 | "vulkan_max_allocation": "4294967296", 40 | "precision": "fp16", 41 | } 42 | resnet_18.export_static_resnet_18_model( 43 | resnet_model, 44 | "vmfb", 45 | "rocm", 46 | arguments["target_triple"], 47 | ) 48 | namespace = AttributeDict(arguments) 49 | rocm_err = resnet_18.run_resnet_18_vmfb_comparison(resnet_model, namespace) 50 | namespace.device = "hip" 51 | hip_err = resnet_18.run_resnet_18_vmfb_comparison(resnet_model, namespace) 52 | print("ROCM ERROR:", rocm_err) 53 | print("HIP ERROR:", hip_err) 54 | assert rocm_err < 1e-5 55 | assert hip_err < 1e-5 56 | 57 | # def tearDown(self): 58 | # if os.path.exists("resnet_18.vmfb"): 59 | # os.remove("resnet_18.vmfb") 60 | # if os.path.exists("resnet_18.mlir"): 61 | # os.remove("resnet_18.mlir") 62 | 63 | 64 | class AttributeDict(dict): 65 | def __getattr__(self, attr): 66 | return self[attr] 67 | 68 | def __setattr__(self, attr, value): 69 | self[attr] = value 70 | 71 | 72 | if __name__ == "__main__": 73 | logging.basicConfig(level=logging.DEBUG) 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /models/turbine_models/tests/vmfb_comparison_cached_torch_output_f32_unquantized.txt: -------------------------------------------------------------------------------- 1 | Hello! I'm just an AI assistant, I don't have personal experiences or feelings. I'm here to help answer your questions to the best of my ability, but I can't provide false information. If a question doesn't make sense or is not factually coherent, I will let you know. Please feel free to ask me anything! -------------------------------------------------------------------------------- /models/turbine_models/turbine_tank/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nod-ai/SHARK-ModelDev/bdf46351281f47843bb0d8c77bcd8ecde7271b60/models/turbine_models/turbine_tank/__init__.py -------------------------------------------------------------------------------- /models/turbine_models/turbine_tank/turbine_tank.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Advanced Micro Devices, Inc 2 | # 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | 7 | from azure.storage.blob import BlobServiceClient 8 | 9 | import subprocess 10 | import datetime 11 | import os 12 | from pathlib import Path 13 | from functools import cmp_to_key 14 | import logging 15 | 16 | custom_path = os.getenv("TURBINE_TANK_CACHE_DIR") 17 | if custom_path is not None: 18 | if not os.path.exists(custom_path): 19 | os.mkdir(custom_path) 20 | 21 | WORKDIR = custom_path 22 | 23 | logging.info(f"Using {WORKDIR} as local turbine_tank cache directory.") 24 | else: 25 | WORKDIR = os.path.join(str(Path.home()), ".local/turbine_tank/") 26 | logging.info( 27 | f"turbine_tank local cache is located at {WORKDIR} . You may change this by assigning the TURBINE_TANK_CACHE_DIR environment variable." 28 | ) 29 | os.makedirs(WORKDIR, exist_ok=True) 30 | 31 | connection_string = os.environ.get("AZURE_CONNECTION_STRING") 32 | CONTAINER_NAME = os.environ.get("AZURE_CONTAINER_NAME") 33 | 34 | 35 | def get_short_git_sha() -> str: 36 | try: 37 | return ( 38 | subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) 39 | .decode("utf-8") 40 | .strip() 41 | ) 42 | except FileNotFoundError: 43 | return None 44 | 45 | 46 | def changeBlobName(old_blob_name, new_blob_name): 47 | blob_service_client = BlobServiceClient.from_connection_string(connection_string) 48 | blob_client = blob_service_client.get_blob_client(container_name, old_blob_name) 49 | new_blob_client = blob_service_client.get_blob_client(container_name, new_blob_name) 50 | 51 | blob = blob_client.from_connection_string( 52 | conn_str=connection_string, 53 | container_name=container_name, 54 | blob_name=blob_client.blob_name, 55 | ) 56 | if not blob.exists(): 57 | logging.warning("Blob to be renamed does not exist.") 58 | return 59 | 60 | # Copy the blob to the new name 61 | new_blob_client.start_copy_from_url(blob_client.url) 62 | 63 | # Delete the original blob 64 | blob_client.delete_blob() 65 | logging.info("The blob is Renamed successfully:", {new_blob_name}) 66 | 67 | 68 | def uploadToBlobStorage(file_path, file_name): 69 | # create our prefix (we use this to keep track of when and what version of turbine is being used) 70 | today = str(datetime.date.today()) 71 | commit = get_short_git_sha() 72 | prefix = today + "_" + commit 73 | blob_service_client = BlobServiceClient.from_connection_string(connection_string) 74 | blob_client = blob_service_client.get_blob_client( 75 | container=CONTAINER_NAME, blob=prefix + "/" + file_name 76 | ) 77 | blob = blob_client.from_connection_string( 78 | conn_str=connection_string, 79 | CONTAINER_NAME=CONTAINER_NAME, 80 | blob_name=blob_client.blob_name, 81 | ) 82 | # we check to see if we already uploaded the blob (don't want to duplicate) 83 | if blob.exists(): 84 | logging.info( 85 | f"model artifacts have already been uploaded for {today} on the same github commit ({commit})" 86 | ) 87 | return 88 | # upload to azure storage container tankturbine 89 | with open(file_path, "rb") as data: 90 | blob_client.upload_blob(data) 91 | logging.info(f"Uploaded {file_name}.") 92 | return blob_client.blob_name 93 | 94 | 95 | def checkAndRemoveIfDownloadedOld(model_name: str, model_dir: str, prefix: str): 96 | if os.path.isdir(model_dir) and len(os.listdir(model_dir)) > 0: 97 | for item in os.listdir(model_dir): 98 | item_path = os.path.join(model_dir, item) 99 | # model artifacts already downloaded and up to date 100 | # we check if model artifacts are behind using the prefix (day + git_sha) 101 | if os.path.isdir(item_path) and item == prefix: 102 | return True 103 | # model artifacts are behind, so remove for new download 104 | if os.path.isdir(item_path) and os.path.isfile( 105 | os.path.join(item_path, model_name + ".mlir") 106 | ): 107 | os.remove(os.path.join(item_path, model_name + ".mlir")) 108 | os.rmdir(item_path) 109 | return False 110 | if os.path.isdir(item_path) and os.path.isfile( 111 | os.path.join(item_path, model_name + "-param.mlir") 112 | ): 113 | os.remove(os.path.join(item_path, model_name + "-param.mlir")) 114 | os.rmdir(item_path) 115 | return False 116 | # did not downloaded this model artifacts yet 117 | return False 118 | 119 | 120 | def download_public_folder( 121 | model_name: str, prefix: str, model_dir: str, container_name=CONTAINER_NAME 122 | ): 123 | """Downloads a folder of blobs in azure container.""" 124 | blob_service_client = BlobServiceClient.from_connection_string(connection_string) 125 | container_client = blob_service_client.get_container_client( 126 | container=container_name 127 | ) 128 | blob_list = container_client.list_blobs(name_starts_with=prefix) 129 | empty = True 130 | 131 | # go through the blobs with our target prefix 132 | # example prefix: "2024-02-13_26d6428/CompVis_stable-diffusion-v1-4-clip" 133 | for blob in blob_list: 134 | empty = False 135 | blob_client = blob_service_client.get_blob_client( 136 | container=container_name, blob=blob.name 137 | ) 138 | # create path if directory doesn't exist locally 139 | dest_path = model_dir 140 | if not os.path.isdir(dest_path): 141 | os.makedirs(dest_path) 142 | # download blob into local turbine tank cache 143 | if "param" in blob.name: 144 | file_path = os.path.join(model_dir, model_name + "-param.mlir") 145 | else: 146 | file_path = os.path.join(model_dir, model_name + ".mlir") 147 | with open(file=file_path, mode="wb") as sample_blob: 148 | download_stream = blob_client.download_blob() 149 | sample_blob.write(download_stream.readall()) 150 | 151 | if empty: 152 | logging.warning(f"Model ({model_name}) has not been uploaded yet") 153 | return True 154 | 155 | return False 156 | 157 | 158 | # sort blobs by last modified 159 | def compare(item1, item2): 160 | if item1.last_modified < item2.last_modified: 161 | return -1 162 | elif item1.last_modified < item2.last_modified: 163 | return 1 164 | else: 165 | return 0 166 | 167 | 168 | def downloadModelArtifacts(model_name: str, container_name=CONTAINER_NAME) -> str: 169 | model_name = model_name.replace("/", "_") 170 | container_client = BlobServiceClient.from_connection_string( 171 | connection_string 172 | ).get_container_client(container=container_name) 173 | blob_list = container_client.list_blobs() 174 | # get the latest blob uploaded to turbine tank (can't use [] notation for blob_list) 175 | blob_list = sorted(blob_list, key=cmp_to_key(compare)) 176 | for blob in blob_list: 177 | latest_blob = blob 178 | # get the prefix for the latest blob (2024-02-13_26d6428) 179 | download_latest_prefix = latest_blob.name.split("/")[0] 180 | model_dir = os.path.join(WORKDIR, model_name) 181 | # check if we already downloaded the model artifacts for this day + commit 182 | exists = checkAndRemoveIfDownloadedOld( 183 | model_name=model_name, model_dir=model_dir, prefix=download_latest_prefix 184 | ) 185 | if exists: 186 | logging.info("Already downloaded most recent version") 187 | return "NA" 188 | # download the model artifacts (passing in the model name, path in azure storage to model artifacts, local directory to store) 189 | blobDNE = download_public_folder( 190 | model_name, 191 | download_latest_prefix + "/" + model_name, 192 | os.path.join(model_dir, download_latest_prefix), 193 | ) 194 | if blobDNE: 195 | return 196 | model_dir = os.path.join(WORKDIR, model_name + "/" + download_latest_prefix) 197 | mlir_filename = os.path.join(model_dir, model_name + "-param.mlir") 198 | logging.info( 199 | f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..." 200 | ) 201 | assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}" 202 | 203 | return mlir_filename 204 | -------------------------------------------------------------------------------- /models/turbine_models/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from collections import namedtuple 3 | import iree.runtime as ireert 4 | import numpy as np 5 | 6 | 7 | BenchmarkResult = namedtuple( 8 | "BenchmarkResult", "benchmark_name time cpu_time iterations user_counters" 9 | ) 10 | 11 | 12 | class BenchmarkToolError(Exception): 13 | """Benchmark exception that preserves the command line and error output.""" 14 | 15 | def __init__(self, message): 16 | self.message = message 17 | super().__init__(self.message) 18 | 19 | 20 | class BenchmarkTimeoutError(Exception): 21 | """Exception raised if the benchmark is cancelled by the user specified timeout.""" 22 | 23 | pass 24 | 25 | 26 | DTYPE_TO_ABI_TYPE = { 27 | np.dtype(np.float32): "f32", 28 | np.dtype(np.float16): "f16", 29 | np.dtype(np.int32): "i32", 30 | np.dtype(np.int64): "i64", 31 | np.dtype(np.float64): "f64", 32 | np.dtype(np.int16): "i16", 33 | np.dtype(np.int8): "i8", 34 | np.dtype(np.bool_): "i1", 35 | } 36 | 37 | 38 | def benchmark_module( 39 | module, 40 | entry_function=None, 41 | vmfbs=[], 42 | inputs=[], 43 | tracy_profile=False, 44 | timeout=None, 45 | **kwargs, 46 | ): 47 | funcs = [a for a in module.function_names if a != "__init"] 48 | if entry_function is None: 49 | if len(funcs) > 1: 50 | raise ValueError(f"No function specified with multiple options {funcs}") 51 | entry_function = funcs[0] 52 | if entry_function not in funcs: 53 | raise ValueError( 54 | f"Attempted to benchmark unknown function {entry_function} of options {funcs}" 55 | ) 56 | 57 | args = [] 58 | if tracy_profile: 59 | args.append("TRACY_NO_EXIT=1") 60 | # TODO: run iree-tracy-capture subprocess 61 | args.append(ireert.benchmark_exe()) 62 | args.append(f"--function={entry_function}") 63 | 64 | for inp in inputs: 65 | if isinstance(inp, str): 66 | args.append(f"--input={inp}") 67 | continue 68 | shape = "x".join([str(d) for d in inp.shape]) 69 | abitype = DTYPE_TO_ABI_TYPE[inp.dtype] 70 | values = inp.flatten() 71 | if np.all(values[0] == values): 72 | values = str(values[0]) 73 | else: 74 | values = ",".join([str(v) for v in values]) 75 | input_arg = f"--input={shape}x{abitype}={values}" 76 | if len(input_arg) > 256: 77 | print( 78 | f"Randomizing {input_arg.split('=')[0]} because it is too long for subprocess.run" 79 | ) 80 | input_arg = f"--input={shape}x{abitype}" 81 | args.append(input_arg) 82 | print(args) 83 | 84 | for k in kwargs: 85 | v = kwargs[k] 86 | args.append(f"--{k}={v}") 87 | 88 | for vmfb in vmfbs: 89 | args.append(f"--module={vmfb}") 90 | 91 | try: 92 | benchmark_process = subprocess.run( 93 | args=args, 94 | # input=flatbuffer, 95 | timeout=timeout, 96 | stdout=subprocess.PIPE, 97 | stderr=subprocess.PIPE, 98 | ) 99 | except subprocess.TimeoutExpired: 100 | raise BenchmarkTimeoutError(f"Benchmark timed out after {timeout} seconds") 101 | out = benchmark_process.stdout 102 | err = benchmark_process.stderr 103 | 104 | err = err.decode() 105 | if "INVALID_ARGUMENT;" in err: 106 | raise ValueError("Invalid inputs specified for benchmarking") 107 | 108 | # In the event benchmarking runs but encounteres an internal error, 109 | # return the internal error instead of benchmark results. 110 | if "INTERNAL; CUDA driver error" in str(out): 111 | raise BenchmarkToolError(str(out)) 112 | 113 | # Grab individual results by line (skip header lines) 114 | bench_lines = out.decode().split("\n")[3:] 115 | benchmark_results = [] 116 | for line in bench_lines: 117 | split = line.split() 118 | if len(split) == 0: 119 | continue 120 | benchmark_name = split[0] 121 | time = " ".join(split[1:3]) 122 | cpu_time = " ".join(split[3:5]) 123 | iterations = split[5] 124 | user_counters = None 125 | if len(split) > 5: 126 | user_counters = split[6] 127 | benchmark_results.append( 128 | BenchmarkResult( 129 | benchmark_name=benchmark_name, 130 | time=time, 131 | cpu_time=cpu_time, 132 | iterations=iterations, 133 | user_counters=user_counters, 134 | ) 135 | ) 136 | 137 | return benchmark_results 138 | -------------------------------------------------------------------------------- /models/turbine_models/utils/sdxl_benchmark.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from iree import runtime as ireert 3 | from turbine_models.utils.benchmark import benchmark_module 4 | 5 | 6 | DTYPE_MAP = { 7 | "fp16": "f16", 8 | "fp32": "f32", 9 | } 10 | 11 | 12 | def run_benchmark( 13 | model, 14 | vmfb_path, 15 | weights_path, 16 | device, 17 | max_length=None, 18 | height=None, 19 | width=None, 20 | batch_size=None, 21 | in_channels=None, 22 | precision=None, 23 | tracy_profile=False, 24 | ): 25 | config = ireert.Config(device) 26 | 27 | if not vmfb_path: 28 | sys.exit("no vmfb_path provided, required for run_benchmark") 29 | benchmark_mod = ireert.VmModule.mmap(config.vm_instance, vmfb_path) 30 | 31 | if weights_path: 32 | index = ireert.ParameterIndex() 33 | index.load(weights_path) 34 | 35 | vmfbs = [] 36 | vmfbs.append(vmfb_path) 37 | 38 | inputs = [] 39 | match model: 40 | case "clip_1": 41 | inputs.append(f"1x{max_length}xi64") 42 | case "clip_2": 43 | inputs.append(f"1x{max_length}xi64") 44 | case "prompt_encoder": 45 | inputs.extend([f"1x{max_length}xi64"] * 4) 46 | case "unet": 47 | inputs.append( 48 | f"{batch_size}x{in_channels}x{height//8}x{width//8}x{DTYPE_MAP[precision]}" 49 | ) 50 | inputs.append(f"1x{DTYPE_MAP[precision]}") 51 | inputs.append(f"{2*batch_size}x{max_length}x2048x{DTYPE_MAP[precision]}") 52 | inputs.append(f"{2*batch_size}x1280x{DTYPE_MAP[precision]}") 53 | inputs.append(f"{2*batch_size}x6x{DTYPE_MAP[precision]}") 54 | inputs.append(f"1x{DTYPE_MAP[precision]}") 55 | case "vae_decode": 56 | inputs.append(f"1x4x{height//8}x{width//8}x{DTYPE_MAP[precision]}") 57 | case "vae_encode": 58 | inputs.append(f"1x3x{height}x{width}x{DTYPE_MAP[precision]}") 59 | case _: 60 | sys.exit("model name doesn't match for inputs") 61 | 62 | if weights_path: 63 | results = benchmark_module( 64 | benchmark_mod, 65 | "main", 66 | vmfbs, 67 | inputs, 68 | tracy_profile, 69 | parameters=f"model={weights_path}", 70 | ) 71 | else: 72 | results = benchmark_module(benchmark_mod, "main", vmfbs, inputs, tracy_profile) 73 | 74 | for benchmark_result in results: 75 | print( 76 | f"model: {model}, benchmark_name: {benchmark_result.benchmark_name}, time: {benchmark_result.time}, cpu_time: {benchmark_result.cpu_time}, iterations: {benchmark_result.iterations}, user_counters: {benchmark_result.user_counters}" 77 | ) 78 | -------------------------------------------------------------------------------- /mypy-requirements.txt: -------------------------------------------------------------------------------- 1 | # Typing packages needed for full mypy execution at the project level. 2 | mypy==1.8.0 3 | types-requests 4 | -------------------------------------------------------------------------------- /version_info.json: -------------------------------------------------------------------------------- 1 | {"core-version": "2.3.0rc20240410", "package-version": "0.9.7.dev1"} --------------------------------------------------------------------------------