├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── anime2sd ├── __init__.py ├── arrange.py ├── balancing.py ├── basics.py ├── captioning │ ├── __init__.py │ ├── captioning.py │ ├── tag_and_caption.py │ ├── tagging_basics.py │ ├── tagging_character.py │ └── waifuc_actions.py ├── character.py ├── classif │ ├── __init__.py │ ├── classify_characters.py │ ├── file_utils.py │ ├── imagewise.py │ └── merge_clusters.py ├── common_preprocess.py ├── download.py ├── emb_utils.py ├── execution_ordering.py ├── extract_frames.py ├── image_selection.py ├── parse_arguments.py ├── remove_duplicates.py └── waifuc_customize.py ├── automatic_pipeline.py ├── configs ├── csv_examples │ ├── character_mapping_example.csv │ ├── character_mapping_example_no_underscore.csv │ ├── default_weighting.csv │ ├── embedding_names_example.csv │ └── weighting_example.csv ├── hcp │ ├── caption.txt │ ├── dataset.yaml │ ├── diag_oft.yaml │ ├── loha.yaml │ ├── lokr.yaml │ ├── lora_conventional.yaml │ ├── text2img.yaml │ └── train_base.yaml ├── pipelines │ ├── base.toml │ ├── booru.toml │ └── screenshots.toml └── tag_filtering │ ├── blacklist_tags.txt │ ├── character_tags.json │ ├── overlap_tags.json │ └── overlap_tags_simplified.json ├── docs ├── Character_ref_organization.md ├── Conversion_scripts.md ├── Dataset_organization.md ├── Main_arguments.md ├── Pipeline.md ├── Start_training.md └── example_logs │ ├── booru_hikikomari_2023-12-0120-27-45.log │ ├── booru_sousou_no_frieren_2023-12-0201-59-37.log │ ├── hikikomari_weighting_2023-12-0120-36-13.log │ ├── screenshots_hikikomari_2023-12-0120-27-45.log │ ├── screenshots_sousou_no_frieren_2023-12-0201-59-37.log │ └── sousou_no_frieren_weighting_2023-12-0202-16-35.log ├── flatten_folder.py ├── install.bat ├── install.py ├── install.sh ├── prepare_hcp.py ├── requirements.txt ├── scripts_v1 ├── README.md ├── arrange_folder.py ├── augment_metadata.py ├── classifier_dataset_preparation │ ├── crop_and_make_dataset.py │ ├── data_split.py │ └── make_data_dic_imagenetsyle.py ├── classifier_training │ ├── models │ │ ├── README.md │ │ ├── download_convert_models.py │ │ ├── setup.py │ │ └── vit_animesion │ │ │ ├── __init__.py │ │ │ ├── configs.py │ │ │ ├── model.py │ │ │ ├── transformer.py │ │ │ └── utils.py │ ├── requirements.txt │ ├── train.py │ ├── utilities │ │ ├── __init__.py │ │ ├── build_vocab.py │ │ ├── calc_tokens_len.py │ │ ├── custom_tokenizer.py │ │ ├── data_selection.py │ │ ├── data_selection_customize.py │ │ ├── loss-landscapes │ │ │ ├── .gitignore │ │ │ ├── MANIFEST.in │ │ │ ├── README.md │ │ │ ├── examples │ │ │ │ └── core-features.ipynb │ │ │ ├── img │ │ │ │ ├── loss-contour-3d.png │ │ │ │ ├── loss-contour.png │ │ │ │ └── loss-landscape.png │ │ │ ├── loss_landscapes │ │ │ │ ├── __init__.py │ │ │ │ ├── contrib │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── connecting_paths.py │ │ │ │ │ └── trajectories.py │ │ │ │ ├── main.py │ │ │ │ ├── metrics │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── metric.py │ │ │ │ │ ├── rl_metrics.py │ │ │ │ │ └── sl_metrics.py │ │ │ │ └── model_interface │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── model_parameters.py │ │ │ │ │ └── model_wrapper.py │ │ │ ├── requirements.txt │ │ │ ├── setup.cfg │ │ │ └── setup.py │ │ ├── matcher.py │ │ ├── misc.py │ │ ├── model_selection.py │ │ ├── plot_mask_schedules.py │ │ ├── scheduler.py │ │ └── video_transform.py │ └── vocab.pkl ├── classify_characters.py ├── correct_metadata_from_foldername.py ├── danbooru_tag_tree │ ├── .gitignore │ ├── parse_page.py │ ├── parse_tag.py │ ├── tag_list_urls.json │ ├── tag_tree.json │ └── tag_tree.py ├── detect_faces.py ├── extract_frames.py ├── generate_captions.py ├── generate_multiply.py ├── remove_similar.ipynb ├── rename_character.py ├── requirements.txt ├── subsidiary │ ├── batch_resize.py │ ├── classify_characters_advanced.py │ ├── classify_characters_ensemble.py │ ├── crop_faces.py │ ├── find_duplicate.py │ ├── generate_captions_advanced.py │ ├── rename_md5.py │ └── retrieve_high_score.py └── tagger │ ├── blip │ ├── blip.py │ ├── med.py │ ├── med_config.json │ └── vit.py │ ├── make_caption.py │ └── tag_images_by_wd14_tagger.py ├── setup.py ├── tests ├── integration_tests │ ├── test_classify_characters.py │ └── test_download_images.py └── unit_tests │ ├── test_basics.py │ └── test_select_to_add.py └── utilities ├── batch_bundle_convert.py ├── batch_hcp_convert.py ├── convert_metadata.py ├── correct_path_field.py ├── count_tag_appearance.py ├── get_core_tags.py ├── rename_characters.py ├── replace_tags.py └── update_safetensor_metadata.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .flake8 131 | .*.swp 132 | 133 | scripts_v1/tagger/wd14_tagger_model/ 134 | scripts_v1/tagger/wd14_tagger_model_v1/ 135 | data/ 136 | logs/ 137 | tmp/ 138 | *.egg-info 139 | **/wandb/ 140 | **/results_training 141 | */*.ckpt 142 | !docs/**/*.log 143 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "waifuc"] 2 | path = waifuc 3 | url = https://github.com/deepghs/waifuc 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 CyberMeow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /anime2sd/__init__.py: -------------------------------------------------------------------------------- 1 | from .common_preprocess import rearrange_related_files, load_metadata_from_aux 2 | from .download import download_animes, download_images 3 | from .extract_frames import extract_and_remove_similar 4 | from .remove_duplicates import DuplicateRemover 5 | from .classif.classify_characters import classify_from_directory 6 | from .emb_utils import update_emb_init_info 7 | from .image_selection import ( 8 | select_dataset_images_from_directory, 9 | ) 10 | from .captioning import ( 11 | tag_and_caption_from_directory, 12 | compute_and_save_core_tags, 13 | tag_and_caption_from_directory_core_final, 14 | TaggingManager, 15 | CaptionGenerator, 16 | CoreTagProcessor, 17 | CharacterTagProcessor, 18 | ) 19 | from .arrange import arrange_folder 20 | from .balancing import read_weight_mapping, get_repeat 21 | -------------------------------------------------------------------------------- /anime2sd/balancing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import time 4 | import logging 5 | import fnmatch 6 | from tqdm import tqdm 7 | 8 | import numpy as np 9 | 10 | from anime2sd.basics import get_images_recursively 11 | 12 | 13 | class WeightTree(object): 14 | def __init__(self, dirname, weight_mapping=None, progress_bar=None): 15 | self.dirname = dirname 16 | self.n_images = 0 17 | self.contain_images = False 18 | self.children = [] 19 | 20 | for path in os.listdir(dirname): 21 | path = os.path.join(self.dirname, path) 22 | if os.path.isfile(path): 23 | extension = os.path.splitext(path)[1] 24 | if extension.lower() in [".jpg", ".jpeg", ".png", ".webp"]: 25 | if progress_bar is not None: 26 | progress_bar.update(1) 27 | self.n_images += 1 28 | self.contain_images = True 29 | elif os.path.isdir(path): 30 | sub_weight_tree = WeightTree(path, weight_mapping, progress_bar) 31 | if sub_weight_tree.contain_images or len(sub_weight_tree.children) > 0: 32 | self.children.append(sub_weight_tree) 33 | self.weight = self.modify_weight(weight_mapping) 34 | 35 | def modify_weight(self, training_weights): 36 | if training_weights is None: 37 | return 1 38 | basename = os.path.basename(self.dirname) 39 | if basename in training_weights: 40 | # print(self.dirname) 41 | # print(training_weights[basename]) 42 | return float(training_weights[basename]) 43 | for pattern in training_weights: 44 | if fnmatch.fnmatch(self.dirname, pattern): 45 | # print(self.dirname) 46 | # print(training_weights[pattern]) 47 | return float(training_weights[pattern]) 48 | return 1 49 | 50 | def compute_sampling_prob(self, baseprob, dir_list, prob_list, n_images_list): 51 | weights_list = [] 52 | for weight_tree in self.children: 53 | weights_list.append(weight_tree.weight) 54 | if self.contain_images: 55 | weights_list.append(self.weight) 56 | probs = np.array(weights_list) / np.sum(weights_list) 57 | # Modify dir_list and prob_list in place 58 | if self.contain_images: 59 | dir_list.append(self.dirname) 60 | prob_list.append(baseprob * probs[-1]) 61 | n_images_list.append(self.n_images) 62 | for i, weight_tree in enumerate(self.children): 63 | weight_tree.compute_sampling_prob( 64 | baseprob * probs[i], dir_list, prob_list, n_images_list 65 | ) 66 | 67 | 68 | def read_weight_mapping(weight_mapping_csv): 69 | weight_mapping = {} 70 | with open(weight_mapping_csv, "r") as f: 71 | reader = csv.reader(f) 72 | for row in reader: 73 | pattern, weight = row 74 | weight_mapping[pattern] = weight 75 | return weight_mapping 76 | 77 | 78 | def get_repeat( 79 | src_dir, 80 | weight_mapping, 81 | min_multiply=1, 82 | max_multiply=100, 83 | log_file=None, 84 | logger=None, 85 | ): 86 | if logger is None: 87 | logger = logging.getLogger() 88 | n_images_totol = len(get_images_recursively(src_dir)) 89 | bar = tqdm(total=n_images_totol) 90 | 91 | weight_tree = WeightTree(src_dir, weight_mapping, bar) 92 | 93 | dir_list = [] 94 | prob_list = [] 95 | n_images_list = [] 96 | 97 | weight_tree.compute_sampling_prob(1, dir_list, prob_list, n_images_list) 98 | 99 | probs = np.array(prob_list) 100 | n_images_array = np.array(n_images_list) 101 | per_image_weights = probs / n_images_array 102 | 103 | # This makes the weights larger than 1 104 | per_image_multiply = per_image_weights / np.min(per_image_weights) 105 | per_image_multiply = per_image_multiply * min_multiply 106 | per_image_multiply_final = np.minimum( 107 | np.around(per_image_multiply, 2), max_multiply 108 | ) 109 | 110 | if log_file is not None: 111 | original_handlers = logger.handlers[:] 112 | for handler in original_handlers: 113 | logger.removeHandler(handler) 114 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 115 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 116 | fh = logging.FileHandler(log_file) 117 | fh.setLevel(logging.INFO) 118 | fh.setFormatter(formatter) 119 | logger.addHandler(fh) 120 | # logger.propagate = False 121 | 122 | n_images_total = 0 123 | n_images_virtual_total = 0 124 | 125 | for k in np.argsort(per_image_multiply): 126 | dirname = dir_list[k] 127 | n_images = n_images_list[k] 128 | multiply = per_image_multiply_final[k] 129 | n_images_total += n_images 130 | n_images_virtual_total += n_images * multiply 131 | with open(os.path.join(dirname, "multiply.txt"), "w") as f: 132 | f.write(str(multiply)) 133 | if log_file is not None: 134 | logger.info(dirname) 135 | logger.info(f"sampling probability: {prob_list[k]}") 136 | logger.info(f"number of images: {n_images}") 137 | logger.info(f"original multipy: {per_image_multiply[k]}") 138 | logger.info(f"final multipy: {multiply}\n") 139 | 140 | logger.info(f"Number of images: {n_images_totol}") 141 | logger.info(f"Virtual dataset size: {n_images_virtual_total}") 142 | time.sleep(1) 143 | 144 | if log_file is not None: 145 | logger.handlers = [] # Removing existing handlers 146 | for handler in original_handlers: 147 | logger.addHandler(handler) 148 | # logger.propagate = True 149 | logger.info(f"Number of images: {n_images_totol}") 150 | logger.info(f"Virtual dataset size: {n_images_virtual_total}") 151 | time.sleep(1) 152 | -------------------------------------------------------------------------------- /anime2sd/captioning/__init__.py: -------------------------------------------------------------------------------- 1 | from .captioning import * 2 | from .tagging_basics import * 3 | from .tagging_character import * 4 | from .waifuc_actions import * 5 | from .tag_and_caption import * 6 | -------------------------------------------------------------------------------- /anime2sd/captioning/tagging_basics.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Dict, Union, Optional 3 | 4 | from hbutils.string import singular_form, plural_form 5 | 6 | 7 | def get_all_singular_plural_forms(tags): 8 | """ 9 | Get all singular and plural forms of the given tags. 10 | 11 | :param tags: List of tags. 12 | :type tags: list[str] 13 | :return: List of all singular and plural forms of the tags. 14 | :rtype: list[str] 15 | """ 16 | forms = set() 17 | for tag in tags: 18 | forms.add(tag) # Add the original form 19 | sing = singular_form(tag) 20 | forms.add(sing) 21 | plur = plural_form(tag) 22 | forms.add(plur) 23 | return list(forms) 24 | 25 | 26 | def sort_tags(tags, sort_mode): 27 | """ 28 | Sorts tags based on the specified mode. 29 | 30 | :param tags: List or Dictionary of tags. 31 | :param sort_mode: Sorting mode ('original', 'shuffle', 'score'). 32 | :return: Sorted tags. 33 | """ 34 | assert sort_mode in ["original", "shuffle", "score"] 35 | npeople_tags = [] 36 | remaining_tags = [] 37 | 38 | if "solo" in tags: 39 | npeople_tags.append("solo") 40 | 41 | for tag in tags: 42 | if tag == "solo": 43 | continue 44 | if "girls" in tag or "boys" in tag or tag in ["1girl", "1boy"]: 45 | npeople_tags.append(tag) 46 | else: 47 | remaining_tags.append(tag) 48 | 49 | if sort_mode == "score" and isinstance(tags, dict): 50 | # Sorting remaining_tags by score in descending order 51 | remaining_tags = sorted( 52 | remaining_tags, 53 | key=lambda tag: tags[tag], 54 | reverse=True, # Higher scores first 55 | ) 56 | elif sort_mode == "shuffle": 57 | random.shuffle(remaining_tags) 58 | 59 | return npeople_tags + remaining_tags 60 | 61 | 62 | def drop_tags_from_dictionary( 63 | tags: Union[List[str], Dict[str, float]], 64 | kept_tags: List[str], 65 | dropped_tags: Optional[List[str]] = None, 66 | ): 67 | if isinstance(tags, dict): 68 | kept_tags = {tag: value for tag, value in tags.items() if tag in kept_tags} 69 | if dropped_tags is not None: 70 | dropped_tags = { 71 | tag: value for tag, value in tags.items() if tag in dropped_tags 72 | } 73 | return kept_tags, dropped_tags 74 | 75 | 76 | def drop_blacklisted_tags(tags, blacklisted_tags): 77 | """ 78 | Remove blacklisted tags from the list or dictionary of tags. 79 | 80 | :param tags: List or dictionary of tags. 81 | :param blacklisted_tags: Set of blacklisted tags. 82 | :return: List or dictionary of tags after removing blacklisted tags. 83 | """ 84 | # Handle both underscore and whitespace in tags 85 | blacklist = set(tag.replace(" ", "_") for tag in blacklisted_tags) 86 | blacklist_update = set(tag.replace("_", " ") for tag in blacklist) 87 | blacklist.update(blacklist_update) 88 | 89 | if isinstance(tags, dict): 90 | return { 91 | tag: value 92 | for tag, value in tags.items() 93 | if tag not in blacklisted_tags 94 | and tag.replace(" ", "_") not in blacklisted_tags 95 | and tag.replace("_", " ") not in blacklisted_tags 96 | } 97 | elif isinstance(tags, list): 98 | return [ 99 | tag 100 | for tag in tags 101 | if tag not in blacklisted_tags 102 | and tag.replace(" ", "_") not in blacklisted_tags 103 | and tag.replace("_", " ") not in blacklisted_tags 104 | ] 105 | else: 106 | raise ValueError(f"Unsuppored types {type(tags)} for {tags}") 107 | 108 | 109 | def drop_overlap_tags(tags, overlap_tags_dict, check_superword=True): 110 | """ 111 | Removes overlap tags from the list of tags. 112 | 113 | :param tags: List or Dictionary of tags. 114 | :param overlap_tags_dict: Dictionary with overlap tag information. 115 | Assume here to take the underscore format. 116 | :return: A list or dictionary with overlap tags removed. 117 | """ 118 | # If tags is a dictionary, extract the keys for processing 119 | # and remember to return a dictionary 120 | return_as_dict = False 121 | original_tags = tags 122 | if isinstance(tags, dict): 123 | return_as_dict = True 124 | tags = list(tags.keys()) 125 | 126 | result_tags = [] 127 | tags_underscore = [tag.replace(" ", "_") for tag in tags] 128 | 129 | for tag, tag_ in zip(tags, tags_underscore): 130 | to_remove = False 131 | 132 | # Case 1: If the tag is a key and some of 133 | # the associated values are in tags 134 | if tag_ in overlap_tags_dict: 135 | overlap_values = set(val for val in overlap_tags_dict[tag_]) 136 | if overlap_values.intersection(set(tags_underscore)): 137 | to_remove = True 138 | 139 | if check_superword: 140 | # Checking superword condition separately 141 | for tag_another in tags: 142 | if tag in tag_another and tag != tag_another: 143 | to_remove = True 144 | break 145 | 146 | if not to_remove: 147 | result_tags.append(tag) 148 | 149 | # If the input was a dictionary 150 | # return as a dictionary with the same values 151 | if return_as_dict: 152 | result_tags = {tag: original_tags[tag] for tag in result_tags} 153 | 154 | return result_tags 155 | -------------------------------------------------------------------------------- /anime2sd/classif/__init__.py: -------------------------------------------------------------------------------- 1 | from .classify_characters import * 2 | from .file_utils import * 3 | from .imagewise import * 4 | from .merge_clusters import * 5 | -------------------------------------------------------------------------------- /anime2sd/common_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import logging 5 | from tqdm import tqdm 6 | from typing import List, Dict, Optional 7 | 8 | from anime2sd.basics import ( 9 | get_images_recursively, 10 | get_related_paths, 11 | get_default_metadata, 12 | ) 13 | from anime2sd.waifuc_customize import LocalSource, SaveExporter, TagRenameAction 14 | 15 | 16 | def construct_file_list(src_dir: str): 17 | """ 18 | Construct a list of all files in the directory and checks for duplicates. 19 | 20 | Args: 21 | src_dir (str): The directory to search. 22 | Reurns: 23 | A list of all file paths in the directory. 24 | """ 25 | all_files = {} 26 | for root, _, filenames in os.walk(src_dir): 27 | for filename in filenames: 28 | path = os.path.join(root, filename) 29 | if filename in all_files and filename != "multiply.txt": 30 | raise ValueError(f"Duplicate filename found: {filename}") 31 | all_files[filename] = path 32 | return all_files 33 | 34 | 35 | def rearrange_related_files(src_dir: str, logger: Optional[logging.Logger] = None): 36 | """ 37 | Rearrange related files in some directory. 38 | 39 | Args: 40 | src_dir (src): The directory containing images and other files to rearrange. 41 | logger (Logger): A logger to use for logging. 42 | """ 43 | if logger is None: 44 | logger = logging.getLogger() 45 | all_files = construct_file_list(src_dir) 46 | image_files = get_images_recursively(src_dir) 47 | 48 | logger.info("Arranging related files ...") 49 | for img_path in tqdm(image_files, desc="Rearranging related files"): 50 | related_paths = get_related_paths(img_path) 51 | for related_path in related_paths: 52 | # If the related file does not exist in the expected location 53 | if not os.path.exists(related_path): 54 | # Search for the file in the all_files dictionary 55 | found_path = all_files.get(os.path.basename(related_path)) 56 | if found_path is None: 57 | if related_path.endswith("json"): 58 | logger.warning(f"No related file found for {related_path}") 59 | meta_data = get_default_metadata(img_path) 60 | with open(related_path, "w") as f: 61 | json.dump(meta_data, f) 62 | else: 63 | # Move the found file to the expected location 64 | shutil.move(found_path, related_path) 65 | logger.info( 66 | f"Moved related file from {found_path} " f"to {related_path}" 67 | ) 68 | 69 | 70 | def load_metadata_from_aux( 71 | src_dir: str, 72 | load_grabber_ext: Optional[str], 73 | load_aux: List[str], 74 | overwrite_path: bool, 75 | character_mapping: Optional[Dict[str, str]], 76 | logger: Optional[logging.Logger] = None, 77 | ) -> None: 78 | """ 79 | Load metadata from auxiliary data and export it with potential modifications. 80 | 81 | This function loads metadata from a source directory, potentially modifies it, 82 | and then saves it back to the same directory. 83 | 84 | Args: 85 | src_dir (str): 86 | The source directory from which to load the metadata. 87 | load_grabber_ext (Optional[str]): 88 | The extension of the grabber information files to be loaded. 89 | load_aux (List[str]): 90 | A list of auxiliary data attributes to be loaded. 91 | overwrite_path (bool): 92 | Flag to indicate if the path in the metadata should be overwritten. 93 | character_mapping (Optional[Dict[str, str]]): 94 | A mapping from old character names to new character names. 95 | logger (Logger): Logger to use for logging. 96 | """ 97 | if logger is None: 98 | logger = logging.getLogger() 99 | logger.info("Load metadata from auxiliary data ...") 100 | source = LocalSource( 101 | src_dir, 102 | load_grabber_ext=load_grabber_ext, 103 | load_aux=load_aux, 104 | overwrite_path=overwrite_path, 105 | ) 106 | if character_mapping: 107 | # Renaming characters 108 | source = source.attach( 109 | TagRenameAction(character_mapping, fields=["characters"]) 110 | ) 111 | source.export( 112 | SaveExporter( 113 | src_dir, 114 | no_meta=False, 115 | in_place=True, 116 | ) 117 | ) 118 | -------------------------------------------------------------------------------- /anime2sd/emb_utils.py: -------------------------------------------------------------------------------- 1 | """Check embedding names are valid and update embedding information 2 | with initialization text 3 | """ 4 | import os 5 | import json 6 | import logging 7 | from typing import Optional, List, Dict 8 | from transformers import AutoTokenizer 9 | 10 | 11 | # TODO: Separate function that take directly embedding names from those that take 12 | # characters and image type, and deal with outfits embedding etc properly 13 | def update_emb_init_info( 14 | filepath: str, 15 | characters: List[str], 16 | image_types: List[str], 17 | emb_init_dict: Optional[Dict[str, List[str]]] = None, 18 | overwrite: bool = False, 19 | logger: Optional[logging.Logger] = None, 20 | ) -> None: 21 | """ 22 | Updates the JSON file with character names and optionally with embedding 23 | initialization information. 24 | Additionally, checks the validity of embedding names for HCP training. 25 | 26 | Args: 27 | filepath (str): 28 | Path to the trigger word JSON file. 29 | characters (List[str]): 30 | List of character names to add. 31 | image_types (List[str): 32 | Types of the images ("screenshots", "booru", or other). 33 | emb_init_dict (Optional[Dict[str, List[str]]]): 34 | Optional dictionary for embedding initializations. 35 | overwrite (bool): 36 | Whether to overwrite existing JSON content. 37 | logger (Optional[logging.Logger]): 38 | Optional logger to use. Defaults to None, which uses the default logger. 39 | """ 40 | if logger is None: 41 | logger = logging.getLogger() 42 | if isinstance(image_types, str): 43 | image_types = [image_types] 44 | name_init_map = {} 45 | 46 | # Read existing content if not overwriting 47 | if not overwrite and os.path.exists(filepath): 48 | with open(filepath, "r") as file: 49 | name_init_map = json.load(file) 50 | 51 | # Add characters to the map 52 | for character in characters: 53 | embedding_name = character.split()[0] 54 | if embedding_name not in name_init_map: 55 | name_init_map[embedding_name] = [] 56 | 57 | # Add image_type to the map 58 | for image_type in image_types: 59 | if image_type not in name_init_map: 60 | if image_type == "screenshots": 61 | default_init_text = "anime screencap" 62 | elif image_type == "booru": 63 | default_init_text = "masterpiece" 64 | else: 65 | default_init_text = "" 66 | name_init_map[image_type] = [default_init_text] 67 | 68 | # Update with emb_init_dict 69 | if emb_init_dict: 70 | for emb, tags in emb_init_dict.items(): 71 | if emb in name_init_map: 72 | # Add new tags to the existing list, avoiding duplicates 73 | name_init_map[emb].extend( 74 | [tag for tag in tags if tag not in name_init_map[emb]] 75 | ) 76 | else: 77 | name_init_map[emb] = tags 78 | 79 | # Initialize tokenizer 80 | tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") 81 | invalid_embedding_names = [] 82 | for embedding_name in name_init_map.keys(): 83 | if embedding_name.lower() in tokenizer.vocab: 84 | invalid_embedding_names.append(embedding_name) 85 | 86 | # Log warning for invalid embedding names 87 | if invalid_embedding_names: 88 | logger.warning( 89 | "Some embedding names may not be valid for HCP training: " 90 | + ", ".join(invalid_embedding_names) 91 | ) 92 | 93 | # Write the updated content back to the JSON file 94 | with open(filepath, "w") as file: 95 | json.dump(name_init_map, file, indent=4) 96 | -------------------------------------------------------------------------------- /anime2sd/extract_frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import logging 4 | import subprocess 5 | from typing import Optional 6 | 7 | from .basics import parse_anime_info 8 | from .remove_duplicates import DuplicateRemover 9 | 10 | 11 | def check_cuda_availability(logger=None): 12 | if logger is None: 13 | logger = logging.getLogger() 14 | try: 15 | output = subprocess.check_output( 16 | ["ffmpeg", "-hwaccels"], universal_newlines=True 17 | ) 18 | return "cuda" in output 19 | except Exception as e: 20 | logging.warning(f"Error checking CUDA availability: {e}") 21 | return False 22 | 23 | 24 | def get_ffmpeg_command(file, file_pattern, extract_key, logger=None): 25 | if logger is None: 26 | logger = logging.getLogger() 27 | cuda_available = check_cuda_availability(logger) 28 | command = ["ffmpeg"] 29 | 30 | if cuda_available: 31 | command.extend(["-hwaccel", "cuda"]) 32 | else: 33 | logger.warning("CUDA is not available. Proceeding without CUDA.") 34 | 35 | command.extend(["-i", file]) 36 | 37 | if extract_key: 38 | command.extend(["-vf", "select='eq(pict_type,I)'", "-vsync", "vfr"]) 39 | else: 40 | command.extend( 41 | [ 42 | "-filter:v", 43 | "mpdecimate=hi=64*200:lo=64*50:frac=0.33,setpts=N/FRAME_RATE/TB", 44 | ] 45 | ) 46 | 47 | command.extend(["-qscale:v", "1", "-qmin", "1", "-c:a", "copy", file_pattern]) 48 | 49 | return command 50 | 51 | 52 | def extract_and_remove_similar( 53 | src_dir: str, 54 | dst_dir: str, 55 | prefix: Optional[str] = None, 56 | ep_init: Optional[int] = None, 57 | extract_key: bool = False, 58 | duplicate_remover: Optional[DuplicateRemover] = None, 59 | logger: Optional[logging.Logger] = None, 60 | ) -> None: 61 | """ 62 | Extracts frames from video files in the specified source directory, 63 | saves them to the destination directory, and optionally removes similar frames. 64 | 65 | The function supports multiple video file formats such as mp4, mkv, avi, etc. 66 | It uses FFmpeg to extract frames from videos. 67 | If a `DuplicateRemover` instance is provided, it removes similar frames within each 68 | episode's directory and across the entire source directory (for opening and ending). 69 | 70 | Args: 71 | src_dir (str): 72 | The directory containing source video files. 73 | dst_dir (str): 74 | The directory where extracted frames will be saved. 75 | prefix (Optional[str]): 76 | A prefix to add to the names of extracted frames. 77 | Defaults to None in which case prefix if inferred from file name. 78 | ep_init (Optional[int]): 79 | An initial episode number to start from for naming the extracted frames. 80 | Defaults to None in which case episode number is inferred from file name. 81 | extract_key (bool): 82 | Flag indicating whether to extract only key frames. 83 | Defaults to False. 84 | duplicate_remover (Optional[DuplicateRemover]): 85 | An instance of DuplicateRemover to remove duplicate frames. 86 | Defaults to None in which case no duplicate removal is performed. 87 | logger (Optional[logging.Logger]): 88 | A logger for logging messages. 89 | Defaults to None in which case a default logger is used. 90 | """ 91 | if logger is None: 92 | logger = logging.getLogger() 93 | # Supported video file extensions 94 | video_extensions = [".mp4", ".mkv", ".avi", ".flv", ".mov", ".wmv"] 95 | 96 | # Recursively find all video files in the specified 97 | # source directory and its subdirectories 98 | files = [ 99 | os.path.join(root, file) 100 | for root, dirs, files in os.walk(src_dir) 101 | for file in files 102 | if os.path.splitext(file)[1] in video_extensions 103 | ] 104 | 105 | # Loop through each file 106 | for i, file in enumerate(sorted(files)): 107 | # Extract the filename without extension 108 | filename_without_ext = os.path.splitext(os.path.basename(file))[0] 109 | 110 | # Extract the anime name and episode number 111 | anime_name, ep_num = parse_anime_info(filename_without_ext) 112 | anime_name = "_".join(re.split(r"\s+", anime_name)) 113 | prefix_anime = f"{prefix if isinstance(prefix, str) else anime_name}_" 114 | if isinstance(ep_init, int): 115 | ep_num = i + ep_init 116 | elif ep_num is None: 117 | ep_num = i 118 | 119 | # Create the output directory 120 | dst_ep_dir = os.path.join(dst_dir, filename_without_ext) 121 | os.makedirs(dst_ep_dir, exist_ok=True) 122 | file_pattern = os.path.join(dst_ep_dir, f"{prefix_anime}EP{ep_num}_%d.png") 123 | 124 | # Run ffmpeg on the file, saving the output to the output directory 125 | ffmpeg_command = get_ffmpeg_command(file, file_pattern, extract_key, logger) 126 | logger.info(ffmpeg_command) 127 | subprocess.run(ffmpeg_command, check=True) 128 | 129 | if duplicate_remover is not None: 130 | duplicate_remover.remove_similar_from_dir(dst_ep_dir) 131 | 132 | # Go through all files again to remove duplicates from op and ed 133 | if duplicate_remover is not None: 134 | duplicate_remover.remove_similar_from_dir(dst_dir, portion="first") 135 | duplicate_remover.remove_similar_from_dir(dst_dir, portion="last") 136 | -------------------------------------------------------------------------------- /configs/csv_examples/character_mapping_example.csv: -------------------------------------------------------------------------------- 1 | terakomari_gandezblood,Terakomari 2 | villhaze,Villhaze 3 | amatsu_karla,Amatsu 4 | karen_helvetius,KarenH 5 | nelia_cunningham,Nelia 6 | sakuna_memoire,Sakuna 7 | millicent_bluenight,Millicent 8 | flote_mascarail,Mascarail 9 | -------------------------------------------------------------------------------- /configs/csv_examples/character_mapping_example_no_underscore.csv: -------------------------------------------------------------------------------- 1 | aoba kokona,AobaKokona 2 | kuraue hinata,KuraueHinata 3 | kurosaki honoka,KurosakiHonoka 4 | saitou kaede (yama no susume),SaitoKaede 5 | yukimura aoi,YukimuraAoi 6 | senjuin koharu,SenjuinKoharu 7 | onozuka hikari,OnozukaHikari 8 | yukimura megumi,YukimuraMegumi 9 | suu (yama no susume), 10 | -------------------------------------------------------------------------------- /configs/csv_examples/default_weighting.csv: -------------------------------------------------------------------------------- 1 | screenshots,3 2 | booru,2 3 | regularization,5 4 | 0_characters,0.5 5 | 1_character,1.5 6 | 2_characters,1 7 | 2+_characters,1 8 | 3_characters,0.8 9 | 3+_characters,0.8 10 | 4_characters,0.6 11 | 4+_characters,0.6 12 | 5_characters,0.6 13 | 5+_characters,0.6 14 | 6+_characters,0.6 15 | character_others,0.8 16 | -------------------------------------------------------------------------------- /configs/csv_examples/embedding_names_example.csv: -------------------------------------------------------------------------------- 1 | ADol, 2 | Airi, 3 | Akira, 4 | Alice, 5 | Hirokazu, 6 | Izumi, 7 | Kaworu, 8 | Kazeharu, 9 | Makoto, 10 | Manami, 11 | Mayuri, 12 | Seria, 13 | Tomoko, 14 | xclmsA, 15 | xclsB, 16 | screenshots,anime screencap 17 | fanart,masterpiece 18 | -------------------------------------------------------------------------------- /configs/csv_examples/weighting_example.csv: -------------------------------------------------------------------------------- 1 | 1_character,2 2 | 2_characters,2 3 | 3_characters,1 4 | 4_characters,0.8 5 | 5_characters,0.6 6 | 6+_characters,0.4 7 | character_others,0.8 8 | others,0.4 9 | */face_height_ratio_0-25,1 10 | */face_height_ratio_25-50,1.2 11 | */face_height_ratio_25-75,0.8 12 | */face_height_ratio_75-100,0.3 13 | KurosakiHonoka+YukimuraAoi,2 14 | AobaKokona+YukimuraAoi,1.2 15 | AobaKokona+KurosakiHonoka,1.5 16 | YukimuraAoi,1.2 17 | KuraueHinata,1.2 18 | AobaKokona,1.2 19 | SenjuinKoharu,0.8 20 | YukimuraAoi Twintail,0.6 21 | KuraueHinata Hairdown,0.6 22 | YuriYama,0.6 23 | MioYama,0.6 24 | OnozukaHikari,0.6 25 | TakekanaKasumi,0.6 26 | SasaharaYuka,0.6 27 | Suusan,0.4 28 | KuraueMai,0.4 29 | KuraueKenichi,0.4 30 | YukimuraMakoto,0.4 31 | YukimuraMegumi,0.4 32 | KurosakiTaiki,0.3 33 | KuraueHinata+YukimuraAoi,2 34 | AobaKokona+SaitoKaede+YukimuraAoi,1.2 35 | AobaKokona+KuraueHinata+YukimuraAoi,1.6 36 | KuraueHinata+SaitoKaede+YukimuraAoi,1.2 37 | MioYama+TakekanaKasumi+YuriYama,1.2 38 | AobaKokona+KuraueHinata+SaitoKaede+YukimuraAoi,2 39 | AobaKokona+KuraueHinata+KurosakiHonoka+YukimuraAoi,1.4 40 | -------------------------------------------------------------------------------- /configs/hcp/caption.txt: -------------------------------------------------------------------------------- 1 | {caption} 2 | -------------------------------------------------------------------------------- /configs/hcp/dataset.yaml: -------------------------------------------------------------------------------- 1 | config_dir: 'configs/hcp' 2 | dataset_dir: 'data' 3 | 4 | data: 5 | dataset1: 6 | _target_: hcpdiff.data.TextImagePairDataset 7 | _partial_: True # Not directly instantiate the object here. There are other parameters to be added in the runtime. 8 | batch_size: 8 9 | cache_latents: True 10 | att_mask_encode: False 11 | loss_weight: 1.0 12 | cache_path: '${dataset_dir}/latent_cache.pth' 13 | 14 | source: 15 | data_source_1: 16 | _target_: hcpdiff.data.source.Text2ImageAttMapSource 17 | img_root: 'imgs/' 18 | repeat: 1 19 | prompt_template: '${config_dir}/caption.txt' 20 | caption_file: 21 | _target_: hcpdiff.data.TXTCaptionLoader 22 | path: 'imgs/' 23 | 24 | att_mask: null 25 | bg_color: [ 255, 255, 255 ] # RGB; for ARGB -> RGB 26 | 27 | word_names: {} 28 | 29 | text_transforms: 30 | _target_: torchvision.transforms.Compose 31 | transforms: 32 | # - _target_: hcpdiff.utils.caption_tools.TagDropout 33 | # p: 0.1 34 | # - _target_: hcpdiff.utils.caption_tools.TagShuffle 35 | - _target_: hcpdiff.utils.caption_tools.TemplateFill 36 | word_names: ${....word_names} 37 | 38 | bucket: 39 | # _target_: hcpdiff.data.bucket.RatioBucket.from_files # the buckets are automatically selected but this would require recaching latents when dataset changes 40 | _target_: hcpdiff.data.bucket.RatioBucket.from_ratios # aspect ratio bucket with fixed ratios 41 | target_area: ${hcp.eval:"512*512"} 42 | num_bucket: 10 43 | pre_build_bucket: '${dataset_dir}/bucket_cache.pkl' 44 | -------------------------------------------------------------------------------- /configs/hcp/diag_oft.yaml: -------------------------------------------------------------------------------- 1 | exp_dir_base: 'hcp_exps' 2 | config_dir: 'configs/hcp' 3 | emb_dir: 'embs' 4 | emb_lr: 1e-2 5 | 6 | _base_: 7 | - ${config_dir}/train_base.yaml 8 | - ${config_dir}/dataset.yaml 9 | 10 | exp_dir: ${exp_dir_base}/${hcp.time:} 11 | 12 | model: 13 | pretrained_model_name_or_path: 'deepghs/animefull-latest' # JosephusCheung/ACertainty, Crosstyan/BPModel 14 | tokenizer_repeats: 1 15 | clip_skip: 1 16 | ema_unet: 0 17 | ema_text_encoder: 0 18 | 19 | train: 20 | train_steps: 50000 21 | save_step: 5000 22 | gradient_accumulation_steps: 1 23 | 24 | scheduler: 25 | name: 'constant_with_warmup' 26 | num_warmup_steps: 1000 27 | num_training_steps: 50000 28 | 29 | unet: null 30 | text_encoder: null 31 | lora_unet: null 32 | lora_text_encoder: null 33 | 34 | plugin_unet: 35 | diag_oft: 36 | _target_: lycoris.hcp.DiagOFTBlock.wrap_model 37 | _partial_: True 38 | rescaled: True 39 | lr: 2e-5 40 | dim: 32 41 | layers: 42 | - 're:.*\.attn.?$' 43 | - 're:.*\.ff$' 44 | 45 | plugin_TE: null 46 | # diag_oft: 47 | # _target_: lycoris.hcp.DiagOFTBlock.wrap_model 48 | # _partial_: True 49 | # rescaled: True 50 | # lr: 5e-7 51 | # dim: 32 52 | # layers: 53 | # - 're:.*self_attn$' 54 | # - 're:.*mlp$' 55 | 56 | tokenizer_pt: 57 | emb_dir: '${emb_dir}' 58 | replace: False 59 | train: 60 | - name: pt1 61 | lr: ${emb_lr} 62 | 63 | logger: 64 | - _target_: hcpdiff.loggers.CLILogger 65 | _partial_: True 66 | out_path: 'train.log' 67 | log_step: 20 68 | enable_log_image: False 69 | - _target_: hcpdiff.loggers.TBLogger 70 | _partial_: True 71 | out_path: 'tblog/' 72 | log_step: 5 73 | enable_log_image: False 74 | -------------------------------------------------------------------------------- /configs/hcp/loha.yaml: -------------------------------------------------------------------------------- 1 | exp_dir_base: 'hcp_exps' 2 | config_dir: 'configs/hcp' 3 | emb_dir: 'embs' 4 | emb_lr: 1e-2 5 | 6 | _base_: 7 | - ${config_dir}/train_base.yaml 8 | - ${config_dir}/dataset.yaml 9 | 10 | exp_dir: ${exp_dir_base}/${hcp.time:} 11 | 12 | model: 13 | pretrained_model_name_or_path: 'deepghs/animefull-latest' # JosephusCheung/ACertainty, Crosstyan/BPModel 14 | tokenizer_repeats: 1 15 | clip_skip: 1 16 | ema_unet: 0 17 | ema_text_encoder: 0 18 | 19 | train: 20 | train_steps: 50000 21 | save_step: 5000 22 | gradient_accumulation_steps: 1 23 | 24 | scheduler: 25 | name: 'constant_with_warmup' 26 | num_warmup_steps: 1000 27 | num_training_steps: 50000 28 | 29 | unet: null 30 | text_encoder: null 31 | lora_unet: null 32 | lora_text_encoder: null 33 | 34 | plugin_unet: 35 | loha: 36 | _target_: lycoris.hcp.LohaBlock.wrap_model 37 | _partial_: True 38 | lr: 2e-4 39 | dim: 8 40 | alpha: 4 41 | layers: 42 | - 're:.*\.attn.?$' 43 | - 're:.*\.ff$' 44 | 45 | plugin_TE: null 46 | # loha: 47 | # _target_: lycoris.hcp.LohaBlock.wrap_model 48 | # _partial_: True 49 | # lr: 6e-5 50 | # dim: 4 51 | # alpha: 1 52 | # layers: 53 | # - 're:.*self_attn$' 54 | # - 're:.*mlp$' 55 | 56 | tokenizer_pt: 57 | emb_dir: '${emb_dir}' 58 | replace: False 59 | train: 60 | - name: pt1 61 | lr: ${emb_lr} 62 | 63 | logger: 64 | - _target_: hcpdiff.loggers.CLILogger 65 | _partial_: True 66 | out_path: 'train.log' 67 | log_step: 20 68 | enable_log_image: False 69 | - _target_: hcpdiff.loggers.TBLogger 70 | _partial_: True 71 | out_path: 'tblog/' 72 | log_step: 5 73 | enable_log_image: False 74 | -------------------------------------------------------------------------------- /configs/hcp/lokr.yaml: -------------------------------------------------------------------------------- 1 | exp_dir_base: 'hcp_exps' 2 | config_dir: 'configs/hcp' 3 | emb_dir: 'embs' 4 | emb_lr: 1e-2 5 | 6 | _base_: 7 | - ${config_dir}/train_base.yaml 8 | - ${config_dir}/dataset.yaml 9 | 10 | exp_dir: ${exp_dir_base}/${hcp.time:} 11 | 12 | model: 13 | pretrained_model_name_or_path: 'deepghs/animefull-latest' # JosephusCheung/ACertainty, Crosstyan/BPModel 14 | tokenizer_repeats: 1 15 | clip_skip: 1 16 | ema_unet: 0 17 | ema_text_encoder: 0 18 | 19 | train: 20 | train_steps: 50000 21 | save_step: 5000 22 | gradient_accumulation_steps: 1 23 | 24 | scheduler: 25 | name: 'constant_with_warmup' 26 | num_warmup_steps: 1000 27 | num_training_steps: 50000 28 | 29 | unet: null 30 | text_encoder: null 31 | lora_unet: null 32 | lora_text_encoder: null 33 | 34 | plugin_unet: 35 | lokr: 36 | _target_: lycoris.hcp.LokrBlock.wrap_model 37 | _partial_: True 38 | lr: 2e-4 39 | dim: 10000 40 | alpha: 0 41 | factor: 8 42 | layers: 43 | - 're:.*\.attn.?$' 44 | - 're:.*\.ff$' 45 | 46 | plugin_TE: null 47 | # lokr: 48 | # _target_: lycoris.hcp.LokrBlock.wrap_model 49 | # _partial_: True 50 | # lr: 2e-5 51 | # dim: 10000 52 | # alpha: 0 53 | # factor: 8 54 | # layers: 55 | # - 're:.*self_attn$' 56 | # - 're:.*mlp$' 57 | 58 | tokenizer_pt: 59 | emb_dir: '${emb_dir}' 60 | replace: False 61 | train: 62 | - name: pt1 63 | lr: ${emb_lr} 64 | 65 | logger: 66 | - _target_: hcpdiff.loggers.CLILogger 67 | _partial_: True 68 | out_path: 'train.log' 69 | log_step: 20 70 | enable_log_image: False 71 | - _target_: hcpdiff.loggers.TBLogger 72 | _partial_: True 73 | out_path: 'tblog/' 74 | log_step: 5 75 | enable_log_image: False 76 | -------------------------------------------------------------------------------- /configs/hcp/lora_conventional.yaml: -------------------------------------------------------------------------------- 1 | exp_dir_base: 'hcp_exps' 2 | config_dir: 'configs/hcp' 3 | emb_dir: 'embs' 4 | emb_lr: 1e-2 5 | 6 | _base_: 7 | - ${config_dir}/train_base.yaml 8 | - ${config_dir}/dataset.yaml 9 | 10 | exp_dir: ${exp_dir_base}/${hcp.time:} 11 | 12 | model: 13 | pretrained_model_name_or_path: 'deepghs/animefull-latest' # JosephusCheung/ACertainty, Crosstyan/BPModel 14 | tokenizer_repeats: 1 15 | clip_skip: 1 16 | ema_unet: 0 17 | ema_text_encoder: 0 18 | 19 | train: 20 | train_steps: 50000 21 | save_step: 5000 22 | gradient_accumulation_steps: 1 23 | 24 | scheduler: 25 | name: 'constant_with_warmup' 26 | num_warmup_steps: 1000 27 | num_training_steps: 50000 28 | 29 | unet: null 30 | text_encoder: null 31 | plugin_unet: null 32 | plugin_TE: null 33 | 34 | lora_unet: 35 | - lr: 2e-4 36 | rank: 16 37 | alpha: 8 38 | layers: 39 | - 're:.*\.attn.?$' 40 | - 're:.*\.ff$' 41 | 42 | lora_text_encoder: null 43 | # - lr: 6e-5 44 | # rank: 8 45 | # alpha: 4 46 | # layers: 47 | # - 're:.*self_attn$' 48 | # - 're:.*mlp$' 49 | 50 | tokenizer_pt: 51 | emb_dir: '${emb_dir}' 52 | replace: False 53 | train: 54 | - name: pt1 55 | lr: ${emb_lr} 56 | 57 | logger: 58 | - _target_: hcpdiff.loggers.CLILogger 59 | _partial_: True 60 | out_path: 'train.log' 61 | log_step: 20 62 | enable_log_image: False 63 | - _target_: hcpdiff.loggers.TBLogger 64 | _partial_: True 65 | out_path: 'tblog/' 66 | log_step: 5 67 | enable_log_image: False 68 | -------------------------------------------------------------------------------- /configs/hcp/text2img.yaml: -------------------------------------------------------------------------------- 1 | # base_state*base_model_alpha + (lora_state[i]*lora_scale[i]*lora_alpha[i]) + (part_state[k]*part_alpha[k]) 2 | exp_dir: 'exps/2023-07-26-01-05-35' # experiment directory 3 | model_steps: 1000 # steps of selected model 4 | emb_dir: '${exp_dir}/ckpts/' 5 | 6 | pretrained_model: 'deepghs/animefull-latest' 7 | prompt: '' 8 | neg_prompt: 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry' 9 | N_repeats: 1 10 | clip_skip: 1 11 | clip_final_norm: True 12 | bs: 1 13 | num: 1 14 | seed: null 15 | dtype: 'fp16' 16 | 17 | condition: null 18 | 19 | ex_input: {} 20 | 21 | # Syntactic sugar for interface 22 | save: 23 | out_dir: 'output/' 24 | save_cfg: True 25 | image_type: png 26 | quality: 95 27 | # image_type: webp 28 | # quality: 75 29 | 30 | offload: null 31 | 32 | #vae_optimize: null 33 | vae_optimize: 34 | tiling: False 35 | slicing: False 36 | 37 | interface: 38 | - _target_: hcpdiff.vis.DiskInterface 39 | show_steps: 0 40 | save_root: ${save.out_dir} 41 | save_cfg: ${save.save_cfg} 42 | image_type: ${save.image_type} 43 | quality: ${save.quality} 44 | 45 | infer_args: 46 | width: 512 47 | height: 640 48 | guidance_scale: 7.5 49 | num_inference_steps: 25 50 | 51 | new_components: 52 | scheduler: 53 | _target_: diffusers.EulerAncestralDiscreteScheduler # change Sampler 54 | beta_start: 0.00085 55 | beta_end: 0.012 56 | beta_schedule: 'scaled_linear' 57 | 58 | merge: 59 | plugin_cfg: loha.yaml 60 | alpha: 1 61 | 62 | group_unet: 63 | type: 'unet' 64 | base_model_alpha: 1.0 65 | plugin: 66 | loha: 67 | path: '${.....exp_dir}/ckpts/unet-${.....model_steps}.safetensors' 68 | alpha: ${....alpha} 69 | layers: 'all' 70 | 71 | group_TE: 72 | type: 'TE' 73 | base_model_alpha: 1.0 74 | plugin: 75 | loha: 76 | path: '${.....exp_dir}/ckpts/text_encoder-${.....model_steps}.safetensors' 77 | alpha: ${....alpha} 78 | layers: 'all' 79 | -------------------------------------------------------------------------------- /configs/hcp/train_base.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: hcp_exps/${hcp.time:} 2 | mixed_precision: 'bf16' 3 | allow_tf32: False 4 | seed: 114514 5 | ckpt_type: 'safetensors' # [torch, safetensors] 6 | 7 | vis_info: 8 | prompt: null 9 | negative_prompt: '' 10 | 11 | train: 12 | train_steps: 50000 13 | train_epochs: null # Choose one of [train_steps, train_epochs] 14 | gradient_accumulation_steps: 1 15 | workers: 4 16 | max_grad_norm: 1.0 17 | set_grads_to_none: False 18 | save_step: 5000 19 | cfg_scale: '1.0' # for DreamArtist 20 | 21 | resume: null 22 | # resume: 23 | # ckpt_path: 24 | # unet: [] 25 | # TE: [] 26 | # words: {} 27 | # start_step: 0 28 | 29 | loss: 30 | criterion: 31 | _target_: torch.nn.MSELoss 32 | _partial_: True 33 | reduction: 'none' # support for attention mask 34 | type: 'eps' # 'eps' or 'sample' 35 | 36 | optimizer: 37 | _target_: bitsandbytes.optim.AdamW8bit 38 | _partial_: True 39 | weight_decay: 1e-3 40 | 41 | optimizer_pt: 42 | _target_: bitsandbytes.optim.AdamW8bit 43 | _partial_: True 44 | weight_decay: 5e-4 45 | 46 | scale_lr: False # auto scale lr with total batch size 47 | scheduler: 48 | name: 'constant_with_restart' 49 | num_warmup_steps: 1000 50 | num_training_steps: 50000 51 | scheduler_kwargs: {} # args for scheduler 52 | 53 | scale_lr_pt: False 54 | scheduler_pt: ${.scheduler} 55 | 56 | model: 57 | revision: null 58 | pretrained_model_name_or_path: null 59 | tokenizer_repeats: 2 60 | enable_xformers: True 61 | gradient_checkpointing: False 62 | force_cast_precision: False 63 | ema: null 64 | clip_skip: 0 65 | clip_final_norm: True 66 | 67 | tokenizer: null 68 | noise_scheduler: null 69 | unet: null 70 | text_encoder: null 71 | vae: null 72 | 73 | previewer: null 74 | -------------------------------------------------------------------------------- /configs/pipelines/booru.toml: -------------------------------------------------------------------------------- 1 | # General Configuration 2 | [general] 3 | # Directory containing source files 4 | src_dir = "data/intermediate/booru/raw" 5 | # Directory to save output files 6 | dst_dir = "data" 7 | # Extra path component to add between dst_dir/[training|intermediate] and image type 8 | extra_path_component = "" 9 | # Stage number or alias to start from 10 | start_stage = 0 11 | # Stage number or alias to end at 12 | end_stage = 7 13 | # Directory to save logs. Set to None or none to disable. 14 | log_dir = "logs" 15 | # Prefix for log files, defaults to --anime_name if provided otherwise 'logfile' 16 | log_prefix = {} 17 | # Pipeline type that is used to construct dataset 18 | pipeline_type = "booru" 19 | # Image type that we are dealing with, used for folder name and might appear in caption as well. Default to --pipeline_type. 20 | image_type = "booru" 21 | 22 | # Metadata Loading and Saving 23 | [metadata_handling] 24 | # Extension of the grabber information files to load. Attributes from this file would overwrite those loaded from --load_aux. 25 | load_grabber_ext = ".tags" 26 | 27 | # Configuration for downloading images from Danbooru 28 | [booru_download] 29 | # The anime name used for downloading images from Danbooru 30 | # Set to {} or comment this line to use --anime_name 31 | anime_name_booru = {} 32 | # Path to CSV file containing character mapping information, used for renaming characters and potentially for downloading as well 33 | # Set to {} or comment this line to disable 34 | character_info_file = {} 35 | # character_info_file = "configs/csv_examples/character_mapping_example.csv" 36 | # Indicates whether to attempt downloading for all characters in the character info file 37 | download_for_characters = true 38 | # Limit on the total number of images to download from Danbooru 39 | # Set to 0 for no limit or specify a number 40 | booru_download_limit = 0 41 | # Limit on the number of images to download per character from Danbooru 42 | # Set to 0 for no limit or specify a number 43 | # Note that for efficiency if both booru_download_limit and booru_download_limit_per_character are set, 44 | # we are not guaranteed to download booru_download_limit number of images 45 | booru_download_limit_per_character = 500 46 | # List of allowed ratings for filtering images, set to empty list to disable 47 | allowed_ratings = [] 48 | # List of allowed classes for filtering images, set to empty list to disable 49 | allowed_image_classes = ["illustration", "bangumi"] 50 | # Maximum size for the smaller dimension of too large downloaded images to resize to 51 | max_download_size = 1024 52 | 53 | # Character Cropping Configuration 54 | [character_cropping] 55 | # Use 3 stage crop to get halfbody and head crops 56 | use_3stage_crop = 2 57 | 58 | # Character Clustering/Classification Configuration 59 | [character_classification] 60 | # Directory containing reference character images 61 | character_ref_dir = "data/ref_images" 62 | # The number of additional reference images to add to each character from classification result" 63 | n_add_to_ref_per_character = 20 64 | # Whether to ignore existing character metadata during classification 65 | ignore_character_metadata = false 66 | # Minimum cluster samples in character clustering 67 | cluster_min_samples = 5 68 | # Whether to keep unnamed clusters when reference images are provided or when characters are available in metadata 69 | keep_unnamed_clusters = false 70 | # Whether we try to attribute label when multiple candidates are available when performing classification with metadata character information 71 | # This typically coressponds to the case where we have one character that always appear with another specific character, 72 | # or to some specific form of a character that is recognized as character tag in Danbooru 73 | accept_multiple_candidates = false 74 | 75 | # Dataset Construction Configuration 76 | [dataset_construction] 77 | # Overwrite existing character metadata for uncropped images 78 | character_overwrite_uncropped = false 79 | # Remove unclassified characters in the character metadata field 80 | character_remove_unclassified = false 81 | 82 | # Tagging Configuration 83 | [tagging] 84 | # Whether to overwrite existing tags 85 | overwrite_tags = true 86 | 87 | # General Tag Processing Configuration 88 | [tag_processing] 89 | # Mode to sort the tags 90 | sort_mode = "score" 91 | # Whether to append dropped character tags to the caption 92 | append_dropped_character_tags = false 93 | # Max number of tags to include in caption 94 | max_tag_number = 30 95 | # Process tags from original tags instead of processed tags 96 | process_from_original_tags = true 97 | # Different ways to prune tags 98 | prune_mode = "character_core" 99 | 100 | # Folder Organization Configuration 101 | [folder_organization] 102 | # Description of the concept balancing directory hierarchy 103 | arrange_format = "n_characters/character" 104 | # If have more than X characters put X+ 105 | max_character_number = 2 106 | # Put others instead of character name if number of images of the character combination is smaller than this number 107 | min_images_per_combination = 10 108 | 109 | # Balancing Configuration 110 | [balancing] 111 | # Minimum multiply of each image 112 | min_multiply = 1 113 | # Maximum multiply of each image 114 | max_multiply = 100 115 | # If provided use the provided csv to modify weights 116 | weight_csv = "configs/csv_examples/default_weighting.csv" 117 | -------------------------------------------------------------------------------- /configs/pipelines/screenshots.toml: -------------------------------------------------------------------------------- 1 | # General Configuration 2 | [general] 3 | # Directory containing source files 4 | src_dir = "data/intermediate/screenshots/animes" 5 | # Directory to save output files 6 | dst_dir = "data" 7 | # Extra path component to add between dst_dir/[training|intermediate] and image type 8 | extra_path_component = "" 9 | # Stage number or alias to start from 10 | start_stage = 0 11 | # Stage number or alias to end at 12 | end_stage = 7 13 | # Directory to save logs. Set to None or none to disable. 14 | log_dir = "logs" 15 | # Prefix for log files, defaults to --anime_name if provided otherwise 'logfile' 16 | log_prefix = {} 17 | # Pipeline type that is used to construct dataset 18 | pipeline_type = "screenshots" 19 | # Image type that we are dealing with, used for folder name and might appear in caption as well. Default to --pipeline_type. 20 | image_type = "screenshots" 21 | 22 | # Configuration for downloading animes from nyaa.si 23 | [nyaa_download] 24 | # The anime name used for downloading animes from nyaa.si 25 | anime_name = "my_favorite_anime" 26 | # The candidate submitters used for downloading animes from nyaa.si 27 | candidate_submitters = ["Erai-raws", "SubsPlease", "CameEsp", "ohys"] 28 | # The resolution of anime to download 29 | anime_resolution = 720 30 | # The minimum episode to download 31 | # Set to {} or comment this line to disable 32 | min_download_episode = {} 33 | # The maximum episode to download 34 | # Set to {} or comment this line to disable 35 | max_download_episode = {} 36 | 37 | # Video Extraction Configuration 38 | [video_extraction] 39 | # Only extract key frames 40 | extract_key = true 41 | # Output image prefix, when not provided we try to infer it from video file name 42 | image_prefix = {} 43 | # Episode number to start with, when not provided we try to infer it from video file name 44 | ep_init = {} 45 | 46 | # Character Cropping Configuration 47 | [character_cropping] 48 | # Use 3 stage crop to get halfbody and head crops 49 | # Set to {} or comment this line to disable 50 | use_3stage_crop = {} 51 | 52 | # Character Clustering/Classification Configuration 53 | [character_classification] 54 | # Directory containing reference character images 55 | character_ref_dir = "data/ref_images" 56 | # The number of additional reference images to add to each character from classification result" 57 | n_add_to_ref_per_character = 0 58 | # Whether to keep unnamed clusters when reference images are provided or when characters are available in metadata 59 | keep_unnamed_clusters = false 60 | 61 | # Dataset Construction Configuration 62 | [dataset_construction] 63 | # Number of images with no characters to keep 64 | n_anime_reg = 500 65 | 66 | # Tagging Configuration 67 | [tagging] 68 | # Whether to overwrite existing tags 69 | overwrite_tags = false 70 | 71 | # General Tag Processing Configuration 72 | [tag_processing] 73 | # Mode to sort the tags 74 | sort_mode = "score" 75 | # Whether to append dropped character tags to the caption 76 | append_dropped_character_tags = false 77 | # Max number of tags to include in caption 78 | max_tag_number = 30 79 | # Process tags from original tags instead of processed tags 80 | process_from_original_tags = true 81 | # Different ways to prune tags 82 | prune_mode = "character_core" 83 | 84 | # Folder Organization Configuration 85 | [folder_organization] 86 | # Description of the concept balancing directory hierarchy 87 | arrange_format = "n_characters/character" 88 | # If have more than X characters put X+ 89 | max_character_number = 3 90 | # Put others instead of character name if number of images of the character combination is smaller than this number 91 | min_images_per_combination = 10 92 | 93 | # Balancing Configuration 94 | [balancing] 95 | # Minimum multiply of each image 96 | min_multiply = 1 97 | # Maximum multiply of each image 98 | max_multiply = 100 99 | # If provided use the provided csv to modify weights 100 | weight_csv = "configs/csv_examples/default_weighting.csv" 101 | -------------------------------------------------------------------------------- /configs/tag_filtering/blacklist_tags.txt: -------------------------------------------------------------------------------- 1 | alternate_costume 2 | alternate_breast_size 3 | alternate_color 4 | alternate_costume 5 | alternate_eye_color 6 | alternate_form 7 | alternate_hair_color 8 | alternate_hair_length 9 | alternate_hairstyle 10 | alternate_headwear 11 | alternate_legwear 12 | alternate_pectoral_size 13 | alternate_skin_color 14 | alternate_universe 15 | alternate_weapon 16 | alternate_wings 17 | cosplay 18 | casual 19 | adapted_costume 20 | contemporary 21 | bimbofication 22 | aged_down 23 | aged_up 24 | no_eyewear 25 | age_progression 26 | genderswap 27 | genderswap_(mtf) 28 | genderswap_(ftm) 29 | crossdressing 30 | parody 31 | virtual_youtuber 32 | feet_out_of_frame 33 | head_out_of_frame 34 | -------------------------------------------------------------------------------- /configs/tag_filtering/character_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "whitelist": [ 3 | "drill", 4 | "pubic hair", 5 | "closed eyes", 6 | "half-closed eyes", 7 | "empty eyes", 8 | "fake tail" 9 | ], 10 | "suffixes": [ 11 | [ 12 | "man", 13 | "woman", 14 | "eyes", 15 | "skin", 16 | "hair", 17 | "bun", 18 | "bangs", 19 | "cut", 20 | "sidelocks", 21 | "twintails", 22 | "braid", 23 | "braids", 24 | "afro", 25 | "ahoge", 26 | "drill", 27 | "bald", 28 | "dreadlocks", 29 | "side up", 30 | "ponytail", 31 | "updo", 32 | "beard", 33 | "mustache", 34 | "goatee", 35 | "hair intake", 36 | "thick eyebrows", 37 | "otoko no ko", 38 | "bishounen", 39 | "short hair with long locks", 40 | "one eye covered" 41 | ], 42 | [ 43 | "fang", 44 | "mark", 45 | "freckles", 46 | "elf", 47 | "ear", 48 | "horn", 49 | "fur", 50 | "halo", 51 | "wings", 52 | "heterochromia", 53 | "tail", 54 | "animal ear fluff", 55 | "girl", 56 | "boy" 57 | ] 58 | ], 59 | "prefixes": [ 60 | [ 61 | "hair over", 62 | "hair between", 63 | "dark-skinned", 64 | "mature", 65 | "old" 66 | ], 67 | [ 68 | "mole", 69 | "scar", 70 | "furry", 71 | "muscular" 72 | ] 73 | ] 74 | } 75 | -------------------------------------------------------------------------------- /docs/Character_ref_organization.md: -------------------------------------------------------------------------------- 1 | # Organization of the Character Reference Directory 2 | 3 | The character reference directory should contain subfolders with character images. Moreover, it can be organized in a hierarchical way following the convention of `character/appearance/outfits/accessories/objects/extras`, as in the following example. 4 | 5 | 6 | ``` 7 | . 8 | ├── Noise 9 | ├── KarenH 10 | ├── Mascarail 11 | ├── Melakonsi 12 | ├── Melca 13 | ├── Millicent 14 | ├── Sakuna 15 | ├── Terakomari 16 | │   ├── cone hair bun 17 | │   │   ├── red dress 18 | │   │   └── uniform 19 | │   ├── hair bun 20 | │   └── none 21 | │   └── pajama 22 | └── Villhaze 23 | ``` 24 | 25 | ## Character Name, Appearance, and Character Embedding 26 | 27 | For the character level, you can use any name that starts with "Noise" or "noise" to put images of characters or random people that you do not want to get classified as targeted characters. This is useful to prevent from wrong classification results. 28 | 29 | For the appearance level, use something that starts with `_` to have a single embedding for character and appearance. For example, putting `_cone hair bun` under `Terakomari` would result in `Terakomari_cone_hair_bun` to be considered as individual embedding when saving embedding initialization information and this is also what will be used in captions. Otherwise, `Terakomari, cone hair bun` is used in captions (`, ` can be replaced by other separators by specifying `--character_inner_sep`) 30 | 31 | You can put `None` or `none` to skip any level so that they are not used in captions. In the above example, the caption would be `Terakomari, pajama` and **not** `Terakomari, none, pajama`. Generally speaking, current character classification mechanism with ccip embeddings work sufficiently well up to this level. 32 | 33 | 34 | ## Outfits and More 35 | 36 | Starting from this level it is possible to have multiple items in the folder name, separated by `+`, e.g. `red uniform+black skirt`. Anything starting with `_` will be considered as embeddings (**TODO**: this is to be implemented). When multiple items of a same type exists, they are separated by `--caption_inner_sep` in captions, while different types of items are separated by `--character_outer_sep`. It is important to note that ccip embeddings do not work well for outfits and beyond, so manual inspection will be needed after stage 3 to put everything in order. (Hopefully we get cwip done soon). 37 | -------------------------------------------------------------------------------- /docs/Dataset_organization.md: -------------------------------------------------------------------------------- 1 | # Dataset Organization 2 | 3 | After the entire process, you will get the following structure in `/path/to/dataset_dir` if you use the default configuration files and run `booru` and `screenshots` pipelines in parallel. 4 | 5 | ``` 6 | . 7 | ├── intermediate 8 | │   ├── booru 9 | │   │   ├── classified 10 | │   │   ├── cropped 11 | │   │   └── raw 12 | │   └── screenshots 13 | │   ├── animes 14 | │   ├── classified 15 | │   ├── cropped 16 | │   └── raw 17 | └── training 18 | ├── booru 19 | │   ├── 1_character 20 | │   ├── 2+_characters 21 | │   └── emb_init.json 22 | ├── screenshots 23 | │   ├── 0_characters 24 | │   ├── 1_character 25 | │   ├── 2_characters 26 | │   ├── 3+_characters 27 | │   └── emb_init.json 28 | ├── core_tag.json 29 | ├── emb_init.json 30 | └── wildcard.txt 31 | ``` 32 | :bulb: If `--remove_intermediate` is specified the folders `classified` and `cropped` are removed during the process. 33 | 34 | The folder that should be used for training is `/path/to/dataset_dir/training`. Besides the training data, tt contains two important files. 35 | - `emb_init.json` provides information for embedding initialization to be used for pivotal tuning (`emb_init.json` in the subfolders can be ignored). 36 | - `wildcard.txt` provide the wildcard to be used with [sd-dynamic-prompts](https://github.com/adieyal/sd-dynamic-prompts). 37 | 38 | You can put other folders, such as your regularization images in the training folder before launching the process so that they will be taken into account as well when we compute the repeat to balance the concept at the end. 39 | 40 | ## Organization per Image Type 41 | 42 | Each folder `/path/to/dataset_dir/training/{image_type}` is organized in the following way if `--arrange_format` is set to `n_characters/character` (the default value). 43 | 44 | **Level 1** 45 | ``` 46 | ├── ./0_characters 47 | ├── ./1_character 48 | ├── ./2_characters 49 | ├── ./3_characters 50 | ├── ./4+_characters 51 | ``` 52 | 53 | :bulb: Use `--max_character_number n` so that images containing more than `n` characters are all put together. If you don't want them to be included in the dataset. You can remove it manually. 54 | 55 | **Level 2** 56 | ``` 57 | ├── ./1_character 58 | │   ├── ./1_character/AobaKokona 59 | │   ├── ./1_character/AobaMai 60 | │   ├── ./1_character/KuraueHinata 61 | │   ├── ./1_character/KuraueHinata Hairdown 62 | │   ├── ./1_character/KuraueKenichi 63 | │   ├── ./1_character/KuraueMai 64 | │   ├── ./1_character/KurosakiHonoka 65 | │   ├── ./1_character/KurosakiTaiki 66 | ... 67 | ``` 68 | :bulb: Use `--min_images_per_combination m` so that character combinations with fewer than `m` images are all put in the folder `character_others`. 69 | TODO: Add add an argument to optionally remove them. 70 | 71 | The hierarchical organization allows to auto-balance between different concepts without too much need of worrying about the number of images in each class. 72 | 73 | 74 | ## Multi-Anime Dataset and the Like 75 | 76 | You can pass the argument `--extra_path_component` to replace `{image_type}` with `{extra_path_component}/{image_type}` in the aforementioned paths. This allows you for example to have a good organization when processing multiple animes in parallel. 77 | 78 | Note that you will need to set `--compute_core_tag_up_levels` to 2 (or even higher number if `--extra_path_component` contains path separators) if you want to have a single wildcard and embedding initialization file for the entire dataset. Similarly, you may want to increase `--rearrange_up_levels` or `--compute_multiply_up_levels` to make sure that dataset balancing is computed from the root training folder. 79 | -------------------------------------------------------------------------------- /docs/Start_training.md: -------------------------------------------------------------------------------- 1 | # Start Training 2 | 3 | Once we go through the pipeline, the dataset is hierarchically organized in `/path/to/dataset_dir/training` with `multiply.txt` in each subfolder indicating the repeat of the images from this directory. You can pretty much launch the training process with your favorite trainer at this stage, modulo a few more steps to make sure that the data are read correctly. 4 | 5 | 6 | ## Training with EveryDream 7 | 8 | With `multiply.txt` in each folder, the above structure is directly compatible with [EveryDream2](https://github.com/victorchall/EveryDream2trainer). 9 | 10 | ## Training with Kohya Trainer 11 | 12 | For [kohya-ss/sd-scripts](https://github.com/kohya-ss/sd-scripts) you need to perform one more step with `flatten_folder.py` 13 | 14 | ```bash 15 | python flatten_folder.py \ 16 | --separator ~ \ 17 | --src_dir /path/to/dataset_dir/training 18 | ``` 19 | 20 | If you do not have the used separator (`~` by default) in any folder name you can undo the change by 21 | 22 | ```bash 23 | python flatten_folder.py \ 24 | --separator ~ \ 25 | --src_dir /path/to/dataset_dir/training \ 26 | --revert 27 | ``` 28 | 29 | It is important to switch between the two modes as I rely on the folder structure to compute repeat for now. 30 | 31 | ## Training with HCP-Diffusion 32 | 33 | [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion) requires to set up an yaml file to specify the repeat of each data source, and its configuration is generally more complicated, so I have provided `prepare_hcp.py` to streamline the process (to be run in the hcp-diffusion python environment). 34 | 35 | ```bash 36 | python prepare_hcp \ 37 | --config_dst_dir /path/to/training_config_dir \ 38 | --dataset_dir /path/to/dataset_dir/training 39 | --pivotal \ 40 | --trigger_word_file /path/to/dataset_dir/emb_init.json 41 | ``` 42 | 43 | Once this is done, the embeddings are created in `/path/to/training_config_dir/embs` and you can start training with 44 | 45 | ```bash 46 | accelerate launch -m hcpdiff.train_ac_single \ 47 | --cfg /path/to/training_config_dir/lora_conventional.yaml 48 | ``` 49 | 50 | ### Further details 51 | - `--pivotal` indicates pivotal tuning, i.e. training of embedding and network at the same time (this is not possible with neither kohya nor EveryDream). Remove this argument if you do not want to train embedding. 52 | - You can customize the embedding you want to create and how they are initialized by modifying the content of `emb_init.json`. 53 | - Use `--help` to see more arguments. Notably you can set `--emb_dir`, `--exp_dir`, and `--main_config_file` (which defaults to `hcp_configs/lora_conventional.yaml`), among others. 54 | - To modify training and dataset parameters, you can modify either directly the files in `hcp_configs` before running the script or modify `dataset.yaml` and `lora_conventional.yaml` (or other config file you use) in `/path/to/training_config_dir` after running the script. 55 | - You should not move the generated config files because some absolute paths are used. 56 | 57 | ### Post training conversion 58 | After training, the output files from HCP diffusion cannot be readily used by [a1111/sd-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui). For conversion please refer to [Conversion_scripts.md](Conversion_scripts.md). 59 | 60 | 61 | ## Training with ... 62 | 63 | Each trainer has its strength and drawback. If you know another good trainer that I overlook here, please let me know. 64 | -------------------------------------------------------------------------------- /docs/example_logs/hikikomari_weighting_2023-12-0120-36-13.log: -------------------------------------------------------------------------------- 1 | 2023-12-01 20:36:14,033 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/booru/1_character/Terakomari 2 | 2023-12-01 20:36:14,033 - INFO - sampling probability: 0.039473684210526314 3 | 2023-12-01 20:36:14,033 - INFO - number of images: 145 4 | 2023-12-01 20:36:14,033 - INFO - original multipy: 1.0 5 | 2023-12-01 20:36:14,033 - INFO - final multipy: 1.0 6 | 7 | 2023-12-01 20:36:14,033 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/screenshots/0_characters/character_others 8 | 2023-12-01 20:36:14,033 - INFO - sampling probability: 0.125 9 | 2023-12-01 20:36:14,033 - INFO - number of images: 337 10 | 2023-12-01 20:36:14,033 - INFO - original multipy: 1.3625123639960437 11 | 2023-12-01 20:36:14,033 - INFO - final multipy: 1.36 12 | 13 | 2023-12-01 20:36:14,033 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/screenshots/1_character/Terakomari 14 | 2023-12-01 20:36:14,033 - INFO - sampling probability: 0.078125 15 | 2023-12-01 20:36:14,033 - INFO - number of images: 147 16 | 2023-12-01 20:36:14,033 - INFO - original multipy: 1.9522392290249437 17 | 2023-12-01 20:36:14,033 - INFO - final multipy: 1.95 18 | 19 | 2023-12-01 20:36:14,033 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/booru/1_character/Villhaze 20 | 2023-12-01 20:36:14,034 - INFO - sampling probability: 0.039473684210526314 21 | 2023-12-01 20:36:14,034 - INFO - number of images: 53 22 | 2023-12-01 20:36:14,034 - INFO - original multipy: 2.7358490566037736 23 | 2023-12-01 20:36:14,034 - INFO - final multipy: 2.74 24 | 25 | 2023-12-01 20:36:14,034 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/screenshots/1_character/Villhaze 26 | 2023-12-01 20:36:14,034 - INFO - sampling probability: 0.078125 27 | 2023-12-01 20:36:14,034 - INFO - number of images: 70 28 | 2023-12-01 20:36:14,034 - INFO - original multipy: 4.099702380952381 29 | 2023-12-01 20:36:14,034 - INFO - final multipy: 4.1 30 | 31 | 2023-12-01 20:36:14,034 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/screenshots/2_characters/Terakomari+Villhaze 32 | 2023-12-01 20:36:14,034 - INFO - sampling probability: 0.08928571428571429 33 | 2023-12-01 20:36:14,034 - INFO - number of images: 53 34 | 2023-12-01 20:36:14,034 - INFO - original multipy: 6.188230008984727 35 | 2023-12-01 20:36:14,034 - INFO - final multipy: 6.19 36 | 37 | 2023-12-01 20:36:14,034 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/booru/1_character/character_others 38 | 2023-12-01 20:36:14,034 - INFO - sampling probability: 0.031578947368421054 39 | 2023-12-01 20:36:14,034 - INFO - number of images: 18 40 | 2023-12-01 20:36:14,034 - INFO - original multipy: 6.4444444444444455 41 | 2023-12-01 20:36:14,034 - INFO - final multipy: 6.44 42 | 43 | 2023-12-01 20:36:14,034 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/screenshots/1_character/Sakuna 44 | 2023-12-01 20:36:14,034 - INFO - sampling probability: 0.078125 45 | 2023-12-01 20:36:14,034 - INFO - number of images: 42 46 | 2023-12-01 20:36:14,034 - INFO - original multipy: 6.832837301587302 47 | 2023-12-01 20:36:14,034 - INFO - final multipy: 6.83 48 | 49 | 2023-12-01 20:36:14,035 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/booru/2+_characters/Terakomari+Villhaze 50 | 2023-12-01 20:36:14,035 - INFO - sampling probability: 0.05555555555555556 51 | 2023-12-01 20:36:14,035 - INFO - number of images: 28 52 | 2023-12-01 20:36:14,035 - INFO - original multipy: 7.288359788359791 53 | 2023-12-01 20:36:14,035 - INFO - final multipy: 7.29 54 | 55 | 2023-12-01 20:36:14,035 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/booru/1_character/Sakuna 56 | 2023-12-01 20:36:14,035 - INFO - sampling probability: 0.039473684210526314 57 | 2023-12-01 20:36:14,035 - INFO - number of images: 18 58 | 2023-12-01 20:36:14,035 - INFO - original multipy: 8.055555555555555 59 | 2023-12-01 20:36:14,035 - INFO - final multipy: 8.06 60 | 61 | 2023-12-01 20:36:14,035 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/booru/2+_characters/character_others 62 | 2023-12-01 20:36:14,035 - INFO - sampling probability: 0.04444444444444445 63 | 2023-12-01 20:36:14,035 - INFO - number of images: 15 64 | 2023-12-01 20:36:14,035 - INFO - original multipy: 10.883950617283954 65 | 2023-12-01 20:36:14,035 - INFO - final multipy: 10.88 66 | 67 | 2023-12-01 20:36:14,035 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/screenshots/1_character/Millicent 68 | 2023-12-01 20:36:14,035 - INFO - sampling probability: 0.078125 69 | 2023-12-01 20:36:14,035 - INFO - number of images: 21 70 | 2023-12-01 20:36:14,035 - INFO - original multipy: 13.665674603174605 71 | 2023-12-01 20:36:14,035 - INFO - final multipy: 13.67 72 | 73 | 2023-12-01 20:36:14,035 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/screenshots/2_characters/character_others 74 | 2023-12-01 20:36:14,035 - INFO - sampling probability: 0.07142857142857144 75 | 2023-12-01 20:36:14,035 - INFO - number of images: 10 76 | 2023-12-01 20:36:14,036 - INFO - original multipy: 26.238095238095244 77 | 2023-12-01 20:36:14,036 - INFO - final multipy: 26.24 78 | 79 | 2023-12-01 20:36:14,036 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/screenshots/1_character/character_others 80 | 2023-12-01 20:36:14,036 - INFO - sampling probability: 0.0625 81 | 2023-12-01 20:36:14,036 - INFO - number of images: 7 82 | 2023-12-01 20:36:14,036 - INFO - original multipy: 32.79761904761905 83 | 2023-12-01 20:36:14,036 - INFO - final multipy: 32.8 84 | 85 | 2023-12-01 20:36:14,036 - INFO - /home/sashi/Documents/Projects/Anime/extract-training-from-anime/data/training/screenshots/2_characters/Terakomari 86 | 2023-12-01 20:36:14,036 - INFO - sampling probability: 0.08928571428571429 87 | 2023-12-01 20:36:14,036 - INFO - number of images: 2 88 | 2023-12-01 20:36:14,036 - INFO - original multipy: 163.98809523809527 89 | 2023-12-01 20:36:14,036 - INFO - final multipy: 100.0 90 | 91 | 2023-12-01 20:36:14,036 - INFO - Number of images: 966 92 | 2023-12-01 20:36:14,036 - INFO - Virtual dataset size: 3544.5099999999998 93 | -------------------------------------------------------------------------------- /flatten_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | def list_image_subfolders(path): 6 | """Recursively list all subfolders of a directory.""" 7 | subfolders = [] 8 | contain_subfolder = False 9 | for item in os.listdir(path): 10 | item_path = os.path.join(path, item) 11 | if os.path.isdir(item_path): 12 | contain_subfolder = True 13 | subfolders.extend(list_image_subfolders(item_path)) 14 | if not contain_subfolder: 15 | subfolders.append(path) 16 | return subfolders 17 | 18 | 19 | def get_new_path(src_dir, path, separator): 20 | multiply_file = os.path.join(path, 'multiply.txt') 21 | repeat = 1 22 | if os.path.exists(multiply_file): 23 | with open(multiply_file, 'r') as f: 24 | repeat = round(float(f.readline().strip())) 25 | subpath = path.replace(src_dir, '').lstrip(os.path.sep) 26 | new_subpath = subpath.replace(os.path.sep, separator) 27 | return os.path.join(src_dir, f'{repeat}_{new_subpath}') 28 | 29 | 30 | def revert_path(src_dir, path, separator): 31 | subpath = '_'.join( 32 | path.replace(src_dir, '').lstrip(os.path.sep).split('_')[1:]) 33 | new_subpath = subpath.replace(separator, os.path.sep) 34 | return os.path.join(src_dir, new_subpath) 35 | 36 | 37 | def remove_empty_folders(path_abs): 38 | walk = list(os.walk(path_abs)) 39 | for path, _, _ in walk[::-1]: 40 | if len(os.listdir(path)) == 0: 41 | os.rmdir(path) 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--src_dir', required=True, 47 | help='Path to the source directory') 48 | parser.add_argument('--separator', default='~', 49 | help='String to sepearte folders of different levels') 50 | parser.add_argument('--revert', action='store_true') 51 | args = parser.parse_args() 52 | for path in list_image_subfolders(args.src_dir): 53 | if args.revert: 54 | new_path = revert_path(args.src_dir, path, args.separator) 55 | os.makedirs(os.path.dirname(new_path), exist_ok=True) 56 | else: 57 | new_path = get_new_path(args.src_dir, path, args.separator) 58 | os.rename(path, new_path) 59 | remove_empty_folders(args.src_dir) 60 | -------------------------------------------------------------------------------- /install.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal 3 | 4 | set "python_cmd=python" 5 | set "delimiter=****************************************************************" 6 | 7 | echo %delimiter% 8 | echo Python Environment Setup Script 9 | echo %delimiter% 10 | 11 | REM Check if the script is being run as administrator 12 | net session >nul 2>&1 13 | if %errorlevel% == 0 ( 14 | echo ERROR: This script should not be run as administrator. 15 | exit /b 16 | ) 17 | 18 | REM Update Git submodules 19 | echo Updating Git submodules... 20 | git submodule update --init --recursive 21 | if not %errorlevel% == 0 ( 22 | echo ERROR: Failed to update Git submodules. 23 | exit /b 24 | ) 25 | 26 | echo. 27 | echo Select environment setup: 28 | echo 1) venv 29 | echo 2) conda 30 | echo 3) existing environment 31 | set /p env_choice="Enter choice [1-3]: " 32 | 33 | REM Function to setup using venv 34 | :setup_venv 35 | %python_cmd% -m venv --help >nul 2>&1 36 | if not %errorlevel% == 0 ( 37 | echo ERROR: venv is not installed or not available. 38 | exit /b 39 | ) 40 | %python_cmd% -m venv venv 41 | if not exist "venv" ( 42 | echo ERROR: Failed to create venv environment. 43 | exit /b 44 | ) 45 | call venv\Scripts\activate.bat 46 | goto end 47 | 48 | REM Function to setup using conda 49 | :setup_conda 50 | where conda >nul 2>&1 51 | if not %errorlevel% == 0 ( 52 | echo ERROR: conda is not installed. 53 | exit /b 54 | ) 55 | 56 | REM Conda environment setup is tricky in batch and might not work as expected 57 | REM You might need to adjust this part based on your Conda setup 58 | call conda create --name anime2sd python=3.10 -y 59 | if not %errorlevel% == 0 ( 60 | echo ERROR: Failed to create conda environment. 61 | exit /b 62 | ) 63 | call conda activate anime2sd 64 | if not %errorlevel% == 0 ( 65 | echo ERROR: Failed to activate conda environment. 66 | exit /b 67 | ) 68 | goto end 69 | 70 | REM Choose environment setup 71 | if "%env_choice%"=="1" goto setup_venv 72 | if "%env_choice%"=="2" goto setup_conda 73 | if "%env_choice%"=="3" ( 74 | REM Check for existing environment 75 | if defined VIRTUAL_ENV goto end 76 | if defined CONDA_DEFAULT_ENV goto end 77 | echo ERROR: No existing Python environment is activated. 78 | exit /b 79 | ) else ( 80 | echo Invalid choice. Exiting. 81 | exit /b 82 | ) 83 | 84 | :end 85 | echo %delimiter% 86 | echo Environment setup complete. 87 | echo %delimiter% 88 | 89 | REM Run the install.py script and check for errors 90 | %python_cmd% install.py 91 | if not %errorlevel% == 0 ( 92 | echo ERROR: Installation failed. Please check the error messages above. 93 | exit /b 94 | ) 95 | 96 | REM Add notices at the end of the script 97 | echo %delimiter% 98 | echo NOTICE: If you want to run frame extraction (stage 1 of the 'screenshots' pipeline), please make sure FFmpeg is installed and can be run from the command line. 99 | echo On Windows, you can install FFmpeg using Chocolatey: https://chocolatey.org/install 100 | echo Then, run 'choco install ffmpeg' 101 | echo Or download it directly from https://ffmpeg.org/download.html 102 | echo. 103 | echo NOTICE: If you want to use onnxruntime on GPU, please make sure that CUDA 11.8 toolkit is installed and can be found on PATH 104 | echo For installation, go to: https://developer.nvidia.com/cuda-11-8-0-download-archive 105 | echo %delimiter% 106 | 107 | REM Provide instructions based on the chosen environment setup 108 | if "%env_choice%"=="1" ( 109 | echo To activate the venv environment, run: call venv\Scripts\activate.bat 110 | ) else if "%env_choice%"=="2" ( 111 | echo To activate the conda environment, run: conda activate anime2sd 112 | ) 113 | echo %delimiter% 114 | 115 | endlocal -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import sys 4 | 5 | 6 | def run(command, desc=None): 7 | if desc is not None: 8 | print(desc) 9 | 10 | # Join the command list into a single string if it's a list 11 | if isinstance(command, list): 12 | command = " ".join(command) 13 | 14 | process = subprocess.run(command, shell=True, capture_output=False) 15 | if process.returncode != 0: 16 | print(f"Error running command: {command}") 17 | print(f"Error code: {process.returncode}") 18 | sys.exit(1) 19 | 20 | 21 | def install_package(package, command): 22 | if not is_installed(package): 23 | run(command, f"Installing {package}") 24 | 25 | 26 | def is_installed(package): 27 | try: 28 | subprocess.run( 29 | f"{sys.executable} -m pip show {package}", 30 | shell=True, 31 | capture_output=True, 32 | check=True, 33 | ) 34 | return True 35 | except subprocess.CalledProcessError: 36 | return False 37 | 38 | 39 | def prepare_environment(): 40 | # Install PyTorch 41 | # Use cuda 11.8 here for consistency with onnxruntime but that 42 | # does not really matter since they use cuda from different places anyway 43 | install_package( 44 | "torch", 45 | ( 46 | f"{sys.executable} -m pip install torch torchvision torchaudio " 47 | "--index-url https://download.pytorch.org/whl/cu118" 48 | ), 49 | ) 50 | 51 | # Install other requirements from requirements.txt 52 | requirements_path = os.path.join(os.getcwd(), "requirements.txt") 53 | if os.path.exists(requirements_path): 54 | run( 55 | f"{sys.executable} -m pip install -r {requirements_path}", 56 | "Installing requirements from requirements.txt", 57 | ) 58 | 59 | # Install waifuc package 60 | waifuc_path = os.path.join(os.getcwd(), "waifuc") 61 | if os.path.exists(waifuc_path): 62 | os.chdir(waifuc_path) 63 | run( 64 | f"{sys.executable} -m pip install .", 65 | "Installing waifuc package", 66 | ) 67 | 68 | 69 | if __name__ == "__main__": 70 | prepare_environment() 71 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Define python command 4 | python_cmd="python" 5 | 6 | # Pretty print 7 | delimiter="################################################################" 8 | 9 | printf "\n%s\n" "${delimiter}" 10 | printf "\e[1m\e[32mPython Environment Setup Script\n\e[0m" 11 | printf "\n%s\n" "${delimiter}" 12 | 13 | 14 | # Update Git submodules 15 | printf "\e[1m\e[34mUpdating Git submodules...\e[0m\n" 16 | git submodule update --init --recursive 17 | if [ $? -ne 0 ]; then 18 | printf "\e[1m\e[31mERROR: Failed to update Git submodules.\e[0m\n" 19 | exit 1 20 | fi 21 | 22 | 23 | # Check if the script is being run as root 24 | if [[ $(id -u) -eq 0 ]]; then 25 | printf "\e[1m\e[31mERROR: This script should not be run as root.\e[0m\n" 26 | exit 1 27 | fi 28 | 29 | 30 | # Prompt user for environment setup choice 31 | echo "Select environment setup:" 32 | echo "1) venv" 33 | echo "2) conda" 34 | echo "3) existing environment" 35 | read -p "Enter choice [1-3]: " env_choice 36 | 37 | 38 | # Function to setup using venv 39 | setup_venv() { 40 | if ! ${python_cmd} -m venv --help > /dev/null 2>&1; then 41 | printf "\e[1m\e[31mERROR: venv is not installed or not available.\e[0m\n" 42 | exit 1 43 | fi 44 | 45 | # Create venv environment 46 | ${python_cmd} -m venv venv 47 | if [ ! -d "venv" ]; then 48 | printf "\e[1m\e[31mERROR: Failed to create venv environment.\e[0m\n" 49 | exit 1 50 | fi 51 | 52 | # Activate venv environment 53 | if [ -f "venv/bin/activate" ]; then 54 | source venv/bin/activate 55 | else 56 | printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m\n" 57 | exit 1 58 | fi 59 | } 60 | 61 | 62 | # Function to setup using conda 63 | setup_conda() { 64 | if ! command -v conda > /dev/null; then 65 | printf "\e[1m\e[31mERROR: conda is not installed.\e[0m\n" 66 | exit 1 67 | fi 68 | 69 | # Path to the Conda initialization script 70 | CONDA_INIT_SCRIPT="$HOME/miniconda3/etc/profile.d/conda.sh" 71 | 72 | # Check if the Conda initialization script exists 73 | if [ ! -f "$CONDA_INIT_SCRIPT" ]; then 74 | printf "\e[1m\e[31mERROR: Conda initialization script not found at %s.\e[0m\n" "$CONDA_INIT_SCRIPT" 75 | exit 1 76 | fi 77 | 78 | # Initialize Conda for script 79 | source "$CONDA_INIT_SCRIPT" 80 | 81 | # Create conda environment 82 | conda create --name anime2sd python=3.10 -y 83 | if [ $? -ne 0 ]; then 84 | printf "\e[1m\e[31mERROR: Failed to create conda environment.\e[0m\n" 85 | exit 1 86 | fi 87 | 88 | # Activate conda environment 89 | conda activate anime2sd 90 | if [ $? -ne 0 ]; then 91 | printf "\e[1m\e[31mERROR: Failed to activate conda environment.\e[0m\n" 92 | exit 1 93 | fi 94 | } 95 | 96 | 97 | # Choose environment setup 98 | case $env_choice in 99 | 1) 100 | setup_venv 101 | ;; 102 | 2) 103 | setup_conda 104 | ;; 105 | 3) 106 | if [[ -z "${VIRTUAL_ENV}" && -z "${CONDA_DEFAULT_ENV}" ]]; then 107 | printf "\e[1m\e[31mERROR: No existing Python environment is activated.\e[0m\n" 108 | exit 1 109 | fi 110 | ;; 111 | *) 112 | printf "\e[1m\e[31mInvalid choice. Exiting.\e[0m\n" 113 | exit 1 114 | ;; 115 | esac 116 | 117 | printf "\e[1m\e[32mEnvironment setup complete.\e[0m\n" 118 | printf "\n%s\n" "${delimiter}" 119 | 120 | # Run the install.py script and check for errors 121 | if ! ${python_cmd} install.py; then 122 | printf "\e[1m\e[31mERROR: Installation failed. Please check the error messages above.\e[0m\n" 123 | exit 1 124 | fi 125 | 126 | printf "\e[1m\e[32mInstallation complete.\e[0m\n" 127 | printf "\n%s\n" "${delimiter}" 128 | 129 | 130 | # Add notices at the end of the script 131 | if [ $? -eq 0 ]; then 132 | printf "\e[1m\e[33mNOTICE:\e[0m If you want to run frame extraction (stage 1 of the 'screenshots' pipeline), " 133 | printf "please make sure ffmpeg is installed and can be run from the command line.\n" 134 | printf "\e[1m\e[33mOn Ubuntu, run:\e[0m sudo apt update && sudo apt install ffmpeg\n\n" 135 | 136 | printf "\e[1m\e[33mNOTICE:\e[0m If you want to use onnxruntime on GPU, " 137 | printf "please make sure that CUDA 11.8 toolkit is installed and can be found on LD_LIBRARY_PATH\n" 138 | printf "\e[1m\e[33mFor installation, go to:\e[0m https://developer.nvidia.com/cuda-11-8-0-download-archive\n" 139 | fi 140 | printf "\n%s\n" "${delimiter}" 141 | 142 | 143 | # Provide instructions based on the chosen environment setup 144 | case $env_choice in 145 | 1) 146 | printf "\e[1m\e[32mTo activate the venv environment, run:\e[0m source venv/bin/activate\n" 147 | printf "\n%s\n" "${delimiter}" 148 | ;; 149 | 2) 150 | printf "\e[1m\e[32mTo activate the conda environment, run:\e[0m conda activate anime2sd\n" 151 | printf "\n%s\n" "${delimiter}" 152 | ;; 153 | 3) 154 | # No additional instructions needed for existing environment 155 | ;; 156 | esac 157 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dghs_imgutils==0.2.10 2 | natsort==8.4.0 3 | numpy==1.26.2 4 | Pillow==10.1.0 5 | tqdm==4.64.0 6 | toml==0.10.2 7 | scikit_learn==1.3.2 8 | timm==0.9.12 9 | hbutils==0.9.3 10 | pynyaasi==0.0.1 11 | torrentp==0.1.6 12 | transformers==4.35.0 # Only used for embedding name checking 13 | wheel==0.42.0 # For waifuc installation 14 | -------------------------------------------------------------------------------- /scripts_v1/classifier_dataset_preparation/data_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pandas as pd 4 | 5 | 6 | def data_split(data_dic_path, split): 7 | # splits data into training and testing 8 | 9 | data_folder = os.path.dirname(os.path.realpath(data_dic_path)) 10 | df = pd.read_csv(data_dic_path, sep=',') 11 | print('Original df: ', len(df)) 12 | 13 | samples_per_class_df = df.groupby('class_id', as_index=True).count() 14 | 15 | df_list_train = [] 16 | df_list_test = [] 17 | for class_id, total_samples_class in enumerate( 18 | samples_per_class_df['file_rel_path']): 19 | train_samples_class = int(total_samples_class * split[0]) 20 | test_samples_class = total_samples_class - train_samples_class 21 | assert (train_samples_class + 22 | test_samples_class == total_samples_class) 23 | train_subset_class = df.loc[df['class_id'] == class_id].groupby( 24 | 'class_id').head(train_samples_class) 25 | test_subset_class = df.loc[df['class_id'] == class_id].groupby( 26 | 'class_id').tail(test_samples_class) 27 | df_list_train.append(train_subset_class) 28 | df_list_test.append(test_subset_class) 29 | 30 | df_train = pd.concat(df_list_train) 31 | df_test = pd.concat(df_list_test) 32 | 33 | print('Train df: ') 34 | print(df_train.head()) 35 | print(df_train.shape) 36 | print('Test df: ') 37 | print(df_test.head()) 38 | print(df_test.shape) 39 | 40 | df_train_name = os.path.join(data_folder, 'train.csv') 41 | df_train.to_csv(df_train_name, sep=',', header=True, index=False) 42 | 43 | df_test_name = os.path.join(data_folder, 'test.csv') 44 | df_test.to_csv(df_test_name, sep=',', header=True, index=False) 45 | print('Finished saving train and test split dictionaries.') 46 | 47 | 48 | def main(): 49 | 50 | data_dic_path = os.path.abspath(sys.argv[1]) 51 | split_train = float(sys.argv[2]) # % of data for train (def: 0.8) 52 | split_test = float(sys.argv[3]) # % of data for val (def: 0.2) 53 | assert split_train + split_test == 1, 'Arguments for split ratios should add up to 1' 54 | split = [split_train, split_test] 55 | 56 | print(split) 57 | data_split(data_dic_path, split) 58 | 59 | 60 | main() 61 | -------------------------------------------------------------------------------- /scripts_v1/classifier_dataset_preparation/make_data_dic_imagenetsyle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import pandas as pd 5 | import numpy as np 6 | import time 7 | from PIL import Image 8 | 9 | 10 | def make_data_dic(data_folder): 11 | 12 | # makes an imagefolder (imagenet style) with images of class in 13 | # a certain folder into a txt dictionary with the first column being 14 | # the file dir (relative) and the second into the class 15 | types = ('*.jpg', '*.jpeg', '*.png', '*.webp') # the tuple of file types 16 | files_all = [] 17 | for file_type in types: 18 | # files_all is the list of files 19 | path = os.path.join(data_folder, '**', file_type) 20 | files_curr_type = glob.glob(path, recursive=True) 21 | files_all.extend(files_curr_type) 22 | 23 | print(file_type, len(files_curr_type)) 24 | print('Total image files pre-filtering', len(files_all)) 25 | 26 | class_name_list = [] # holds classes names and is also relative path 27 | filename_classid_dic = {} # filename and classid pairs 28 | classid_classname_dic = {} # id and class name/rel path as dict 29 | 30 | idx = -1 31 | for file_path in files_all: 32 | # verify the image is RGB 33 | im = Image.open(file_path) 34 | if im.mode == 'RGB': 35 | abs_path, filename = os.path.split(file_path) 36 | _, class_name = os.path.split(abs_path) 37 | rel_path = os.path.join(class_name, filename) 38 | if class_name not in class_name_list: 39 | idx += 1 40 | class_name_list.append(class_name) 41 | classid_classname_dic[idx] = class_name 42 | 43 | tag_file = file_path + '.tags' 44 | if os.path.exists(tag_file): 45 | with open(tag_file, 'r') as f: 46 | tags = f.readline() 47 | tags = [tag.strip() for tag in tags.split(',')] 48 | else: 49 | tags = [] 50 | 51 | filename_classid_dic[rel_path] = [idx, tags] 52 | 53 | no_classes = idx + 1 54 | print('Total number of classes: ', no_classes) 55 | print('Total images files post-filtering (RGB only): ', 56 | len(filename_classid_dic)) 57 | 58 | # save dataframe to hold the class IDs and the relative paths of the files 59 | df = pd.DataFrame.from_dict(filename_classid_dic, 60 | orient='index', 61 | columns=['class_id', 'tags']) 62 | idx_col = np.arange(0, len(df), 1) 63 | df['idx_col'] = idx_col 64 | df['file_rel_path'] = df.index 65 | df.set_index('idx_col', inplace=True) 66 | df = df[['class_id', 'file_rel_path', 'tags']] 67 | print(df.head()) 68 | df_name = os.path.join(data_folder, 'labels.csv') 69 | df.to_csv(df_name, sep=',', header=True, index=False) 70 | 71 | df_classid_classname = pd.DataFrame.from_dict(classid_classname_dic, 72 | orient='index', 73 | columns=['class_id']) 74 | idx_col = np.arange(0, len(df_classid_classname), 1) 75 | df_classid_classname['idx_col'] = idx_col 76 | df_classid_classname['class_name'] = df_classid_classname.index 77 | df_classid_classname.set_index('idx_col', inplace=True) 78 | cols = df_classid_classname.columns.tolist() 79 | cols = cols[-1:] + cols[:-1] # reordering the columns 80 | df_classid_classname = df_classid_classname[cols] 81 | print(df_classid_classname.head()) 82 | df_classid_classname_name = os.path.join(data_folder, 83 | 'classid_classname.csv') 84 | df_classid_classname.to_csv(df_classid_classname_name, 85 | sep=',', 86 | header=True, 87 | index=False) 88 | classname_file = os.path.join(data_folder, 'classnames.txt') 89 | with open(classname_file, 'w') as f: 90 | f.write('\n'.join(sorted(class_name_list))) 91 | time.sleep(1) 92 | 93 | 94 | def main(): 95 | ''' 96 | input is the path to the folder with imagenet-like structure 97 | imagenet/ 98 | imagenet/class1/ 99 | imagenet/class2/ 100 | ... 101 | imagenet/classN/ 102 | ''' 103 | try: 104 | data_folder = os.path.abspath(sys.argv[1]) 105 | except: 106 | data_folder = '.' 107 | make_data_dic(data_folder) 108 | 109 | 110 | main() 111 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/models/README.md: -------------------------------------------------------------------------------- 1 | ### Setup 2 | ``` 3 | pip install -e . 4 | python download_convert_models.py 5 | # can modify to download different models, by default it downloads all 5 ViTs pretrained on ImageNet21k 6 | ``` 7 | ### Usage 8 | ``` 9 | from vit_animesion import ViT, ViTConfigExtended, PRETRAINED_CONFIGS 10 | model_name = 'B_16' 11 | def_config = PRETRAINED_CONFIGS['{}'.format(model_name)]['config'] 12 | configuration = ViTConfigExtended(**def_config) 13 | model = ViT(configuration, name=model_name, pretrained=True, 14 | load_repr_layer=False, ret_attn_scores=False) 15 | ``` 16 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/models/download_convert_models.py: -------------------------------------------------------------------------------- 1 | from vit_animesion import ViT, ViTConfigExtended, PRETRAINED_CONFIGS 2 | 3 | models_list = ['B_16', 'B_32', 'L_16', 'L_32', 'H_14'] 4 | for model_name in models_list: 5 | def_config = PRETRAINED_CONFIGS['{}'.format(model_name)]['config'] 6 | configuration = ViTConfigExtended(**def_config) 7 | model = ViT(configuration, name=model_name, pretrained=True, load_repr_layer=True) 8 | 9 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/models/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pipenv install twine --dev 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = 'vit-animesion' 16 | DESCRIPTION = 'Visual Transformers (ViT) for Animesion' 17 | URL = 'blank' 18 | EMAIL = 'blank' 19 | AUTHOR = 'blank' 20 | REQUIRES_PYTHON = '>=3.5.0' 21 | VERSION = '0.0.7' 22 | 23 | # What packages are required for this module to be executed? 24 | REQUIRED = [ 25 | 'torch' 26 | ] 27 | 28 | # What packages are optional? 29 | EXTRAS = { 30 | # 'fancy feature': ['django'], 31 | } 32 | 33 | # The rest you shouldn't have to touch too much :) 34 | # ------------------------------------------------ 35 | # Except, perhaps the License and Trove Classifiers! 36 | # If you do change the License, remember to change the Trove Classifier for that! 37 | 38 | here = os.path.abspath(os.path.dirname(__file__)) 39 | 40 | # Import the README and use it as the long-description. 41 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 42 | try: 43 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 44 | long_description = '\n' + f.read() 45 | except FileNotFoundError: 46 | long_description = DESCRIPTION 47 | 48 | # Load the package's __version__.py module as a dictionary. 49 | about = {} 50 | if not VERSION: 51 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 52 | with open(os.path.join(here, project_slug, '__version__.py')) as f: 53 | exec(f.read(), about) 54 | else: 55 | about['__version__'] = VERSION 56 | 57 | 58 | class UploadCommand(Command): 59 | """Support setup.py upload.""" 60 | 61 | description = 'Build and publish the package.' 62 | user_options = [] 63 | 64 | @staticmethod 65 | def status(s): 66 | """Prints things in bold.""" 67 | print('\033[1m{0}\033[0m'.format(s)) 68 | 69 | def initialize_options(self): 70 | pass 71 | 72 | def finalize_options(self): 73 | pass 74 | 75 | def run(self): 76 | try: 77 | self.status('Removing previous builds…') 78 | rmtree(os.path.join(here, 'dist')) 79 | except OSError: 80 | pass 81 | 82 | self.status('Building Source and Wheel (universal) distribution…') 83 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 84 | 85 | self.status('Uploading the package to PyPI via Twine…') 86 | os.system('twine upload dist/*') 87 | 88 | self.status('Pushing git tags…') 89 | os.system('git tag v{0}'.format(about['__version__'])) 90 | os.system('git push --tags') 91 | 92 | sys.exit() 93 | 94 | 95 | # Where the magic happens: 96 | setup( 97 | name=NAME, 98 | version=about['__version__'], 99 | description=DESCRIPTION, 100 | long_description=long_description, 101 | long_description_content_type='text/markdown', 102 | author=AUTHOR, 103 | author_email=EMAIL, 104 | python_requires=REQUIRES_PYTHON, 105 | url=URL, 106 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 107 | # py_modules=['model'], # If your package is a single module, use this instead of 'packages' 108 | install_requires=REQUIRED, 109 | extras_require=EXTRAS, 110 | include_package_data=True, 111 | license='Apache', 112 | classifiers=[ 113 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 114 | 'License :: OSI Approved :: Apache Software License', 115 | 'Programming Language :: Python', 116 | 'Programming Language :: Python :: 3', 117 | 'Programming Language :: Python :: 3.6', 118 | ], 119 | # $ setup.py publish support. 120 | cmdclass={ 121 | 'upload': UploadCommand, 122 | }, 123 | ) 124 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/models/vit_animesion/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ViT 2 | from .configs import * 3 | 4 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/models/vit_animesion/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/lukemelas/simple-bert 3 | """ 4 | 5 | import numpy as np 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | 10 | 11 | def split_last(x, shape): 12 | "split the last dimension to given shape" 13 | shape = list(shape) 14 | assert shape.count(-1) <= 1 15 | if -1 in shape: 16 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 17 | return x.view(*x.size()[:-1], *shape) 18 | 19 | 20 | def merge_last(x, n_dims): 21 | "merge the last n_dims to a dimension" 22 | s = x.size() 23 | assert n_dims > 1 and n_dims < len(s) 24 | return x.view(*s[:-n_dims], -1) 25 | 26 | 27 | class MultiHeadedSelfAttention(nn.Module): 28 | """Multi-Headed Dot Product Attention""" 29 | def __init__(self, dim, num_heads, dropout, ret_attn_scores): 30 | super().__init__() 31 | self.proj_q = nn.Linear(dim, dim) 32 | self.proj_k = nn.Linear(dim, dim) 33 | self.proj_v = nn.Linear(dim, dim) 34 | self.drop = nn.Dropout(dropout) 35 | self.n_heads = num_heads 36 | self.ret_attn_scores = ret_attn_scores 37 | 38 | def forward(self, x, mask): 39 | """ 40 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 41 | mask : (B(batch_size) x S(seq_len)) 42 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 43 | """ 44 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 45 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 46 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) 47 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 48 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 49 | if mask is not None: 50 | mask = mask[:, None, None, :].float() 51 | scores -= 10000.0 * (1.0 - mask) 52 | # this is what's used to visualize attention 53 | scores = self.drop(F.softmax(scores, dim=-1)) 54 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 55 | h = (scores @ v).transpose(1, 2).contiguous() 56 | # -merge-> (B, S, D) 57 | h = merge_last(h, 2) 58 | if self.ret_attn_scores: 59 | return h, scores 60 | else: 61 | return h 62 | 63 | 64 | class PositionWiseFeedForward(nn.Module): 65 | """FeedForward Neural Networks for each position""" 66 | def __init__(self, dim, ff_dim): 67 | super().__init__() 68 | self.fc1 = nn.Linear(dim, ff_dim) 69 | self.fc2 = nn.Linear(ff_dim, dim) 70 | 71 | def forward(self, x): 72 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 73 | return self.fc2(F.gelu(self.fc1(x))) 74 | 75 | 76 | class Block(nn.Module): 77 | """Transformer Block""" 78 | def __init__(self, dim, num_heads, ff_dim, hidden_dropout_prob, 79 | attention_probs_dropout_prob, layer_norm_eps, ret_attn_scores): 80 | super().__init__() 81 | self.attn = MultiHeadedSelfAttention(dim, num_heads, attention_probs_dropout_prob, ret_attn_scores) 82 | self.proj = nn.Linear(dim, dim) 83 | self.norm1 = nn.LayerNorm(dim, eps=layer_norm_eps) 84 | self.pwff = PositionWiseFeedForward(dim, ff_dim) 85 | self.norm2 = nn.LayerNorm(dim, eps=layer_norm_eps) 86 | self.drop = nn.Dropout(hidden_dropout_prob) 87 | self.ret_attn_scores = ret_attn_scores 88 | 89 | def forward(self, x, mask): 90 | if self.ret_attn_scores: 91 | h, scores = self.attn(self.norm1(x), mask) # eq 1 92 | else: 93 | h = self.attn(self.norm1(x), mask) # eq 1 94 | h = self.drop(self.proj(h)) # eq 1 95 | x = x + h # eq 2 96 | h = self.drop(self.pwff(self.norm2(x))) # eq 3 97 | x = x + h # eq 3 98 | if self.ret_attn_scores: 99 | return x, scores 100 | else: 101 | return x 102 | 103 | 104 | class Transformer(nn.Module): 105 | """Transformer with Self-Attentive Blocks""" 106 | def __init__(self, num_layers, dim, num_heads, ff_dim, hidden_dropout_prob, 107 | attention_probs_dropout_prob, layer_norm_eps, ret_attn_scores, ret_interm_repr): 108 | super().__init__() 109 | self.blocks = nn.ModuleList([ 110 | Block(dim, num_heads, ff_dim, hidden_dropout_prob, 111 | attention_probs_dropout_prob, layer_norm_eps, ret_attn_scores) for _ in range(num_layers)]) 112 | 113 | self.ret_attn_scores = ret_attn_scores 114 | self.ret_interm_repr = ret_interm_repr 115 | 116 | def forward(self, x, mask=None): 117 | if self.ret_attn_scores: 118 | scores_list = [] 119 | if self.ret_interm_repr: 120 | interm_repr_list = [] 121 | 122 | for block in self.blocks: 123 | if self.ret_attn_scores: 124 | x, scores = block(x, mask) 125 | scores_list.append(scores) 126 | else: 127 | x = block(x, mask) 128 | if self.ret_interm_repr: 129 | interm_repr_list.append(x) 130 | 131 | if self.ret_interm_repr and self.ret_attn_scores: 132 | return x, interm_repr_list, scores_list 133 | elif self.ret_interm_repr: 134 | return x, interm_repr_list 135 | elif self.ret_attn_scores: 136 | return x, scores_list 137 | else: 138 | return x 139 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/requirements.txt: -------------------------------------------------------------------------------- 1 | efficientnet-pytorch==0.7.1 2 | einops==0.3.0 3 | gradio==2.2.8 4 | huggingface-hub==0.0.12 5 | opencv-python==4.4.0.46 6 | regex==2021.8.3 7 | tokenizers==0.10.3 8 | torchsummary==1.5.1 9 | transformers==4.9.1 10 | wandb==0.11.2 11 | wordcloud==1.8.1 12 | gdown==3.13.0 -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduler import * 2 | from .data_selection import * 3 | from .data_selection_customize import * 4 | from .misc import * 5 | from .model_selection import * 6 | from .custom_tokenizer import CustomTokenizer 7 | from .build_vocab import Vocabulary 8 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/build_vocab.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import pickle 3 | import argparse 4 | import pandas as pd 5 | from collections import Counter 6 | 7 | class Vocabulary(object): 8 | # https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/image_captioning 9 | """Simple vocabulary wrapper.""" 10 | def __init__(self): 11 | self.word2idx = {} 12 | self.idx2word = {} 13 | self.idx = 0 14 | 15 | def add_word(self, word): 16 | if not word in self.word2idx: 17 | self.word2idx[word] = self.idx 18 | self.idx2word[self.idx] = word 19 | self.idx += 1 20 | 21 | def __call__(self, word): 22 | if not word in self.word2idx: 23 | return self.word2idx['[UNK]'] 24 | return self.word2idx[word] 25 | 26 | def ret_word(self, idx): 27 | if not idx in self.idx2word: 28 | return '[UNK]' 29 | return self.idx2word[idx] 30 | 31 | def __len__(self): 32 | return len(self.word2idx) 33 | 34 | 35 | def build_vocab(args): 36 | """Build a simple vocabulary wrapper.""" 37 | df = pd.read_csv(args.df_path, usecols=['tags_cat0']) 38 | counter = Counter() 39 | for i, tag in enumerate(df.tags_cat0): 40 | tokens = ast.literal_eval(tag) 41 | counter.update(tokens) 42 | if i % 10000 == 0: 43 | print('Progress: {}/{}'.format(i, len(df))) 44 | 45 | # If the word frequency is less than 'threshold', then the word is discarded. 46 | print('Number of unique tokens: ', len(counter)) 47 | print('Total number of tokens: ', sum(counter.values())) 48 | words = [word for word, cnt in counter.items() if cnt >= args.threshold] 49 | 50 | # Create a vocab wrapper and add some special tokens. 51 | vocab = Vocabulary() 52 | vocab.add_word('[PAD]') 53 | vocab.add_word('[UNUSED]') 54 | vocab.add_word('[CLS]') 55 | vocab.add_word('[SEP]') 56 | vocab.add_word('[UNK]') 57 | 58 | # Add the words to the vocabulary. 59 | for i, word in enumerate(words): 60 | vocab.add_word(word) 61 | print("Total vocabulary size: {}".format(len(vocab))) 62 | 63 | return vocab 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--df_path', type=str, required=True, 68 | help='Path for dataframe file with whole list of files and tags') 69 | parser.add_argument('--vocab_path', type=str, default='vocab.pkl') 70 | parser.add_argument('--threshold', type=int, default=2, 71 | help='Minimum tag count to put into final vocabulary') 72 | args = parser.parse_args() 73 | 74 | vocab = build_vocab(args) 75 | with open(args.vocab_path, 'wb') as f: 76 | pickle.dump(vocab, f) 77 | print('Saved vocabulary wrapper to ', args.vocab_path) 78 | 79 | if __name__ == '__main__': 80 | main() 81 | 82 | 83 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/calc_tokens_len.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import statistics 3 | import pandas as pd 4 | from transformers import BertTokenizer 5 | 6 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 7 | df = pd.read_csv("/edahome/pcslab/pcs05/edwin/data/Danbooru2018AnimeCharacterRecognitionDataset_Revamped/compressed/dafre_tags_symbolsremoved_minlen2_minapp2_profsremoved_filledempty.csv") 8 | len_tags_list = [] 9 | len_tags_tokenized_list = [] 10 | for i, tag in enumerate(df.tags_cat0): 11 | tag_list = ast.literal_eval(tag) 12 | tag_str = ' '.join(tag_list) 13 | tag_tokenized = tokenizer(tag_str)['input_ids'] 14 | 15 | #print(len(tag_list), len(tag_tokenized)) 16 | len_tags_list.append(len(tag_list)) 17 | len_tags_tokenized_list.append(len(tag_tokenized)) 18 | if i % 10000 == 0: 19 | print('{}/{}'.format(i, len(df))) 20 | 21 | print('Median, mean and standard deviation of tags in list form :', 22 | statistics.median(len_tags_list), statistics.mean(len_tags_list), statistics.stdev(len_tags_list)) 23 | print('Median, mean and standard deviation of tags after being tokenized using BERT default tokenizer: ', 24 | statistics.median(len_tags_tokenized_list), statistics.mean(len_tags_tokenized_list), statistics.stdev(len_tags_tokenized_list)) 25 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/custom_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | from .build_vocab import Vocabulary 4 | 5 | class CustomTokenizer(object): 6 | def __init__(self, vocab_path, max_text_seq_len, ret_tensor=True): 7 | with open(vocab_path, 'rb') as f: 8 | self.vocab = pickle.load(f) 9 | self.vocab_size = len(self.vocab) 10 | self.max_text_seq_len = max_text_seq_len 11 | self.ret_tensor = ret_tensor 12 | 13 | def __call__(self, tag_list): 14 | no_tokens = len(tag_list) + 2 15 | diff = abs(self.max_text_seq_len - no_tokens) 16 | 17 | tokens = [] 18 | tokens.append(self.vocab('[CLS]')) 19 | 20 | if no_tokens > self.max_text_seq_len: 21 | tokens.extend([self.vocab(tag) for tag in tag_list[:self.max_text_seq_len-2]]) 22 | tokens.append(self.vocab('[SEP]')) 23 | elif no_tokens < self.max_text_seq_len: 24 | tokens.extend([self.vocab(tag) for tag in tag_list]) 25 | tokens.append(self.vocab('[SEP]')) 26 | tokens.extend([self.vocab('[PAD]') for _ in range(diff)]) 27 | else: 28 | tokens.extend([self.vocab(tag) for tag in tag_list]) 29 | tokens.append(self.vocab('[SEP]')) 30 | 31 | if self.ret_tensor: 32 | return torch.tensor([tokens], dtype=torch.int64) 33 | return tokens 34 | 35 | def decode(self, tokens_list): 36 | if self.ret_tensor: 37 | tokens_notensor = tokens_list.squeeze().tolist() 38 | tag_list = [self.vocab.ret_word(idx) for idx in tokens_notensor] 39 | return tag_list 40 | else: 41 | return [self.vocab.ret_word(idx) for idx in tokens_list] 42 | 43 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/data_selection_customize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ast 3 | import random 4 | import pandas as pd 5 | from PIL import Image 6 | from PIL import ImageFile 7 | 8 | import torch 9 | import torch.utils.data as data 10 | from torchvision import transforms 11 | from transformers import BertTokenizer 12 | 13 | from .custom_tokenizer import CustomTokenizer 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def load_data(args, split, label_csv_name='labels.csv'): 19 | 20 | transform = None 21 | 22 | dataset = CustomDataset(args, 23 | split=split, 24 | label_csv_name=label_csv_name, 25 | transform=transform) 26 | 27 | dataset_loader = data.DataLoader(dataset, 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | num_workers=args.no_cpu_workers, 31 | drop_last=True) 32 | 33 | return dataset, dataset_loader 34 | 35 | 36 | def get_transform(split, image_size): 37 | if split == 'train': 38 | transform = transforms.Compose([ 39 | transforms.Resize((image_size + 32, image_size + 32)), 40 | transforms.RandomCrop((image_size, image_size)), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ColorJitter(brightness=0.1, 43 | contrast=0.1, 44 | saturation=0.1, 45 | hue=0.1), 46 | transforms.ToTensor(), 47 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 48 | ]) 49 | else: 50 | transform = transforms.Compose([ 51 | transforms.Resize((image_size, image_size)), 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 54 | ]) 55 | return transform 56 | 57 | 58 | class CustomDataset(data.Dataset): 59 | def __init__(self, 60 | args, 61 | split='train', 62 | transform=None, 63 | label_csv_name='labels.csv'): 64 | super().__init__() 65 | self.root = os.path.abspath(args.dataset_path) 66 | self.image_size = args.image_size 67 | self.split = split 68 | self.transform = transform 69 | 70 | self.tokenizer_method = args.tokenizer 71 | self.max_text_seq_len = args.max_text_seq_len 72 | self.shuffle = args.shuffle_tokens 73 | self.label_csv = os.path.join(self.root, label_csv_name) 74 | 75 | if self.transform is None: 76 | self.transform = get_transform(split=split, 77 | image_size=self.image_size) 78 | 79 | if self.max_text_seq_len: 80 | if self.tokenizer_method == 'wp': 81 | self.tokenizer = BertTokenizer.from_pretrained( 82 | 'bert-base-uncased') 83 | elif self.tokenizer_method == 'tag': 84 | self.tokenizer = CustomTokenizer( 85 | vocab_path=os.path.join(args.dataset_path, 'vocab.pkl'), 86 | max_text_seq_len=args.max_text_seq_len) 87 | self.df = pd.read_csv(self.label_csv) 88 | else: 89 | self.df = pd.read_csv(self.label_csv) 90 | 91 | self.targets = self.df['class_id'].to_numpy() 92 | self.data = self.df['file_rel_path'].to_numpy() 93 | 94 | self.classes = pd.read_csv(os.path.join(self.root, 95 | 'classid_classname.csv')) 96 | self.num_classes = len(self.classes) 97 | 98 | def __getitem__(self, idx): 99 | 100 | if torch.is_tensor(idx): 101 | idx = idx.tolist() 102 | 103 | img_dir, target = self.data[idx], self.targets[idx] 104 | img_dir = os.path.join(self.root, 'data', img_dir) 105 | img = Image.open(img_dir) 106 | 107 | if self.transform: 108 | img = self.transform(img) 109 | 110 | if self.max_text_seq_len: 111 | caption = ast.literal_eval(self.df.iloc[idx].tags) 112 | if self.shuffle: 113 | random.shuffle(caption) 114 | if self.tokenizer_method == 'wp': 115 | caption = ' '.join(caption) # originally joined by '[SEP]' 116 | caption = self.tokenizer(caption, 117 | return_tensors='pt', 118 | padding='max_length', 119 | max_length=self.max_text_seq_len, 120 | truncation=True)['input_ids'] 121 | elif self.tokenizer_method == 'tag': 122 | caption = self.tokenizer(caption) 123 | return img, target, caption 124 | else: 125 | return img, target 126 | 127 | def __len__(self): 128 | return len(self.targets) 129 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/.gitignore: -------------------------------------------------------------------------------- 1 | # environments 2 | env/ 3 | ENV/ 4 | Env/ 5 | 6 | # data 7 | data/ 8 | 9 | # debugging files 10 | tests/paste.txt 11 | 12 | # jupyter notebook checkpoints 13 | .ipynb_checkpoints/ 14 | 15 | # pip 16 | loss_landscapes.egg-info/ 17 | 18 | # dist 19 | dist/ 20 | build/ 21 | 22 | # tests 23 | tests/ 24 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/MANIFEST.in: -------------------------------------------------------------------------------- 1 | Include the README 2 | include *.md 3 | 4 | # Include the license file 5 | include LICENSE.txt -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/img/loss-contour-3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyber-meow/anime_screenshot_pipeline/c9e3fb804c3847d136c2124a68c7af4b17ef3219/scripts_v1/classifier_training/utilities/loss-landscapes/img/loss-contour-3d.png -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/img/loss-contour.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyber-meow/anime_screenshot_pipeline/c9e3fb804c3847d136c2124a68c7af4b17ef3219/scripts_v1/classifier_training/utilities/loss-landscapes/img/loss-contour.png -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/img/loss-landscape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyber-meow/anime_screenshot_pipeline/c9e3fb804c3847d136c2124a68c7af4b17ef3219/scripts_v1/classifier_training/utilities/loss-landscapes/img/loss-landscape.png -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/__init__.py: -------------------------------------------------------------------------------- 1 | from loss_landscapes.main import point 2 | from loss_landscapes.main import linear_interpolation 3 | from loss_landscapes.main import random_line 4 | from loss_landscapes.main import planar_interpolation 5 | from loss_landscapes.main import random_plane 6 | from loss_landscapes.model_interface.model_wrapper import ModelWrapper, GeneralModelWrapper 7 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyber-meow/anime_screenshot_pipeline/c9e3fb804c3847d136c2124a68c7af4b17ef3219/scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/contrib/__init__.py -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/contrib/connecting_paths.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module exposes functions for loss landscape operations which are more complex than simply 3 | computing the loss at different points in parameter space. This includes things such as Kolsbjerg 4 | et al.'s Automated Nudged Elastic Band algorithm. 5 | """ 6 | 7 | 8 | import abc 9 | import copy 10 | import numpy as np 11 | from loss_landscapes.model_interface.model_interface import wrap_model 12 | 13 | 14 | class _ParametricCurve(abc.ABC): 15 | """ A _ParametricCurve is used in the Garipov path search algorithm. """ 16 | # todo 17 | 18 | 19 | class _PolygonChain(_ParametricCurve): 20 | """ A _ParametricCurve consisting of consecutive line segments. """ 21 | # todo 22 | pass 23 | 24 | 25 | class _BezierCurve(_ParametricCurve): 26 | """ 27 | A Bezier curve is a parametric curve defined by a set of control points, including 28 | a start point and an end-point. The order of the curve refers to the number of control 29 | points excluding the start point: for example, an order 1 (linear) Bezier curve is 30 | defined by 2 points, an order 2 (quadratic) Bezier curve is defined by 3 points, and 31 | so on. 32 | 33 | In this library, each point is a neural network model with a specific value assignment 34 | to the model parameters. 35 | """ 36 | def __init__(self, model_start, model_end, order=2): 37 | """ 38 | Define a Bezier curve between a start point and an end point. The order of the 39 | curve refers to the number of control points, excluding the start point. The default 40 | order of 1, for example, results in no further control points being added after 41 | the given start and end points. 42 | 43 | :param model_start: point defining start of curve 44 | :param model_end: point defining end of curve 45 | :param order: number of control points, excluding start point 46 | """ 47 | super().__init__() 48 | if order != 2: 49 | raise NotImplementedError('Currently only order 2 bezier curves are supported.') 50 | 51 | self.model_start_wrapper = wrap_model(copy.deepcopy(model_start)) 52 | self.model_end_wrapper = wrap_model(copy.deepcopy(model_end)) 53 | self.order = order 54 | self.control_points = [] 55 | 56 | # add intermediate control points 57 | if order > 1: 58 | start_parameters = self.model_start_wrapper.get_parameter_tensor() 59 | end_parameters = self.model_end_wrapper.get_parameter_tensor() 60 | direction = (end_parameters - start_parameters) / order 61 | 62 | for i in range(1, order): 63 | model_template_wrapper = copy.deepcopy(self.model_start_wrapper) 64 | model_template_wrapper.set_parameter_tensor(start_parameters + (direction * i)) 65 | self.control_points.append(model_template_wrapper) 66 | 67 | def fit(self): 68 | # todo 69 | raise NotImplementedError() 70 | 71 | 72 | def auto_neb() -> np.ndarray: 73 | """ Automatic Nudged Elastic Band algorithm, as used in https://arxiv.org/abs/1803.00885 """ 74 | # todo return list of points in parameter space to represent trajectory 75 | # todo figure out how to return points as coordinates in 2D 76 | raise NotImplementedError() 77 | 78 | 79 | def garipov_curve_search(model_a, model_b, curve_type='polygon_chain') -> np.ndarray: 80 | """ 81 | We refer by 'Garipov curve search' to the algorithm proposed by Garipov et al (2018) for 82 | finding low-loss paths between two arbitrary minima in a loss landscape. The core idea 83 | of the method is to define a parametric curve in the model's parameter space connecting 84 | one minima to the other, and then minimizing the expected loss along this curve by 85 | modifying its parameterization. For details, see https://arxiv.org/abs/1802.10026 86 | 87 | This is an alternative to the auto_neb algorithm. 88 | """ 89 | model_a_wrapper = wrap_model(model_a) 90 | model_b_wrapper = wrap_model(model_b) 91 | 92 | point_a = model_a_wrapper.get_parameter_tensor() 93 | point_b = model_b_wrapper.get_parameter_tensor() 94 | 95 | # todo 96 | if curve_type == 'polygon_chain': 97 | raise NotImplementedError('Not implemented yet.') 98 | elif curve_type == 'bezier_curve': 99 | raise NotImplementedError('Not implemented yet.') 100 | else: 101 | raise AttributeError('Curve type is not polygon_chain or bezier_curve.') 102 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/contrib/trajectories.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes and functions for tracking a model's optimization trajectory and computing 3 | a low-dimensional approximation of the trajectory. 4 | """ 5 | 6 | 7 | from abc import ABC, abstractmethod 8 | from datetime import datetime 9 | import numpy as np 10 | from loss_landscapes.model_interface.model_interface import wrap_model 11 | 12 | 13 | class TrajectoryTracker(ABC): 14 | """ 15 | A TrajectoryTracker facilitates tracking the optimization trajectory of a 16 | DL/RL model. Trajectory trackers provide facilities for storing model parameters 17 | as well as for retrieving and operating on stored parameters. 18 | """ 19 | 20 | @abstractmethod 21 | def __getitem__(self, timestep) -> np.ndarray: 22 | """ 23 | Returns the position of the model from the given training timestep as a numpy array. 24 | :param timestep: training step of parameters to retrieve 25 | :return: numpy array 26 | """ 27 | pass 28 | 29 | @abstractmethod 30 | def get_item(self, timestep) -> np.ndarray: 31 | """ 32 | Returns the position of the model from the given training timestep as a numpy array. 33 | :param timestep: training step of parameters to retrieve 34 | :return: numpy array 35 | """ 36 | pass 37 | 38 | @abstractmethod 39 | def get_trajectory(self) -> list: 40 | """ 41 | Returns a reference to the currently stored trajectory. 42 | :return: numpy array 43 | """ 44 | pass 45 | 46 | @abstractmethod 47 | def save_position(self, model): 48 | """ 49 | Appends the current model parameterization to the stored training trajectory. 50 | :param model: model object with current state of interest 51 | :return: N/A 52 | """ 53 | pass 54 | 55 | 56 | class FullTrajectoryTracker(TrajectoryTracker): 57 | """ 58 | A FullTrajectoryTracker is a tracker which stores a history of points in the tracked 59 | model's original parameter space, and can be used to perform a variety of computations 60 | on the trajectory. The tracker spills data into storage rather than keeping everything 61 | in main memory. 62 | """ 63 | def __init__(self, model, agent_interface=None, directory='./', experiment_name=None): 64 | super().__init__() 65 | self.dir = directory + (experiment_name if experiment_name is not None else str(datetime.now()) + '/') 66 | self.next_idx = 0 67 | self.save_position(model) 68 | self.agent_interface = agent_interface 69 | 70 | def __getitem__(self, timestep) -> np.ndarray: 71 | if not (1 <= timestep < self.next_idx): 72 | raise IndexError('Given timestep does not exist.') 73 | return np.load(self.dir + str(timestep) + '.npy') 74 | 75 | def get_item(self, timestep) -> np.ndarray: 76 | return self.__getitem__(timestep) 77 | 78 | def save_position(self, model): 79 | np.save(self.dir + str(self.next_idx) + '.npy', wrap_model(model, self.agent_interface).get_parameter_tensor(deepcopy=True).as_numpy()) 80 | self.next_idx += 1 81 | 82 | def get_trajectory(self) -> list: 83 | """ 84 | WARNING: be aware that full trajectory tracking requires N * M memory, where N is the 85 | number of iterations tracked and M is the size of the model. The amount of memory used 86 | by the trajectory tracker can easily become very large. 87 | :return: list of numpy arrays 88 | """ 89 | return [self[idx] for idx in range(self.next_idx)] 90 | 91 | 92 | class ProjectingTrajectoryTracker(TrajectoryTracker): 93 | """ 94 | A ProjectingTrajectoryTracker is a tracker which applies dimensionality reduction to 95 | all model parameterizations upon storage. This is particularly appropriate for large 96 | models, where storing a history of points in the model's parameter space would be 97 | unfeasible in terms of memory. 98 | """ 99 | def __init__(self, model, agent_interface=None, n_bases=2): 100 | super().__init__() 101 | self.trajectory = [] 102 | self.agent_interface = agent_interface 103 | 104 | n = wrap_model(model, agent_interface).get_parameter_tensor().numel() 105 | self.A = np.column_stack( 106 | [np.random.normal(size=n) for _ in range(n_bases)] 107 | ) 108 | 109 | def __getitem__(self, timestep) -> np.ndarray: 110 | return self.trajectory[timestep] 111 | 112 | def get_item(self, timestep) -> np.ndarray: 113 | return self.__getitem__(timestep) 114 | 115 | def get_trajectory(self) -> list: 116 | return self.trajectory 117 | 118 | def save_position(self, model): 119 | # we solve the equation Ax = b using least squares, where A is the matrix of basis vectors 120 | b = wrap_model(model, self.agent_interface).get_parameter_tensor().as_numpy() 121 | self.trajectory.append(np.linalg.lstsq(self.A, b, rcond=None)[0]) 122 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from loss_landscapes.metrics.metric import Metric, MetricPipeline 2 | from loss_landscapes.metrics.rl_metrics import ExpectedReturnMetric 3 | from loss_landscapes.metrics.sl_metrics import Loss, LossGradient, LossPerturbations 4 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/metrics/metric.py: -------------------------------------------------------------------------------- 1 | """ Base classes for model evaluation metrics. """ 2 | 3 | from abc import ABC, abstractmethod 4 | from loss_landscapes.model_interface.model_wrapper import ModelWrapper 5 | 6 | 7 | class Metric(ABC): 8 | """ A quantity that can be computed given a model or an agent. """ 9 | 10 | def __init__(self): 11 | super().__init__() 12 | 13 | @abstractmethod 14 | def __call__(self, model_wrapper: ModelWrapper): 15 | pass 16 | 17 | 18 | class MetricPipeline(Metric): 19 | """ A sequence of metrics to be computed in order, given a model or an agent. """ 20 | 21 | def __init__(self, metrics: list): 22 | super().__init__() 23 | self.metrics = metrics 24 | 25 | def __call__(self, model_wrapper: ModelWrapper) -> tuple: 26 | return tuple([metric(model_wrapper) for metric in self.metrics]) 27 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/metrics/rl_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd 3 | from loss_landscapes.metrics.metric import Metric 4 | 5 | 6 | class ExpectedReturnMetric(Metric): 7 | def __init__(self, gym_environment, n_episodes): 8 | super().__init__() 9 | self.gym_environment = gym_environment 10 | self.n_episodes = n_episodes 11 | 12 | def __call__(self, agent): 13 | returns = [] 14 | 15 | # compute total return for each episode 16 | for episode in range(self.n_episodes): 17 | episode_return = 0 18 | obs, reward, done, _ = self.gym_environment.step( 19 | agent(torch.from_numpy(self.gym_environment.reset()).float()) 20 | ) 21 | episode_return += reward 22 | 23 | while not done: 24 | obs, reward, done, info = self.gym_environment.step( 25 | agent(torch.from_numpy(obs).float()) 26 | ) 27 | episode_return += reward 28 | returns.append(episode_return) 29 | 30 | # return average of episode returns 31 | return sum(returns) / len(returns) 32 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/model_interface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyber-meow/anime_screenshot_pipeline/c9e3fb804c3847d136c2124a68c7af4b17ef3219/scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/model_interface/__init__.py -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/loss_landscapes/model_interface/model_wrapper.py: -------------------------------------------------------------------------------- 1 | """ Class used to define interface to complex models """ 2 | 3 | import abc 4 | import itertools 5 | import torch.nn 6 | from loss_landscapes.model_interface.model_parameters import ModelParameters 7 | 8 | 9 | class ModelWrapper(abc.ABC): 10 | def __init__(self, modules: list): 11 | self.modules = modules 12 | 13 | def get_modules(self) -> list: 14 | return self.modules 15 | 16 | def get_module_parameters(self) -> ModelParameters: 17 | return ModelParameters([p for module in self.modules for p in module.parameters()]) 18 | 19 | def train(self, mode=True) -> 'ModelWrapper': 20 | for module in self.modules: 21 | module.train(mode) 22 | return self 23 | 24 | def eval(self) -> 'ModelWrapper': 25 | return self.train(False) 26 | 27 | def requires_grad_(self, requires_grad=True) -> 'ModelWrapper': 28 | for module in self.modules: 29 | for p in module.parameters(): 30 | p.requires_grad = requires_grad 31 | return self 32 | 33 | def zero_grad(self) -> 'ModelWrapper': 34 | for module in self.modules: 35 | for p in module.parameters(): 36 | if p.grad is not None: 37 | p.grad.detach_() 38 | p.grad.zero_() 39 | return self 40 | 41 | def parameters(self): 42 | return itertools.chain([module.parameters() for module in self.modules]) 43 | 44 | def named_parameters(self): 45 | return itertools.chain([module.named_parameters() for module in self.modules]) 46 | 47 | @abc.abstractmethod 48 | def forward(self, x): 49 | pass 50 | 51 | 52 | class SimpleModelWrapper(ModelWrapper): 53 | def __init__(self, model: torch.nn.Module): 54 | super().__init__([model]) 55 | 56 | def forward(self, x): 57 | return self.modules[0](x) 58 | 59 | 60 | class GeneralModelWrapper(ModelWrapper): 61 | def __init__(self, model, modules: list, forward_fn): 62 | super().__init__(modules) 63 | self.model = model 64 | self.forward_fn = forward_fn 65 | 66 | def forward(self, x): 67 | return self.forward_fn(self.model, x) 68 | 69 | 70 | def wrap_model(model): 71 | if isinstance(model, ModelWrapper): 72 | return model.requires_grad_(False) 73 | elif isinstance(model, torch.nn.Module): 74 | return SimpleModelWrapper(model).requires_grad_(False) 75 | else: 76 | raise ValueError('Only models of type torch.nn.modules.module.Module can be passed without a wrapper.') 77 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | tqdm 4 | torch 5 | torchvision -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file(s) in the wheel. 3 | license_files = LICENSE.txt 4 | 5 | [bdist_wheel] 6 | 7 | 8 | # support. Removing this line (or setting universal to 0) will prevent 9 | 10 | # bdist_wheel from trying to make a universal wheel. For more see: 11 | 12 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#wheels 13 | 14 | universal=0 -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/loss-landscapes/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from os import path 3 | 4 | # Get the long description from the README file 5 | with open(path.join(path.abspath(path.dirname(__file__)), 'README.md'), encoding='utf-8') as f: 6 | long_description = f.read() 7 | 8 | setup( 9 | name='loss_landscapes', 10 | version='3.0.7', 11 | packages=find_packages(exclude='tests'), 12 | url='https://github.com/marcellodebernardi/loss-landscapes', 13 | license='MIT', 14 | author='Marcello De Bernardi', 15 | author_email='marcello.debernardi@stcatz.ox.ac.uk', 16 | description='A library for approximating loss landscapes in low-dimensional parameter subspaces', 17 | long_description=long_description, 18 | long_description_content_type='text/markdown', 19 | python_requires='>=3.5', 20 | install_requires=['numpy'], 21 | classifiers=[ 22 | 'Development Status :: 4 - Beta', 23 | 'Intended Audience :: Developers', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | 'Programming Language :: Python :: 3.7', 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/matcher.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import linear_sum_assignment 2 | 3 | import torch 4 | from torch import nn 5 | 6 | #target_classes[idx] = target_classes_o 7 | 8 | def _get_src_permutation_idx(self, indices): 9 | # permute predictions following indices 10 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 11 | src_idx = torch.cat([src for (src, _) in indices]) 12 | return batch_idx, src_idx 13 | 14 | class HungarianMatcher(nn.Module): 15 | def __init__(self, cost=1.0): 16 | super().__init__() 17 | self.cost = cost 18 | 19 | @torch.no_grad() 20 | def forward(self, outputs, targets): 21 | bs, num_queries = outputs.shape[:2] 22 | 23 | out_prob = outputs.flatten(0, 1).softmax(-1) 24 | tgt_ids = targets.flatten() 25 | 26 | cost_class = -out_prob[:, tgt_ids] 27 | 28 | cost_matrix = self.cost * cost_class 29 | cost_matrix = cost_matrix.view(bs, num_queries, -1).cpu() 30 | 31 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(num_queries, -1))] 32 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 33 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/plot_mask_schedules.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | from scheduler import MasksSchedule 7 | 8 | plt.rcParams['font.family'] = 'serif' 9 | plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif'] 10 | 11 | try: 12 | schedule = sys.argv[1] 13 | except: 14 | schedule = 'sigmoid' 15 | 16 | batch_size = 16 17 | no_epochs = 100 18 | steps_per_epoch = 20000 19 | total_steps = no_epochs * steps_per_epoch 20 | 21 | cdwu_percent = 0.1 22 | cdwu_steps = int(total_steps*cdwu_percent) 23 | 24 | max_text_seq_len = 16 25 | 26 | # device=device, mask_schedule=args.mask_schedule, 27 | # masking_behavior=args.masking_behavior, tokenizer=args.tokenizer, vocab_size=args.vocab_size, 28 | # batch_size=args.batch_size, max_text_seq_len=args.max_text_seq_len, 29 | # warmup_steps=mask_wu_steps, cooldown_steps=mask_cd_steps, total_steps=total_steps, cycles=.5 30 | 31 | mask_scheduler = MasksSchedule(torch.device('cpu'), schedule, 32 | 'constant', 'wp', 30522, batch_size, max_text_seq_len, cdwu_steps, cdwu_steps, total_steps) 33 | 34 | sample_captions = torch.tensor([[101, 1015, 2611, 2630, 2159, 2829, 2606, 1015, 1015, 2611, 5967, 9427, 2849, 10557, 2159, 102]], dtype=torch.int64) 35 | sample_captions = sample_captions.repeat(batch_size, 1) 36 | 37 | masked_percent_list = [] 38 | 39 | for step in range(total_steps): 40 | 41 | captions_updated, labels_text = mask_scheduler.ret_mask([step], sample_captions) 42 | 43 | text_len = captions_updated.shape[0] * captions_updated.shape[1] 44 | masked_text_len = torch.where(captions_updated==1, 1, 0).sum().item() 45 | masked_percent = ( masked_text_len / text_len) * 100 46 | masked_percent_list.append(masked_percent) 47 | 48 | if step % 5000 == 0: 49 | print(step/total_steps, text_len, masked_text_len, masked_percent) 50 | # print(step, sample_captions, captions_updated, labels_text) 51 | 52 | print(len(masked_percent_list)) 53 | plt.plot(np.arange(total_steps), masked_percent_list) 54 | plt.ylim([-1, 101]) 55 | plt.xlabel('Global step') 56 | plt.ylabel('Tokens (text) masked (%)') 57 | plt.title('Percentage of tokens masked as function of training progress') 58 | plt.show() 59 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | 5 | import torch 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | class WarmupCosineSchedule(LambdaLR): 9 | """ Linear warmup and then cosine decay. 10 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 11 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 12 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 13 | """ 14 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 15 | self.warmup_steps = warmup_steps 16 | self.t_total = t_total 17 | self.cycles = cycles 18 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 19 | 20 | def lr_lambda(self, step): 21 | if step < self.warmup_steps: 22 | return float(step) / float(max(1.0, self.warmup_steps)) 23 | # progress after warmup 24 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 25 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 26 | 27 | 28 | class MasksSchedule(): 29 | 30 | def __init__(self, device, mask_schedule, masking_behavior, 31 | tokenizer, vocab_size, batch_size, max_text_seq_len, 32 | warmup_steps, cooldown_steps, total_steps, cycles=.5): 33 | 34 | self.device = device 35 | 36 | self.mask_schedule = mask_schedule 37 | self.masking_behavior = masking_behavior 38 | self.tokenizer = tokenizer 39 | self.vocab_size = vocab_size 40 | 41 | self.batch_size = batch_size 42 | self.max_text_seq_len = max_text_seq_len 43 | 44 | self.warmup_steps = warmup_steps 45 | self.cooldown_steps = cooldown_steps 46 | self.total_steps = total_steps 47 | self.cycles = cycles 48 | 49 | if self.tokenizer == 'wp': 50 | # 0 is [PAD], 101 is [CLS], 102 is [SEP] 51 | self.special_tokens = [0, 101, 102] 52 | elif self.tokenizer == 'tag': 53 | # 0 is [PAD], 2 is [CLS], 3 is [SEP] 54 | self.special_tokens = [0, 2, 3] 55 | 56 | def ret_mask(self, step, tokens_text=None): 57 | step = step[0] 58 | 59 | if self.mask_schedule == None: 60 | return None, None 61 | 62 | elif self.mask_schedule == 'bert': 63 | # 15 % masking like bert but only mask (0) or 1 64 | masks = torch.from_numpy(np.random.choice(a=[0, 1], size=(self.batch_size, self.max_text_seq_len), 65 | p=[0.15, 0.85])).to(self.device) 66 | 67 | elif self.mask_schedule == 'full': 68 | # from beginning all masks equal to 1 69 | masks = torch.from_numpy(np.random.choice(a=[0, 1], size=(self.batch_size, self.max_text_seq_len), 70 | p=[1, 0])).to(self.device) 71 | 72 | elif self.mask_schedule == 'sigmoid': 73 | # during warmup attend to all tokens 74 | # during cooldown attend to no text tokens 75 | # else attend to a percentage of text tokens following cosine function 76 | 77 | if step < self.warmup_steps: 78 | masks = torch.from_numpy(np.random.choice(a=[0, 1], size=(self.batch_size, self.max_text_seq_len), 79 | p=[0, 1])).to(self.device) 80 | 81 | elif step > (self.total_steps - self.cooldown_steps): 82 | masks = torch.from_numpy(np.random.choice(a=[0, 1], size=(self.batch_size, self.max_text_seq_len), 83 | p=[1, 0])).to(self.device) 84 | 85 | else: 86 | progress = (float(step - self.warmup_steps) / 87 | (float(max(1, self.total_steps - self.warmup_steps - self.cooldown_steps)))) 88 | 89 | prob_visible = max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 90 | prob_mask = 1.0 - prob_visible 91 | 92 | masks = torch.from_numpy(np.random.choice(a=[0, 1], size=(self.batch_size, self.max_text_seq_len), 93 | p=[prob_mask, prob_visible])).to(self.device) 94 | 95 | if self.masking_behavior == 'constant': 96 | # if mask then change token to 1 (unused token) 97 | updated_numbers = torch.ones(self.batch_size, self.max_text_seq_len, dtype=torch.int64).to(self.device) 98 | elif self.masking_behavior == 'random': 99 | updated_numbers = torch.randint(0, self.vocab_size-1, (self.batch_size, self.max_text_seq_len)).to(self.device) 100 | 101 | tokens_text_updated = torch.where( 102 | (masks==0) & (tokens_text!=self.special_tokens[0]) & (tokens_text!=self.special_tokens[1]) & (tokens_text!=self.special_tokens[2]), 103 | updated_numbers, tokens_text) 104 | labels_text = torch.where( 105 | (masks==1) | (tokens_text==self.special_tokens[0]) | (tokens_text==self.special_tokens[1]) | (tokens_text==self.special_tokens[2]), 106 | -100, tokens_text) 107 | 108 | return tokens_text_updated, labels_text 109 | -------------------------------------------------------------------------------- /scripts_v1/classifier_training/utilities/video_transform.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import cv2 4 | 5 | def video_transform(args): 6 | vid_in = cv2.VideoCapture(args.video_path) 7 | 8 | vid_out_name = os.path.splitext(args.video_path)[0]+'_fps={}_size={}.mp4'.format(args.fps, args.vid_size) 9 | vid_out = cv2.VideoWriter(vid_out_name, cv2.VideoWriter_fourcc(*'mp4v'), args.fps, (args.vid_size, args.vid_size)) 10 | 11 | while(vid_in.isOpened()): 12 | ret, frame = vid_in.read() 13 | if ret: 14 | img = cv2.resize(frame, (args.vid_size, args.vid_size)) 15 | vid_out.write(img) 16 | else: 17 | break 18 | print('Finished saving video: ', vid_out_name) 19 | 20 | vid_in.release() 21 | vid_out.release() 22 | cv2.destroyAllWindows() 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--video_path", required=True, 27 | help="Path for the video.") 28 | parser.add_argument("--fps", default=4, type=int, 29 | help="Number of frames per second.") 30 | parser.add_argument("--vid_size", default=640, type=int, 31 | help="Height and width of frames.") 32 | args = parser.parse_args() 33 | 34 | video_transform(args) 35 | 36 | if __name__ == '__main__': 37 | main() -------------------------------------------------------------------------------- /scripts_v1/classifier_training/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyber-meow/anime_screenshot_pipeline/c9e3fb804c3847d136c2124a68c7af4b17ef3219/scripts_v1/classifier_training/vocab.pkl -------------------------------------------------------------------------------- /scripts_v1/correct_metadata_from_foldername.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | 8 | 9 | def get_files_recursively(folder_path): 10 | allowed_patterns = [ 11 | '*.[Pp][Nn][Gg]', '*.[Jj][Pp][Gg]', '*.[Jj][Pp][Ee][Gg]', 12 | '*.[Gg][Ii][Ff]', '*.[Ww][Ee][Bb][Pp]', 13 | ] 14 | 15 | image_path_list = [ 16 | str(path) for pattern in allowed_patterns 17 | for path in Path(folder_path).rglob(pattern) 18 | ] 19 | 20 | return image_path_list 21 | 22 | 23 | def correct_metadata(path, path_format, character_list, use_subject_file): 24 | json_file = os.path.splitext(path)[0] + '.json' 25 | if not os.path.exists(json_file): 26 | print(f'Warning: {json_file} unfound, skip') 27 | return 28 | with open(json_file, 'r') as f: 29 | metadata = json.load(f) 30 | to_correct = path_format.split('/') 31 | dirname = os.path.dirname(path) 32 | for folder_type in reversed(to_correct): 33 | dirname, basename = os.path.split(dirname) 34 | correct_metadata_single(metadata, folder_type, 35 | basename, character_list) 36 | with open(json_file, 'w') as f: 37 | json.dump(metadata, f) 38 | 39 | subject_file = os.path.splitext(path)[0] + '.subjects' 40 | 41 | if use_subject_file and os.path.exists(subject_file): 42 | with open(subject_file, 'r') as f: 43 | characters_from_subject = list( 44 | map(lambda x: x.strip(), f.read().strip().split(';'))) 45 | if characters_from_subject == [""]: 46 | characters_from_subject = [] 47 | with open(json_file, 'r') as f: 48 | metadata = json.load(f) 49 | metadata['character'] = characters_from_subject 50 | with open(json_file, 'w') as f: 51 | json.dump(metadata, f) 52 | 53 | 54 | def correct_metadata_single( 55 | metadata, folder_type, basename, character_list=None): 56 | if folder_type == '*': 57 | return 58 | elif folder_type == 'n_faces': 59 | if basename == '1face': 60 | metadata['n_faces'] = 1 61 | else: 62 | count = basename.rstrip('faces') 63 | if count.isnumeric(): 64 | count = int(count) 65 | metadata['n_faces'] = count 66 | elif folder_type == 'n_people': 67 | if basename == '1person': 68 | metadata['n_people'] = 1 69 | else: 70 | count = basename.rstrip('people') 71 | if count.isnumeric(): 72 | count = int(count) 73 | metadata['n_people'] = count 74 | elif folder_type == 'character': 75 | if basename in ['character_others', 'others']: 76 | return 77 | if basename == 'ood': 78 | characters = [] 79 | else: 80 | characters = sorted(list(set(basename.split('+')))) 81 | for to_remove in ['unknown', 'ood']: 82 | if to_remove in characters: 83 | characters.remove(to_remove) 84 | if 'character' in metadata: 85 | characters_in_meta = sorted( 86 | list(set(metadata['character']))) 87 | for to_remove in ['unknown', 'ood']: 88 | if to_remove in characters_in_meta: 89 | characters_in_meta.remove(to_remove) 90 | # No need for correction if metadata agrees with folder name 91 | # Use the original metadata preserve order that correspond to 92 | # that of facepos 93 | if characters == characters_in_meta: 94 | return 95 | if character_list is not None: 96 | for character in characters: 97 | assert character in character_list, \ 98 | f'Invalid character {character} for {basename}' 99 | metadata['character'] = characters 100 | else: 101 | print(f'Warning: invalid folder type {folder_type}') 102 | 103 | 104 | if __name__ == '__main__': 105 | 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument( 108 | '--src_dir', type=str, 109 | help='Directory to load images') 110 | parser.add_argument( 111 | '--format', type=str, default='*/character', 112 | help='Description of the output directory hierarchy' 113 | ) 114 | parser.add_argument( 115 | "--character_list", type=str, default=None, 116 | help="Txt file containing character names separated " 117 | + "by comma or new line") 118 | parser.add_argument( 119 | "--use_subject_file", action="store_true", 120 | help="Use the subject file (if available) to replace character names in the json file") 121 | args = parser.parse_args() 122 | 123 | if args.character_list is not None: 124 | with open(args.character_list, 'r') as f: 125 | lines = f.readlines() 126 | character_list = [] 127 | for line in lines: 128 | character_list.extend(line.strip().split(',')) 129 | print(character_list) 130 | else: 131 | character_list = None 132 | 133 | paths = get_files_recursively(args.src_dir) 134 | 135 | for path in tqdm(paths): 136 | correct_metadata(path, args.format, character_list, 137 | args.use_subject_file) 138 | -------------------------------------------------------------------------------- /scripts_v1/danbooru_tag_tree/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /scripts_v1/danbooru_tag_tree/parse_page.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from bs4 import BeautifulSoup as Soup 3 | 4 | 5 | def query(d, l): 6 | for (i, _) in l: 7 | d = d[i] 8 | return d 9 | 10 | def clean_title(x: str): 11 | return x.replace('tag group:', '').replace('Tag group:', '').replace('List of ', '') 12 | 13 | 14 | def parse_page(url): 15 | try: 16 | html = requests.get(f'https://danbooru.donmai.us{url}').text 17 | soup = Soup(html) 18 | 19 | content = soup.find(id='wiki-page-body') 20 | if content is None: 21 | print(url, 'None') 22 | return None 23 | elements = list(content.children) 24 | 25 | group_dict = {} 26 | now_stack = [] 27 | before = '' 28 | for i in elements: 29 | match i.name: 30 | case 'h3'|'h4'|'h5'|'h6' as level: 31 | print(level) 32 | title = clean_title(i.text) 33 | if title == 'See also': break 34 | 35 | pop = 0 36 | while (before>=level or before=='ul') and now_stack: 37 | prev_title, before = now_stack.pop() 38 | print(prev_title, before) 39 | pop += 1 40 | if before=level or before=='ul') and now_stack: 32 | prev_title, before = now_stack.pop() 33 | print(prev_title, before) 34 | pop += 1 35 | if before None: 16 | self.name = name 17 | self.is_tag = is_tag 18 | self.parent = parent 19 | self.childs = childs 20 | 21 | def __init__(self) -> None: 22 | self.node_table: dict[str, TagTree.Node] = {} 23 | self.root = self.Node('root', False, None, []) 24 | 25 | def build_from_json(self, json_file): 26 | with open(json_file, 'r', encoding='utf-8') as f: 27 | data = json.load(f) 28 | self.build_from_dict(data) 29 | 30 | def _build_from_dict(self, root, key, data): 31 | if data is None: 32 | return None 33 | if isinstance(data, str): 34 | return TagTree.Node(key, True, root, {}) 35 | self_data = data.get('self', None) 36 | 37 | is_tag = isinstance(self_data, str) 38 | if is_tag: 39 | data.pop('self') 40 | new_node = TagTree.Node( 41 | key, is_tag, root 42 | ) 43 | 44 | all_childs = {} 45 | # if isinstance(self_data, dict): 46 | # self_node = TagTree.Node(f'{root.name}-self', 'self') 47 | # all_childs[key] = self._build_from_dict(root, k, v) 48 | for k, v in data.items(): 49 | if v is None: continue 50 | child = self._build_from_dict(new_node, k, v) 51 | self.node_table[k] = self.node_table.get(k, []) + [child] 52 | all_childs[k] = child 53 | 54 | new_node.childs = all_childs 55 | return new_node 56 | 57 | def build_from_dict(self, data: dict): 58 | all_childs = {} 59 | for k, v in data.items(): 60 | if v is None: continue 61 | child = self._build_from_dict(self.root, k, v) 62 | self.node_table[k] = self.node_table.get(k, []) + [child] 63 | all_childs[k] = child 64 | 65 | self.root.childs = all_childs 66 | 67 | def find_nodes(self, query: list[str], reverse_query=False): 68 | query = list(query) 69 | if reverse_query: 70 | query.reverse() 71 | query_root = query.pop(0) 72 | if query_root not in self.node_table: 73 | raise ValueError('Tag/Groups not Found !') 74 | target_node = self.node_table[query_root] 75 | if len(target_node)>1: 76 | raise ValueError( 77 | 'Have multiple groups with same name, ' 78 | 'please give some parent group for querying' 79 | ) 80 | 81 | target_node = target_node[0] 82 | for i in query: 83 | if i not in target_node.childs: 84 | raise ValueError('Tag/Groups not Found !') 85 | target_node = target_node.childs[i] 86 | return target_node 87 | 88 | def get_groups(self, query, reverse_query=False): 89 | if isinstance(query, str): 90 | query = [query] 91 | 92 | target_node = self.find_nodes(query, reverse_query) 93 | all_groups = [target_node.name] 94 | while target_node.parent is not None: 95 | target_node = target_node.parent 96 | all_groups.append(target_node.name) 97 | 98 | return all_groups 99 | 100 | def _get_tags(self, node: 'TagTree.Node'): 101 | res = [] 102 | if node.is_tag: 103 | res = [node.name] 104 | for i in node.childs.values(): 105 | res += self._get_tags(i) 106 | return res 107 | 108 | def get_tags(self, query, reverse_query=False): 109 | if isinstance(query, str): 110 | query = [query] 111 | target_node = self.find_nodes(query, reverse_query) 112 | all_tag = self._get_tags(target_node) 113 | return all_tag 114 | 115 | 116 | # tree = TagTree() 117 | # tree.build_from_json('./tag_tree.json') 118 | # 119 | # tag = ['Attire'] 120 | # print(f'query tag/groups: {tag}') 121 | # print(tree.get_tags(tag, reverse_query=True)) 122 | # print(tree.get_groups(tag)) 123 | -------------------------------------------------------------------------------- /scripts_v1/detect_faces.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import os 4 | import json 5 | 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | from anime_face_detector import create_detector 11 | 12 | 13 | def get_files_recursively(folder_path): 14 | allowed_patterns = [ 15 | '*.[Pp][Nn][Gg]', '*.[Jj][Pp][Gg]', '*.[Jj][Pp][Ee][Gg]', 16 | '*.[Gg][Ii][Ff]', '*.[Ww][Ee][Bb][Pp]', 17 | ] 18 | 19 | image_path_list = [ 20 | str(path) for pattern in allowed_patterns 21 | for path in Path(folder_path).rglob(pattern) 22 | ] 23 | 24 | return image_path_list 25 | 26 | 27 | def detect_faces(detector, 28 | image, 29 | score_thres=0.75, 30 | ratio_thres=2, 31 | debug=False): 32 | preds = detector(image) # bgr 33 | h, w = image.shape[:2] 34 | facedata = { 35 | 'n_faces': 0, 36 | 'facepos': [], 37 | 'fh_ratio': 0, 38 | 'cropped': False, 39 | } 40 | 41 | for pred in preds: 42 | bb = pred['bbox'] 43 | score = bb[-1] 44 | left, top, right, bottom = [int(pos) for pos in bb[:4]] 45 | fw, fh = right - left, bottom - top 46 | # ignore the face if too far from square or too low score 47 | if (fw / fh > ratio_thres or 48 | fh / fw > ratio_thres or score < score_thres): 49 | continue 50 | facedata['n_faces'] = facedata['n_faces'] + 1 51 | left_rel = left / w 52 | top_rel = top / h 53 | right_rel = right / w 54 | bottom_rel = bottom / h 55 | facedata['facepos'].append( 56 | [left_rel, top_rel, right_rel, bottom_rel]) 57 | if fh / h > facedata['fh_ratio']: 58 | facedata['fh_ratio'] = fh / h 59 | if debug: 60 | cv2.rectangle(image, (left, top), (right, bottom), (255, 0, 255), 61 | 4) 62 | 63 | return facedata 64 | 65 | 66 | def main(args): 67 | 68 | print("loading face detector.") 69 | detector = create_detector('yolov3') 70 | 71 | print("processing.") 72 | 73 | paths = get_files_recursively(args.src_dir) 74 | 75 | for path in tqdm(paths): 76 | # print(path) 77 | filename_noext = os.path.splitext(path)[0] 78 | 79 | try: 80 | image = cv2.imdecode( 81 | np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED) 82 | except cv2.error as e: 83 | print(f'Error reading the image {path}: {e}') 84 | continue 85 | if image is None: 86 | print(f'Error reading the image {path}: get None') 87 | continue 88 | if len(image.shape) == 2: 89 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 90 | if image.shape[2] == 4: 91 | # print(f"image has alpha. ignore: {path}") 92 | image = image[:, :, :3].copy() 93 | 94 | h, w = image.shape[:2] 95 | 96 | facedata = detect_faces(detector, 97 | image, 98 | score_thres=args.score_thres, 99 | ratio_thres=args.ratio_thres, 100 | debug=args.debug) 101 | 102 | json_file = f"{filename_noext}.json" 103 | if os.path.exists(json_file): 104 | with open(json_file, "r") as f: 105 | metadata = json.load(f) | facedata 106 | else: 107 | metadata = facedata 108 | 109 | with open(json_file, "w") as f: 110 | json.dump(metadata, f) 111 | 112 | 113 | if __name__ == '__main__': 114 | 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument( 117 | "--src_dir", type=str, 118 | help="Directory to load images") 119 | parser.add_argument( 120 | "--score_thres", 121 | type=float, 122 | default=0.75, 123 | help="Score threshold above which is counted as face") 124 | parser.add_argument( 125 | "--ratio_thres", 126 | type=float, 127 | default=2, 128 | help="Ratio threshold below which is counted as face") 129 | parser.add_argument( 130 | "--debug", 131 | action="store_true", 132 | help="Render rect for face") 133 | args = parser.parse_args() 134 | 135 | main(args) 136 | -------------------------------------------------------------------------------- /scripts_v1/extract_frames.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from anime2sd import extract_and_remove_similar 4 | 5 | 6 | if __name__ == "__main__": 7 | # Parse command line arguments 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--src_dir", default='.', 10 | help="directory containing source files") 11 | parser.add_argument("--dst_dir", default='.', 12 | help="directory to save output files") 13 | parser.add_argument("--prefix", default='', help="output file prefix") 14 | parser.add_argument("--ep_init", 15 | type=int, 16 | default=1, 17 | help="episode number to start with") 18 | parser.add_argument( 19 | "--similar_thresh", 20 | type=float, 21 | default=0.985, 22 | help="cosine similarity threshold for image duplicate detection") 23 | parser.add_argument("--no-remove-similar", 24 | action="store_true", 25 | help="flag to not remove similar images") 26 | args = parser.parse_args() 27 | 28 | # Process the files 29 | extract_and_remove_similar(args.src_dir, args.dst_dir, args.prefix, 30 | args.ep_init, thresh=args.similar_thresh, 31 | to_remove_similar=not args.no_remove_similar) 32 | -------------------------------------------------------------------------------- /scripts_v1/rename_character.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import csv 4 | import shutil 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | 8 | 9 | def get_files_recursively(folder_path): 10 | allowed_patterns = [ 11 | '*.[Pp][Nn][Gg]', '*.[Jj][Pp][Gg]', '*.[Jj][Pp][Ee][Gg]', 12 | ] 13 | 14 | image_path_list = [ 15 | str(path) for pattern in allowed_patterns 16 | for path in Path(folder_path).rglob(pattern) 17 | ] 18 | 19 | return image_path_list 20 | 21 | 22 | def read_class_mapping(class_mapping_csv): 23 | class_mapping = {} 24 | with open(class_mapping_csv, 'r') as f: 25 | reader = csv.reader(f) 26 | for row in reader: 27 | old_class, new_class = row 28 | class_mapping[old_class] = new_class 29 | return class_mapping 30 | 31 | 32 | def rename_folder(folder_name, class_mapping, drop_unknown_class=False): 33 | dirname, folder_name = os.path.split(folder_name) 34 | old_classes = folder_name.split('+') 35 | new_classes = [] 36 | unknown_class = False 37 | for old_class in old_classes: 38 | if old_class in class_mapping: 39 | new_class = class_mapping[old_class] 40 | else: 41 | new_class = old_class 42 | if new_class not in class_mapping.values(): 43 | unknown_class = True 44 | new_classes.append(new_class) 45 | if unknown_class and drop_unknown_class: 46 | return None 47 | return os.path.join(dirname, '+'.join(new_classes)) 48 | 49 | 50 | def modify_tags_file(tags_file, class_mapping): 51 | with open(tags_file, 'r') as f: 52 | lines = f.readlines() 53 | new_lines = [] 54 | for line in lines: 55 | if line.startswith('character:'): 56 | old_classes = line.lstrip('character:').split(',') 57 | new_classes = [] 58 | for old_class in old_classes: 59 | old_class = old_class.strip() 60 | if old_class in class_mapping: 61 | new_class = class_mapping[old_class] 62 | else: 63 | new_class = old_class 64 | new_classes.append(new_class) 65 | line = 'character: ' + ', '.join(new_classes) + '\n' 66 | new_lines.append(line) 67 | with open(tags_file, 'w') as f: 68 | f.writelines(new_lines) 69 | 70 | 71 | def modify_caption_file(caption_file, class_mapping): 72 | with open(caption_file, 'r') as f: 73 | lines = f.readlines() 74 | new_lines = [] 75 | for line in lines: 76 | for old_class in class_mapping: 77 | new_class = class_mapping[old_class] 78 | line = line.replace(old_class, new_class) 79 | new_lines.append(line) 80 | with open(caption_file, 'w') as f: 81 | f.writelines(new_lines) 82 | 83 | 84 | def rename_folder_and_tags(folder, class_mapping, drop_unknown_class=False): 85 | new_folder_name = rename_folder(folder, class_mapping, drop_unknown_class) 86 | if new_folder_name is None: 87 | shutil.rmtree(folder) 88 | return 89 | if os.path.exists(new_folder_name): 90 | for file in os.listdir(folder): 91 | new_file_path = os.path.join(new_folder_name, file) 92 | os.rename(os.path.join(folder, file), new_file_path) 93 | else: 94 | os.rename(folder, new_folder_name) 95 | for file in get_files_recursively(new_folder_name): 96 | file_noext = os.path.splitext(file)[0] 97 | tags_file = file + '.tags' 98 | if os.path.exists(tags_file): 99 | modify_tags_file(tags_file, class_mapping) 100 | caption_file = file_noext + '.txt' 101 | if os.path.exists(caption_file): 102 | modify_caption_file(caption_file, class_mapping) 103 | 104 | 105 | def get_all_subdirectories(root_dir): 106 | subfolders = [] 107 | for root, dirs, files in os.walk(root_dir): 108 | subfolders.append(root) 109 | return subfolders 110 | 111 | 112 | def main(src_dir, class_mapping_csv, drop_unknown_class): 113 | class_mapping = read_class_mapping(class_mapping_csv) 114 | for folder in tqdm(get_all_subdirectories(src_dir)): 115 | rename_folder_and_tags(os.path.join( 116 | src_dir, folder), class_mapping, drop_unknown_class) 117 | 118 | 119 | if __name__ == '__main__': 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('--src_dir', required=True, 122 | help='Path to the source directory') 123 | parser.add_argument('--class_mapping_csv', required=True, 124 | help='Path to the class mapping CSV file') 125 | parser.add_argument('--drop_unknown_class', action='store_true', 126 | help='Drop folders with unknown class names') 127 | args = parser.parse_args() 128 | main(args.src_dir, args.class_mapping_csv, args.drop_unknown_class) 129 | -------------------------------------------------------------------------------- /scripts_v1/requirements.txt: -------------------------------------------------------------------------------- 1 | anime_face_detector==0.0.9 2 | efficientnet_pytorch==0.7.1 3 | einops==0.6.0 4 | fairscale==0.4.13 5 | huggingface_hub==0.0.12 6 | matplotlib==3.6.2 7 | numpy==1.24.2 8 | opencv_python==4.4.0.46 9 | pandas==1.3.5 10 | Pillow==9.4.0 11 | scipy 12 | setuptools==65.5.0 13 | tensorflow==2.10.1 14 | timm==0.6.12 15 | toml==0.10.2 16 | torch==1.12.1 17 | torchsummary==1.5.1 18 | torchvision==0.13.1 19 | tqdm==4.64.0 20 | transformers==4.9.1 21 | -------------------------------------------------------------------------------- /scripts_v1/subsidiary/batch_resize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | from PIL import Image 5 | 6 | 7 | def get_files_recursively(folder_path): 8 | allowed_patterns = [ 9 | '*.[Pp][Nn][Gg]', '*.[Jj][Pp][Gg]', '*.[Jj][Pp][Ee][Gg]', 10 | ] 11 | 12 | image_path_list = [ 13 | str(path) for pattern in allowed_patterns 14 | for path in Path(folder_path).rglob(pattern) 15 | ] 16 | 17 | return image_path_list 18 | 19 | 20 | def resize_image(image_path, max_size): 21 | # Open the image 22 | image = Image.open(image_path) 23 | 24 | # Get the current width and height of the image 25 | width, height = image.size 26 | if max(width, height) <= max_size: 27 | return 28 | 29 | # Calculate the new size of the image based on the maximum size 30 | if width > height: 31 | new_width = max_size 32 | new_height = int((max_size / width) * height) 33 | else: 34 | new_width = int((max_size / height) * width) 35 | new_height = max_size 36 | 37 | # Resize the image 38 | image = image.resize((new_width, new_height), Image.ANTIALIAS) 39 | 40 | # Save the resized image 41 | image.save(image_path, "PNG", quality=100) 42 | 43 | 44 | if __name__ == '__main__': 45 | 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--src_dir", type=str, 48 | help="Directory to load images") 49 | parser.add_argument("--max_image_size", type=int, default=1024) 50 | args = parser.parse_args() 51 | 52 | paths = get_files_recursively(args.src_dir) 53 | # Find all .png files in the current directory and its subdirectories 54 | for image_path in tqdm(paths): 55 | # Resize the image 56 | resize_image(image_path, args.max_image_size) 57 | -------------------------------------------------------------------------------- /scripts_v1/subsidiary/find_duplicate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | 6 | 7 | def get_files_recursively(folder_path): 8 | allowed_patterns = [ 9 | '*.[Pp][Nn][Gg]', 10 | '*.[Jj][Pp][Gg]', 11 | '*.[Jj][Pp][Ee][Gg]', 12 | '*.[Gg][Ii][Ff]' 13 | '*.[Ww][Ee][Bb][Pp]' 14 | ] 15 | 16 | image_path_list = [ 17 | str(path) 18 | for pattern in allowed_patterns 19 | for path in Path(folder_path).rglob(pattern)] 20 | 21 | return image_path_list 22 | 23 | 24 | def find_duplicates(dir1, dir2): 25 | 26 | duplicates1 = [] 27 | duplicates2 = [] 28 | detected_files = [] 29 | 30 | # Get the list of files in each directory 31 | paths1 = get_files_recursively(dir1) 32 | files1 = [os.path.basename(path) for path in paths1] 33 | paths2 = get_files_recursively(dir2) 34 | files2 = [os.path.basename(path) for path in paths2] 35 | 36 | # Iterate over the files in the first directory 37 | for path1, file1 in zip(paths1, files1): 38 | # Check if the file also exists in the second directory 39 | for path2, file2 in zip(paths2, files2): 40 | if file1 == file2: 41 | duplicates1.append(path1) 42 | # Some duplicates in dir2 may remain 43 | if file1 not in detected_files: 44 | duplicates2.append(path2) 45 | break 46 | return duplicates1, duplicates2, paths1, paths2 47 | 48 | 49 | def move_to_subfolder(path1, path2, dir1, dir2, 50 | files1=None, files2=None, 51 | subfolder='duplicate'): 52 | 53 | file1_path = path1 54 | file2_path = path2 55 | 56 | duplicate_path1 = os.path.join(dir1, subfolder) 57 | duplicate_path2 = os.path.join(dir2, subfolder) 58 | if not os.path.exists(duplicate_path1): 59 | os.makedirs(duplicate_path1) 60 | if not os.path.exists(duplicate_path2): 61 | os.makedirs(duplicate_path2) 62 | file = os.path.basename(file1_path) 63 | os.rename(file1_path, os.path.join(duplicate_path1, file)) 64 | # os.rename(file2_path, os.path.join(duplicate_path2, file)) 65 | 66 | if files1 is None: 67 | files1 = os.listdir(dir1) 68 | if files2 is None: 69 | files2 = os.listdir(dir2) 70 | 71 | txt_file = f"{file1_path}.txt" 72 | if os.path.exists(txt_file): 73 | os.rename(txt_file, 74 | os.path.join(duplicate_path1, os.path.basename(txt_file))) 75 | # txt_file = f"{file2_path}.txt" 76 | # if txt_file in files2: 77 | # os.rename(os.path.join(dir2, txt_file), 78 | # os.path.join(duplicate_path2, os.path.basename(txt_file))) 79 | 80 | txt_file = f"{file1_path}.tags" 81 | if os.path.exists(txt_file): 82 | os.rename(txt_file, 83 | os.path.join(duplicate_path1, os.path.basename(txt_file))) 84 | # txt_file = f"{file2_path}.tags" 85 | # if txt_file in files2: 86 | # os.rename(os.path.join(dir2, txt_file), 87 | # os.path.join(duplicate_path2, os.path.basename(txt_file))) 88 | 89 | 90 | def remove_empty_folders(path_abs): 91 | walk = list(os.walk(path_abs)) 92 | for path, _, _ in walk[::-1]: 93 | if len(os.listdir(path)) == 0: 94 | os.rmdir(path) 95 | 96 | 97 | if __name__ == '__main__': 98 | 99 | dir1 = sys.argv[1] 100 | dir2 = sys.argv[2] 101 | duplicates1, duplicates2, paths1, paths2 = find_duplicates(dir1, dir2) 102 | print(len(duplicates1)) 103 | print(len(duplicates2)) 104 | image_extensions = ['.png', '.jpg', '.gif', '.jpeg', '.webp'] 105 | for file_path1, file_path2 in tqdm(zip(duplicates1, duplicates2)): 106 | if os.path.splitext(file_path1)[1].lower() in image_extensions: 107 | move_to_subfolder( 108 | file_path1, file_path2, dir1, dir2, paths1, paths2) 109 | remove_empty_folders(dir1) 110 | -------------------------------------------------------------------------------- /scripts_v1/subsidiary/rename_md5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import hashlib 5 | from pathlib import Path 6 | from tqdm.auto import tqdm 7 | 8 | 9 | def get_files_recursively(folder_path): 10 | allowed_patterns = [ 11 | '*.[Pp][Nn][Gg]', 12 | '*.[Jj][Pp][Gg]', 13 | '*.[Jj][Pp][Ee][Gg]', 14 | '*.[Gg][Ii][Ff]' 15 | '*.[Ww][Ee][Bb][Pp]' 16 | ] 17 | 18 | image_path_list = [ 19 | str(path) 20 | for pattern in allowed_patterns 21 | for path in Path(folder_path).rglob(pattern)] 22 | 23 | return image_path_list 24 | 25 | 26 | if __name__ == '__main__': 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--source', type=str, default='.') 30 | parser.add_argument('--dest', type=str, default='files_md5') 31 | args = parser.parse_args() 32 | 33 | if not os.path.exists(args.dest): 34 | os.makedirs(args.dest) 35 | 36 | for path in tqdm(get_files_recursively(args.source)): 37 | # Open,close, read file and calculate MD5 on its contents 38 | with open(path, 'rb') as file_to_check: 39 | # read contents of the file 40 | data = file_to_check.read() 41 | # pipe contents of the file through 42 | md5_returned = hashlib.md5(data).hexdigest() 43 | _, ext = os.path.splitext(path) 44 | new_path = os.path.join(args.dest, md5_returned + ext) 45 | shutil.copy(path, new_path) 46 | -------------------------------------------------------------------------------- /scripts_v1/subsidiary/retrieve_high_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | if __name__ == '__main__': 6 | 7 | dir1 = sys.argv[1] 8 | image_extensions = ['.png', '.jpg', '.jpeg'] 9 | file_scores = [] 10 | for filename in os.listdir(dir1): 11 | filename_noext, ext = os.path.splitext(filename) 12 | if ext.lower() in image_extensions: 13 | tag_file = filename + '.tags' 14 | with open(os.path.join(dir1, tag_file), 'r') as f: 15 | lines = f.readlines() 16 | for line in lines: 17 | if line.startswith('score:'): 18 | score = int(line.lstrip('score:').strip()) 19 | file_scores.append((score, filename)) 20 | file_scores = list(reversed(sorted(file_scores))) 21 | retain = int(sys.argv[2]) 22 | dst_folder = os.path.join(dir1, 'high_score') 23 | os.makedirs(dst_folder, exist_ok=True) 24 | for (_, file) in file_scores[:retain]: 25 | os.rename(os.path.join(dir1, file), 26 | os.path.join(dst_folder, file)) 27 | os.rename(os.path.join(dir1, file + '.tags'), 28 | os.path.join(dst_folder, file + '.tags')) 29 | -------------------------------------------------------------------------------- /scripts_v1/tagger/blip/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /scripts_v1/tagger/make_caption.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | from torchvision import transforms 10 | from torchvision.transforms.functional import InterpolationMode 11 | from blip.blip import blip_decoder 12 | 13 | from pathlib import Path 14 | # from Salesforce_BLIP.models.blip import blip_decoder 15 | 16 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | def get_files_recursively(folder_path): 20 | allowed_patterns = [ 21 | '*.[Pp][Nn][Gg]', '*.[Jj][Pp][Gg]', '*.[Jj][Pp][Ee][Gg]', 22 | '*.[Gg][Ii][Ff]', '*.[Ww][Ee][Bb][Pp]' 23 | ] 24 | 25 | image_path_list = [ 26 | str(path) for pattern in allowed_patterns 27 | for path in Path(folder_path).rglob(pattern) 28 | ] 29 | 30 | return image_path_list 31 | 32 | 33 | def main(args): 34 | # fix the seed for reproducibility 35 | seed = args.seed # + utils.get_rank() 36 | torch.manual_seed(seed) 37 | np.random.seed(seed) 38 | random.seed(seed) 39 | 40 | if not os.path.exists("blip"): 41 | args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path 42 | 43 | cwd = os.getcwd() 44 | print('Current Working Directory is: ', cwd) 45 | os.chdir('finetune') 46 | 47 | print(f"load images from {args.train_data_dir}") 48 | image_paths = get_files_recursively(args.train_data_dir) 49 | print(f"found {len(image_paths)} images.") 50 | 51 | print(f"loading BLIP caption: {args.caption_weights}") 52 | image_size = 384 53 | model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json") 54 | model.eval() 55 | model = model.to(DEVICE) 56 | print("BLIP loaded") 57 | 58 | # 正方形でいいのか? という気がするがソースがそうなので 59 | transform = transforms.Compose([ 60 | transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), 61 | transforms.ToTensor(), 62 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 63 | ]) 64 | 65 | # captioningする 66 | def run_batch(path_imgs): 67 | imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) 68 | 69 | with torch.no_grad(): 70 | if args.beam_search: 71 | captions = model.generate(imgs, sample=False, num_beams=args.num_beams, 72 | max_length=args.max_length, min_length=args.min_length) 73 | else: 74 | captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) 75 | 76 | for (image_path, _), caption in zip(path_imgs, captions): 77 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: 78 | f.write(caption + "\n") 79 | if args.debug: 80 | print(image_path, caption) 81 | 82 | b_imgs = [] 83 | for image_path in tqdm(image_paths, smoothing=0.0): 84 | raw_image = Image.open(image_path) 85 | if raw_image.mode != "RGB": 86 | print(f"convert image mode {raw_image.mode} to RGB: {image_path}") 87 | raw_image = raw_image.convert("RGB") 88 | 89 | image = transform(raw_image) 90 | b_imgs.append((image_path, image)) 91 | if len(b_imgs) >= args.batch_size: 92 | run_batch(b_imgs) 93 | b_imgs.clear() 94 | if len(b_imgs) > 0: 95 | run_batch(b_imgs) 96 | 97 | print("done!") 98 | 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 103 | parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", 104 | help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") 105 | parser.add_argument("--caption_extention", type=str, default=None, 106 | help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") 107 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 108 | parser.add_argument("--beam_search", action="store_true", 109 | help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)") 110 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 111 | parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") 112 | parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") 113 | parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") 114 | parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") 115 | parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') 116 | parser.add_argument("--debug", action="store_true", help="debug mode") 117 | 118 | args = parser.parse_args() 119 | 120 | # スペルミスしていたオプションを復元する 121 | if args.caption_extention is not None: 122 | args.caption_extension = args.caption_extention 123 | 124 | main(args) 125 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | # List of requirements 5 | # This could be retrieved from requirements.txt 6 | requirements = [] 7 | 8 | 9 | # Package (minimal) configuration 10 | setup( 11 | name="anime2sd", 12 | version="0.0.1", 13 | description="pipeline for anime all in one sd model", 14 | package_dir={"": "."}, 15 | packages=find_packages(), # __init__.py folders search 16 | install_requires=requirements 17 | ) 18 | -------------------------------------------------------------------------------- /tests/integration_tests/test_classify_characters.py: -------------------------------------------------------------------------------- 1 | # tests/classif/test_classify_characters.py 2 | import pytest 3 | import os 4 | from anime2sd.classif.classify_characters import classify_from_directory 5 | 6 | 7 | @pytest.fixture(scope="module") 8 | def test_data(): 9 | # Define your test directories 10 | root_dir = "data" 11 | src_dir = os.path.join(root_dir, "intermediate", "screenshots", "cropped_mini") 12 | dst_dir = os.path.join(root_dir, "intermediate", "screenshots", "classified_mini") 13 | os.makedirs(dst_dir, exist_ok=True) 14 | character_ref_dir = os.path.join(root_dir, "ref_images", "tearmoon") 15 | 16 | # (Optional) Setup the test data before running the tests 17 | # shutil.copytree('path_to_sample_data', src_dir) 18 | # shutil.copytree('path_to_sample_ref_images', character_ref_dir) 19 | 20 | return src_dir, dst_dir, character_ref_dir 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def test_data_inplace(): 25 | # Define your test directories 26 | root_dir = "data" 27 | src_dir = os.path.join(root_dir, "intermediate", "screenshots", "classified_mini") 28 | dst_dir = os.path.join(root_dir, "intermediate", "screenshots", "classified_mini") 29 | os.makedirs(dst_dir, exist_ok=True) 30 | character_ref_dir = os.path.join(root_dir, "ref_images", "tearmoon") 31 | 32 | return src_dir, dst_dir, character_ref_dir 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def test_data_booru(): 37 | # Define your test directories 38 | root_dir = "data" 39 | src_dir = os.path.join(root_dir, "intermediate", "booru", "cropped") 40 | dst_dir = os.path.join(root_dir, "intermediate", "booru", "classified") 41 | os.makedirs(dst_dir, exist_ok=True) 42 | character_ref_dir = os.path.join(root_dir, "ref_images", "hikikomari") 43 | 44 | return src_dir, dst_dir, character_ref_dir 45 | 46 | 47 | @pytest.fixture(scope="module") 48 | def test_data_booru_inplace(): 49 | # Define your test directories 50 | root_dir = "data" 51 | src_dir = os.path.join(root_dir, "intermediate", "booru", "classified") 52 | dst_dir = os.path.join(root_dir, "intermediate", "booru", "classified") 53 | os.makedirs(dst_dir, exist_ok=True) 54 | character_ref_dir = os.path.join(root_dir, "ref_images", "hikikomari") 55 | 56 | return src_dir, dst_dir, character_ref_dir 57 | 58 | 59 | def test_clustering(test_data): 60 | src_dir, dst_dir, _ = test_data 61 | # Call the function with the test arguments 62 | classify_from_directory( 63 | src_dir, 64 | dst_dir, 65 | None, 66 | to_extract_from_noise=True, 67 | keep_unnamed=True, 68 | clu_min_samples=5, 69 | merge_threshold=0.85, 70 | move=False, 71 | ) 72 | 73 | 74 | def test_classify_ref(test_data): 75 | src_dir, dst_dir, character_ref_dir = test_data 76 | # Call the function with the test arguments 77 | classify_from_directory( 78 | src_dir, 79 | dst_dir, 80 | character_ref_dir, 81 | to_extract_from_noise=True, 82 | keep_unnamed=True, 83 | # keep_unnamed=False, 84 | clu_min_samples=5, 85 | merge_threshold=0.85, 86 | move=False, 87 | ) 88 | 89 | 90 | def test_classify_ref_inplace(test_data_inplace): 91 | src_dir, dst_dir, character_ref_dir = test_data_inplace 92 | # Call the function with the test arguments 93 | classify_from_directory( 94 | src_dir, 95 | dst_dir, 96 | character_ref_dir, 97 | to_extract_from_noise=True, 98 | to_filter=True, 99 | # keep_unnamed=True, 100 | keep_unnamed=False, 101 | clu_min_samples=5, 102 | merge_threshold=0.85, 103 | move=True, 104 | ) 105 | 106 | 107 | def test_classify_existing(test_data_booru): 108 | src_dir, dst_dir, _ = test_data_booru 109 | # Call the function with the test arguments 110 | classify_from_directory( 111 | src_dir, 112 | dst_dir, 113 | None, 114 | to_extract_from_noise=True, 115 | keep_unnamed=True, 116 | clu_min_samples=5, 117 | merge_threshold=0.85, 118 | move=False, 119 | ) 120 | 121 | 122 | def test_classify_existing_ref_inplace(test_data_booru_inplace): 123 | src_dir, dst_dir, character_ref_dir = test_data_booru_inplace 124 | # Call the function with the test arguments 125 | classify_from_directory( 126 | src_dir, 127 | dst_dir, 128 | character_ref_dir, 129 | to_extract_from_noise=True, 130 | to_filter=True, 131 | keep_unnamed=True, 132 | clu_min_samples=5, 133 | merge_threshold=0.85, 134 | n_add_images_to_ref=5, 135 | move=True, 136 | ) 137 | -------------------------------------------------------------------------------- /tests/integration_tests/test_download_images.py: -------------------------------------------------------------------------------- 1 | from anime2sd import download_images 2 | 3 | 4 | if __name__ == "__main__": 5 | output_dir = "data/intermediate/16bit/booru/raw" 6 | tags = ["16bit_sensation"] 7 | limit_per_character = 6 8 | max_image_size = 640 9 | character_mapping = { 10 | "akisato_konoha": "Konoha", 11 | "riko_(machikado_mazoku)": "Riko", 12 | "riko_(made_in_abyss)": "", 13 | } 14 | save_aux = ["tags", "characters"] 15 | 16 | # Call the function 17 | download_images( 18 | output_dir=output_dir, 19 | tags=tags, 20 | limit_per_character=limit_per_character, 21 | max_image_size=max_image_size, 22 | character_mapping=character_mapping, 23 | save_aux=save_aux, 24 | ) 25 | -------------------------------------------------------------------------------- /tests/unit_tests/test_basics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from anime2sd.basics import parse_anime_info 3 | 4 | 5 | @pytest.mark.parametrize( 6 | "filename, expected", 7 | [ 8 | ( 9 | "[SubsPlease] 16bit Sensation - Another Layer - 07 (1080p) [771BDD0C].mkv", 10 | ("16bit Sensation - Another Layer", 7), 11 | ), 12 | ( 13 | "[HorribleSubs] Toaru Kagaku no Railgun T - 25 [1080p].mkv", 14 | ("Toaru Kagaku no Railgun T", 25), 15 | ), 16 | ("[Hayaisubs] Yama no Susume 2 - 18 [720p].mkv", ("Yama no Susume 2", 18)), 17 | # Add more test cases as needed 18 | ( 19 | "[RandomGroup] Anime Title - Extra Info - 10 [720p].mkv", 20 | ("Anime Title - Extra Info", 10), 21 | ), 22 | ( 23 | "Yama no Susume (Saison 2) 16 vostfr [720p]", 24 | ("Yama no Susume (Saison 2) 16 vostfr", None), 25 | ), 26 | ( 27 | "[Ohys-Raws] Toaru Kagaku no Railgun T - SP2 (BD 1280x720 x264 AAC).mp4", 28 | ("Toaru Kagaku no Railgun T", None), 29 | ), 30 | ( 31 | "[EA]Toaru_Kagaku_no_Railgun_T_24_[1920x1080][Hi10p][373BAEBF].mkv", 32 | ("Toaru_Kagaku_no_Railgun_T_24_", None), 33 | ), 34 | ("Only Title.mkv", ("Only Title", None)), 35 | ("[Group] Only Title - No Episode.mkv", ("Only Title", None)), 36 | ], 37 | ) 38 | def test_parse_anime_info(filename, expected): 39 | assert parse_anime_info(filename) == expected 40 | -------------------------------------------------------------------------------- /tests/unit_tests/test_select_to_add.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from anime2sd.classif import select_indices_recursively 4 | 5 | 6 | def test_even_distribution(): 7 | list_of_indices = [np.array([1, 2, 3]), np.array([4, 5, 6]), np.array([7, 8, 9])] 8 | n_to_select = 6 9 | expected_result = [1, 2, 4, 5, 7, 8] 10 | assert select_indices_recursively(list_of_indices, n_to_select) == expected_result 11 | 12 | 13 | def test_uneven_distribution(): 14 | list_of_indices = [np.array([1, 2]), np.array([3, 4, 5, 6]), np.array([7])] 15 | n_to_select = 5 16 | expected_result = [1, 2, 3, 4, 7] 17 | assert ( 18 | sorted(select_indices_recursively(list_of_indices, n_to_select)) 19 | == expected_result 20 | ) 21 | 22 | 23 | def test_more_to_select_than_available(): 24 | list_of_indices = [np.array([1]), np.array([2, 3]), np.array([4, 5])] 25 | n_to_select = 10 26 | expected_result = [1, 2, 3, 4, 5] 27 | assert ( 28 | sorted(select_indices_recursively(list_of_indices, n_to_select)) 29 | == expected_result 30 | ) 31 | 32 | 33 | def test_empty_arrays(): 34 | list_of_indices = [np.array([]), np.array([])] 35 | n_to_select = 5 36 | expected_result = [] 37 | assert select_indices_recursively(list_of_indices, n_to_select) == expected_result 38 | 39 | 40 | def test_no_arrays(): 41 | list_of_indices = [] 42 | n_to_select = 5 43 | expected_result = [] 44 | assert select_indices_recursively(list_of_indices, n_to_select) == expected_result 45 | 46 | 47 | def test_zero_to_select(): 48 | list_of_indices = [np.array([1, 2, 3]), np.array([4, 5, 6])] 49 | n_to_select = 0 50 | expected_result = [] 51 | assert select_indices_recursively(list_of_indices, n_to_select) == expected_result 52 | 53 | 54 | # Add more tests if necessary to cover additional scenarios or edge cases. 55 | 56 | # Run the tests 57 | if __name__ == "__main__": 58 | pytest.main() 59 | -------------------------------------------------------------------------------- /utilities/convert_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | 7 | 8 | def get_images_recursively(folder_path): 9 | allowed_patterns = [ 10 | '*.[Pp][Nn][Gg]', 11 | '*.[Jj][Pp][Gg]', 12 | '*.[Jj][Pp][Ee][Gg]', 13 | '*.[Ww][Ee][Bb][Pp]', 14 | '*.[Gg][Ii][Ff]', 15 | ] 16 | 17 | image_path_list = [ 18 | str(path) for pattern in allowed_patterns 19 | for path in Path(folder_path).rglob(pattern) 20 | ] 21 | 22 | return image_path_list 23 | 24 | 25 | def convert_metadata(src_dir): 26 | img_paths = get_images_recursively(src_dir) 27 | for img_path in tqdm(img_paths): 28 | meta_file_path = os.path.splitext(img_path)[0] + '.json' 29 | if os.path.exists(meta_file_path): 30 | # Rename the metadata file 31 | new_meta_file_path = os.path.join( 32 | os.path.dirname(img_path), 33 | f".{os.path.splitext(os.path.basename(img_path))[0]}_meta.json" 34 | ) 35 | os.rename(meta_file_path, new_meta_file_path) 36 | 37 | # Modify its content 38 | with open(new_meta_file_path, 'r') as meta_file: 39 | meta_data = json.load(meta_file) 40 | 41 | # Rename fields 42 | if "character" in meta_data: 43 | meta_data["characters"] = meta_data.pop("character") 44 | if "general" in meta_data: 45 | meta_data["type"] = meta_data.pop("general") 46 | 47 | # Add new fields 48 | meta_data["path"] = os.path.abspath(img_path) 49 | meta_data["current_path"] = os.path.abspath(img_path) 50 | meta_data["filename"] = os.path.basename(img_path) 51 | 52 | # Save the modified metadata 53 | with open(new_meta_file_path, 'w') as meta_file: 54 | json.dump(meta_data, meta_file, indent=4) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser(description="Convert metadata files.") 59 | parser.add_argument("--src_dir", required=True, 60 | help="Directory containing the metadata to modify.") 61 | args = parser.parse_args() 62 | convert_metadata(args.src_dir) 63 | -------------------------------------------------------------------------------- /utilities/correct_path_field.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from tqdm import tqdm 5 | 6 | 7 | from anime2sd.basics import get_corr_meta_names, get_images_recursively 8 | 9 | 10 | def find_corresponding_image(file_path, ref_images): 11 | file_path, file_ext = os.path.splitext(file_path) 12 | new_file_path = "_".join(file_path.split("_")[:-1]) + file_ext 13 | for img_path in ref_images: 14 | if os.path.basename(img_path) == os.path.basename(new_file_path): 15 | return img_path 16 | return None 17 | 18 | 19 | def update_metadata(src_dir, ref_dir): 20 | src_images = get_images_recursively(src_dir) 21 | ref_images = get_images_recursively(ref_dir) 22 | 23 | for img_path in tqdm(src_images): 24 | meta_path, _ = get_corr_meta_names(img_path) 25 | if os.path.exists(meta_path): 26 | with open(meta_path, "r") as file: 27 | metadata = json.load(file) 28 | 29 | corresponding_img_path = find_corresponding_image(img_path, ref_images) 30 | if corresponding_img_path: 31 | metadata["path"] = corresponding_img_path 32 | with open(meta_path, "w") as file: 33 | json.dump(metadata, file, indent=4) 34 | else: 35 | print(f"Warning: No corresponding image found for {img_path.name}") 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser( 40 | description="Update image metadata with reference directory paths." 41 | ) 42 | parser.add_argument( 43 | "--src_dir", required=True, help="Source directory containing image files." 44 | ) 45 | parser.add_argument( 46 | "--ref_dir", required=True, help="Reference directory to match image files." 47 | ) 48 | 49 | args = parser.parse_args() 50 | 51 | update_metadata(args.src_dir, args.ref_dir) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /utilities/count_tag_appearance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from collections import Counter 5 | 6 | 7 | def count_tags_in_directory(directory): 8 | # Counter object to hold the tags and their counts 9 | tags_counter = Counter() 10 | 11 | # Walk through the directory 12 | for root, dirs, files in os.walk(directory): 13 | for file in files: 14 | # Check if the file is a .txt file 15 | if file.endswith('.txt'): 16 | file_path = os.path.join(root, file) 17 | with open(file_path, 'r') as f: 18 | # Read the content of the file 19 | content = f.read() 20 | # Split the content by commas to get the tags 21 | tags = content.split(',') 22 | # Remove leading and trailing whitespaces from each tag 23 | tags = [tag.strip() for tag in tags] 24 | # Update the counter with the tags 25 | tags_counter.update(tags) 26 | 27 | return tags_counter 28 | 29 | 30 | def save_tags_count_to_file(tags_counter, output_file): 31 | # Sort the tags by frequency in descending order 32 | sorted_tags = sorted( 33 | tags_counter.items(), key=lambda x: x[1], reverse=True) 34 | 35 | # Write the sorted tags and their counts to the output file 36 | with open(output_file, 'w') as f: 37 | for tag, count in sorted_tags: 38 | f.write(f"{tag}: {count}\n") 39 | 40 | 41 | if __name__ == '__main__': 42 | # Create an ArgumentParser object 43 | parser = argparse.ArgumentParser( 44 | description='Count and sort tags from text files in a directory.') 45 | 46 | # Add arguments 47 | parser.add_argument('--directory', type=str, 48 | help='Path to the directory to search for .txt files.') 49 | parser.add_argument('--output_file', type=str, 50 | default='output.txt', help='Path to the output file.') 51 | 52 | # Parse the arguments 53 | args = parser.parse_args() 54 | 55 | # Count the tags in the specified directory 56 | tags_counter = count_tags_in_directory(args.directory) 57 | 58 | # Save the tags count to the specified output file 59 | save_tags_count_to_file(tags_counter, args.output_file) 60 | -------------------------------------------------------------------------------- /utilities/get_core_tags.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from anime2sd import get_character_core_tags_and_save 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser( 7 | description="Extract frequent tags for characters.") 8 | parser.add_argument( 9 | "--src_dir", type=str, 10 | help="Path to the folder containing images and metadata.") 11 | parser.add_argument( 12 | "--frequency_threshold", type=float, default=0.5, 13 | help="Minimum frequency for a tag to be considered core tag.") 14 | parser.add_argument( 15 | "--core_tag_output", type=str, default="core_tags.json", 16 | help="Output JSON file to save the frequent tags.") 17 | parser.add_argument( 18 | "--wildcard_output", type=str, default="wildcard.txt", 19 | help="Output TXT file to save the character names and their tags.") 20 | 21 | args = parser.parse_args() 22 | get_character_core_tags_and_save( 23 | args.src_dir, args.core_tag_output, args.wildcard_output, 24 | frequency_threshold=args.frequency_threshold) 25 | -------------------------------------------------------------------------------- /utilities/rename_characters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import json 5 | from tqdm import tqdm 6 | 7 | from anime2sd.basics import get_images_recursively 8 | from anime2sd.basics import get_corr_meta_names 9 | 10 | 11 | def replace_tag_in_file(file_path, old_name, new_name, separator): 12 | with open(file_path, "r") as f: 13 | content = f.read() 14 | # Split the content by commas to get the tags 15 | tags = content.split(separator) 16 | # Replace the old_name tag with new_name if it exists 17 | tags = [new_name if tag.strip() == old_name else tag for tag in tags] 18 | # Join the tags back with commas 19 | content = separator.join(tags) 20 | with open(file_path, "w") as f: 21 | f.write(content) 22 | 23 | 24 | def replace_character_in_json(json_path, old_name, new_name): 25 | with open(json_path, "r") as f: 26 | content = json.load(f) 27 | if "characters" in content: 28 | content["characters"] = [ 29 | new_name if c == old_name else c for c in content["characters"] 30 | ] 31 | with open(json_path, "w") as f: 32 | json.dump(content, f, indent=4) 33 | 34 | 35 | def rename_folders(src_dir, old_name, new_name): 36 | for root, dirs, _ in os.walk(src_dir): 37 | for dir_name in dirs: 38 | name_parts = dir_name.split("+") 39 | if old_name in name_parts: 40 | # Replace only the exact old_name 41 | name_parts = [ 42 | new_name if part == old_name else part for part in name_parts 43 | ] 44 | new_dir_name = "+".join(name_parts) 45 | shutil.move( 46 | os.path.join(root, dir_name), os.path.join(root, new_dir_name) 47 | ) 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser(description="Replace tags in txt files.") 52 | parser.add_argument( 53 | "--src_dir", required=True, help="Source directory containing txt files." 54 | ) 55 | parser.add_argument( 56 | "--old_name", required=True, help="Old character name to be replaced." 57 | ) 58 | parser.add_argument( 59 | "--new_name", required=True, help="New character name to replace the old one." 60 | ) 61 | parser.add_argument( 62 | "--caption_separation", 63 | type=str, 64 | default=",", 65 | help="Symbol used to separate character names in caption", 66 | ) 67 | 68 | args = parser.parse_args() 69 | 70 | # Walk through the src_dir and find all txt files 71 | for img_path in tqdm(get_images_recursively(args.src_dir)): 72 | img_noext, _ = os.path.splitext(img_path) 73 | for potential_exts in [".txt", ".characters"]: 74 | potential_file = img_noext + potential_exts 75 | if os.path.exists(potential_file): 76 | replace_tag_in_file( 77 | potential_file, 78 | args.old_name, 79 | args.new_name, 80 | args.caption_separation, 81 | ) 82 | meta_path, _ = get_corr_meta_names(img_path) 83 | if os.path.exists(meta_path): 84 | replace_character_in_json(meta_path, args.old_name, args.new_name) 85 | rename_folders(args.src_dir, args.old_name, args.new_name) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /utilities/replace_tags.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def replace_tag_in_file(file_path, old_name, new_name): 6 | with open(file_path, 'r') as f: 7 | content = f.read() 8 | 9 | # Split the content by commas to get the tags 10 | tags = content.split(',') 11 | 12 | # Replace the old_name tag with new_name if it exists 13 | tags = [new_name if tag.strip() == old_name else tag for tag in tags] 14 | 15 | # Join the tags back with commas 16 | content = ','.join(tags) 17 | 18 | with open(file_path, 'w') as f: 19 | f.write(content) 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser(description="Replace tags in txt files.") 24 | parser.add_argument("--src_dir", required=True, 25 | help="Source directory containing txt files.") 26 | parser.add_argument("--old_name", required=True, 27 | help="Old tag name to be replaced.") 28 | parser.add_argument("--new_name", required=True, 29 | help="New tag name to replace the old one.") 30 | parser.add_argument("--ext", default='.txt', 31 | help="Extension of the files to replace tags.") 32 | 33 | args = parser.parse_args() 34 | 35 | # Walk through the src_dir and find all txt files 36 | for dirpath, dirnames, filenames in os.walk(args.src_dir): 37 | for filename in filenames: 38 | if filename.endswith(args.ext): 39 | file_path = os.path.join(dirpath, filename) 40 | replace_tag_in_file(file_path, args.old_name, args.new_name) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /utilities/update_safetensor_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import struct 4 | import os 5 | import shutil 6 | 7 | 8 | def read_metadata(safetensors_file_path): 9 | with open(safetensors_file_path, 'rb') as file: 10 | length_of_header = struct.unpack('