├── README.md └── lora2wildcard.py /README.md: -------------------------------------------------------------------------------- 1 | # lora2wildcard 2 | 3 | This is a small script that generates a wildcard from the lora directory of A1111 or reforge. 4 | Extract the tags used for training from the metadata in safetensors and use them in wildcard. 5 | If a json file output by civitai helper exists, it is also used. 6 | 7 | ### usage 8 | ```sh 9 | python lora2wildcard.py LORA_DIR_PATH 10 | 11 | python lora2wildcard.py -h 12 | ``` 13 | 14 | ### sample output 15 | ```sh 16 | ,1girl, 2b, black blindfold, black dress, black hairband, blindfold, clothing cutout, dress, hairband, juliet sleeves, long sleeves, mole, mole under mouth, puffy sleeves, short hair, solo, white hair, yorha no. 2 type b 17 | ,asuka soryu langley, solo 18 | ``` 19 | -------------------------------------------------------------------------------- /lora2wildcard.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import json 4 | import os 5 | import random 6 | import time 7 | from pathlib import Path 8 | from datetime import datetime 9 | 10 | from safetensors import safe_open 11 | from safetensors.torch import load_file 12 | 13 | 14 | 15 | def extract_metadata_from_safetensors(file_path): 16 | try: 17 | with safe_open(file_path, framework="pt", device="cpu") as f: 18 | metadata = f.metadata() 19 | 20 | if metadata is None: 21 | print(f"metadata not found") 22 | return None 23 | 24 | return metadata 25 | except Exception as e: 26 | print(f"{file_path}:{e}") 27 | return None 28 | 29 | def parse_tags(metadata): 30 | tags = [] 31 | 32 | possible_keys = ["ss_tag_frequency", "ss_tag_frequency_0", "tag_frequency", "tags"] 33 | 34 | for key in possible_keys: 35 | if key in metadata: 36 | try: 37 | tag_data = json.loads(metadata[key]) 38 | 39 | if isinstance(tag_data, dict): 40 | # {tag : freq} 41 | for tag, freq in tag_data.items(): 42 | if isinstance(freq, dict): 43 | for _tag, _freq in freq.items(): 44 | tags.append({"tag": _tag, "frequency": _freq}) 45 | #print(f"{_tag = } {_freq = }") 46 | else: 47 | tags.append({"tag": tag, "frequency": freq}) 48 | #print(f"{tag = } {freq = }") 49 | elif isinstance(tag_data, list): 50 | # tag list 51 | for item in tag_data: 52 | if isinstance(item, str): 53 | tags.append({"tag": item, "frequency": 1}) 54 | elif isinstance(item, dict) and "name" in item: 55 | tags.append({"tag": item["name"], "frequency": item.get("frequency", 1)}) 56 | except json.JSONDecodeError: 57 | print("json.JSONDecodeError") 58 | tags.append({"tag": metadata[key], "frequency": 1}) 59 | 60 | if "ss_character_tags" in metadata: 61 | try: 62 | char_tags = json.loads(metadata["ss_character_tags"]) 63 | if isinstance(char_tags, list): 64 | for tag in char_tags: 65 | tags.append({"tag": tag, "frequency": 1, "type": "character"}) 66 | except json.JSONDecodeError: 67 | pass 68 | 69 | if "ss_dataset_name" in metadata: 70 | tags.append({"tag": f"dataset: {metadata['ss_dataset_name']}", "frequency": 1, "type": "dataset"}) 71 | 72 | if "ss_network_args" in metadata: 73 | try: 74 | network_args = json.loads(metadata["ss_network_args"]) 75 | if "network_module" in network_args: 76 | tags.append({"tag": f"network: {network_args['network_module']}", "frequency": 1, "type": "network"}) 77 | except json.JSONDecodeError: 78 | pass 79 | 80 | return tags 81 | 82 | 83 | def generate_prompt_from_tags(tags, th = -1, prohibited_tags=[]): 84 | max_count = None 85 | res = [] 86 | for tag, count in tags: 87 | #print(f"{tag=} {count=}") 88 | if not max_count: 89 | max_count = count 90 | 91 | if tag in prohibited_tags: 92 | print(f"ignore {tag=}") 93 | continue 94 | 95 | if th < 0: 96 | v = random.random() * max_count 97 | else: 98 | v = th * max_count 99 | 100 | if count > v: 101 | if False: 102 | for x in "({[]})": 103 | tag = tag.replace(x, '\\' + x) 104 | res.append(tag) 105 | 106 | res = list(set(res)) 107 | 108 | return ", ".join(sorted(res)) 109 | 110 | def get_prompt_from_metadata(lora_path, th, prohibited_tags): 111 | metadata = extract_metadata_from_safetensors(lora_path) 112 | if metadata: 113 | tags = parse_tags(metadata) 114 | tags.sort(key=lambda x: x.get("frequency", 0), reverse=True) 115 | tags = [ (k['tag'], k['frequency']) for k in tags] 116 | return generate_prompt_from_tags(tags, th, prohibited_tags) 117 | else: 118 | print(f"{lora_path=} metadata not found!") 119 | return "" 120 | 121 | 122 | def get_activation_text_from_json(lora_path): 123 | 124 | lora_path = Path(lora_path) 125 | 126 | info_path = lora_path.with_suffix(".json") 127 | 128 | if info_path.is_file(): 129 | info = {} 130 | with open(info_path, "r", encoding="utf-8") as f: 131 | info = json.load(f) 132 | return info.get("activation text", "") 133 | else: 134 | return "" 135 | 136 | def get_safetensors_files(dir_path): 137 | if not dir_path.is_dir(): 138 | raise ValueError(f"dir not found: {dir_path}") 139 | 140 | safetensors_files = dir_path.glob('**/*.safetensors') 141 | 142 | return safetensors_files 143 | 144 | 145 | def get_time_str(): 146 | return datetime.now().strftime("%Y%m%d_%H%M%S") 147 | 148 | def get_path_token(p): 149 | parts = p.parts 150 | if len(parts) >= 2: 151 | last_two = parts[-2:] 152 | return '_'.join(last_two) 153 | elif len(parts) == 1: 154 | return parts[0] 155 | else: 156 | return "" 157 | 158 | 159 | def main(): 160 | parser = argparse.ArgumentParser(description="Extract tags used for training from lora files and generate wildcards. Omit tags that are used infrequently.") 161 | parser.add_argument("lora_dir_path", help="lora dir path") 162 | parser.add_argument("--th", "-t", type=float, default=0.5, help="tag threshold. 0.5 means to use up to half the frequency of the most frequent tag.") 163 | parser.add_argument("--weight", "-w", type=float, default=1.0, help="lora weight") 164 | parser.add_argument("--prohibited_tags", "-pt", default="simple background, white background", help="") 165 | parser.add_argument("--act", "-a", action="store_true", help="Prioritize activation text in .json files than training tags.") 166 | 167 | 168 | start_time = time.time() 169 | 170 | args = parser.parse_args() 171 | 172 | prohibited_tags = args.prohibited_tags.split(",") 173 | prohibited_tags = [s.strip() for s in prohibited_tags] 174 | 175 | print(f"{prohibited_tags = }") 176 | 177 | dir_path = Path(args.lora_dir_path) 178 | 179 | safetensors_files = get_safetensors_files(dir_path) 180 | 181 | prompt_list = [] 182 | 183 | for sf in safetensors_files: 184 | 185 | if args.act: 186 | prompt = get_activation_text_from_json(sf) 187 | else: 188 | prompt = get_prompt_from_metadata(sf, args.th, prohibited_tags) 189 | 190 | if not prompt: 191 | if args.act: 192 | prompt = get_prompt_from_metadata(sf, args.th, prohibited_tags) 193 | else: 194 | prompt = get_activation_text_from_json(sf) 195 | 196 | prompt = f"" + "," + prompt 197 | 198 | prompt_list.append(prompt) 199 | 200 | output_path = get_path_token(dir_path) + "_" + get_time_str() + ".txt" 201 | 202 | with open(output_path, "w", encoding="utf-8") as f: 203 | f.write("\n".join(prompt_list)) 204 | 205 | end_time = time.time() 206 | 207 | print(f"elapsed time : {end_time-start_time}") 208 | 209 | return 0 210 | 211 | if __name__ == "__main__": 212 | exit(main()) --------------------------------------------------------------------------------