├── .gitignore ├── .vscode └── settings.json ├── README.md ├── requirements.txt ├── setup.sh └── wdv3_timm.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### Python ### 49 | # Byte-compiled / optimized / DLL files 50 | __pycache__/ 51 | *.py[cod] 52 | *$py.class 53 | 54 | # C extensions 55 | *.so 56 | 57 | # Distribution / packaging 58 | .Python 59 | build/ 60 | develop-eggs/ 61 | dist/ 62 | downloads/ 63 | eggs/ 64 | .eggs/ 65 | lib/ 66 | lib64/ 67 | parts/ 68 | sdist/ 69 | var/ 70 | wheels/ 71 | share/python-wheels/ 72 | *.egg-info/ 73 | .installed.cfg 74 | *.egg 75 | MANIFEST 76 | 77 | # PyInstaller 78 | # Usually these files are written by a python script from a template 79 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 80 | *.manifest 81 | *.spec 82 | 83 | # Installer logs 84 | pip-log.txt 85 | pip-delete-this-directory.txt 86 | 87 | # Unit test / coverage reports 88 | htmlcov/ 89 | .tox/ 90 | .nox/ 91 | .coverage 92 | .coverage.* 93 | .cache 94 | nosetests.xml 95 | coverage.xml 96 | *.cover 97 | *.py,cover 98 | .hypothesis/ 99 | .pytest_cache/ 100 | cover/ 101 | 102 | # Translations 103 | *.mo 104 | *.pot 105 | 106 | # Django stuff: 107 | *.log 108 | local_settings.py 109 | db.sqlite3 110 | db.sqlite3-journal 111 | 112 | # Flask stuff: 113 | instance/ 114 | .webassets-cache 115 | 116 | # Scrapy stuff: 117 | .scrapy 118 | 119 | # Sphinx documentation 120 | docs/_build/ 121 | 122 | # PyBuilder 123 | .pybuilder/ 124 | target/ 125 | 126 | # Jupyter Notebook 127 | .ipynb_checkpoints 128 | 129 | # IPython 130 | profile_default/ 131 | ipython_config.py 132 | 133 | # pyenv 134 | # For a library or package, you might want to ignore these files since the code is 135 | # intended to run in multiple environments; otherwise, check them in: 136 | # .python-version 137 | 138 | # pipenv 139 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 140 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 141 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 142 | # install all needed dependencies. 143 | #Pipfile.lock 144 | 145 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 146 | __pypackages__/ 147 | 148 | # Celery stuff 149 | celerybeat-schedule 150 | celerybeat.pid 151 | 152 | # SageMath parsed files 153 | *.sage.py 154 | 155 | # Environments 156 | .env 157 | .venv 158 | env/ 159 | venv/ 160 | ENV/ 161 | env.bak/ 162 | venv.bak/ 163 | 164 | # Spyder project settings 165 | .spyderproject 166 | .spyproject 167 | 168 | # Rope project settings 169 | .ropeproject 170 | 171 | # mkdocs documentation 172 | /site 173 | 174 | # mypy 175 | .mypy_cache/ 176 | .dmypy.json 177 | dmypy.json 178 | 179 | # Pyre type checker 180 | .pyre/ 181 | 182 | # pytype static type analyzer 183 | .pytype/ 184 | 185 | # Cython debug symbols 186 | cython_debug/ 187 | 188 | ### VisualStudioCode ### 189 | .vscode/* 190 | !.vscode/settings.json 191 | !.vscode/tasks.json 192 | !.vscode/launch.json 193 | !.vscode/extensions.json 194 | *.code-workspace 195 | 196 | # Local History for Visual Studio Code 197 | .history/ 198 | 199 | ### VisualStudioCode Patch ### 200 | # Ignore all local history of files 201 | .history 202 | .ionide 203 | 204 | ### Windows ### 205 | # Windows thumbnail cache files 206 | Thumbs.db 207 | Thumbs.db:encryptable 208 | ehthumbs.db 209 | ehthumbs_vista.db 210 | 211 | # Dump file 212 | *.stackdump 213 | 214 | # Folder config file 215 | [Dd]esktop.ini 216 | 217 | # Recycle Bin used on file shares 218 | $RECYCLE.BIN/ 219 | 220 | # Windows Installer files 221 | *.cab 222 | *.msi 223 | *.msix 224 | *.msm 225 | *.msp 226 | 227 | # Windows shortcuts 228 | *.lnk 229 | 230 | # End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python 231 | 232 | # temp and misc 233 | /misc/ 234 | /temp/ 235 | 236 | # direnv 237 | .envrc 238 | .envrc.* 239 | 240 | # dotenv 241 | .env 242 | .env.* 243 | 244 | # temp files 245 | **/tmp_*.* 246 | **/*.tmp.* 247 | 248 | # but keep examples 249 | !*.example 250 | 251 | # input images and heatmap outputs 252 | /images/ 253 | /heatmaps/ 254 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.insertSpaces": true, 3 | "editor.tabSize": 4, 4 | "files.trimTrailingWhitespace": true, 5 | "editor.rulers": [100, 120], 6 | 7 | "files.associations": { 8 | "*.yaml": "yaml" 9 | }, 10 | "files.exclude": { 11 | "**/.git": true, 12 | "**/.svn": true, 13 | "**/.hg": true, 14 | "**/CVS": true, 15 | "**/.DS_Store": true, 16 | "**/Thumbs.db": true, 17 | "**/.ruff_cache": true, 18 | "**/__pycache__": true, 19 | "**/*.egg-info": true 20 | }, 21 | 22 | "[shellscript]": { 23 | "files.eol": "\n", 24 | "editor.tabSize": 4, 25 | "editor.detectIndentation": false 26 | }, 27 | 28 | "[python]": { 29 | "editor.wordBasedSuggestions": "off", 30 | "editor.formatOnSave": true, 31 | "editor.defaultFormatter": "charliermarsh.ruff", 32 | "editor.codeActionsOnSave": { 33 | "source.organizeImports": "always" 34 | } 35 | }, 36 | "python.analysis.include": ["./src", "./scripts", "./tests"], 37 | 38 | "[json]": { 39 | "editor.defaultFormatter": "esbenp.prettier-vscode", 40 | "editor.detectIndentation": false, 41 | "editor.formatOnSaveMode": "file", 42 | "editor.formatOnSave": true, 43 | "editor.tabSize": 2 44 | }, 45 | "[jsonc]": { 46 | "editor.defaultFormatter": "esbenp.prettier-vscode", 47 | "editor.detectIndentation": false, 48 | "editor.formatOnSaveMode": "file", 49 | "editor.formatOnSave": true, 50 | "editor.tabSize": 2 51 | }, 52 | 53 | "[toml]": { 54 | "editor.tabSize": 2, 55 | "editor.detectIndentation": false, 56 | "editor.formatOnSave": true, 57 | "editor.formatOnSaveMode": "file", 58 | "editor.defaultFormatter": "tamasfe.even-better-toml", 59 | "editor.rulers": [80, 100] 60 | }, 61 | "evenBetterToml.formatter.columnWidth": 88, 62 | 63 | "[yaml]": { 64 | "editor.detectIndentation": false, 65 | "editor.tabSize": 2, 66 | "editor.formatOnSave": true, 67 | "editor.formatOnSaveMode": "file", 68 | "diffEditor.ignoreTrimWhitespace": false, 69 | "editor.defaultFormatter": "redhat.vscode-yaml" 70 | }, 71 | "yaml.format.bracketSpacing": true, 72 | "yaml.format.proseWrap": "preserve", 73 | "yaml.format.singleQuote": false, 74 | "yaml.format.printWidth": 110, 75 | 76 | "[hcl]": { 77 | "editor.detectIndentation": false, 78 | "editor.formatOnSave": true, 79 | "editor.formatOnSaveMode": "file", 80 | "editor.defaultFormatter": "fredwangwang.vscode-hcl-format" 81 | }, 82 | 83 | "[markdown]": { 84 | "files.trimTrailingWhitespace": false 85 | }, 86 | 87 | "css.lint.validProperties": ["dock", "content-align", "content-justify"], 88 | "[css]": { 89 | "editor.formatOnSave": true 90 | }, 91 | 92 | "remote.autoForwardPorts": false, 93 | "remote.autoForwardPortsSource": "process" 94 | } 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wdv3-timm 2 | 3 | small example thing showing how to use `timm` to run the WD Tagger V3 models. 4 | 5 | ## How To Use 6 | 7 | 1. clone the repository and enter the directory: 8 | ```sh 9 | git clone https://github.com/neggles/wdv3-timm.git 10 | cd wd3-timm 11 | ``` 12 | 13 | 2. Create a virtual environment and install the Python requirements. 14 | 15 | If you're using Linux, you can use the provided script: 16 | ```sh 17 | bash setup.sh 18 | ``` 19 | 20 | Or if you're on Windows (or just want to do it manually), you can do the following: 21 | ```sh 22 | # Create virtual environment 23 | python3.10 -m venv .venv 24 | # Activate it 25 | source .venv/bin/activate 26 | # Upgrade pip/setuptools/wheel 27 | python -m pip install -U pip setuptools wheel 28 | # At this point, optionally you can install PyTorch manually (e.g. if you are not using an nVidia GPU) 29 | python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu 30 | # Install requirements 31 | python -m pip install -r requirements.txt 32 | ``` 33 | 34 | 3. Run the example script, picking one of the 3 models to use: 35 | ```sh 36 | python wdv3_timm.py path/to/image.png 37 | ``` 38 | 39 | Example output from `python wdv3_timm.py vit a_picture_of_ganyu.png`: 40 | ```sh 41 | Loading model 'vit' from 'SmilingWolf/wd-vit-tagger-v3'... 42 | Loading tag list... 43 | Creating data transform... 44 | Loading image and preprocessing... 45 | Running inference... 46 | Processing results... 47 | -------- 48 | Caption: 1girl, horns, solo, bell, ahoge, colored_skin, blue_skin, neck_bell, looking_at_viewer, purple_eyes, upper_body, blonde_hair, long_hair, goat_horns, blue_hair, off_shoulder, sidelocks, bare_shoulders, alternate_costume, shirt, black_shirt, cowbell, ganyu_(genshin_impact) 49 | -------- 50 | Tags: 1girl, horns, solo, bell, ahoge, colored skin, blue skin, neck bell, looking at viewer, purple eyes, upper body, blonde hair, long hair, goat horns, blue hair, off shoulder, sidelocks, bare shoulders, alternate costume, shirt, black shirt, cowbell, ganyu \(genshin impact\) 51 | -------- 52 | Ratings: 53 | general: 0.827 54 | sensitive: 0.199 55 | questionable: 0.001 56 | explicit: 0.001 57 | -------- 58 | Character tags (threshold=0.75): 59 | ganyu_(genshin_impact): 0.991 60 | -------- 61 | General tags (threshold=0.35): 62 | 1girl: 0.996 63 | horns: 0.950 64 | solo: 0.947 65 | bell: 0.918 66 | ahoge: 0.897 67 | colored_skin: 0.881 68 | blue_skin: 0.872 69 | neck_bell: 0.854 70 | looking_at_viewer: 0.817 71 | purple_eyes: 0.734 72 | upper_body: 0.615 73 | blonde_hair: 0.609 74 | long_hair: 0.607 75 | goat_horns: 0.524 76 | blue_hair: 0.496 77 | off_shoulder: 0.472 78 | sidelocks: 0.470 79 | bare_shoulders: 0.464 80 | alternate_costume: 0.437 81 | shirt: 0.427 82 | black_shirt: 0.417 83 | cowbell: 0.415 84 | ``` 85 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | huggingface-hub 3 | numpy 4 | pandas 5 | pillow >= 9.5.0 6 | simple-parsing >= 0.1.5 7 | timm @ git+https://github.com/huggingface/pytorch-image-models@main#egg=timm 8 | tokenizers 9 | torch >= 2.0.0 10 | torchvision 11 | transformers 12 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euo pipefail 3 | 4 | # get the folder this script is in and make sure we're in it 5 | script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd -P) 6 | cd "${script_dir}" 7 | 8 | # make venv if not exist 9 | if [[ ! -d .venv ]]; then 10 | echo "Creating virtual environment..." 11 | python3.10 -m venv .venv 12 | fi 13 | 14 | # activate the venv 15 | source .venv/bin/activate 16 | 17 | # upgrade pip 18 | python -m pip install -U pip setuptools wheel 19 | 20 | # install requirements 21 | python -m pip install -r requirements.txt 22 | 23 | echo "Setup complete. Run 'source .venv/bin/activate' to enter the virtual environment." 24 | exit 0 25 | -------------------------------------------------------------------------------- /wdv3_timm.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import timm 8 | import torch 9 | from huggingface_hub import hf_hub_download 10 | from huggingface_hub.utils import HfHubHTTPError 11 | from PIL import Image 12 | from simple_parsing import field, parse_known_args 13 | from timm.data import create_transform, resolve_data_config 14 | from torch import Tensor, nn 15 | from torch.nn import functional as F 16 | 17 | torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | MODEL_REPO_MAP = { 19 | "vit": "SmilingWolf/wd-vit-tagger-v3", 20 | "swinv2": "SmilingWolf/wd-swinv2-tagger-v3", 21 | "convnext": "SmilingWolf/wd-convnext-tagger-v3", 22 | } 23 | 24 | 25 | def pil_ensure_rgb(image: Image.Image) -> Image.Image: 26 | # convert to RGB/RGBA if not already (deals with palette images etc.) 27 | if image.mode not in ["RGB", "RGBA"]: 28 | image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") 29 | # convert RGBA to RGB with white background 30 | if image.mode == "RGBA": 31 | canvas = Image.new("RGBA", image.size, (255, 255, 255)) 32 | canvas.alpha_composite(image) 33 | image = canvas.convert("RGB") 34 | return image 35 | 36 | 37 | def pil_pad_square(image: Image.Image) -> Image.Image: 38 | w, h = image.size 39 | # get the largest dimension so we can pad to a square 40 | px = max(image.size) 41 | # pad to square with white background 42 | canvas = Image.new("RGB", (px, px), (255, 255, 255)) 43 | canvas.paste(image, ((px - w) // 2, (px - h) // 2)) 44 | return canvas 45 | 46 | 47 | @dataclass 48 | class LabelData: 49 | names: list[str] 50 | rating: list[np.int64] 51 | general: list[np.int64] 52 | character: list[np.int64] 53 | 54 | 55 | def load_labels_hf( 56 | repo_id: str, 57 | revision: Optional[str] = None, 58 | token: Optional[str] = None, 59 | ) -> LabelData: 60 | try: 61 | csv_path = hf_hub_download( 62 | repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token 63 | ) 64 | csv_path = Path(csv_path).resolve() 65 | except HfHubHTTPError as e: 66 | raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e 67 | 68 | df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) 69 | tag_data = LabelData( 70 | names=df["name"].tolist(), 71 | rating=list(np.where(df["category"] == 9)[0]), 72 | general=list(np.where(df["category"] == 0)[0]), 73 | character=list(np.where(df["category"] == 4)[0]), 74 | ) 75 | 76 | return tag_data 77 | 78 | 79 | def get_tags( 80 | probs: Tensor, 81 | labels: LabelData, 82 | gen_threshold: float, 83 | char_threshold: float, 84 | ): 85 | # Convert indices+probs to labels 86 | probs = list(zip(labels.names, probs.numpy())) 87 | 88 | # First 4 labels are actually ratings 89 | rating_labels = dict([probs[i] for i in labels.rating]) 90 | 91 | # General labels, pick any where prediction confidence > threshold 92 | gen_labels = [probs[i] for i in labels.general] 93 | gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) 94 | gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) 95 | 96 | # Character labels, pick any where prediction confidence > threshold 97 | char_labels = [probs[i] for i in labels.character] 98 | char_labels = dict([x for x in char_labels if x[1] > char_threshold]) 99 | char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) 100 | 101 | # Combine general and character labels, sort by confidence 102 | combined_names = [x for x in gen_labels] 103 | combined_names.extend([x for x in char_labels]) 104 | 105 | # Convert to a string suitable for use as a training caption 106 | caption = ", ".join(combined_names) 107 | taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") 108 | 109 | return caption, taglist, rating_labels, char_labels, gen_labels 110 | 111 | 112 | @dataclass 113 | class ScriptOptions: 114 | image_file: Path = field(positional=True) 115 | model: str = field(default="vit") 116 | gen_threshold: float = field(default=0.35) 117 | char_threshold: float = field(default=0.75) 118 | 119 | 120 | def main(opts: ScriptOptions): 121 | repo_id = MODEL_REPO_MAP.get(opts.model) 122 | image_path = Path(opts.image_file).resolve() 123 | if not image_path.is_file(): 124 | raise FileNotFoundError(f"Image file not found: {image_path}") 125 | 126 | print(f"Loading model '{opts.model}' from '{repo_id}'...") 127 | model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval() 128 | state_dict = timm.models.load_state_dict_from_hf(repo_id) 129 | model.load_state_dict(state_dict) 130 | 131 | print("Loading tag list...") 132 | labels: LabelData = load_labels_hf(repo_id=repo_id) 133 | 134 | print("Creating data transform...") 135 | transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) 136 | 137 | print("Loading image and preprocessing...") 138 | # get image 139 | img_input: Image.Image = Image.open(image_path) 140 | # ensure image is RGB 141 | img_input = pil_ensure_rgb(img_input) 142 | # pad to square with white background 143 | img_input = pil_pad_square(img_input) 144 | # run the model's input transform to convert to tensor and rescale 145 | inputs: Tensor = transform(img_input).unsqueeze(0) 146 | # NCHW image RGB to BGR 147 | inputs = inputs[:, [2, 1, 0]] 148 | 149 | print("Running inference...") 150 | with torch.inference_mode(): 151 | # move model to GPU, if available 152 | if torch_device.type != "cpu": 153 | model = model.to(torch_device) 154 | inputs = inputs.to(torch_device) 155 | # run the model 156 | outputs = model.forward(inputs) 157 | # apply the final activation function (timm doesn't support doing this internally) 158 | outputs = F.sigmoid(outputs) 159 | # move inputs, outputs, and model back to to cpu if we were on GPU 160 | if torch_device.type != "cpu": 161 | inputs = inputs.to("cpu") 162 | outputs = outputs.to("cpu") 163 | model = model.to("cpu") 164 | 165 | print("Processing results...") 166 | caption, taglist, ratings, character, general = get_tags( 167 | probs=outputs.squeeze(0), 168 | labels=labels, 169 | gen_threshold=opts.gen_threshold, 170 | char_threshold=opts.char_threshold, 171 | ) 172 | 173 | print("--------") 174 | print(f"Caption: {caption}") 175 | print("--------") 176 | print(f"Tags: {taglist}") 177 | 178 | print("--------") 179 | print("Ratings:") 180 | for k, v in ratings.items(): 181 | print(f" {k}: {v:.3f}") 182 | 183 | print("--------") 184 | print(f"Character tags (threshold={opts.char_threshold}):") 185 | for k, v in character.items(): 186 | print(f" {k}: {v:.3f}") 187 | 188 | print("--------") 189 | print(f"General tags (threshold={opts.gen_threshold}):") 190 | for k, v in general.items(): 191 | print(f" {k}: {v:.3f}") 192 | 193 | print("Done!") 194 | 195 | 196 | if __name__ == "__main__": 197 | opts, _ = parse_known_args(ScriptOptions) 198 | if opts.model not in MODEL_REPO_MAP: 199 | print(f"Available models: {list(MODEL_REPO_MAP.keys())}") 200 | raise ValueError(f"Unknown model name '{opts.model}'") 201 | main(opts) 202 | --------------------------------------------------------------------------------