├── .github └── workflows │ └── pylint.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── README.ko.md ├── README.md ├── docs ├── model-comparison.md ├── screenshot.png └── what-is-wd14-tagger.md ├── install.py ├── javascript └── tagger.js ├── json_schema └── db_json_v1_schema.json ├── preload.py ├── pyproject.toml ├── requirements.txt ├── scripts └── tagger.py ├── shell_scripts ├── compare_weighted_frequencies.py ├── create_safetensors_db.sh ├── model_grep.sh └── tag_based_image_dedup.sh ├── style.css └── tagger ├── api.py ├── api_models.py ├── dbimutils.py ├── format.py ├── generator └── tf_data_reader.py ├── interrogator.py ├── preset.py ├── settings.py ├── ui.py ├── uiset.py └── utils.py /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8", "3.9", "3.10"] 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install pylint 21 | - name: Analysing the code with pylint 22 | run: | 23 | pylint $(git ls-files '*.py') 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | .venv/ 4 | .env 5 | 6 | presets/ -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | Api changes: 3 | Image interrogation via api receives two extra parameters; empty strings by 4 | default. `queue`: the name for a queue, which could be e.g. the person or 5 | subject name. You can leave it empty for the first interrogation, then the 6 | response will que in a new auto-generated unique name, listed in the response. 7 | # v1.2.0 (2023-09-16) 8 | 9 | Make sure you use this same name as queue, for all interrogations that you want 10 | to be grouped together. The second parameter is `name_in_queue`: the name for 11 | that particular image that is being queued, e.g. a file name. 12 | 13 | If both queue and name are empty, there is a single interrogation with response, 14 | which includes nested objects "ratings" and "tags", so: 15 | `{"ratings": {"sensitive": 0.5, ..}, "tags": {"tag1": 0.5, ..}}` 16 | 17 | If neither name nor queue are empty, the interrogation is queued under that name. 18 | If already in queue, that name is changed - clobbered - with #. An exception is if 19 | the given name is in which case an image checksum will be used instead of 20 | a name. Duplicates are ignored. 21 | 22 | During queuing, the response is the number of all processed interrogations for all 23 | active queues. 24 | 25 | If name_in_queue is empty, but queue is not, that particular queue is finalized, 26 | A response is awaited for remaining interrogations in this queue (if any still). 27 | The response, only for this queue, is an object with the name_in_queue as key, 28 | and the tag with weights contained. Ratings have ther tag name prefixed with 29 | "rating:". Example: 30 | `{"name_in_queue": {"rating:sensitive": 0.5, "tag1": 0.5, ..}}` 31 | 32 | Fix in absence of tensrflow_io 33 | Fix deprecation warning 34 | Added three scripts in shell scripts under shell_scripts: 35 | * A bash script to generate per safetensors file the fraction of images 36 | that the model was trained on that was tagged with particular tokens. 37 | * A python script to compare the interrogation results (read from db.json) 38 | and find the top -c safetensors files that contain similar weights (or at 39 | least, that was the intention, there may be better algorithms to compare, 40 | but it seems to do the job). 41 | * And finally a model_grep script which listts the tags and number of trained 42 | images in a safetensors model. 43 | 44 | # v1.1.2 c9f8efd (2023-08-26) 45 | 46 | Explain recursive path usage better in ui 47 | Fix sending tags via buttons to txt2img and img2img 48 | type additions, inadvertently pushed, later retouched. 49 | allow setting gpu device via flag 50 | Fix inverted cumulative checkbox 51 | wrap_gradio_gpu_call fallback 52 | Fix for preload shared access 53 | preload update 54 | A few ui changes 55 | Fix not clearing the tags after writing them to files 56 | Fix: Tags were still added, beyond count threshold 57 | fix search/replace bug 58 | (here int based weights were reverted) 59 | circumvent when unable to load tensorflow 60 | fix for too many exclude_tags 61 | add db.json validation schema, add schema validation 62 | return fix for fastapi 63 | pick up huggingface cache dir from env, with default, configurable also via settings. 64 | leave tensorflow requirements to the user. 65 | Fix for Reappearance of gradio bug: duplicate image edit 66 | (index based weights, but later reverted) 67 | Instead of cache_dir use local_dir, leav 68 | 69 | 70 | # v1.1.1 eada050 (2023-07-20) 71 | 72 | Internal cleanup, no separate interrogation for inverse 73 | Fix issues with search and sending selection to keep/exclude 74 | Fix issue #14, picking up last edit box changes 75 | Fix 2 issues reported by guansss 76 | fix huggingface reload issues. Thanks to Atoli and coder168 for reporting 77 | experimental tensorflow unloading, but after some discussion, maybe conversion to onxx can solve this. See #17, thanks again Sean Wang. 78 | add gallery tab, rudimentary. 79 | fix some hf download issues 80 | fixes for fastapi 81 | added ML-Danbooru support, thanks to [CCRcmcpe](github.com/CCRcmcpe) 82 | 83 | 84 | # v1.1.0 87706b7 (2023-07-16) 85 | 86 | fix: failed to install onnxruntime package on MacOS thanks to heady713 87 | fastapi: remote unload model, picked up from [here](https://github.com/toriato/stable-diffusion-webui-wd14-tagger/pull/109) 88 | attribute error fix from aria1th also reported by yjunej 89 | re-allowed weighted tags files, now configured in settings -> tagger. 90 | wzgrx pointed out there were some modules not installed by default, so I've added a requirements.txt file that will auto-install required dependencies. However, the initial requirements.txt had issues. I ran to create the requirements.txt: 91 | ``` 92 | pipreqs --force `pwd` 93 | sed -i s/==.*$//g requirements.txt 94 | ``` 95 | but it ended up adding external modules that were shadowing webui modules. If you have installed those, you may find you are not even able to start the webui until you remove them. Change to the directory of my extension and 96 | ``` 97 | pip uninstall webui 98 | pip uninstall modules 99 | pip uninstall launch 100 | ``` 101 | In particular installing a module named modules was a serious problem. Python should flag that name as illegal. 102 | 103 | There were some interrogators that were not working unless you have them installed manually. Now they are only listed if you have them. 104 | 105 | Thanks to wzgrx for testing and reporting these last two issues. 106 | changed internal file structure, thanks to idiotcomerce #4 107 | more regex usage in search and exclusion tags 108 | fixed a bug where some exclusion tags were not reflected in the tags file 109 | changed internal error handling, It is a bit quirky, which I intend to fix, still. 110 | If you find it keeps complaining about an input field without reason, just try editing that one again (e.g. add a space there and remove it). 111 | 112 | 113 | # v1.0.0 a1b59d6 (2023-07-10) 114 | 115 | You may have to remove the presets/default.json and save a new one.witth your desired defaults. Otherwise checkboxes may not have the right default values. 116 | 117 | General changes: 118 | 119 | Weights, when enabled, are not printed in the tags list. Weights are displayed in the list below already as bars, so they do not add information, only obfuscate the list IMO. 120 | There is an settings entry for the tagger, several options have been moved there. 121 | The list of tags weights stops at a number specified on the settings tab (the slider) 122 | There is both an included and excluded rags tab 123 | tags in the tags list on top are clickable. 124 | Tags below are also clickable. There is a difference if you click on the dotted line or on the actual word. a click on the word will add it to a search/kept tag (dependent on which was last active) on the dotted line will add it to the input box next to it. 125 | interrogations can be combined (checkbox), also for a single image. 126 | Make the labels listed clickable again, a click will add it to the selected listbox. This also functions when you are on the discarded tags tab. 127 | Added search and replace input lists. 128 | Changed behavior: when clicking on the dotted line, inserted is in the exclude/replace input list, if not the tag is inserted in the additional/search input list 129 | Added a Mininmum fraction for tags slider. This filters tags based on the fraction of images and interrogations per image that has this tag with the selected weight threshold. I find this kind of filtering makes more sense than limiting the tags list to a number, though that is ok to prevent cluttering up the view, 130 | 131 | Added a string search selected tags input field (top right) and two buttons: 132 | Move visible tags to keep tags 133 | Move visible tags to exclude tags 134 | 135 | For batch processing: 136 | After each update a db.json is written in the images folder. The db contains the weights for queries, a rerun of the same images using an interrogator just rereads this db.json. This also works after a stable diffusion reload or a reboot, as long as this db.json is there. 137 | 138 | There is a huge batch implementation, but I was unable to test, not the right tensorflow version. EXPERIMENTAL. It is only enabled if you have the right tf version, but it's likely buggy due to my lack of testing. feel free to send me a patch if you can improve it. also see here 139 | pre- or appending weights to weighed tag files, i.e. with weights enabled, will instead have the weights averaged 140 | 141 | After batch processing the combined tag count average is listed, for all processed files, and the corrected average when combining the weighed tags. This is not limited to the tag_count_threshold, as it relates to the weights of all tag files. Conversely, the already existing threshold slider does affect this list length. 142 | search tag can be a single regex or as many as replacements, comma separated. Currently a single regex or multiple as many strings in search an replace are allowed, but this is going to change in the near future, to allow all regexes and back referencing per replacements as in a re.sub(). 143 | added a 'verbose setting'. 144 | a comma was previously missing when appending tags 145 | several of the interrogators have been fixed. 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Thanks for the time to contrribute to this project. 2 | 3 | The followiing is a set of guidelines for contributing to this project. These are just guidelines, not rules, use your best judgment and this document is also subject to change. 4 | 5 | Table of Contents 6 | ================= 7 | 1. Contribution Workflow 8 | * Styleguides 9 | * Git Commit Messages 10 | * Styleguides, general notes 11 | * JavaScript Styleguide 12 | * Python Styleguide 13 | * Documentation Styleguide 14 | 2. License 15 | 3. Questions 16 | 17 | # Contribution Workflow 18 | * Fork the repo and create your branch from master. 19 | * If you've added code that should be tested, add tests. 20 | * If you've changed APIs, update the documentation. 21 | * Ensure the test suite passes. 22 | * Make sure your code lints. 23 | * Issue that pull request! 24 | 25 | # Styleguides 26 | ## Git Commit Messages 27 | * Use the present tense ("Add feature" not "Added feature") 28 | * Use the imperative mood ("Move cursor to..." not "Moves cursor to...") 29 | * Limit the first line to 72 characters or less 30 | * Reference issues and pull requests liberally after the first line 31 | * When only changing documentation, include [ci skip] in the commit title 32 | * Consider starting the commit message with an applicable emoji. 33 | * A sign-off is not required, but encouraged using the -s flag. Example: git commit -s -m "Adding a new feature" 34 | 35 | Example commit message: 36 | ``` 37 | :rocket: Adds `launch()` method 38 | 39 | The launch method accepts a single argument for the speed of the launch. 40 | This method is necessary to get to the moon and fixes #76. 41 | This commit closes issue #34 42 | 43 | Signed-off-by: Jane Doe 44 | ``` 45 | 46 | ## Styleguides, general notes 47 | The current code does not follow the below proposed styleguides everywhere. Please try to follow the styleguides as much as possible, but if you see something that is not following the styleguides, please do not change it. Commits should be atomic and only change one thing, and changing the style obfuscates the changes. The same goes for whitespace changes. 48 | 49 | * If you change current code, please do use the styleguides, even if the code around it does not follow it. 50 | * If you do not adhere to the styleguides, that is ok as well, but please make sure your code is readable and easy to understand. 51 | 52 | 53 | ## JavaScript Styleguide 54 | All JavaScript must adhere to [JavaScript Standard Style](https://standardjs.com/). [![JavaScript Style Guide](https://cdn.rawgit.com/standard/standard/master/badge.svg)](JS%20Style%20Guide) 55 | 56 | ## Python Styleguide 57 | Try to adhere to [PEP 8](https://www.python.org/dev/peps/pep-0008/). It is not required, but it is recommended. 58 | 59 | ## Documentation Styleguide 60 | Use [JSDoc](http://usejsdoc.org/) syntax to document code. 61 | Use [GitHub-flavored Markdown](https://guides.github.com/features/mastering-markdown/) syntax to format documentation. 62 | 63 | Thank you for your interest in contributing to this project! 64 | 65 | # License 66 | Largely public domain, I think tagger/dbimutils,py was [MIT](https://choosealicense.com/licenses/mit/) 67 | 68 | # Questions 69 | If you have any questions about the repo, open an issue or contact me directly at [email](mailto:pi.co.0o.byte@gmail.com). 70 | 71 | 72 | -------------------------------------------------------------------------------- /README.ko.md: -------------------------------------------------------------------------------- 1 | [Automatic1111 웹UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)를 위한 태깅(라벨링) 확장 기능 2 | --- 3 | DeepDanbooru 와 같은 모델을 통해 단일 또는 여러 이미지로부터 부루에서 사용하는 태그를 알아냅니다. 4 | 5 | [You don't know how to read Korean? Read it in English here!](README.md) 6 | 7 | ## 들어가기 앞서 8 | 모델과 대부분의 코드는 제가 만들지 않았고 [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) 와 MrSmillingWolf 의 태거에서 가져왔습니다. 9 | 10 | ## 설치하기 11 | 1. *확장기능* -> *URL로부터 확장기능 설치* -> 이 레포지토리 주소 입력 -> *설치* 12 | - 또는 이 레포지토리를 `extensions/` 디렉터리 내에 클론합니다. 13 | ```sh 14 | $ git clone https://github.com/picobyte/stable-diffusion-webui-wd14-tagger.git extensions/tagger 15 | ``` 16 | 17 | 1. 모델 추가하기 18 | - #### *MrSmilingWolf's model (a.k.a. Waifu Diffusion 1.4 tagger)* 19 | 처음 실행할 때 [HuggingFace 레포지토리](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger)로부터 자동으로 받아옵니다. 20 | 21 | 모델과 관련된 또는 추가 학습에 대한 질문은 원작자인 MrSmilingWolf#5991 으로 물어봐주세요. 22 | 23 | - #### *DeepDanbooru* 24 | 1. 다양한 모델 파일은 아래 주소에서 찾을 수 있습니다. 25 | - [DeepDanbooru model](https://github.com/KichangKim/DeepDanbooru/releases) 26 | - [e621 model by 🐾Zack🐾#1984](https://discord.gg/BDFpq9Yb7K) 27 | *(NSFW 주의!)* 28 | 29 | 1. 모델과 설정 파일이 포함된 프로젝트 폴더를 `models/deepdanbooru` 경로로 옮깁니다. 30 | 31 | 1. 파일 구조는 다음과 같습니다: 32 | ``` 33 | models/ 34 | └╴deepdanbooru/ 35 | ├╴deepdanbooru-v3-20211112-sgd-e28/ 36 | │ ├╴project.json 37 | │ └╴... 38 | │ 39 | ├╴deepdanbooru-v4-20200814-sgd-e30/ 40 | │ ├╴project.json 41 | │ └╴... 42 | │ 43 | ├╴e621-v3-20221117-sgd-e32/ 44 | │ ├╴project.json 45 | │ └╴... 46 | │ 47 | ... 48 | ``` 49 | 50 | 1. 웹UI 를 시작하거나 재시작합니다. 51 | - 또는 *Interrogator* 드롭다운 상자 우측에 있는 새로고침 버튼을 누릅니다. 52 | 53 | 54 | ## 스크린샷 55 | ![Screenshot](docs/screenshot.png) 56 | 57 | Artwork made by [hecattaart](https://vk.com/hecattaart?w=wall-89063929_3767) 58 | 59 | ## 저작권 60 | 61 | 빌려온 코드(예: `dbimutils.py`)를 제외하고 모두 Public domain 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Fork changes 2 | - v3 taggers added 3 | - Z3D-E621-Convnext tagger added 4 | - Onnxruntime dep version required by v3 taggers 5 | 6 | Tagger for [Automatic1111's WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 7 | --- 8 | Interrogate booru style tags for single or multiple image files using various models, such as DeepDanbooru. 9 | 10 | [한국어를 사용하시나요? 여기에 한국어 설명서가 있습니다!](README.ko.md) 11 | 12 | ## Disclaimer 13 | I didn't make any models, and most of the code was heavily borrowed from the [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) and MrSmillingWolf's tagger. 14 | 15 | ## Installation 16 | 1. *Extensions* -> *Install from URL* -> Enter URL of this repository -> Press *Install* button 17 | - or clone this repository under `extensions/` 18 | ```sh 19 | $ git clone https://github.com/picobyte/stable-diffusion-webui-wd14-tagger.git extensions/tagger 20 | ``` 21 | 22 | 1. *(optional)* Add interrogate model 23 | - #### [*Waifu Diffusion 1.4 Tagger by MrSmilingWolf*](docs/what-is-wd14-tagger.md) 24 | Downloads automatically from the [HuggingFace repository](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) the first time you run it. 25 | 26 | - #### *DeepDanbooru* 27 | 1. Various model files can be found below. 28 | - [DeepDanbooru models](https://github.com/KichangKim/DeepDanbooru/releases) 29 | - [e621 model by 🐾Zack🐾#1984](https://discord.gg/BDFpq9Yb7K) 30 | *(link contains NSFW contents!)* 31 | 32 | 1. Move the project folder containing the model and config to `models/deepdanbooru` 33 | 34 | 1. The file structure should look like: 35 | ``` 36 | models/ 37 | └╴deepdanbooru/ 38 | ├╴deepdanbooru-v3-20211112-sgd-e28/ 39 | │ ├╴project.json 40 | │ └╴... 41 | │ 42 | ├╴deepdanbooru-v4-20200814-sgd-e30/ 43 | │ ├╴project.json 44 | │ └╴... 45 | │ 46 | ├╴e621-v3-20221117-sgd-e32/ 47 | │ ├╴project.json 48 | │ └╴... 49 | │ 50 | ... 51 | ``` 52 | 53 | 1. Start or restart the WebUI. 54 | - or you can press refresh button after *Interrogator* dropdown box. 55 | - "You must close stable diffusion completely after installation and re-run it!" 56 | 57 | 58 | ## Model comparison 59 | [Model comparison](docs/model-comparison.md) 60 | 61 | ## Screenshot 62 | ![Screenshot](docs/screenshot.png) 63 | 64 | Artwork made by [hecattaart](https://vk.com/hecattaart?w=wall-89063929_3767) 65 | 66 | ## Copyright 67 | 68 | Public domain, except borrowed parts (e.g. `dbimutils.py`) 69 | -------------------------------------------------------------------------------- /docs/model-comparison.md: -------------------------------------------------------------------------------- 1 | # Model comparison 2 | --- 3 | 4 | * Used image: [hecattaart's artwork](https://vk.com/hecattaart?w=wall-89063929_3767) 5 | * Threshold: `0.5` 6 | 7 | ## DeepDanbooru 8 | 9 | ### [`deepdanbooru-v3-20211112-sgd-e28`](https://github.com/KichangKim/DeepDanbooru/releases/tag/v3-20211112-sgd-e28) 10 | > 1girl, animal ears, cat ears, cat tail, clothes writing, full body, rating:safe, shiba inu, shirt, shoes, simple background, sneakers, socks, solo, standing, t-shirt, tail, white background, white shirt 11 | 12 | ### [`deepdanbooru-v4-20200814-sgd-e30`](https://github.com/KichangKim/DeepDanbooru/releases/tag/v4-20200814-sgd-e30) 13 | > 1girl, animal, animal ears, bottomless, clothes writing, full body, rating:safe, shirt, shoes, short sleeves, sneakers, solo, standing, t-shirt, tail, white background, white shirt 14 | 15 | ## e621 16 | 17 | ### `e621-v3-20221117-sgd-e32` 18 | > anthro, bottomwear, clothing, footwear, fur, hi res, mammal, shirt, shoes, shorts, simple background, sneakers, socks, solo, standing, text on clothing, text on topwear, topwear, white background 19 | 20 | ### [`Z3D E621 Convnext`](https://huggingface.co/toynya/Z3D-E621-Convnext) 21 | > mammal, solo, clothing, anthro, felid, footwear, feline, socks, topwear, shirt, domestic cat, felis, simple background, clothed, white background, shoes, piercing, ear piercing, hi res, fur 22 | 23 | ## ML-Danbooru 24 | 25 | ### [`ML-Danbooru Caformer dec-5-97527`](https://huggingface.co/deepghs/ml-danbooru-onnx) 26 | > shirt, white background, shoes, simple background, solo, dog, socks, animal, white shirt, black footwear, no humans, animal focus, full body, t-shirt, star (symbol), star print, artist name, green legwear, sweat, leg hair, clothes writing, no pants, standing, short sleeves, sneakers, signature, looking at viewer, closed mouth, cat, walking, oversized clothes, dirty, dirty clothes, clothed animal, outline, print legwear, shiba inu, black eyes, bandaid, bottomless, print shirt, english text, oversized shirt, artist logo, legs apart, romaji text, chromatic aberration, earrings, 1boy, tail, sweatdrop, male focus, :3, furry, naked shirt, white outline, bare legs, blush, jewelry, supreme, looking away, bruise, legs, tears, ferret, 1girl, off shoulder, profanity, holding, plaid legwear, watermark, injury, dirty face, 1other, pug, scratches, long shirt 27 | 28 | ### [`ML-Danbooru TResNet-D 6-30000`](https://huggingface.co/deepghs/ml-danbooru-onnx) 29 | > tail, animal ears, white background, solo, 1girl, cat, simple background, shoes, full body, clothes writing, shirt, t-shirt, standing, sneakers, furry, socks, animal, short sleeves, food print, cat ears, print shirt, white shirt, short hair, dog, bottomless, dog tail, artist name, sportswear, furry male, paw print, signature, black footwear, closed mouth, green legwear, no pants, dog ears, :3, bare legs, green ribbon, cat tail, nike, legs apart, black eyes, english text, furry female, looking at viewer, green footwear, underwear, oversized clothes, dated, animal print, star print, dress, fake animal ears, shadow, bangs, naked shirt, blush, cat print, clothed animal, holding, sweat, brown eyes, thighs, adidas, male focus, black hair, extra ears, green panties, pigeon-toed, cat girl, tears, smile, collar, white footwear, watermark, looking down, legs, twitter username, animalization, whiskers, alternate costume, animal focus, sweatdrop, you work you lose, green eyes, medium hair, oversized shirt, character name, shiba inu, shorts, 1boy, blue eyes, baseball uniform, looking at another, brown hair 30 | 31 | ## Waifu Diffusion Tagger 32 | 33 | ### [`WD ConvNeXT v1`](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger) 34 | > solo, tail, shirt, shoes, white background, furry, simple background, full body, socks 35 | 36 | ### [`WD ConvNeXT v2`](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2) 37 | > shirt, tail, solo, simple background, white background, shoes, full body, socks, clothes writing, white shirt, earrings, animal focus, sweat 38 | 39 | ### [`WD ConvNeXTV2 v1`](https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2) 40 | > shirt, white background, shoes, simple background, solo, socks, leg hair, black footwear, tail, full body, white shirt, cat, furry, star (symbol), black eyes, looking at viewer 41 | 42 | ### [`WD ConvNext v3`](https://huggingface.co/SmilingWolf/wd-convnext-tagger-v3) 43 | > shirt, solo, socks, tail, shoes, simple background, white background, full body, furry 44 | 45 | ### [`WD EVA02 v3 Large`](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3) 46 | > shirt, shoes, no humans, animal focus, walking, socks, full body, cat, simple background, black footwear, white background, meme, clothed animal, green socks, print shirt, tail, solo, white shirt, clothes writing, leg hair, furry, star (symbol) 47 | 48 | ### [`WD SwinV2 v1`](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2) 49 | > leg hair, solo, shirt, 1boy, male focus, shoes, white background, tail, full body, simple background, socks, white shirt, furry, black footwear, dirty, arm hair, standing 50 | 51 | ### [`WD SwinV2 v3`](https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3) 52 | > leg hair, shirt, solo, shoes, socks, simple background, white background, tail, 1boy, white shirt, animal, full body, male focus, black footwear 53 | 54 | ### [`WD ViT v1`](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) 55 | > solo, 1boy, shirt, male focus, furry, socks, tail, dog, leg hair, shoes, white background, animal ears, simple background 56 | 57 | ### [`WD ViT v2`](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2) 58 | > solo, shirt, 1boy, male focus, furry, tail, socks, animal ears, white background, shoes, simple background, cat 59 | 60 | ### [`WD ViT v3`](https://huggingface.co/SmilingWolf/wd-vit-tagger-v3) 61 | > shirt, shoes, solo, simple background, white background, white shirt, socks, full body, print shirt, walking, cat, tail 62 | 63 | ### [`WD ViT v3 Large`](https://huggingface.co/SmilingWolf/wd-vit-large-tagger-v3) 64 | > shirt, shoes, socks, simple background, full body, black footwear, white shirt, solo, tail, cat, green socks, white background, meme, walking, animal focus, furry, no humans, whiskers 65 | 66 | ### [`WD MOAT v2`](https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2) 67 | > shirt, solo, shoes, socks, simple background, full body, white background, 1boy, white shirt 68 | -------------------------------------------------------------------------------- /docs/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/67372a/stable-diffusion-webui-wd14-tagger/6624e0803a0021732adf9a986ad5f6fa20082323/docs/screenshot.png -------------------------------------------------------------------------------- /docs/what-is-wd14-tagger.md: -------------------------------------------------------------------------------- 1 | What is Waifu Diffison 1.4 Tagger? 2 | --- 3 | 4 | Image to text model created and maintained by [MrSmilingWolf](https://huggingface.co/SmilingWolf), which was used to train Waifu Diffusion. 5 | 6 | Please ask the original author `MrSmilingWolf#5991` for questions related to model or additional training. 7 | 8 | ## SwinV2 vs Convnext vs ViT 9 | > It's got characters now, the HF space has been updated too. Model of choice for classification is SwinV2 now. ConvNext was used to extract features because SwinV2 is a bit of a pain cuz it is twice as slow and more memory intensive 10 | 11 | — [this message](https://discord.com/channels/930499730843250783/930499731451428926/1066830289382408285) from the [東方Project AI discord server](https://discord.com/invite/touhouai) 12 | 13 | > To make it clear: the ViT model is the one used to tag images for WD 1.4. That's why the repo was originally called like that. This one has been trained on the same data and tags, but has got no other relation to WD 1.4, aside from stemming from the same coordination effort. They were trained in parallel, and the best one at the time was selected for WD 1.4 14 | > 15 | > This particular model was trained later and might actually be slightly better than the ViT one. Difference is in the noise range tho 16 | 17 | — [this thread](https://discord.com/channels/930499730843250783/1052283314997837955) from the [東方Project AI discord server](https://discord.com/invite/touhouai) 18 | 19 | ## Performance 20 | > I stack them together and get a 1.1GB model with higher validation metrics than the three separated, so they each do their own thing and averaging the predictions sorta helps covering for each models failures. I suppose. 21 | > As for my impression for each model: 22 | > - SwinV2: a memory and GPU hog. Best metrics of the bunch, my model is compatible with timm weights (so it can be used on PyTorch if somebody ports it) but slooow. Good for a few predictions, would reconsider for massive tagging jobs if you're pressed for time 23 | > - ConvNext: nice perfs, good metrics. A sweet spot. The 1024 final embedding size provides ample space for training the Dense layer on other datasets, like E621. 24 | > - ViT: fastest of the bunch, at least on TPU, probably on GPU too? Slightly less then stellar metrics when compared with the other two. Onnxruntime and Tensorflow keep adding optimizations for Transformer models so that's good too. 25 | 26 | — [this message](https://discord.com/channels/930499730843250783/930499731451428926/1066833768112996384) from the [東方Project AI discord server](https://discord.com/invite/touhouai) 27 | 28 | ## Links 29 | - [MrSmilingWolf's HuggingFace profile](https://huggingface.co/SmilingWolf) 30 | - [MrSmilingWolf's GitHub profile](https://github.com/SmilingWolf) 31 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | """Install requirements for WD14-tagger.""" 2 | import os 3 | import sys 4 | 5 | from launch import run # pylint: disable=import-error 6 | 7 | NAME = "WD14-tagger" 8 | req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 9 | "requirements.txt") 10 | print(f"loading {NAME} reqs from {req_file}") 11 | run(f'"{sys.executable}" -m pip install -q -r "{req_file}"', 12 | f"Checking {NAME} requirements.", 13 | f"Couldn't install {NAME} requirements.") 14 | -------------------------------------------------------------------------------- /javascript/tagger.js: -------------------------------------------------------------------------------- 1 | /** 2 | * wait until element is loaded and returns 3 | * @param {string} selector 4 | * @param {number} timeout 5 | * @param {Element} $rootElement 6 | * @returns {Promise} 7 | */ 8 | function waitQuerySelector(selector, timeout = 5000, $rootElement = gradioApp()) { 9 | return new Promise((resolve, reject) => { 10 | const element = $rootElement.querySelector(selector) 11 | if (document.querySelector(element)) { 12 | return resolve(element) 13 | } 14 | 15 | let timeoutId 16 | 17 | const observer = new MutationObserver(() => { 18 | const element = $rootElement.querySelector(selector) 19 | if (!element) { 20 | return 21 | } 22 | 23 | if (timeoutId) { 24 | clearInterval(timeoutId) 25 | } 26 | 27 | observer.disconnect() 28 | resolve(element) 29 | }) 30 | 31 | timeoutId = setTimeout(() => { 32 | observer.disconnect() 33 | reject(new Error(`timeout, cannot find element by '${selector}'`)) 34 | }, timeout) 35 | 36 | observer.observe($rootElement, { 37 | childList: true, 38 | subtree: true 39 | }) 40 | }) 41 | } 42 | 43 | function tag_clicked(tag, is_inverse) { 44 | // escaped characters 45 | const escapedTag = tag.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); 46 | 47 | // add the tag to the selected textarea 48 | let $selectedTextarea; 49 | if (is_inverse) { 50 | $selectedTextarea = document.getElementById('keep-tags'); 51 | } else { 52 | $selectedTextarea = document.getElementById('exclude-tags'); 53 | } 54 | let value = $selectedTextarea.querySelector('textarea').value; 55 | // ignore if tag is already exist in textbox 56 | const pattern = new RegExp(`(^|,)\\s{0,}${escapedTag}\\s{0,}($|,)`); 57 | if (pattern.test(value)) { 58 | return; 59 | } 60 | const emptyRegex = new RegExp(`^\\s*$`); 61 | if (!emptyRegex.test(value)) { 62 | value += ', '; 63 | } 64 | // besides setting the value an event needs to be triggered or the value isn't actually stored. 65 | const input_event = new Event('input'); 66 | $selectedTextarea.querySelector('textarea').value = value + escapedTag; 67 | $selectedTextarea.dispatchEvent(input_event); 68 | const input_event2 = new Event('blur'); 69 | $selectedTextarea.dispatchEvent(input_event2); 70 | } 71 | 72 | document.addEventListener('DOMContentLoaded', () => { 73 | Promise.all([ 74 | // option texts 75 | waitQuerySelector('#keep-tags'), 76 | waitQuerySelector('#exclude-tags'), 77 | waitQuerySelector('#search-tags'), 78 | waitQuerySelector('#replace-tags'), 79 | 80 | // tag-confident labels 81 | waitQuerySelector('#rating-confidences'), 82 | waitQuerySelector('#tag-confidences'), 83 | waitQuerySelector('#discard-tag-confidences') 84 | ]).then(elements => { 85 | 86 | const $keepTags = elements[0]; 87 | const $excludeTags = elements[1]; 88 | const $searchTags = elements[2]; 89 | const $replaceTags = elements[3]; 90 | const $ratingConfidents = elements[4]; 91 | const $tagConfidents = elements[5]; 92 | const $discardTagConfidents = elements[6]; 93 | 94 | let $selectedTextarea = $keepTags; 95 | 96 | /** 97 | * @this {HTMLElement} 98 | * @param {MouseEvent} e 99 | * @listens document#click 100 | */ 101 | function onClickTextarea(e) { 102 | $selectedTextarea = this; 103 | } 104 | 105 | $keepTags.addEventListener('click', onClickTextarea); 106 | $excludeTags.addEventListener('click', onClickTextarea); 107 | $searchTags.addEventListener('click', onClickTextarea); 108 | $replaceTags.addEventListener('click', onClickTextarea); 109 | 110 | /** 111 | * @this {HTMLElement} 112 | * @param {MouseEvent} e 113 | * @listens document#click 114 | */ 115 | function onClickLabels(e) { 116 | // find clicked label item's wrapper element 117 | let tag = e.target.innerText; 118 | 119 | // when clicking unlucky, you get all tags and percentages. Prevent inserting those here. 120 | const multiTag = new RegExp(`\\n.*\\n`); 121 | if (tag.match(multiTag)) { 122 | return; 123 | } 124 | 125 | // when clicking on the dotted line or the percentage, you get the percentage as well. Don't include it in the tags. 126 | // use this fact to choose whether to insert in positive or negative. May require some getting used to, but saves 127 | // having to select the input field. 128 | const pctPattern = new RegExp(`\\n?([0-9.]+)%$`); 129 | let percentage = tag.match(pctPattern); 130 | if (percentage) { 131 | tag = tag.replace(pctPattern, ''); 132 | if (tag == '') { 133 | //percentage = percentage[1]; 134 | // could trigger a set Thresold value event 135 | return; 136 | } 137 | // when clicking on athe dotted line, insert in either the exclude or replace list 138 | // when not clicking on the dotted line, insert in the additingal or search list 139 | if ($selectedTextarea == $keepTags) { 140 | $selectedTextarea = $excludeTags; 141 | } else if ($selectedTextarea == $searchTags) { 142 | $selectedTextarea = $replaceTags; 143 | } 144 | } else if ($selectedTextarea == $excludeTags) { 145 | $selectedTextarea = $keepTags; 146 | } else if ($selectedTextarea == $replaceTags) { 147 | $selectedTextarea = $searchTags; 148 | } 149 | 150 | let value = $selectedTextarea.querySelector('textarea').value; 151 | // except replace_tag because multiple can be replaced with the same 152 | if ($selectedTextarea != $replaceTags) { 153 | // ignore if tag is already exist in textbox 154 | const escapedTag = tag.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); 155 | const pattern = new RegExp(`(^|,)\\s{0,}${escapedTag}\\s{0,}($|,)`); 156 | if (pattern.test(value)) { 157 | return; 158 | } 159 | } 160 | 161 | // besides setting the value an event needs to be triggered or the value isn't actually stored. 162 | const spaceOrAlreadyWithComma = new RegExp(`(^|.*,)\\s*$`); 163 | if (!spaceOrAlreadyWithComma.test(value)) { 164 | value += ', '; 165 | } 166 | const input_event = new Event('input'); 167 | $selectedTextarea.querySelector('textarea').value = value + tag; 168 | $selectedTextarea.querySelector('textarea').dispatchEvent(input_event); 169 | const input_event2 = new Event('blur'); 170 | $selectedTextarea.querySelector('textarea').dispatchEvent(input_event2); 171 | 172 | } 173 | 174 | $tagConfidents.addEventListener('click', onClickLabels) 175 | $discardTagConfidents.addEventListener('click', onClickLabels) 176 | 177 | }).catch(err => { 178 | console.error(err) 179 | }) 180 | }) 181 | -------------------------------------------------------------------------------- /json_schema/db_json_v1_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "object", 3 | "properties": { 4 | "rating": { "$ref": "#/$defs/weighted_label" }, 5 | "tag": { "$ref": "#/$defs/weighted_label" }, 6 | "query": { 7 | "type": "object", 8 | "patternProperties": { 9 | "^[0-9a-f]{64}.*$": { 10 | "type": "array", 11 | "prefixItems": [ 12 | {"type": "string" }, 13 | {"type": "number", "minimum": 0} 14 | ], 15 | "minContains": 2, 16 | "maxContains": 2 17 | } 18 | } 19 | }, 20 | "meta": { 21 | "type": "object", 22 | "properties": { 23 | "index_shift": { 24 | "type": "integer", 25 | "minimum": 0, 26 | "maximum": 16 27 | } 28 | } 29 | }, 30 | "add": { "type": "string" }, 31 | "exclude": { "type": "string" }, 32 | "keep": { "type": "string" }, 33 | "repl": { "type": "string" }, 34 | "search": { "type": "string" } 35 | }, 36 | "required": ["rating", "tag", "query"], 37 | "additionalProperties": false, 38 | "$defs": { 39 | "weighted_label": { 40 | "type": "object", 41 | "patternProperties": { 42 | "^[^,]+$": { 43 | "type": "array", 44 | "items": { 45 | "type": "number", 46 | "minimum": 0 47 | } 48 | } 49 | } 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /preload.py: -------------------------------------------------------------------------------- 1 | """ Preload module for DeepDanbooru or onnxtagger. """ 2 | from argparse import ArgumentParser 3 | 4 | 5 | def preload(parser: ArgumentParser): 6 | """ Preload module for DeepDanbooru or onnxtagger. """ 7 | # default deepdanbooru use different paths: 8 | # models/deepbooru and models/torch_deepdanbooru 9 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/c81d440d876dfd2ab3560410f37442ef56fc6632 10 | 11 | parser.add_argument( 12 | '--deepdanbooru-projects-path', 13 | type=str, 14 | help='Path to directory with DeepDanbooru project(s).' 15 | ) 16 | parser.add_argument( 17 | '--onnxtagger-path', 18 | type=str, 19 | help='Path to directory with Onnyx project(s).' 20 | ) 21 | # TODO allow using devices in parallel, specified as comma separed list 22 | parser.add_argument( 23 | '--additional-device-ids', 24 | type=str, 25 | help='Device ID to use. cpu:0, gpu:0 or gpu:1, etc.', 26 | ) 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | 3 | target-version = "py39" 4 | 5 | extend-select = [ 6 | "B", 7 | "C", 8 | "I", 9 | "W", 10 | ] 11 | 12 | exclude = [ 13 | "addons", 14 | ] 15 | 16 | ignore = [ 17 | "E501", # Line too long 18 | "E731", # Do not assign a `lambda` expression, use a `def` 19 | 20 | "I001", # Import block is un-sorted or un-formatted 21 | "C901", # Function is too complex 22 | "C408", # Rewrite as a literal 23 | "W605", # invalid escape sequence, messes with some docstrings 24 | ] 25 | 26 | #[tool.ruff.per-file-ignores] 27 | #"webui.py" = ["E402"] # Module level import not at top of file 28 | 29 | #[tool.ruff.flake8-bugbear] 30 | # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. 31 | #extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] 32 | 33 | [tool.pytest.ini_options] 34 | base_url = "http://127.0.0.1:7860" 35 | 36 | [tool.pylint.'MESSAGES CONTROL'] 37 | extension-pkg-whitelist = ["pydantic"] 38 | disable= ["C", "R", "W", "E", "I"] 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deepdanbooru 2 | jsonschema 3 | fastapi 4 | gradio 5 | huggingface_hub 6 | numpy 7 | opencv_contrib_python 8 | opencv_python 9 | opencv_python_headless 10 | packaging 11 | pandas 12 | Pillow 13 | tensorflow 14 | tqdm 15 | onnxruntime>=1.17.0 16 | -------------------------------------------------------------------------------- /scripts/tagger.py: -------------------------------------------------------------------------------- 1 | """Tagger module entry point.""" 2 | from PIL import Image, ImageFile 3 | 4 | from modules import script_callbacks # pylint: disable=import-error 5 | from tagger.api import on_app_started # pylint: disable=import-error 6 | from tagger.ui import on_ui_tabs # pylint: disable=import-error 7 | from tagger.settings import on_ui_settings # pylint: disable=import-error 8 | 9 | 10 | # if you do not initialize the Image object 11 | # Image.registered_extensions() returns only PNG 12 | Image.init() 13 | 14 | # PIL spits errors when loading a truncated image by default 15 | # https://pillow.readthedocs.io/en/stable/reference/ImageFile.html#PIL.ImageFile.LOAD_TRUNCATED_IMAGES 16 | ImageFile.LOAD_TRUNCATED_IMAGES = True 17 | 18 | 19 | script_callbacks.on_app_started(on_app_started) 20 | script_callbacks.on_ui_tabs(on_ui_tabs) 21 | script_callbacks.on_ui_settings(on_ui_settings) 22 | -------------------------------------------------------------------------------- /shell_scripts/compare_weighted_frequencies.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from math import ceil 3 | from json import load 4 | import argparse 5 | import re 6 | from collections import defaultdict 7 | 8 | # read two json files, compare the weighted frequencies of the tags in the two 9 | # files the first file is json and contains all safetensor files, with major 10 | # sections and weighted tags 11 | 12 | # the second file is the result of an interrogations of images, with weighted 13 | # tags. tags may be substrings of the tags in the first file 14 | 15 | # next argument is an interrogator id, a string 16 | # optinally there are comma delimited images as arguments 17 | 18 | # in the end print the top ten safetensors and major sections that containss 19 | # the tags that are most similar to the tags in the second file 20 | 21 | # all weights are between 0 and 1, higher is more important 22 | 23 | 24 | # example usage: 25 | # first run shell_scripts/create_safetensors_db.sh 26 | # then interrogate an image in a subdirectory test/ 27 | 28 | 29 | # cd stable-diffusion-webui/extensions/stable-diffusion-webui-wd14-tagger/ 30 | # 31 | # python shell_scripts/compare_weighted_frequencies.py safetensors_db.json \ 32 | # test/db.json 33 | 34 | # # .. lists used interrogation models 35 | 36 | 37 | # python shell_scripts/compare_weighted_frequencies.py safetensors_db.json \ 38 | # -c 20 test/db.json 39 | 40 | 41 | desc = 'Compare weighted frequencies of tags in two json file' 42 | parser = argparse.ArgumentParser(description=desc) 43 | hlp = 'number of results to print' 44 | parser.add_argument('-c', '--count', default=10, type=int, help=hlp) 45 | parser.add_argument('file1', help='all safetensors json file') 46 | parser.add_argument('file2', help='image interrogation json file') 47 | parser.add_argument('id', help='interrogator id', nargs='?', default="") 48 | parser.add_argument('images', nargs='*', help='images', default=[]) 49 | args = parser.parse_args() 50 | 51 | 52 | with open(args.file1) as f: 53 | all_sftns = load(f) 54 | 55 | with open(args.file2) as f: 56 | data = load(f) 57 | 58 | query = data["query"] 59 | 60 | indices = set() 61 | if args.id == "": 62 | uniq = set() 63 | for k in data["query"]: 64 | if k not in uniq: 65 | uniq.add(k[64:]) 66 | if len(uniq) != 1: 67 | print("Missing interrogator id, contained are:") 68 | for k in uniq: 69 | print(k) 70 | exit(1) 71 | else: 72 | # use the only one 73 | args.id = uniq.pop() 74 | 75 | for k, t in data["query"].items(): 76 | img_fn, idx = t 77 | if k[64:] == args.id: 78 | if len(args.images) > 0: 79 | for i in args.images: 80 | if img_fn[-len(i):] == i: 81 | break 82 | else: 83 | continue 84 | indices.add(int(idx)) 85 | 86 | interrogation_result = {} 87 | for t, lst in data["tag"].items(): 88 | wt = 0.0 89 | for stored in lst: 90 | i = ceil(stored) - 1 91 | if i in indices: 92 | wt += stored - i 93 | if wt > 0.0: 94 | interrogation_result[t] = wt / len(indices) 95 | 96 | scores: Dict[str, float] = defaultdict(float) 97 | 98 | for safetensor in all_sftns: 99 | for major in all_sftns[safetensor]: 100 | ct = len(all_sftns[safetensor][major]) 101 | if ct == 0: 102 | continue 103 | 104 | for tag, wt in interrogation_result.items(): 105 | if tag in all_sftns[safetensor][major]: 106 | sftns_wt = all_sftns[safetensor][major][tag] 107 | n = (1.0 - abs(sftns_wt - wt)) 108 | scores[safetensor + "\t" + major] += n / ct 109 | else: 110 | rex = re.compile(r'\b{}\b'.format(tag)) 111 | t_len = len(tag) 112 | # the tag may be a substring of a tag in the safetensor 113 | # however only entire words are considered and a penalty if the 114 | # string lenghts are close to each other 115 | highest = 0.0 116 | for sftns_tag in all_sftns[safetensor][major]: 117 | if rex.search(sftns_tag): 118 | sftns_tag_len = len(sftns_tag) 119 | sftns_wt = all_sftns[safetensor][major][sftns_tag] 120 | n = (sftns_tag_len - t_len) / sftns_tag_len 121 | n -= abs(sftns_wt - wt) 122 | highest = max(highest, n) 123 | scores[safetensor + "\t" + major] += highest / ct 124 | 125 | # sort the scores 126 | sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) 127 | 128 | # print the top ten safetensors and major sections 129 | for i in range(args.count): 130 | if i >= len(sorted_scores): 131 | break 132 | print(sorted_scores[i][0] + "\t" + str(sorted_scores[i][1])) 133 | -------------------------------------------------------------------------------- /shell_scripts/create_safetensors_db.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Create a database for safetensors wherein for the 4 | # models each major tag, the occurrence frequency of 5 | # each associated subtag is listed. 6 | # 7 | # requires https://github.com/by321/safetensors_util.git 8 | # gnu parallel, jq, sed, awk 9 | # 10 | 11 | # To build the safetensors_db.json database with 12 | # "file.safetensors" { "major tag": { "tag1": , "tag2": .. } }: 13 | 14 | 15 | # cd stable-diffusion-webui/extensions/stable-diffusion-webui-wd14-tagger/ 16 | # git clone https://github.com/by321/safetensors_util.git 17 | # 18 | # bash shell_scripts/create_safetensors_db.sh -f -p ../../models/Lora -u safetensors_util/ -o safetensors_db.json 19 | # 20 | ## now you can compare interrogation weights with the safetensors_db.json using 21 | ## shell_scripts/compare_weighted_frequencies.py, see there for usage. 22 | 23 | 24 | # number of cpus to use by default or use -j to specify 25 | ncpu=$(nproc --all) 26 | [ $ncpu -gt 8 ] && ncpu=8 27 | 28 | path=. 29 | utilpath=. 30 | force=0 31 | out=safetensors_db.json 32 | 33 | while [ $# -gt 0 ]; do 34 | case "$1" in 35 | -h|--help) 36 | echo "Usage: $0 [-j ncpu] [-p path] [-u utilpath] [-f] [-o out]" 37 | echo " -j ncpu number of cpus to use (default: $ncpu)" 38 | echo " -p path path to directory containing safetensor models (default: ./)" 39 | echo " -u utilpath path to safetensors_util.py" 40 | echo " -f force overwrite of output file" 41 | echo " -o out output file (default: safetensors_db.json)" 42 | exit 0 43 | ;; 44 | -j) ncpu="$2"; shift 2;; 45 | -p) path="$2/"; shift 2;; 46 | -u) utilpath="$2/"; shift 2;; 47 | -f) force=1; shift 1;; 48 | -o) out="$2"; shift 2;; 49 | esac 50 | done 51 | 52 | if [ ! -d "${path}" ]; then 53 | echo "Error: '${path}' does not exist (use -p to specify path)" 54 | exit 1 55 | fi 56 | 57 | if [ ! -e "${utilpath}/safetensors_util.py" ]; then 58 | echo "Error: ${utilpath}/safetensors_util.py does not exist (use -u to specify path)" 59 | exit 1 60 | fi 61 | 62 | if [ -e "${out}" -a $force -eq 0 ]; then 63 | echo "Error: ${out} already exists (use -f to overwrite)" 64 | exit 1 65 | fi 66 | 67 | ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | 68 | sed -n '1b;p' | jq -r 'select(.__metadata__ != null) | .__metadata__ | .ss_tag_frequency | select( . != null )' 2>/dev/null | sed 's/\" /\"/' | 69 | awk -v FS=': ' '{ 70 | if (index(\$2, \"null\") > 0) next 71 | o = index(\$0, \"{\") 72 | if (o == 1) printf \"\\\"'{}'\\\": \" 73 | if (o > 0) { 74 | print \$0 75 | m = 0 76 | } else { 77 | c = index(\$0, \"}\") 78 | if (c > 0) { 79 | L=\"\" 80 | for (i in a) { 81 | if (L != \"\") print \",\" 82 | printf \"%s: %.6f\", i, a[i] / m 83 | L = "x" 84 | } 85 | delete a 86 | if (c == 1) print \$0\",\" 87 | else print \"\n\"\$0 88 | } else { 89 | x = index(\$2, \",\") 90 | v = int(x != 0 ? substr(\$2, 1, x - 1) : \$2) 91 | if (v > m) m = v 92 | a[\$1] = v 93 | } 94 | } 95 | }'" | sed -r ' 96 | s/^/ /; 97 | 1s/^/{\n/; 98 | s/\\"//g 99 | s/^([ \t]+"[^"]+):*(: [01]+(\.[0-9]+)?,?)$/\1"\2/ 100 | $s/,?$/\n}/ 101 | ' > "${out}" 102 | 103 | 104 | -------------------------------------------------------------------------------- /shell_scripts/model_grep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # requires jq, grep 3 | # 4 | # Usage: ./model_grep.sh [-p path] [-u utilpath] 5 | 6 | # number of cpus to use by default or use -j to specify 7 | ncpu=$(nproc --all) 8 | [ $ncpu -gt 8 ] && ncpu=8 9 | 10 | path=. 11 | utilpath=. 12 | 13 | while [ $# -gt 1 ]; do 14 | case "$1" in 15 | -h|--help) 16 | echo "Usage: $0 [ -j ncpu] [-p path] [-u utilpath] " 17 | echo " -j ncpu number of cpus to use (default: $ncpu)" 18 | echo " -p path path to directory containing safetensor models (default: ./)" 19 | echo " -u utilpath path to safetensors_util.py" 20 | echo " extended regex to match against model names" 21 | exit 0 22 | ;; 23 | -j) ncpu="$2"; shift 2;; 24 | -p) path="$2"; shift 2;; 25 | -u) utilpath="$2"; shift 2;; 26 | esac 27 | done 28 | 29 | if [ ! -d "${path}" ]; then 30 | echo "Error: ${path} does not exist (use -p to specify path)" 31 | exit 1 32 | fi 33 | 34 | if [ ! -e "${utilpath}/safetensors_util.py" ]; then 35 | echo "Error: ${utilpath}/safetensors_util.py does not exist (use -u to specify path)" 36 | exit 1 37 | fi 38 | 39 | if [ -n "$1" ]; then 40 | echo "reading from $1 with $ncpu cpus" 1>&2 41 | ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | 42 | sed -n '1b;p' | jq '.__metadata__.ss_tag_frequency' 2>/dev/null | grep -o -E '\"[^\"]*${1}[^\"]*\": [0-9]+'| sed 's~^~'{}':~'" 43 | else 44 | echo "reading from stdin" 1>&2 45 | tmp=$(mktemp) 46 | sed 's/^/"[^\"]*/;s/$/[^\"]*": [0-9]+/' < /dev/stdin > $tmp 47 | ls -1 ${path}/*.safetensors | parallel -n 1 -j $ncpu "python ${utilpath}/safetensors_util.py metadata {} -pm 2>/dev/null | 48 | sed -n '1b;p' | jq '.__metadata__.ss_tag_frequency' 2>/dev/null | grep -oE -f $tmp | sed 's~^~'{}':~'" 49 | echo rm $tmp 50 | fi 51 | 52 | 53 | -------------------------------------------------------------------------------- /shell_scripts/tag_based_image_dedup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # this script is for deduping images based on tags after they have been interrogated using this extension 4 | # 5 | # the file removal instructions are written to remove_instructions.sh 6 | # you have to manually run remove_instructions.sh to remove the files 7 | # this script requires exiftool and feh 8 | # TODO: implement this in the extension 9 | # 10 | # Usage: 11 | # repo_dir=/path/to/repo 12 | # cd /path/to/images 13 | # 14 | 15 | # use tabs as field separator 16 | while read -r -d '\t' first_file second_file etc; do 17 | # images may be jpg jpeg or png 18 | first_image=$(basename "$first_file" ".txt") 19 | if [[ -f "$first_image.jpg" ]]; then 20 | first_image="$first_image.jpg" 21 | elif [[ -f "$first_image.jpeg" ]]; then 22 | first_image="$first_image.jpeg" 23 | elif [[ -f "$first_image.png" ]]; then 24 | first_image="$first_image.png" 25 | else 26 | echo "No image file found for $first_file" 1>&2 27 | continue 28 | fi 29 | second_image=$(basename "$second_file" ".txt") 30 | if [[ -f "$second_image.jpg" ]]; then 31 | second_image="$second_image.jpg" 32 | elif [[ -f "$second_image.jpeg" ]]; then 33 | second_image="$second_image.jpeg" 34 | elif [[ -f "$second_image.png" ]]; then 35 | second_image="$second_image.png" 36 | else 37 | echo "No image file found for $second_file" 1>&2 38 | continue 39 | fi 40 | feh -g 950x800+5+30 -Z --scale-down -d -S filename --title "$first_image" "$first_image"& 41 | pid1=$! 42 | feh -g 950x800+963+30 -Z --scale-down -d -S filename --title "$second_image" "$second_image"& 43 | pid2=$! 44 | read -p "Are $first_image and $second_image the same? " -n 1 -r REPLY &2 45 | if [[ ! $REPLY =~ ^[Yy]$ ]]; then 46 | echo "Not the same" 1>&2 47 | continue 48 | fi 49 | # keep file with largest dimensions 50 | first_width=$(exiftool "$first_image" | grep -E '^Image Width' | cut -d ':' -f 2) 51 | first_height=$(exiftool "$first_image" | grep -E '^Image Height' | cut -d ':' -f 2) 52 | second_width=$(exiftool "$second_image" | grep -E '^Image Width' | cut -d ':' -f 2) 53 | second_height=$(exiftool "$second_image" | grep -E '^Image Height' | cut -d ':' -f 2) 54 | echo -e "$first_image: ${first_width}x${first_height}\t-\t$second_image: ${second_width}x${second_height}" 1>&2 55 | first_product=$((first_width * first_height)) 56 | second_product=$((second_width * second_height)) 57 | 58 | if [ $first_product -eq $second_product ]; then 59 | read -p "Same size for 1) $first_image and 2) $second_image. Which one do you want to keep? (1/2) [skip]" -n 1 -r REPLY &2 60 | if [[ $REPLY =~ ^[1]$ ]]; then 61 | echo "Keeping $first_file" 1>&2 62 | echo rm "$second_file" "$second_image" 63 | elif [[ $REPLY =~ ^[2]$ ]]; then 64 | echo "Keeping $second_file" 1>&2 65 | echo rm "$first_file" "$first_image" 66 | else 67 | echo "Skipping" 1>&2 68 | fi 69 | elif [ $((first_width * first_height)) -gt $((second_width * second_height)) ]; then 70 | echo "Keeping $first_file" 1>&2 71 | echo rm "$second_file" "$second_image" 72 | else 73 | echo "Keeping $second_file" 1>&2 74 | echo rm "$first_file" "$first_image" 75 | fi 76 | kill $pid1 $pid2 77 | done < <( 78 | ls -1 *.txt | while read f; do 79 | sed 's/, /\n/g' "$f" | sort | tr '\n' ',' | sed "s~,$~\t$f\n~" 80 | done | sort | awk -F'\t' '{ 81 | a[$1] = a[$1] == "" ? $2 : a[$1]"\t"$2; 82 | } END { 83 | for (i in a) { 84 | if (index(a[i], "\t") != 0) { 85 | print a[i]; 86 | } 87 | } 88 | }') > remove_instructions.sh 89 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | #rating-confidences .output-label>div:not(:first-child) { 2 | cursor: pointer; 3 | } 4 | 5 | #tag-confidences .output-label>div:not(:first-child) { 6 | cursor: pointer; 7 | } 8 | 9 | #rating-confidences .output-label>div:not(:first-child):hover { 10 | foreground-color: #f5f5f5; 11 | } 12 | 13 | #tag-confidences .output-label>div:not(:first-child):hover { 14 | foreground-color: #f5f5f5; 15 | } 16 | 17 | #rating-confidences .output-label>div:not(:first-child):active { 18 | foreground-color: #e6e6e6; 19 | } 20 | 21 | #tag-confidences .output-label>div:not(:first-child):active { 22 | foreground-color: #e6e6e6; 23 | } 24 | 25 | #discard-tag-confidences .output-label>div:not(:first-child) { 26 | cursor: pointer; 27 | } 28 | 29 | #discard-tag-confidences .output-label>div:not(:first-child):hover { 30 | foreground-color: #f5f5f5; 31 | } 32 | 33 | #discard-tag-confidences .output-label>div:not(:first-child):active { 34 | foreground-color: #e6e6e6; 35 | } 36 | #tags a { 37 | font-weight: inherit; 38 | color: #888; 39 | } 40 | #tags a:hover { 41 | color: #f5f5f5; 42 | } 43 | 44 | #gallery-container { 45 | display: flex; 46 | flex-direction: column; 47 | } 48 | 49 | #gallery-container > .gradio-row { 50 | flex: 1; 51 | } 52 | 53 | #gallery { 54 | height: 800px; 55 | margin-bottom: 0; 56 | } 57 | 58 | #tag-editor { 59 | height: 800px !important; 60 | overflow: hidden !important; 61 | } 62 | 63 | #tag-editor > .wrap { 64 | max-height: calc(800px - 40px) !important; 65 | overflow-y: auto !important; 66 | padding-right: 8px; 67 | } 68 | 69 | #tag-editor label { 70 | padding: 4px 0; 71 | margin: 2px 0; 72 | } 73 | 74 | #tag-editor > .scroll-hide { 75 | overflow: hidden !important; 76 | } 77 | 78 | #tag-editor > .wrap::-webkit-scrollbar { 79 | width: 6px; 80 | } 81 | 82 | #tag-editor > .wrap::-webkit-scrollbar-track { 83 | background: #f1f1f1; 84 | border-radius: 3px; 85 | } 86 | 87 | #tag-editor > .wrap::-webkit-scrollbar-thumb { 88 | background: #888; 89 | border-radius: 3px; 90 | } 91 | 92 | #tag-editor > .wrap::-webkit-scrollbar-thumb:hover { 93 | background: #555; 94 | } 95 | -------------------------------------------------------------------------------- /tagger/api.py: -------------------------------------------------------------------------------- 1 | """API module for FastAPI""" 2 | from typing import Callable, Dict, Optional 3 | from threading import Lock 4 | from secrets import compare_digest 5 | import asyncio 6 | from collections import defaultdict 7 | from hashlib import sha256 8 | import string 9 | from random import choices 10 | 11 | from modules import shared # pylint: disable=import-error 12 | from modules.api.api import decode_base64_to_image # pylint: disable=E0401 13 | from modules.call_queue import queue_lock # pylint: disable=import-error 14 | from fastapi import FastAPI, Depends, HTTPException 15 | from fastapi.security import HTTPBasic, HTTPBasicCredentials 16 | 17 | from tagger import utils # pylint: disable=import-error 18 | from tagger import api_models as models # pylint: disable=import-error 19 | 20 | 21 | class Api: 22 | """Api class for FastAPI""" 23 | def __init__( 24 | self, app: FastAPI, qlock: Lock, prefix: Optional[str] = None 25 | ) -> None: 26 | if shared.cmd_opts.api_auth: 27 | self.credentials = {} 28 | for auth in shared.cmd_opts.api_auth.split(","): 29 | user, password = auth.split(":") 30 | self.credentials[user] = password 31 | 32 | self.app = app 33 | self.queue: Dict[str, asyncio.Queue] = {} 34 | self.res: Dict[str, Dict[str, Dict[str, float]]] = \ 35 | defaultdict(dict) 36 | self.queue_lock = qlock 37 | self.tasks: Dict[str, asyncio.Task] = {} 38 | 39 | self.runner: Optional[asyncio.Task] = None 40 | self.prefix = prefix 41 | self.running_batches: Dict[str, Dict[str, float]] = \ 42 | defaultdict(lambda: defaultdict(int)) 43 | 44 | self.add_api_route( 45 | 'interrogate', 46 | self.endpoint_interrogate, 47 | methods=['POST'], 48 | response_model=models.TaggerInterrogateResponse 49 | ) 50 | 51 | self.add_api_route( 52 | 'interrogators', 53 | self.endpoint_interrogators, 54 | methods=['GET'], 55 | response_model=models.TaggerInterrogatorsResponse 56 | ) 57 | 58 | self.add_api_route( 59 | 'unload-interrogators', 60 | self.endpoint_unload_interrogators, 61 | methods=['POST'], 62 | response_model=str, 63 | ) 64 | 65 | async def add_to_queue(self, m, q, n='', i=None, t=0.0) -> Dict[ 66 | str, Dict[str, float] 67 | ]: 68 | if m not in self.queue: 69 | self.queue[m] = asyncio.Queue() 70 | # loop = asyncio.get_running_loop() 71 | # asyncio.run_coroutine_threadsafe( 72 | task = asyncio.create_task(self.queue[m].put((q, n, i, t))) 73 | # , loop) 74 | 75 | if self.runner is None: 76 | loop = asyncio.get_running_loop() 77 | asyncio.ensure_future(self.batch_process(), loop=loop) 78 | await task 79 | return await self.tasks[q+"\t"+n] 80 | 81 | async def do_queued_interrogation(self, m, q, n, i, t) -> Dict[ 82 | str, Dict[str, float] 83 | ]: 84 | self.running_batches[m][q] += 1.0 85 | # queue and name empty to process, not queue 86 | res = self.endpoint_interrogate( 87 | models.TaggerInterrogateRequest( 88 | image=i, 89 | model=m, 90 | threshold=t, 91 | name_in_queue='', 92 | queue='' 93 | ) 94 | ) 95 | self.res[q][n] = res.caption["tag"] 96 | for k, v in res.caption["rating"].items(): 97 | self.res[q][n]["rating:"+k] = v 98 | return self.running_batches 99 | 100 | async def finish_queue(self, m, q) -> Dict[str, Dict[str, float]]: 101 | if q in self.running_batches[m]: 102 | del self.running_batches[m][q] 103 | if q in self.res: 104 | return self.res.pop(q) 105 | return self.running_batches 106 | 107 | async def batch_process(self) -> None: 108 | # loop = asyncio.get_running_loop() 109 | while len(self.queue) > 0: 110 | for m in self.queue: 111 | # if zero the queue might just be pending 112 | while True: 113 | try: 114 | # q, n, i, t = asyncio.run_coroutine_threadsafe( 115 | # self.queue[m].get_nowait(), loop).result() 116 | q, n, i, t = self.queue[m].get_nowait() 117 | except asyncio.QueueEmpty: 118 | break 119 | self.tasks[q+"\t"+n] = asyncio.create_task( 120 | self.do_queued_interrogation(m, q, n, i, t) if n != "" 121 | else self.finish_queue(m, q) 122 | ) 123 | 124 | for model in self.running_batches: 125 | if len(self.running_batches[model]) == 0: 126 | del self.queue[model] 127 | else: 128 | await asyncio.sleep(0.1) 129 | 130 | self.running_batches.clear() 131 | self.runner = None 132 | 133 | def auth(self, creds: Optional[HTTPBasicCredentials] = None): 134 | if creds is None: 135 | creds = Depends(HTTPBasic()) 136 | if creds.username in self.credentials: 137 | if compare_digest(creds.password, 138 | self.credentials[creds.username]): 139 | return True 140 | 141 | raise HTTPException( 142 | status_code=401, 143 | detail="Incorrect username or password", 144 | headers={ 145 | "WWW-Authenticate": "Basic" 146 | }) 147 | 148 | def add_api_route(self, path: str, endpoint: Callable, **kwargs): 149 | if self.prefix: 150 | path = f'{self.prefix}/{path}' 151 | 152 | if shared.cmd_opts.api_auth: 153 | return self.app.add_api_route(path, endpoint, dependencies=[ 154 | Depends(self.auth)], **kwargs) 155 | return self.app.add_api_route(path, endpoint, **kwargs) 156 | 157 | async def queue_interrogation(self, m, q, n='', i=None, t=0.0) -> Dict[ 158 | str, Dict[str, float] 159 | ]: 160 | """ queue an interrogation, or add to batch """ 161 | if n == '': 162 | task = asyncio.create_task(self.add_to_queue(m, q)) 163 | else: 164 | if n == '': 165 | n = sha256(i).hexdigest() 166 | if n in self.res[q]: 167 | return self.running_batches 168 | elif n in self.res[q]: 169 | # clobber name if it's already in the queue 170 | j = 0 171 | while f'{n}#{j}' in self.res[q]: 172 | j += 1 173 | n = f'{n}#{j}' 174 | self.res[q][n] = {} 175 | # add image to queue 176 | task = asyncio.create_task(self.add_to_queue(m, q, n, i, t)) 177 | return await task 178 | 179 | def endpoint_interrogate(self, req: models.TaggerInterrogateRequest): 180 | """ one file interrogation, queueing, or batch results """ 181 | if req.image is None: 182 | raise HTTPException(404, 'Image not found') 183 | 184 | if req.model not in utils.interrogators: 185 | raise HTTPException(404, 'Model not found') 186 | 187 | m, q, n = (req.model, req.queue, req.name_in_queue) 188 | res: Dict[str, Dict[str, float]] = {} 189 | 190 | if q != '' or n != '': 191 | if q == '': 192 | # generate a random queue name, not in use 193 | while True: 194 | q = ''.join(choices(string.ascii_uppercase + 195 | string.digits, k=8)) 196 | if q not in self.queue: 197 | break 198 | print(f'WD14 tagger api generated queue name: {q}') 199 | res = asyncio.run(self.queue_interrogation(m, q, n, req.image, 200 | req.threshold), debug=True) 201 | else: 202 | image = decode_base64_to_image(req.image) 203 | interrogator = utils.interrogators[m] 204 | res = {"tag": {}, "rating": {}} 205 | with self.queue_lock: 206 | res["rating"], tag = interrogator.interrogate(image) 207 | 208 | for k, v in tag.items(): 209 | if v > req.threshold: 210 | res["tag"][k] = v 211 | 212 | return models.TaggerInterrogateResponse(caption=res) 213 | 214 | def endpoint_interrogators(self): 215 | return models.TaggerInterrogatorsResponse( 216 | models=list(utils.interrogators.keys()) 217 | ) 218 | 219 | def endpoint_unload_interrogators(self): 220 | unloaded_models = 0 221 | 222 | for i in utils.interrogators.values(): 223 | if i.unload(): 224 | unloaded_models = unloaded_models + 1 225 | 226 | return f"Successfully unload {unloaded_models} model(s)" 227 | 228 | 229 | def on_app_started(_, app: FastAPI): 230 | Api(app, queue_lock, '/tagger/v1') 231 | -------------------------------------------------------------------------------- /tagger/api_models.py: -------------------------------------------------------------------------------- 1 | """Purpose: Pydantic models for the API.""" 2 | from typing import List, Dict 3 | 4 | from modules.api import models as sd_models # pylint: disable=E0401 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class TaggerInterrogateRequest(sd_models.InterrogateRequest): 9 | """Interrogate request model""" 10 | model: str = Field( 11 | title='Model', 12 | description='The interrogate model used.', 13 | ) 14 | threshold: float = Field( 15 | title='Threshold', 16 | description='The threshold used for the interrogate model.', 17 | default=0.0, 18 | ) 19 | queue: str = Field( 20 | title='Queue', 21 | description='name of queue; leave empty for single response', 22 | default='', 23 | ) 24 | name_in_queue: str = Field( 25 | title='Name', 26 | description='name to queue image as or use . leave empty to ' 27 | 'retrieve the final response', 28 | default='', 29 | ) 30 | 31 | 32 | class TaggerInterrogateResponse(BaseModel): 33 | """Interrogate response model""" 34 | caption: Dict[str, Dict[str, float]] = Field( 35 | title='Caption', 36 | description='The generated captions for the image.' 37 | ) 38 | 39 | 40 | class TaggerInterrogatorsResponse(BaseModel): 41 | """Interrogators response model""" 42 | models: List[str] = Field( 43 | title='Models', 44 | description='' 45 | ) 46 | -------------------------------------------------------------------------------- /tagger/dbimutils.py: -------------------------------------------------------------------------------- 1 | """DanBooru IMage Utility functions""" 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def fill_transparent(image: Image.Image, color='WHITE'): 9 | image = image.convert('RGBA') 10 | new_image = Image.new('RGBA', image.size, color) 11 | new_image.paste(image, mask=image) 12 | image = new_image.convert('RGB') 13 | return image 14 | 15 | 16 | def resize(pic: Image.Image, size: int, keep_ratio=True) -> Image.Image: 17 | if not keep_ratio: 18 | target_size = (size, size) 19 | else: 20 | min_edge = min(pic.size) 21 | target_size = ( 22 | int(pic.size[0] / min_edge * size), 23 | int(pic.size[1] / min_edge * size), 24 | ) 25 | 26 | target_size = (target_size[0] & ~3, target_size[1] & ~3) 27 | 28 | return pic.resize(target_size, resample=Image.Resampling.LANCZOS) 29 | 30 | 31 | def smart_imread(img, flag=cv2.IMREAD_UNCHANGED): 32 | """ Read an image, convert to 24-bit if necessary """ 33 | if img.endswith(".gif"): 34 | img = Image.open(img) 35 | img = img.convert("RGB") 36 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 37 | else: 38 | img = cv2.imread(img, flag) 39 | return img 40 | 41 | 42 | def smart_24bit(img): 43 | """ Convert an image to 24-bit if necessary """ 44 | if img.dtype is np.dtype(np.uint16): 45 | img = (img / 257).astype(np.uint8) 46 | 47 | if len(img.shape) == 2: 48 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 49 | elif img.shape[2] == 4: 50 | trans_mask = img[:, :, 3] == 0 51 | img[trans_mask] = [255, 255, 255, 255] 52 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) 53 | return img 54 | 55 | 56 | def make_square(img, target_size): 57 | """ Make an image square """ 58 | old_size = img.shape[:2] 59 | desired_size = max(old_size) 60 | desired_size = max(desired_size, target_size) 61 | 62 | delta_w = desired_size - old_size[1] 63 | delta_h = desired_size - old_size[0] 64 | top, bottom = delta_h // 2, delta_h - (delta_h // 2) 65 | left, right = delta_w // 2, delta_w - (delta_w // 2) 66 | 67 | color = [255, 255, 255] 68 | new_im = cv2.copyMakeBorder( 69 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color 70 | ) 71 | return new_im 72 | 73 | 74 | def smart_resize(img, size): 75 | """ Resize an image """ 76 | # Assumes the image has already gone through make_square 77 | if img.shape[0] > size: 78 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 79 | elif img.shape[0] < size: 80 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) 81 | return img 82 | -------------------------------------------------------------------------------- /tagger/format.py: -------------------------------------------------------------------------------- 1 | """Format module, for formatting output filenames""" 2 | import re 3 | import hashlib 4 | 5 | from typing import Dict, Callable, NamedTuple 6 | from pathlib import Path 7 | 8 | 9 | class Info(NamedTuple): 10 | path: Path 11 | output_ext: str 12 | 13 | 14 | def hashfun(i: Info, algo='sha1') -> str: 15 | try: 16 | hasher = hashlib.new(algo) 17 | except ImportError as err: 18 | raise ValueError(f"'{algo}' is invalid hash algorithm") from err 19 | 20 | with open(i.path, 'rb') as file: 21 | hasher.update(file.read()) 22 | 23 | return hasher.hexdigest() 24 | 25 | 26 | pattern = re.compile(r'\[([\w:]+)\]') 27 | 28 | # all function must returns string or raise TypeError or ValueError 29 | # other errors will cause the extension error 30 | available_formats: Dict[str, Callable] = { 31 | 'name': lambda i: i.path.stem, 32 | 'extension': lambda i: i.path.suffix[1:], 33 | 'hash': hashfun, 34 | 35 | 'output_extension': lambda i: i.output_ext 36 | } 37 | 38 | 39 | def parse(match: re.Match, info: Info) -> str: 40 | matches = match[1].split(':') 41 | name, args = matches[0], matches[1:] 42 | 43 | if name not in available_formats: 44 | return match[0] 45 | 46 | return available_formats[name](info, *args) 47 | -------------------------------------------------------------------------------- /tagger/generator/tf_data_reader.py: -------------------------------------------------------------------------------- 1 | """ Credits to SmilingWolf """ 2 | 3 | import tensorflow as tf 4 | try: 5 | import tensorflow_io as tfio # pylint: disable=import-error 6 | except ImportError: 7 | tfio = None 8 | 9 | def is_webp(contents): 10 | """Checks if the image is a webp image""" 11 | riff_header = tf.strings.substr(contents, 0, 4) 12 | webp_header = tf.strings.substr(contents, 8, 4) 13 | 14 | is_riff = riff_header == b"RIFF" 15 | is_fourcc_webp = webp_header == b"WEBP" 16 | return is_riff and is_fourcc_webp 17 | 18 | 19 | class DataGenerator: 20 | """ Data generator for the dataset """ 21 | def __init__(self, file_list, target_height, target_width, batch_size): 22 | self.file_list = file_list 23 | self.target_width = target_width 24 | self.target_height = target_height 25 | self.batch_size = batch_size 26 | 27 | def read_image(self, filename): 28 | image_bytes = tf.io.read_file(filename) 29 | return filename, image_bytes 30 | 31 | def parse_single_image(self, filename, image_bytes): 32 | """ Parses a single image """ 33 | if is_webp(image_bytes): 34 | image = tfio.image.decode_webp(image_bytes) 35 | else: 36 | image = tf.io.decode_image( 37 | image_bytes, channels=0, dtype=tf.uint8, 38 | expand_animations=False 39 | ) 40 | 41 | # Black and white image 42 | if tf.shape(image)[2] == 1: 43 | image = tf.repeat(image, 3, axis=-1) 44 | 45 | # Black and white image with alpha 46 | elif tf.shape(image)[2] == 2: 47 | image, mask = tf.unstack(image, num=2, axis=-1) 48 | mask = tf.expand_dims(mask, axis=-1) 49 | image = tf.expand_dims(image, axis=-1) 50 | image = tf.repeat(image, 3, axis=-1) 51 | image = tf.concat([image, mask], -1) 52 | 53 | # Alpha to white 54 | if tf.shape(image)[2] == 4: 55 | alpha_mask = image[:, :, 3] 56 | alpha_mask = tf.cast(alpha_mask, tf.float32) / 255 57 | alpha_mask = tf.repeat(tf.expand_dims(alpha_mask, -1), 4, axis=-1) 58 | 59 | matte = tf.ones_like(image, dtype=tf.uint8) * [255, 255, 255, 255] 60 | 61 | weighted_matte = tf.cast(matte, dtype=alpha_mask.dtype) * (1 - alpha_mask) # noqa: E501 62 | weighted_image = tf.cast(image, dtype=alpha_mask.dtype) * alpha_mask # noqa: E501 63 | image = weighted_image + weighted_matte 64 | 65 | # Remove alpha channel 66 | image = tf.cast(image, dtype=tf.uint8)[:, :, :-1] 67 | 68 | # Pillow/Tensorflow RGB -> OpenCV BGR 69 | image = image[:, :, ::-1] 70 | return filename, image 71 | 72 | def resize_single_image(self, filename, image): 73 | """ Resizes a single image """ 74 | height, width, _ = tf.unstack(tf.shape(image)) 75 | 76 | if height <= self.target_height and width <= self.target_width: 77 | return filename, image 78 | 79 | image = tf.image.resize( 80 | image, 81 | (self.target_height, self.target_width), 82 | method=tf.image.ResizeMethod.AREA, 83 | preserve_aspect_ratio=True, 84 | ) 85 | image = tf.cast(tf.math.round(image), dtype=tf.uint8) 86 | return filename, image 87 | 88 | def pad_single_image(self, filename, image): 89 | """ Pads a single image """ 90 | height, width, _ = tf.unstack(tf.shape(image)) 91 | 92 | float_h = tf.cast(height, dtype=tf.float32) 93 | float_w = tf.cast(width, dtype=tf.float32) 94 | float_target_h = tf.cast(self.target_height, dtype=tf.float32) 95 | float_target_w = tf.cast(self.target_width, dtype=tf.float32) 96 | 97 | padding_top = tf.cast((float_target_h - float_h) / 2, dtype=tf.int32) 98 | padding_right = tf.cast((float_target_w - float_w) / 2, dtype=tf.int32) 99 | padding_bottom = self.target_height - padding_top - height 100 | padding_left = self.target_width - padding_right - width 101 | 102 | padding = [[padding_top, padding_bottom], 103 | [padding_right, padding_left], [0, 0]] 104 | image = tf.pad(image, padding, mode="CONSTANT", constant_values=255) 105 | return filename, image 106 | 107 | def gen_ds(self): 108 | """ Generates the dataset """ 109 | if tfio is None: 110 | print("Tensorflow IO is not installed, try\n" 111 | "`pip install tensorflow_io' or use another interrogator") 112 | return [] 113 | images_list = tf.data.Dataset.from_tensor_slices(self.file_list) 114 | 115 | images_data = images_list.map( 116 | self.read_image, num_parallel_calls=tf.data.AUTOTUNE 117 | ) 118 | images_data = images_data.map( 119 | self.parse_single_image, num_parallel_calls=tf.data.AUTOTUNE 120 | ) 121 | images_data = images_data.map( 122 | self.resize_single_image, num_parallel_calls=tf.data.AUTOTUNE 123 | ) 124 | images_data = images_data.map( 125 | self.pad_single_image, num_parallel_calls=tf.data.AUTOTUNE 126 | ) 127 | 128 | images_list = images_data.batch( 129 | self.batch_size, drop_remainder=False, 130 | num_parallel_calls=tf.data.AUTOTUNE 131 | ) 132 | images_list = images_list.prefetch(tf.data.AUTOTUNE) 133 | return images_list 134 | -------------------------------------------------------------------------------- /tagger/interrogator.py: -------------------------------------------------------------------------------- 1 | """ Interrogator class and subclasses for tagger """ 2 | import os 3 | from pathlib import Path 4 | import io 5 | import json 6 | import inspect 7 | from re import match as re_match 8 | from platform import system, uname 9 | from typing import Tuple, List, Dict, Callable 10 | from pandas import read_csv 11 | from PIL import Image, UnidentifiedImageError 12 | from numpy import asarray, float32, expand_dims, exp 13 | from tqdm import tqdm 14 | from huggingface_hub import hf_hub_download 15 | 16 | from modules.paths import extensions_dir 17 | from modules import shared 18 | from tagger import settings # pylint: disable=import-error 19 | from tagger.uiset import QData, IOData # pylint: disable=import-error 20 | from . import dbimutils # pylint: disable=import-error # noqa 21 | 22 | Its = settings.InterrogatorSettings 23 | 24 | # select a device to process 25 | use_cpu = ('all' in shared.cmd_opts.use_cpu) or ( 26 | 'interrogate' in shared.cmd_opts.use_cpu) 27 | 28 | # https://onnxruntime.ai/docs/execution-providers/ 29 | # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958 30 | onnxrt_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 31 | 32 | if shared.cmd_opts.additional_device_ids is not None: 33 | m = re_match(r'([cg])pu:\d+$', shared.cmd_opts.additional_device_ids) 34 | if m is None: 35 | raise ValueError('--device-id is not cpu: or gpu:') 36 | if m.group(1) == 'c': 37 | onnxrt_providers.pop(0) 38 | TF_DEVICE_NAME = f'/{shared.cmd_opts.additional_device_ids}' 39 | elif use_cpu: 40 | TF_DEVICE_NAME = '/cpu:0' 41 | onnxrt_providers.pop(0) 42 | else: 43 | TF_DEVICE_NAME = '/gpu:0' 44 | 45 | print(f'== WD14 tagger {TF_DEVICE_NAME}, {uname()} ==') 46 | 47 | 48 | class Interrogator: 49 | """ Interrogator class for tagger """ 50 | # the raw input and output. 51 | input = { 52 | "cumulative": False, 53 | "large_query": False, 54 | "unload_after": False, 55 | "add": '', 56 | "keep": '', 57 | "exclude": '', 58 | "search": '', 59 | "replace": '', 60 | "output_dir": '', 61 | } 62 | output = None 63 | odd_increment = 0 64 | 65 | @classmethod 66 | def flip(cls, key): 67 | def toggle(): 68 | cls.input[key] = not cls.input[key] 69 | return toggle 70 | 71 | @staticmethod 72 | def get_errors() -> str: 73 | errors = '' 74 | if len(IOData.err) > 0: 75 | # write errors in html pointer list, every error in a
  • tag 76 | errors = IOData.error_msg() 77 | if len(QData.err) > 0: 78 | errors += 'Fix to write correct output:
    • ' + \ 79 | '
    • '.join(QData.err) + '
    ' 80 | return errors 81 | 82 | @classmethod 83 | def set(cls, key: str) -> Callable[[str], Tuple[str, str]]: 84 | def setter(val) -> Tuple[str, str]: 85 | if key == 'input_glob': 86 | IOData.update_input_glob(val) 87 | return (val, cls.get_errors()) 88 | if val != cls.input[key]: 89 | tgt_cls = IOData if key == 'output_dir' else QData 90 | getattr(tgt_cls, "update_" + key)(val) 91 | cls.input[key] = val 92 | return (cls.input[key], cls.get_errors()) 93 | 94 | return setter 95 | 96 | @staticmethod 97 | def load_image(path: str) -> Image: 98 | try: 99 | return Image.open(path) 100 | except FileNotFoundError: 101 | print(f'${path} not found') 102 | except UnidentifiedImageError: 103 | # just in case, user has mysterious file... 104 | print(f'${path} is not a supported image type') 105 | except ValueError: 106 | print(f'${path} is not readable or StringIO') 107 | return None 108 | 109 | def __init__(self, name: str) -> None: 110 | self.name = name 111 | self.model = None 112 | self.tags = None 113 | # run_mode 0 is dry run, 1 means run (alternating), 2 means disabled 114 | self.run_mode = 0 if hasattr(self, "large_batch_interrogate") else 2 115 | 116 | def load(self): 117 | raise NotImplementedError() 118 | 119 | def large_batch_interrogate(self, images: List, dry_run=False) -> str: 120 | raise NotImplementedError() 121 | 122 | def unload(self) -> bool: 123 | unloaded = False 124 | 125 | if self.model is not None: 126 | del self.model 127 | self.model = None 128 | unloaded = True 129 | print(f'Unloaded {self.name}') 130 | 131 | if hasattr(self, 'tags'): 132 | del self.tags 133 | self.tags = None 134 | 135 | return unloaded 136 | 137 | def interrogate_image(self, image: Image) -> None: 138 | sha = IOData.get_bytes_hash(image.tobytes()) 139 | QData.clear(1 - Interrogator.input["cumulative"]) 140 | 141 | fi_key = sha + self.name 142 | count = 0 143 | 144 | if fi_key in QData.query: 145 | # this file was already queried for this interrogator. 146 | QData.single_data(fi_key) 147 | else: 148 | # single process 149 | count += 1 150 | data = ('', '', fi_key) + self.interrogate(image) 151 | # When drag-dropping an image, the path [0] is not known 152 | if Interrogator.input["unload_after"]: 153 | self.unload() 154 | 155 | QData.apply_filters(data) 156 | 157 | for got in QData.in_db.values(): 158 | QData.apply_filters(got) 159 | 160 | Interrogator.output = QData.finalize(count) 161 | 162 | def batch_interrogate_image(self, index: int) -> None: 163 | # if outputpath is '', no tags file will be written 164 | if len(IOData.paths[index]) == 5: 165 | path, out_path, output_dir, image_hash, image = IOData.paths[index] 166 | elif len(IOData.paths[index]) == 4: 167 | path, out_path, output_dir, image_hash = IOData.paths[index] 168 | image = Interrogator.load_image(path) 169 | # should work, we queried before to get the image_hash 170 | else: 171 | path, out_path, output_dir = IOData.paths[index] 172 | image = Interrogator.load_image(path) 173 | if image is None: 174 | return 175 | 176 | image_hash = IOData.get_bytes_hash(image.tobytes()) 177 | IOData.paths[index].append(image_hash) 178 | if getattr(shared.opts, 'tagger_store_images', False): 179 | IOData.paths[index].append(image) 180 | 181 | if output_dir: 182 | output_dir.mkdir(0o755, True, True) 183 | # next iteration we don't need to create the directory 184 | IOData.paths[index][2] = '' 185 | QData.image_dups[image_hash].add(path) 186 | 187 | abspath = str(path.absolute()) 188 | fi_key = image_hash + self.name 189 | 190 | if fi_key in QData.query: 191 | # this file was already queried for this interrogator. 192 | i = QData.get_index(fi_key, abspath) 193 | # this file was already queried and stored 194 | QData.in_db[i] = (abspath, out_path, '', {}, {}) 195 | else: 196 | data = (abspath, out_path, fi_key) + self.interrogate(image) 197 | # also the tags can indicate that the image is a duplicate 198 | no_floats = sorted(filter(lambda x: not isinstance(x[0], float), 199 | data[3].items()), key=lambda x: x[0]) 200 | sorted_tags = ','.join(f'({k},{v:.1f})' for (k, v) in no_floats) 201 | QData.image_dups[sorted_tags].add(abspath) 202 | QData.apply_filters(data) 203 | QData.had_new = True 204 | 205 | def batch_interrogate(self) -> None: 206 | """ Interrogate all images in the input list """ 207 | QData.clear(1 - Interrogator.input["cumulative"]) 208 | 209 | if Interrogator.input["large_query"] is True and self.run_mode < 2: 210 | # TODO: write specified tags files instead of simple .txt 211 | image_list = [str(x[0].resolve()) for x in IOData.paths] 212 | self.large_batch_interrogate(image_list, self.run_mode == 0) 213 | 214 | # alternating dry run and run modes 215 | self.run_mode = (self.run_mode + 1) % 2 216 | count = len(image_list) 217 | Interrogator.output = QData.finalize(count) 218 | else: 219 | verb = getattr(shared.opts, 'tagger_verbose', True) 220 | count = len(QData.query) 221 | 222 | for i in tqdm(range(len(IOData.paths)), disable=verb, desc='Tags'): 223 | self.batch_interrogate_image(i) 224 | 225 | if Interrogator.input["unload_after"]: 226 | self.unload() 227 | 228 | count = len(QData.query) - count 229 | Interrogator.output = QData.finalize_batch(count) 230 | 231 | def interrogate( 232 | self, 233 | image: Image 234 | ) -> Tuple[ 235 | Dict[str, float], # rating confidences 236 | Dict[str, float] # tag confidences 237 | ]: 238 | raise NotImplementedError() 239 | 240 | 241 | class DeepDanbooruInterrogator(Interrogator): 242 | """ Interrogator for DeepDanbooru models """ 243 | def __init__(self, name: str, project_path: os.PathLike) -> None: 244 | super().__init__(name) 245 | self.project_path = project_path 246 | self.model = None 247 | self.tags = None 248 | 249 | def load(self) -> None: 250 | print(f'Loading {self.name} from {str(self.project_path)}') 251 | 252 | # deepdanbooru package is not include in web-sd anymore 253 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/c81d440d876dfd2ab3560410f37442ef56fc663 254 | from launch import is_installed, run_pip 255 | if not is_installed('deepdanbooru'): 256 | package = os.environ.get( 257 | 'DEEPDANBOORU_PACKAGE', 258 | 'git+https://github.com/KichangKim/DeepDanbooru.' 259 | 'git@d91a2963bf87c6a770d74894667e9ffa9f6de7ff' 260 | ) 261 | 262 | run_pip( 263 | f'install {package} tensorflow tensorflow-io', 'deepdanbooru') 264 | 265 | import tensorflow as tf 266 | 267 | # tensorflow maps nearly all vram by default, so we limit this 268 | # https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth 269 | # TODO: only run on the first run 270 | for device in tf.config.experimental.list_physical_devices('GPU'): 271 | try: 272 | tf.config.experimental.set_memory_growth(device, True) 273 | except RuntimeError as err: 274 | print(err) 275 | 276 | with tf.device(TF_DEVICE_NAME): 277 | import deepdanbooru.project as ddp 278 | 279 | self.model = ddp.load_model_from_project( 280 | project_path=self.project_path, 281 | compile_model=False 282 | ) 283 | 284 | print(f'Loaded {self.name} model from {str(self.project_path)}') 285 | 286 | self.tags = ddp.load_tags_from_project( 287 | project_path=self.project_path 288 | ) 289 | 290 | def unload(self) -> bool: 291 | return False 292 | 293 | def interrogate( 294 | self, 295 | image: Image 296 | ) -> Tuple[ 297 | Dict[str, float], # rating confidences 298 | Dict[str, float] # tag confidences 299 | ]: 300 | # init model 301 | if self.model is None: 302 | self.load() 303 | 304 | import deepdanbooru.data as ddd 305 | 306 | # convert an image to fit the model 307 | image_bufs = io.BytesIO() 308 | image.save(image_bufs, format='PNG') 309 | image = ddd.load_image_for_evaluate( 310 | image_bufs, 311 | self.model.input_shape[2], 312 | self.model.input_shape[1] 313 | ) 314 | 315 | image = image.reshape((1, *image.shape[0:3])) 316 | 317 | # evaluate model 318 | result = self.model.predict(image) 319 | 320 | confidences = result[0].tolist() 321 | ratings = {} 322 | tags = {} 323 | 324 | for i, tag in enumerate(self.tags): 325 | if tag[:7] != "rating:": 326 | tags[tag] = confidences[i] 327 | else: 328 | ratings[tag[7:]] = confidences[i] 329 | 330 | return ratings, tags 331 | 332 | def large_batch_interrogate(self, images: List, dry_run=False) -> str: 333 | raise NotImplementedError() 334 | 335 | 336 | # FIXME this is silly, in what scenario would the env change from MacOS to 337 | # another OS? TODO: remove if the author does not respond. 338 | def get_onnxrt(): 339 | try: 340 | import onnxruntime 341 | return onnxruntime 342 | except ImportError: 343 | # only one of these packages should be installed at one time in an env 344 | # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime 345 | # TODO: remove old package when the environment changes? 346 | from launch import is_installed, run_pip 347 | if not is_installed('onnxruntime'): 348 | if system() == "Darwin": 349 | package_name = "onnxruntime-silicon" 350 | else: 351 | package_name = "onnxruntime-gpu" 352 | package = os.environ.get( 353 | 'ONNXRUNTIME_PACKAGE', 354 | package_name 355 | ) 356 | 357 | run_pip(f'install {package}', 'onnxruntime') 358 | 359 | import onnxruntime 360 | return onnxruntime 361 | 362 | 363 | class WaifuDiffusionInterrogator(Interrogator): 364 | """ Interrogator for Waifu Diffusion models """ 365 | def __init__( 366 | self, 367 | name: str, 368 | model_path='model.onnx', 369 | tags_path='selected_tags.csv', 370 | repo_id=None, 371 | is_hf=True, 372 | ) -> None: 373 | super().__init__(name) 374 | self.repo_id = repo_id 375 | self.model_path = model_path 376 | self.tags_path = tags_path 377 | self.tags = None 378 | self.model = None 379 | self.tags = None 380 | self.local_model = None 381 | self.local_tags = None 382 | self.is_hf = is_hf 383 | 384 | def download(self) -> None: 385 | mdir = Path(shared.models_path, 'interrogators') 386 | if self.is_hf: 387 | cache = getattr(shared.opts, 'tagger_hf_cache_dir', Its.hf_cache) 388 | print(f"Loading {self.name} model file from {self.repo_id}, " 389 | f"{self.model_path}") 390 | 391 | model_path = hf_hub_download( 392 | repo_id=self.repo_id, 393 | filename=self.model_path, 394 | cache_dir=cache) 395 | tags_path = hf_hub_download( 396 | repo_id=self.repo_id, 397 | filename=self.tags_path, 398 | cache_dir=cache) 399 | else: 400 | model_path = self.local_model 401 | tags_path = self.local_tags 402 | 403 | download_model = { 404 | 'name': self.name, 405 | 'model_path': model_path, 406 | 'tags_path': tags_path, 407 | } 408 | mpath = Path(mdir, 'model.json') 409 | 410 | data = [download_model] 411 | 412 | if not os.path.exists(mdir): 413 | os.mkdir(mdir) 414 | 415 | elif os.path.exists(mpath): 416 | with io.open(file=mpath, mode='r', encoding='utf-8') as filename: 417 | try: 418 | data = json.load(filename) 419 | # No need to append if it's already contained 420 | if download_model not in data: 421 | data.append(download_model) 422 | except json.JSONDecodeError as err: 423 | print(f'Adding download_model {mpath} raised {repr(err)}') 424 | data = [download_model] 425 | 426 | with io.open(mpath, 'w', encoding='utf-8') as filename: 427 | json.dump(data, filename) 428 | return model_path, tags_path 429 | 430 | def load(self) -> None: 431 | model_path, tags_path = self.download() 432 | ort = get_onnxrt() 433 | self.model = ort.InferenceSession(model_path, 434 | providers=onnxrt_providers) 435 | 436 | print(f'Loaded {self.name} model from {self.repo_id}') 437 | self.tags = read_csv(tags_path) 438 | 439 | def interrogate( 440 | self, 441 | image: Image 442 | ) -> Tuple[ 443 | Dict[str, float], # rating confidences 444 | Dict[str, float] # tag confidences 445 | ]: 446 | # init model 447 | if self.model is None: 448 | self.load() 449 | 450 | # code for converting the image and running the model is taken from the 451 | # link below. thanks, SmilingWolf! 452 | # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py 453 | 454 | # convert an image to fit the model 455 | _, height, _, _ = self.model.get_inputs()[0].shape 456 | 457 | # alpha to white 458 | image = dbimutils.fill_transparent(image) 459 | 460 | image = asarray(image) 461 | # PIL RGB to OpenCV BGR 462 | image = image[:, :, ::-1] 463 | 464 | tags = dict 465 | 466 | image = dbimutils.make_square(image, height) 467 | image = dbimutils.smart_resize(image, height) 468 | image = image.astype(float32) 469 | image = expand_dims(image, 0) 470 | 471 | # evaluate model 472 | input_name = self.model.get_inputs()[0].name 473 | label_name = self.model.get_outputs()[0].name 474 | confidences = self.model.run([label_name], {input_name: image})[0] 475 | 476 | tags = self.tags[:][['name']] 477 | tags['confidences'] = confidences[0] 478 | 479 | # first 4 items are for rating (general, sensitive, questionable, 480 | # explicit) 481 | ratings = dict(tags[:4].values) 482 | 483 | # rest are regular tags 484 | tags = dict(tags[4:].values) 485 | 486 | return ratings, tags 487 | 488 | def dry_run(self, images) -> Tuple[str, Callable[[str], None]]: 489 | 490 | def process_images(filepaths, _): 491 | lines = [] 492 | for image_path in filepaths: 493 | image_path = image_path.numpy().decode("utf-8") 494 | lines.append(f"{image_path}\n") 495 | with io.open("dry_run_read.txt", "a", encoding="utf-8") as filen: 496 | filen.writelines(lines) 497 | 498 | scheduled = [f"{image_path}\n" for image_path in images] 499 | 500 | # Truncate the file from previous runs 501 | print("updating dry_run_read.txt") 502 | io.open("dry_run_read.txt", "w", encoding="utf-8").close() 503 | with io.open("dry_run_scheduled.txt", "w", encoding="utf-8") as filen: 504 | filen.writelines(scheduled) 505 | return process_images 506 | 507 | def run(self, images, pred_model) -> Tuple[str, Callable[[str], None]]: 508 | threshold = QData.threshold 509 | self.tags["sanitized_name"] = self.tags["name"].map( 510 | lambda i: i if i in Its.kaomojis else i.replace("_", " ") 511 | ) 512 | 513 | def process_images(filepaths, images): 514 | preds = pred_model(images).numpy() 515 | 516 | for ipath, pred in zip(filepaths, preds): 517 | ipath = ipath.numpy().decode("utf-8") 518 | 519 | self.tags["preds"] = pred 520 | generic = self.tags[self.tags["category"] == 0] 521 | chosen = generic[generic["preds"] > threshold] 522 | chosen = chosen.sort_values(by="preds", ascending=False) 523 | tags_names = chosen["sanitized_name"] 524 | 525 | key = ipath.split("/")[-1].split(".")[0] + "_" + self.name 526 | QData.add_tags = tags_names 527 | QData.apply_filters((ipath, '', {}, {}), key, False) 528 | 529 | tags_string = ", ".join(tags_names) 530 | txtfile = Path(ipath).with_suffix(".txt") 531 | with io.open(txtfile, "w", encoding="utf-8") as filename: 532 | filename.write(tags_string) 533 | return images, process_images 534 | 535 | def large_batch_interrogate(self, images, dry_run=True) -> None: 536 | """ Interrogate a large batch of images. """ 537 | 538 | # init model 539 | if not hasattr(self, 'model') or self.model is None: 540 | self.load() 541 | 542 | os.environ["TF_XLA_FLAGS"] = '--tf_xla_auto_jit=2 '\ 543 | '--tf_xla_cpu_global_jit' 544 | # Reduce logging 545 | # os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" 546 | 547 | import tensorflow as tf 548 | 549 | from tagger.generator.tf_data_reader import DataGenerator 550 | 551 | # tensorflow maps nearly all vram by default, so we limit this 552 | # https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth 553 | # TODO: only run on the first run 554 | gpus = tf.config.experimental.list_physical_devices("GPU") 555 | if gpus: 556 | for device in gpus: 557 | try: 558 | tf.config.experimental.set_memory_growth(device, True) 559 | except RuntimeError as err: 560 | print(err) 561 | 562 | if dry_run: # dry run 563 | height, width = 224, 224 564 | process_images = self.dry_run(images) 565 | else: 566 | _, height, width, _ = self.model.inputs[0].shape 567 | 568 | @tf.function 569 | def pred_model(model): 570 | return self.model(model, training=False) 571 | 572 | process_images = self.run(images, pred_model) 573 | 574 | generator = DataGenerator( 575 | file_list=images, target_height=height, target_width=width, 576 | batch_size=getattr(shared.opts, 'tagger_batch_size', 1024) 577 | ).gen_ds() 578 | 579 | orig_add_tags = QData.add_tags 580 | for filepaths, image_list in tqdm(generator): 581 | process_images(filepaths, image_list) 582 | QData.add_tag = orig_add_tags 583 | del os.environ["TF_XLA_FLAGS"] 584 | 585 | 586 | class MLDanbooruInterrogator(Interrogator): 587 | """ Interrogator for the MLDanbooru model. """ 588 | def __init__( 589 | self, 590 | name: str, 591 | repo_id: str, 592 | model_path: str, 593 | tags_path='classes.json', 594 | ) -> None: 595 | super().__init__(name) 596 | self.model_path = model_path 597 | self.tags_path = tags_path 598 | self.repo_id = repo_id 599 | self.tags = None 600 | self.model = None 601 | 602 | def download(self) -> Tuple[str, str]: 603 | print(f"Loading {self.name} model file from {self.repo_id}") 604 | cache = getattr(shared.opts, 'tagger_hf_cache_dir', Its.hf_cache) 605 | 606 | model_path = hf_hub_download( 607 | repo_id=self.repo_id, 608 | filename=self.model_path, 609 | cache_dir=cache 610 | ) 611 | tags_path = hf_hub_download( 612 | repo_id=self.repo_id, 613 | filename=self.tags_path, 614 | cache_dir=cache 615 | ) 616 | return model_path, tags_path 617 | 618 | def load(self) -> None: 619 | model_path, tags_path = self.download() 620 | 621 | ort = get_onnxrt() 622 | self.model = ort.InferenceSession(model_path, 623 | providers=onnxrt_providers) 624 | print(f'Loaded {self.name} model from {model_path}') 625 | 626 | with open(tags_path, 'r', encoding='utf-8') as filen: 627 | self.tags = json.load(filen) 628 | 629 | def interrogate( 630 | self, 631 | image: Image 632 | ) -> Tuple[ 633 | Dict[str, float], # rating confidents 634 | Dict[str, float] # tag confidents 635 | ]: 636 | # init model 637 | if self.model is None: 638 | self.load() 639 | 640 | image = dbimutils.fill_transparent(image) 641 | image = dbimutils.resize(image, 448) # TODO CUSTOMIZE 642 | 643 | x = asarray(image, dtype=float32) / 255 644 | # HWC -> 1CHW 645 | x = x.transpose((2, 0, 1)) 646 | x = expand_dims(x, 0) 647 | 648 | input_ = self.model.get_inputs()[0] 649 | output = self.model.get_outputs()[0] 650 | # evaluate model 651 | y, = self.model.run([output.name], {input_.name: x}) 652 | 653 | # Softmax 654 | y = 1 / (1 + exp(-y)) 655 | 656 | tags = {tag: float(conf) for tag, conf in zip(self.tags, y.flatten())} 657 | return {}, tags 658 | 659 | def large_batch_interrogate(self, images: List, dry_run=False) -> str: 660 | raise NotImplementedError() 661 | 662 | 663 | class Z3DInterrogator(Interrogator): 664 | """ Interrogator for Z3D Waifu Diffusion models """ 665 | def __init__( 666 | self, 667 | name: str, 668 | model_path='model.onnx', 669 | tags_path='tags-selected.csv', 670 | repo_id=None, 671 | is_hf=True, 672 | ) -> None: 673 | super().__init__(name) 674 | self.repo_id = repo_id 675 | self.model_path = model_path 676 | self.tags_path = tags_path 677 | self.tags = None 678 | self.model = None 679 | self.tags = None 680 | self.local_model = None 681 | self.local_tags = None 682 | self.is_hf = is_hf 683 | 684 | def download(self) -> None: 685 | mdir = Path(shared.models_path, 'interrogators') 686 | if self.is_hf: 687 | cache = getattr(shared.opts, 'tagger_hf_cache_dir', Its.hf_cache) 688 | print(f"Loading {self.name} model file from {self.repo_id}, " 689 | f"{self.model_path}") 690 | 691 | model_path = hf_hub_download( 692 | repo_id=self.repo_id, 693 | filename=self.model_path, 694 | cache_dir=cache) 695 | tags_path = hf_hub_download( 696 | repo_id=self.repo_id, 697 | filename=self.tags_path, 698 | cache_dir=cache) 699 | else: 700 | model_path = self.local_model 701 | tags_path = self.local_tags 702 | 703 | download_model = { 704 | 'name': self.name, 705 | 'model_path': model_path, 706 | 'tags_path': tags_path, 707 | } 708 | mpath = Path(mdir, 'model.json') 709 | 710 | data = [download_model] 711 | 712 | if not os.path.exists(mdir): 713 | os.mkdir(mdir) 714 | 715 | elif os.path.exists(mpath): 716 | with io.open(file=mpath, mode='r', encoding='utf-8') as filename: 717 | try: 718 | data = json.load(filename) 719 | # No need to append if it's already contained 720 | if download_model not in data: 721 | data.append(download_model) 722 | except json.JSONDecodeError as err: 723 | print(f'Adding download_model {mpath} raised {repr(err)}') 724 | data = [download_model] 725 | 726 | with io.open(mpath, 'w', encoding='utf-8') as filename: 727 | json.dump(data, filename) 728 | return model_path, tags_path 729 | 730 | def load(self) -> None: 731 | model_path, tags_path = self.download() 732 | ort = get_onnxrt() 733 | self.model = ort.InferenceSession(model_path, 734 | providers=onnxrt_providers) 735 | 736 | print(f'Loaded {self.name} model from {self.repo_id}') 737 | self.tags = read_csv(tags_path) 738 | 739 | def interrogate( 740 | self, 741 | image: Image 742 | ) -> Tuple[ 743 | Dict[str, float], # rating confidences 744 | Dict[str, float] # tag confidences 745 | ]: 746 | # init model 747 | if self.model is None: 748 | self.load() 749 | 750 | # code for converting the image and running the model is taken from the 751 | # link below. thanks, SmilingWolf! 752 | # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py 753 | 754 | # convert an image to fit the model 755 | _, height, _, _ = self.model.get_inputs()[0].shape 756 | 757 | # alpha to white 758 | image = dbimutils.fill_transparent(image) 759 | 760 | image = asarray(image) 761 | # PIL RGB to OpenCV BGR 762 | image = image[:, :, ::-1] 763 | 764 | tags = dict 765 | 766 | image = dbimutils.make_square(image, height) 767 | image = dbimutils.smart_resize(image, height) 768 | image = image.astype(float32) 769 | image = expand_dims(image, 0) 770 | 771 | # evaluate model 772 | input_name = self.model.get_inputs()[0].name 773 | label_name = self.model.get_outputs()[0].name 774 | confidences = self.model.run([label_name], {input_name: image})[0] 775 | 776 | tags = self.tags[:][['name']] 777 | tags['confidences'] = confidences[0] 778 | 779 | # first 4 items are for rating (Safe, Questionable, Explicit) 780 | ratings = dict(tags[:0].values) 781 | 782 | # rest are regular tags 783 | tags = dict(tags[0:].values) 784 | 785 | return ratings, tags 786 | 787 | def dry_run(self, images) -> Tuple[str, Callable[[str], None]]: 788 | 789 | def process_images(filepaths, _): 790 | lines = [] 791 | for image_path in filepaths: 792 | image_path = image_path.numpy().decode("utf-8") 793 | lines.append(f"{image_path}\n") 794 | with io.open("dry_run_read.txt", "a", encoding="utf-8") as filen: 795 | filen.writelines(lines) 796 | 797 | scheduled = [f"{image_path}\n" for image_path in images] 798 | 799 | # Truncate the file from previous runs 800 | print("updating dry_run_read.txt") 801 | io.open("dry_run_read.txt", "w", encoding="utf-8").close() 802 | with io.open("dry_run_scheduled.txt", "w", encoding="utf-8") as filen: 803 | filen.writelines(scheduled) 804 | return process_images 805 | 806 | def run(self, images, pred_model) -> Tuple[str, Callable[[str], None]]: 807 | threshold = QData.threshold 808 | self.tags["sanitized_name"] = self.tags["name"].map( 809 | lambda i: i if i in Its.kaomojis else i.replace("_", " ") 810 | ) 811 | 812 | def process_images(filepaths, images): 813 | preds = pred_model(images).numpy() 814 | 815 | for ipath, pred in zip(filepaths, preds): 816 | ipath = ipath.numpy().decode("utf-8") 817 | 818 | self.tags["preds"] = pred 819 | generic = self.tags[self.tags["rating"] == 0] 820 | chosen = generic[generic["preds"] > threshold] 821 | chosen = chosen.sort_values(by="preds", ascending=False) 822 | tags_names = chosen["sanitized_name"] 823 | 824 | key = ipath.split("/")[-1].split(".")[0] + "_" + self.name 825 | QData.add_tags = tags_names 826 | QData.apply_filters((ipath, '', {}, {}), key, False) 827 | 828 | tags_string = ", ".join(tags_names) 829 | txtfile = Path(ipath).with_suffix(".txt") 830 | with io.open(txtfile, "w", encoding="utf-8") as filename: 831 | filename.write(tags_string) 832 | return images, process_images 833 | 834 | def large_batch_interrogate(self, images, dry_run=True) -> None: 835 | """ Interrogate a large batch of images. """ 836 | 837 | # init model 838 | if not hasattr(self, 'model') or self.model is None: 839 | self.load() 840 | 841 | os.environ["TF_XLA_FLAGS"] = '--tf_xla_auto_jit=2 '\ 842 | '--tf_xla_cpu_global_jit' 843 | # Reduce logging 844 | # os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" 845 | 846 | import tensorflow as tf 847 | 848 | from tagger.generator.tf_data_reader import DataGenerator 849 | 850 | # tensorflow maps nearly all vram by default, so we limit this 851 | # https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth 852 | # TODO: only run on the first run 853 | gpus = tf.config.experimental.list_physical_devices("GPU") 854 | if gpus: 855 | for device in gpus: 856 | try: 857 | tf.config.experimental.set_memory_growth(device, True) 858 | except RuntimeError as err: 859 | print(err) 860 | 861 | if dry_run: # dry run 862 | height, width = 224, 224 863 | process_images = self.dry_run(images) 864 | else: 865 | _, height, width, _ = self.model.inputs[0].shape 866 | 867 | @tf.function 868 | def pred_model(model): 869 | return self.model(model, training=False) 870 | 871 | process_images = self.run(images, pred_model) 872 | 873 | generator = DataGenerator( 874 | file_list=images, target_height=height, target_width=width, 875 | batch_size=getattr(shared.opts, 'tagger_batch_size', 1024) 876 | ).gen_ds() 877 | 878 | orig_add_tags = QData.add_tags 879 | for filepaths, image_list in tqdm(generator): 880 | process_images(filepaths, image_list) 881 | QData.add_tag = orig_add_tags 882 | del os.environ["TF_XLA_FLAGS"] -------------------------------------------------------------------------------- /tagger/preset.py: -------------------------------------------------------------------------------- 1 | """Module for Tagger, to save and load presets.""" 2 | import os 3 | import json 4 | 5 | from typing import Tuple, List, Dict 6 | from pathlib import Path 7 | from gradio.context import Context 8 | from modules.images import sanitize_filename_part # pylint: disable=E0401 9 | 10 | PresetDict = Dict[str, Dict[str, any]] 11 | 12 | 13 | class Preset: 14 | """Preset class for Tagger, to save and load presets.""" 15 | base_dir: Path 16 | default_filename: str 17 | default_values: PresetDict 18 | components: List[object] 19 | 20 | def __init__( 21 | self, 22 | base_dir: os.PathLike, 23 | default_filename='default.json' 24 | ) -> None: 25 | self.base_dir = Path(base_dir) 26 | self.default_filename = default_filename 27 | self.default_values = self.load(default_filename)[1] 28 | self.components = [] 29 | 30 | def component(self, component_class: object, **kwargs) -> object: 31 | # find all the top components from the Gradio context and create a path 32 | parent = Context.block 33 | paths = [kwargs['label']] 34 | 35 | while parent is not None: 36 | if hasattr(parent, 'label'): 37 | paths.insert(0, parent.label) 38 | 39 | parent = parent.parent 40 | 41 | path = '/'.join(paths) 42 | 43 | component = component_class(**{ 44 | **kwargs, 45 | **self.default_values.get(path, {}) 46 | }) 47 | 48 | component.path = path 49 | 50 | self.components.append(component) 51 | return component 52 | 53 | def load(self, filename: str) -> Tuple[str, PresetDict]: 54 | if not filename.endswith('.json'): 55 | filename += '.json' 56 | 57 | path = self.base_dir.joinpath(sanitize_filename_part(filename)) 58 | configs = {} 59 | 60 | if path.is_file(): 61 | configs = json.loads(path.read_text(encoding='utf-8')) 62 | 63 | return path, configs 64 | 65 | def save(self, filename: str, *values) -> Tuple: 66 | path, configs = self.load(filename) 67 | 68 | for index, component in enumerate(self.components): 69 | config = configs.get(component.path, {}) 70 | config['value'] = values[index] 71 | 72 | for attr in ['visible', 'min', 'max', 'step']: 73 | if hasattr(component, attr): 74 | config[attr] = config.get(attr, getattr(component, attr)) 75 | 76 | configs[component.path] = config 77 | 78 | self.base_dir.mkdir(0o777, True, True) 79 | path.write_text(json.dumps(configs, indent=4), encoding='utf-8') 80 | 81 | return 'successfully saved the preset' 82 | 83 | def apply(self, filename: str) -> Tuple: 84 | values = self.load(filename)[1] 85 | outputs = [] 86 | 87 | for component in self.components: 88 | config = values.get(component.path, {}) 89 | 90 | if 'value' in config and hasattr(component, 'choices'): 91 | if config['value'] not in component.choices: 92 | config['value'] = None 93 | 94 | outputs.append(component.update(**config)) 95 | 96 | return (*outputs, 'successfully loaded the preset') 97 | 98 | def list(self) -> List[str]: 99 | presets = [ 100 | p.name 101 | for p in self.base_dir.glob('*.json') 102 | if p.is_file() 103 | ] 104 | 105 | if len(presets) < 1: 106 | presets.append(self.default_filename) 107 | 108 | return presets 109 | -------------------------------------------------------------------------------- /tagger/settings.py: -------------------------------------------------------------------------------- 1 | """Settings tab entries for the tagger module""" 2 | import os 3 | from typing import List 4 | from modules import shared # pylint: disable=import-error 5 | import gradio as gr 6 | 7 | # kaomoji from WD 1.4 tagger csv. thanks, Meow-San#5400! 8 | DEFAULT_KAMOJIS = '0_0, (o)_(o), +_+, +_-, ._., _, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||' # pylint: disable=line-too-long # noqa: E501 9 | 10 | DEFAULT_OFF = '[name].[output_extension]' 11 | 12 | HF_CACHE = os.environ.get('HF_HOME', os.environ.get('HUGGINGFACE_HUB_CACHE', 13 | str(os.path.join(shared.models_path, 'interrogators')))) 14 | 15 | def slider_wrapper(value, elem_id, **kwargs): 16 | # required or else gradio will throw errors 17 | return gr.Slider(**kwargs) 18 | 19 | 20 | def on_ui_settings(): 21 | """Called when the UI settings tab is opened""" 22 | Its = InterrogatorSettings 23 | section = 'tagger', 'Tagger' 24 | shared.opts.add_option( 25 | key='tagger_out_filename_fmt', 26 | info=shared.OptionInfo( 27 | DEFAULT_OFF, 28 | label='Tag file output format. Leave blank to use same filename or' 29 | ' e.g. "[name].[hash:sha1].[output_extension]". Also allowed are ' 30 | '[extension] or any other [hash:] supported by hashlib', 31 | section=section, 32 | ), 33 | ) 34 | shared.opts.onchange( 35 | key='tagger_out_filename_fmt', 36 | func=Its.set_output_filename_format 37 | ) 38 | shared.opts.add_option( 39 | key='tagger_count_threshold', 40 | info=shared.OptionInfo( 41 | 100.0, 42 | label="Maximum number of tags to be shown in the UI", 43 | section=section, 44 | component=slider_wrapper, 45 | component_args={"minimum": 1.0, "maximum": 500.0, "step": 1.0}, 46 | ), 47 | ) 48 | shared.opts.add_option( 49 | key='tagger_batch_recursive', 50 | info=shared.OptionInfo( 51 | True, 52 | label='Glob recursively with input directory pattern', 53 | section=section, 54 | ), 55 | ) 56 | shared.opts.add_option( 57 | key='tagger_auto_serde_json', 58 | info=shared.OptionInfo( 59 | True, 60 | label='Auto load and save JSON database', 61 | section=section, 62 | ), 63 | ) 64 | shared.opts.add_option( 65 | key='tagger_store_images', 66 | info=shared.OptionInfo( 67 | False, 68 | label='Store images in database', 69 | section=section, 70 | ), 71 | ) 72 | shared.opts.add_option( 73 | key='tagger_weighted_tags_files', 74 | info=shared.OptionInfo( 75 | False, 76 | label='Write weights to tags files', 77 | section=section, 78 | ), 79 | ) 80 | shared.opts.add_option( 81 | key='tagger_verbose', 82 | info=shared.OptionInfo( 83 | False, 84 | label='Console log tag counts per file, no progress bar', 85 | section=section, 86 | ), 87 | ) 88 | shared.opts.add_option( 89 | key='tagger_repl_us', 90 | info=shared.OptionInfo( 91 | True, 92 | label='Use spaces instead of underscore', 93 | section=section, 94 | ), 95 | ) 96 | shared.opts.add_option( 97 | key='tagger_repl_us_excl', 98 | info=shared.OptionInfo( 99 | DEFAULT_KAMOJIS, 100 | label='Underscore replacement excludes (split by comma)', 101 | section=section, 102 | ), 103 | ) 104 | shared.opts.onchange( 105 | key='tagger_repl_us_excl', 106 | func=Its.set_us_excl 107 | ) 108 | shared.opts.add_option( 109 | key='tagger_escape', 110 | info=shared.OptionInfo( 111 | False, 112 | label='Escape brackets', 113 | section=section, 114 | ), 115 | ) 116 | shared.opts.add_option( 117 | key='tagger_batch_size', 118 | info=shared.OptionInfo( 119 | 1024, 120 | label='batch size for large queries', 121 | section=section, 122 | ), 123 | ) 124 | # see huggingface_hub guides/manage-cache 125 | shared.opts.add_option( 126 | key='tagger_hf_cache_dir', 127 | info=shared.OptionInfo( 128 | HF_CACHE, 129 | label='HuggingFace cache directory, ' 130 | 'see huggingface_hub guides/manage-cache', 131 | section=section, 132 | ), 133 | ) 134 | 135 | 136 | def split_str(string: str, separator=',') -> List[str]: 137 | return [x.strip() for x in string.split(separator) if x] 138 | 139 | 140 | class InterrogatorSettings: 141 | kamojis = set(split_str(DEFAULT_KAMOJIS)) 142 | output_filename_format = DEFAULT_OFF 143 | hf_cache = HF_CACHE 144 | 145 | @classmethod 146 | def set_us_excl(cls): 147 | ruxs = getattr(shared.opts, 'tagger_repl_us_excl', DEFAULT_KAMOJIS) 148 | cls.kamojis = set(split_str(ruxs)) 149 | 150 | @classmethod 151 | def set_output_filename_format(cls): 152 | fnfmt = getattr(shared.opts, 'tagger_out_filename_fmt', DEFAULT_OFF) 153 | if fnfmt[-12:] == '.[extension]': 154 | print("refused to write an image extension") 155 | fnfmt = fnfmt[:-12] + '.[output_extension]' 156 | 157 | cls.output_filename_format = fnfmt.strip() 158 | -------------------------------------------------------------------------------- /tagger/ui.py: -------------------------------------------------------------------------------- 1 | """ This module contains the ui for the tagger tab. """ 2 | from typing import Dict, Tuple, List, Optional 3 | import gradio as gr 4 | import re 5 | import json 6 | from pathlib import Path 7 | from PIL import Image 8 | from packaging import version 9 | 10 | try: 11 | from tensorflow import __version__ as tf_version 12 | except ImportError: 13 | tf_version = '0.0.0' 14 | 15 | from html import escape as html_esc 16 | 17 | from modules import ui # pylint: disable=import-error 18 | from modules import generation_parameters_copypaste as parameters_copypaste # pylint: disable=import-error # noqa 19 | 20 | try: 21 | from modules.call_queue import wrap_gradio_gpu_call 22 | except ImportError: 23 | from webui import wrap_gradio_gpu_call # pylint: disable=import-error 24 | from tagger import utils # pylint: disable=import-error 25 | from tagger import settings 26 | from tagger.interrogator import Interrogator as It # pylint: disable=E0401 27 | from tagger.uiset import IOData, QData # pylint: disable=import-error 28 | 29 | TAG_INPUTS = ["add", "keep", "exclude", "search", "replace"] 30 | COMMON_OUTPUT = Tuple[ 31 | Optional[str], # tags as string 32 | Optional[str], # html tags as string 33 | Optional[str], # discarded tags as string 34 | Optional[Dict[str, float]], # rating confidences 35 | Optional[Dict[str, float]], # tag confidences 36 | Optional[Dict[str, float]], # excluded tag confidences 37 | str, # error message 38 | ] 39 | 40 | class GalleryState: 41 | selected_index: Optional[int] = None 42 | 43 | def unload_interrogators() -> Tuple[str]: 44 | unloaded_models = 0 45 | remaining_models = '' 46 | 47 | for i in utils.interrogators.values(): 48 | if i.unload(): 49 | unloaded_models = unloaded_models + 1 50 | elif i.model is not None: 51 | if remaining_models == '': 52 | remaining_models = f', remaining models:
    • {i.name}
    • ' 53 | else: 54 | remaining_models = remaining_models + f'
    • {i.name}
    • ' 55 | if remaining_models != '': 56 | remaining_models = remaining_models + "Some tensorflow models could "\ 57 | "not be unloaded, a known issue." 58 | QData.clear(1) 59 | 60 | return (f'{unloaded_models} model(s) unloaded{remaining_models}',) 61 | 62 | 63 | def on_interrogate( 64 | input_glob: str, output_dir: str, name: str, filt: str, *args 65 | ) -> COMMON_OUTPUT: 66 | # input glob should always be rechecked for new files 67 | IOData.update_input_glob(input_glob) 68 | if output_dir != It.input["output_dir"]: 69 | IOData.update_output_dir(output_dir) 70 | It.input["output_dir"] = output_dir 71 | 72 | if len(IOData.err) > 0: 73 | return (None,) * 6 + (IOData.error_msg(),) 74 | 75 | for i, val in enumerate(args): 76 | part = TAG_INPUTS[i] 77 | if val != It.input[part]: 78 | getattr(QData, "update_" + part)(val) 79 | It.input[part] = val 80 | 81 | interrogator: It = next((i for i in utils.interrogators.values() if 82 | i.name == name), None) 83 | if interrogator is None: 84 | return (None,) * 6 + (f"'{name}': invalid interrogator",) 85 | 86 | interrogator.batch_interrogate() 87 | return search_filter(filt) 88 | 89 | 90 | def on_gallery() -> List: 91 | return QData.get_image_dups() 92 | 93 | 94 | def on_interrogate_image(*args) -> COMMON_OUTPUT: 95 | # hack brcause image interrogaion occurs twice 96 | It.odd_increment = It.odd_increment + 1 97 | if It.odd_increment & 1 == 1: 98 | return (None,) * 6 + ('',) 99 | return on_interrogate_image_submit(*args) 100 | 101 | 102 | def on_interrogate_image_submit( 103 | image: Image, name: str, filt: str, *args 104 | ) -> COMMON_OUTPUT: 105 | for i, val in enumerate(args): 106 | part = TAG_INPUTS[i] 107 | if val != It.input[part]: 108 | getattr(QData, "update_" + part)(val) 109 | It.input[part] = val 110 | 111 | if image is None: 112 | return (None,) * 6 + ('No image selected',) 113 | interrogator: It = next((i for i in utils.interrogators.values() if 114 | i.name == name), None) 115 | if interrogator is None: 116 | return (None,) * 6 + (f"'{name}': invalid interrogator",) 117 | 118 | interrogator.interrogate_image(image) 119 | return search_filter(filt) 120 | 121 | 122 | def move_selection_to_input( 123 | filt: str, field: str 124 | ) -> Tuple[Optional[str], Optional[str], str]: 125 | """ moves the selected to the input field """ 126 | if It.output is None: 127 | return (None, None, '') 128 | tags = It.output[1] 129 | got = It.input[field] 130 | existing = set(got.split(', ')) 131 | if filt: 132 | re_part = re.compile('(' + re.sub(', ?', '|', filt) + ')') 133 | tags = {k: v for k, v in tags.items() if re_part.search(k) and 134 | k not in existing} 135 | print("Tags remaining: ", tags) 136 | 137 | if len(tags) == 0: 138 | return ('', None, '') 139 | 140 | if got != '': 141 | got = got + ', ' 142 | 143 | (data, info) = It.set(field)(got + ', '.join(tags.keys())) 144 | return ('', data, info) 145 | 146 | 147 | def move_selection_to_keep( 148 | tag_search_filter: str 149 | ) -> Tuple[Optional[str], Optional[str], str]: 150 | return move_selection_to_input(tag_search_filter, "keep") 151 | 152 | 153 | def move_selection_to_exclude( 154 | tag_search_filter: str 155 | ) -> Tuple[Optional[str], Optional[str], str]: 156 | return move_selection_to_input(tag_search_filter, "exclude") 157 | 158 | 159 | def search_filter(filt: str) -> COMMON_OUTPUT: 160 | """ filters the tags and lost tags for the search field """ 161 | ratings, tags, lost, info = It.output 162 | if ratings is None: 163 | return (None,) * 6 + (info,) 164 | if filt: 165 | re_part = re.compile('(' + re.sub(', ?', '|', filt) + ')') 166 | tags = {k: v for k, v in tags.items() if re_part.search(k)} 167 | lost = {k: v for k, v in lost.items() if re_part.search(k)} 168 | 169 | h_tags = ', '.join(f'{k}' for k in tags.keys()) 171 | h_lost = ', '.join(f'{k}' for k in lost.keys()) 173 | 174 | return (', '.join(tags.keys()), h_tags, h_lost, ratings, tags, lost, info) 175 | 176 | def update_file_tags(file_path: str, selected_tags: List[str]) -> None: 177 | """Updates the tags file for a given image with the selected tags.""" 178 | txt_path = Path(file_path).with_suffix('.txt') 179 | if txt_path.exists(): 180 | txt_path.write_text(', '.join(selected_tags), encoding='utf-8') 181 | 182 | def get_file_tags(file_path: str) -> List[str]: 183 | """Gets the existing tags from a file's associated txt file.""" 184 | txt_path = Path(file_path).with_suffix('.txt') 185 | if txt_path.exists(): 186 | content = txt_path.read_text(encoding='utf-8') 187 | return [tag.strip() for tag in content.split(',') if tag.strip()] 188 | return [] 189 | 190 | def parse_weight(weight_str: float) -> Tuple[int, float]: 191 | """Parses a weight string into image_id and actual weight.""" 192 | image_id = int(weight_str) 193 | weight = weight_str - image_id 194 | return image_id, weight * 100 # Convert to percentage 195 | 196 | def get_image_tags(image_id: int) -> Dict[str, float]: 197 | """Gets all tags and their weights for a specific image from db.json.""" 198 | if not QData.json_db or not QData.json_db.exists(): 199 | return {} 200 | 201 | try: 202 | data = json.loads(QData.json_db.read_text()) 203 | tags_data = data.get("tag", {}) 204 | image_tags = {} 205 | 206 | for tag, weights in tags_data.items(): 207 | for weight in weights: 208 | parsed_id, parsed_weight = parse_weight(weight) 209 | if parsed_id == image_id: 210 | corrected_tag = QData.correct_tag(tag) 211 | image_tags[corrected_tag] = parsed_weight 212 | 213 | return image_tags 214 | 215 | except (json.JSONDecodeError, AttributeError) as e: 216 | print(f"Error reading db.json: {e}") 217 | return {} 218 | 219 | def get_image_id_from_path(file_path: str) -> Optional[int]: 220 | """Gets the image ID from the query section of db.json.""" 221 | if not QData.json_db or not QData.json_db.exists(): 222 | return None 223 | 224 | try: 225 | data = json.loads(QData.json_db.read_text()) 226 | queries = data.get("query", {}) 227 | 228 | # Search for the file path in queries 229 | for _, query_data in queries.items(): 230 | if query_data[0] == str(Path(file_path).absolute()): 231 | return query_data[1] 232 | return None 233 | 234 | except (json.JSONDecodeError, AttributeError) as e: 235 | print(f"Error reading db.json: {e}") 236 | return None 237 | 238 | def get_sorted_tags(file_path: str) -> List[str]: 239 | """Gets all tags for an image, sorted by weight.""" 240 | image_id = get_image_id_from_path(file_path) 241 | if image_id is None: 242 | return [] 243 | 244 | # Get tags from db.json for this image 245 | image_tags = get_image_tags(image_id) 246 | 247 | # Get tags from the txt file 248 | file_tags = set(get_file_tags(file_path)) 249 | 250 | # Sort tags by weight 251 | sorted_tags = sorted( 252 | image_tags.items(), 253 | key=lambda x: (-x[1], x[0]) # Sort by weight desc, then tag name asc 254 | ) 255 | 256 | # First add tags that are in the txt file 257 | result = [tag for tag in sorted_tags if tag[0] in file_tags] 258 | 259 | # Then add tags that are not in the txt file 260 | result.extend(tag for tag in sorted_tags if tag[0] not in file_tags) 261 | 262 | return [tag[0] for tag in result] 263 | 264 | def on_gallery_select(evt: gr.SelectData, state: gr.State) -> tuple: 265 | """Handler for gallery selection event.""" 266 | image_paths = QData.get_image_dups() 267 | if not image_paths or evt.index >= len(image_paths): 268 | return gr.CheckboxGroup.update(choices=[], value=[], label="No image selected"), None 269 | 270 | selected_path = image_paths[evt.index] 271 | file_tags = get_file_tags(selected_path) 272 | all_tags = get_sorted_tags(selected_path) 273 | 274 | return ( 275 | gr.CheckboxGroup.update( 276 | choices=all_tags, 277 | value=file_tags, 278 | label=f"Tags for {Path(selected_path).name}" 279 | ), 280 | evt.index 281 | ) 282 | 283 | def on_tags_change(selected_tags: List[str], state: gr.State) -> None: 284 | """Handler for checkbox group changes.""" 285 | if state is None: 286 | return 287 | 288 | image_paths = QData.get_image_dups() 289 | if not image_paths or state >= len(image_paths): 290 | return 291 | 292 | selected_path = image_paths[state] 293 | update_file_tags(selected_path, selected_tags) 294 | 295 | def create_gallery_ui(tab_gallery): 296 | """Creates the gallery UI components.""" 297 | with tab_gallery: 298 | selected_index = gr.State(None) 299 | 300 | # Create a container div for consistent height 301 | with gr.Box(elem_id="gallery-container"): 302 | # Use Row for main layout 303 | with gr.Row(): 304 | # Left column with gallery 305 | with gr.Column(scale=1): 306 | gallery = gr.Gallery( 307 | label='Gallery', 308 | elem_id='gallery', 309 | object_fit="contain", 310 | height="800px", 311 | show_label=False 312 | ) 313 | 314 | # Right column with tag editor 315 | with gr.Column(scale=1): 316 | tag_editor = gr.CheckboxGroup( 317 | label="Select image to edit tags", 318 | choices=[], 319 | value=[], 320 | interactive=True, 321 | container=True, 322 | elem_id="tag-editor" 323 | ) 324 | 325 | # Connect the components 326 | gallery.select( 327 | fn=on_gallery_select, 328 | inputs=[selected_index], 329 | outputs=[tag_editor, selected_index] 330 | ) 331 | 332 | tag_editor.change( 333 | fn=on_tags_change, 334 | inputs=[tag_editor, selected_index], 335 | outputs=[] 336 | ) 337 | 338 | return gallery 339 | 340 | def on_ui_tabs(): 341 | """ configures the ui on the tagger tab """ 342 | # If checkboxes misbehave you have to adapt the default.json preset 343 | tag_input = {} 344 | 345 | with gr.Blocks(analytics_enabled=False) as tagger_interface: 346 | with gr.Tabs(): 347 | with gr.TabItem("Tag Generation"): 348 | with gr.Row(): 349 | with gr.Column(variant='panel'): 350 | 351 | # input components 352 | with gr.Tabs(): 353 | with gr.TabItem(label='Single process'): 354 | image = gr.Image( 355 | label='Source', 356 | source='upload', 357 | interactive=True, 358 | type="pil" 359 | ) 360 | image_submit = gr.Button( 361 | value='Interrogate image', 362 | variant='primary' 363 | ) 364 | 365 | with gr.TabItem(label='Batch from directory'): 366 | input_glob = utils.preset.component( 367 | gr.Textbox, 368 | value='', 369 | label='Input directory - To recurse use ** or */* ' 370 | 'in your glob; also check the settings tab.', 371 | placeholder='/path/to/images or to/images/**/*' 372 | ) 373 | output_dir = utils.preset.component( 374 | gr.Textbox, 375 | value=It.input["output_dir"], 376 | label='Output directory', 377 | placeholder='Leave blank to save images ' 378 | 'to the same path.' 379 | ) 380 | 381 | batch_submit = gr.Button( 382 | value='Interrogate', 383 | variant='primary' 384 | ) 385 | with gr.Row(variant='compact'): 386 | with gr.Column(variant='panel'): 387 | large_query = utils.preset.component( 388 | gr.Checkbox, 389 | label='huge batch query (TF 2.10, ' 390 | 'experimental)', 391 | value=False, 392 | interactive=version.parse(tf_version) == 393 | version.parse('2.10') 394 | ) 395 | with gr.Column(variant='panel'): 396 | save_tags = utils.preset.component( 397 | gr.Checkbox, 398 | label='Save to tags files', 399 | value=True 400 | ) 401 | 402 | info = gr.HTML( 403 | label='Info', 404 | interactive=False, 405 | elem_classes=['info'] 406 | ) 407 | 408 | # interrogator selector 409 | with gr.Column(): 410 | # preset selector 411 | with gr.Row(variant='compact'): 412 | available_presets = utils.preset.list() 413 | selected_preset = gr.Dropdown( 414 | label='Preset', 415 | choices=available_presets, 416 | value=available_presets[0] 417 | ) 418 | 419 | save_preset_button = gr.Button( 420 | value=ui.save_style_symbol 421 | ) 422 | 423 | ui.create_refresh_button( 424 | selected_preset, 425 | lambda: None, 426 | lambda: {'choices': utils.preset.list()}, 427 | 'refresh_preset' 428 | ) 429 | 430 | with gr.Row(variant='compact'): 431 | def refresh(): 432 | utils.refresh_interrogators() 433 | return sorted(x.name for x in utils.interrogators 434 | .values()) 435 | interrogator_names = refresh() 436 | interrogator = utils.preset.component( 437 | gr.Dropdown, 438 | label='Interrogator', 439 | choices=interrogator_names, 440 | value=( 441 | None 442 | if len(interrogator_names) < 1 else 443 | interrogator_names[-1] 444 | ) 445 | ) 446 | 447 | ui.create_refresh_button( 448 | interrogator, 449 | lambda: None, 450 | lambda: {'choices': refresh()}, 451 | 'refresh_interrogator' 452 | ) 453 | 454 | unload_all_models = gr.Button( 455 | value='Unload all interrogate models' 456 | ) 457 | with gr.Row(variant='compact'): 458 | tag_input["add"] = utils.preset.component( 459 | gr.Textbox, 460 | label='Additional tags (comma split)', 461 | elem_id='additional-tags' 462 | ) 463 | with gr.Row(variant='compact'): 464 | threshold = utils.preset.component( 465 | gr.Slider, 466 | label='Weight threshold', 467 | minimum=0, 468 | maximum=1, 469 | value=QData.threshold 470 | ) 471 | tag_frac_threshold = utils.preset.component( 472 | gr.Slider, 473 | label='Min tag fraction in batch and ' 474 | 'interrogations', 475 | minimum=0, 476 | maximum=1, 477 | value=QData.tag_frac_threshold, 478 | ) 479 | with gr.Row(variant='compact'): 480 | cumulative = utils.preset.component( 481 | gr.Checkbox, 482 | label='Combine interrogations', 483 | value=False 484 | ) 485 | unload_after = utils.preset.component( 486 | gr.Checkbox, 487 | label='Unload model after running', 488 | value=False 489 | ) 490 | with gr.Row(variant='compact'): 491 | tag_input["search"] = utils.preset.component( 492 | gr.Textbox, 493 | label='Search tag, .. ->', 494 | elem_id='search-tags' 495 | ) 496 | tag_input["replace"] = utils.preset.component( 497 | gr.Textbox, 498 | label='-> Replace tag, ..', 499 | elem_id='replace-tags' 500 | ) 501 | with gr.Row(variant='compact'): 502 | tag_input["keep"] = utils.preset.component( 503 | gr.Textbox, 504 | label='Keep tag, ..', 505 | elem_id='keep-tags' 506 | ) 507 | tag_input["exclude"] = utils.preset.component( 508 | gr.Textbox, 509 | label='Exclude tag, ..', 510 | elem_id='exclude-tags' 511 | ) 512 | 513 | # output components 514 | with gr.Column(variant='panel'): 515 | with gr.Row(variant='compact'): 516 | with gr.Column(variant='compact'): 517 | mv_selection_to_keep = gr.Button( 518 | value='Move visible tags to keep tags', 519 | variant='secondary' 520 | ) 521 | mv_selection_to_exclude = gr.Button( 522 | value='Move visible tags to exclude tags', 523 | variant='secondary' 524 | ) 525 | with gr.Column(variant='compact'): 526 | tag_search_selection = utils.preset.component( 527 | gr.Textbox, 528 | label='Multi string search: part1, part2.. ' 529 | '(Enter key to update)', 530 | ) 531 | with gr.Tabs(): 532 | with gr.TabItem(label='Ratings and included tags'): 533 | # clickable tags to populate excluded tags 534 | tags = gr.State(value="") 535 | html_tags = gr.HTML( 536 | label='Tags', 537 | elem_id='tags', 538 | ) 539 | 540 | with gr.Row(): 541 | parameters_copypaste.bind_buttons( 542 | parameters_copypaste.create_buttons( 543 | ["txt2img", "img2img"], 544 | ), 545 | None, 546 | tags 547 | ) 548 | rating_confidences = gr.Label( 549 | label='Rating confidences', 550 | elem_id='rating-confidences', 551 | ) 552 | tag_confidences = gr.Label( 553 | label='Tag confidences', 554 | elem_id='tag-confidences', 555 | ) 556 | with gr.TabItem(label='Excluded tags'): 557 | # clickable tags to populate keep tags 558 | discarded_tags = gr.HTML( 559 | label='Tags', 560 | elem_id='tags', 561 | ) 562 | excluded_tag_confidences = gr.Label( 563 | label='Excluded Tag confidences', 564 | elem_id='discard-tag-confidences', 565 | ) 566 | tab_gallery = gr.TabItem(label='Tag Curation') 567 | gallery = create_gallery_ui(tab_gallery) 568 | 569 | 570 | # register events 571 | # Checkboxes 572 | cumulative.input(fn=It.flip('cumulative'), inputs=[], outputs=[]) 573 | large_query.input(fn=It.flip('large_query'), inputs=[], outputs=[]) 574 | unload_after.input(fn=It.flip('unload_after'), inputs=[], outputs=[]) 575 | 576 | save_tags.input(fn=IOData.flip_save_tags(), inputs=[], outputs=[]) 577 | 578 | # Preset and unload buttons 579 | selected_preset.change(fn=utils.preset.apply, inputs=[selected_preset], 580 | outputs=[*utils.preset.components, info]) 581 | 582 | save_preset_button.click(fn=utils.preset.save, inputs=[selected_preset, 583 | *utils.preset.components], outputs=[info]) 584 | 585 | unload_all_models.click(fn=unload_interrogators, outputs=[info]) 586 | 587 | # Sliders 588 | threshold.input(fn=QData.set("threshold"), inputs=[threshold], 589 | outputs=[]) 590 | threshold.release(fn=QData.set("threshold"), inputs=[threshold], 591 | outputs=[]) 592 | 593 | tag_frac_threshold.input(fn=QData.set("tag_frac_threshold"), 594 | inputs=[tag_frac_threshold], outputs=[]) 595 | tag_frac_threshold.release(fn=QData.set("tag_frac_threshold"), 596 | inputs=[tag_frac_threshold], outputs=[]) 597 | 598 | # Input textboxes (blur == lose focus) 599 | for tag in TAG_INPUTS: 600 | tag_input[tag].blur(fn=wrap_gradio_gpu_call(It.set(tag)), 601 | inputs=[tag_input[tag]], 602 | outputs=[tag_input[tag], info]) 603 | 604 | input_glob.blur(fn=wrap_gradio_gpu_call(It.set("input_glob")), 605 | inputs=[input_glob], outputs=[input_glob, info]) 606 | output_dir.blur(fn=wrap_gradio_gpu_call(It.set("output_dir")), 607 | inputs=[output_dir], outputs=[output_dir, info]) 608 | 609 | tab_gallery.select(fn=on_gallery, inputs=[], outputs=[gallery]) 610 | 611 | common_output = [tags, html_tags, discarded_tags, rating_confidences, 612 | tag_confidences, excluded_tag_confidences, info] 613 | 614 | # search input textbox 615 | for fun in [tag_search_selection.change, tag_search_selection.submit]: 616 | fun(fn=wrap_gradio_gpu_call(search_filter), 617 | inputs=[tag_search_selection], outputs=common_output) 618 | 619 | # buttons to move tags (right) 620 | mv_selection_to_keep.click( 621 | fn=wrap_gradio_gpu_call(move_selection_to_keep), 622 | inputs=[tag_search_selection], 623 | outputs=[tag_search_selection, tag_input["keep"], info]) 624 | 625 | mv_selection_to_exclude.click( 626 | fn=wrap_gradio_gpu_call(move_selection_to_exclude), 627 | inputs=[tag_search_selection], 628 | outputs=[tag_search_selection, tag_input["exclude"], info]) 629 | 630 | common_input = [interrogator, tag_search_selection] + \ 631 | [tag_input[tag] for tag in TAG_INPUTS] 632 | 633 | # interrogation events 634 | image_submit.click(fn=wrap_gradio_gpu_call(on_interrogate_image_submit), 635 | inputs=[image] + common_input, outputs=common_output) 636 | 637 | image.change(fn=wrap_gradio_gpu_call(on_interrogate_image), 638 | inputs=[image] + common_input, outputs=common_output) 639 | 640 | batch_submit.click(fn=wrap_gradio_gpu_call(on_interrogate), 641 | inputs=[input_glob, output_dir] + common_input, 642 | outputs=common_output) 643 | 644 | return [(tagger_interface, "Tagger", "tagger")] 645 | -------------------------------------------------------------------------------- /tagger/uiset.py: -------------------------------------------------------------------------------- 1 | """ for handling ui settings """ 2 | 3 | from typing import List, Dict, Tuple, Callable, Set, Optional 4 | import os 5 | from pathlib import Path 6 | from glob import glob 7 | from math import ceil 8 | from hashlib import sha256 9 | from re import compile as re_comp, sub as re_sub, match as re_match, IGNORECASE 10 | from json import dumps, loads 11 | from jsonschema import validate, ValidationError 12 | from functools import partial 13 | from collections import defaultdict 14 | from PIL import Image 15 | 16 | from modules import shared # pylint: disable=import-error 17 | from modules.deepbooru import re_special # pylint: disable=import-error 18 | from tagger import format as tags_format # pylint: disable=import-error 19 | from tagger import settings # pylint: disable=import-error 20 | 21 | Its = settings.InterrogatorSettings 22 | 23 | # PIL.Image.registered_extensions() returns only PNG if you call early 24 | supported_extensions = { 25 | e 26 | for e, f in Image.registered_extensions().items() 27 | if f in Image.OPEN 28 | } 29 | 30 | # interrogator return type 31 | ItRetTP = Tuple[ 32 | Dict[str, float], # rating confidences 33 | Dict[str, float], # tag confidences 34 | Dict[str, float], # excluded tag confidences 35 | str, # error message 36 | ] 37 | 38 | 39 | class IOData: 40 | """ data class for input and output paths """ 41 | last_path_mtimes = None 42 | base_dir = None 43 | output_root = None 44 | paths: List[List[str]] = [] 45 | save_tags = True 46 | err: Set[str] = set() 47 | 48 | @classmethod 49 | def error_msg(cls) -> str: 50 | return "Errors:
        " + ''.join(f'
      • {x}
      • ' for x in cls.err) + \ 51 | "
      " 52 | 53 | @classmethod 54 | def flip_save_tags(cls) -> callable: 55 | def toggle(): 56 | cls.save_tags = not cls.save_tags 57 | return toggle 58 | 59 | @classmethod 60 | def toggle_save_tags(cls) -> None: 61 | cls.save_tags = not cls.save_tags 62 | 63 | @classmethod 64 | def update_output_dir(cls, output_dir: str) -> None: 65 | """ update output directory, and set input and output paths """ 66 | pout = Path(output_dir) 67 | if pout != cls.output_root: 68 | paths = [x[0] for x in cls.paths] 69 | cls.paths = [] 70 | cls.output_root = pout 71 | cls.set_batch_io(paths) 72 | 73 | @staticmethod 74 | def get_bytes_hash(data) -> str: 75 | """ get sha256 checksum of file """ 76 | # Note: the checksum from an image is not the same as from file 77 | return sha256(data).hexdigest() 78 | 79 | @classmethod 80 | def get_hashes(cls) -> Set[str]: 81 | """ get hashes of all files """ 82 | ret = set() 83 | for entries in cls.paths: 84 | if len(entries) == 4: 85 | ret.add(entries[3]) 86 | else: 87 | # if there is no checksum, calculate it 88 | image = Image.open(entries[0]) 89 | checksum = cls.get_bytes_hash(image.tobytes()) 90 | entries.append(checksum) 91 | ret.add(checksum) 92 | return ret 93 | 94 | @classmethod 95 | def update_input_glob(cls, input_glob: str) -> None: 96 | """ update input glob pattern, and set input and output paths """ 97 | input_glob = input_glob.strip() 98 | 99 | paths = [] 100 | 101 | # if there is no glob pattern, insert it automatically 102 | if not input_glob.endswith('*'): 103 | if not input_glob.endswith(os.sep): 104 | input_glob += os.sep 105 | input_glob += '*' 106 | 107 | # get root directory of input glob pattern 108 | base_dir = input_glob.replace('?', '*') 109 | base_dir = base_dir.split(os.sep + '*').pop(0) 110 | msg = 'Invalid input directory' 111 | if not os.path.isdir(base_dir): 112 | cls.err.add(msg) 113 | return 114 | cls.err.discard(msg) 115 | 116 | recursive = getattr(shared.opts, 'tagger_batch_recursive', True) 117 | path_mtimes = [] 118 | for filename in glob(input_glob, recursive=recursive): 119 | if not os.path.isdir(filename): 120 | ext = os.path.splitext(filename)[1].lower() 121 | if ext in supported_extensions: 122 | path_mtimes.append(os.path.getmtime(filename)) 123 | paths.append(filename) 124 | elif ext != '.txt' and 'db.json' not in filename: 125 | print(f'{filename}: not an image extension: "{ext}"') 126 | 127 | # interrogating in a directory with no pics, still flush the cache 128 | if len(path_mtimes) > 0 and cls.last_path_mtimes == path_mtimes: 129 | print('No changed images') 130 | return 131 | 132 | QData.clear(2) 133 | cls.last_path_mtimes = path_mtimes 134 | 135 | if not cls.output_root: 136 | cls.output_root = Path(base_dir) 137 | elif cls.base_dir and cls.output_root == Path(cls.base_dir): 138 | cls.output_root = Path(base_dir) 139 | 140 | # XXX what is this basedir magic trying to achieve? 141 | cls.base_dir_last = Path(base_dir).parts[-1] 142 | cls.base_dir = base_dir 143 | 144 | QData.read_json(cls.output_root) 145 | 146 | print(f'found {len(paths)} image(s)') 147 | cls.set_batch_io(paths) 148 | 149 | @classmethod 150 | def set_batch_io(cls, paths: List[str]) -> None: 151 | """ set input and output paths for batch mode """ 152 | checked_dirs = set() 153 | cls.paths = [] 154 | for path in paths: 155 | path = Path(path) 156 | if not cls.save_tags: 157 | cls.paths.append([path, '', '']) 158 | continue 159 | 160 | # guess the output path 161 | base_dir_last_idx = path.parts.index(cls.base_dir_last) 162 | # format output filename 163 | 164 | info = tags_format.Info(path, 'txt') 165 | fmt = partial(lambda info, m: tags_format.parse(m, info), info) 166 | 167 | msg = 'Invalid output format' 168 | cls.err.discard(msg) 169 | try: 170 | formatted_output_filename = tags_format.pattern.sub( 171 | fmt, 172 | Its.output_filename_format 173 | ) 174 | except (TypeError, ValueError): 175 | cls.err.add(msg) 176 | 177 | output_dir = cls.output_root.joinpath( 178 | *path.parts[base_dir_last_idx + 1:]).parent 179 | 180 | tags_out = output_dir.joinpath(formatted_output_filename) 181 | 182 | if output_dir in checked_dirs: 183 | cls.paths.append([path, tags_out, '']) 184 | else: 185 | checked_dirs.add(output_dir) 186 | if os.path.exists(output_dir): 187 | msg = 'output_dir: not a directory.' 188 | if os.path.isdir(output_dir): 189 | cls.paths.append([path, tags_out, '']) 190 | cls.err.discard(msg) 191 | else: 192 | cls.err.add(msg) 193 | else: 194 | cls.paths.append([path, tags_out, output_dir]) 195 | 196 | 197 | class QData: 198 | """ Query data: contains parameters for the query """ 199 | add_tags = [] 200 | keep_tags = set() 201 | exclude_tags = [] 202 | search_tags = {} 203 | replace_tags = [] 204 | threshold = 0.35 205 | tag_frac_threshold = 0.05 206 | count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100) 207 | 208 | # read from db.json, update with what should be written to db.json: 209 | json_db = None 210 | weighed = (defaultdict(list), defaultdict(list)) 211 | query = {} 212 | 213 | # representing the (cumulative) current interrogations 214 | ratings = defaultdict(float) 215 | tags = defaultdict(list) 216 | discarded_tags = defaultdict(list) 217 | in_db = {} 218 | for_tags_file = defaultdict(lambda: defaultdict(float)) 219 | 220 | had_new = False 221 | err = set() 222 | image_dups = defaultdict(set) 223 | 224 | @classmethod 225 | def set(cls, key: str) -> Callable[[str], Tuple[str]]: 226 | def setter(val) -> Tuple[str]: 227 | setattr(cls, key, val) 228 | return setter 229 | 230 | @classmethod 231 | def set(cls, key: str) -> Callable[[str], Tuple[str]]: 232 | def setter(val) -> Tuple[str]: 233 | setattr(cls, key, val) 234 | return setter 235 | 236 | @classmethod 237 | def clear(cls, mode: int) -> None: 238 | """ clear tags and ratings """ 239 | cls.tags.clear() 240 | cls.discarded_tags.clear() 241 | cls.ratings.clear() 242 | cls.for_tags_file.clear() 243 | if mode > 0: 244 | cls.in_db.clear() 245 | cls.image_dups.clear() 246 | if mode > 1: 247 | cls.json_db = None 248 | cls.weighed = (defaultdict(list), defaultdict(list)) 249 | cls.query = {} 250 | if mode > 2: 251 | cls.add_tags = [] 252 | cls.keep_tags = set() 253 | cls.exclude_tags = [] 254 | cls.search_tags = {} 255 | cls.replace_tags = [] 256 | 257 | @classmethod 258 | def test_add(cls, tag: str, current: str, incompatible: list) -> None: 259 | """ check if there are incompatible collections """ 260 | msg = f'Empty tag in {current} tags' 261 | if tag == '': 262 | cls.err.add(msg) 263 | return 264 | cls.err.discard(msg) 265 | for bad in incompatible: 266 | if current < bad: 267 | msg = f'"{tag}" is both in {bad} and {current} tags' 268 | else: 269 | msg = f'"{tag}" is both in {current} and {bad} tags' 270 | attr = getattr(cls, bad + '_tags') 271 | if bad == 'search': 272 | for rex in attr.values(): 273 | if rex.match(tag): 274 | cls.err.add(msg) 275 | return 276 | elif bad in 'exclude': 277 | if any(rex.match(tag) for rex in attr): 278 | cls.err.add(msg) 279 | return 280 | else: 281 | if tag in attr: 282 | cls.err.add(msg) 283 | return 284 | 285 | attr = getattr(cls, current + '_tags') 286 | if current in ['add', 'replace']: 287 | attr.append(tag) 288 | elif current == 'keep': 289 | attr.add(tag) 290 | else: 291 | rex = cls.compile_rex(tag) 292 | if rex: 293 | if current == 'exclude': 294 | attr.append(rex) 295 | elif current == 'search': 296 | attr[len(attr)] = rex 297 | else: 298 | cls.err.add(f'empty regex in {current} tags') 299 | 300 | @classmethod 301 | def update_keep(cls, keep: str) -> None: 302 | cls.keep_tags = set() 303 | if keep == '': 304 | return 305 | un_re = re_comp(r' keep(?: and \w+)? tags') 306 | cls.err = {err for err in cls.err if not un_re.search(err)} 307 | for tag in map(str.strip, keep.split(',')): 308 | cls.test_add(tag, 'keep', ['exclude', 'search']) 309 | 310 | @classmethod 311 | def update_add(cls, add: str) -> None: 312 | cls.add_tags = [] 313 | if add == '': 314 | return 315 | un_re = re_comp(r' add(?: and \w+)? tags') 316 | cls.err = {err for err in cls.err if not un_re.search(err)} 317 | for tag in map(str.strip, add.split(',')): 318 | cls.test_add(tag, 'add', ['exclude', 'search']) 319 | 320 | # silently raise count threshold to avoid issue in apply_filters 321 | if len(cls.add_tags) > cls.count_threshold: 322 | cls.count_threshold = len(cls.add_tags) 323 | 324 | @staticmethod 325 | def compile_rex(rex: str) -> Optional: 326 | if rex in {'', '^', '$', '^$'}: 327 | return None 328 | if rex[0] == '^': 329 | rex = rex[1:] 330 | if rex[-1] == '$': 331 | rex = rex[:-1] 332 | return re_comp('^'+rex+'$', flags=IGNORECASE) 333 | 334 | @classmethod 335 | def update_exclude(cls, exclude: str) -> None: 336 | cls.exclude_tags = [] 337 | if exclude == '': 338 | return 339 | un_re = re_comp(r' exclude(?: and \w+)? tags') 340 | cls.err = {err for err in cls.err if not un_re.search(err)} 341 | 342 | # These tags make the tagger inoperable, exclude them. 343 | exclude = re_sub(r'(\*|\+|\?|\(|\))', r'\\\1', exclude) 344 | 345 | for excl in map(str.strip, exclude.split(',')): 346 | incompatible = ['add', 'keep', 'search', 'replace'] 347 | cls.test_add(excl, 'exclude', incompatible) 348 | 349 | @classmethod 350 | def update_search(cls, search_str: str) -> None: 351 | cls.search_tags = {} 352 | if search_str == '': 353 | return 354 | un_re = re_comp(r' search(?: and \w+)? tags') 355 | cls.err = {err for err in cls.err if not un_re.search(err)} 356 | for rex in map(str.strip, search_str.split(',')): 357 | incompatible = ['add', 'keep', 'exclude', 'replace'] 358 | cls.test_add(rex, 'search', incompatible) 359 | 360 | msg = 'Unequal number of search and replace tags' 361 | if len(cls.search_tags) != len(cls.replace_tags): 362 | cls.err.add(msg) 363 | else: 364 | cls.err.discard(msg) 365 | 366 | @classmethod 367 | def update_replace(cls, replace: str) -> None: 368 | cls.replace_tags = [] 369 | if replace == '': 370 | return 371 | un_re = re_comp(r' replace(?: and \w+)? tags') 372 | cls.err = {err for err in cls.err if not un_re.search(err)} 373 | for repl in map(str.strip, replace.split(',')): 374 | cls.test_add(repl, 'replace', ['exclude', 'search']) 375 | msg = 'Unequal number of search and replace tags' 376 | if len(cls.search_tags) != len(cls.replace_tags): 377 | cls.err.add(msg) 378 | else: 379 | cls.err.discard(msg) 380 | 381 | @classmethod 382 | def get_i_wt(cls, stored: int) -> Tuple[int, float]: 383 | """ 384 | in db.json or QData.weighed, the weights & increment in the list are 385 | encoded. Each filestamp-interrogation corresponds to an incrementing 386 | index. The index is above the floating point, the weight is below. 387 | """ 388 | i = ceil(stored) - 1 389 | return i, stored - i 390 | 391 | @classmethod 392 | def read_json(cls, outdir) -> None: 393 | """ read db.json if it exists, validate it, and update cls """ 394 | cls.json_db = None 395 | if getattr(shared.opts, 'tagger_auto_serde_json', True): 396 | cls.json_db = outdir.joinpath('db.json') 397 | if cls.json_db.is_file(): 398 | print(f'Reading {cls.json_db}') 399 | cls.had_new = False 400 | msg = f'Error reading {cls.json_db}' 401 | cls.err.discard(msg) 402 | # validate json using either json_schema/db_jon_v1_schema.json 403 | # or json_schema/db_jon_v2_schema.json 404 | 405 | schema = Path(__file__).parent.parent.joinpath( 406 | 'json_schema', 'db_json_v1_schema.json' 407 | ) 408 | try: 409 | data = loads(cls.json_db.read_text()) 410 | validate(data, loads(schema.read_text())) 411 | 412 | # convert v2 back to v1 413 | if "meta" in data: 414 | cls.had_new = True # <- force write for v2 -> v1 415 | except (ValidationError, IndexError) as err: 416 | print(f'{msg}: {repr(err)}') 417 | cls.err.add(msg) 418 | data = {"query": {}, "tag": [], "rating": []} 419 | 420 | cls.query = data["query"] 421 | cls.weighed = ( 422 | defaultdict(list, data["rating"]), 423 | defaultdict(list, data["tag"]) 424 | ) 425 | print(f'Read {cls.json_db}: {len(cls.query)} interrogations, ' 426 | f'{len(cls.tags)} tags.') 427 | 428 | @classmethod 429 | def write_json(cls) -> None: 430 | """ write db.json """ 431 | if cls.json_db is not None: 432 | data = { 433 | "rating": cls.weighed[0], 434 | "tag": cls.weighed[1], 435 | "query": cls.query, 436 | } 437 | cls.json_db.write_text(dumps(data, indent=2)) 438 | print(f'Wrote {cls.json_db}: {len(cls.query)} interrogations, ' 439 | f'{len(cls.tags)} tags.') 440 | 441 | @classmethod 442 | def get_index(cls, fi_key: str, path='') -> int: 443 | """ get index for filestamp-interrogator """ 444 | if path and path != cls.query[fi_key][0]: 445 | if cls.query[fi_key][0] != '': 446 | print(f'Dup or rename: Identical checksums for {path}\n' 447 | f'and: {cls.query[fi_key][0]} (path updated)') 448 | cls.had_new = True 449 | cls.query[fi_key] = (path, cls.query[fi_key][1]) 450 | 451 | return cls.query[fi_key][1] 452 | 453 | @classmethod 454 | def single_data(cls, fi_key: str) -> None: 455 | """ get tags and ratings for filestamp-interrogator """ 456 | index = cls.query.get(fi_key)[1] 457 | data = ({}, {}) 458 | for j in range(2): 459 | for ent, lst in cls.weighed[j].items(): 460 | for i, val in map(cls.get_i_wt, lst): 461 | if i == index: 462 | data[j][ent] = val 463 | 464 | QData.in_db[index] = ('', '', '') + data 465 | 466 | @classmethod 467 | def is_excluded(cls, ent: str) -> bool: 468 | """ check if tag is excluded """ 469 | return any(re_match(x, ent) for x in cls.exclude_tags) 470 | 471 | @classmethod 472 | def correct_tag(cls, tag: str) -> str: 473 | """ correct tag for display """ 474 | replace_underscore = getattr(shared.opts, 'tagger_repl_us', True) 475 | if replace_underscore and tag not in Its.kamojis: 476 | tag = tag.replace('_', ' ') 477 | 478 | if getattr(shared.opts, 'tagger_escape', False): 479 | tag = re_special.sub(r'\\\1', tag) # tag_escape_pattern 480 | 481 | if len(cls.search_tags) == len(cls.replace_tags): 482 | for i, regex in cls.search_tags.items(): 483 | if re_match(regex, tag): 484 | tag = re_sub(regex, cls.replace_tags[i], tag) 485 | break 486 | 487 | return tag 488 | 489 | @classmethod 490 | def apply_filters(cls, data) -> None: 491 | """ apply filters to query data, store in db.json if required """ 492 | # data = (path, fi_key, tags, ratings, new) 493 | # fi_key == '' means this is a new file or interrogation for that file 494 | 495 | tags = sorted(data[4].items(), key=lambda x: x[1], reverse=True) 496 | 497 | fi_key = data[2] 498 | index = len(cls.query) 499 | 500 | ratings = sorted(data[3].items(), key=lambda x: x[1], reverse=True) 501 | # loop over ratings 502 | for rating, val in ratings: 503 | if fi_key != '': 504 | cls.weighed[0][rating].append(val + index) 505 | cls.ratings[rating] += val 506 | 507 | max_ct = cls.count_threshold - len(cls.add_tags) 508 | count = 0 509 | # loop over tags with db update 510 | for tag, val in tags: 511 | if isinstance(tag, float): 512 | print(f'bad return from interrogator, float: {tag} {val}') 513 | # FIXME: why does this happen? what does it mean? 514 | continue 515 | 516 | if fi_key != '' and val >= 0.005: 517 | cls.weighed[1][tag].append(val + index) 518 | 519 | if count < max_ct: 520 | tag = cls.correct_tag(tag) 521 | if tag not in cls.keep_tags: 522 | if cls.is_excluded(tag) or val < cls.threshold: 523 | if tag not in cls.add_tags and \ 524 | len(cls.discarded_tags) < max_ct: 525 | cls.discarded_tags[tag].append(val) 526 | continue 527 | if data[1] != '': 528 | current = cls.for_tags_file[data[1]].get(tag, 0.0) 529 | cls.for_tags_file[data[1]][tag] = min(val + current, 1.0) 530 | count += 1 531 | if tag not in cls.add_tags: 532 | # those are already added 533 | cls.tags[tag].append(val) 534 | elif fi_key == '': 535 | break 536 | 537 | if getattr(shared.opts, 'tagger_verbose', True): 538 | print(f'{data[0]}: {count}/{len(tags)} tags kept') 539 | 540 | if fi_key != '': 541 | cls.query[fi_key] = (data[0], index) 542 | 543 | @classmethod 544 | def finalize_batch(cls, count: int) -> ItRetTP: 545 | """ finalize the batch query """ 546 | if cls.json_db and cls.had_new: 547 | cls.write_json() 548 | cls.had_new = False 549 | 550 | # collect the weights per file/interrogation of the prior in db stored. 551 | for index in range(2): 552 | for ent, lst in cls.weighed[index].items(): 553 | for i, val in map(cls.get_i_wt, lst): 554 | if i not in cls.in_db: 555 | continue 556 | cls.in_db[i][3+index][ent] = val 557 | 558 | # process the retrieved from db and add them to the stats 559 | for got in cls.in_db.values(): 560 | no_floats = sorted(filter(lambda x: not isinstance(x[0], float), 561 | got[3].items()), key=lambda x: x[0]) 562 | sorted_tags = ','.join(f'({k},{v:.1f})' for (k, v) in no_floats) 563 | QData.image_dups[sorted_tags].add(got[0]) 564 | cls.apply_filters(got) 565 | 566 | # average 567 | return cls.finalize(count) 568 | 569 | @staticmethod 570 | def sort_tags(tags: Dict[str, float]) -> List[Tuple[str, float]]: 571 | """ sort tags by value, return list of tuples """ 572 | return sorted(tags.items(), key=lambda x: x[1], reverse=True) 573 | 574 | @classmethod 575 | def get_image_dups(cls) -> List[str]: 576 | # first sort values so that those without a comma come first 577 | ordered = sorted(cls.image_dups.items(), key=lambda x: ',' in x[0]) 578 | return [str(x) for s in ordered if len(s[1]) > 1 for x in s[1]] 579 | 580 | @classmethod 581 | def finalize(cls, count: int) -> ItRetTP: 582 | """ finalize the query, return the results """ 583 | 584 | count += len(cls.in_db) 585 | if count == 0: 586 | return None, None, None, 'no results for query' 587 | 588 | ratings, tags, discarded_tags = {}, {}, {} 589 | 590 | for n in cls.for_tags_file.keys(): 591 | for k in cls.add_tags: 592 | cls.for_tags_file[n][k] = 1.0 * count 593 | 594 | for k in cls.add_tags: 595 | tags[k] = 1.0 596 | 597 | for k, lst in cls.tags.items(): 598 | # len(!) fraction of the all interrogations was above the threshold 599 | fraction_of_queries = len(lst) / count 600 | 601 | if fraction_of_queries >= cls.tag_frac_threshold: 602 | # store the average of those interrogations sum(!) / count 603 | tags[k] = sum(lst) / count 604 | # trigger an event to place the tag in the active tags list 605 | # replace if k interferes with html code 606 | else: 607 | discarded_tags[k] = sum(lst) / count 608 | for n in cls.for_tags_file.keys(): 609 | if k in cls.for_tags_file[n]: 610 | if k not in cls.add_tags and k not in cls.keep_tags: 611 | del cls.for_tags_file[n][k] 612 | 613 | for k, lst in cls.discarded_tags.items(): 614 | fraction_of_queries = len(lst) / count 615 | discarded_tags[k] = sum(lst) / count 616 | 617 | for ent, val in cls.ratings.items(): 618 | ratings[ent] = val / count 619 | 620 | weighted_tags_files = getattr(shared.opts, 621 | 'tagger_weighted_tags_files', False) 622 | for file, remaining_tags in cls.for_tags_file.items(): 623 | sorted_tags = cls.sort_tags(remaining_tags) 624 | if weighted_tags_files: 625 | sorted_tags = [f'({k}:{v})' for k, v in sorted_tags] 626 | else: 627 | sorted_tags = [k for k, v in sorted_tags] 628 | file.write_text(', '.join(sorted_tags), encoding='utf-8') 629 | 630 | warn = "" 631 | if len(QData.err) > 0: 632 | warn = "Warnings (fix and try again - it should be cheap):
        " + \ 633 | ''.join([f'
      • {x}
      • ' for x in QData.err]) + "
      " 634 | 635 | if count > 1 and len(cls.get_image_dups()) > 0: 636 | warn += "There were duplicates, see gallery tab" 637 | return ratings, tags, discarded_tags, warn 638 | -------------------------------------------------------------------------------- /tagger/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for the tagger module""" 2 | import os 3 | 4 | from typing import List, Dict 5 | from pathlib import Path 6 | 7 | from modules import shared, scripts # pylint: disable=import-error 8 | from modules.shared import models_path # pylint: disable=import-error 9 | 10 | default_ddp_path = Path(models_path, 'deepdanbooru') 11 | default_onnx_path = Path(models_path, 'TaggerOnnx') 12 | from tagger.preset import Preset # pylint: disable=import-error 13 | from tagger.interrogator import Interrogator, DeepDanbooruInterrogator, Z3DInterrogator, \ 14 | MLDanbooruInterrogator # pylint: disable=E0401 # noqa: E501 15 | from tagger.interrogator import WaifuDiffusionInterrogator # pylint: disable=E0401 # noqa: E501 16 | 17 | preset = Preset(Path(scripts.basedir(), 'presets')) 18 | 19 | interrogators: Dict[str, Interrogator] = { 20 | 'wd14-vit.v1': WaifuDiffusionInterrogator( 21 | 'WD14 ViT v1', 22 | repo_id='SmilingWolf/wd-v1-4-vit-tagger' 23 | ), 24 | 'wd14-vit.v2': WaifuDiffusionInterrogator( 25 | 'WD14 ViT v2', 26 | repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2', 27 | ), 28 | 'wd14-convnext.v1': WaifuDiffusionInterrogator( 29 | 'WD14 ConvNeXT v1', 30 | repo_id='SmilingWolf/wd-v1-4-convnext-tagger' 31 | ), 32 | 'wd14-convnext.v2': WaifuDiffusionInterrogator( 33 | 'WD14 ConvNeXT v2', 34 | repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2', 35 | ), 36 | 'wd14-convnextv2.v1': WaifuDiffusionInterrogator( 37 | 'WD14 ConvNeXTV2 v1', 38 | # the name is misleading, but it's v1 39 | repo_id='SmilingWolf/wd-v1-4-convnextv2-tagger-v2', 40 | ), 41 | 'wd14-eva02.v3.large': WaifuDiffusionInterrogator( 42 | 'WD14 EVA02 v3 Large', 43 | # Moved "Large" to the end to fix organization 44 | repo_id='SmilingWolf/wd-eva02-large-tagger-v3', 45 | ), 46 | 'wd14-swinv2-v1': WaifuDiffusionInterrogator( 47 | 'WD14 SwinV2 v1', 48 | # again misleading name 49 | repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2', 50 | ), 51 | 'wd-v1-4-moat-tagger.v2': WaifuDiffusionInterrogator( 52 | 'WD14 moat tagger v2', 53 | repo_id='SmilingWolf/wd-v1-4-moat-tagger-v2' 54 | ), 55 | 'wd-v1-4-vit-tagger.v3': WaifuDiffusionInterrogator( 56 | 'WD14 ViT v3', 57 | repo_id='SmilingWolf/wd-vit-tagger-v3' 58 | ), 59 | 'wd14-vit.v3.large': WaifuDiffusionInterrogator( 60 | 'WD14 ViT v3 Large', 61 | # Moved "Large" to the end to fix organization 62 | repo_id='SmilingWolf/wd-vit-large-tagger-v3', 63 | ), 64 | 'wd-v1-4-convnext-tagger.v3': WaifuDiffusionInterrogator( 65 | 'WD14 ConvNext v3', 66 | repo_id='SmilingWolf/wd-convnext-tagger-v3' 67 | ), 68 | 'wd-v1-4-swinv2-tagger.v3': WaifuDiffusionInterrogator( 69 | 'WD14 SwinV2 v3', 70 | repo_id='SmilingWolf/wd-swinv2-tagger-v3' 71 | ), 72 | 'mld-caformer.dec-5-97527': MLDanbooruInterrogator( 73 | 'ML-Danbooru Caformer dec-5-97527', 74 | repo_id='deepghs/ml-danbooru-onnx', 75 | model_path='ml_caformer_m36_dec-5-97527.onnx' 76 | ), 77 | 'mld-tresnetd.6-30000': MLDanbooruInterrogator( 78 | 'ML-Danbooru TResNet-D 6-30000', 79 | repo_id='deepghs/ml-danbooru-onnx', 80 | model_path='TResnet-D-FLq_ema_6-30000.onnx' 81 | ), 82 | 'Z3D-E621-Convnext': Z3DInterrogator( 83 | 'Z3D-E621-Convnext', 84 | repo_id='toynya/Z3D-E621-Convnext', 85 | model_path='model.onnx' 86 | ), 87 | } 88 | 89 | 90 | def refresh_interrogators() -> List[str]: 91 | """Refreshes the interrogators list""" 92 | # load deepdanbooru project 93 | ddp_path = shared.cmd_opts.deepdanbooru_projects_path 94 | if ddp_path is None: 95 | ddp_path = default_ddp_path 96 | onnx_path = shared.cmd_opts.onnxtagger_path 97 | if onnx_path is None: 98 | onnx_path = default_onnx_path 99 | os.makedirs(ddp_path, exist_ok=True) 100 | os.makedirs(onnx_path, exist_ok=True) 101 | 102 | for path in os.scandir(ddp_path): 103 | print(f"Scanning {path} as deepdanbooru project") 104 | if not path.is_dir(): 105 | print(f"Warning: {path} is not a directory, skipped") 106 | continue 107 | 108 | if not Path(path, 'project.json').is_file(): 109 | print(f"Warning: {path} has no project.json, skipped") 110 | continue 111 | 112 | interrogators[path.name] = DeepDanbooruInterrogator(path.name, path) 113 | # scan for onnx models as well 114 | for path in os.scandir(onnx_path): 115 | print(f"Scanning {path} as onnx model") 116 | if not path.is_dir(): 117 | print(f"Warning: {path} is not a directory, skipped") 118 | continue 119 | 120 | onnx_files = [x for x in os.scandir(path) if x.name.endswith('.onnx')] 121 | if len(onnx_files) != 1: 122 | print(f"Warning: {path} requires exactly one .onnx model, skipped") 123 | continue 124 | local_path = Path(path, onnx_files[0].name) 125 | 126 | csv = [x for x in os.scandir(path) if x.name.endswith('.csv')] 127 | if len(csv) == 0: 128 | print(f"Warning: {path} has no selected tags .csv file, skipped") 129 | continue 130 | 131 | def tag_select_csvs_up_front(k): 132 | sum(-1 if t in k.name.lower() else 1 for t in ["tag", "select"]) 133 | 134 | csv.sort(key=tag_select_csvs_up_front) 135 | tags_path = Path(path, csv[0]) 136 | 137 | if path.name not in interrogators: 138 | if path.name == 'wd-v1-4-convnextv2-tagger-v2': 139 | interrogators[path.name] = WaifuDiffusionInterrogator( 140 | path.name, 141 | repo_id='SmilingWolf/SW-CV-ModelZoo', 142 | is_hf=False 143 | ) 144 | elif path.name == 'Z3D-E621-Convnext': 145 | interrogators[path.name] = WaifuDiffusionInterrogator( 146 | 'Z3D-E621-Convnext', is_hf=False) 147 | else: 148 | raise NotImplementedError(f"Add {path.name} resolution similar" 149 | "to above here") 150 | 151 | interrogators[path.name].local_model = str(local_path) 152 | interrogators[path.name].local_tags = str(tags_path) 153 | 154 | return sorted(interrogators.keys()) 155 | 156 | 157 | def split_str(string: str, separator=',') -> List[str]: 158 | return [x.strip() for x in string.split(separator) if x] 159 | --------------------------------------------------------------------------------