├── version.txt
├── scripts
├── tokenizer
│ ├── __init__.py
│ └── clip_tokenizer.py
├── tag_editor_ui
│ ├── __init__.py
│ ├── ui_common.py
│ ├── uibase.py
│ ├── ui_instance.py
│ ├── ui_classes.py
│ ├── block_gallery_state.py
│ ├── block_toprow.py
│ ├── block_dataset_gallery.py
│ ├── tab_filter_by_tags.py
│ ├── block_tag_select.py
│ ├── tab_move_or_delete_files.py
│ ├── block_tag_filter.py
│ ├── tab_filter_by_selection.py
│ ├── block_load_dataset.py
│ ├── tab_batch_edit_captions.py
│ └── tab_edit_caption_of_selected_image.py
├── dte_instance.py
├── logger.py
├── singleton.py
├── dataset_tag_editor
│ ├── __init__.py
│ ├── interrogators
│ │ ├── __init__.py
│ │ ├── blip2_captioning.py
│ │ ├── git_large_captioning.py
│ │ ├── waifu_diffusion_tagger_timm.py
│ │ └── waifu_diffusion_tagger.py
│ ├── interrogator_names.py
│ ├── custom_scripts.py
│ ├── dataset.py
│ ├── kohya_finetune_metadata.py
│ ├── filters.py
│ └── taggers_builtin.py
├── paths.py
├── model_loader.py
├── tagger.py
├── utilities.py
├── config.py
└── main.py
├── .gitignore
├── ss01.png
├── pic
├── ss02.png
├── ss03.png
├── ss04.png
├── ss05.png
├── ss06.png
├── ss07.png
├── ss08.png
├── ss09.png
├── ss10.png
└── ss12.png
├── javascript
├── 90_ui.js
├── 99_main.js
└── 00_modified_gallery.js
├── style.css
├── .github
├── ISSUE_TEMPLATE
│ ├── feature_request.md
│ └── bug_report.md
└── workflows
│ └── codeql.yml
├── LICENSE
├── userscripts
└── taggers
│ ├── cafeai_aesthetic_classifier.py
│ ├── aesthetic_shadow.py
│ ├── waifu_aesthetic_classifier.py
│ └── improved_aesthetic_predictor.py
├── README-JP.md
├── DESCRIPTION_OF_DISPLAY.md
├── DESCRIPTION_OF_DISPLAY-JP.md
└── README.md
/version.txt:
--------------------------------------------------------------------------------
1 | 0.3.4
2 |
--------------------------------------------------------------------------------
/scripts/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | from . import clip_tokenizer
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/__init__.py:
--------------------------------------------------------------------------------
1 | from .ui_common import *
2 | from .ui_instance import *
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/ui_common.py:
--------------------------------------------------------------------------------
1 | from scripts.dte_instance import dte_instance, dte_module
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | config.json
4 | .vscode/
5 | models/
--------------------------------------------------------------------------------
/ss01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/ss01.png
--------------------------------------------------------------------------------
/pic/ss02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss02.png
--------------------------------------------------------------------------------
/pic/ss03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss03.png
--------------------------------------------------------------------------------
/pic/ss04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss04.png
--------------------------------------------------------------------------------
/pic/ss05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss05.png
--------------------------------------------------------------------------------
/pic/ss06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss06.png
--------------------------------------------------------------------------------
/pic/ss07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss07.png
--------------------------------------------------------------------------------
/pic/ss08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss08.png
--------------------------------------------------------------------------------
/pic/ss09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss09.png
--------------------------------------------------------------------------------
/pic/ss10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss10.png
--------------------------------------------------------------------------------
/pic/ss12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/HEAD/pic/ss12.png
--------------------------------------------------------------------------------
/scripts/dte_instance.py:
--------------------------------------------------------------------------------
1 | import scripts.dataset_tag_editor as dte_module
2 | dte_instance = dte_module.DatasetTagEditor()
--------------------------------------------------------------------------------
/scripts/logger.py:
--------------------------------------------------------------------------------
1 | def write(content):
2 | print("[tag-editor] " + content)
3 |
4 | def warn(content):
5 | write(f"[tag-editor:WARNING] {content}")
6 |
7 | def error(content):
8 | write(f"[tag-editor:ERROR] {content}")
--------------------------------------------------------------------------------
/scripts/singleton.py:
--------------------------------------------------------------------------------
1 | class Singleton(object):
2 | _instance = None
3 | def __new__(class_, *args, **kwargs):
4 | if not isinstance(class_._instance, class_):
5 | class_._instance = object.__new__(class_, *args, **kwargs)
6 | return class_._instance
7 |
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/__init__.py:
--------------------------------------------------------------------------------
1 | from . import taggers_builtin
2 | from . import filters
3 | from . import dataset as ds
4 | from . import kohya_finetune_metadata
5 |
6 | from .dte_logic import DatasetTagEditor
7 |
8 | __all__ = ["ds", "taggers_builtin", "filters", "kohya_finetune_metadata", "DatasetTagEditor"]
9 |
--------------------------------------------------------------------------------
/javascript/90_ui.js:
--------------------------------------------------------------------------------
1 | function dataset_tag_editor_ask_save_change_or_not(idx){
2 | if (idx < 0){
3 | return -1
4 | }
5 | res = window.confirm(`Save changes in captions?`)
6 | if(res){
7 | let set_button = gradioApp().getElementById("dataset_tag_editor_btn_hidden_save_caption");
8 | if(set_button){
9 | set_button.click()
10 | }
11 | }
12 | }
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/interrogators/__init__.py:
--------------------------------------------------------------------------------
1 | from .blip2_captioning import BLIP2Captioning
2 | from .git_large_captioning import GITLargeCaptioning
3 | from .waifu_diffusion_tagger import WaifuDiffusionTagger
4 | from .waifu_diffusion_tagger_timm import WaifuDiffusionTaggerTimm
5 |
6 | __all__ = [
7 | "BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger', 'WaifuDiffusionTaggerTimm'
8 | ]
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/uibase.py:
--------------------------------------------------------------------------------
1 | from scripts.singleton import Singleton
2 |
3 | class UIBase(Singleton):
4 | def create_ui(self, *args, **kwargs):
5 | raise NotImplementedError()
6 | def set_callbacks(self, *args, **kwargs):
7 | raise NotImplementedError()
8 | def func_to_set_value(self, name, type=None):
9 | def func(value):
10 | if type is not None:
11 | value = type(value)
12 | setattr(self, name, value)
13 | return value
14 | return func
--------------------------------------------------------------------------------
/scripts/paths.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from scripts.singleton import Singleton
4 |
5 | def base_dir_path():
6 | return Path(__file__).parents[1].absolute()
7 |
8 | def base_dir():
9 | return str(base_dir_path())
10 |
11 | class Paths(Singleton):
12 | def __init__(self):
13 | self.base_path:Path = base_dir_path()
14 | self.script_path: Path = self.base_path / "scripts"
15 | self.userscript_path: Path = self.base_path / "userscripts"
16 | self.model_path = self.base_path / "models"
17 |
18 | paths = Paths()
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | .token-counter-dte{
2 | position: absolute;
3 | display: inline-block;
4 | right: 2em;
5 | min-width: 0 !important;
6 | width: auto;
7 | z-index: 100;
8 | }
9 |
10 | .token-counter-dte div{
11 | display: inline;
12 | }
13 |
14 | .token-counter-dte span{
15 | padding: 0.1em 0.75em;
16 | }
17 |
18 | .token-counter-dte span{
19 | background: var(--input-background-fill) !important;
20 | box-shadow: 0 0 0.0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075);
21 | border: 2px solid rgba(192,192,192,0.4) !important;
22 | border-radius: 0.4em;
23 | }
--------------------------------------------------------------------------------
/scripts/model_loader.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from torch.hub import download_url_to_file
4 |
5 | def load(model_path:Path, model_url:str, progress:bool=True, force_download:bool=False):
6 | model_path = Path(model_path)
7 | if model_path.exists():
8 | return model_path
9 |
10 | if model_url is not None and (force_download or not model_path.is_file()):
11 | if not model_path.parent.is_dir():
12 | model_path.parent.mkdir(parents=True)
13 | download_url_to_file(model_url, model_path, progress=progress)
14 | return model_path
15 |
16 | return model_path
17 |
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/ui_instance.py:
--------------------------------------------------------------------------------
1 | from .ui_classes import *
2 |
3 | __all__ = [
4 | 'toprow', 'load_dataset', 'dataset_gallery', 'gallery_state', 'filter_by_tags', 'filter_by_selection', 'batch_edit_captions', 'edit_caption_of_selected_image', 'move_or_delete_files'
5 | ]
6 |
7 | toprow = ToprowUI()
8 | load_dataset = LoadDatasetUI()
9 | dataset_gallery = DatasetGalleryUI()
10 | gallery_state = GalleryStateUI()
11 | filter_by_tags = FilterByTagsUI()
12 | filter_by_selection = FilterBySelectionUI()
13 | batch_edit_captions = BatchEditCaptionsUI()
14 | edit_caption_of_selected_image = EditCaptionOfSelectedImageUI()
15 | move_or_delete_files = MoveOrDeleteFilesUI()
16 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/ui_classes.py:
--------------------------------------------------------------------------------
1 | from .block_toprow import ToprowUI
2 | from .block_load_dataset import LoadDatasetUI
3 | from .block_dataset_gallery import DatasetGalleryUI
4 | from .block_gallery_state import GalleryStateUI
5 | from .block_tag_filter import TagFilterUI
6 | from .block_tag_select import TagSelectUI
7 | from .tab_filter_by_tags import FilterByTagsUI
8 | from .tab_filter_by_selection import FilterBySelectionUI
9 | from .tab_batch_edit_captions import BatchEditCaptionsUI
10 | from .tab_edit_caption_of_selected_image import EditCaptionOfSelectedImageUI
11 | from .tab_move_or_delete_files import MoveOrDeleteFilesUI
12 |
13 | __all__ = [
14 | 'ToprowUI', 'LoadDatasetUI', 'DatasetGalleryUI', 'GalleryStateUI', 'TagFilterUI', 'TagSelectUI', 'FilterByTagsUI', 'FilterBySelectionUI', 'BatchEditCaptionsUI', 'EditCaptionOfSelectedImageUI', 'MoveOrDeleteFilesUI'
15 | ]
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/interrogator_names.py:
--------------------------------------------------------------------------------
1 |
2 | BLIP2_CAPTIONING_NAMES = [
3 | "blip2-opt-2.7b",
4 | "blip2-opt-2.7b-coco",
5 | "blip2-opt-6.7b",
6 | "blip2-opt-6.7b-coco",
7 | "blip2-flan-t5-xl",
8 | "blip2-flan-t5-xl-coco",
9 | "blip2-flan-t5-xxl",
10 | ]
11 |
12 |
13 | # {tagger name : default tagger threshold}
14 | # v1: idk if it's okay v2, v3: P=R thresholds on each repo https://huggingface.co/SmilingWolf
15 | WD_TAGGERS = {
16 | "wd-v1-4-vit-tagger" : 0.35,
17 | "wd-v1-4-convnext-tagger" : 0.35,
18 | "wd-v1-4-vit-tagger-v2" : 0.3537,
19 | "wd-v1-4-convnext-tagger-v2" : 0.3685,
20 | "wd-v1-4-convnextv2-tagger-v2" : 0.371,
21 | "wd-v1-4-moat-tagger-v2" : 0.3771
22 | }
23 | WD_TAGGERS_TIMM = {
24 | "wd-v1-4-swinv2-tagger-v2" : 0.3771,
25 | "wd-vit-tagger-v3" : 0.2614,
26 | "wd-convnext-tagger-v3" : 0.2682,
27 | "wd-swinv2-tagger-v3" : 0.2653,
28 | }
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Environment (please complete the following information):**
27 | - OS: [e.g. Windows, Linux]
28 | - Browser: [e.g. chrome, safari]
29 | - Version of SD WebUI: [e.g. v1.9.3, by AUTOMATIC1111]
30 | - Version of this app: [e.g. v0.0.7]
31 |
32 | **Additional context**
33 | Add any other context about the problem here.
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 toshiaki1729
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/block_gallery_state.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING
3 | import gradio as gr
4 |
5 | from .ui_common import *
6 | from .uibase import UIBase
7 |
8 | if TYPE_CHECKING:
9 | from .ui_classes import *
10 |
11 | class GalleryStateUI(UIBase):
12 | def __init__(self):
13 | self.texts = dict()
14 |
15 | def register_value(self, key:str, value:str):
16 | self.texts[key] = value
17 |
18 | def remove_key(self, key:str):
19 | del self.texts[key]
20 |
21 | def get_current_gallery_txt(self):
22 | res = ''
23 | for k, v in self.texts.items():
24 | res += f'{k} : {v}
'
25 | return res
26 |
27 | def create_ui(self):
28 | self.txt_gallery = gr.HTML(value=self.get_current_gallery_txt())
29 |
30 | def set_callbacks(self, dataset_gallery:DatasetGalleryUI):
31 | dataset_gallery.nb_hidden_image_index.change(fn=lambda:None).then(
32 | fn=self.update_gallery_txt,
33 | inputs=None,
34 | outputs=self.txt_gallery
35 | )
36 |
37 | def update_gallery_txt(self):
38 | return self.get_current_gallery_txt()
39 |
40 |
41 |
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/interrogators/blip2_captioning.py:
--------------------------------------------------------------------------------
1 | from transformers import Blip2Processor, Blip2ForConditionalGeneration
2 |
3 | from modules import devices, shared
4 | from scripts.paths import paths
5 |
6 |
7 | class BLIP2Captioning:
8 | def __init__(self, model_repo: str):
9 | self.MODEL_REPO = model_repo
10 | self.processor: Blip2Processor = None
11 | self.model: Blip2ForConditionalGeneration = None
12 |
13 | def load(self):
14 | if self.model is None or self.processor is None:
15 | self.processor = Blip2Processor.from_pretrained(self.MODEL_REPO)
16 | self.model = Blip2ForConditionalGeneration.from_pretrained(self.MODEL_REPO).to(devices.device)
17 |
18 | def unload(self):
19 | if not shared.opts.interrogate_keep_models_in_memory:
20 | self.model = None
21 | self.processor = None
22 | devices.torch_gc()
23 |
24 | def apply(self, image):
25 | if self.model is None or self.processor is None:
26 | return ""
27 | inputs = self.processor(images=image, return_tensors="pt").to(devices.device)
28 | ids = self.model.generate(**inputs)
29 | return self.processor.batch_decode(ids, skip_special_tokens=True)
30 |
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/interrogators/git_large_captioning.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoProcessor, AutoModelForCausalLM
2 | from modules import shared, devices, lowvram
3 |
4 | from scripts.paths import paths
5 |
6 |
7 | # brought from https://huggingface.co/docs/transformers/main/en/model_doc/git and modified
8 | class GITLargeCaptioning:
9 | MODEL_REPO = "microsoft/git-large-coco"
10 |
11 | def __init__(self):
12 | self.processor: AutoProcessor = None
13 | self.model: AutoModelForCausalLM = None
14 |
15 | def load(self):
16 | if self.model is None or self.processor is None:
17 | self.processor = AutoProcessor.from_pretrained(self.MODEL_REPO)
18 | self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_REPO).to(shared.device)
19 | lowvram.send_everything_to_cpu()
20 |
21 | def unload(self):
22 | if not shared.opts.interrogate_keep_models_in_memory:
23 | self.model = None
24 | self.processor = None
25 | devices.torch_gc()
26 |
27 | def apply(self, image):
28 | if self.model is None or self.processor is None:
29 | return ""
30 | inputs = self.processor(images=image, return_tensors="pt").to(shared.device)
31 | ids = self.model.generate(
32 | pixel_values=inputs.pixel_values,
33 | max_length=shared.opts.interrogate_clip_max_length,
34 | )
35 | return self.processor.batch_decode(ids, skip_special_tokens=True)[0]
36 |
--------------------------------------------------------------------------------
/scripts/tagger.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Optional, Generator, Any
3 |
4 | from PIL import Image
5 |
6 | from modules import shared, lowvram, devices
7 | from modules import deepbooru as db
8 |
9 | # Custom tagger classes have to inherit from this class
10 | class Tagger:
11 | def __enter__(self):
12 | lowvram.send_everything_to_cpu()
13 | devices.torch_gc()
14 | self.start()
15 | return self
16 |
17 | def __exit__(self, exception_type, exception_value, traceback):
18 | self.stop()
19 | pass
20 |
21 | def start(self):
22 | pass
23 |
24 | def stop(self):
25 | pass
26 |
27 | # predict tags of one image
28 | def predict(self, image: Image.Image, threshold: Optional[float] = None) -> list[str]:
29 | raise NotImplementedError()
30 |
31 | # Please implement if you want to use more efficient data loading system
32 | # None input will come to check if this function is implemented
33 | def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None) -> Generator[list[str], Any, None]:
34 | raise NotImplementedError()
35 |
36 | # Visible name in UI
37 | def name(self):
38 | raise NotImplementedError()
39 |
40 |
41 | def get_replaced_tag(tag: str):
42 | use_spaces = shared.opts.deepbooru_use_spaces
43 | use_escape = shared.opts.deepbooru_escape
44 | if use_spaces:
45 | tag = tag.replace('_', ' ')
46 | if use_escape:
47 | tag = re.sub(db.re_special, r'\\\1', tag)
48 | return tag
49 |
50 |
51 | def get_arranged_tags(probs: dict[str, float]):
52 | return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])]
53 |
--------------------------------------------------------------------------------
/userscripts/taggers/cafeai_aesthetic_classifier.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from PIL import Image
4 | from transformers import pipeline
5 | import torch
6 |
7 | from modules import devices, shared
8 | from scripts.tagger import Tagger
9 |
10 | # brought and modified from https://huggingface.co/spaces/cafeai/cafe_aesthetic_demo/blob/main/app.py
11 |
12 | # I'm not sure if this is really working
13 | BATCH_SIZE = 8
14 |
15 | class CafeAIAesthetic(Tagger):
16 | def load(self):
17 | if devices.device.index is None:
18 | dev = torch.device(devices.device.type, 0)
19 | else:
20 | dev = devices.device
21 | self.pipe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic", device=dev, batch_size=BATCH_SIZE)
22 |
23 | def unload(self):
24 | if not shared.opts.interrogate_keep_models_in_memory:
25 | self.pipe_aesthetic = None
26 | devices.torch_gc()
27 |
28 | def start(self):
29 | self.load()
30 | return self
31 |
32 | def stop(self):
33 | self.unload()
34 |
35 | def _get_score(self, data):
36 | final = {}
37 | for d in data:
38 | final[d["label"]] = d["score"]
39 | ae = final['aesthetic']
40 |
41 | # edit here to change tag
42 | return [f"[CAFE]score_{math.floor(ae*10)}"]
43 |
44 | def predict(self, image: Image.Image, threshold=None):
45 | data = self.pipe_aesthetic(image, top_k=2)
46 | return self._get_score(data)
47 |
48 | def predict_pipe(self, data: list[Image.Image], threshold=None):
49 | if data is None:
50 | return
51 | for out in self.pipe_aesthetic(data, batch_size=BATCH_SIZE):
52 | yield self._get_score(out)
53 |
54 | def name(self):
55 | return "cafeai aesthetic classifier"
--------------------------------------------------------------------------------
/scripts/tokenizer/clip_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified
2 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py
3 |
4 | from typing import List
5 | from functools import reduce
6 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
7 | from modules import shared, extra_networks, prompt_parser
8 | from modules.sd_hijack import model_hijack
9 | import open_clip.tokenizer
10 |
11 | class VanillaClip:
12 | def __init__(self, clip):
13 | self.clip = clip
14 |
15 | def vocab(self):
16 | return self.clip.tokenizer.get_vocab()
17 |
18 | def byte_decoder(self):
19 | return self.clip.tokenizer.byte_decoder
20 |
21 | class OpenClip:
22 | def __init__(self, clip):
23 | self.clip = clip
24 | self.tokenizer = open_clip.tokenizer._tokenizer
25 |
26 | def vocab(self):
27 | return self.tokenizer.encoder
28 |
29 | def byte_decoder(self):
30 | return self.tokenizer.byte_decoder
31 |
32 | def tokenize(text:str, use_raw_clip:bool=True):
33 | if use_raw_clip:
34 | tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
35 | token_count = len(tokens)
36 | else:
37 | try:
38 | text, _ = extra_networks.parse_prompt(text)
39 | _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
40 | prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list)
41 | except Exception:
42 | prompt = text
43 | token_chunks, token_count = model_hijack.clip.tokenize_line(prompt)
44 | tokens = reduce(lambda list1, list2: list1+list2, [tc.tokens for tc in token_chunks])
45 | return tokens, token_count
46 |
47 | def get_target_token_count(token_count:int):
48 | return model_hijack.clip.get_target_prompt_token_count(token_count)
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/custom_scripts.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 | import importlib.util
4 | from types import ModuleType
5 |
6 | from scripts import logger
7 | from scripts.paths import paths
8 |
9 |
10 | class CustomScripts:
11 | def _load_module_from(self, path:Path):
12 | module_spec = importlib.util.spec_from_file_location(path.stem, path)
13 | module = importlib.util.module_from_spec(module_spec)
14 | module_spec.loader.exec_module(module)
15 | return module
16 |
17 | def _load_derived_classes(self, module:ModuleType, base_class:type):
18 | derived_classes = []
19 | for name in dir(module):
20 | obj = getattr(module, name)
21 | if isinstance(obj, type) and issubclass(obj, base_class) and obj is not base_class:
22 | derived_classes.append(obj)
23 |
24 | return derived_classes
25 |
26 | def __init__(self, scripts_dir:Path) -> None:
27 | self.scripts = dict()
28 | self.scripts_dir = scripts_dir.absolute()
29 |
30 | def load_derived_classes(self, baseclass:type):
31 | back_syspath = sys.path
32 | if not self.scripts_dir.is_dir():
33 | logger.warn(f"NOT A DIRECTORY: {self.scripts_dir}")
34 | return []
35 |
36 | classes = []
37 | try:
38 | sys.path = [str(paths.base_path)] + sys.path
39 | for path in self.scripts_dir.glob("*.py"):
40 | self.scripts[path.stem] = self._load_module_from(path)
41 | for module in self.scripts.values():
42 | classes.extend(self._load_derived_classes(module, baseclass))
43 | except Exception as e:
44 | tb = sys.exc_info()[2]
45 | logger.error(f"Error on loading {path}")
46 | logger.error(e.with_traceback(tb))
47 | finally:
48 | sys.path = back_syspath
49 |
50 | return classes
--------------------------------------------------------------------------------
/userscripts/taggers/aesthetic_shadow.py:
--------------------------------------------------------------------------------
1 | # This code is using the image classification "aesthetic-shadow-v2" by shadowlilac (https://huggingface.co/shadowlilac/aesthetic-shadow-v2)
2 | # and "aesthetic-shadow-v2" is licensed under CC-BY-NC 4.0 (https://spdx.org/licenses/CC-BY-NC-4.0)
3 |
4 | import math
5 |
6 | from PIL import Image
7 | from transformers import pipeline
8 | import torch
9 |
10 | from modules import devices, shared
11 | from scripts.tagger import Tagger
12 |
13 | # brought and modified from https://huggingface.co/spaces/cafeai/cafe_aesthetic_demo/blob/main/app.py
14 |
15 | # I'm not sure if this is really working
16 | BATCH_SIZE = 3
17 |
18 | # tags used in Animagine-XL
19 | SCORE_N = {
20 | 'very aesthetic':0.71,
21 | 'aesthetic':0.45,
22 | 'displeasing':0.27,
23 | 'very displeasing':-float('inf'),
24 | }
25 |
26 | def get_aesthetic_tag(score:float):
27 | for k, v in SCORE_N.items():
28 | if score > v:
29 | return k
30 |
31 | class AestheticShadowV2(Tagger):
32 | def load(self):
33 | if devices.device.index is None:
34 | dev = torch.device(devices.device.type, 0)
35 | else:
36 | dev = devices.device
37 | self.pipe_aesthetic = pipeline("image-classification", "shadowlilac/aesthetic-shadow-v2", device=dev, batch_size=BATCH_SIZE)
38 |
39 | def unload(self):
40 | if not shared.opts.interrogate_keep_models_in_memory:
41 | self.pipe_aesthetic = None
42 | devices.torch_gc()
43 |
44 | def start(self):
45 | self.load()
46 | return self
47 |
48 | def stop(self):
49 | self.unload()
50 |
51 | def _get_score(self, data):
52 | final = {}
53 | for d in data:
54 | final[d["label"]] = d["score"]
55 | hq = final['hq']
56 | return [get_aesthetic_tag(hq)]
57 |
58 | def predict(self, image: Image.Image, threshold=None):
59 | data = self.pipe_aesthetic(image)
60 | return self._get_score(data)
61 |
62 | def predict_pipe(self, data: list[Image.Image], threshold=None):
63 | if data is None:
64 | return
65 | for out in self.pipe_aesthetic(data, batch_size=BATCH_SIZE):
66 | yield self._get_score(out)
67 |
68 | def name(self):
69 | return "aesthetic shadow"
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Set, Dict
2 |
3 | class Data:
4 | def __init__(self, imgpath: str, caption: str):
5 | self.imgpath = imgpath
6 | self.tags = [t.strip() for t in caption.split(',')]
7 | self.tagset = set(self.tags)
8 |
9 | def tag_contains_allof(self, tags: Set[str]):
10 | return self.tagset.issuperset(tags)
11 |
12 | def tag_contains_anyof(self, tags: Set[str]):
13 | return not self.tagset.isdisjoint(tags)
14 |
15 |
16 | class Dataset:
17 | def __init__(self):
18 | self.datas: Dict[str, Data] = dict()
19 |
20 | def __len__(self):
21 | return len(self.datas)
22 |
23 | def clear(self):
24 | self.datas.clear()
25 |
26 | def merge(self, dataset, overwrite: bool = True):
27 | if type(dataset) is Dataset:
28 | for path in dataset.datas.keys():
29 | if overwrite or path not in self.datas.keys():
30 | self.datas[path] = dataset[path]
31 | return self
32 |
33 | def append_data(self, data: Data):
34 | self.datas[data.imgpath] = data
35 |
36 | def remove(self, dataset):
37 | if type(dataset) is Dataset:
38 | for path in dataset.datas.keys():
39 | if path in self.datas.keys():
40 | del self.datas[path]
41 | return self
42 |
43 | def remove_by_path(self, path: str):
44 | if path in self.datas.keys():
45 | del self.datas[path]
46 |
47 | def copy(self):
48 | res = Dataset()
49 | res.datas = self.datas.copy()
50 | return res
51 |
52 | def filter(self, filter):
53 | return filter.apply(self)
54 |
55 | def get_data(self, path: str):
56 | return self.datas.get(path)
57 |
58 | def get_data_tags(self, path: str):
59 | data = self.get_data(path)
60 | if data:
61 | return data.tags
62 | else:
63 | return []
64 |
65 | def get_data_tagset(self, path: str):
66 | data = self.get_data(path)
67 | if data:
68 | return data.tagset
69 | else:
70 | return {}
71 |
72 | def get_tagset(self):
73 | tags = set()
74 | for data in self.datas.values():
75 | tags |= data.tagset
76 | return tags
77 |
78 | def get_taglist(self):
79 | return [t for t in self.get_tagset()]
80 |
81 |
--------------------------------------------------------------------------------
/scripts/utilities.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import math
3 |
4 | from PIL import Image
5 |
6 | if not hasattr(Image, 'Resampling'): # Pillow<9.0
7 | Image.Resampling = Image
8 |
9 |
10 | def resize(image: Image.Image, size: Tuple[int, int]):
11 | return image.resize(size, resample=Image.Resampling.LANCZOS)
12 |
13 |
14 | def get_rgb_image(image:Image.Image):
15 | if image.mode not in ["RGB", "RGBA"]:
16 | image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
17 | if image.mode == "RGBA":
18 | white = Image.new("RGBA", image.size, (255, 255, 255, 255))
19 | white.alpha_composite(image)
20 | image = white.convert("RGB")
21 | return image
22 |
23 |
24 | def resize_and_fill(image: Image.Image, size: Tuple[int, int], repeat_edge = True, fill_rgb:tuple[int,int,int] = (255, 255, 255)):
25 | width, height = size
26 | scale_w, scale_h = width / image.width, height / image.height
27 | resized_w, resized_h = width, height
28 | if scale_w < scale_h:
29 | resized_h = image.height * resized_w // image.width
30 | elif scale_h < scale_w:
31 | resized_w = image.width * resized_h // image.height
32 |
33 | resized = resize(image, (resized_w, resized_h))
34 | if resized_w == width and resized_h == height:
35 | return resized
36 |
37 | if repeat_edge:
38 | fill_l = math.floor((width - resized_w) / 2)
39 | fill_r = width - resized_w - fill_l
40 | fill_t = math.floor((height - resized_h) / 2)
41 | fill_b = height - resized_h - fill_t
42 | result = Image.new("RGB", (width, height))
43 | result.paste(resized, (fill_l, fill_t))
44 | if fill_t > 0:
45 | result.paste(resized.resize((width, fill_t), box=(0, 0, width, 0)), (0, 0))
46 | if fill_b > 0:
47 | result.paste(
48 | resized.resize(
49 | (width, fill_b), box=(0, resized.height, width, resized.height)
50 | ),
51 | (0, resized.height + fill_t),
52 | )
53 | if fill_l > 0:
54 | result.paste(resized.resize((fill_l, height), box=(0, 0, 0, height)), (0, 0))
55 | if fill_r > 0:
56 | result.paste(
57 | resized.resize(
58 | (fill_r, height), box=(resized.width, 0, resized.width, height)
59 | ),
60 | (resized.width + fill_l, 0),
61 | )
62 | return result
63 | else:
64 | result = Image.new("RGB", size, fill_rgb)
65 | result.paste(resized, box=((width - resized_w) // 2, (height - resized_h) // 2))
66 | return result.convert("RGB")
67 |
68 |
69 |
--------------------------------------------------------------------------------
/userscripts/taggers/waifu_aesthetic_classifier.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | import numpy as np
4 | import math
5 |
6 | from transformers import CLIPModel, CLIPProcessor
7 |
8 | from modules import devices, shared
9 | from scripts import model_loader
10 | from scripts.paths import paths
11 | from scripts.tagger import Tagger
12 |
13 | # brought from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py
14 | class Classifier(torch.nn.Module):
15 | def __init__(self, input_size, hidden_size, output_size):
16 | super(Classifier, self).__init__()
17 | self.fc1 = torch.nn.Linear(input_size, hidden_size)
18 | self.fc2 = torch.nn.Linear(hidden_size, hidden_size//2)
19 | self.fc3 = torch.nn.Linear(hidden_size//2, output_size)
20 | self.relu = torch.nn.ReLU()
21 | self.sigmoid = torch.nn.Sigmoid()
22 |
23 | def forward(self, x:torch.Tensor):
24 | x = self.fc1(x)
25 | x = self.relu(x)
26 | x = self.fc2(x)
27 | x = self.relu(x)
28 | x = self.fc3(x)
29 | x = self.sigmoid(x)
30 | return x
31 |
32 | # brought and modified from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py
33 | def image_embeddings(image:Image, model:CLIPModel, processor:CLIPProcessor):
34 | inputs = processor(images=image, return_tensors='pt')['pixel_values']
35 | inputs = inputs.to(devices.device)
36 | result:np.ndarray = model.get_image_features(pixel_values=inputs).cpu().detach().numpy()
37 | return (result / np.linalg.norm(result)).squeeze(axis=0)
38 |
39 |
40 | class WaifuAesthetic(Tagger):
41 | def load(self):
42 | file = model_loader.load(
43 | model_path=paths.models_path / "aesthetic" / "aes-B32-v0.pth",
44 | model_url='https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/models/aes-B32-v0.pth'
45 | )
46 | CLIP_REPOS = 'openai/clip-vit-base-patch32'
47 | self.model = Classifier(512, 256, 1)
48 | self.model.load_state_dict(torch.load(file))
49 | self.model = self.model.to(devices.device)
50 | self.clip_processor = CLIPProcessor.from_pretrained(CLIP_REPOS)
51 | self.clip_model = CLIPModel.from_pretrained(CLIP_REPOS).to(devices.device).eval()
52 |
53 | def unload(self):
54 | if not shared.opts.interrogate_keep_models_in_memory:
55 | self.model = None
56 | self.clip_processor = None
57 | self.clip_model = None
58 | devices.torch_gc()
59 |
60 | def start(self):
61 | self.load()
62 | return self
63 |
64 | def stop(self):
65 | self.unload()
66 |
67 | def predict(self, image: Image.Image, threshold=None):
68 | image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
69 | prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
70 | # edit here to change tag
71 | return [f"[WD]score_{math.floor(prediction.item()*10)}"]
72 |
73 | def name(self):
74 | return "wd aesthetic classifier"
--------------------------------------------------------------------------------
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ "main" ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ "main" ]
20 | schedule:
21 | - cron: '34 9 * * 3'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'javascript', 'python' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
38 |
39 | steps:
40 | - name: Checkout repository
41 | uses: actions/checkout@v3
42 |
43 | # Initializes the CodeQL tools for scanning.
44 | - name: Initialize CodeQL
45 | uses: github/codeql-action/init@v2
46 | with:
47 | languages: ${{ matrix.language }}
48 | # If you wish to specify custom queries, you can do so here or in a config file.
49 | # By default, queries listed here will override any specified in a config file.
50 | # Prefix the list here with "+" to use these queries and those in the config file.
51 |
52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
53 | # queries: security-extended,security-and-quality
54 |
55 |
56 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
57 | # If this step fails, then you should remove it and run the build manually (see below)
58 | - name: Autobuild
59 | uses: github/codeql-action/autobuild@v2
60 |
61 | # ℹ️ Command-line programs to run using the OS shell.
62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
63 |
64 | # If the Autobuild fails above, remove it and uncomment the following three lines.
65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
66 |
67 | # - run: |
68 | # echo "Run, Build Application using script"
69 | # ./location_of_script_within_repo/buildscript.sh
70 |
71 | - name: Perform CodeQL Analysis
72 | uses: github/codeql-action/analyze@v2
73 | with:
74 | category: "/language:${{matrix.language}}"
75 |
--------------------------------------------------------------------------------
/userscripts/taggers/improved_aesthetic_predictor.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import math
6 |
7 | from transformers import CLIPModel, CLIPProcessor
8 |
9 | from modules import devices, shared
10 | from scripts import model_loader
11 | from scripts.paths import paths
12 | from scripts.tagger import Tagger
13 |
14 | # brought from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py and modified
15 | class Classifier(nn.Module):
16 | def __init__(self, input_size):
17 | super().__init__()
18 | self.input_size = input_size
19 | self.layers = nn.Sequential(
20 | nn.Linear(self.input_size, 1024),
21 | nn.Dropout(0.2),
22 | nn.Linear(1024, 128),
23 | nn.Dropout(0.2),
24 | nn.Linear(128, 64),
25 | nn.Dropout(0.1),
26 | nn.Linear(64, 16),
27 | nn.Linear(16, 1)
28 | )
29 |
30 | def forward(self, x):
31 | return self.layers(x)
32 |
33 | # brought and modified from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py
34 | def image_embeddings(image:Image, model:CLIPModel, processor:CLIPProcessor):
35 | inputs = processor(images=image, return_tensors='pt')['pixel_values']
36 | inputs = inputs.to(devices.device)
37 | result:np.ndarray = model.get_image_features(pixel_values=inputs).cpu().detach().numpy()
38 | return (result / np.linalg.norm(result)).squeeze(axis=0)
39 |
40 |
41 | class ImprovedAestheticPredictor(Tagger):
42 | def load(self):
43 | MODEL_VERSION = "sac+logos+ava1-l14-linearMSE"
44 | file = model_loader.load(
45 | model_path=paths.models_path / "aesthetic" / f"{MODEL_VERSION}.pth",
46 | model_url=f'https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/{MODEL_VERSION}.pth'
47 | )
48 | CLIP_REPOS = 'openai/clip-vit-large-patch14'
49 | self.model = Classifier(768)
50 | self.model.load_state_dict(torch.load(file))
51 | self.model = self.model.to(devices.device)
52 | self.clip_processor = CLIPProcessor.from_pretrained(CLIP_REPOS)
53 | self.clip_model = CLIPModel.from_pretrained(CLIP_REPOS).to(devices.device).eval()
54 |
55 | def unload(self):
56 | if not shared.opts.interrogate_keep_models_in_memory:
57 | self.model = None
58 | self.clip_processor = None
59 | self.clip_model = None
60 | devices.torch_gc()
61 |
62 | def start(self):
63 | self.load()
64 | return self
65 |
66 | def stop(self):
67 | self.unload()
68 |
69 | def predict(self, image: Image.Image, threshold=None):
70 | image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
71 | prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
72 | # edit here to change tag
73 | return [f"[IAP]score_{math.floor(prediction.item())}"]
74 |
75 | def name(self):
76 | return "Improved Aesthetic Predictor"
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/block_toprow.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING
3 | import gradio as gr
4 |
5 | from .ui_common import *
6 | from .uibase import UIBase
7 |
8 | if TYPE_CHECKING:
9 | from .ui_classes import *
10 |
11 |
12 | class ToprowUI(UIBase):
13 | def create_ui(self, cfg_general):
14 | with gr.Column(variant='panel'):
15 | with gr.Row():
16 | with gr.Column(scale=1):
17 | self.btn_save_all_changes = gr.Button(value='Save all changes', variant='primary')
18 | with gr.Column(scale=2):
19 | self.cb_backup = gr.Checkbox(value=cfg_general.backup, label='Backup original text file (original file will be renamed like filename.000, .001, .002, ...)', interactive=True)
20 | gr.HTML(value='Note: New text file will be created if you are using filename as captions.')
21 | with gr.Row():
22 | self.cb_save_kohya_metadata = gr.Checkbox(value=cfg_general.save_kohya_metadata, label="Use kohya-ss's finetuning metadata json", interactive=True)
23 | with gr.Row():
24 | with gr.Column(variant='panel', visible=cfg_general.save_kohya_metadata) as self.cl_kohya_metadata:
25 | self.tb_metadata_output = gr.Textbox(label='json path', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_output_path)
26 | self.tb_metadata_input = gr.Textbox(label='json input path (Optional, only for append results)', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_input_path)
27 | with gr.Row():
28 | self.cb_metadata_overwrite = gr.Checkbox(value=cfg_general.meta_overwrite, label="Overwrite if output file exists", interactive=True)
29 | self.cb_metadata_as_caption = gr.Checkbox(value=cfg_general.meta_save_as_caption, label="Save metadata as caption", interactive=True)
30 | self.cb_metadata_use_fullpath = gr.Checkbox(value=cfg_general.meta_use_full_path, label="Save metadata image key as fullpath", interactive=True)
31 | with gr.Row(visible=False):
32 | self.txt_result = gr.Textbox(label='Results', interactive=False)
33 |
34 | def set_callbacks(self, load_dataset:LoadDatasetUI):
35 |
36 | def save_all_changes(backup: bool, save_kohya_metadata:bool, metadata_output:str, metadata_input:str, metadata_overwrite:bool, metadata_as_caption:bool, metadata_use_fullpath:bool, caption_file_ext:str):
37 | if not metadata_input:
38 | metadata_input = None
39 | dte_instance.save_dataset(backup, caption_file_ext, save_kohya_metadata, metadata_output, metadata_input, metadata_overwrite, metadata_as_caption, metadata_use_fullpath)
40 |
41 | self.btn_save_all_changes.click(
42 | fn=save_all_changes,
43 | inputs=[self.cb_backup, self.cb_save_kohya_metadata, self.tb_metadata_output, self.tb_metadata_input, self.cb_metadata_overwrite, self.cb_metadata_as_caption, self.cb_metadata_use_fullpath, load_dataset.tb_caption_file_ext]
44 | )
45 |
46 | self.cb_save_kohya_metadata.change(
47 | fn=lambda x:gr.update(visible=x),
48 | inputs=self.cb_save_kohya_metadata,
49 | outputs=self.cl_kohya_metadata
50 | )
51 |
52 |
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/block_dataset_gallery.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING, Callable, List
3 |
4 | import gradio as gr
5 |
6 | from .ui_common import *
7 | from .uibase import UIBase
8 |
9 | if TYPE_CHECKING:
10 | from .ui_classes import *
11 |
12 | class DatasetGalleryUI(UIBase):
13 | def __init__(self):
14 | self.selected_path = ''
15 | self.selected_index = -1
16 | self.selected_index_prev = -1
17 |
18 | def create_ui(self, image_columns):
19 | with gr.Row(visible=False):
20 | self.cbg_hidden_dataset_filter = gr.CheckboxGroup(label='Dataset Filter')
21 | self.nb_hidden_dataset_filter_apply = gr.Number(label='Filter Apply', value=-1)
22 | self.btn_hidden_set_index = gr.Button(elem_id="dataset_tag_editor_btn_hidden_set_index")
23 | self.nb_hidden_image_index = gr.Number(value=None, label='hidden_idx_next')
24 | self.nb_hidden_image_index_prev = gr.Number(value=None, label='hidden_idx_prev')
25 | self.gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_dataset_gallery", columns=image_columns)
26 |
27 | def set_callbacks(self, load_dataset:LoadDatasetUI, gallery_state:GalleryStateUI, get_filters:Callable[[], dte_module.filters.Filter]):
28 | gallery_state.register_value('Selected Image', self.selected_path)
29 |
30 | load_dataset.btn_load_datasets.click(
31 | fn=lambda:[-1, -1],
32 | outputs=[self.nb_hidden_image_index, self.nb_hidden_image_index_prev]
33 | )
34 |
35 | def set_index_clicked(next_idx: int, prev_idx: int):
36 | prev_idx = int(prev_idx) if prev_idx is not None else -1
37 | next_idx = int(next_idx) if next_idx is not None else -1
38 | img_paths = dte_instance.get_filtered_imgpaths(filters=get_filters())
39 |
40 | if prev_idx < 0 or len(img_paths) <= prev_idx:
41 | prev_idx = -1
42 |
43 | if 0 <= next_idx and next_idx < len(img_paths):
44 | self.selected_path = img_paths[next_idx]
45 | else:
46 | next_idx = -1
47 | self.selected_path = ''
48 |
49 | gallery_state.register_value('Selected Image', self.selected_path)
50 | return [next_idx, prev_idx]
51 |
52 | self.btn_hidden_set_index.click(
53 | fn=set_index_clicked,
54 | _js="(x, y) => [dataset_tag_editor_gl_dataset_images_selected_index(), x]",
55 | inputs=[self.nb_hidden_image_index, self.nb_hidden_image_index_prev],
56 | outputs=[self.nb_hidden_image_index, self.nb_hidden_image_index_prev]
57 | )
58 | self.nb_hidden_image_index.change(
59 | fn=self.func_to_set_value('selected_index', int),
60 | inputs=self.nb_hidden_image_index
61 | )
62 | self.nb_hidden_image_index_prev.change(
63 | fn=self.func_to_set_value('selected_index_prev', int),
64 | inputs=self.nb_hidden_image_index_prev
65 | )
66 |
67 | self.nb_hidden_dataset_filter_apply.change(
68 | fn=lambda a, b: [a, b],
69 | _js='(x, y) => [y>=0 ? dataset_tag_editor_gl_dataset_images_filter(x) : x, -1]',
70 | inputs=[self.cbg_hidden_dataset_filter, self.nb_hidden_dataset_filter_apply],
71 | outputs=[self.cbg_hidden_dataset_filter, self.nb_hidden_dataset_filter_apply]
72 | )
73 |
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/kohya_finetune_metadata.py:
--------------------------------------------------------------------------------
1 | # This code is based on following codes written by kohya-ss and modified by toshiaki1729.
2 | # https://github.com/kohya-ss/sd-scripts/blob/main/finetune/merge_captions_to_metadata.py
3 | # https://github.com/kohya-ss/sd-scripts/blob/main/finetune/merge_dd_tags_to_metadata.py
4 |
5 | # The original code is distributed in the Apache License 2.0.
6 | # Full text of the license is available at the following link.
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # implement metadata output compatible to kohya-ss's finetuning captions
10 | # on commit hash: ae33d724793e14f16b4c68bdad79f836c86b1b8e
11 |
12 | import json
13 | from glob import glob
14 | from pathlib import Path
15 | from PIL import Image
16 |
17 | def write(dataset, dataset_dir, out_path, in_path=None, overwrite=False, save_as_caption=False, use_full_path=False):
18 | dataset_dir = Path(dataset_dir)
19 | if in_path is None and Path(out_path).is_file() and not overwrite:
20 | in_path = out_path
21 |
22 | result = {}
23 | if in_path is not None:
24 | try:
25 | result = json.loads(Path(in_path).read_text(encoding='utf-8'))
26 | except:
27 | result = {}
28 |
29 | tags_key = 'caption' if save_as_caption else 'tags'
30 |
31 | for data in dataset.datas.values():
32 | img_path, tags = Path(data.imgpath), data.tags
33 |
34 | img_key = str(img_path.absolute()) if use_full_path else img_path.stem
35 | save_caption = ', '.join(tags) if save_as_caption else tags
36 |
37 | if img_key not in result:
38 | result[img_key] = {}
39 |
40 | result[img_key][tags_key] = save_caption
41 |
42 | with open(out_path, 'w', encoding='utf-8', newline='') as f:
43 | json.dump(result, f, indent=2)
44 |
45 |
46 | def read(dataset_dir, json_path, use_temp_dir:bool):
47 | dataset_dir = Path(dataset_dir)
48 | json_path = Path(json_path)
49 | metadata = json.loads(json_path.read_text('utf8'))
50 | imgpaths = []
51 | images = {}
52 | taglists = []
53 |
54 | def load_image(img_path):
55 | img_path = Path(path)
56 | try:
57 | img = Image.open(img_path)
58 | except:
59 | return None, None
60 | else:
61 | abs_path = str(img_path.absolute())
62 | if not use_temp_dir:
63 | img.already_saved_as = abs_path
64 | return abs_path, img
65 |
66 | for image_key, img_md in metadata.items():
67 | img_path = Path(image_key)
68 | abs_path = None
69 | img = None
70 | if img_path.is_file():
71 | abs_path, img = load_image(img_path)
72 | if abs_path is None or img is None:
73 | continue
74 | images[abs_path] = img
75 | else:
76 | for path in glob(str(dataset_dir.absolute() / (image_key + '.*'))):
77 | abs_path, img = load_image(path)
78 | if abs_path is None or img is None:
79 | continue
80 | images[abs_path] = img
81 | break
82 | if abs_path is None or img is None:
83 | continue
84 | caption = img_md.get('caption')
85 | tags = img_md.get('tags')
86 | if tags is None:
87 | tags = []
88 | if caption is not None and isinstance(caption, str):
89 | caption = [s.strip() for s in caption.split(',')]
90 | tags = [s for s in caption if s] + tags
91 | imgpaths.append(abs_path)
92 | taglists.append(tags)
93 |
94 | return imgpaths, images, taglists
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from typing import Tuple
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | import torchvision.transforms as tf
7 | from torch.utils.data import Dataset, DataLoader
8 | from tqdm import tqdm
9 |
10 | from modules import shared, devices
11 | import launch
12 |
13 |
14 | class ImageDataset(Dataset):
15 | def __init__(self, images:list[Image.Image], transforms:tf.Compose=None):
16 | self.images = images
17 | self.transforms = transforms
18 |
19 | def __len__(self):
20 | return len(self.images)
21 |
22 | def __getitem__(self, i):
23 | img = self.images[i]
24 | if self.transforms is not None:
25 | img = self.transforms(img)
26 | return img
27 |
28 |
29 |
30 | class WaifuDiffusionTaggerTimm:
31 | # some codes are brought from https://github.com/neggles/wdv3-timm and modified
32 |
33 | def __init__(self, model_repo, label_filename="selected_tags.csv"):
34 | self.LABEL_FILENAME = label_filename
35 | self.MODEL_REPO = model_repo
36 | self.model = None
37 | self.transform = None
38 | self.labels = []
39 |
40 | def load(self):
41 | import huggingface_hub
42 |
43 | if not launch.is_installed("timm"):
44 | launch.run_pip(
45 | "install -U timm",
46 | "requirements for dataset-tag-editor [timm]",
47 | )
48 | import timm
49 | from timm.data import create_transform, resolve_data_config
50 |
51 | if not self.model:
52 | self.model: torch.nn.Module = timm.create_model(
53 | "hf-hub:" + self.MODEL_REPO
54 | ).eval()
55 | state_dict = timm.models.load_state_dict_from_hf(self.MODEL_REPO)
56 | self.model.load_state_dict(state_dict)
57 | self.model.to(devices.device)
58 | self.transform = create_transform(
59 | **resolve_data_config(self.model.pretrained_cfg, model=self.model)
60 | )
61 |
62 | path_label = huggingface_hub.hf_hub_download(
63 | self.MODEL_REPO, self.LABEL_FILENAME
64 | )
65 | import pandas as pd
66 |
67 | self.labels = pd.read_csv(path_label)["name"].tolist()
68 |
69 | def unload(self):
70 | if not shared.opts.interrogate_keep_models_in_memory:
71 | self.model = None
72 | devices.torch_gc()
73 |
74 | def apply(self, image: Image.Image):
75 | if not self.model:
76 | return []
77 |
78 | image_t: torch.Tensor = self.transform(image).unsqueeze(0)
79 | image_t = image_t[:, [2, 1, 0]]
80 | image_t = image_t.to(devices.device)
81 |
82 | with torch.inference_mode():
83 | features = self.model.forward(image_t)
84 | probs = F.sigmoid(features).detach().cpu().numpy()
85 |
86 | labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
87 |
88 | return labels
89 |
90 |
91 | def apply_multi(self, images: list[Image.Image], batch_size: int):
92 | if not self.model:
93 | return []
94 |
95 | dataset = ImageDataset(images, self.transform)
96 | dataloader = DataLoader(dataset, batch_size=batch_size)
97 |
98 | with torch.inference_mode():
99 | for batch in tqdm(dataloader):
100 | batch = batch[:, [2, 1, 0]].to(devices.device)
101 | features = self.model.forward(batch)
102 | probs = F.sigmoid(features).detach().cpu().numpy()
103 | labels: list[Tuple[str, float]] = [list(zip(self.labels, probs[i].astype(float))) for i in range(probs.shape[0])]
104 | yield labels
105 |
--------------------------------------------------------------------------------
/scripts/config.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import json
3 |
4 | from scripts import logger
5 | from scripts.paths import paths
6 | from scripts.dte_instance import dte_instance
7 |
8 | SortBy = dte_instance.SortBy
9 | SortOrder = dte_instance.SortOrder
10 |
11 | CONFIG_PATH = paths.base_path / "config.json"
12 |
13 | GeneralConfig = namedtuple(
14 | "GeneralConfig",
15 | [
16 | "backup",
17 | "dataset_dir",
18 | "caption_ext",
19 | "load_recursive",
20 | "load_caption_from_filename",
21 | "replace_new_line",
22 | "use_interrogator",
23 | "use_interrogator_names",
24 | "use_custom_threshold_booru",
25 | "custom_threshold_booru",
26 | "use_custom_threshold_waifu",
27 | "custom_threshold_waifu",
28 | "custom_threshold_z3d",
29 | "save_kohya_metadata",
30 | "meta_output_path",
31 | "meta_input_path",
32 | "meta_overwrite",
33 | "meta_save_as_caption",
34 | "meta_use_full_path",
35 | ],
36 | )
37 | FilterConfig = namedtuple(
38 | "FilterConfig",
39 | ["sw_prefix", "sw_suffix", "sw_regex", "sort_by", "sort_order", "logic"],
40 | )
41 | BatchEditConfig = namedtuple(
42 | "BatchEditConfig",
43 | [
44 | "show_only_selected",
45 | "prepend",
46 | "use_regex",
47 | "target",
48 | "sw_prefix",
49 | "sw_suffix",
50 | "sw_regex",
51 | "sory_by",
52 | "sort_order",
53 | "batch_sort_by",
54 | "batch_sort_order",
55 | "token_count",
56 | ],
57 | )
58 | EditSelectedConfig = namedtuple(
59 | "EditSelectedConfig",
60 | [
61 | "auto_copy",
62 | "sort_on_save",
63 | "warn_change_not_saved",
64 | "use_interrogator_name",
65 | "sort_by",
66 | "sort_order",
67 | ],
68 | )
69 | MoveDeleteConfig = namedtuple(
70 | "MoveDeleteConfig", ["range", "target", "caption_ext", "destination"]
71 | )
72 |
73 | CFG_GENERAL_DEFAULT = GeneralConfig(
74 | True,
75 | "",
76 | ".txt",
77 | False,
78 | True,
79 | False,
80 | "No",
81 | [],
82 | False,
83 | 0.7,
84 | False,
85 | 0.35,
86 | 0.35,
87 | False,
88 | "",
89 | "",
90 | True,
91 | False,
92 | False,
93 | )
94 | CFG_FILTER_P_DEFAULT = FilterConfig(
95 | False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, "AND"
96 | )
97 | CFG_FILTER_N_DEFAULT = FilterConfig(
98 | False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, "OR"
99 | )
100 | CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(
101 | True,
102 | False,
103 | False,
104 | "Only Selected Tags",
105 | False,
106 | False,
107 | False,
108 | SortBy.ALPHA.value,
109 | SortOrder.ASC.value,
110 | SortBy.ALPHA.value,
111 | SortOrder.ASC.value,
112 | 75,
113 | )
114 | CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(
115 | False, False, False, "", SortBy.ALPHA.value, SortOrder.ASC.value
116 | )
117 | CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig("Selected One", [], ".txt", "")
118 |
119 |
120 | class Config:
121 | def __init__(self):
122 | self.config = dict()
123 |
124 | def load(self):
125 | if not CONFIG_PATH.is_file():
126 | self.config = dict()
127 | return
128 | try:
129 | self.config = json.loads(CONFIG_PATH.read_text("utf8"))
130 | except:
131 | logger.warn("Error on loading config.json. Default settings will be loaded.")
132 | self.config = dict()
133 | else:
134 | logger.write("Settings has been read from config.json")
135 |
136 | def save(self):
137 | CONFIG_PATH.write_text(json.dumps(self.config, indent=4), "utf8")
138 |
139 | def read(self, name: str):
140 | return self.config.get(name)
141 |
142 | def write(self, cfg: dict, name: str):
143 | self.config[name] = cfg
144 |
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/tab_filter_by_tags.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING, List, Callable
3 | import gradio as gr
4 |
5 | from .ui_common import *
6 | from .uibase import UIBase
7 | from .block_tag_filter import TagFilterUI
8 |
9 | if TYPE_CHECKING:
10 | from .ui_classes import *
11 |
12 | filters = dte_module.filters
13 |
14 | class FilterByTagsUI(UIBase):
15 | def __init__(self):
16 | self.tag_filter_ui = TagFilterUI(tag_filter_mode=filters.TagFilter.Mode.INCLUSIVE)
17 | self.tag_filter_ui_neg = TagFilterUI(tag_filter_mode=filters.TagFilter.Mode.EXCLUSIVE)
18 |
19 | def create_ui(self, cfg_filter_p, cfg_filter_n, get_filters):
20 | with gr.Row():
21 | self.btn_clear_tag_filters = gr.Button(value='Clear tag filters')
22 | self.btn_clear_all_filters = gr.Button(value='Clear ALL filters')
23 |
24 | with gr.Tab(label='Positive Filter'):
25 | with gr.Column(variant='panel'):
26 | gr.HTML(value='Search tags / Filter images by tags (INCLUSIVE)')
27 | logic_p = filters.TagFilter.Logic.OR if cfg_filter_p.logic=='OR' else filters.TagFilter.Logic.NONE if cfg_filter_p.logic=='NONE' else filters.TagFilter.Logic.AND
28 | self.tag_filter_ui.create_ui(get_filters, logic_p, cfg_filter_p.sort_by, cfg_filter_p.sort_order, cfg_filter_p.sw_prefix, cfg_filter_p.sw_suffix, cfg_filter_p.sw_regex)
29 |
30 | with gr.Tab(label='Negative Filter'):
31 | with gr.Column(variant='panel'):
32 | gr.HTML(value='Search tags / Filter images by tags (EXCLUSIVE)')
33 | logic_n = filters.TagFilter.Logic.AND if cfg_filter_n.logic=='AND' else filters.TagFilter.Logic.NONE if cfg_filter_n.logic=='NONE' else filters.TagFilter.Logic.OR
34 | self.tag_filter_ui_neg.create_ui(get_filters, logic_n, cfg_filter_n.sort_by, cfg_filter_n.sort_order, cfg_filter_n.sw_prefix, cfg_filter_n.sw_suffix, cfg_filter_n.sw_regex)
35 |
36 | def set_callbacks(self, o_update_gallery:List[gr.components.Component], o_update_filter_and_gallery:List[gr.components.Component], batch_edit_captions:BatchEditCaptionsUI, move_or_delete_files:MoveOrDeleteFilesUI, update_gallery:Callable[[], List], update_filter_and_gallery:Callable[[], List], get_filters:Callable[[], List[dte_module.filters.Filter]]):
37 | common_callback = lambda : \
38 | update_gallery() + \
39 | batch_edit_captions.get_common_tags(get_filters, self) + \
40 | [move_or_delete_files.update_current_move_or_delete_target_num()] + \
41 | [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
42 |
43 | common_callback_output = \
44 | o_update_gallery + \
45 | [batch_edit_captions.tb_common_tags, batch_edit_captions.tb_edit_tags] + \
46 | [move_or_delete_files.ta_move_or_delete_target_dataset_num]+ \
47 | [batch_edit_captions.tag_select_ui_remove.cbg_tags]
48 |
49 |
50 | self.tag_filter_ui.on_filter_update(
51 | fn=lambda :
52 | common_callback() +
53 | [', '.join(self.tag_filter_ui.filter.tags)],
54 | inputs=None,
55 | outputs=common_callback_output + [batch_edit_captions.tb_sr_selected_tags],
56 | _js='(...args) => {dataset_tag_editor_gl_dataset_images_close(); return args}'
57 | )
58 |
59 | self.tag_filter_ui_neg.on_filter_update(
60 | fn=common_callback,
61 | inputs=None,
62 | outputs=common_callback_output,
63 | _js='(...args) => {dataset_tag_editor_gl_dataset_images_close(); return args}'
64 | )
65 |
66 | self.tag_filter_ui.set_callbacks()
67 | self.tag_filter_ui_neg.set_callbacks()
68 |
69 | self.btn_clear_tag_filters.click(
70 | fn=lambda:self.clear_filters(update_filter_and_gallery),
71 | outputs=o_update_filter_and_gallery
72 | )
73 |
74 | self.btn_clear_all_filters.click(
75 | fn=lambda:self.clear_filters(update_filter_and_gallery),
76 | outputs=o_update_filter_and_gallery
77 | )
78 |
79 | def clear_filters(self, update_filter_and_gallery):
80 | self.tag_filter_ui.clear_filter()
81 | self.tag_filter_ui_neg.clear_filter()
82 | return update_filter_and_gallery()
83 |
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/filters.py:
--------------------------------------------------------------------------------
1 | from typing import Set, Dict
2 | from enum import Enum
3 |
4 |
5 | class Filter:
6 | def apply(self, dataset):
7 | return dataset
8 | def __str__(self):
9 | return ''
10 |
11 |
12 | class TagFilter(Filter):
13 | class Logic(Enum):
14 | NONE = 0
15 | AND = 1
16 | OR = 2
17 |
18 | class Mode(Enum):
19 | NONE = 0
20 | INCLUSIVE = 1
21 | EXCLUSIVE = 2
22 |
23 | def __init__(self, tags: Set[str] = set(), logic: Logic = Logic.NONE, mode: Mode = Mode.NONE):
24 | self.tags = tags
25 | self.logic = logic
26 | self.mode = mode
27 |
28 | def apply(self, dataset):
29 | if not self.tags or self.logic == TagFilter.Logic.NONE or self.mode == TagFilter.Mode.NONE:
30 | return dataset
31 |
32 | paths_remove = []
33 |
34 | if self.logic == TagFilter.Logic.AND:
35 | if self.mode == TagFilter.Mode.INCLUSIVE:
36 | for path, data in dataset.datas.items():
37 | if not data.tag_contains_allof(self.tags):
38 | paths_remove.append(path)
39 |
40 | elif self.mode == TagFilter.Mode.EXCLUSIVE:
41 | for path, data in dataset.datas.items():
42 | if data.tag_contains_allof(self.tags):
43 | paths_remove.append(path)
44 |
45 | elif self.logic == TagFilter.Logic.OR:
46 | if self.mode == TagFilter.Mode.INCLUSIVE:
47 | for path, data in dataset.datas.items():
48 | if not data.tag_contains_anyof(self.tags):
49 | paths_remove.append(path)
50 |
51 | elif self.mode == TagFilter.Mode.EXCLUSIVE:
52 | for path, data in dataset.datas.items():
53 | if data.tag_contains_anyof(self.tags):
54 | paths_remove.append(path)
55 |
56 | for path in paths_remove:
57 | dataset.remove_by_path(path)
58 |
59 | return dataset
60 |
61 | def __str__(self):
62 | if len(self.tags) == 0:
63 | return ''
64 | res = ''
65 | if self.mode == TagFilter.Mode.EXCLUSIVE:
66 | res += 'NOT '
67 | if self.logic == TagFilter.Logic.AND:
68 | res += 'AND'
69 | elif self.logic == TagFilter.Logic.OR:
70 | res += 'OR'
71 | if self.logic == TagFilter.Logic.AND or self.logic == TagFilter.Logic.OR:
72 | text = ', '.join([tag for tag in self.tags])
73 | res += f'({text})'
74 | return res
75 |
76 |
77 |
78 | class PathFilter(Filter):
79 | class Mode(Enum):
80 | NONE = 0
81 | INCLUSIVE = 1
82 | EXCLUSIVE = 2
83 |
84 | def __init__(self, paths: Set[str] = {}, mode: Mode = Mode.NONE):
85 | self.paths = paths
86 | self.mode = mode
87 |
88 | def apply(self, dataset):
89 | if self.mode == PathFilter.Mode.NONE:
90 | return dataset
91 |
92 | paths_remove = self.paths
93 | if self.mode == PathFilter.Mode.INCLUSIVE:
94 | paths_remove = {path for path in dataset.datas.keys()} - paths_remove
95 |
96 | for path in paths_remove:
97 | dataset.remove_by_path(path)
98 |
99 | return dataset
100 |
101 |
102 | class TagScoreFilter(Filter):
103 | class Mode(Enum):
104 | NONE = 0
105 | LESS_THAN = 1
106 | GREATER_THAN = 2
107 |
108 | def __init__(self, scores: Dict[str, Dict[str, float]], tag: str, threshold: float, mode: Mode = Mode.NONE):
109 | self.scores = scores
110 | self.mode = mode
111 | self.tag = tag
112 | self.threshold = threshold
113 |
114 | def apply(self, dataset):
115 | if self.mode == TagScoreFilter.Mode.NONE:
116 | return dataset
117 |
118 | paths_remove = {path for path, scores in self.scores.items() if (scores.get(self.tag) or 0) > self.threshold}
119 |
120 | if self.mode == TagScoreFilter.Mode.GREATER_THAN:
121 | paths_remove = {path for path in dataset.datas.keys()} - paths_remove
122 |
123 | for path in paths_remove:
124 | dataset.remove_by_path(path)
125 |
126 | return dataset
--------------------------------------------------------------------------------
/README-JP.md:
--------------------------------------------------------------------------------
1 | # Dataset Tag Editor
2 | [**スタンドアロン版はこちらです**](https://github.com/toshiaki1729/dataset-tag-editor-standalone): いくつかの既知のバグを回避するのに有効かもしれません。
3 |
4 |
5 | [English Readme](README.md)
6 |
7 | [Stable Diffusion web UI by AUTOMATIC1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)用の拡張機能です。
8 |
9 | web UI 上で学習用データセットのキャプションを編集できるようにします。
10 |
11 | 
12 |
13 | DeepBooru interrogator で生成したような、カンマ区切り形式のキャプションを編集するのに適しています。
14 |
15 | キャプションとして画像ファイル名を利用している場合も読み込むことができますが、保存はテキストファイルのみ対応しています。
16 |
17 | ## インストール方法
18 | ### WebUIのExtensionsタブからインストールする
19 | "Install from URL" タブに `https://github.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor.git` をコピーしてインストールできます。
20 | "Availables" タブにこの拡張機能が表示されている場合は、ワンクリックでインストール可能です。
21 | **web UI の "Extensions" タブから更新をした際、完全に更新を適用するには web UI を再起動する必要がある場合があります。**
22 |
23 | ### 手動でインストールする
24 | web UI の `extensions` フォルダにリポジトリのクローンを作成し再起動してください。
25 |
26 | web UI のフォルダで以下のコマンドを実行することによりインストールできます。
27 | ```commandline
28 | git clone https://github.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor.git extensions/dataset-tag-editor
29 | ```
30 |
31 | ## 特徴
32 | 以下、「タグ」はカンマ区切りされたキャプションの各部分を意味します。
33 | - テキスト形式(webUI方式)またはjson形式 ([kohya-ss sd-scripts metadata](https://github.com/kohya-ss/sd-scripts))のキャプションを編集できます
34 | - 画像を見ながらキャプションの編集ができます
35 | - タグの検索ができます
36 | - 複数タグで絞り込んでキャプションの編集ができます
37 | - 絞り込み方法として、AND/OR/NOT検索ができます
38 | - タグを一括で置換・削除・追加できます
39 | - タグを一括で並べ替えできます
40 | - タグまたはキャプション全体について一括置換ができます
41 | - [正規表現](https://docs.python.org/ja/3/library/re.html#regular-expression-syntax) が利用可能です
42 | - Interrogatorを使用してタグの追加や編集ができます
43 | - BLIP、DeepDanbooru、[Z3D-E621-Convnext](https://huggingface.co/toynya/Z3D-E621-Convnext)、 [WDv1.4 Tagger](https://huggingface.co/SmilingWolf)の各ネットワークによる学習結果(v1, v2, v3)が使用可能です
44 | - お好みのTaggerを `userscripts/taggers` に追加できます (`scripts.tagger.Tagger`を継承したクラスでラップしてください)
45 | - 当該フォルダにAesthetic Scoreに基づいたTaggerをいくつか実装しています
46 | - 画像やキャプションファイルの一括移動・削除ができます
47 |
48 |
49 | ## 使い方
50 | 1. web UI でデータセットを作成する
51 | - 既にリサイズ・トリミングされた画像を使用することをお勧めします
52 | 1. データセットを読み込む
53 | - 必要に応じてDeepDanbooru等でタグ付けができます
54 | 1. キャプションを編集する
55 | - "Filter by Tags" タブでキャプションの編集をしたい画像を絞り込む
56 | - 画像を手動で選んで絞り込む場合は "Filter by Selection" タブを使用する
57 | - 一括でタグを置換・削除・追加する場合は "Batch Edit Caption" タブを使用する
58 | - キャプションを個別に編集したい場合は "Edit Caption of Selected Image" タブを使用する
59 | - DeepDanbooru等も利用可能
60 | - 選択したものをデータセットから一括移動・削除したい場合は "Remove or Delete Files" タブを使用する
61 | 1. "Save all changes" ボタンをクリックして保存する
62 |
63 |
64 | ## タグ編集の手引き
65 |
66 | 基本的な手順は以下の通りです
67 |
68 | 1. 編集対象の画像をフィルターで絞り込む
69 | 1. まとめて編集する
70 |
71 | 一括で行われる編集はすべて**表示されている画像(=絞り込まれた画像)にのみ**適用されます。
72 |
73 | ### 1. フィルターの選び方
74 | - **全ての画像を一括で処理したい場合**
75 | フィルターは不要です。
76 | - **何枚かを処理したい場合**
77 | 1. **共通のタグや、共通して持たないタグがある**
78 | "Filter by Tags" タブで画像を絞り込み、編集対象だけが表示されるようにする。
79 | 1. **何も共通点が無い**
80 | "Filter by Selection" タブで画像を絞り込む。
81 | フィルターへの画像の追加は[Enter]キーがショートカットです。
82 |
83 | ### 2. 編集の仕方
84 | - **新しいタグを追加したい場合**
85 | 1. "Batch Edit Captions" タブを開く
86 | 1. "Edit tags" 内に追加したいタグをカンマ区切りで追記する
87 | 1. "Apply changes to filtered images" ボタンを押す
88 | 
89 | 例:"foo" と "bar" が表示されている画像に追加されます
90 |
91 | - **絞り込まれた画像に共通なタグを編集(置換)したい場合**
92 | 1. "Batch Edit Captions" タブを開く
93 | 1. "Edit tags" 内に表示されたタグを書き換える
94 | 1. "Apply changes to filtered images" ボタンを押す
95 | 
96 | 例:"male focus" と "solo" がそれぞれ "foo" と "bar" に置換されます
97 |
98 | - **タグを取り除きたい場合**
99 | 置換と同様の手順で、対象のタグを空欄に書き換えることで取り除けます。
100 | 共通のタグでない(一部の画像にのみ含まれる等)場合は、"Batch Edit Captions" タブにある "Remove" を利用することもできます。
101 |
102 | - **柔軟にタグを追加・削除・置換した**
103 | 1. "Batch Edit Captions" タブを開く
104 | 2. "Use regex" にチェックを入れて "Search and Replace" する
105 | 
106 | 例:"1boy", "2boys", … がそれぞれ、 "1girl", "2girls", … に置換されます。
107 | カンマはタグの区切りとみなされるため、カンマを追加・削除することで新しいタグを追加・削除できます。
108 | 正規表現(regex)を使うと、複雑な条件に応じてタグの編集が可能です。
109 |
110 |
111 | ## トラブルシューティング
112 | ### ギャラリーに画像が表示されず、コンソールに "All files must contained within the Gradio python app working directory…" と出ている
113 | "Settings" タブで、サムネイル画像を一時保存するフォルダを指定してください。
114 | "Directory to save temporary files" にパスを指定して "Force using temporary file…" をチェックしてください。
115 |
116 | ### 大量の画像や巨大な画像を開いたときに動作が遅くなる
117 | "Settings" タブで、"Force image gallery to use temporary files" にチェックを入れて、 "Maximum resolution of ..." に希望の解像度を入れてください。
118 | 数百万もの画像を含むなど、あまりにも巨大なデータセットでは効果がないかもしれません。
119 | もしくは、[**スタンドアロン版**](https://github.com/toshiaki1729/dataset-tag-editor-standalone)を試してください。
120 | 
121 |
122 |
123 | ## 表示内容
124 |
125 | [こちら](DESCRIPTION_OF_DISPLAY-JP.md)に移動しました
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | from typing import List, Tuple
4 | from modules import shared, devices
5 | import launch
6 |
7 | from scripts.paths import paths
8 |
9 |
10 | class WaifuDiffusionTagger:
11 | # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
12 | def __init__(
13 | self,
14 | model_name,
15 | model_filename="model.onnx",
16 | label_filename="selected_tags.csv",
17 | ):
18 | self.MODEL_FILENAME = model_filename
19 | self.LABEL_FILENAME = label_filename
20 | self.MODEL_REPO = model_name
21 | self.model = None
22 | self.labels = []
23 |
24 | def load(self):
25 | import huggingface_hub
26 |
27 | if not self.model:
28 | path_model = huggingface_hub.hf_hub_download(
29 | self.MODEL_REPO, self.MODEL_FILENAME
30 | )
31 | if (
32 | "all" in shared.cmd_opts.use_cpu
33 | or "interrogate" in shared.cmd_opts.use_cpu
34 | ):
35 | providers = ["CPUExecutionProvider"]
36 | else:
37 | providers = [
38 | "CUDAExecutionProvider",
39 | "DmlExecutionProvider",
40 | "CPUExecutionProvider",
41 | ]
42 |
43 | def check_available_device():
44 | import torch
45 |
46 | if torch.cuda.is_available():
47 | return "cuda"
48 | elif launch.is_installed("torch-directml"):
49 | # This code cannot detect DirectML available device without pytorch-directml
50 | try:
51 | import torch_directml
52 |
53 | torch_directml.device()
54 | except:
55 | pass
56 | else:
57 | return "directml"
58 | return "cpu"
59 |
60 | if not launch.is_installed("onnxruntime"):
61 | dev = check_available_device()
62 | if dev == "cuda":
63 | launch.run_pip(
64 | "install -U onnxruntime-gpu",
65 | "requirements for dataset-tag-editor [onnxruntime-gpu]",
66 | )
67 | elif dev == "directml":
68 | launch.run_pip(
69 | "install -U onnxruntime-directml",
70 | "requirements for dataset-tag-editor [onnxruntime-directml]",
71 | )
72 | else:
73 | print(
74 | "Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow."
75 | )
76 | launch.run_pip(
77 | "install -U onnxruntime",
78 | "requirements for dataset-tag-editor [onnxruntime for CPU]",
79 | )
80 | import onnxruntime as ort
81 |
82 | self.model = ort.InferenceSession(path_model, providers=providers)
83 |
84 | path_label = huggingface_hub.hf_hub_download(
85 | self.MODEL_REPO, self.LABEL_FILENAME
86 | )
87 | import pandas as pd
88 |
89 | self.labels = pd.read_csv(path_label)["name"].tolist()
90 |
91 | def unload(self):
92 | if not shared.opts.interrogate_keep_models_in_memory:
93 | self.model = None
94 | devices.torch_gc()
95 |
96 | # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
97 | def apply(self, image: Image.Image):
98 | if not self.model:
99 | return dict()
100 |
101 | from modules import images
102 |
103 | _, height, width, _ = self.model.get_inputs()[0].shape
104 |
105 | # the way to fill empty pixels is quite different from original one;
106 | # original: fill by white pixels
107 | # this: repeat the pixels on the edge
108 | image = images.resize_image(2, image.convert("RGB"), width, height)
109 | image_np = np.array(image, dtype=np.float32)
110 | # PIL RGB to OpenCV BGR
111 | image_np = image_np[:, :, ::-1]
112 | image_np = np.expand_dims(image_np, 0)
113 |
114 | input_name = self.model.get_inputs()[0].name
115 | label_name = self.model.get_outputs()[0].name
116 | probs = self.model.run([label_name], {input_name: image_np})[0]
117 | labels: List[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
118 |
119 | return labels
120 |
--------------------------------------------------------------------------------
/DESCRIPTION_OF_DISPLAY.md:
--------------------------------------------------------------------------------
1 | # Description of Display
2 |
3 | ## Common
4 | 
5 | - "Save all changes" buttton
6 | - save captions to text file
7 | - changes will not be applied to the text files until you press this button
8 | - if "Backup original text file" is checked, original text files will be renamed not to be overwritten
9 | - backup file name will be like filename.000, -.001, -.002, ...
10 | - new caption text file will be created if it does not exist
11 | - "Reload/Save Settings" Accordion (closed initially)
12 | - you can reload/save/restore all settings in the UI here
13 | - settings will be saved in `.../tag-editor-root-dir/config.json`
14 | - "Dataset Directory" text box
15 | - input the directory of training images and load them by clicking "Load" button
16 | - loading options are below
17 | - you can make caption on loading by using interrogator if needed
18 | - "Dataset Images" gallery
19 | - to view and select images
20 | - the number of colums can be changed in web UI's "Settings" tab
21 |
22 | ***
23 |
24 | ## "Filter by Tags" tab
25 | 
26 | ### Common
27 | - "Clear tag filters" button
28 | - clear tag search text and tag selection
29 | - "Clear ALL filters" button
30 | - clear all filters including image selection filter in the next tab
31 |
32 | ### Search tags / Filter images by tags
33 | Positive (inclusive) / Negative (exclusive) filters can be used by toglling tabs.
34 | - "Search Tags" text box
35 | - search and filter the tags displayed below
36 | - "Sort by / Sort order" radio buttons
37 | - change sort order of the tags displayed below
38 | - "Filter Images by Tags" checkboxes
39 | - filter images displayed in the left gallery by tags
40 | - also filter tags depending on captions of the displayed images
41 |
42 | ***
43 |
44 | ## "Filter by Selection" tab
45 | 
46 |
47 | - "Add selection" button
48 | - to include selected dataset image in selection
49 | - "Enter" is shortcut key
50 | - Tips: you can change the selected image in gallery using arrow keys
51 | - "Remove selection" button
52 | - to remove selected image from selection
53 | - "Delete" is shortcut key
54 | - "Invert selection" button
55 | - select all images in the entire dataset that have not been selected
56 | - "Clear selection" button
57 | - to remove all current selection, not to clear current filter
58 | - "Apply selection filter" button
59 | - apply selection filter on displaying dataset images
60 |
61 | ***
62 |
63 | ## "Batch Edit Captions" tab
64 | 
65 | ### "Search and Replace" tab
66 |
67 | - "Edit common tags" is a simple way to edit tags.
68 | - "Common Tags" text box (not editable)
69 | - shows the common tags among the displayed images in comma separated style
70 | - "Edit Tags" text box
71 | - you can edit the selected tags for all captions of the displayed images
72 | - each tags will be replaced by the tags in "same place"
73 | - erase tags by changing it into blank
74 | - you can add some tags to the captions by appending new tags
75 | - the tags will be added to the beggining/end of text files depending on the checkbox below
76 | - "Apply changes to filtered images" button
77 | - apply the tag changes only to displayed images
78 |
79 | - "Search and Replace" is a little complicated but powerful way to edit tags.
80 | - Regular expression can be used here.
81 | - "Search/Replace Text" textboxes
82 | - "Search Text" will be replaced by "Replace Text"
83 | - "Search and Replace in" radio buttons
84 | - to select the replacing method
85 | - "Only Selected Tags" : do replace sepalately in each only selected tags
86 | - "Each Tags" : do replace sepalately in each tags
87 | - "Entire Caption" : do replace in entire caption at once
88 | - "Search and Replace" button to apply
89 |
90 | ### "Remove" tab
91 | Simple way to batch remove tags
92 | - "Remove duplicate tags" button
93 | - make each tags in each captions appear only once
94 | - "Remove selected tags" button
95 | - remove tags selected below
96 |
97 | ***
98 |
99 | ## "Edit Caption of Selected Image" tab
100 | 
101 |
102 | ### "Read Caption from Selected Image" tab
103 | - "Caption of Selected Image" textbox
104 | - shows the caption of the selected image in the dataset gallery
105 |
106 | ### "Interrogate Selected Image" tab
107 | - "Interrogate Result" textbox
108 | - shows the result of interrogator
109 |
110 | ### Common
111 | - "Copy and Overwrite / Prepend / Apppend" button
112 | - copy/prepend/append the content in the textbox above to the textbox below
113 | - "Edit Caption" textbox
114 | - edit caption here
115 | - "Apply changes to selected image" button
116 | - change the caption of selected image into the text in "Edit Tags" textbox
117 |
118 | ***
119 |
120 | ## "Move or Delete Files" tab
121 | 
122 | - "Move or Delete" radio buttons to select target image
123 | - "Target" checkboxes to select which files to be moved or deleted
124 | - "Move File(s)" button
125 | - move files to "Destination Directory"
126 | - "DELETE File(s)" button
127 | - delete files
128 | - Note: This won't move the files into $Recycle.Bin, just do DELETE them completely.
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/block_tag_select.py:
--------------------------------------------------------------------------------
1 | from typing import List, Callable
2 | import gradio as gr
3 |
4 | from .ui_common import *
5 |
6 | TagFilter = dte_module.filters.TagFilter
7 | Filter = dte_module.filters.Filter
8 |
9 | SortBy = dte_instance.SortBy
10 | SortOrder = dte_instance.SortOrder
11 |
12 |
13 | class TagSelectUI():
14 | def __init__(self):
15 | self.filter_word = ''
16 | self.sort_by = SortBy.ALPHA
17 | self.sort_order = SortOrder.ASC
18 | self.selected_tags = set()
19 | self.tags = set()
20 | self.get_filters = lambda:[]
21 | self.prefix = False
22 | self.suffix = False
23 | self.regex = False
24 |
25 |
26 | def create_ui(self, get_filters: Callable[[], List[Filter]], sort_by = SortBy.ALPHA, sort_order = SortOrder.ASC, prefix=False, suffix=False, regex=False):
27 | self.get_filters = get_filters
28 | self.prefix = prefix
29 | self.suffix = suffix
30 | self.regex = regex
31 |
32 | self.tb_search_tags = gr.Textbox(label='Search Tags', interactive=True)
33 | with gr.Row():
34 | self.cb_prefix = gr.Checkbox(label='Prefix', value=False, interactive=True)
35 | self.cb_suffix = gr.Checkbox(label='Suffix', value=False, interactive=True)
36 | self.cb_regex = gr.Checkbox(label='Use regex', value=False, interactive=True)
37 | with gr.Row():
38 | self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=sort_by, interactive=True, label='Sort by')
39 | self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=sort_order, interactive=True, label='Sort Order')
40 | with gr.Row():
41 | self.btn_select_visibles = gr.Button(value='Select visible tags')
42 | self.btn_deselect_visibles = gr.Button(value='Deselect visible tags')
43 | self.cbg_tags = gr.CheckboxGroup(label='Select Tags', interactive=True)
44 |
45 |
46 | def set_callbacks(self):
47 | self.tb_search_tags.change(fn=self.tb_search_tags_changed, inputs=self.tb_search_tags, outputs=self.cbg_tags)
48 | self.cb_prefix.change(fn=self.cb_prefix_changed, inputs=self.cb_prefix, outputs=self.cbg_tags)
49 | self.cb_suffix.change(fn=self.cb_suffix_changed, inputs=self.cb_suffix, outputs=self.cbg_tags)
50 | self.cb_regex.change(fn=self.cb_regex_changed, inputs=self.cb_regex, outputs=self.cbg_tags)
51 | self.rb_sort_by.change(fn=self.rd_sort_by_changed, inputs=self.rb_sort_by, outputs=self.cbg_tags)
52 | self.rb_sort_order.change(fn=self.rd_sort_order_changed, inputs=self.rb_sort_order, outputs=self.cbg_tags)
53 | self.btn_select_visibles.click(fn=self.btn_select_visibles_clicked, outputs=self.cbg_tags)
54 | self.btn_deselect_visibles.click(fn=self.btn_deselect_visibles_clicked, inputs=self.cbg_tags, outputs=self.cbg_tags)
55 | self.cbg_tags.change(fn=self.cbg_tags_changed, inputs=self.cbg_tags, outputs=self.cbg_tags)
56 |
57 |
58 | def tb_search_tags_changed(self, tb_search_tags: str):
59 | self.filter_word = tb_search_tags
60 | return self.cbg_tags_update()
61 |
62 |
63 | def cb_prefix_changed(self, prefix:bool):
64 | self.prefix = prefix
65 | return self.cbg_tags_update()
66 |
67 |
68 | def cb_suffix_changed(self, suffix:bool):
69 | self.suffix = suffix
70 | return self.cbg_tags_update()
71 |
72 |
73 | def cb_regex_changed(self, regex:bool):
74 | self.regex = regex
75 | return self.cbg_tags_update()
76 |
77 |
78 | def rd_sort_by_changed(self, rb_sort_by: str):
79 | self.sort_by = rb_sort_by
80 | return self.cbg_tags_update()
81 |
82 |
83 | def rd_sort_order_changed(self, rd_sort_order: str):
84 | self.sort_order = rd_sort_order
85 | return self.cbg_tags_update()
86 |
87 |
88 | def cbg_tags_changed(self,
89 | cbg_tags#: List[str]
90 | ):
91 | self.selected_tags = set(dte_instance.read_tags(cbg_tags))
92 | return self.cbg_tags_update()
93 |
94 |
95 | def btn_deselect_visibles_clicked(self,
96 | cbg_tags#: List[str]
97 | ):
98 | tags = dte_instance.get_filtered_tags(self.get_filters(), self.filter_word, True)
99 | selected_tags = set(dte_instance.read_tags(cbg_tags)) & tags
100 | self.selected_tags -= selected_tags
101 | return self.cbg_tags_update()
102 |
103 |
104 | def btn_select_visibles_clicked(self):
105 | tags = set(dte_instance.get_filtered_tags(self.get_filters(), self.filter_word, True))
106 | self.selected_tags |= tags
107 | return self.cbg_tags_update()
108 |
109 |
110 | def cbg_tags_update(self):
111 | tags = dte_instance.get_filtered_tags(self.get_filters(), self.filter_word, True, prefix=self.prefix, suffix=self.suffix, regex=self.regex)
112 | self.tags = set(dte_instance.get_filtered_tags(self.get_filters(), filter_tags=True, prefix=self.prefix, suffix=self.suffix, regex=self.regex))
113 | self.selected_tags &= self.tags
114 | tags = dte_instance.sort_tags(tags=tags, sort_by=self.sort_by, sort_order=self.sort_order)
115 | tags = dte_instance.write_tags(tags, self.sort_by)
116 | selected_tags = dte_instance.write_tags(list(self.selected_tags), self.sort_by)
117 | return gr.CheckboxGroup.update(value=selected_tags, choices=tags)
--------------------------------------------------------------------------------
/DESCRIPTION_OF_DISPLAY-JP.md:
--------------------------------------------------------------------------------
1 |
2 | # 表示内容
3 |
4 | ## 共通
5 |
6 | 
7 |
8 | - "Save all changes" ボタン
9 | - キャプションをテキストファイルに保存します。このボタンを押すまで全ての変更は適用されません。
10 | - "Backup original text file" にチェックを入れることで、保存時にオリジナルのテキストファイル名をバックアップします。
11 | - バックアップファイル名は、filename.000、 -.001、 -.002、…、のように付けられます。
12 | - キャプションを含むテキストファイルが無い場合は新しく作成されます。
13 | - "Reload/Save Settings" アコーディオン
14 | - UIで表示されている設定を全て再読み込み・保存したり、デフォルトに戻せます。
15 | - "Reload settings" : 設定を再読み込みします。
16 | - "Save current settings" : 現在の設定を保存します。
17 | - "Restore settings to default" : 設定をデフォルトに戻します(保存はしません)。
18 | - 設定は `.../tag-editor-root-dir/config.json` に保存されています。
19 | - "Dataset Directory" テキストボックス
20 | - 学習データセットのあるディレクトリを入力してください。
21 | - 下のオプションからロード方法を変更できます。
22 | - "Load from subdirectories" をチェックすると、全てのサブディレクトリを含めて読み込みます。
23 | - "Load captioin from filename if no text file exists" をチェックすると、画像と同名のテキストファイルが無い場合に画像ファイル名からキャプションを読み込みます。
24 | - "Use Interrogator Caption" ラジオボタン
25 | - 読み込み時にBLIPやDeepDanbooruを使用するか、またその結果をキャプションにどう反映させるかを選びます。
26 | - "No": BLIPやDeepDanbooruを使用しません。
27 | - "If Empty": キャプションが無い場合のみ使用します。
28 | - "Overwrite" / "Prepend" / "Append": 生成したキャプションで上書き/先頭に追加/末尾に追加します。
29 | - "Dataset Images" ギャラリー
30 | - 教師画像の確認と選択ができます。
31 | - 表示する列数はwebUIの "Settings" タブから変更できます。
32 |
33 | ***
34 |
35 | ## "Filter by Tags" タブ
36 | 
37 | ## 共通
38 | - "Clear tag filters" ボタン
39 | - タグの検索やタグによる画像の絞り込みを取り消します。
40 | - "Clear ALL filters" ボタン
41 | - "Filter by Selection" タブでの画像選択による絞り込みを含めて、全ての絞り込みを取り消します。
42 |
43 | ## Search tags / Filter images by tags
44 | "Positive Filter" : 指定した条件を**満たす**画像を表示します。
45 | "Negative Filter" : 指定した条件を**満たさない**画像を表示します。
46 | 両フィルタは同時に指定可能です
47 |
48 | - "Search Tags" テキストボックス
49 | - 入力した文字で下に表示されているタグを検索し絞り込みます。
50 | - "Sort by / Sort order" ラジオボタン
51 | - 下に表示されているタグの並び順を切り替えます。
52 | - Alphabetical Order / Frequency : アルファベット順/出現頻度順
53 | - Ascending / Descending : 昇順/降順
54 | - "Filter Logic" ラジオボタン
55 | - 絞り込みの方法を指定します。
56 | - "AND" : 選択したタグを全て含む画像
57 | - "OR" : 選択したタグのいずれかを含む画像
58 | - "NONE" : フィルタを無効にする
59 | - "Filter Images by Tags" チェックボックス
60 | - 選択したタグによって左の画像を絞り込みます。絞り込まれた画像のキャプションの内容に応じて、タグも絞り込まれます。
61 |
62 | ***
63 |
64 | ## "Filter by Selection" タブ
65 | 
66 |
67 | - "Add selection" ボタン
68 | - 左で選択した画像を選択対象に追加します。
69 | - ショートカットは "Enter" キーです。
70 | - Tips: ギャラリーで選択している画像は矢印キーでも変更できます。
71 | - "Remove selection" ボタン
72 | - "Filter Images" で選択している画像を選択対象から外します。
73 | - ショートカットは "Delete" キーです。
74 | - "Invert selection" ボタン
75 | - 現在の選択対象を反転し、全データセットのうち選択されていないものに変更します。
76 | - "Clear selection" ボタン
77 | - 全ての選択を解除します。既にある絞り込みは解除しません。
78 | - "Apply selection filter" ボタン
79 | - 選択対象によって左の画像を絞り込みます。
80 |
81 |
82 | ***
83 |
84 | ## "Batch Edit Captions" タブ
85 | 
86 | ## "Search and Replace" タブ
87 | 複数のタグを一括置換できます
88 | - "Edit common tags" は、表示されている画像のタグを編集するシンプルな方法です。
89 | - "Common Tags" テキストボックス (編集不可)
90 | - 表示されている画像に共通するタグをカンマ区切りで表示します。
91 | - 上の "Show only … Positive Filter" をチェックすることで、"Filter by Tags" の "Positive Filter" で選択したもののみ表示することができます。
92 | - "Edit Tags" テキストボックス
93 | - 共通のタグを編集します。編集内容は絞り込まれている画像にのみ適用されます。表示されていないタグには影響しません。
94 | - 編集すると、カンマ区切りで同じ場所にあるタグを置換できます。
95 | - タグを空白に変えることで削除できます。
96 | - 末尾にタグを追加することでキャプションに新たなタグを追加できます。
97 | - タグが追加される位置はキャプションの先頭と末尾を選べます。
98 | - "Prepend additional tags" をチェックすると先頭、チェックを外すと末尾に追加します。
99 | - "Apply changes to filtered images" ボタン
100 | - 絞り込まれている画像に、タグの変更を適用します。
101 | - "Search and Replace" では、表示されている画像のタグまたはキャプション全体に対して一括置換ができます。
102 | - "Use regex" にチェックを入れることで、正規表現が利用可能です。
103 | - "Search Text" テキストボックス
104 | - 置換対象の文字列を入力します。
105 | - "Replace Text" テキストボックス
106 | - この文字列で "Search Text" を置換します。
107 | - "Search and Replace in" ラジオボタン
108 | - 一括置換の範囲を選択します
109 | - "Only Selected Tags" : "Positive Filter" で選択したタグのみ、それぞれのタグを個別に置換
110 | - "Each Tags" : それぞれのタグを個別に置換
111 | - "Entire Caption" : キャプション全体を一度に置換
112 | - "Search and Replace" ボタン
113 | - 一括置換を実行します
114 |
115 | ## "Remove" タブ
116 | 複数のタグを簡単に一括削除できます
117 | - "Remove duplicate tags" ボタン
118 | - キャプションの中に複数存在するタグを1つにする
119 | - "Remove selected tags" ボタン
120 | - 下で選択したタグを削除する
121 | - "Search Tags" からタグの検索と絞り込みが可能
122 | - "Select visible tags" で表示されているタグ全てを選択、"Deselect visible tags" で選択解除する。
123 |
124 | ***
125 |
126 | ## "Edit Caption of Selected Image" タブ
127 | 
128 |
129 | ## "Read Caption from Selected Image" タブ
130 | - "Caption of Selected Image" テキストボックス
131 | - 左で選択した画像のキャプションを表示します。
132 |
133 | ## "Interrogate Selected Image" タブ
134 | - "Interrogate Result" テキストボックス
135 | - 左で選択した画像にBLIPやDeepDanbooruを使用した結果を表示します。
136 |
137 | ## 共通
138 | - "Copy and Overwrite / Prepend / Apppend" ボタン
139 | - 上のテキストボックスの内容を、下のテキストボックスに、コピーして上書き/先頭に追加/末尾に追加します。
140 | - "Edit Caption" テキストボックス
141 | - ここでキャプションの編集が可能です。
142 | - "Apply changes to selected image" ボタン
143 | - 選択している画像のキャプションを "Edit Tags" の内容に変更します。
144 |
145 | ## "Move or Delete Files" タブ
146 | 
147 |
148 | - "Move or Delete" ラジオボタン
149 | - 操作を実行する対象を選びます。
150 | - "Selected One" : 左のギャラリーで選択されている画像のみ
151 | - "All Displayed Ones" : 左のギャラリーで表示されている画像全て
152 | - "Target" チェックボックス
153 | - 操作を実行する対象を選びます。
154 | - "Image File" : 画像ファイル
155 | - "Caption Text File" : キャプションファイル
156 | - "Caption Backup File" : キャプションファイルのバックアップ
157 | - "Move File(s)" ボタン
158 | - "Destination Directory" で指定したディレクトリにファイルを移動します。
159 | - "DELETE Files(s)" ボタン
160 | - ファイルを削除します。
161 | - 注意 : ごみ箱には送られず、完全に削除されます。
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dataset Tag Editor
2 | [**Stand alone version is here**](https://github.com/toshiaki1729/dataset-tag-editor-standalone): This may be better to avoid some known bugs.
3 |
4 | **Due to gradio update on webUI, the latest version don't support old webUI.**
5 | Please see [Releases](https://github.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor/releases) page and check the compatibility with webUI you are using.
6 |
7 | [日本語 Readme](README-JP.md)
8 |
9 | This is an extension to edit captions in training dataset for [Stable Diffusion web UI by AUTOMATIC1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
10 |
11 | 
12 |
13 | It works well with text captions in comma-separated style (such as the tags generated by DeepBooru interrogator).
14 |
15 | Caption in the filenames of images can be loaded, but edited captions can only be saved in the form of text files.
16 |
17 | ## Installation
18 | ### Extensions tab on WebUI
19 | Copy `https://github.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor.git` into "Install from URL" tab.
20 |
21 | Also, if you see this extension listed, you can install from "Available" tab with a single click.
22 |
23 | **Please note that if you update this extension from "Extensions" tab, you will need to restart web UI to reload completely.**
24 |
25 | ### Install Manually
26 | To install, clone the repository into the `extensions` directory and restart the web UI.
27 |
28 | On the web UI directory, run the following command to install:
29 | ```commandline
30 | git clone https://github.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor.git extensions/dataset-tag-editor
31 | ```
32 |
33 | ## Features
34 | Note. "tag" means each blocks of caption separated by commas.
35 | - Edit and save captions in text file (webUI style) or json file ([kohya-ss sd-scripts metadata](https://github.com/kohya-ss/sd-scripts))
36 | - Edit captions while viewing related images
37 | - Search tags
38 | - Filter images to edit their caption by tags
39 | - AND/OR logic can be used in each Positive/Negative filters
40 | - Batch replace/remove/append tags
41 | - Batch sort tags
42 | - Batch search and replace
43 | - [regular expression](https://docs.python.org/3/library/re.html#regular-expression-syntax) can be used
44 | - Use interrogators
45 | - BLIP, BLIP2, GIT, DeepDanbooru, [Z3D-E621-Convnext](https://huggingface.co/toynya/Z3D-E621-Convnext), SmilingWolf's [WDv1.4 Tagger](https://huggingface.co/SmilingWolf) (v1, v2, v3 and some variants of them)
46 | - You can add Custom Tagger in `userscripts/taggers` (they have to be wrapped by a class derived from `scripts.tagger.Tagger`)
47 | - Some Aesthetic Score Predictors are implemented in there
48 | - Batch remove image and/or caption files
49 |
50 |
51 | ## Usage
52 | 1. Make dataset using web UI
53 | - better to use already cropped images
54 | 1. Load them
55 | - use interrogator if needed
56 | 1. Edit their captions
57 | - filter images you want to edit by tags in "Filter by Tags" tab
58 | - filter images manually in "Filter by Selection" tab
59 | - replace/remove tags or append new tags in "Batch Edit Captions" tab
60 | - edit captions individually in "Edit Caption of Selected Image" tab
61 | - you also can use interrogator here
62 | - move/delete files in "Move or Delete Files" tab if needed
63 | 1. Click "Save all changes" button
64 |
65 |
66 | ## By the way, how can I edit tags quickly?
67 |
68 | Basic workflow is as follows:
69 |
70 | 1. Filter images
71 | 1. Batch edit
72 |
73 | Please note that all batch editing will be applyed **only to displayed images (=filtered images)**.
74 |
75 | ### 1. Which filter is appropriate?
76 | - **I want to edit all at once**
77 | No filter is required.
78 | - **Some images require editing**
79 | 1. **They should / shouldn't already have same tag(s)**
80 | Go to "Filter by Tags" so that the only images to be edited are displayed.
81 | 1. **They have nothing in common**
82 | Go to "Filter by Selection" and apply.
83 | Images can also be added to the filter by pushing [Enter] key.
84 |
85 | ### 2. How can I edit as I want?
86 | - **I want to add some new tags**
87 | 1. Go to "Batch Edit Captions" tab
88 | 1. Append tags to "Edit tags" textbox
89 | 1. Push "Apply changes to filtered images" button
90 | 
91 | "foo" and "bar" will be added to all images displayed.
92 |
93 | - **I want to replace the tags which are common to displayed images**
94 | 1. Go to "Batch Edit Captions" tab
95 | 1. Replace tags in "Edit tags" textbox
96 | 1. Push "Apply changes to filtered images" button
97 | 
98 | "male focus" and "solo" will be replaced with "foo" and "bar".
99 |
100 | - **I want to remove some tags**
101 | The same as replacing. Just replace the tags with "blank".
102 | Also you can use "Remove" tab in "Batch Edit Captions".
103 |
104 | - **I want to add/replace/remove tags more flexibly**
105 | 1. Go to "Batch Edit Captions" tab
106 | 2. Use "Search and Replace" with "Use regex" checked
107 | 
108 | "1boy", "2boys", … will be replaced with "1girl", "2girls", … in each tags of images displayed.
109 | A comma will be regarded as the sepalator of two tags.
110 | By using regex, you can add/replace/remove tags according to more complex conditions.
111 |
112 |
113 | ## Trouble shooting
114 | ### Cannot see any image in dataset and saying "All files must contained within the Gradio python app working directory…"
115 | Set folder to store temporaly image in the "Settings" tab.
116 | Input path in "Directory to save temporary files" and check "Force using temporary file…"
117 |
118 | ### So laggy when opening many images or extremely large image
119 | Check "Force image gallery to use temporary files" and input number in "Maximum resolution of ..." in the "Settings" tab.
120 | It may not work with dataset with millions of images.
121 | If it doesn't work, please consider using [**stand alone version**](https://github.com/toshiaki1729/dataset-tag-editor-standalone).
122 | 
123 |
124 |
125 | ## Description of Display
126 |
127 | Moved to [here](DESCRIPTION_OF_DISPLAY.md)
128 |
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/tab_move_or_delete_files.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING, List, Callable
3 |
4 | import gradio as gr
5 |
6 | from .ui_common import *
7 | from .uibase import UIBase
8 |
9 | if TYPE_CHECKING:
10 | from .ui_classes import *
11 |
12 |
13 | class MoveOrDeleteFilesUI(UIBase):
14 | def __init__(self):
15 | self.target_data = 'Selected One'
16 | self.current_target_txt = ''
17 | self.update_func = None
18 |
19 | def create_ui(self, cfg_file_move_delete):
20 | gr.HTML(value='Note: Moved or deleted images will be unloaded.')
21 | self.target_data = cfg_file_move_delete.range
22 | self.rb_move_or_delete_target_data = gr.Radio(choices=['Selected One', 'All Displayed Ones'], value=cfg_file_move_delete.range, label='Move or Delete')
23 | self.cbg_move_or_delete_target_file = gr.CheckboxGroup(choices=['Image File', 'Caption Text File', 'Caption Backup File'], label='Target', value=cfg_file_move_delete.target)
24 | self.tb_move_or_delete_caption_ext = gr.Textbox(label='Caption File Ext', placeholder='txt', value=cfg_file_move_delete.caption_ext)
25 | self.ta_move_or_delete_target_dataset_num = gr.HTML(value='Target dataset num: 0')
26 | self.tb_move_or_delete_destination_dir = gr.Textbox(label='Destination Directory', value=cfg_file_move_delete.destination)
27 | self.btn_move_or_delete_move_files = gr.Button(value='Move File(s)', variant='primary')
28 | gr.HTML(value='Note: DELETE cannot be undone. The files will be deleted completely.')
29 | self.btn_move_or_delete_delete_files = gr.Button(value='DELETE File(s)', variant='primary')
30 |
31 | def update_current_move_or_delete_target_num(self):
32 | if self.update_func:
33 | return self.update_func(self.target_data)
34 | else:
35 | return self.current_target_txt
36 |
37 | def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
38 | def _get_current_move_or_delete_target_num(text: str):
39 | self.target_data = text
40 | if self.target_data == 'Selected One':
41 | self.current_target_txt = f'Target dataset num: {1 if dataset_gallery.selected_index != -1 else 0}'
42 | elif self.target_data == 'All Displayed Ones':
43 | img_paths = dte_instance.get_filtered_imgpaths(filters=get_filters())
44 | self.current_target_txt = f'Target dataset num: {len(img_paths)}'
45 | else:
46 | self.current_target_txt = f'Target dataset num: 0'
47 | return self.current_target_txt
48 |
49 | self.update_func = _get_current_move_or_delete_target_num
50 |
51 | update_args = {
52 | 'fn': self.update_func,
53 | 'inputs': [self.rb_move_or_delete_target_data],
54 | 'outputs' : [self.ta_move_or_delete_target_dataset_num]
55 | }
56 |
57 | self.rb_move_or_delete_target_data.change(**update_args)
58 | dataset_gallery.cbg_hidden_dataset_filter.change(lambda:None).then(**update_args)
59 | dataset_gallery.nb_hidden_image_index.change(lambda:None).then(**update_args)
60 |
61 | def move_files(
62 | target_data: str,
63 | target_file, #: List[str], : to avoid error on gradio v3.23.0
64 | caption_ext: str,
65 | dest_dir: str):
66 | move_img = 'Image File' in target_file
67 | move_txt = 'Caption Text File' in target_file
68 | move_bak = 'Caption Backup File' in target_file
69 | if target_data == 'Selected One':
70 | img_path = dataset_gallery.selected_path
71 | if img_path:
72 | dte_instance.move_dataset_file(img_path, caption_ext, dest_dir, move_img, move_txt, move_bak)
73 | dte_instance.construct_tag_infos()
74 |
75 | elif target_data == 'All Displayed Ones':
76 | dte_instance.move_dataset(dest_dir, caption_ext, get_filters(), move_img, move_txt, move_bak)
77 |
78 | return update_filter_and_gallery()
79 |
80 | self.btn_move_or_delete_move_files.click(
81 | fn=move_files,
82 | inputs=[self.rb_move_or_delete_target_data, self.cbg_move_or_delete_target_file, self.tb_move_or_delete_caption_ext, self.tb_move_or_delete_destination_dir],
83 | outputs=o_update_filter_and_gallery
84 | ).then(**update_args).then(
85 | fn=None,
86 | _js='() => dataset_tag_editor_gl_dataset_images_close()'
87 | )
88 |
89 | def delete_files(
90 | target_data: str,
91 | target_file, #: List[str], : to avoid error on gradio v3.23.0
92 | caption_ext: str):
93 | delete_img = 'Image File' in target_file
94 | delete_txt = 'Caption Text File' in target_file
95 | delete_bak = 'Caption Backup File' in target_file
96 | if target_data == 'Selected One':
97 | img_path = dataset_gallery.selected_path
98 | if img_path:
99 | dte_instance.delete_dataset_file(img_path, delete_img, caption_ext, delete_txt, delete_bak)
100 | dte_instance.construct_tag_infos()
101 |
102 | elif target_data == 'All Displayed Ones':
103 | dte_instance.delete_dataset(caption_ext, get_filters(), delete_img, delete_txt, delete_bak)
104 |
105 | return update_filter_and_gallery()
106 |
107 | self.btn_move_or_delete_delete_files.click(
108 | fn=delete_files,
109 | inputs=[self.rb_move_or_delete_target_data, self.cbg_move_or_delete_target_file, self.tb_move_or_delete_caption_ext],
110 | outputs=o_update_filter_and_gallery
111 | )
112 | self.btn_move_or_delete_delete_files.click(**update_args).then(
113 | fn=None,
114 | _js='() => dataset_tag_editor_gl_dataset_images_close()'
115 | )
116 |
--------------------------------------------------------------------------------
/javascript/99_main.js:
--------------------------------------------------------------------------------
1 | let dteModifiedGallery_dataset = new DTEModifiedGallery()
2 | let dteModifiedGallery_filter = new DTEModifiedGallery()
3 |
4 |
5 | function dataset_tag_editor_gl_dataset_images_selected_index() {
6 | return dteModifiedGallery_dataset.getSelectedIndex()
7 | }
8 |
9 | function dataset_tag_editor_gl_filter_images_selected_index() {
10 | return dteModifiedGallery_filter.getSelectedIndex()
11 | }
12 |
13 | function dataset_tag_editor_gl_dataset_images_filter(indices) {
14 | dteModifiedGallery_dataset.filter(indices)
15 | return indices
16 | }
17 |
18 | function dataset_tag_editor_gl_dataset_images_clear_filter() {
19 | dteModifiedGallery_dataset.clearFilter()
20 | return []
21 | }
22 |
23 | function dataset_tag_editor_gl_dataset_images_close() {
24 | dteModifiedGallery_dataset.clickClose()
25 | }
26 |
27 | function dataset_tag_editor_gl_filter_images_close() {
28 | dteModifiedGallery_filter.clickClose()
29 | }
30 |
31 | let dataset_tag_editor_gl_dataset_images_clicked = function () {
32 | dteModifiedGallery_dataset.updateFilter()
33 | dteModifiedGallery_dataset.clickHandler()
34 | let set_button = gradioApp().getElementById("dataset_tag_editor_btn_hidden_set_index");
35 | if(set_button){
36 | set_button.click()
37 | }
38 | }
39 |
40 | let dataset_tag_editor_gl_dataset_images_next_clicked = function () {
41 | dteModifiedGallery_dataset.updateFilter()
42 | dteModifiedGallery_dataset.clickNextHandler()
43 | let set_button = gradioApp().getElementById("dataset_tag_editor_btn_hidden_set_index");
44 | if(set_button){
45 | set_button.click()
46 | }
47 | }
48 |
49 | let dataset_tag_editor_gl_dataset_images_close_clicked = function () {
50 | dteModifiedGallery_dataset.updateFilter()
51 | dteModifiedGallery_dataset.clickCloseHandler()
52 | let set_button = gradioApp().getElementById("dataset_tag_editor_btn_hidden_set_index");
53 | if(set_button){
54 | set_button.click()
55 | }
56 | }
57 |
58 | let dataset_tag_editor_gl_dataset_images_key_handler = function (e) {
59 | dteModifiedGallery_dataset.keyHandler(e)
60 | switch(e.key)
61 | {
62 | case 'Enter':
63 | let button = gradioApp().getElementById('dataset_tag_editor_btn_add_image_selection');
64 | if (button) {
65 | button.click();
66 | }
67 | e.preventDefault();
68 | break;
69 | }
70 | let set_button = gradioApp().getElementById("dataset_tag_editor_btn_hidden_set_index");
71 | if(set_button){
72 | set_button.click()
73 | }
74 | }
75 |
76 |
77 | let dataset_tag_editor_gl_filter_images_clicked = function () {
78 | dteModifiedGallery_filter.updateFilter()
79 | dteModifiedGallery_filter.clickHandler()
80 | let set_button = gradioApp().getElementById("dataset_tag_editor_btn_hidden_set_selection_index");
81 | if(set_button){
82 | set_button.click()
83 | }
84 | }
85 |
86 | let dataset_tag_editor_gl_filter_images_next_clicked = function () {
87 | dteModifiedGallery_filter.updateFilter()
88 | dteModifiedGallery_filter.clickNextHandler()
89 | let set_button = gradioApp().getElementById("dataset_tag_editor_btn_hidden_set_selection_index");
90 | if(set_button){
91 | set_button.click()
92 | }
93 | }
94 |
95 | let dataset_tag_editor_gl_filter_images_close_clicked = function () {
96 | dteModifiedGallery_filter.updateFilter()
97 | dteModifiedGallery_filter.clickCloseHandler()
98 | let set_button = gradioApp().getElementById("dataset_tag_editor_btn_hidden_set_selection_index");
99 | if(set_button){
100 | set_button.click()
101 | }
102 | }
103 |
104 | let dataset_tag_editor_gl_filter_images_key_handler = function (e) {
105 | dteModifiedGallery_filter.keyHandler(e)
106 | switch(e.key)
107 | {
108 | case 'Delete':
109 | let button = gradioApp().getElementById('dataset_tag_editor_btn_remove_image_selection');
110 | if (button) {
111 | button.click();
112 | }
113 | e.preventDefault();
114 | break;
115 | }
116 | let set_button = gradioApp().getElementById("dataset_tag_editor_btn_hidden_set_selection_index");
117 | if(set_button){
118 | set_button.click()
119 | }
120 | }
121 |
122 | document.addEventListener("DOMContentLoaded", function () {
123 | let o = new MutationObserver(function (m) {
124 | let elem_gl_dataset = gradioApp().getElementById("dataset_tag_editor_dataset_gallery")
125 | let elem_gl_filter = gradioApp().getElementById("dataset_tag_editor_filter_gallery")
126 | if(elem_gl_dataset){
127 | dteModifiedGallery_dataset.setElement(elem_gl_dataset)
128 | dteModifiedGallery_dataset.addKeyHandler(dataset_tag_editor_gl_dataset_images_key_handler)
129 | dteModifiedGallery_dataset.addClickHandler(dataset_tag_editor_gl_dataset_images_clicked)
130 | dteModifiedGallery_dataset.addClickNextHandler(dataset_tag_editor_gl_dataset_images_next_clicked)
131 | dteModifiedGallery_dataset.addClickCloseHandler(dataset_tag_editor_gl_dataset_images_close_clicked)
132 | }
133 | if(elem_gl_filter){
134 | dteModifiedGallery_filter.setElement(elem_gl_filter)
135 | dteModifiedGallery_filter.addKeyHandler(dataset_tag_editor_gl_filter_images_key_handler)
136 | dteModifiedGallery_filter.addClickHandler(dataset_tag_editor_gl_filter_images_clicked)
137 | dteModifiedGallery_filter.addClickNextHandler(dataset_tag_editor_gl_filter_images_next_clicked)
138 | dteModifiedGallery_filter.addClickCloseHandler(dataset_tag_editor_gl_filter_images_close_clicked)
139 | }
140 |
141 | if(gradioApp().getElementById('settings_json') == null) return
142 | function changeTokenCounterPos(id, id_counter){
143 | var prompt = gradioApp().getElementById(id)
144 | var counter = gradioApp().getElementById(id_counter)
145 |
146 | if(counter.parentElement == prompt.parentElement){
147 | return
148 | }
149 |
150 | prompt.parentElement.insertBefore(counter, prompt)
151 | prompt.parentElement.style.position = "relative"
152 | counter.style.width = "auto"
153 | }
154 | changeTokenCounterPos('dte_caption', 'dte_caption_counter')
155 | changeTokenCounterPos('dte_edit_caption', 'dte_edit_caption_counter')
156 | changeTokenCounterPos('dte_interrogate', 'dte_interrogate_counter')
157 | });
158 |
159 | o.observe(gradioApp(), { childList: true, subtree: true })
160 | });
--------------------------------------------------------------------------------
/scripts/dataset_tag_editor/taggers_builtin.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from PIL import Image
4 | import numpy as np
5 | import torch
6 |
7 | from modules import devices, shared
8 | from modules import deepbooru as db
9 |
10 | from scripts.tagger import Tagger, get_replaced_tag
11 | from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger, WaifuDiffusionTaggerTimm
12 |
13 |
14 | class BLIP(Tagger):
15 | def start(self):
16 | shared.interrogator.load()
17 |
18 | def stop(self):
19 | shared.interrogator.unload()
20 |
21 | def predict(self, image:Image.Image, threshold=None):
22 | tags = shared.interrogator.generate_caption(image).split(',')
23 | return [t for t in tags if t]
24 |
25 | def name(self):
26 | return 'BLIP'
27 |
28 |
29 |
30 | class BLIP2(Tagger):
31 | def __init__(self, repo_name):
32 | self.interrogator = BLIP2Captioning("Salesforce/" + repo_name)
33 | self.repo_name = repo_name
34 |
35 | def start(self):
36 | self.interrogator.load()
37 |
38 | def stop(self):
39 | self.interrogator.unload()
40 |
41 | def predict(self, image:Image, threshold=None):
42 | tags = self.interrogator.apply(image)[0].split(",")
43 | return [t for t in tags if t]
44 |
45 | # def predict_multi(self, images:list):
46 | # captions = self.interrogator.apply(images)
47 | # return [[t for t in caption.split(',') if t] for caption in captions]
48 |
49 | def name(self):
50 | return self.repo_name
51 |
52 |
53 | class GITLarge(Tagger):
54 | def __init__(self):
55 | self.interrogator = GITLargeCaptioning()
56 |
57 | def start(self):
58 | self.interrogator.load()
59 |
60 | def stop(self):
61 | self.interrogator.unload()
62 |
63 | def predict(self, image:Image, threshold=None):
64 | tags = self.interrogator.apply(image)[0].split(",")
65 | return [t for t in tags if t]
66 |
67 | # def predict_multi(self, images:list):
68 | # captions = self.interrogator.apply(images)
69 | # return [[t for t in caption.split(',') if t] for caption in captions]
70 |
71 | def name(self):
72 | return "GIT-large-COCO"
73 |
74 |
75 | class DeepDanbooru(Tagger):
76 | def start(self):
77 | db.model.start()
78 |
79 | def stop(self):
80 | db.model.stop()
81 |
82 | # brought from webUI modules/deepbooru.py and modified
83 | def predict(self, image: Image.Image, threshold: Optional[float] = None):
84 | from modules import images
85 |
86 | pic = images.resize_image(2, image.convert("RGB"), 512, 512)
87 | a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
88 |
89 | with torch.no_grad(), devices.autocast():
90 | x = torch.from_numpy(a).to(devices.device)
91 | y = db.model.model(x)[0].detach().cpu().numpy()
92 |
93 | tags = []
94 |
95 | for tag, probability in zip(db.model.model.tags, y):
96 | if threshold and probability < threshold:
97 | continue
98 | if not shared.opts.dataset_editor_use_rating and tag.startswith("rating:"):
99 | continue
100 | tags.append(get_replaced_tag(tag))
101 |
102 | return tags
103 |
104 | def name(self):
105 | return 'DeepDanbooru'
106 |
107 |
108 | class WaifuDiffusion(Tagger):
109 | def __init__(self, repo_name, threshold):
110 | self.repo_name = repo_name
111 | self.tagger_inst = WaifuDiffusionTagger("SmilingWolf/" + repo_name)
112 | self.threshold = threshold
113 |
114 | def start(self):
115 | self.tagger_inst.load()
116 | return self
117 |
118 | def stop(self):
119 | self.tagger_inst.unload()
120 |
121 | # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
122 | # set threshold<0 to use default value for now...
123 | def predict(self, image: Image.Image, threshold: Optional[float] = None):
124 | # may not use ratings
125 | # rating = dict(labels[:4])
126 |
127 | labels = self.tagger_inst.apply(image)
128 |
129 | if not shared.opts.dataset_editor_use_rating:
130 | labels = labels[4:]
131 |
132 | if threshold is not None:
133 | if threshold < 0:
134 | threshold = self.threshold
135 | tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
136 | else:
137 | tags = [get_replaced_tag(tag) for tag, _ in labels]
138 |
139 | return tags
140 |
141 | def name(self):
142 | return self.repo_name
143 |
144 |
145 | class WaifuDiffusionTimm(WaifuDiffusion):
146 | def __init__(self, repo_name, threshold, batch_size=4):
147 | super().__init__(repo_name, threshold)
148 | self.tagger_inst = WaifuDiffusionTaggerTimm("SmilingWolf/" + repo_name)
149 | self.batch_size = batch_size
150 |
151 | def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None):
152 | for labels_list in self.tagger_inst.apply_multi(data, batch_size=self.batch_size):
153 | for labels in labels_list:
154 | if not shared.opts.dataset_editor_use_rating:
155 | labels = labels[4:]
156 |
157 | if threshold is not None:
158 | if threshold < 0:
159 | threshold = self.threshold
160 | tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
161 | else:
162 | tags = [get_replaced_tag(tag) for tag, _ in labels]
163 |
164 | yield tags
165 |
166 |
167 | class Z3D_E621(Tagger):
168 | def __init__(self):
169 | self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv")
170 |
171 | def start(self):
172 | self.tagger_inst.load()
173 | return self
174 |
175 | def stop(self):
176 | self.tagger_inst.unload()
177 |
178 | # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
179 | # set threshold<0 to use default value for now...
180 | def predict(self, image: Image.Image, threshold: Optional[float] = None):
181 | # may not use ratings
182 | # rating = dict(labels[:4])
183 |
184 | labels = self.tagger_inst.apply(image)
185 | if threshold is not None:
186 | tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
187 | else:
188 | tags = [get_replaced_tag(tag) for tag, _ in labels]
189 |
190 | return tags
191 |
192 | def name(self):
193 | return "Z3D-E621-Convnext"
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/block_tag_filter.py:
--------------------------------------------------------------------------------
1 | from typing import List, Callable
2 | import gradio as gr
3 |
4 | from .ui_common import *
5 |
6 | filters = dte_module.filters
7 | TagFilter = filters.TagFilter
8 |
9 | SortBy = dte_instance.SortBy
10 | SortOrder = dte_instance.SortOrder
11 |
12 |
13 |
14 | class TagFilterUI():
15 | def __init__(self, tag_filter_mode = TagFilter.Mode.INCLUSIVE):
16 | self.logic = TagFilter.Logic.AND
17 | self.filter_word = ''
18 | self.sort_by = SortBy.ALPHA
19 | self.sort_order = SortOrder.ASC
20 | self.selected_tags = set()
21 | self.filter_mode = tag_filter_mode
22 | self.filter = TagFilter(logic=self.logic, mode=self.filter_mode)
23 | self.get_filters = lambda:[]
24 | self.prefix = False
25 | self.suffix = False
26 | self.regex = False
27 | self.on_filter_update_callbacks = []
28 |
29 | def get_filter(self):
30 | return self.filter
31 |
32 | def create_ui(self, get_filters: Callable[[], List[filters.Filter]], logic = TagFilter.Logic.AND, sort_by = SortBy.ALPHA, sort_order = SortOrder.ASC, prefix=False, suffix=False, regex=False):
33 | self.get_filters = get_filters
34 | self.logic = logic
35 | self.filter = filters.TagFilter(logic=self.logic, mode=self.filter_mode)
36 | self.sort_by = sort_by
37 | self.sort_order = sort_order
38 | self.prefix = prefix
39 | self.suffix = suffix
40 | self.regex = regex
41 |
42 | self.tb_search_tags = gr.Textbox(label='Search Tags', interactive=True)
43 | with gr.Row():
44 | self.cb_prefix = gr.Checkbox(label='Prefix', value=self.prefix, interactive=True)
45 | self.cb_suffix = gr.Checkbox(label='Suffix', value=self.suffix, interactive=True)
46 | self.cb_regex = gr.Checkbox(label='Use regex', value=self.regex, interactive=True)
47 | with gr.Row():
48 | self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=sort_by, interactive=True, label='Sort by')
49 | self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=sort_order, interactive=True, label='Sort Order')
50 | v = 'AND' if self.logic==TagFilter.Logic.AND else 'OR' if self.logic==TagFilter.Logic.OR else 'NONE'
51 | self.rb_logic = gr.Radio(choices=['AND', 'OR', 'NONE'], value=v, label='Filter Logic', interactive=True)
52 | self.cbg_tags = gr.CheckboxGroup(label='Filter Images by Tags', interactive=True)
53 |
54 |
55 | def on_filter_update(self, fn:Callable[[List], List], inputs=None, outputs=None, _js=None):
56 | self.on_filter_update_callbacks.append((fn, inputs, outputs, _js))
57 |
58 |
59 | def set_callbacks(self):
60 | self.tb_search_tags.change(fn=self.tb_search_tags_changed, inputs=self.tb_search_tags, outputs=self.cbg_tags)
61 | self.cb_prefix.change(fn=self.cb_prefix_changed, inputs=self.cb_prefix, outputs=self.cbg_tags)
62 | self.cb_suffix.change(fn=self.cb_suffix_changed, inputs=self.cb_suffix, outputs=self.cbg_tags)
63 | self.cb_regex.change(fn=self.cb_regex_changed, inputs=self.cb_regex, outputs=self.cbg_tags)
64 | self.rb_sort_by.change(fn=self.rd_sort_by_changed, inputs=self.rb_sort_by, outputs=self.cbg_tags)
65 | self.rb_sort_order.change(fn=self.rd_sort_order_changed, inputs=self.rb_sort_order, outputs=self.cbg_tags)
66 |
67 | self.rb_logic.change(fn=self.rd_logic_changed, inputs=[self.rb_logic], outputs=[self.cbg_tags])
68 | for fn, inputs, outputs, _js in self.on_filter_update_callbacks:
69 | self.rb_logic.change(fn=lambda:None).then(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
70 | self.cbg_tags.change(fn=self.cbg_tags_changed, inputs=[self.cbg_tags], outputs=[self.cbg_tags])
71 | for fn, inputs, outputs, _js in self.on_filter_update_callbacks:
72 | self.cbg_tags.change(fn=lambda:None).then(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
73 |
74 |
75 | def tb_search_tags_changed(self, tb_search_tags: str):
76 | self.filter_word = tb_search_tags
77 | return self.cbg_tags_update()
78 |
79 |
80 | def cb_prefix_changed(self, prefix:bool):
81 | self.prefix = prefix
82 | return self.cbg_tags_update()
83 |
84 |
85 | def cb_suffix_changed(self, suffix:bool):
86 | self.suffix = suffix
87 | return self.cbg_tags_update()
88 |
89 |
90 | def cb_regex_changed(self, use_regex:bool):
91 | self.regex = use_regex
92 | return self.cbg_tags_update()
93 |
94 |
95 | def rd_sort_by_changed(self, rb_sort_by: str):
96 | self.sort_by = rb_sort_by
97 | return self.cbg_tags_update()
98 |
99 |
100 | def rd_sort_order_changed(self, rd_sort_order: str):
101 | self.sort_order = rd_sort_order
102 | return self.cbg_tags_update()
103 |
104 |
105 | def rd_logic_changed(self, rd_logic: str):
106 | self.logic = TagFilter.Logic.AND if rd_logic == 'AND' else TagFilter.Logic.OR if rd_logic == 'OR' else TagFilter.Logic.NONE
107 | self.filter = TagFilter(self.selected_tags, self.logic, self.filter_mode)
108 | return self.cbg_tags_update()
109 |
110 |
111 | def cbg_tags_changed(self,
112 | cbg_tags#: List[str]
113 | ):
114 | self.selected_tags = dte_instance.cleanup_tagset(set(dte_instance.read_tags(cbg_tags)))
115 | return self.cbg_tags_update()
116 |
117 |
118 | def cbg_tags_update(self):
119 | self.selected_tags = dte_instance.cleanup_tagset(self.selected_tags)
120 | self.filter = TagFilter(self.selected_tags, self.logic, self.filter_mode)
121 |
122 | if self.filter_mode == TagFilter.Mode.INCLUSIVE:
123 | tags = dte_instance.get_filtered_tags(self.get_filters(), self.filter_word, self.filter.logic == TagFilter.Logic.AND, prefix=self.prefix, suffix=self.suffix, regex=self.regex)
124 | else:
125 | tags = dte_instance.get_filtered_tags(self.get_filters(), self.filter_word, self.filter.logic == TagFilter.Logic.OR, prefix=self.prefix, suffix=self.suffix, regex=self.regex)
126 | tags_in_filter = self.filter.tags
127 |
128 | tags = dte_instance.sort_tags(tags=tags, sort_by=self.sort_by, sort_order=self.sort_order)
129 | tags_in_filter = dte_instance.sort_tags(tags=tags_in_filter, sort_by=self.sort_by, sort_order=self.sort_order)
130 |
131 | tags = tags_in_filter + [tag for tag in tags if tag not in self.filter.tags]
132 | tags = dte_instance.write_tags(tags, self.sort_by)
133 | tags_in_filter = dte_instance.write_tags(tags_in_filter, self.sort_by)
134 |
135 | return gr.CheckboxGroup.update(value=tags_in_filter, choices=tags)
136 |
137 |
138 | def clear_filter(self):
139 | self.filter = TagFilter(logic=self.logic, mode=self.filter_mode)
140 | self.filter_word = ''
141 | self.selected_tags = set()
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/tab_filter_by_selection.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING, List, Callable
3 | import gradio as gr
4 |
5 | from .ui_common import *
6 | from .uibase import UIBase
7 |
8 | if TYPE_CHECKING:
9 | from .ui_classes import *
10 |
11 | filters = dte_module.filters
12 |
13 |
14 | class FilterBySelectionUI(UIBase):
15 | def __init__(self):
16 | self.path_filter = filters.PathFilter()
17 | self.selected_index = -1
18 | self.selected_path = ''
19 | self.tmp_selection = set()
20 |
21 | def get_current_txt_selection(self):
22 | return f"""Selected Image : {self.selected_path}"""
23 |
24 | def create_ui(self, image_columns:int):
25 | with gr.Row(visible=False):
26 | self.btn_hidden_set_selection_index = gr.Button(elem_id="dataset_tag_editor_btn_hidden_set_selection_index")
27 | self.nb_hidden_selection_image_index = gr.Number(value=-1)
28 | gr.HTML("""Select images from the left gallery.""")
29 |
30 | with gr.Column(variant='panel'):
31 | with gr.Row():
32 | self.btn_add_image_selection = gr.Button(value='Add selection [Enter]', elem_id='dataset_tag_editor_btn_add_image_selection')
33 | self.btn_add_all_displayed_image_selection = gr.Button(value='Add ALL Displayed')
34 |
35 | self.gl_filter_images = gr.Gallery(label='Filter Images', elem_id="dataset_tag_editor_filter_gallery", columns=image_columns)
36 | self.txt_selection = gr.HTML(value=self.get_current_txt_selection())
37 |
38 | with gr.Row():
39 | self.btn_remove_image_selection = gr.Button(value='Remove selection [Delete]', elem_id='dataset_tag_editor_btn_remove_image_selection')
40 | self.btn_invert_image_selection = gr.Button(value='Invert selection')
41 | self.btn_clear_image_selection = gr.Button(value='Clear selection')
42 |
43 | self.btn_apply_image_selection_filter = gr.Button(value='Apply selection filter', variant='primary')
44 |
45 | def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
46 | def selection_index_changed(idx:int = -1):
47 | idx = int(idx) if idx is not None else -1
48 | img_paths = arrange_selection_order(self.tmp_selection)
49 | if idx < 0 or len(img_paths) <= idx:
50 | self.selected_path = ''
51 | idx = -1
52 | else:
53 | self.selected_path = img_paths[idx]
54 | self.selected_index = idx
55 | return [self.get_current_txt_selection(), idx]
56 |
57 | self.btn_hidden_set_selection_index.click(
58 | fn=selection_index_changed,
59 | _js="(x) => dataset_tag_editor_gl_filter_images_selected_index()",
60 | inputs=[self.nb_hidden_selection_image_index],
61 | outputs=[self.txt_selection, self.nb_hidden_selection_image_index]
62 | )
63 |
64 | def add_image_selection():
65 | img_path = dataset_gallery.selected_path
66 | if img_path:
67 | self.tmp_selection.add(img_path)
68 | return [dte_instance.images[p] for p in arrange_selection_order(self.tmp_selection)]
69 |
70 | self.btn_add_image_selection.click(
71 | fn=add_image_selection,
72 | outputs=[self.gl_filter_images]
73 | )
74 |
75 | def add_all_displayed_image_selection():
76 | img_paths = dte_instance.get_filtered_imgpaths(filters=get_filters())
77 | self.tmp_selection |= set(img_paths)
78 | return [dte_instance.images[p] for p in arrange_selection_order(self.tmp_selection)]
79 |
80 | self.btn_add_all_displayed_image_selection.click(
81 | fn=add_all_displayed_image_selection,
82 | outputs=self.gl_filter_images
83 | )
84 |
85 | def invert_image_selection():
86 | img_paths = dte_instance.get_img_path_set()
87 | self.tmp_selection = img_paths - self.tmp_selection
88 | return [dte_instance.images[p] for p in arrange_selection_order(self.tmp_selection)]
89 |
90 | self.btn_invert_image_selection.click(
91 | fn=invert_image_selection,
92 | outputs=self.gl_filter_images
93 | )
94 |
95 | def remove_image_selection():
96 | img_path = self.selected_path
97 | if img_path:
98 | self.tmp_selection.remove(img_path)
99 | self.selected_path = ''
100 | self.selected_index = -1
101 |
102 | return [
103 | [dte_instance.images[p] for p in arrange_selection_order(self.tmp_selection)],
104 | self.get_current_txt_selection(),
105 | -1
106 | ]
107 |
108 | self.btn_remove_image_selection.click(
109 | fn=remove_image_selection,
110 | outputs=[self.gl_filter_images, self.txt_selection, self.nb_hidden_selection_image_index]
111 | )
112 |
113 | def clear_image_selection():
114 | self.tmp_selection.clear()
115 | self.selected_path = ''
116 | self.selected_index = -1
117 | return[
118 | [],
119 | self.get_current_txt_selection(),
120 | -1
121 | ]
122 |
123 | self.btn_clear_image_selection.click(
124 | fn=clear_image_selection,
125 | outputs=
126 | [self.gl_filter_images, self.txt_selection, self.nb_hidden_selection_image_index]
127 | )
128 |
129 | def clear_image_filter():
130 | self.path_filter = filters.PathFilter()
131 | return clear_image_selection() + update_filter_and_gallery()
132 |
133 | filter_by_tags.btn_clear_all_filters.click(lambda:None).then(
134 | fn=clear_image_filter,
135 | outputs=
136 | [self.gl_filter_images, self.txt_selection, self.nb_hidden_selection_image_index] +
137 | o_update_filter_and_gallery
138 | )
139 |
140 | def apply_image_selection_filter():
141 | if len(self.tmp_selection) > 0:
142 | self.path_filter = filters.PathFilter(self.tmp_selection, filters.PathFilter.Mode.INCLUSIVE)
143 | else:
144 | self.path_filter = filters.PathFilter()
145 | return update_filter_and_gallery()
146 |
147 | self.btn_apply_image_selection_filter.click(
148 | fn=apply_image_selection_filter,
149 | outputs=o_update_filter_and_gallery
150 | ).then(
151 | fn=None,
152 | _js='() => dataset_tag_editor_gl_dataset_images_close()'
153 | )
154 |
155 |
156 | def arrange_selection_order(paths: List[str]):
157 | return sorted(paths)
158 |
159 |
--------------------------------------------------------------------------------
/javascript/00_modified_gallery.js:
--------------------------------------------------------------------------------
1 | class DTEModifiedGallery{
2 | #elem;
3 | #items_grid;
4 | #items_selector;
5 | #current_filter = null;
6 | #selected_idx = -1
7 | #filter_idx = -1
8 |
9 | setElement(elem){
10 | this.#elem = elem;
11 | this.#items_grid = this.#elem.querySelectorAll('div.grid-wrap > div.grid-container > button.thumbnail-item')
12 | this.#items_selector = this.#elem.querySelectorAll('div.preview > div.thumbnails > button.thumbnail-item')
13 | }
14 |
15 | updateFilter(){
16 | if (!this.#elem) return;
17 |
18 | if (this.#items_grid){
19 | for(let i = 0; i < this.#items_grid.length; ++i){
20 | if(!this.#current_filter || this.#current_filter.includes(i)){
21 | this.#items_grid[i].hidden = false
22 | }
23 | else{
24 | this.#items_grid[i].hidden = true
25 | }
26 | }
27 | }
28 | if(this.#items_selector){
29 | for(let i = 0; i < this.#items_selector.length; ++i){
30 | if(!this.#current_filter || this.#current_filter.includes(i)){
31 | this.#items_selector[i].hidden = false
32 | }
33 | else{
34 | this.#items_selector[i].hidden = true
35 | }
36 | }
37 | }
38 | }
39 |
40 | filter(indices){
41 | if (!this.#elem) return;
42 | this.#current_filter = indices.map((e) => +e).sort((a, b) => a - b)
43 | this.updateFilter()
44 | }
45 |
46 | clearFilter(){
47 | this.#current_filter = null
48 | this.updateFilter()
49 | }
50 |
51 | getVisibleSelectedIndex(){
52 | if (!this.#elem || !this.#items_selector) return -1;
53 |
54 | let button = this.#elem.querySelector('.gradio-gallery .thumbnail-item.selected')
55 |
56 | for (let i = 0; i < this.#items_selector.length; ++i){
57 | if (this.#items_selector[i] == button){
58 | return i;
59 | }
60 | }
61 | return -1
62 | }
63 |
64 | getSelectedIndex() {
65 | if (!this.#elem || !this.#items_selector) return -1;
66 | if (!this.#current_filter) return this.#selected_idx
67 | return this.#filter_idx
68 | }
69 |
70 | keyHandler(e){
71 | switch(e.key)
72 | {
73 | case 'ArrowLeft':
74 | {
75 | let filteridx = this.getSelectedIndex()
76 | if (filteridx < 0) break;
77 | if (this.#current_filter){
78 | let next = (filteridx + this.#current_filter.length - 1) % this.#current_filter.length;
79 | this.#filter_idx = next
80 | this.#selected_idx = this.#current_filter[next]
81 | }
82 | else{
83 | this.#selected_idx = (filteridx + this.#items_selector.length - 1) % this.#items_selector.length;
84 | }
85 | let button = this.#items_selector[this.#selected_idx]
86 | if(button){
87 | button.click()
88 | }
89 | break;
90 | }
91 | case 'ArrowRight':
92 | {
93 | let filteridx = this.getSelectedIndex()
94 | if (filteridx < 0) break;
95 | if (this.#current_filter){
96 | let next = (filteridx + 1) % this.#current_filter.length;
97 | this.#filter_idx = next
98 | this.#selected_idx = this.#current_filter[next]
99 | }
100 | else{
101 | this.#selected_idx = (filteridx + 1) % this.#items_selector.length;
102 | }
103 | let button = this.#items_selector[this.#selected_idx]
104 | if(button){
105 | button.click()
106 | }
107 | break;
108 | }
109 | case 'Escape':
110 | {
111 | let imgPreview_close = this.#elem.querySelector('div.preview > div > button[class^="svelte"]')
112 | if (imgPreview_close != null) {
113 | imgPreview_close.click()
114 | }
115 | this.#filter_idx = -1
116 | this.#selected_idx = -1
117 | }
118 | }
119 | }
120 |
121 | clickHandler(){
122 | if(!this.#items_selector) return
123 | let idx = this.getVisibleSelectedIndex()
124 |
125 | if(!this.#current_filter){
126 | this.#selected_idx = idx
127 | return
128 | }
129 |
130 | for(let i = 0; i img')
188 | if (fullImg_preview != null) {
189 | fullImg_preview.forEach(function (e) {
190 | if (e) {
191 | e.addEventListener('click', callback_clicked, false);
192 | }
193 | });
194 | }
195 |
196 | }
197 |
198 | addClickCloseHandler(callback_clicked){
199 | if (!this.#elem) return;
200 |
201 | let imgPreview_close = this.#elem.querySelectorAll('div.preview > div > button[class^="svelte"]')
202 | if (imgPreview_close != null) {
203 | imgPreview_close.forEach(function (e) {
204 | if (e) {
205 | e.addEventListener('click', callback_clicked, false);
206 | }
207 | });
208 | }
209 |
210 | }
211 |
212 | clickClose(){
213 | if (!this.#elem) return;
214 |
215 | let imgPreview_close = this.#elem.querySelectorAll('div.preview > div > button[class^="svelte"]')
216 | if (imgPreview_close != null) {
217 | imgPreview_close.forEach(function (e) {
218 | if (e) {
219 | e.click()
220 | }
221 | });
222 | }
223 |
224 | }
225 |
226 | }
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/block_load_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING, List, Callable
3 | import gradio as gr
4 |
5 | from modules import shared
6 | from modules.shared import opts
7 |
8 | from .ui_common import *
9 | from .uibase import UIBase
10 |
11 | if TYPE_CHECKING:
12 | from .ui_classes import *
13 |
14 |
15 | class LoadDatasetUI(UIBase):
16 | def __init__(self):
17 | self.caption_file_ext = ""
18 |
19 | def create_ui(self, cfg_general):
20 | with gr.Column(variant="panel"):
21 | with gr.Row():
22 | with gr.Column(scale=3):
23 | self.tb_img_directory = gr.Textbox(
24 | label="Dataset directory",
25 | placeholder="C:\\directory\\of\\datasets",
26 | value=cfg_general.dataset_dir,
27 | )
28 | with gr.Column(scale=1, min_width=60):
29 | self.tb_caption_file_ext = gr.Textbox(
30 | label="Caption File Ext",
31 | placeholder=".txt (on Load and Save)",
32 | value=cfg_general.caption_ext,
33 | )
34 | self.caption_file_ext = cfg_general.caption_ext
35 | with gr.Column(scale=1, min_width=80):
36 | self.btn_load_datasets = gr.Button(value="Load")
37 | self.btn_unload_datasets = gr.Button(value="Unload")
38 | with gr.Accordion(label="Dataset Load Settings"):
39 | with gr.Row():
40 | with gr.Column():
41 | self.cb_load_recursive = gr.Checkbox(
42 | value=cfg_general.load_recursive,
43 | label="Load from subdirectories",
44 | )
45 | self.cb_load_caption_from_filename = gr.Checkbox(
46 | value=cfg_general.load_caption_from_filename,
47 | label="Load caption from filename if no text file exists",
48 | )
49 | self.cb_replace_new_line_with_comma = gr.Checkbox(
50 | value=cfg_general.replace_new_line,
51 | label="Replace new-line character with comma",
52 | )
53 | with gr.Column():
54 | self.rb_use_interrogator = gr.Radio(
55 | choices=[
56 | "No",
57 | "If Empty",
58 | "Overwrite",
59 | "Prepend",
60 | "Append",
61 | ],
62 | value=cfg_general.use_interrogator,
63 | label="Use Interrogator Caption",
64 | )
65 | self.dd_intterogator_names = gr.Dropdown(
66 | label="Interrogators",
67 | choices=dte_instance.INTERROGATOR_NAMES,
68 | value=cfg_general.use_interrogator_names,
69 | interactive=True,
70 | multiselect=True,
71 | )
72 | with gr.Accordion(label="Interrogator Settings", open=False):
73 | with gr.Row():
74 | self.cb_use_custom_threshold_booru = gr.Checkbox(
75 | value=cfg_general.use_custom_threshold_booru,
76 | label="Use Custom Threshold (Booru)",
77 | interactive=True,
78 | )
79 | self.sl_custom_threshold_booru = gr.Slider(
80 | minimum=0,
81 | maximum=1,
82 | value=cfg_general.custom_threshold_booru,
83 | step=0.01,
84 | interactive=True,
85 | label="Booru Score Threshold",
86 | )
87 | with gr.Row():
88 | self.sl_custom_threshold_z3d = gr.Slider(
89 | minimum=0,
90 | maximum=1,
91 | value=cfg_general.custom_threshold_z3d,
92 | step=0.01,
93 | interactive=True,
94 | label="Z3D-E621 Score Threshold",
95 | )
96 | with gr.Row():
97 | self.cb_use_custom_threshold_waifu = gr.Checkbox(
98 | value=cfg_general.use_custom_threshold_waifu,
99 | label="Use Custom Threshold (WDv1.4 Tagger)",
100 | interactive=True,
101 | )
102 | self.sl_custom_threshold_waifu = gr.Slider(
103 | minimum=0,
104 | maximum=1,
105 | value=cfg_general.custom_threshold_waifu,
106 | step=0.01,
107 | interactive=True,
108 | label="WDv1.4 Tagger Score Threshold",
109 | )
110 |
111 | def set_callbacks(
112 | self,
113 | o_update_filter_and_gallery: List[gr.components.Component],
114 | toprow: ToprowUI,
115 | dataset_gallery: DatasetGalleryUI,
116 | filter_by_tags: FilterByTagsUI,
117 | filter_by_selection: FilterBySelectionUI,
118 | batch_edit_captions: BatchEditCaptionsUI,
119 | update_filter_and_gallery: Callable[[], List],
120 | ):
121 | def load_files_from_dir(
122 | dir: str,
123 | caption_file_ext: str,
124 | recursive: bool,
125 | load_caption_from_filename: bool,
126 | replace_new_line: bool,
127 | use_interrogator: str,
128 | use_interrogator_names, #: List[str], : to avoid error on gradio v3.23.0
129 | use_custom_threshold_booru: bool,
130 | custom_threshold_booru: float,
131 | use_custom_threshold_waifu: bool,
132 | custom_threshold_waifu: float,
133 | custom_threshold_z3d: float,
134 | use_kohya_metadata: bool,
135 | kohya_json_path: str,
136 | ):
137 |
138 | interrogate_method = dte_instance.InterrogateMethod.NONE
139 | if use_interrogator == "If Empty":
140 | interrogate_method = dte_instance.InterrogateMethod.PREFILL
141 | elif use_interrogator == "Overwrite":
142 | interrogate_method = dte_instance.InterrogateMethod.OVERWRITE
143 | elif use_interrogator == "Prepend":
144 | interrogate_method = dte_instance.InterrogateMethod.PREPEND
145 | elif use_interrogator == "Append":
146 | interrogate_method = dte_instance.InterrogateMethod.APPEND
147 |
148 | threshold_booru = (
149 | custom_threshold_booru
150 | if use_custom_threshold_booru
151 | else opts.interrogate_deepbooru_score_threshold
152 | )
153 | threshold_waifu = (
154 | custom_threshold_waifu if use_custom_threshold_waifu else -1
155 | )
156 | threshold_z3d = custom_threshold_z3d
157 |
158 | dte_instance.load_dataset(
159 | dir,
160 | caption_file_ext,
161 | recursive,
162 | load_caption_from_filename,
163 | replace_new_line,
164 | interrogate_method,
165 | use_interrogator_names,
166 | threshold_booru,
167 | threshold_waifu,
168 | threshold_z3d,
169 | opts.dataset_editor_use_temp_files,
170 | kohya_json_path if use_kohya_metadata else None,
171 | opts.dataset_editor_max_res,
172 | )
173 | imgs = dte_instance.get_filtered_imgs(filters=[])
174 | img_indices = dte_instance.get_filtered_imgindices(filters=[])
175 | return (
176 | [imgs, []]
177 | + [
178 | gr.CheckboxGroup.update(
179 | value=[str(i) for i in img_indices],
180 | choices=[str(i) for i in img_indices],
181 | ),
182 | 1,
183 | ]
184 | + filter_by_tags.clear_filters(update_filter_and_gallery)
185 | + [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
186 | )
187 |
188 | self.btn_load_datasets.click(
189 | fn=load_files_from_dir,
190 | inputs=[
191 | self.tb_img_directory,
192 | self.tb_caption_file_ext,
193 | self.cb_load_recursive,
194 | self.cb_load_caption_from_filename,
195 | self.cb_replace_new_line_with_comma,
196 | self.rb_use_interrogator,
197 | self.dd_intterogator_names,
198 | self.cb_use_custom_threshold_booru,
199 | self.sl_custom_threshold_booru,
200 | self.cb_use_custom_threshold_waifu,
201 | self.sl_custom_threshold_waifu,
202 | self.sl_custom_threshold_z3d,
203 | toprow.cb_save_kohya_metadata,
204 | toprow.tb_metadata_output,
205 | ],
206 | outputs=[
207 | dataset_gallery.gl_dataset_images,
208 | filter_by_selection.gl_filter_images,
209 | ]
210 | + [
211 | dataset_gallery.cbg_hidden_dataset_filter,
212 | dataset_gallery.nb_hidden_dataset_filter_apply,
213 | ]
214 | + o_update_filter_and_gallery,
215 | )
216 |
217 | def unload_files():
218 | dte_instance.clear()
219 | return (
220 | [[], []]
221 | + [gr.CheckboxGroup.update(value=[], choices=[]), 1]
222 | + filter_by_tags.clear_filters(update_filter_and_gallery)
223 | + [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
224 | )
225 |
226 | self.btn_unload_datasets.click(
227 | fn=unload_files,
228 | outputs=[
229 | dataset_gallery.gl_dataset_images,
230 | filter_by_selection.gl_filter_images,
231 | ]
232 | + [
233 | dataset_gallery.cbg_hidden_dataset_filter,
234 | dataset_gallery.nb_hidden_dataset_filter_apply,
235 | ]
236 | + o_update_filter_and_gallery,
237 | )
238 |
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/tab_batch_edit_captions.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING, List, Callable
3 | import gradio as gr
4 |
5 | from .ui_common import *
6 | from .uibase import UIBase
7 | from .block_tag_select import TagSelectUI
8 |
9 | if TYPE_CHECKING:
10 | from .ui_classes import *
11 |
12 | SortBy = dte_instance.SortBy
13 | SortOrder = dte_instance.SortOrder
14 |
15 | class BatchEditCaptionsUI(UIBase):
16 | def __init__(self):
17 | self.tag_select_ui_remove = TagSelectUI()
18 | self.show_only_selected_tags = False
19 |
20 | def create_ui(self, cfg_batch_edit, get_filters:Callable[[], List[dte_module.filters.Filter]]):
21 | with gr.Tab(label='Search and Replace'):
22 | with gr.Column(variant='panel'):
23 | gr.HTML('Edit common tags.')
24 | self.cb_show_only_tags_selected = gr.Checkbox(value=cfg_batch_edit.show_only_selected, label='Show only the tags selected in the Positive Filter')
25 | self.show_only_selected_tags = cfg_batch_edit.show_only_selected
26 | self.tb_common_tags = gr.Textbox(label='Common Tags', interactive=False)
27 | self.tb_edit_tags = gr.Textbox(label='Edit Tags', interactive=True)
28 | self.cb_prepend_tags = gr.Checkbox(value=cfg_batch_edit.prepend, label='Prepend additional tags')
29 | self.btn_apply_edit_tags = gr.Button(value='Apply changes to filtered images', variant='primary')
30 | with gr.Accordion(label='Show description of how to edit tags', open=False):
31 | gr.HTML(value="""
32 | 1. The tags common to all displayed images are shown in comma separated style.
33 | 2. When changes are applied, all tags in each displayed images are replaced.
34 | 3. If you change some tags into blank, they will be erased.
35 | 4. If you add some tags to the end, they will be added to the end/beginning of the text file.
36 | 5. Changes are not applied to the text files until the "Save all changes" button is pressed.
37 | ex A.
38 | Original Text = "A, A, B, C" Common Tags = "B, A" Edit Tags = "X, Y"
39 | Result = "Y, Y, X, C" (B->X, A->Y)
40 | ex B.
41 | Original Text = "A, B, C" Common Tags = "(nothing)" Edit Tags = "X, Y"
42 | Result = "A, B, C, X, Y" (add X and Y to the end (default))
43 | Result = "X, Y, A, B, C" (add X and Y to the beginning ("Prepend additional tags" checked))
44 | ex C.
45 | Original Text = "A, B, C, D, E" Common Tags = "A, B, D" Edit Tags = ", X, "
46 | Result = "X, C, E" (A->"", B->X, D->"")
47 | """)
48 | with gr.Column(variant='panel'):
49 | gr.HTML('Search and Replace for all images displayed.')
50 | self.tb_sr_search_tags = gr.Textbox(label='Search Text', interactive=True)
51 | self.tb_sr_replace_tags = gr.Textbox(label='Replace Text', interactive=True)
52 | self.cb_use_regex = gr.Checkbox(label='Use regex', value=cfg_batch_edit.use_regex)
53 | self.rb_sr_replace_target = gr.Radio(['Only Selected Tags', 'Each Tags', 'Entire Caption'], value=cfg_batch_edit.target, label='Search and Replace in', interactive=True)
54 | self.tb_sr_selected_tags = gr.Textbox(label='Selected Tags', interactive=False, lines=2)
55 | self.btn_apply_sr_tags = gr.Button(value='Search and Replace', variant='primary')
56 | with gr.Tab(label='Remove'):
57 | with gr.Column(variant='panel'):
58 | gr.HTML('Remove duplicate tags from the images displayed.')
59 | self.btn_remove_duplicate = gr.Button(value='Remove duplicate tags', variant='primary')
60 | with gr.Column(variant='panel'):
61 | gr.HTML('Remove selected tags from the images displayed.')
62 | self.btn_remove_selected = gr.Button(value='Remove selected tags', variant='primary')
63 | self.tag_select_ui_remove.create_ui(get_filters, cfg_batch_edit.sory_by, cfg_batch_edit.sort_order, cfg_batch_edit.sw_prefix, cfg_batch_edit.sw_suffix, cfg_batch_edit.sw_regex)
64 | with gr.Tab(label='Extras'):
65 | with gr.Column(variant='panel'):
66 | gr.HTML('Sort tags in the images displayed.')
67 | with gr.Row():
68 | self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_batch_edit.batch_sort_by, interactive=True, label='Sort by')
69 | self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_batch_edit.batch_sort_order, interactive=True, label='Sort Order')
70 | self.btn_sort_selected = gr.Button(value='Sort tags', variant='primary')
71 | with gr.Column(variant='panel'):
72 | gr.HTML('Truncate tags by token count.')
73 | self.nb_token_count = gr.Number(value=cfg_batch_edit.token_count, precision=0)
74 | self.btn_truncate_by_token = gr.Button(value='Truncate tags by token count', variant='primary')
75 |
76 | def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], load_dataset:LoadDatasetUI, filter_by_tags:FilterByTagsUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
77 | load_dataset.btn_load_datasets.click(
78 | fn=lambda:['', ''],
79 | outputs=[self.tb_common_tags, self.tb_edit_tags]
80 | )
81 |
82 | def apply_edit_tags(search_tags: str, replace_tags: str, prepend: bool):
83 | search_tags = [t.strip() for t in search_tags.split(',')]
84 | search_tags = [t for t in search_tags if t]
85 | replace_tags = [t.strip() for t in replace_tags.split(',')]
86 | replace_tags = [t for t in replace_tags if t]
87 |
88 | dte_instance.replace_tags(search_tags = search_tags, replace_tags = replace_tags, filters=get_filters(), prepend = prepend)
89 | filter_by_tags.tag_filter_ui.get_filter().tags = dte_instance.get_replaced_tagset(filter_by_tags.tag_filter_ui.get_filter().tags, search_tags, replace_tags)
90 | filter_by_tags.tag_filter_ui_neg.get_filter().tags = dte_instance.get_replaced_tagset(filter_by_tags.tag_filter_ui_neg.get_filter().tags, search_tags, replace_tags)
91 |
92 | return update_filter_and_gallery()
93 |
94 | self.btn_apply_edit_tags.click(
95 | fn=apply_edit_tags,
96 | inputs=[self.tb_common_tags, self.tb_edit_tags, self.cb_prepend_tags],
97 | outputs=o_update_filter_and_gallery
98 | ).then(
99 | fn=None,
100 | _js='() => dataset_tag_editor_gl_dataset_images_close()'
101 | )
102 |
103 | def search_and_replace(search_text: str, replace_text: str, target_text: str, use_regex: bool):
104 | if target_text == 'Only Selected Tags':
105 | selected_tags = set(filter_by_tags.tag_filter_ui.selected_tags)
106 | dte_instance.search_and_replace_selected_tags(search_text = search_text, replace_text=replace_text, selected_tags=selected_tags, filters=get_filters(), use_regex=use_regex)
107 | filter_by_tags.tag_filter_ui.filter.tags = dte_instance.search_and_replace_tag_set(search_text, replace_text, filter_by_tags.tag_filter_ui.filter.tags, selected_tags, use_regex)
108 | filter_by_tags.tag_filter_ui_neg.filter.tags = dte_instance.search_and_replace_tag_set(search_text, replace_text, filter_by_tags.tag_filter_ui_neg.filter.tags, selected_tags, use_regex)
109 |
110 | elif target_text == 'Each Tags':
111 | dte_instance.search_and_replace_selected_tags(search_text = search_text, replace_text=replace_text, selected_tags=None, filters=get_filters(), use_regex=use_regex)
112 | filter_by_tags.tag_filter_ui.filter.tags = dte_instance.search_and_replace_tag_set(search_text, replace_text, filter_by_tags.tag_filter_ui.filter.tags, None, use_regex)
113 | filter_by_tags.tag_filter_ui_neg.filter.tags = dte_instance.search_and_replace_tag_set(search_text, replace_text, filter_by_tags.tag_filter_ui_neg.filter.tags, None, use_regex)
114 |
115 | elif target_text == 'Entire Caption':
116 | dte_instance.search_and_replace_caption(search_text=search_text, replace_text=replace_text, filters=get_filters(), use_regex=use_regex)
117 | filter_by_tags.tag_filter_ui.filter.tags = dte_instance.search_and_replace_tag_set(search_text, replace_text, filter_by_tags.tag_filter_ui.filter.tags, None, use_regex)
118 | filter_by_tags.tag_filter_ui_neg.filter.tags = dte_instance.search_and_replace_tag_set(search_text, replace_text, filter_by_tags.tag_filter_ui_neg.filter.tags, None, use_regex)
119 |
120 | return update_filter_and_gallery()
121 |
122 | self.btn_apply_sr_tags.click(
123 | fn=search_and_replace,
124 | inputs=[self.tb_sr_search_tags, self.tb_sr_replace_tags, self.rb_sr_replace_target, self.cb_use_regex],
125 | outputs=o_update_filter_and_gallery
126 | ).then(
127 | fn=None,
128 | _js='() => dataset_tag_editor_gl_dataset_images_close()'
129 | )
130 |
131 | def cb_show_only_tags_selected_changed(value: bool):
132 | self.show_only_selected_tags = value
133 | return self.get_common_tags(get_filters, filter_by_tags)
134 |
135 | self.cb_show_only_tags_selected.change(
136 | fn=cb_show_only_tags_selected_changed,
137 | inputs=self.cb_show_only_tags_selected,
138 | outputs=[self.tb_common_tags, self.tb_edit_tags]
139 | )
140 |
141 | def remove_duplicated_tags():
142 | dte_instance.remove_duplicated_tags(get_filters())
143 | return update_filter_and_gallery()
144 |
145 | self.btn_remove_duplicate.click(
146 | fn=remove_duplicated_tags,
147 | outputs=o_update_filter_and_gallery
148 | )
149 |
150 | self.tag_select_ui_remove.set_callbacks()
151 |
152 | def remove_selected_tags():
153 | dte_instance.remove_tags(self.tag_select_ui_remove.selected_tags, get_filters())
154 | return update_filter_and_gallery()
155 |
156 | self.btn_remove_selected.click(
157 | fn=remove_selected_tags,
158 | outputs=o_update_filter_and_gallery
159 | )
160 |
161 | def sort_selected_tags(sort_by:str, sort_order:str):
162 | sort_by = SortBy(sort_by)
163 | sort_order = SortOrder(sort_order)
164 | dte_instance.sort_filtered_tags(get_filters(), sort_by=sort_by, sort_order=sort_order)
165 | return update_filter_and_gallery()
166 |
167 | self.btn_sort_selected.click(
168 | fn=sort_selected_tags,
169 | inputs=[self.rb_sort_by, self.rb_sort_order],
170 | outputs=o_update_filter_and_gallery
171 | )
172 |
173 | self.cb_show_only_tags_selected.change(
174 | fn=self.func_to_set_value('show_only_selected_tags'),
175 | inputs=self.cb_show_only_tags_selected
176 | )
177 |
178 | def truncate_by_token_count(token_count:int):
179 | token_count = max(int(token_count), 0)
180 | dte_instance.truncate_filtered_tags_by_token_count(get_filters(), token_count)
181 | return update_filter_and_gallery()
182 |
183 | self.btn_truncate_by_token.click(
184 | fn=truncate_by_token_count,
185 | inputs=self.nb_token_count,
186 | outputs=o_update_filter_and_gallery
187 | )
188 |
189 |
190 | def get_common_tags(self, get_filters:Callable[[], List[dte_module.filters.Filter]], filter_by_tags:FilterByTagsUI):
191 | if self.show_only_selected_tags:
192 | tags = ', '.join([t for t in dte_instance.get_common_tags(filters=get_filters()) if t in filter_by_tags.tag_filter_ui.filter.tags])
193 | else:
194 | tags = ', '.join(dte_instance.get_common_tags(filters=get_filters()))
195 | return [tags, tags]
--------------------------------------------------------------------------------
/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING, List, Callable
3 | import gradio as gr
4 |
5 | from modules import shared
6 | from modules.call_queue import wrap_queued_call
7 | from scripts.dte_instance import dte_module
8 |
9 | from .ui_common import *
10 | from .uibase import UIBase
11 | from scripts.tokenizer import clip_tokenizer
12 |
13 | if TYPE_CHECKING:
14 | from .ui_classes import *
15 |
16 | SortBy = dte_instance.SortBy
17 | SortOrder = dte_instance.SortOrder
18 |
19 | class EditCaptionOfSelectedImageUI(UIBase):
20 | def __init__(self):
21 | self.change_is_saved = True
22 |
23 | def create_ui(self, cfg_edit_selected):
24 | with gr.Row(visible=False):
25 | self.nb_hidden_image_index_save_or_not = gr.Number(value=-1, label='hidden_s_or_n')
26 | self.tb_hidden_edit_caption = gr.Textbox()
27 | self.btn_hidden_save_caption = gr.Button(elem_id="dataset_tag_editor_btn_hidden_save_caption")
28 | with gr.Tab(label='Read Caption from Selected Image'):
29 | self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6, elem_id='dte_caption')
30 | self.token_counter_caption = gr.HTML(value='0/75', elem_id='dte_caption_counter', elem_classes=["token-counter-dte"])
31 | with gr.Row():
32 | self.btn_copy_caption = gr.Button(value='Copy and Overwrite')
33 | self.btn_prepend_caption = gr.Button(value='Prepend')
34 | self.btn_append_caption = gr.Button(value='Append')
35 |
36 | with gr.Tab(label='Interrogate Selected Image'):
37 | with gr.Row():
38 | self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_instance.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False)
39 | self.btn_interrogate_si = gr.Button(value='Interrogate')
40 | with gr.Column():
41 | self.tb_interrogate = gr.Textbox(label='Interrogate Result', interactive=True, lines=6, elem_id='dte_interrogate')
42 | self.token_counter_interrogate = gr.HTML(value='', elem_id='dte_interrogate_counter')
43 | with gr.Row():
44 | self.btn_copy_interrogate = gr.Button(value='Copy and Overwrite')
45 | self.btn_prepend_interrogate = gr.Button(value='Prepend')
46 | self.btn_append_interrogate = gr.Button(value='Append')
47 | with gr.Column():
48 | self.cb_copy_caption_automatically = gr.Checkbox(value=cfg_edit_selected.auto_copy, label='Copy caption from selected images automatically')
49 | self.cb_sort_caption_on_save = gr.Checkbox(value=cfg_edit_selected.sort_on_save, label='Sort caption on save')
50 | with gr.Row(visible=cfg_edit_selected.sort_on_save) as self.sort_settings:
51 | self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_edit_selected.sort_by, interactive=True, label='Sort by')
52 | self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_edit_selected.sort_order, interactive=True, label='Sort Order')
53 | self.cb_ask_save_when_caption_changed = gr.Checkbox(value=cfg_edit_selected.warn_change_not_saved, label='Warn if changes in caption is not saved')
54 | with gr.Column():
55 | self.tb_edit_caption = gr.Textbox(label='Edit Caption', interactive=True, lines=6, elem_id= 'dte_edit_caption')
56 | self.token_counter_edit_caption = gr.HTML(value='0/75', elem_id='dte_edit_caption_counter', elem_classes=["token-counter-dte"])
57 | self.btn_apply_changes_selected_image = gr.Button(value='Apply changes to selected image', variant='primary')
58 | self.btn_apply_changes_all_images = gr.Button(value='Apply changes to ALL displayed images', variant='primary')
59 |
60 | gr.HTML("""Changes are not applied to the text files until the "Save all changes" button is pressed.""")
61 |
62 | def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, load_dataset:LoadDatasetUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
63 | load_dataset.btn_load_datasets.click(
64 | fn=lambda:['', -1],
65 | outputs=[self.tb_caption, self.nb_hidden_image_index_save_or_not]
66 | )
67 |
68 | def gallery_index_changed(next_idx:int, prev_idx:int, edit_caption: str, copy_automatically: bool, warn_change_not_saved: bool):
69 | next_idx = int(next_idx) if next_idx is not None else -1
70 | prev_idx = int(prev_idx) if prev_idx is not None else -1
71 | img_paths = dte_instance.get_filtered_imgpaths(filters=get_filters())
72 | prev_tags_txt = ''
73 | if 0 <= prev_idx and prev_idx < len(img_paths):
74 | prev_tags_txt = ', '.join(dte_instance.get_tags_by_image_path(img_paths[prev_idx]))
75 | else:
76 | prev_idx = -1
77 |
78 | next_tags_txt = ''
79 | if 0 <= next_idx and next_idx < len(img_paths):
80 | next_tags_txt = ', '.join(dte_instance.get_tags_by_image_path(img_paths[next_idx]))
81 |
82 | return\
83 | [prev_idx if warn_change_not_saved and edit_caption != prev_tags_txt and not self.change_is_saved else -1] +\
84 | [next_tags_txt, next_tags_txt if copy_automatically else edit_caption] +\
85 | [edit_caption]
86 |
87 | self.nb_hidden_image_index_save_or_not.change(
88 | fn=lambda a:None,
89 | _js='(a) => dataset_tag_editor_ask_save_change_or_not(a)',
90 | inputs=self.nb_hidden_image_index_save_or_not
91 | )
92 | dataset_gallery.nb_hidden_image_index.change(lambda:None).then(
93 | fn=gallery_index_changed,
94 | inputs=[dataset_gallery.nb_hidden_image_index, dataset_gallery.nb_hidden_image_index_prev, self.tb_edit_caption, self.cb_copy_caption_automatically, self.cb_ask_save_when_caption_changed],
95 | outputs=[self.nb_hidden_image_index_save_or_not] + [self.tb_caption, self.tb_edit_caption] + [self.tb_hidden_edit_caption]
96 | )
97 |
98 | def change_selected_image_caption(tags_text: str, idx:int, sort: bool, sort_by:str, sort_order:str):
99 | idx = int(idx)
100 | img_paths = dte_instance.get_filtered_imgpaths(filters=get_filters())
101 |
102 | edited_tags = [t.strip() for t in tags_text.split(',')]
103 | edited_tags = [t for t in edited_tags if t]
104 |
105 | if sort:
106 | edited_tags = dte_instance.sort_tags(edited_tags, SortBy(sort_by), SortOrder(sort_order))
107 |
108 | if 0 <= idx and idx < len(img_paths):
109 | dte_instance.set_tags_by_image_path(imgpath=img_paths[idx], tags=edited_tags)
110 | return update_filter_and_gallery()
111 |
112 | self.btn_hidden_save_caption.click(
113 | fn=change_selected_image_caption,
114 | inputs=[self.tb_hidden_edit_caption, self.nb_hidden_image_index_save_or_not, self.cb_sort_caption_on_save, self.rb_sort_by, self.rb_sort_order],
115 | outputs=o_update_filter_and_gallery
116 | )
117 |
118 | self.btn_copy_caption.click(
119 | fn=lambda a:a,
120 | inputs=[self.tb_caption],
121 | outputs=[self.tb_edit_caption]
122 | )
123 |
124 | self.btn_append_caption.click(
125 | fn=lambda a, b : b + (', ' if a and b else '') + a,
126 | inputs=[self.tb_caption, self.tb_edit_caption],
127 | outputs=[self.tb_edit_caption]
128 | )
129 |
130 | self.btn_prepend_caption.click(
131 | fn=lambda a, b : a + (', ' if a and b else '') + b,
132 | inputs=[self.tb_caption, self.tb_edit_caption],
133 | outputs=[self.tb_edit_caption]
134 | )
135 |
136 | def interrogate_selected_image(interrogator_name: str, use_threshold_booru: bool, threshold_booru: float, use_threshold_waifu: bool, threshold_waifu: float, threshold_z3d: float):
137 |
138 | if not interrogator_name:
139 | return ''
140 | threshold_booru = threshold_booru if use_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
141 | threshold_waifu = threshold_waifu if use_threshold_waifu else -1
142 | return dte_instance.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu, threshold_z3d)
143 |
144 | self.btn_interrogate_si.click(
145 | fn=interrogate_selected_image,
146 | inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu, load_dataset.sl_custom_threshold_z3d],
147 | outputs=[self.tb_interrogate]
148 | )
149 |
150 | self.btn_copy_interrogate.click(
151 | fn=lambda a:a,
152 | inputs=[self.tb_interrogate],
153 | outputs=[self.tb_edit_caption]
154 | )
155 |
156 | self.btn_append_interrogate.click(
157 | fn=lambda a, b : b + (', ' if a and b else '') + a,
158 | inputs=[self.tb_interrogate, self.tb_edit_caption],
159 | outputs=[self.tb_edit_caption]
160 | )
161 |
162 | self.btn_prepend_interrogate.click(
163 | fn=lambda a, b : a + (', ' if a and b else '') + b,
164 | inputs=[self.tb_interrogate, self.tb_edit_caption],
165 | outputs=[self.tb_edit_caption]
166 | )
167 |
168 | def change_in_caption():
169 | self.change_is_saved = False
170 |
171 | self.tb_edit_caption.change(
172 | fn=change_in_caption
173 | )
174 |
175 | self.tb_caption.change(
176 | fn=change_in_caption
177 | )
178 |
179 | def apply_changes(edited:str, sort:bool, sort_by:str, sort_order:str):
180 | self.change_is_saved = True
181 | return change_selected_image_caption(edited, dataset_gallery.selected_index, sort, sort_by, sort_order)
182 |
183 | self.btn_apply_changes_selected_image.click(
184 | fn=apply_changes,
185 | inputs=[self.tb_edit_caption, self.cb_sort_caption_on_save, self.rb_sort_by, self.rb_sort_order],
186 | outputs=o_update_filter_and_gallery,
187 | _js='(a,b,c,d) => {dataset_tag_editor_gl_dataset_images_close(); return [a, b, c, d]}'
188 | )
189 |
190 | def apply_chages_all(tags_text: str, sort: bool, sort_by:str, sort_order:str):
191 | self.change_is_saved = True
192 | img_paths = dte_instance.get_filtered_imgpaths(filters=get_filters())
193 |
194 | edited_tags = [t.strip() for t in tags_text.split(',')]
195 | edited_tags = [t for t in edited_tags if t]
196 |
197 | if sort:
198 | edited_tags = dte_instance.sort_tags(edited_tags, SortBy(sort_by), SortOrder(sort_order))
199 |
200 | for img_path in img_paths:
201 | dte_instance.set_tags_by_image_path(imgpath=img_path, tags=edited_tags)
202 | return update_filter_and_gallery()
203 |
204 | self.btn_apply_changes_all_images.click(
205 | fn=apply_chages_all,
206 | inputs=[self.tb_edit_caption, self.cb_sort_caption_on_save, self.rb_sort_by, self.rb_sort_order],
207 | outputs=o_update_filter_and_gallery,
208 | _js='(a,b,c,d) => {dataset_tag_editor_gl_dataset_images_close(); return [a, b, c, d]}'
209 | )
210 |
211 | self.cb_sort_caption_on_save.change(
212 | fn=lambda x:gr.update(visible=x),
213 | inputs=self.cb_sort_caption_on_save,
214 | outputs=self.sort_settings
215 | )
216 |
217 | def update_token_counter(text:str):
218 | _, token_count = clip_tokenizer.tokenize(text, shared.opts.dataset_editor_use_raw_clip_token)
219 | max_length = clip_tokenizer.get_target_token_count(token_count)
220 | return f"{token_count}/{max_length}"
221 |
222 | update_caption_token_counter_args = {
223 | 'fn' : wrap_queued_call(update_token_counter),
224 | 'inputs' : [self.tb_caption],
225 | 'outputs' : [self.token_counter_caption]
226 | }
227 | update_edit_caption_token_counter_args = {
228 | 'fn' : wrap_queued_call(update_token_counter),
229 | 'inputs' : [self.tb_edit_caption],
230 | 'outputs' : [self.token_counter_edit_caption]
231 | }
232 | update_interrogate_token_counter_args = {
233 | 'fn' : wrap_queued_call(update_token_counter),
234 | 'inputs' : [self.tb_interrogate],
235 | 'outputs' : [self.token_counter_interrogate]
236 | }
237 |
238 | self.tb_caption.change(**update_caption_token_counter_args)
239 | self.tb_edit_caption.change(**update_edit_caption_token_counter_args)
240 | self.tb_interrogate.change(**update_interrogate_token_counter_args)
241 |
--------------------------------------------------------------------------------
/scripts/main.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple, Type, Dict, Any
2 | from modules import shared, script_callbacks
3 | from modules.shared import opts
4 | import gradio as gr
5 | from scripts.config import *
6 |
7 | import scripts.tag_editor_ui as ui
8 |
9 | # ================================================================
10 | # General Callbacks
11 | # ================================================================
12 |
13 | config = Config()
14 |
15 |
16 | def write_general_config(*args):
17 | cfg = GeneralConfig(*args)
18 | config.write(cfg._asdict(), "general")
19 |
20 |
21 | def write_filter_config(*args):
22 | hlen = len(args) // 2
23 | cfg_p = FilterConfig(*args[:hlen])
24 | cfg_n = FilterConfig(*args[hlen:])
25 | config.write({"positive": cfg_p._asdict(), "negative": cfg_n._asdict()}, "filter")
26 |
27 |
28 | def write_batch_edit_config(*args):
29 | cfg = BatchEditConfig(*args)
30 | config.write(cfg._asdict(), "batch_edit")
31 |
32 |
33 | def write_edit_selected_config(*args):
34 | cfg = EditSelectedConfig(*args)
35 | config.write(cfg._asdict(), "edit_selected")
36 |
37 |
38 | def write_move_delete_config(*args):
39 | cfg = MoveDeleteConfig(*args)
40 | config.write(cfg._asdict(), "file_move_delete")
41 |
42 |
43 | def read_config(name: str, config_type: Type, default: NamedTuple, compat_func=None):
44 | d = config.read(name)
45 | cfg = default
46 | if d:
47 | if compat_func:
48 | d = compat_func(d)
49 | d = cfg._asdict() | d
50 | d = {k: v for k, v in d.items() if k in cfg._asdict().keys()}
51 | cfg = config_type(**d)
52 | return cfg
53 |
54 |
55 | def read_general_config():
56 | # for compatibility
57 | generalcfg_intterogator_names = [
58 | ("use_blip_to_prefill", "BLIP"),
59 | ("use_git_to_prefill", "GIT-large-COCO"),
60 | ("use_booru_to_prefill", "DeepDanbooru"),
61 | ("use_waifu_to_prefill", "wd-v1-4-vit-tagger"),
62 | ]
63 | use_interrogator_names = []
64 |
65 | def compat_func(d: Dict[str, Any]):
66 | if "use_interrogator_names" in d.keys():
67 | return d
68 | for cfg in generalcfg_intterogator_names:
69 | if d.get(cfg[0]):
70 | use_interrogator_names.append(cfg[1])
71 | d["use_interrogator_names"] = use_interrogator_names
72 | return d
73 |
74 | return read_config("general", GeneralConfig, CFG_GENERAL_DEFAULT, compat_func)
75 |
76 |
77 | def read_filter_config():
78 | d = config.read("filter")
79 | d_p = d.get("positive") if d else None
80 | d_n = d.get("negative") if d else None
81 | cfg_p = CFG_FILTER_P_DEFAULT
82 | cfg_n = CFG_FILTER_N_DEFAULT
83 | if d_p:
84 | d_p = cfg_p._asdict() | d_p
85 | d_p = {k: v for k, v in d_p.items() if k in cfg_p._asdict().keys()}
86 | cfg_p = FilterConfig(**d_p)
87 | if d_n:
88 | d_n = cfg_n._asdict() | d_n
89 | d_n = {k: v for k, v in d_n.items() if k in cfg_n._asdict().keys()}
90 | cfg_n = FilterConfig(**d_n)
91 | return cfg_p, cfg_n
92 |
93 |
94 | def read_batch_edit_config():
95 | return read_config("batch_edit", BatchEditConfig, CFG_BATCH_EDIT_DEFAULT)
96 |
97 |
98 | def read_edit_selected_config():
99 | return read_config("edit_selected", EditSelectedConfig, CFG_EDIT_SELECTED_DEFAULT)
100 |
101 |
102 | def read_move_delete_config():
103 | return read_config("file_move_delete", MoveDeleteConfig, CFG_MOVE_DELETE_DEFAULT)
104 |
105 |
106 | # ================================================================
107 | # General Callbacks for Updating UIs
108 | # ================================================================
109 |
110 |
111 | def get_filters():
112 | filters = [
113 | ui.filter_by_tags.tag_filter_ui.get_filter(),
114 | ui.filter_by_tags.tag_filter_ui_neg.get_filter(),
115 | ] + [ui.filter_by_selection.path_filter]
116 | return filters
117 |
118 |
119 | def update_gallery():
120 | img_indices = ui.dte_instance.get_filtered_imgindices(filters=get_filters())
121 | total_image_num = len(ui.dte_instance.dataset)
122 | displayed_image_num = len(img_indices)
123 | ui.gallery_state.register_value(
124 | "Displayed Images", f"{displayed_image_num} / {total_image_num} total"
125 | )
126 | ui.gallery_state.register_value(
127 | "Current Tag Filter",
128 | f"{ui.filter_by_tags.tag_filter_ui.get_filter()} {' AND ' if ui.filter_by_tags.tag_filter_ui.get_filter().tags and ui.filter_by_tags.tag_filter_ui_neg.get_filter().tags else ''} {ui.filter_by_tags.tag_filter_ui_neg.get_filter()}",
129 | )
130 | ui.gallery_state.register_value(
131 | "Current Selection Filter",
132 | f"{len(ui.filter_by_selection.path_filter.paths)} images",
133 | )
134 | return [
135 | [str(i) for i in img_indices],
136 | 1,
137 | -1,
138 | -1,
139 | -1,
140 | ui.gallery_state.get_current_gallery_txt(),
141 | ]
142 |
143 |
144 | def update_filter_and_gallery():
145 | return (
146 | [
147 | ui.filter_by_tags.tag_filter_ui.cbg_tags_update(),
148 | ui.filter_by_tags.tag_filter_ui_neg.cbg_tags_update(),
149 | ]
150 | + update_gallery()
151 | + ui.batch_edit_captions.get_common_tags(get_filters, ui.filter_by_tags)
152 | + [", ".join(ui.filter_by_tags.tag_filter_ui.filter.tags)]
153 | + [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
154 | + ["", ""]
155 | )
156 |
157 |
158 | # ================================================================
159 | # Script Callbacks
160 | # ================================================================
161 |
162 |
163 | def on_ui_tabs():
164 | config.load()
165 |
166 | cfg_general = read_general_config()
167 | cfg_filter_p, cfg_filter_n = read_filter_config()
168 | cfg_batch_edit = read_batch_edit_config()
169 | cfg_edit_selected = read_edit_selected_config()
170 | cfg_file_move_delete = read_move_delete_config()
171 |
172 | ui.dte_instance.load_interrogators()
173 |
174 | with gr.Blocks(analytics_enabled=False) as dataset_tag_editor_interface:
175 |
176 | gr.HTML(
177 | value="""
178 | This extension works well with text captions in comma-separated style (such as the tags generated by DeepBooru interrogator).
179 | """
180 | )
181 |
182 | ui.toprow.create_ui(cfg_general)
183 |
184 | with gr.Accordion(label="Reload/Save Settings (config.json)", open=False):
185 | with gr.Row():
186 | btn_reload_config_file = gr.Button(value="Reload settings")
187 | btn_save_setting_as_default = gr.Button(value="Save current settings")
188 | btn_restore_default = gr.Button(value="Restore settings to default")
189 |
190 | with gr.Row(equal_height=False):
191 | with gr.Column():
192 | ui.load_dataset.create_ui(cfg_general)
193 | ui.dataset_gallery.create_ui(opts.dataset_editor_image_columns)
194 | ui.gallery_state.create_ui()
195 |
196 | with gr.Tab(label="Filter by Tags"):
197 | ui.filter_by_tags.create_ui(cfg_filter_p, cfg_filter_n, get_filters)
198 |
199 | with gr.Tab(label="Filter by Selection"):
200 | ui.filter_by_selection.create_ui(opts.dataset_editor_image_columns)
201 |
202 | with gr.Tab(label="Batch Edit Captions"):
203 | ui.batch_edit_captions.create_ui(cfg_batch_edit, get_filters)
204 |
205 | with gr.Tab(label="Edit Caption of Selected Image"):
206 | ui.edit_caption_of_selected_image.create_ui(cfg_edit_selected)
207 |
208 | with gr.Tab(label="Move or Delete Files"):
209 | ui.move_or_delete_files.create_ui(cfg_file_move_delete)
210 |
211 | # ----------------------------------------------------------------
212 | # General
213 |
214 | components_general = [
215 | ui.toprow.cb_backup,
216 | ui.load_dataset.tb_img_directory,
217 | ui.load_dataset.tb_caption_file_ext,
218 | ui.load_dataset.cb_load_recursive,
219 | ui.load_dataset.cb_load_caption_from_filename,
220 | ui.load_dataset.cb_replace_new_line_with_comma,
221 | ui.load_dataset.rb_use_interrogator,
222 | ui.load_dataset.dd_intterogator_names,
223 | ui.load_dataset.cb_use_custom_threshold_booru,
224 | ui.load_dataset.sl_custom_threshold_booru,
225 | ui.load_dataset.cb_use_custom_threshold_waifu,
226 | ui.load_dataset.sl_custom_threshold_waifu,
227 | ui.load_dataset.sl_custom_threshold_z3d,
228 | ui.toprow.cb_save_kohya_metadata,
229 | ui.toprow.tb_metadata_output,
230 | ui.toprow.tb_metadata_input,
231 | ui.toprow.cb_metadata_overwrite,
232 | ui.toprow.cb_metadata_as_caption,
233 | ui.toprow.cb_metadata_use_fullpath,
234 | ]
235 | components_filter = [
236 | ui.filter_by_tags.tag_filter_ui.cb_prefix,
237 | ui.filter_by_tags.tag_filter_ui.cb_suffix,
238 | ui.filter_by_tags.tag_filter_ui.cb_regex,
239 | ui.filter_by_tags.tag_filter_ui.rb_sort_by,
240 | ui.filter_by_tags.tag_filter_ui.rb_sort_order,
241 | ui.filter_by_tags.tag_filter_ui.rb_logic,
242 | ] + [
243 | ui.filter_by_tags.tag_filter_ui_neg.cb_prefix,
244 | ui.filter_by_tags.tag_filter_ui_neg.cb_suffix,
245 | ui.filter_by_tags.tag_filter_ui_neg.cb_regex,
246 | ui.filter_by_tags.tag_filter_ui_neg.rb_sort_by,
247 | ui.filter_by_tags.tag_filter_ui_neg.rb_sort_order,
248 | ui.filter_by_tags.tag_filter_ui_neg.rb_logic,
249 | ]
250 | components_batch_edit = [
251 | ui.batch_edit_captions.cb_show_only_tags_selected,
252 | ui.batch_edit_captions.cb_prepend_tags,
253 | ui.batch_edit_captions.cb_use_regex,
254 | ui.batch_edit_captions.rb_sr_replace_target,
255 | ui.batch_edit_captions.tag_select_ui_remove.cb_prefix,
256 | ui.batch_edit_captions.tag_select_ui_remove.cb_suffix,
257 | ui.batch_edit_captions.tag_select_ui_remove.cb_regex,
258 | ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by,
259 | ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order,
260 | ui.batch_edit_captions.rb_sort_by,
261 | ui.batch_edit_captions.rb_sort_order,
262 | ui.batch_edit_captions.nb_token_count,
263 | ]
264 | components_edit_selected = [
265 | ui.edit_caption_of_selected_image.cb_copy_caption_automatically,
266 | ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
267 | ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed,
268 | ui.edit_caption_of_selected_image.dd_intterogator_names_si,
269 | ui.edit_caption_of_selected_image.rb_sort_by,
270 | ui.edit_caption_of_selected_image.rb_sort_order,
271 | ]
272 | components_move_delete = [
273 | ui.move_or_delete_files.rb_move_or_delete_target_data,
274 | ui.move_or_delete_files.cbg_move_or_delete_target_file,
275 | ui.move_or_delete_files.tb_move_or_delete_caption_ext,
276 | ui.move_or_delete_files.tb_move_or_delete_destination_dir,
277 | ]
278 |
279 | configurable_components = (
280 | components_general
281 | + components_filter
282 | + components_batch_edit
283 | + components_edit_selected
284 | + components_move_delete
285 | )
286 |
287 | def reload_config_file():
288 | config.load()
289 | p, n = read_filter_config()
290 | logger.write("Reload config.json")
291 | return (
292 | read_general_config()
293 | + p
294 | + n
295 | + read_batch_edit_config()
296 | + read_edit_selected_config()
297 | + read_move_delete_config()
298 | )
299 |
300 | btn_reload_config_file.click(
301 | fn=reload_config_file, outputs=configurable_components
302 | )
303 |
304 | def save_settings_callback(*a):
305 | p = 0
306 |
307 | def inc(v):
308 | nonlocal p
309 | p += v
310 | return p
311 |
312 | write_general_config(*a[p : inc(len(components_general))])
313 | write_filter_config(*a[p : inc(len(components_filter))])
314 | write_batch_edit_config(*a[p : inc(len(components_batch_edit))])
315 | write_edit_selected_config(*a[p : inc(len(components_edit_selected))])
316 | write_move_delete_config(*a[p:])
317 | config.save()
318 | logger.write("Current settings have been saved into config.json")
319 |
320 | btn_save_setting_as_default.click(
321 | fn=save_settings_callback, inputs=configurable_components
322 | )
323 |
324 | def restore_default_settings():
325 | write_general_config(*CFG_GENERAL_DEFAULT)
326 | write_filter_config(*CFG_FILTER_P_DEFAULT, *CFG_FILTER_N_DEFAULT)
327 | write_batch_edit_config(*CFG_BATCH_EDIT_DEFAULT)
328 | write_edit_selected_config(*CFG_EDIT_SELECTED_DEFAULT)
329 | write_move_delete_config(*CFG_MOVE_DELETE_DEFAULT)
330 | logger.write("Restore default settings")
331 | return (
332 | CFG_GENERAL_DEFAULT
333 | + CFG_FILTER_P_DEFAULT
334 | + CFG_FILTER_N_DEFAULT
335 | + CFG_BATCH_EDIT_DEFAULT
336 | + CFG_EDIT_SELECTED_DEFAULT
337 | + CFG_MOVE_DELETE_DEFAULT
338 | )
339 |
340 | btn_restore_default.click(
341 | fn=restore_default_settings, outputs=configurable_components
342 | )
343 |
344 | o_update_gallery = [
345 | ui.dataset_gallery.cbg_hidden_dataset_filter,
346 | ui.dataset_gallery.nb_hidden_dataset_filter_apply,
347 | ui.dataset_gallery.nb_hidden_image_index,
348 | ui.dataset_gallery.nb_hidden_image_index_prev,
349 | ui.edit_caption_of_selected_image.nb_hidden_image_index_save_or_not,
350 | ui.gallery_state.txt_gallery,
351 | ]
352 |
353 | o_update_filter_and_gallery = (
354 | [
355 | ui.filter_by_tags.tag_filter_ui.cbg_tags,
356 | ui.filter_by_tags.tag_filter_ui_neg.cbg_tags,
357 | ]
358 | + o_update_gallery
359 | + [
360 | ui.batch_edit_captions.tb_common_tags,
361 | ui.batch_edit_captions.tb_edit_tags,
362 | ]
363 | + [ui.batch_edit_captions.tb_sr_selected_tags]
364 | + [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags]
365 | + [
366 | ui.edit_caption_of_selected_image.tb_caption,
367 | ui.edit_caption_of_selected_image.tb_edit_caption,
368 | ]
369 | )
370 |
371 | ui.toprow.set_callbacks(ui.load_dataset)
372 | ui.load_dataset.set_callbacks(
373 | o_update_filter_and_gallery,
374 | ui.toprow,
375 | ui.dataset_gallery,
376 | ui.filter_by_tags,
377 | ui.filter_by_selection,
378 | ui.batch_edit_captions,
379 | update_filter_and_gallery,
380 | )
381 | ui.dataset_gallery.set_callbacks(ui.load_dataset, ui.gallery_state, get_filters)
382 | ui.gallery_state.set_callbacks(ui.dataset_gallery)
383 | ui.filter_by_tags.set_callbacks(
384 | o_update_gallery,
385 | o_update_filter_and_gallery,
386 | ui.batch_edit_captions,
387 | ui.move_or_delete_files,
388 | update_gallery,
389 | update_filter_and_gallery,
390 | get_filters,
391 | )
392 | ui.filter_by_selection.set_callbacks(
393 | o_update_filter_and_gallery,
394 | ui.dataset_gallery,
395 | ui.filter_by_tags,
396 | get_filters,
397 | update_filter_and_gallery,
398 | )
399 | ui.batch_edit_captions.set_callbacks(
400 | o_update_filter_and_gallery,
401 | ui.load_dataset,
402 | ui.filter_by_tags,
403 | get_filters,
404 | update_filter_and_gallery,
405 | )
406 | ui.edit_caption_of_selected_image.set_callbacks(
407 | o_update_filter_and_gallery,
408 | ui.dataset_gallery,
409 | ui.load_dataset,
410 | get_filters,
411 | update_filter_and_gallery,
412 | )
413 | ui.move_or_delete_files.set_callbacks(
414 | o_update_filter_and_gallery,
415 | ui.dataset_gallery,
416 | get_filters,
417 | update_filter_and_gallery,
418 | )
419 |
420 | return [
421 | (
422 | dataset_tag_editor_interface,
423 | "Dataset Tag Editor",
424 | "dataset_tag_editor_interface",
425 | )
426 | ]
427 |
428 |
429 | def on_ui_settings():
430 | section = ("dataset-tag-editor", "Dataset Tag Editor")
431 | shared.opts.add_option(
432 | "dataset_editor_image_columns",
433 | shared.OptionInfo(6, "Number of columns on image gallery", section=section),
434 | )
435 | shared.opts.add_option(
436 | "dataset_editor_max_res",
437 | shared.OptionInfo(0, "Max resolution of temporary files", section=section),
438 | )
439 | shared.opts.add_option(
440 | "dataset_editor_use_temp_files",
441 | shared.OptionInfo(
442 | False, "Force image gallery to use temporary files", section=section
443 | ),
444 | )
445 | shared.opts.add_option(
446 | "dataset_editor_use_raw_clip_token",
447 | shared.OptionInfo(
448 | True,
449 | "Use raw CLIP token to calculate token count (without emphasis or embeddings)",
450 | section=section,
451 | ),
452 | )
453 | shared.opts.add_option(
454 | "dataset_editor_use_rating",
455 | shared.OptionInfo(
456 | False,
457 | "Use rating tags",
458 | section=section,
459 | ),
460 | )
461 |
462 | shared.opts.add_option(
463 | "dataset_editor_num_cpu_workers",
464 | shared.OptionInfo(
465 | -1,
466 | "Number of CPU workers when preprocessing images (set -1 to auto)",
467 | section=section,
468 | ),
469 | )
470 |
471 | shared.opts.add_option(
472 | "dataset_editor_batch_size_vit",
473 | shared.OptionInfo(
474 | 4,
475 | "Inference batch size for ViT taggers",
476 | section=section,
477 | ),
478 | )
479 |
480 | shared.opts.add_option(
481 | "dataset_editor_batch_size_convnext",
482 | shared.OptionInfo(
483 | 4,
484 | "Inference batch size for ConvNeXt taggers",
485 | section=section,
486 | ),
487 | )
488 |
489 | shared.opts.add_option(
490 | "dataset_editor_batch_size_swinv2",
491 | shared.OptionInfo(
492 | 4,
493 | "Inference batch size for SwinTransformerV2 taggers",
494 | section=section,
495 | ),
496 | )
497 |
498 |
499 | script_callbacks.on_ui_settings(on_ui_settings)
500 | script_callbacks.on_ui_tabs(on_ui_tabs)
501 |
--------------------------------------------------------------------------------