├── src └── __init__.py ├── danbooru ├── __init__.py └── fetch_tags.py ├── .env ├── .gitmodules ├── LICENSE ├── .gitignore └── e2e_tagger.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /danbooru/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.env: -------------------------------------------------------------------------------- 1 | PYTHONPATH=./src/Augmented-DDTagger -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/Augmented-DDTagger"] 2 | path = src/Augmented-DDTagger 3 | url = https://github.com/AdjointOperator/Augmented-DDTagger 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 AdjointOperator 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | -------------------------------------------------------------------------------- /danbooru/fetch_tags.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import argparse 3 | import asyncio 4 | import functools 5 | import hashlib 6 | import json 7 | import io 8 | import re 9 | 10 | from asyncio import gather, Semaphore 11 | from pathlib import Path 12 | from urllib.parse import urljoin 13 | 14 | from lxml import html 15 | from PIL import Image 16 | from tqdm.auto import tqdm 17 | 18 | import os 19 | DD_API_KEY = None 20 | DD_USER_NAME = None 21 | 22 | MD5_PREFER_FNAME = True 23 | LONG_SIDE = 768 24 | iqdb_url = "https://danbooru.iqdb.org/" 25 | UA = 'Mozilla/4.08 (compatible; MSIE 6.0; Windows NT 5.1)' 26 | 27 | known_img_suffix = set(['jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp', 'tiff', 'tif']) 28 | md5_m = re.compile(r'(? dict[str, list[str]] | None: 53 | if img_path.with_suffix('.json').exists(): 54 | with open(img_path.with_suffix('.json')) as f: 55 | db_json = json.load(f) 56 | tags_dict = dict( 57 | general=db_json['tag_string_general'].split(), 58 | character=db_json['tag_string_character'].split(), 59 | copyright=db_json['tag_string_copyright'].split(), 60 | artist=db_json['tag_string_artist'].split(), 61 | meta=db_json['tag_string_meta'].split(), 62 | ) 63 | else: 64 | tags_dict = None 65 | return tags_dict 66 | 67 | 68 | def get_pred_dict(img_path: Path, score: np.ndarray, tags: list[str]) -> dict[str, np.float32]: 69 | mask = MASK & (score > 0.001) 70 | indices = np.where(mask)[0] 71 | order = np.argsort(score[indices])[::-1] 72 | return {tags[i]: score[i] for i in indices[order]} 73 | 74 | 75 | def get_original_label(img_path: Path) -> str | None: 76 | if img_path.with_suffix('.txt').exists(): 77 | with open(img_path.with_suffix('.txt')) as f: 78 | return f.read().strip() 79 | 80 | 81 | def combine_dd_first(config: TaggerConfig, danbooru_tags: dict[str, str] | None, tagger_preds: dict[str, float], original_label: str): 82 | if danbooru_tags: 83 | return [tag for cate, taglist in danbooru_tags.items() for tag in taglist if cate in config._keep_catagories] 84 | elif not config.fallback: 85 | raise ValueError('DD_ONLY is True but no danbooru tags found') 86 | return [tag for tag, confid in tagger_preds.items() if confid > config.tagger_threshold2] 87 | 88 | 89 | def combine_AND(config: TaggerConfig, danbooru_tags: dict[str, str] | None, tagger_preds: dict[str, float], original_label: str): 90 | """ 91 | Args: 92 | danbooru_tags: category [general, meta, character, artist] -> tag with _ 93 | tagger_preds: tag -> confidence. Sorted by confidence in descending 94 | original_tags: original tags 95 | """ 96 | if danbooru_tags: 97 | dd_tags = set([tag for cate, taglist in danbooru_tags.items() for tag in taglist if cate in config._keep_catagories]) 98 | return_tags = [] 99 | for tag, confid in tagger_preds.items(): 100 | if confid < config.tagger_threshold1: 101 | break 102 | if tag in dd_tags: 103 | return_tags.append(tag) 104 | return return_tags 105 | return [tag for tag, confid in tagger_preds.items() if confid > config.tagger_threshold2] 106 | 107 | 108 | def combine_OR(config: TaggerConfig, danbooru_tags: dict[str, str] | None, tagger_preds: dict[str, float], original_label: str): 109 | tags = [tag for tag, confid in tagger_preds.items() if confid > config.tagger_threshold2] 110 | if danbooru_tags: 111 | dd_tags = set([tag for cate, taglist in danbooru_tags.items() for tag in taglist if cate in config._keep_catagories]) 112 | tags = list(dd_tags | set(tags)) 113 | return tags 114 | 115 | 116 | def combine_tagger_first(config: TaggerConfig, danbooru_tags: dict[str, str] | None, tagger_preds: dict[str, float], original_label: str): 117 | return [tag for tag, confid in tagger_preds.items() if confid > config.tagger_threshold2] 118 | 119 | 120 | def combine(config: TaggerConfig, pred_dict: dict[str, np.float32], danbooru_tag_dict: dict[str, list[str]] | None, original_label: str | None) -> list[str]: 121 | assert config.combine_mode in ('AND', 'OR', 'DD-first', 'TAGGER-first') 122 | 123 | ret_dict = { 124 | 'AND': combine_AND, 125 | 'OR': combine_OR, 126 | 'DD-first': combine_dd_first, 127 | 'TAGGER-first': combine_tagger_first, 128 | } 129 | return ret_dict[config.combine_mode](config, danbooru_tag_dict, pred_dict, original_label) # type: ignore 130 | 131 | 132 | def an_astolfo_is_1girl(all_tags: list[str]): 133 | """Generate gender related tags according to the tagger's prediction. If it looks like 1girl, then it is 1girl.""" 134 | 135 | gp1 = ['1girl', '2girls', '3girls', '4girls', '5girls', '6+girls'] 136 | gp2 = ['1boy', '2boys', '3boys', '4boys', '5boys', '6+boys'] 137 | gp3 = ['no_humans'] 138 | rm = ['genderswap', 'genderswap_(mtf)', 'genderswap_(ftm)', 'ambiguous_gender'] + gp1 + gp2 + gp3 139 | rm = set(rm) 140 | groups = [gp1, gp2, gp3] 141 | gp_indices = [np.array([all_tags.index(t) for t in g]) for g in groups] 142 | 143 | def make_astolfo_girl(config: TaggerConfig, combined_tags: list[str], score: np.ndarray): 144 | gp_scores = [score[idxs] for idxs in gp_indices] 145 | gp_scores_max = [gp_score.max() for gp_score in gp_scores] 146 | top_gp = np.argmax(gp_scores_max) 147 | top_loc = np.argmax(score[gp_indices[top_gp]]) 148 | keep = [groups[top_gp][top_loc]] 149 | for i in range(len(groups)): 150 | if i == top_gp: 151 | continue 152 | if gp_scores_max[i] > config.tagger_threshold2: 153 | keep.append(groups[i][np.argmax(gp_scores[i])]) 154 | for tag in combined_tags: 155 | if tag not in rm: 156 | keep.append(tag) 157 | return keep 158 | return make_astolfo_girl 159 | 160 | 161 | useless_tags = set( 162 | ['virtual_youtuber'] + 163 | [tag for tag in all_tags if 'alternate_' in tag] + 164 | ['genderswap', 'genderswap_(mtf)', 'genderswap_(ftm)', 'ambiguous_gender'] 165 | ) 166 | 167 | 168 | def rm_useless_tags(combined_tags): 169 | return [tag for tag in combined_tags if tag not in useless_tags] 170 | 171 | 172 | make_astolfo_girl = an_astolfo_is_1girl(all_tags) 173 | 174 | 175 | def process( 176 | config: TaggerConfig, 177 | pred_dict: dict[str, np.float32], 178 | scores: np.ndarray, 179 | danbooru_tag_dict: dict[str, list[str]] | None, 180 | original_label: str | None, 181 | ): 182 | tags = combine(config, pred_dict, danbooru_tag_dict, original_label) 183 | tags = make_astolfo_girl(config, tags, scores) 184 | tags = rm_useless_tags(tags) 185 | return tags 186 | 187 | 188 | class Postprocess: 189 | def __init__( 190 | self, 191 | rm: list[str] | None = None, 192 | replace: list[tuple[str, str]] | None = None, 193 | prepend: list[str] | None = None, 194 | append: list[str] | None = None, 195 | must_have: list[str] | None = None, 196 | ): 197 | self.rm = set(rm) if rm else set() 198 | self.replace = {k: v for k, v in replace} if replace else {} 199 | _prepend = set(prepend) if prepend else set() 200 | self.append = list(set(append)) if append else [] 201 | if must_have: 202 | for t in must_have: 203 | self.rm.add(t) 204 | _prepend.add(t) 205 | self.prepend = list(_prepend) 206 | 207 | def __call__(self, tags: list[str]) -> list[str]: 208 | filtered_tags = [] 209 | for t in tags: 210 | if t in self.rm: 211 | continue 212 | if t in self.replace: 213 | t = self.replace[t] 214 | filtered_tags.append(t) 215 | return self.prepend + filtered_tags + self.append 216 | 217 | def update(self, p: Path) -> 'Postprocess': 218 | if not p.exists(): 219 | return deepcopy(self) 220 | with open(p, 'r') as f: 221 | other: dict = yaml.load(f, Loader=yaml.FullLoader) 222 | if not other: 223 | return deepcopy(self) 224 | for k, v in other.items(): 225 | if v.__class__ is not list: 226 | other[k] = [v] 227 | 228 | rm = list(self.rm) + other.get('rm', []) 229 | replace = list(self.replace.items()) + [tuple(v) for v in other.get('replace', [])] 230 | prepend = list(self.prepend) + other.get('prepend', []) 231 | append = list(self.append) + other.get('append', []) 232 | must_have = list(other.get('must_have', [])) 233 | return Postprocess(rm, replace, prepend, append, must_have) 234 | 235 | 236 | def _img_dirwalk_with_postprocessor(path: Path, postprocessor: Postprocess): 237 | 238 | postprocessor = postprocessor.update(path / 'rules.yaml') 239 | if not path.name.startswith('_'): 240 | name = path.name.split('#', 1)[0].strip() 241 | postprocessor.prepend.append(name) 242 | 243 | for p in path.iterdir(): 244 | if p.is_dir(): 245 | yield from _img_dirwalk_with_postprocessor(p, postprocessor) 246 | else: 247 | if p.suffix[1:] in known_img_suffix: 248 | yield p, postprocessor 249 | 250 | 251 | def img_dirwalk_with_postprocessor(path: Path | str): 252 | path = Path(path) 253 | postprocess = Postprocess() 254 | yield from _img_dirwalk_with_postprocessor(path, postprocess) 255 | 256 | 257 | def get_args(): 258 | parser = argparse.ArgumentParser() 259 | parser.add_argument('--path', '-p', type=str, required=True, help='Path to image or directory') 260 | parser.add_argument('--model_path', type=str, required=False, default=None, help='Path to model for AugDD') 261 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference in AugDD') 262 | parser.add_argument('--backend', type=str, default='WD14-SwinV2', help='Backend model to use in AugDD', choices=['WD14-SwinV2', 'WD14-ConvNext', 'DeepDanbooru', 'WD14']) 263 | parser.add_argument('--nproc', type=int, default=-1, help='Number of processes to use for AugDD. -1 means all') 264 | parser.add_argument('--max_chunk', type=int, default=16, help='Maximum number of batches to process before one save in AugDD') 265 | parser.add_argument('--danbooru-concurrency', type=int, default=2, help='Number of concurrent requests to Danbooru') 266 | parser.add_argument('--iqdb-concurrency', type=int, default=2, help='Number of concurrent requests to iqdb') 267 | parser.add_argument('--retry-nomatch', action='store_true', help='Retry images that have no matches on iqdb') 268 | parser.add_argument('--keep-catagories', nargs='+', default=['general'], help='Available: general, character, artist, copyright, meta') 269 | parser.add_argument('--threshold1', '-t1', type=float, default=0.3, help='Tagger threshold for AND mode') 270 | parser.add_argument('--threshold2', '-t2', type=float, default=0.55, help='Tagger threshold for OR and fallback mode') 271 | parser.add_argument('--combine-mode', '-cm', type=str, default='AND', choices=['AND', 'OR', 'TAGGER-first', 'DD-first']) 272 | return parser.parse_args() 273 | 274 | 275 | if __name__ == '__main__': 276 | 277 | args = get_args() 278 | if not args.model_path: 279 | args.model_path = 'src/Augmented-DDTagger/models/wd-v1-4-swinv2-tagger-v2' 280 | print('Using default model path:', args.model_path) 281 | 282 | if 'general' in args.keep_catagories: 283 | MASK |= mask_general 284 | if 'character' in args.keep_catagories: 285 | MASK |= mask_character 286 | if 'meta' in args.keep_catagories: 287 | MASK |= mask_meta 288 | 289 | config = TaggerConfig( 290 | keep_catagories=tuple(args.keep_catagories), 291 | tagger_threshold1=args.threshold1, 292 | tagger_threshold2=args.threshold2, 293 | model_path=args.model_path, 294 | combine_mode=args.combine_mode, 295 | fallback=True, 296 | ) 297 | 298 | path = Path(args.path) 299 | 300 | print('Running AugDD...') 301 | scores, rel_paths = get_predictions( 302 | model_path=args.model_path, 303 | root_path=path, 304 | batch_size=args.batch_size, 305 | backend=args.backend, 306 | nproc=args.nproc, 307 | max_chunk=args.max_chunk, 308 | ) 309 | scores = sigmoid(np.float32(scores)) 310 | 311 | score_dicts = {k: v for k, v in zip(rel_paths, scores)} # type: ignore $ img_path -> confidence 312 | 313 | print('Fetching Danbooru tags...') 314 | asyncio.run(fetch_danbooru_tag(args.path, args.iqdb_concurrency, args.danbooru_concurrency, args.retry_nomatch)) 315 | 316 | for img_path, postproc in img_dirwalk_with_postprocessor(path): 317 | danbooru_tag_dict = get_danbooru_tags_dict(img_path) 318 | scores = score_dicts[str(img_path.relative_to(path))] 319 | pred_dict = get_pred_dict(img_path, scores, all_tags) 320 | original_label = get_original_label(img_path) 321 | tags = process(config, pred_dict, scores, danbooru_tag_dict, original_label) 322 | tags = postproc(tags) 323 | with open(img_path.with_suffix('.txt'), 'w') as f: 324 | f.write(' '.join(tags)) 325 | --------------------------------------------------------------------------------