├── .dockerignore ├── .github └── workflows │ └── evaluate.yaml ├── .gitignore ├── .platform ├── DOCKERFILE ├── README.md ├── docker-compose.yaml ├── merge │ ├── .python-version │ ├── README.md │ ├── pyproject.toml │ └── src │ │ └── merge │ │ ├── __init__.py │ │ └── main.py └── static │ ├── .nojekyll │ ├── index.html │ ├── script.js │ └── style.css ├── .pre-commit-config.yaml ├── README.md ├── RULES.md ├── pyproject.toml ├── src └── eval │ ├── evaluate.py │ ├── submission │ ├── README.md │ ├── __init__.py │ └── model.py │ └── utils.py ├── uv.lock └── zizmor.yml /.dockerignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | .github/ 3 | **/.venv/ 4 | -------------------------------------------------------------------------------- /.github/workflows/evaluate.yaml: -------------------------------------------------------------------------------- 1 | name: Evaluation Workflow 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | pr_number: 7 | description: "Pull Request Number" 8 | required: true 9 | type: string 10 | pr_repo: 11 | description: "Pull Request Repository (format: owner/repo)" 12 | required: true 13 | type: string 14 | pr_branch: 15 | description: "Pull Request Branch" 16 | required: true 17 | type: string 18 | pr_user: 19 | description: "Pull Request Author" 20 | required: true 21 | type: string 22 | pr_title: 23 | description: "Pull Request Title" 24 | required: true 25 | type: string 26 | 27 | concurrency: 28 | group: ${{ github.workflow }}-${{ github.event.pull_request.user.login }} 29 | cancel-in-progress: true 30 | 31 | jobs: 32 | security-check: 33 | name: Security Check 34 | runs-on: vand2025-runner 35 | permissions: 36 | contents: read 37 | steps: 38 | - name: Checkout code 39 | uses: actions/checkout@v4 40 | with: 41 | persist-credentials: false 42 | - name: Run Zizmor security scan 43 | run: | 44 | uv tool install zizmor 45 | uvx zizmor --config zizmor.yml .github/workflows/evaluate.yaml 46 | 47 | evaluate: 48 | needs: security-check 49 | runs-on: vand2025-runner 50 | permissions: 51 | contents: read 52 | pull-requests: read 53 | steps: 54 | - name: Debug PR info 55 | run: | 56 | echo "Branch: ${{ github.event_name == 'workflow_dispatch' && inputs.pr_branch || github.event.pull_request.head.ref }}" 57 | echo "Repo: ${{ github.event_name == 'workflow_dispatch' && inputs.pr_repo || github.event.pull_request.head.repo.full_name }}" 58 | - name: Checkout PR code for evaluation 59 | uses: actions/checkout@v4 60 | with: 61 | repository: ${{ github.event_name == 'workflow_dispatch' && inputs.pr_repo || github.event.pull_request.head.repo.full_name }} 62 | ref: ${{ github.event_name == 'workflow_dispatch' && inputs.pr_branch || github.event.pull_request.head.ref }} 63 | persist-credentials: false 64 | - name: Save PR User and Number 65 | env: 66 | PR_USER: ${{ github.event_name == 'workflow_dispatch' && inputs.pr_user || github.event.pull_request.user.login }} 67 | PR_NUMBER: ${{ github.event_name == 'workflow_dispatch' && inputs.pr_number || github.event.pull_request.number }} 68 | PR_NAME: ${{ github.event_name == 'workflow_dispatch' && inputs.pr_title || github.event.pull_request.title }} 69 | run: | 70 | echo $PR_USER > pr_user.txt 71 | echo $PR_NUMBER > pr_number.txt 72 | echo $PR_NAME > pr_name.txt 73 | - name: Run evaluation 74 | timeout-minutes: 300 # 5 hours timeout 75 | run: | 76 | GIT_LFS_SKIP_SMUDGE=1 uv sync 77 | uv run eval --dataset_path=/home/user/datasets/mvtec_loco 78 | - name: Upload evaluation results and SHA 79 | uses: actions/upload-artifact@v4 80 | with: 81 | name: evaluation-results 82 | path: | 83 | results.json 84 | metrics.csv 85 | pr_user.txt 86 | pr_number.txt 87 | pr_name.txt 88 | retention-days: 30 89 | 90 | publish: 91 | needs: evaluate 92 | runs-on: vand2025-runner 93 | permissions: 94 | contents: write 95 | pull-requests: write 96 | strategy: 97 | max-parallel: 1 98 | steps: 99 | - name: Checkout PR code 100 | uses: actions/checkout@v4 101 | with: 102 | persist-credentials: false 103 | - name: Configure Git 104 | run: | 105 | git config --global user.name "GitHub Actions Bot" 106 | git config --global user.email "actions@github.com" 107 | - name: Save static files and SHA 108 | run: | 109 | # Clear temporary directories if they exist 110 | rm -rf /tmp/static 111 | rm -rf /tmp/merge 112 | # Create a temporary directory and copy static files for the leaderboard website 113 | mkdir -p /tmp/static 114 | cp -r .platform/static/* /tmp/static/ 115 | # Save the PR SHA 116 | echo $(git rev-parse HEAD) > /tmp/pr_sha 117 | # Save merger code 118 | cp -r .platform/merge /tmp/merge 119 | - name: Checkout gh-pages branch 120 | run: | 121 | # Try to fetch gh-pages branch 122 | if ! git fetch origin gh-pages; then 123 | # If branch doesn't exist, create an orphan branch 124 | git checkout --orphan gh-pages 125 | git rm -rf . 126 | touch .nojekyll 127 | git add .nojekyll 128 | git commit -m "Initialize gh-pages branch" 129 | # Removed initial push, the action below will handle it. 130 | else 131 | git checkout gh-pages 132 | fi 133 | - name: Download evaluation results 134 | uses: actions/download-artifact@v4 135 | with: 136 | name: evaluation-results 137 | path: /tmp/merge 138 | - name: Merge results 139 | run: | 140 | PR_NAME=$(cat /tmp/merge/pr_name.txt) 141 | PR_NUMBER=$(cat /tmp/merge/pr_number.txt) 142 | PR_CREATOR=$(cat /tmp/merge/pr_user.txt) 143 | # Copy results.csv to the merge directory if it exists 144 | if [ -f results.csv ]; then 145 | cp results.csv /tmp/merge/results.csv 146 | fi 147 | # Run the merge script 148 | cd /tmp/merge 149 | uv sync 150 | 151 | TIMESTAMP=$(date -u +"%Y-%m-%dT%H:%M:%SZ") 152 | PR_SHA=$(cat /tmp/pr_sha) 153 | # Use environment variables for the uv run command 154 | uv run merge --pr_name="${PR_NAME}" --pr_number="${PR_NUMBER}" --pr_author="${PR_CREATOR}" --timestamp="${TIMESTAMP}" --pr_sha="${PR_SHA}" 155 | 156 | # Copy the merged results.csv back to the gh-pages branch 157 | cp results.csv /tmp/static/results.csv 158 | - name: Copy Static files back to gh-pages 159 | run: | 160 | # Copy the saved static files to the gh-pages branch 161 | cp -r /tmp/static/* . 162 | - name: Commit results # Renamed from Push New Results 163 | run: | 164 | PR_NAME=$(cat /tmp/merge/pr_name.txt) 165 | PR_NUMBER=$(cat /tmp/merge/pr_number.txt) 166 | PR_CREATOR=$(cat /tmp/merge/pr_user.txt) 167 | # Commit changes 168 | git add . 169 | git commit -m "Add evaluation results for PR \"${PR_NAME}\" (#${PR_NUMBER}) by ${PR_CREATOR}" 170 | # Removed git push command 171 | - name: Push to gh-pages 172 | uses: ad-m/github-push-action@d91a481090679876dfc4178fef17f286781251df 173 | with: 174 | github_token: ${{ secrets.GITHUB_TOKEN }} 175 | branch: gh-pages 176 | 177 | add-results-comment: 178 | needs: evaluate 179 | runs-on: vand2025-runner 180 | permissions: 181 | contents: read 182 | pull-requests: write 183 | issues: write 184 | steps: 185 | - name: Download evaluation results 186 | uses: actions/download-artifact@v4 187 | with: 188 | name: evaluation-results 189 | path: /tmp/merge 190 | - name: Add results comment 191 | uses: actions/github-script@v7 192 | with: 193 | github-token: ${{ secrets.GITHUB_TOKEN }} 194 | script: | 195 | const fs = require('fs'); 196 | const PR_USER_LOGIN = fs.readFileSync('/tmp/merge/pr_user.txt', 'utf8'); 197 | const PR_NUMBER = fs.readFileSync('/tmp/merge/pr_number.txt', 'utf8'); 198 | 199 | const result = JSON.parse(fs.readFileSync('/tmp/merge/results.json', 'utf8')); 200 | let result_string = ""; 201 | for (const [key, value] of Object.entries(result)) { 202 | result_string += `${key}: ${value}\n`; 203 | } 204 | const raw_metrics = fs.readFileSync('/tmp/merge/metrics.csv', 'utf8'); 205 | raw_metrics_string = "||seed|k_shot|category|image_score|pixel_score|\n"; 206 | raw_metrics_string += "|---|---|---|---|---|---|\n"; 207 | raw_metrics.split('\n').forEach(line => { 208 | raw_metrics_string += `|${line.split(',')[0]}|${line.split(',')[1]}|${line.split(',')[2]}|${line.split(',')[3]}|${line.split(',')[4]}|${line.split(',')[5]}|\n`; 209 | }); 210 | 211 | const comment = `## Evaluation Results 212 | ${result_string} 213 | ${raw_metrics_string} 214 | Created by: ${PR_USER_LOGIN}`; 215 | 216 | await github.rest.issues.createComment({ 217 | owner: context.repo.owner, 218 | repo: context.repo.repo, 219 | issue_number: PR_NUMBER, 220 | body: comment 221 | }); 222 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.pyo 4 | *.pyd 5 | *.pyw 6 | *.pyz 7 | 8 | **/.venv/ 9 | *.egg-info/ 10 | .env 11 | datasets 12 | -------------------------------------------------------------------------------- /.platform/DOCKERFILE: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.6.1-devel-ubuntu20.04 AS vand_runner 2 | 3 | ENV DEBIAN_FRONTEND="noninteractive" 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | curl \ 7 | software-properties-common \ 8 | cmake \ 9 | pkg-config \ 10 | ffmpeg 11 | 12 | 13 | # Install latest git for github actions 14 | RUN add-apt-repository ppa:git-core/ppa &&\ 15 | apt-get update && \ 16 | apt-get install --no-install-recommends -y git 17 | 18 | # Prettier requires atleast nodejs 10 and actions/checkout requires nodejs 16 19 | RUN curl -sL https://deb.nodesource.com/setup_current.x > nodesetup.sh && \ 20 | bash - nodesetup.sh && \ 21 | apt-get install --no-install-recommends -y nodejs && \ 22 | apt-get clean && \ 23 | rm -rf /var/lib/apt/lists/* 24 | # the list is cleared on the final stage 25 | 26 | # Create a non-root user 27 | RUN useradd -m user 28 | USER user 29 | 30 | # Install uv 31 | RUN curl -LsSf https://astral.sh/uv/install.sh | sh 32 | # Ensure the installed binary is on the `PATH` 33 | ENV PATH="/home/user/.local/bin/:$PATH" 34 | 35 | # Copy codebase and install dependencies 36 | # This helps in caching the dependencies 37 | COPY --chown=user:user ./eval /home/user/eval 38 | RUN cd /home/user/eval/ && uv sync 39 | 40 | # Download gh-actions runner 41 | RUN mkdir -p /home/user/actions-runner 42 | WORKDIR /home/user/actions-runner 43 | RUN curl -o actions-runner-linux-x64-2.322.0.tar.gz -L https://github.com/actions/runner/releases/download/v2.322.0/actions-runner-linux-x64-2.322.0.tar.gz && \ 44 | tar xzf ./actions-runner-linux-x64-2.322.0.tar.gz && \ 45 | rm ./actions-runner-linux-x64-2.322.0.tar.gz 46 | 47 | WORKDIR /home/user/ 48 | -------------------------------------------------------------------------------- /.platform/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | ## To only build the docker image 4 | 5 | ``` 6 | docker build . -t vand2025-runner -f .platform/DOCKERFILE 7 | ``` 8 | 9 | ## To run the docker image with id 10 | 11 | ``` 12 | docker run --gpus '"device="' --shm-size 2G --memory 142g \ 13 | -i -t --mount type=bind,source=./datasets,target=/home/user/datasets,readonly \ 14 | -d --name vand2025-runner-container- vand2025-runner 15 | ``` 16 | 17 | ## Configure the runner in detached mode 18 | 19 | ``` 20 | docker exec -it -d vand2025-runner-container- /bin/bash -c ' 21 | if [ ! -f /home/user/actions-runner/.credentials ]; then 22 | ./actions-runner/config.sh --url https://github.com/samet-akcay/vand-cvpr --token $RUNNER_TOKEN --labels vand2025-runner --unattended 23 | fi 24 | ./actions-runner/run.sh 25 | ' 26 | ``` 27 | 28 | ## To use GPU 3 29 | 30 | ensure that you cd into .platform/ 31 | 32 | ``` 33 | GPU_ID=0 RUNNER_TOKEN=your_token docker compose up 34 | ``` 35 | -------------------------------------------------------------------------------- /.platform/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | runner-0: # change this to runner-1, runner-2, etc. 3 | build: 4 | context: .. 5 | dockerfile: .platform/DOCKERFILE 6 | deploy: 7 | resources: 8 | reservations: 9 | devices: 10 | - driver: nvidia 11 | count: 1 12 | capabilities: [gpu] 13 | limits: 14 | memory: 142g # 1/7 of the total memory in the server 15 | mem_limit: 142g 16 | shm_size: 2G 17 | environment: 18 | - RUNNER_TOKEN=${RUNNER_TOKEN} # Pass GitHub runner token as env variable 19 | - NVIDIA_VISIBLE_DEVICES=${GPU_ID:-0} # Specify which GPU to use, default to 0 20 | - HTTP_PROXY=${http_proxy:-} 21 | - HTTPS_PROXY=${https_proxy:-} 22 | - NO_PROXY=${no_proxy:-} 23 | volumes: 24 | - /mnt/data/datasets/datasets/:/home/user/datasets:ro # Mount datasets as read-only 25 | entrypoint: | 26 | /bin/bash -c ' 27 | if [ ! -f /home/user/actions-runner/.credentials ]; then 28 | ./actions-runner/config.sh --url https://github.com/samet-akcay/vand-cvpr --token $RUNNER_TOKEN --labels vand2025-runner --unattended 29 | fi 30 | ./actions-runner/run.sh 31 | ' 32 | -------------------------------------------------------------------------------- /.platform/merge/.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /.platform/merge/README.md: -------------------------------------------------------------------------------- 1 | # Merge results 2 | 3 | Python script to merge the results.json file into the results.csv file 4 | -------------------------------------------------------------------------------- /.platform/merge/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "merge" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = ["pandas"] 8 | 9 | [tool.hatch.build.targets.wheel] 10 | packages = ["src/merge"] 11 | 12 | [project.scripts] 13 | merge = "merge.main:run" 14 | 15 | [build-system] 16 | requires = ["hatchling"] 17 | build-backend = "hatchling.build" 18 | -------------------------------------------------------------------------------- /.platform/merge/src/merge/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvpr-vand/challenge/dd189038f4c6974033e5ef705b6b68f918e87899/.platform/merge/src/merge/__init__.py -------------------------------------------------------------------------------- /.platform/merge/src/merge/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | 7 | 8 | def main(pr_name: str, pr_number: int, pr_author: str, timestamp: str, pr_sha: str): 9 | with open("results.json", "r") as f: 10 | results = json.load(f) 11 | 12 | results["pr_name"] = pr_name 13 | results["pr_number"] = pr_number 14 | results["pr_author"] = pr_author 15 | results["timestamp"] = timestamp 16 | results["pr_sha"] = pr_sha 17 | 18 | if not Path("results.csv").exists(): 19 | df = pd.DataFrame(results, index=[0]) 20 | else: 21 | df = pd.read_csv("results.csv") 22 | existing_entry = df.query(f"pr_author == '{pr_author}' and pr_number == {pr_number}") 23 | if existing_entry.empty: 24 | df = pd.concat([df, pd.DataFrame([results])], ignore_index=True) 25 | else: 26 | df.loc[existing_entry.index[0], :] = pd.Series(results) 27 | 28 | # Sort by avg_image_score then by normalized_aufc, and then aufc in descending order 29 | df = df.sort_values(by=["avg_image_score", "normalized_aufc", "aufc"], ascending=False) 30 | df.to_csv("results.csv", index=False) 31 | 32 | 33 | def run(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--pr_name", type=str, required=True) 36 | parser.add_argument("--pr_number", type=int, required=True) 37 | parser.add_argument("--pr_author", type=str, required=True) 38 | parser.add_argument("--timestamp", type=str, required=True) 39 | parser.add_argument("--pr_sha", type=str, required=True) 40 | args = parser.parse_args() 41 | main(args.pr_name, args.pr_number, args.pr_author, args.timestamp, args.pr_sha) 42 | 43 | 44 | if __name__ == "__main__": 45 | run() 46 | -------------------------------------------------------------------------------- /.platform/static/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvpr-vand/challenge/dd189038f4c6974033e5ef705b6b68f918e87899/.platform/static/.nojekyll -------------------------------------------------------------------------------- /.platform/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | VAND 2025 Evaluation 5 | 6 | 7 | 8 |

VAND 2025 Evaluation

9 |
10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /.platform/static/script.js: -------------------------------------------------------------------------------- 1 | async function loadResults() { 2 | try { 3 | // Fetch the CSV file 4 | const response = await fetch("results.csv"); 5 | const csvText = await response.text(); 6 | 7 | // Parse CSV (skip header row) 8 | let header = csvText.split("\n")[0]; 9 | header = header.replace(/_/g, " "); 10 | header = header.replace(/\b\w/g, (char) => char.toUpperCase()); 11 | const rows = csvText 12 | .split("\n") 13 | .slice(1) 14 | .filter((row) => row.trim() !== ""); 15 | 16 | // Create table HTML 17 | const table = ` 18 | 19 | 20 | 21 | ${header 22 | .split(",") 23 | .map((col) => ``) 24 | .join("")} 25 | 26 | 27 | 28 | ${rows 29 | .map( 30 | (row) => 31 | `${row 32 | .split(",") 33 | .map((col) => ``) 34 | .join("")}`, 35 | ) 36 | .join("")} 37 | 38 |
${col}
${col}
39 | `; 40 | 41 | // Add table to content div 42 | document.getElementById("content").innerHTML = table; 43 | } catch (error) { 44 | console.error("Error loading results:", error); 45 | document.getElementById("content").innerHTML = 46 | "

Error loading results. Please try again later.

"; 47 | } 48 | } 49 | 50 | // Load results when page loads 51 | document.addEventListener("DOMContentLoaded", loadResults); 52 | -------------------------------------------------------------------------------- /.platform/static/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, 3 | Arial, sans-serif; 4 | max-width: 1200px; 5 | margin: 0 auto; 6 | padding: 20px; 7 | } 8 | 9 | table { 10 | width: 100%; 11 | border-collapse: collapse; 12 | margin-top: 20px; 13 | } 14 | 15 | th, 16 | td { 17 | padding: 12px; 18 | text-align: left; 19 | border-bottom: 1px solid #ddd; 20 | } 21 | 22 | th { 23 | background-color: #f5f5f5; 24 | font-weight: 600; 25 | } 26 | 27 | tr:hover { 28 | background-color: #f9f9f9; 29 | } 30 | 31 | a { 32 | color: #0366d6; 33 | text-decoration: none; 34 | } 35 | 36 | a:hover { 37 | text-decoration: underline; 38 | } 39 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-prettier 3 | rev: v3.1.0 4 | hooks: 5 | - id: prettier 6 | - repo: https://github.com/astral-sh/ruff-pre-commit 7 | # Ruff version. 8 | rev: v0.4.4 9 | hooks: 10 | # Run the linter. 11 | - id: ruff 12 | args: [--fix] 13 | # Run the formatter. 14 | - id: ruff-format 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAND 2025 Evaluation Framework 2 | 3 | This repository contains the official evaluation script and framework for the **Visual Anomaly and Novelty Detection (VAND) 2025 Challenge**. It focuses on evaluating few-shot anomaly detection performance on the MVTec LOCO dataset. 4 | 5 | > [!IMPORTANT] 6 | > Participants should only modify the code within the `src/eval/submission` directory to implement their solution. Do **not** modify the core evaluation scripts in `src/eval`. Refer to `RULES.md` and the documentation within `src/eval/submission` for detailed submission guidelines. 7 | 8 | This repository contains: 9 | 10 | - The evaluation framework and scripts located in the `src/eval/` directory. 11 | - A submission template for participants in `src/eval/submission/`. 12 | - The official challenge rules in `RULES.md`. 13 | 14 | ## Getting Started 15 | 16 | 1. **Familiarize yourself with the rules:** Read `RULES.md` carefully. 17 | 2. **Explore the evaluation framework:** Understand how submissions will be evaluated by reviewing the documentation in `src/eval/submission/README.md`. 18 | 3. **Implement your solution:** Modify the code within `src/eval/submission/` according to the guidelines. 19 | 20 | ## Overview 21 | 22 | The evaluation framework is designed to: 23 | 24 | 1. Load the MVTec LOCO dataset for specific categories. 25 | 2. Instantiate the participant's anomaly detection model (defined in `src/eval/submission/model.py`). 26 | 3. Perform a few-shot evaluation protocol: 27 | - For each category, seed, and specified `k_shot` value: 28 | - Randomly sample `k_shot` images from the training set for model setup/adaptation. 29 | - Evaluate the model on the test set. 30 | - Calculate the image-level F1Max score. 31 | 4. Aggregate results across different seeds and k-shot values. 32 | 5. Compute final metrics: Area Under the F1-max Curve (AUFC), normalized AUFC, and the average image F1Max score. 33 | 34 | ## Dependencies 35 | 36 | This project uses [`uv`](https://github.com/astral-sh/uv) for package management. Key dependencies include: 37 | 38 | - `anomalib`: For MVTec LOCO dataset handling and metrics. 39 | - `scikit-learn`: For calculating the Area Under Curve (AUC). 40 | - `torch`: The deep learning framework. 41 | 42 | ## Project Structure 43 | 44 | ``` 45 | . 46 | ├── src/ 47 | │ └── eval/ 48 | │ ├── evaluate.py # Core evaluation script (DO NOT MODIFY) 49 | │ └── submission/ # <-- PARTICIPANT CODE GOES HERE 50 | │ ├── model.py # Define your Model class here 51 | │ ├── __init__.py 52 | │ └── README.md # Submission-specific instructions 53 | ├── datasets/ # (Optional) Default location for datasets 54 | │ └── mvtec_loco/ 55 | ├── results.json # Final aggregated evaluation results (auto generated) 56 | ├── metrics.csv # Detailed results per category/seed/k-shot (auto generated) 57 | ├── pyproject.toml # Main project configuration (dependencies, scripts) # <-- Add optional dependencies here 58 | ├── uv.lock # Dependency lock file 59 | ├── README.md # Top-level README for the challenge 60 | └── RULES.md # Official challenge rules 61 | 62 | ``` 63 | 64 | ## How to Use 65 | 66 | 1. **Install Dependencies:** 67 | 68 | - Ensure you have `uv` installed. 69 | - Sync the environment: 70 | `uv sync` 71 | This installs all necessary dependencies defined in `pyproject.toml` and `uv.lock` into a virtual environment (`.venv/`). 72 | 73 | 2. **Implement Your Model:** 74 | 75 | - Go to `src/eval/submission/model.py`. 76 | - Implement your anomaly detection logic within the `Model` class, adhering to the required interface (methods like `__init__`, `setup`, `forward`, `weights_url`). See the README within `src/eval/submission` for specifics. 77 | 78 | 3. **Prepare Dataset:** 79 | 80 | - Download the MVTec LOCO dataset. 81 | - By default, the script expects it at `./datasets/mvtec_loco`. You can change this using the `--dataset_path` argument. 82 | 83 | 4. **Run Evaluation:** 84 | - Activate the virtual environment: `source .venv/bin/activate` (or equivalent for your shell). 85 | - Run the evaluation script using the `eval` command (defined in `pyproject.toml`). 86 | - Use `--help` to see available options: 87 | ```bash 88 | eval --help 89 | ``` 90 | - Example command: 91 | ```bash 92 | eval --k_shots 1 2 4 8 --seeds 42 0 1234 --dataset_path /path/to/your/mvtec_loco 93 | ``` 94 | - `--k_shots`: Specifies the list of k-values for few-shot learning. 95 | - `--seeds`: Specifies the random seeds for reproducibility. Multiple seeds are run to ensure robustness. 96 | - `--dataset_path`: Overrides the default dataset location. 97 | 98 | ## Evaluation Details 99 | 100 | - **Dataset:** MVTec LOCO Anomaly Detection Dataset. The evaluation runs on the following categories: 101 | - `breakfast_box` 102 | - `juice_bottle` 103 | - `pushpins` 104 | - `screw_bag` 105 | - `splicing_connectors` 106 | 107 | ## Output Files 108 | 109 | - `metrics.csv`: A CSV file logging the individual results for every run. 110 | - `results.json`: A JSON file containing the final, aggregated performance metrics. 111 | 112 | ## Questions? 113 | 114 | Refer to the contact information provided in `RULES.md` or the official challenge communication channels. 115 | -------------------------------------------------------------------------------- /RULES.md: -------------------------------------------------------------------------------- 1 | # VAND 2025 Challenge Rules 2 | 3 | **Version:** 1.0 (Last Updated: 07-Apr-2025) 4 | 5 | Please read these rules carefully before participating. 6 | 7 | ## 1. Introduction 8 | 9 | - **Challenge Goal:** Briefly describe the main objective (e.g., develop few-shot anomaly detection models for industrial inspection). 10 | - **Organizers:** You can find the list of organizers [here](https://sites.google.com/view/vand30cvpr2025/home) 11 | - **Website:** https://sites.google.com/view/vand30cvpr2025/home 12 | - **Timeline:** Apr 7th - May26th 13 | 14 | ## 2. Eligibility 15 | 16 | - This challenge is open to individuals, teams, and academic and corporate entities worldwide. 17 | 18 | ## 3. Task Definition 19 | 20 | - **Problem:** Participants will create models using few-shot learning and VLMs to find and localize structural and logical anomalies in the MVTec LOCO AD dataset, which contains images of different industrial products showing both defects. This indicates that the models can handle structural defect detection and logical reasoning. 21 | - **Dataset:** [MVTec LOCO](https://www.mvtec.com/company/research/datasets/mvtec-loco) (specify categories used: `breakfast_box`, `juice_bottle`, `pushpins`, `screw_bag`, `splicing_connectors`). Link to the dataset page. 22 | 23 | - **Input:** 24 | - The only inputs to the models are the k-shot support images and the test image. 25 | - **Output:** 26 | - The models are expected to output a prediction score that is used to compute the Image F1Max score. 27 | - The models can optionally output a pixel-level anomaly map that is used to compute the pixel-level F1Max score. 28 | - **Few-Shot Protocol:** Models will be tested with 1, 2, 4, and 8-shot samples for each of the seed values and categories. 29 | 30 | ## 4. Evaluation 31 | 32 | - **Metrics:** The primary evaluation metric is the Image F1Max score. If the models are able to generate pixel-level anomaly maps, the pixel-level F1Max score is also computed. 33 | - **Evaluation Server/Platform:** The code for the evaluation pipeline lives in `eval.py`, with the workflow defined in `.github/workflows/eval.yml`. The participants are not allowed to modify these files. 34 | - **Leaderboard:** The leaderboard is autogenerated each time any participant/team makes a pull request that successfully runs the evaluation pipeline. 35 | 36 | ## 5. Submission Guidelines 37 | 38 | - **Format:** 39 | - Code must be packaged within the provided `src/eval/submission/` directory. 40 | - Participants must implement the `Model` class in `model.py`. 41 | - **Code Requirements:** 42 | - The submission system uses single Nvidia 3090 GPUs and 166GB of RAM per runner. The submission is expected to run within these constraints. 43 | - The runtime is expected to meet the cutoff limit. Any submission that does not meet the cutoff limit will automatically be terminated. 44 | - **Submission Process:** 45 | - The participants are expected to make a pull request to the repository with their submission. The submission process is outlined in the [README](README.md). 46 | - The participants are free to update their PRs until the submission deadline. 47 | - **Number of Submissions:** Currently, there is no limit on how many times a participant or team can submit. However, if we detect excessive computational resource usage by any participant/team, we may implement submission limits to ensure fair access for all. 48 | 49 | ## 6. Allowed Resources & Pre-training 50 | 51 | - **External Data:** The participants are allowed to use any external data for pre-training. However, pre-training on the MVTec LOCO dataset is NOT allowed. 52 | - **Pre-trained Models:** Pre-trained models should be available publically for download. 53 | - **Training Data Usage:** Participants can only use the provided k-shot samples during `model.setup` for training. These images are used for pre-conditioning during inference, and not used for training. 54 | 55 | ## 7. Code Release & Reproducibility 56 | 57 | - The entire evaluation pipeline is publically available on GitHub. 58 | - All submissions will only be allowed as a publically accessible pull request to the repository. 59 | 60 | ## 8. General Rules & Conduct 61 | 62 | - To make the evaluation fair and transparent, we will use the same evaluation protocol for all participants. This is done publically using GitHub Actions with the evaluation script available for scrutiny. 63 | - All participants must use the same evaluation protocol and metrics for their submissions. 64 | - Any team attempting to modify the evaluation script of the actions workflow will be disqualified. 65 | 66 | ## 9. Contact & Support 67 | 68 | - For questions regarding the rules or evaluation, please contact the organizing team mentioned on the challenge [website](https://sites.google.com/view/vand30cvpr2025/home). 69 | 70 | ## 10. Amendments 71 | 72 | - The organizers reserve the right to amend the rules, with notification to participants. This might include changes to the dataset, evaluation metrics, or submission guidelines. 73 | - The organizers reserve the right to change the evaluation server/platform, include the cut-off time for submissions if needed. 74 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "vand2025_eval" 3 | version = "0.1.0" 4 | description = "VAND 2025 evaluation script" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "anomalib[full]", 9 | "scikit-learn", 10 | # TODO: Add your additional dependencies here 11 | ] 12 | 13 | [tool.uv.sources] 14 | anomalib = { git = "https://github.com/openvinotoolkit/anomalib.git", branch = "main" } 15 | 16 | 17 | [tool.hatch.build.targets.wheel] 18 | packages = ["src/eval"] 19 | 20 | [project.scripts] 21 | eval = "eval.evaluate:eval" 22 | 23 | [build-system] 24 | requires = ["hatchling"] 25 | build-backend = "hatchling.build" 26 | 27 | [tool.ruff] 28 | line-length = 120 29 | 30 | [tool.ruff.lint] 31 | select = ["E", "F", "W", "I"] 32 | 33 | [tool.ruff.lint.isort] 34 | known-first-party = ["eval"] 35 | 36 | -------------------------------------------------------------------------------- /src/eval/evaluate.py: -------------------------------------------------------------------------------- 1 | """Evaluation script""" 2 | 3 | import argparse 4 | import json 5 | from pathlib import Path 6 | from tempfile import gettempdir 7 | from typing import cast 8 | from urllib.request import urlretrieve 9 | 10 | import pandas as pd 11 | import torch 12 | from anomalib.data import ImageBatch, MVTecLOCODataset 13 | from anomalib.data.utils import Split 14 | from anomalib.data.utils.download import DownloadProgressBar 15 | from anomalib.metrics.f1_score import _F1Max 16 | from sklearn.metrics import auc 17 | from torch import nn 18 | from torch.utils.data import DataLoader 19 | from torchvision.transforms.v2 import Resize 20 | from tqdm import tqdm 21 | 22 | from eval.submission.model import Model 23 | from eval.utils import auto_batch_size 24 | 25 | CATEGORIES = [ 26 | "breakfast_box", 27 | "juice_bottle", 28 | "pushpins", 29 | "screw_bag", 30 | "splicing_connectors", 31 | ] 32 | 33 | 34 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | 36 | 37 | def parse_args() -> argparse.Namespace: 38 | """Parse command line arguments. 39 | 40 | Returns: 41 | argparse.Namespace: Parsed arguments. 42 | """ 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument( 45 | "--seeds", 46 | type=int, 47 | nargs="*", 48 | default=[42, 0, 1234], 49 | help="List of seed values for reproducibility. Default is [42, 0, 1234].", 50 | ) 51 | parser.add_argument( 52 | "--k_shot", 53 | type=int, 54 | help="Single value for few-shot learning. Overwrites --k_shots if provided.", 55 | ) 56 | parser.add_argument( 57 | "--k_shots", 58 | type=int, 59 | nargs="+", 60 | default=[1, 2, 4, 8], 61 | help="List of integers for few-shot learning samples.", 62 | ) 63 | parser.add_argument( 64 | "--dataset_path", 65 | type=Path, 66 | default="./datasets/mvtec_loco", 67 | help="Path to the MVTEC LOCO dataset.", 68 | ) 69 | 70 | args = parser.parse_args() 71 | 72 | if args.k_shot is not None: 73 | args.k_shots = [args.k_shot] 74 | 75 | return args 76 | 77 | 78 | def get_dataloaders(dataset_path: Path | str, category: str, batch_size: int) -> tuple[DataLoader, DataLoader]: 79 | """Get the MVTec LOCO dataloader. 80 | 81 | Args: 82 | dataset_path (Path | str): Path to the dataset. 83 | category (str): Category of the MVTec dataset. 84 | batch_size (int): Batch size for the dataloaders. 85 | 86 | Returns: 87 | tuple[DataLoader, DataLoader]: Tuple of train and test dataloaders. 88 | """ 89 | # Create the dataset 90 | # NOTE: We fix the image size to (256, 256) for consistent evaluation across all models. 91 | train_dataset = MVTecLOCODataset( 92 | root=dataset_path, 93 | category=category, 94 | split=Split.TRAIN, 95 | augmentations=Resize((256, 256)), 96 | ) 97 | test_dataset = MVTecLOCODataset( 98 | root=dataset_path, 99 | category=category, 100 | split=Split.TEST, 101 | augmentations=Resize((256, 256)), 102 | ) 103 | train_dataloader = DataLoader( 104 | train_dataset, 105 | batch_size=batch_size, 106 | shuffle=False, 107 | num_workers=4, 108 | collate_fn=train_dataset.collate_fn, 109 | ) 110 | test_dataloader = DataLoader( 111 | test_dataset, 112 | batch_size=batch_size, 113 | shuffle=False, 114 | num_workers=4, 115 | collate_fn=test_dataset.collate_fn, 116 | ) 117 | 118 | return train_dataloader, test_dataloader 119 | 120 | 121 | def download(url: str) -> Path: 122 | """Download a file from a URL. 123 | 124 | Args: 125 | url (str): URL of the file to download. 126 | 127 | Returns: 128 | Path: Path to the downloaded file. 129 | """ 130 | root = Path(gettempdir()) 131 | downloaded_file_path = root / url.split("/")[-1] 132 | if not downloaded_file_path.exists(): # Check if file already exists 133 | if url.startswith("http://") or url.startswith("https://"): 134 | with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]) as progress_bar: 135 | urlretrieve( # noqa: S310 # nosec B310 136 | url=f"{url}", 137 | filename=downloaded_file_path, 138 | reporthook=progress_bar.update_to, 139 | ) 140 | else: 141 | message = f"URL {url} is not valid. Please check the URL." 142 | raise ValueError(message) 143 | return downloaded_file_path 144 | 145 | 146 | def get_model( 147 | category: str, 148 | ) -> nn.Module: 149 | """Instantiate and potentially load weights for the model. 150 | 151 | Args: 152 | category (str): Category of the dataset. Used for potentially loading category-specific weights. 153 | 154 | Returns: 155 | nn.Module: Loaded model moved to the specified device. 156 | """ 157 | model = Model() 158 | model.eval() 159 | 160 | weights_url = model.weights_url(category) 161 | if weights_url is not None: 162 | weights_path = download(weights_url) 163 | model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) 164 | 165 | return model.to(DEVICE) 166 | 167 | @auto_batch_size(max_batch_size=128) 168 | def compute_kshot_metrics( 169 | dataset_path, 170 | category, 171 | model: nn.Module, 172 | k_shot: int, 173 | seed: int, 174 | batch_size: int | None = None, 175 | ) -> dict[str, float]: 176 | """Compute metrics for a specific k-shot setting. 177 | 178 | Args: 179 | train_dataset (MVTecLOCODataset): Training dataset (used for sampling few-shot images). 180 | test_dataloader (DataLoader): Test dataloader for the category. 181 | model (nn.Module): The model instance (already on the correct device). 182 | k_shot (int): The number of few-shot samples (k). 183 | seed (int): Seed value for reproducibility of few-shot sampling. 184 | 185 | Returns: 186 | dict[str, float]: Computed metrics for this k-shot setting. 187 | """ 188 | train_dataloader, test_dataloader = get_dataloaders(dataset_path, category, batch_size=batch_size) 189 | train_dataset = cast(MVTecLOCODataset, train_dataloader.dataset) # Get underlying dataset) 190 | 191 | image_metric = _F1Max().to(DEVICE) 192 | 193 | # Sample k_shot images from the training set deterministically 194 | torch.manual_seed(seed) 195 | k_shot_idxs = torch.randperm(len(train_dataset))[:k_shot].tolist() 196 | 197 | # Pass few-shot images and dataset category to model's setup method 198 | few_shot_images = torch.stack([cast(ImageBatch, train_dataset[idx]).image for idx in k_shot_idxs]).to(DEVICE) 199 | setup_data = { 200 | "few_shot_samples": few_shot_images, 201 | "dataset_category": train_dataset.category, 202 | } 203 | model.setup(setup_data) 204 | 205 | # Inference loop 206 | model.eval() # Ensure model is in eval mode 207 | with torch.no_grad(): # Disable gradient calculations for inference 208 | for data in tqdm(test_dataloader, desc=f"k={k_shot} Inference", leave=False): 209 | output = model(data.image.to(DEVICE)) 210 | image_metric.update(output.pred_score, data.gt_label.to(DEVICE)) 211 | 212 | # Compute final metrics 213 | k_shot_metrics = {"image_score": image_metric.compute().item()} 214 | 215 | return k_shot_metrics 216 | 217 | 218 | def compute_average_metrics( 219 | metrics: list[dict[str, int | float]] | pd.DataFrame, 220 | ) -> dict[str, float]: 221 | """Compute the average metrics across all seeds and categories. 222 | 223 | Args: 224 | metrics (list[dict[str, int | float]] | pd.DataFrame): Collected metrics. 225 | 226 | Returns: 227 | dict[str, float]: Average metrics across all seeds and categories. 228 | """ 229 | # Convert the metrics list to a pandas DataFrame 230 | if not isinstance(metrics, pd.DataFrame): 231 | df = pd.DataFrame(metrics) 232 | df.to_csv("metrics.csv") 233 | 234 | # Compute the average metrics for each seed and k_shot across categories 235 | average_seed_performance = df.groupby(["k_shot", "category"])[["image_score"]].mean().reset_index() 236 | 237 | # Calculate the mean image and pixel performance for each k-shot 238 | k_shot_performance = average_seed_performance.groupby("k_shot")[["image_score"]].mean().reset_index() 239 | 240 | # Extract the k-shot numbers and their corresponding average image scores 241 | k_shot_numbers = k_shot_performance["k_shot"] 242 | average_image_scores = k_shot_performance["image_score"] 243 | 244 | # Calculate the area under the F1-max curve (AUFC) 245 | aufc = auc(k_shot_numbers, average_image_scores) 246 | 247 | # Get the normalized aufc score 248 | normalized_k_shot_numbers = (k_shot_numbers - k_shot_numbers.min()) / (k_shot_numbers.max() - k_shot_numbers.min()) 249 | normalized_aufc = auc(normalized_k_shot_numbers, average_image_scores) 250 | 251 | # Directly calculate the average image score across all k-shot performances 252 | avg_image_score = k_shot_performance["image_score"].mean() 253 | 254 | # Output the final average metrics and AUFC 255 | final_avg_metrics = { 256 | "aufc": aufc, 257 | "normalized_aufc": normalized_aufc, 258 | "avg_image_score": avg_image_score, 259 | } 260 | 261 | return final_avg_metrics 262 | 263 | 264 | def evaluate_submission( 265 | seeds: list[int], 266 | k_shots: list[int], 267 | dataset_path: Path | str, 268 | ) -> dict[str, float]: 269 | """Run the full evaluation across seeds, categories, and k-shots. 270 | 271 | Args: 272 | seeds (list[int]): List of seed values. 273 | k_shots (list[int]): List of k-shot values. 274 | dataset_path (Path | str): Path to the dataset. 275 | 276 | Returns: 277 | dict[str, float]: Final averaged metrics. 278 | """ 279 | 280 | metrics = [] 281 | print(f"Using device: {DEVICE}") 282 | 283 | # create dictso that we only compute batch size once per k_shot 284 | batch_size_dict = {k_shot: None for k_shot in k_shots} 285 | 286 | for category in tqdm(CATEGORIES, desc="Processing Categories"): 287 | # --- Per-Category Setup --- 288 | # Load model once per category 289 | model = get_model(category) 290 | for seed in tqdm(seeds, desc=f"Category {category} Seeds", leave=False): 291 | for k_shot in k_shots: # No tqdm here, handled in compute_kshot_metrics 292 | # Compute metrics for this specific seed/category/k-shot combination 293 | k_shot_metrics, batch_size = compute_kshot_metrics( 294 | dataset_path, 295 | category, 296 | model=model, 297 | k_shot=k_shot, 298 | seed=seed, 299 | batch_size=batch_size_dict[k_shot], 300 | ) 301 | 302 | # Append results 303 | metrics.append( 304 | { 305 | "seed": seed, 306 | "k_shot": k_shot, 307 | "category": category, 308 | "image_score": k_shot_metrics["image_score"], 309 | } 310 | ) 311 | 312 | # update batch size dict 313 | batch_size_dict[k_shot] = batch_size 314 | 315 | final_average_metrics = compute_average_metrics(metrics) 316 | print("Final Average Metrics Across All Seeds:", final_average_metrics) 317 | 318 | return final_average_metrics 319 | 320 | 321 | def eval(): 322 | args = parse_args() 323 | result = evaluate_submission( 324 | seeds=args.seeds, 325 | k_shots=args.k_shots, 326 | dataset_path=args.dataset_path, 327 | ) 328 | with open("results.json", "w") as f: 329 | json.dump(result, f, indent=2) 330 | 331 | 332 | if __name__ == "__main__": 333 | eval() 334 | -------------------------------------------------------------------------------- /src/eval/submission/README.md: -------------------------------------------------------------------------------- 1 | # VAND 2025 Challenge Submission Template 2 | 3 | This directory contains the template for participants to implement their solution. 4 | 5 | ## To Create a Submission 6 | 7 | 1. Fork the repository 8 | 2. Modify the `model.py` file to implement your solution. You can introduce any other files/directories in this directory. 9 | 3. Make a pull request to the repository. 10 | 4. If successful, the evaluation pipeline will run and you will be able to see the results on the leaderboard, and a comment will be added to the pull request with the results. 11 | 12 | ## Evaluation 13 | 14 | The evaluation pipeline is defined in `src/eval/evaluate.py`. It is used to evaluate the performance of the submitted model. DO NOT modify the evaluation pipeline. 15 | 16 | The repository uses [uv](https://docs.astral.sh/uv/getting-started/installation/) to manage dependencies, and running the evaluation script. 17 | 18 | 1. **Install `uv` (if you haven't already):** 19 | 20 | ```bash 21 | curl -LsSf https://astral.sh/uv/install.sh | sh 22 | source $HOME/.cargo/env # Or restart your shell 23 | ``` 24 | 25 | 2. **Set up the environment and run evaluation:** 26 | 27 | To test your submission locally, run the following command from the repository root: 28 | 29 | ```bash 30 | uv run eval --dataset_path /path/to/your/mvtec_loco 31 | ``` 32 | 33 | This will automatically create a virtual environment and install the dependencies. 34 | 35 | If the above command fails, you can try the following steps manually: 36 | 37 | ```bash 38 | # Install dependencies (this creates a .venv folder) 39 | uv sync 40 | # Activate the virtual environment 41 | source .venv/bin/activate 42 | # Run the evaluation script (replace with the actual path to your dataset) 43 | uv run eval --dataset_path /path/to/your/mvtec_loco 44 | ``` 45 | 46 | 3. **Check the results:** The script will generate `metrics.csv` (detailed scores) and `results.json` (final aggregated metrics). 47 | -------------------------------------------------------------------------------- /src/eval/submission/__init__.py: -------------------------------------------------------------------------------- 1 | """VAND 2025 Challenge Submission Template.""" 2 | -------------------------------------------------------------------------------- /src/eval/submission/model.py: -------------------------------------------------------------------------------- 1 | """Model for submission.""" 2 | 3 | import torch 4 | from anomalib.data import ImageBatch 5 | from torch import nn 6 | 7 | 8 | class Model(nn.Module): 9 | """TODO: Implement your model here""" 10 | 11 | def setup(self, setup_data: dict[str, torch.Tensor]) -> None: 12 | """Setup the model. 13 | 14 | Optional: Use this to pass few-shot images and dataset category to the model. 15 | 16 | Args: 17 | setup_data (dict[str, torch.Tensor]): The setup data. 18 | """ 19 | pass 20 | 21 | def weights_url(self, category: str) -> str | None: 22 | """URL to the model weights. 23 | 24 | You can optionally use the category to download specific weights for each category. 25 | """ 26 | # TODO: Implement this if you want to download the weights from a URL 27 | return None 28 | 29 | def forward(self, image: torch.Tensor) -> ImageBatch: 30 | """Forward pass of the model. 31 | 32 | Args: 33 | image (torch.Tensor): The input image. 34 | 35 | Returns: 36 | ImageBatch: The output image batch. 37 | """ 38 | # TODO: Implement the forward pass of the model. 39 | batch_size = image.shape[0] 40 | return ImageBatch( 41 | image=image, 42 | pred_score=torch.zeros(batch_size, device=image.device), 43 | ) 44 | -------------------------------------------------------------------------------- /src/eval/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | 4 | def ispow2(n): 5 | return n > 0 and (n & (n - 1)) == 0 6 | 7 | def auto_batch_size(max_batch_size=128, min_batch_size=1): 8 | """ 9 | Decorator that retries the function with decreasing powers of two as batch size 10 | on OOM errors. Validates that starting_batch_size is a power of two. 11 | """ 12 | if not ispow2(max_batch_size): 13 | raise ValueError(f"starting_batch_size must be a power of 2, got {max_batch_size}") 14 | 15 | def decorator(fn): 16 | @functools.wraps(fn) 17 | def wrapper(*args, **kwargs): 18 | if kwargs.get('batch_size') is not None: 19 | # when passed explicitly, we don't auto-configure the batch size 20 | print("\nUsing pre-specified batch size:", kwargs['batch_size']) 21 | return fn(*args, **kwargs), kwargs['batch_size'] 22 | batch_size = max_batch_size 23 | while batch_size >= min_batch_size: 24 | try: 25 | print(f"\nTrying with batch size {batch_size}") 26 | torch.cuda.empty_cache() 27 | kwargs["batch_size"] = batch_size 28 | result = fn(*args, **kwargs) 29 | print(f"\nSuccess with batch size: {batch_size}") 30 | return result, batch_size 31 | except RuntimeError as e: 32 | if 'out of memory' in str(e).lower(): 33 | print(f"\nOOM at batch size {batch_size}, trying smaller...") 34 | torch.cuda.empty_cache() 35 | batch_size //= 2 36 | else: 37 | raise e 38 | raise RuntimeError(f"\nAll batch sizes down to {min_batch_size} caused OOM.") 39 | return wrapper 40 | return decorator 41 | -------------------------------------------------------------------------------- /zizmor.yml: -------------------------------------------------------------------------------- 1 | rules: 2 | dangerous-triggers: 3 | ignore: 4 | - evaluate.yaml:3 5 | template-injection: 6 | ignore: 7 | - evaluate.yaml:32 8 | --------------------------------------------------------------------------------