├── GPT_prompt_helper.ipynb ├── LICENSE ├── README.md ├── SD1-sliders-inference.ipynb ├── XL-sliders-inference.ipynb ├── __init__.py ├── demo_SDXL_Turbo.ipynb ├── demo_concept_sliders.ipynb ├── demo_image_editing.ipynb ├── eval-scripts ├── .ipynb_checkpoints │ └── generate_images_xl-checkpoint.py ├── clip_score.py ├── generate_images-uce.py ├── generate_images_customdiffusion.py ├── generate_images_sd1.py ├── generate_images_textinversion.py ├── generate_images_textinversion_xl.py ├── generate_images_xl.py └── lpip_score.py ├── flux-sliders ├── flux-requirements.txt ├── train-flux-concept-sliders.ipynb └── utils │ ├── custom_flux_pipeline.py │ ├── lora.py │ ├── model_util.py │ ├── prompt_util.py │ ├── ptp_utils.py │ └── train_util.py ├── images └── main_figure.png ├── prompts ├── prompts-car.csv ├── prompts-food.csv ├── prompts-person.csv ├── prompts-room.csv └── prompts-sky.csv ├── requirements.txt └── trainscripts ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc └── __init__.cpython-39.pyc ├── imagesliders ├── __pycache__ │ ├── config_util.cpython-39.pyc │ ├── debug_util.cpython-39.pyc │ ├── lora.cpython-39.pyc │ ├── model_util.cpython-39.pyc │ ├── prompt_util.cpython-39.pyc │ └── train_util.cpython-39.pyc ├── config_util.py ├── data │ ├── config-xl.yaml │ ├── config.yaml │ ├── prompts-xl.yaml │ └── prompts.yaml ├── debug_util.py ├── lora.py ├── model_util.py ├── prompt_util.py ├── train_lora-scale-xl.py ├── train_lora-scale.py └── train_util.py └── textsliders ├── .ipynb_checkpoints ├── train_lora-checkpoint.py └── train_lora_xl-checkpoint.py ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── __init__.cpython-39.pyc ├── config_util.cpython-39.pyc ├── debug_util.cpython-39.pyc ├── lora.cpython-310.pyc ├── lora.cpython-39.pyc ├── model_util.cpython-39.pyc ├── prompt_util.cpython-39.pyc ├── ptp_utils.cpython-39.pyc └── train_util.cpython-39.pyc ├── config_util.py ├── data ├── .ipynb_checkpoints │ ├── prompts-person_age_slider_GPT-checkpoint.yaml │ ├── prompts-smile_slider_GPT-checkpoint.yaml │ └── prompts-xl-checkpoint.yaml ├── config-xl.yaml ├── config.yaml ├── prompts-animated_eyes_GPT.yaml ├── prompts-car_alienTechFuturistic_GPT.yaml ├── prompts-jewelry_diamonds_GPT.yaml ├── prompts-person_age_slider_GPT.yaml ├── prompts-person_surprised_GPT.yaml ├── prompts-smile_slider_GPT.yaml ├── prompts-xl.yaml └── prompts.yaml ├── debug_util.py ├── flush.py ├── generate_images_xl.py ├── lora.py ├── model_util.py ├── prompt_util.py ├── ptp_utils.py ├── train_lora.py ├── train_lora_xl.py └── train_util.py /GPT_prompt_helper.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5f21e08c-259d-4099-b1ed-b782bf94be05", 6 | "metadata": {}, 7 | "source": [ 8 | "Replace `key` with your openai key (do not use quotations)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 5, 14 | "id": "7f45bebd-fe6e-41a8-a6e1-bf56d765463e", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "env: OPENAI_API_KEY=key\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "%env OPENAI_API_KEY=key" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "id": "7784d6c8-e28a-494b-b97f-374fcc94213f", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import os\n", 37 | "import yaml\n", 38 | "from openai import OpenAI\n", 39 | "client = OpenAI()\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "id": "6de55bbf-b978-478c-a8f0-db412bc52e9e", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def generate_prompts_sliders(slider_query, \n", 50 | " file_name_to_save=None, \n", 51 | " temperature=0.2, \n", 52 | " max_tokens=256, \n", 53 | " frequency_penalty=0.0,\n", 54 | " model=\"gpt-4-turbo-preview\",\n", 55 | " verbose=False, \n", 56 | " save=True):\n", 57 | " '''\n", 58 | " A function to automatically build prompts for text sliders using GPT4 (or any other openAI model). \n", 59 | " \n", 60 | " Inputs\n", 61 | " ------\n", 62 | " slider_query (str): A natural language query describing the slider effects the user desired (eg: \"I want to make people older\")\n", 63 | " file_name_to_save (str) (optional): a full name of the yaml file a user desires. If left as None, a name will be chosen by GPT\n", 64 | " temperature (float) (optional): GPT temperature parameter (use smaller values for less randomness)\n", 65 | " max_tokens (int) (optional): GPT output token limit\n", 66 | " frequency_penalty (float) (optional): GPT frequency penalty\n", 67 | " model (str) (optional): The model class from openAI. By default uses GPT-4-Turbo\n", 68 | " verbose (bool) (optional): A flag to print intermediate responses by GPT\n", 69 | " save (bool) (optional): A flag to save the prompts to a destination path\n", 70 | " '''\n", 71 | " gpt_assistant_prompt = '''You are an expert in prompting text-image generation models. Given a concept to edit, your task is to generate 4 detailed prompts.\n", 72 | " 1. Target prompt: a prompt that describes the target class which the concept edit is intended to modify (for example, to edit the concept \"professional\" the target concept is \"person\". Leave it empty if the target concept is too large. For example if user asks for their generations to be more futuristic, since all the images have to be edited, just leave the target \"\"\n", 73 | " 2. Positive prompt: a detailed prompt that describes the extreme positive end of the edit concept with the target concept included (for example, \"person, very professional, blazer, neat, organized)\"\n", 74 | " 3. Negative prompt: a detailed prompt that describes the extreme negative end of the edit concept with the target concept included (for example, \"person, non-professional, ragidy, unkempt\"). This is optional, you can leave it \"\" if there is no obvious negative prompt.\n", 75 | " 4. Preservation prompt: a prompt (must be comma separated) that describes any concepts except the ones to edit that should be preserved when making the edit without the target concept included (for example, \"white, black, indian, asian, hispanic; male, female\" as the race or gender of a person may be changed when we edit the professionalism.). This should not include edit concepts and should not include any of the positive or negative concepts. if there are no obvious entanglement issues with the edit, leave the prompt \"\"\n", 76 | " make preservation prompt comma seperated for each class of perservation. For example if you want to preserve both race and gender, then give something like \"white race, black race, indian race, asian race; male, female\"\n", 77 | "\n", 78 | " All the prompts must be strictly string type. Be specific. Do not use any alphanumeric symbols.\n", 79 | " \n", 80 | " This is an example template for your response when asked to generate prompts for making people smile:\n", 81 | " Target: person\n", 82 | " Positive: person, smiling, happy face, big smile\n", 83 | " Negative: person, frowning, grumpy, sad\n", 84 | " Preservation: white, black, indian, asian, hispanic ; male, female\n", 85 | " Name: person_age_GPT\n", 86 | " \n", 87 | " Here is another example template for your response when asked - \"I want to make images more detailed\":\n", 88 | " Target: \n", 89 | " Positive: highly detailed, intricate patterns, fine textures, realistic shading\n", 90 | " Negative: simplistic, minimalistic, abstract, rough outlines\n", 91 | " Preservation: \n", 92 | " Name: detailed_GPT\n", 93 | " '''\n", 94 | " gpt_user_prompt = slider_query\n", 95 | " gpt_prompt = gpt_assistant_prompt, gpt_user_prompt\n", 96 | " message=[{\"role\": \"assistant\", \"content\": gpt_assistant_prompt}, {\"role\": \"user\", \"content\": gpt_user_prompt}]\n", 97 | " \n", 98 | " response = client.chat.completions.create(\n", 99 | " model= model,\n", 100 | " messages = message,\n", 101 | " temperature=temperature,\n", 102 | " max_tokens=max_tokens,\n", 103 | " frequency_penalty=frequency_penalty\n", 104 | " )\n", 105 | " content = response.choices[0].message.content\n", 106 | " if verbose:\n", 107 | " print(content)\n", 108 | " prompts = content.splitlines()\n", 109 | " result = {}\n", 110 | " result['target'] = \"\"\n", 111 | " result['positive'] = \"\"\n", 112 | " result['unconditional'] = \"\"\n", 113 | " result['neutral'] = \"\"\n", 114 | " for prompt in prompts:\n", 115 | " key = prompt.split(':')\n", 116 | " if key[0].lower().strip() == 'preservation':\n", 117 | " final_attributes = []\n", 118 | " attributes = key[1].split(';')\n", 119 | " for attribute in attributes:\n", 120 | " if len(attribute.strip()) == 0:\n", 121 | " continue\n", 122 | " final_attributes.append(attribute.strip().split(','))\n", 123 | " elif key[0].lower().strip() == 'name':\n", 124 | " name = key[1].strip()\n", 125 | " for prompt in prompts:\n", 126 | " key = prompt.split(':')\n", 127 | " if len(key)!=2:\n", 128 | " continue\n", 129 | " if key[0].lower().strip() == 'target':\n", 130 | " result['target'] = key[1].strip()\n", 131 | " elif key[0].lower().strip() == 'positive':\n", 132 | " result['positive'] = key[1].strip()\n", 133 | " elif key[0].lower().strip() == 'negative':\n", 134 | " result['unconditional'] = key[1].strip()\n", 135 | " result['neutral'] = result['target']\n", 136 | " results = [result]\n", 137 | " \n", 138 | " for attribute_class in final_attributes:\n", 139 | " results_final = []\n", 140 | " for attribute in attribute_class:\n", 141 | " for result in results:\n", 142 | " r = {}\n", 143 | " for key in result.keys():\n", 144 | " r[key] = attribute.strip() + f' {result[key].strip()}'\n", 145 | " r[key] = r[key].strip()\n", 146 | " results_final.append(r)\n", 147 | " \n", 148 | " results = results_final\n", 149 | " results_final = []\n", 150 | " for result in results:\n", 151 | " r_final = result\n", 152 | " r_final['guidance'] = 4\n", 153 | " r_final['rank'] = 4\n", 154 | " r_final['action'] = 'enhance'\n", 155 | " r_final['resolution'] = 512\n", 156 | " r_final['dynamic_resolution'] = False\n", 157 | " r_final['batch_size'] = 1\n", 158 | " results_final.append(r_final)\n", 159 | " if file_name_to_save is None:\n", 160 | " if name is None:\n", 161 | " file_name_to_save = 'custom-prompts-GPT.yaml'\n", 162 | " else:\n", 163 | " file_name_to_save = f'prompts-{name}.yaml'\n", 164 | " if save:\n", 165 | " with open(f'trainscripts/textsliders/data/{file_name_to_save}', 'w+') as f:\n", 166 | " yaml.dump(results_final, f, allow_unicode=True, sort_keys=False)\n", 167 | " if verbose:\n", 168 | " print(f'Prompt file saved to: \"trainscripts/textsliders/data/{file_name_to_save}\"')\n", 169 | " return f'trainscripts/textsliders/data/{file_name_to_save}'" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 4, 175 | "id": "eae07933-c3d7-43d1-99ba-b87b749ec1fd", 176 | "metadata": { 177 | "scrolled": true 178 | }, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "Target: person\n", 185 | "Positive: person, aged, wrinkles, grey hair, elderly features\n", 186 | "Negative: person, youthful, smooth skin, vibrant, young\n", 187 | "Preservation: white, black, indian, asian, hispanic ; male, female\n", 188 | "Name: person_age_GPT\n", 189 | "Prompt file saved to: \"trainscripts/textsliders/data/prompts-person_age_GPT.yaml\"\n" 190 | ] 191 | }, 192 | { 193 | "data": { 194 | "text/plain": [ 195 | "'trainscripts/textsliders/data/prompts-person_age_GPT.yaml'" 196 | ] 197 | }, 198 | "execution_count": 4, 199 | "metadata": {}, 200 | "output_type": "execute_result" 201 | } 202 | ], 203 | "source": [ 204 | "query = \"I want to build a slider to make people old\"\n", 205 | "generate_prompts_sliders(slider_query=query, model=\"gpt-4-turbo-preview\", save=True, verbose=True)" 206 | ] 207 | } 208 | ], 209 | "metadata": { 210 | "kernelspec": { 211 | "display_name": "Python 3 (ipykernel)", 212 | "language": "python", 213 | "name": "python3" 214 | }, 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3", 225 | "version": "3.9.18" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 5 230 | } 231 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Rohit Gandikota 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Concept Sliders 2 | ### [Project Website](https://sliders.baulab.info) | [Arxiv Preprint](https://arxiv.org/pdf/2311.12092.pdf) | [Trained Sliders](https://sliders.baulab.info/weights/xl_sliders/) | [Colab Demo](https://colab.research.google.com/github/rohitgandikota/sliders/blob/main/demo_concept_sliders.ipynb) | [Huggingface Demo](https://huggingface.co/spaces/baulab/ConceptSliders)
3 | Official code implementation of "Concept Sliders: LoRA Adaptors for Precise Control in Diffusion Models", European Conference on Computer Vision (ECCV 2024). 4 | 5 |
6 | 7 |
8 | 9 | ## Colab Demo 10 | Try out our colab demo here [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rohitgandikota/sliders/blob/main/demo_concept_sliders.ipynb) 11 | 12 | ## FLUX Support 🚀🚀🚀 13 | You can train sliders for FLUX-1 models. Right now it is experimental! Please be patient if it doesn't work as good as SDXL. FLUX is not designed the same way as SDXL.
14 | 15 | To play with flux sliders you need to update your packages. 16 | ``` 17 | pip install -r flux-sliders/flux-requirements.txt 18 | ``` 19 | 20 | Now just open the notebook present in the folder `flux-sliders` and have fun! 21 | 22 | ## UPDATE 23 | You can now use GPT-4 (or any other openAI model) to create prompts for your text sliders. All you need to do is describe what slider you want to create (e.g: "i want to make people look happy").
24 | Please refer to the [GPT-notebook](https://github.com/rohitgandikota/sliders/blob/main/GPT_prompt_helper.ipynb) 25 | 26 | ## Setup 27 | To set up your python environment: 28 | ``` 29 | conda create -n sliders python=3.9 30 | conda activate sliders 31 | 32 | git clone https://github.com/rohitgandikota/sliders.git 33 | cd sliders 34 | pip install -r requirements.txt 35 | ``` 36 | If you are running on Windows - please refer to these Windows setup guidelines [here](https://github.com/rohitgandikota/sliders/issues/27#issuecomment-1833572579) 37 | ## Textual Concept Sliders 38 | ### Training SD-1.x and SD-2.x LoRa 39 | To train an age slider - go to `train-scripts/textsliders/data/prompts.yaml` and edit the `target=person` and `positive=old person` and `unconditional=young person` (opposite of positive) and `neutral=person` and `action=enhance` with `guidance=4`.
40 | If you do not want your edit to be targetted to person replace it with any target you want (eg. dog) or if you need it global replace `person` with `""`
41 | Finally, run the command: 42 | ``` 43 | python trainscripts/textsliders/train_lora.py --attributes 'male, female' --name 'ageslider' --rank 4 --alpha 1 --config_file 'trainscripts/textsliders/data/config.yaml' 44 | ``` 45 | 46 | `--attributes` argument is used to disentangle concepts from the slider. For instance age slider makes all old people male (so instead add the `"female, male"` attributes to allow disentanglement) 47 | 48 | 49 | #### Evaluate 50 | To evaluate your trained models use the notebook `SD1-sliders-inference.ipynb` 51 | 52 | 53 | ### Training SD-XL 54 | To train sliders for SD-XL, use the script `train_lora_xl.py`. The setup is same as SDv1.4 55 | 56 | ``` 57 | python trainscripts/textsliders/train_lora_xl.py --attributes 'male, female' --name 'agesliderXL' --rank 4 --alpha 1 --config_file 'trainscripts/textsliders/data/config-xl.yaml' 58 | ``` 59 | 60 | #### Evaluate 61 | To evaluate your trained models use the notebook `XL-sliders-inference.ipynb` 62 | 63 | 64 | ## Visual Concept Sliders 65 | ### Training SD-1.x and SD-2.x LoRa 66 | To train image based sliders, you need to create a ~4-6 pairs of image dataset (before/after edit for desired concept). Save the before images and after images separately. You can also create a dataset with varied intensity effect and save them differently. 67 | 68 | To train an image slider for eye size - go to `train-scripts/imagesliders/data/config.yaml` and edit the `target=eye` and `itive='eye'` and `unconditional=''` and `neutral=eye` and `action=enhance` with `guidance=4`.
69 | If you want the diffusion model to figure out the edit concept - leave `target, positive, unconditional, neutral` as `''`
70 | Finally, run the command: 71 | ``` 72 | python trainscripts/imagesliders/train_lora-scale.py --name 'eyeslider' --rank 4 --alpha 1 --config_file 'trainscripts/imagesliders/data/config.yaml' --folder_main 'datasets/eyesize/' --folders 'bigsize, smallsize' --scales '1, -1' 73 | ``` 74 | For this to work - you need to store your before images in `smallsize` and after images in `bigsize`. The corresponding paired files in both the folders should have same names. Both these subfolders should be under `datasets/eyesize`. Feel free to make your own datasets in your own named conventions. 75 | ### Training SD-XL 76 | To train image sliders for SD-XL, use the script `train-lora-scale-xl.py`. The setup is same as SDv1.4 77 | 78 | ``` 79 | python trainscripts/imagesliders/train_lora-scale-xl.py --name 'eyesliderXL' --rank 4 --alpha 1 --config_file 'trainscripts/imagesliders/data/config-xl.yaml' --folder_main '/share/u/rohit/imageXLdataset/eyesize_data/' 80 | ``` 81 | 82 | ## Editing Real Images 83 | Concept sliders can be used to edit real images. We use null inversion to edit the images - instead of prompt, we use sliders!
84 | Checkout - `demo_image_editing.ipynb` for mode details. 85 | 86 | ## Running Gradio Demo Locally 87 | You can also run the HF hosted gradio slider tool (huge shoutout to gradio and HF team) locally using the following scripts 88 | ``` 89 | git lfs install 90 | git clone https://huggingface.co/spaces/baulab/ConceptSliders 91 | cd ConceptSliders 92 | pip install requirements.txt 93 | python app.py 94 | ``` 95 | For more inference time gradio demos please refer to Cameduru's repo [here](https://github.com/camenduru/sliders-colab) 96 | 97 | ## Running with ControlNet Integration 98 | Our user community is amazing! Here is the resource that integrates ControlNet: https://github.com/rohitgandikota/sliders/issues/76#issuecomment-2099766893 99 | ## Citing our work 100 | The preprint can be cited as follows 101 | ``` 102 | @inproceedings{gandikota2023erasing, 103 | title={Erasing Concepts from Diffusion Models}, 104 | author={Rohit Gandikota and Joanna Materzy\'nska and Tingrui Zhou and Antonio Torralba and David Bau}, 105 | booktitle={Proceedings of the 2024 IEEE European Conference on Computer Vision}, 106 | note={arXiv preprint arXiv:2311.12092}, 107 | year={2024} 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from trainscripts.textsliders import lora -------------------------------------------------------------------------------- /eval-scripts/clip_score.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import requests 3 | import os, glob 4 | import pandas as pd 5 | import numpy as np 6 | import re 7 | import argparse 8 | from transformers import CLIPProcessor, CLIPModel 9 | 10 | 11 | if __name__=='__main__': 12 | parser = argparse.ArgumentParser( 13 | prog = 'clipScore', 14 | description = 'Generate CLIP score for images') 15 | parser.add_argument('--im_path', help='path for images', type=str, required=True) 16 | parser.add_argument('--prompt', help='prompt to check clip score against', type=str, required=True) 17 | parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True) 18 | parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0') 19 | parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000) 20 | parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0) 21 | 22 | args = parser.parse_args() 23 | 24 | model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 25 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 26 | 27 | def sorted_nicely( l ): 28 | convert = lambda text: int(text) if text.isdigit() else text 29 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 30 | return sorted(l, key = alphanum_key) 31 | 32 | 33 | path = args.im_path 34 | model_names = os.listdir(path) 35 | model_names = [m for m in model_names if 'all' not in m and '.csv' not in m] 36 | csv_path = args.prompts_path 37 | save_path = '' 38 | prompt = args. prompt.strip() 39 | print(f'Eval agaisnt prompt: {prompt}') 40 | model_names.sort() 41 | print(model_names) 42 | df = pd.read_csv(csv_path) 43 | for model_name in model_names: 44 | print(model_name) 45 | 46 | im_folder = os.path.join(path, model_name) 47 | 48 | images = os.listdir(im_folder) 49 | images = sorted_nicely(images) 50 | ratios = {} 51 | model_name = model_name.replace('half','0.5') 52 | df[f'clip_{model_name}'] = np.nan 53 | for image in images: 54 | try: 55 | case_number = int(image.split('_')[0].replace('.png','')) 56 | if case_number not in list(df['case_number']): 57 | continue 58 | im = Image.open(os.path.join(im_folder, image)) 59 | inputs = processor(text=[prompt], images=im, return_tensors="pt", padding=True) 60 | outputs = model(**inputs) 61 | clip_score = outputs.logits_per_image[0][0].detach().cpu() # this is the image-text similarity score 62 | ratios[case_number] = ratios.get(case_number, []) + [clip_score] 63 | # print(image, clip_score) 64 | except: 65 | pass 66 | for key in ratios.keys(): 67 | df.loc[key,f'clip_{model_name}'] = np.mean(ratios[key]) 68 | # df = df.dropna(axis=0) 69 | print(f"Mean CLIP score: {df[f'clip_{model_name}'].mean()}") 70 | print('-------------------------------------------------') 71 | print('\n') 72 | df.to_csv(f'{path}/clip_scores.csv', index=False) 73 | -------------------------------------------------------------------------------- /eval-scripts/generate_images-uce.py: -------------------------------------------------------------------------------- 1 | import glob, re 2 | import torch 3 | from diffusers import DiffusionPipeline 4 | from safetensors.torch import load_file 5 | import matplotlib.pyplot as plt 6 | import matplotlib.image as mpimg 7 | import random 8 | import copy 9 | import gc 10 | from tqdm.auto import tqdm 11 | import random 12 | from PIL import Image 13 | from transformers import CLIPTextModel, CLIPTokenizer 14 | 15 | import diffusers 16 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler 17 | from diffusers.loaders import AttnProcsLayers 18 | from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor 19 | import torch 20 | from typing import Any, Dict, List, Optional, Tuple, Union 21 | from transformers import CLIPTextModel, CLIPTokenizer 22 | import os 23 | import sys 24 | import argparse 25 | import pandas as pd 26 | sys.path.insert(1, os.getcwd()) 27 | from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV 28 | 29 | def sorted_nicely( l ): 30 | convert = lambda text: float(text) if text.replace('-','').replace('.','').isdigit() else text 31 | alphanum_key = lambda key: [convert(c) for c in re.split('(-?[0-9]+.?[0-9]+?)', key) ] 32 | return sorted(l, key = alphanum_key) 33 | 34 | def flush(): 35 | torch.cuda.empty_cache() 36 | gc.collect() 37 | flush() 38 | 39 | from transformers import CLIPTextModel, CLIPTokenizer 40 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler 41 | from diffusers import LMSDiscreteScheduler 42 | import torch 43 | from PIL import Image 44 | import argparse 45 | import os, json, random 46 | import pandas as pd 47 | def image_grid(imgs, rows=2, cols=2): 48 | assert len(imgs) == rows*cols 49 | 50 | w, h = imgs[0].size 51 | grid = Image.new('RGB', size=(cols*w, rows*h)) 52 | grid_w, grid_h = grid.size 53 | 54 | for i, img in enumerate(imgs): 55 | grid.paste(img, box=(i%cols*w, i//cols*h)) 56 | return grid 57 | def generate_images(unet, new_state_dict, vae, tokenizer, text_encoder, prompt, evaluation_seed=3124, model_path=None, device='cuda:0', guidance_scale = 7.5, image_size=512, ddim_steps=50, num_samples=4, from_case=0, till_case=1000000, base='1.4', start_noise=800): 58 | 59 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 60 | 61 | vae.to(device) 62 | text_encoder.to(device) 63 | unet.to(device) 64 | torch_device = device 65 | prompt = [str(prompt)]*num_samples 66 | seed = evaluation_seed 67 | 68 | height = image_size # default height of Stable Diffusion 69 | width = image_size # default width of Stable Diffusion 70 | 71 | num_inference_steps = ddim_steps # Number of denoising steps 72 | 73 | guidance_scale = guidance_scale # Scale for classifier-free guidance 74 | 75 | generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise 76 | 77 | batch_size = len(prompt) 78 | 79 | text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 80 | 81 | text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] 82 | 83 | max_length = text_input.input_ids.shape[-1] 84 | uncond_input = tokenizer( 85 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 86 | ) 87 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 88 | 89 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 90 | 91 | latents = torch.randn( 92 | (batch_size, unet.in_channels, height // 8, width // 8), 93 | generator=generator, 94 | ) 95 | latents = latents.to(torch_device) 96 | 97 | scheduler.set_timesteps(num_inference_steps) 98 | 99 | latents = latents * scheduler.init_noise_sigma 100 | 101 | from tqdm.auto import tqdm 102 | 103 | scheduler.set_timesteps(num_inference_steps) 104 | flag = False 105 | for t in tqdm(scheduler.timesteps): 106 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 107 | if t<=start_noise: 108 | if not flag: 109 | flag = True 110 | unet.load_state_dict(new_state_dict) 111 | latent_model_input = torch.cat([latents] * 2) 112 | 113 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) 114 | 115 | # predict the noise residual 116 | with torch.no_grad(): 117 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 118 | 119 | # perform guidance 120 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 121 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 122 | 123 | # compute the previous noisy sample x_t -> x_t-1 124 | latents = scheduler.step(noise_pred, t, latents).prev_sample 125 | 126 | # scale and decode the image latents with vae 127 | latents = 1 / 0.18215 * latents 128 | with torch.no_grad(): 129 | image = vae.decode(latents).sample 130 | 131 | image = (image / 2 + 0.5).clamp(0, 1) 132 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 133 | images = (image * 255).round().astype("uint8") 134 | pil_images = [Image.fromarray(image) for image in images] 135 | return pil_images 136 | 137 | 138 | def generate_images_(model_name, prompts_path, save_path, negative_prompt, device, guidance_scale , image_size, ddim_steps, num_samples,from_case, till_case, base, rank, start_noise): 139 | # Load scheduler, tokenizer and models. 140 | scales = [-2, -1, -.5, 0, .5, 1, 2] 141 | 142 | revision = None 143 | pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4" 144 | weight_dtype = torch.float32 145 | 146 | # Load scheduler, tokenizer and models. 147 | noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 148 | tokenizer = CLIPTokenizer.from_pretrained( 149 | pretrained_model_name_or_path, subfolder="tokenizer", revision=revision 150 | ) 151 | text_encoder = CLIPTextModel.from_pretrained( 152 | pretrained_model_name_or_path, subfolder="text_encoder", revision=revision 153 | ) 154 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) 155 | unet = UNet2DConditionModel.from_pretrained( 156 | pretrained_model_name_or_path, subfolder="unet", revision=revision 157 | ) 158 | # freeze parameters of models to save more memory 159 | unet.requires_grad_(False) 160 | unet.to(device, dtype=weight_dtype) 161 | vae.requires_grad_(False) 162 | 163 | text_encoder.requires_grad_(False) 164 | 165 | vae.to(device, dtype=weight_dtype) 166 | text_encoder.to(device, dtype=weight_dtype) 167 | 168 | df = pd.read_csv(prompts_path) 169 | prompts = df.prompt 170 | seeds = df.evaluation_seed 171 | case_numbers = df.case_number 172 | 173 | name = os.path.basename(model_name) 174 | folder_path = f'{save_path}/{name}' 175 | os.makedirs(folder_path, exist_ok=True) 176 | os.makedirs(folder_path+f'/all', exist_ok=True) 177 | scales_str = [] 178 | for scale in scales: 179 | scale_str = f'{scale}' 180 | scale_str = scale_str.replace('0.5','half') 181 | scales_str.append(scale_str) 182 | os.makedirs(folder_path+f'/{scale_str}', exist_ok=True) 183 | height = image_size # default height of Stable Diffusion 184 | width = image_size # default width of Stable Diffusion 185 | 186 | num_inference_steps = ddim_steps # Number of denoising steps 187 | 188 | guidance_scale = guidance_scale # Scale for classifier-free guidance 189 | torch_device = device 190 | 191 | model_version = "CompVis/stable-diffusion-v1-4" 192 | 193 | unet = UNet2DConditionModel.from_pretrained(model_version, subfolder="unet") 194 | old_state_dict = copy.deepcopy(unet.state_dict()) 195 | model_path = model_name 196 | new_state_dict_ = copy.deepcopy(torch.load(model_path, map_location='cpu')) 197 | delta_dict = {} 198 | for key, value in old_state_dict.items(): 199 | delta_dict[key] = new_state_dict_[key] - value 200 | del new_state_dict_ 201 | for _, row in df.iterrows(): 202 | prompt = str(row.prompt) 203 | seed = row.evaluation_seed 204 | case_number = row.case_number 205 | if not (case_number>=from_case and case_number<=till_case): 206 | continue 207 | images_list = [] 208 | for scale in scales: 209 | 210 | new_state_dict = {} 211 | for key, value in old_state_dict.items(): 212 | new_state_dict[key] = value + scale * delta_dict[key] 213 | 214 | 215 | im = generate_images(unet=unet, new_state_dict=new_state_dict, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, prompt=prompt, evaluation_seed=seed,num_samples=num_samples, start_noise=start_noise, device=device) 216 | images_list.append(im) 217 | del new_state_dict 218 | for num in range(num_samples): 219 | fig, ax = plt.subplots(1, len(images_list), figsize=(4*(len(scales)),4)) 220 | for i, a in enumerate(ax): 221 | images_list[i][num].save(f'{folder_path}/{scales_str[i]}/{case_number}_{num}.png') 222 | a.imshow(images_list[i][num]) 223 | a.set_title(f"{scales[i]}",fontsize=15) 224 | a.axis('off') 225 | fig.savefig(f'{folder_path}/all/{case_number}_{num}.png',bbox_inches='tight') 226 | plt.close() 227 | del unet 228 | flush() 229 | if __name__=='__main__': 230 | parser = argparse.ArgumentParser( 231 | prog = 'generateImages', 232 | description = 'Generate Images using Diffusers Code') 233 | parser.add_argument('--model_name', help='name of model', type=str, required=True) 234 | parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True) 235 | parser.add_argument('--negative_prompts', help='negative prompt', type=str, required=False, default=None) 236 | parser.add_argument('--save_path', help='folder where to save images', type=str, required=True) 237 | parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0') 238 | parser.add_argument('--base', help='version of stable diffusion to use', type=str, required=False, default='1.4') 239 | parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5) 240 | parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512) 241 | parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000) 242 | parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0) 243 | parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=5) 244 | parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50) 245 | parser.add_argument('--rank', help='rank of the LoRA', type=int, required=False, default=4) 246 | parser.add_argument('--start_noise', help='to start the finetuned model from', type=int, required=False, default=800) 247 | 248 | args = parser.parse_args() 249 | 250 | model_name = args.model_name 251 | rank = args.rank 252 | if 'rank1' in model_name: 253 | rank = 1 254 | prompts_path = args.prompts_path 255 | save_path = args.save_path 256 | device = args.device 257 | guidance_scale = args.guidance_scale 258 | image_size = args.image_size 259 | ddim_steps = args.ddim_steps 260 | num_samples= args.num_samples 261 | from_case = args.from_case 262 | till_case = args.till_case 263 | start_noise = args.start_noise 264 | base = args.base 265 | negative_prompts_path = args.negative_prompts 266 | if negative_prompts_path is not None: 267 | negative_prompt = '' 268 | with open(negative_prompts_path, 'r') as fp: 269 | vals = json.load(fp) 270 | for val in vals: 271 | negative_prompt+=val+' ,' 272 | print(f'Negative prompt is being used: {negative_prompt}') 273 | else: 274 | negative_prompt = None 275 | generate_images_(model_name=model_name, prompts_path=prompts_path, save_path=save_path, negative_prompt=negative_prompt, device=device, guidance_scale = guidance_scale, image_size=image_size, ddim_steps=ddim_steps, num_samples=num_samples,from_case=from_case, till_case=till_case, base=base, rank=rank, start_noise=start_noise) -------------------------------------------------------------------------------- /eval-scripts/generate_images_sd1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import argparse 4 | import os, json, random 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | import glob, re,sys 8 | from tqdm.auto import tqdm 9 | 10 | from safetensors.torch import load_file 11 | import matplotlib.image as mpimg 12 | import copy 13 | import gc 14 | from transformers import CLIPTextModel, CLIPTokenizer 15 | 16 | import diffusers 17 | from diffusers import DiffusionPipeline 18 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler 19 | from diffusers.loaders import AttnProcsLayers 20 | from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor 21 | from typing import Any, Dict, List, Optional, Tuple, Union 22 | sys.path.insert(1, os.getcwd()) 23 | from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV 24 | 25 | 26 | def flush(): 27 | torch.cuda.empty_cache() 28 | gc.collect() 29 | flush() 30 | 31 | 32 | 33 | 34 | def sorted_nicely( l ): 35 | convert = lambda text: float(text) if text.replace('-','').replace('.','').isdigit() else text 36 | alphanum_key = lambda key: [convert(c) for c in re.split('(-?[0-9]+.?[0-9]+?)', key) ] 37 | return sorted(l, key = alphanum_key) 38 | 39 | def flush(): 40 | torch.cuda.empty_cache() 41 | gc.collect() 42 | 43 | def generate_images(model_name, prompts_path, save_path, negative_prompt, device, guidance_scale , image_size, ddim_steps, num_samples,from_case, till_case, base, rank, start_noise): 44 | # Load scheduler, tokenizer and models. 45 | scales = [-2, -1, 0, 1, 2] 46 | revision = None 47 | pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4" 48 | weight_dtype = torch.float16 49 | 50 | # Load scheduler, tokenizer and models. 51 | noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 52 | tokenizer = CLIPTokenizer.from_pretrained( 53 | pretrained_model_name_or_path, subfolder="tokenizer", revision=revision 54 | ) 55 | text_encoder = CLIPTextModel.from_pretrained( 56 | pretrained_model_name_or_path, subfolder="text_encoder", revision=revision 57 | ) 58 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) 59 | unet = UNet2DConditionModel.from_pretrained( 60 | pretrained_model_name_or_path, subfolder="unet", revision=revision 61 | ) 62 | # freeze parameters of models to save more memory 63 | unet.requires_grad_(False) 64 | unet.to(device, dtype=weight_dtype) 65 | vae.requires_grad_(False) 66 | 67 | text_encoder.requires_grad_(False) 68 | 69 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision 70 | # as these weights are only used for inference, keeping weights in full precision is not required. 71 | 72 | 73 | # Move unet, vae and text_encoder to device and cast to weight_dtype 74 | 75 | vae.to(device, dtype=weight_dtype) 76 | text_encoder.to(device, dtype=weight_dtype) 77 | 78 | name = os.path.basename(model_name) 79 | alpha = 1 80 | train_method = 'xattn' 81 | n = model_name.split('/')[-2] 82 | if 'noxattn' in n: 83 | train_method = 'noxattn' 84 | if 'hspace' in n: 85 | train_method+='-hspace' 86 | scales = [-5, -2, -1, 0, 1, 2, 5] 87 | if 'last' in n: 88 | train_method+='-last' 89 | scales = [-5, -2, -1, 0, 1, 2, 5] 90 | network_type = "c3lier" 91 | if train_method == 'xattn': 92 | network_type = 'lierla' 93 | 94 | modules = DEFAULT_TARGET_REPLACE 95 | if network_type == "c3lier": 96 | modules += UNET_TARGET_REPLACE_MODULE_CONV 97 | 98 | network = LoRANetwork( 99 | unet, 100 | rank=rank, 101 | multiplier=1.0, 102 | alpha=alpha, 103 | train_method=train_method, 104 | ).to(device, dtype=weight_dtype) 105 | 106 | network.load_state_dict(torch.load(model_name)) 107 | 108 | df = pd.read_csv(prompts_path) 109 | prompts = df.prompt 110 | seeds = df.evaluation_seed 111 | case_numbers = df.case_number 112 | 113 | folder_path = f'{save_path}/{name}' 114 | os.makedirs(folder_path, exist_ok=True) 115 | os.makedirs(folder_path+f'/all', exist_ok=True) 116 | scales_str = [] 117 | for scale in scales: 118 | scale_str = f'{scale}' 119 | scale_str = scale_str.replace('0.5','half') 120 | scales_str.append(scale_str) 121 | os.makedirs(folder_path+f'/{scale_str}', exist_ok=True) 122 | height = image_size # default height of Stable Diffusion 123 | width = image_size # default width of Stable Diffusion 124 | 125 | num_inference_steps = ddim_steps # Number of denoising steps 126 | 127 | guidance_scale = guidance_scale # Scale for classifier-free guidance 128 | torch_device = device 129 | for _, row in df.iterrows(): 130 | print(str(row.prompt),str(row.evaluation_seed)) 131 | prompt = [str(row.prompt)]*num_samples 132 | batch_size = len(prompt) 133 | seed = row.evaluation_seed 134 | case_number = row.case_number 135 | if not (case_number>=from_case and case_number<=till_case): 136 | continue 137 | images_list = [] 138 | for scale in scales: 139 | torch_device = device 140 | negative_prompt = None 141 | height = 512 142 | width = 512 143 | guidance_scale = 7.5 144 | 145 | generator = torch.manual_seed(seed) 146 | text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 147 | 148 | text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] 149 | 150 | max_length = text_input.input_ids.shape[-1] 151 | if negative_prompt is None: 152 | uncond_input = tokenizer( 153 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 154 | ) 155 | else: 156 | uncond_input = tokenizer( 157 | [negative_prompt] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 158 | ) 159 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 160 | 161 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 162 | 163 | latents = torch.randn( 164 | (batch_size, unet.in_channels, height // 8, width // 8), 165 | generator=generator, 166 | ) 167 | latents = latents.to(torch_device) 168 | 169 | noise_scheduler.set_timesteps(ddim_steps) 170 | 171 | latents = latents * noise_scheduler.init_noise_sigma 172 | latents = latents.to(weight_dtype) 173 | latent_model_input = torch.cat([latents] * 2) 174 | for t in tqdm(noise_scheduler.timesteps): 175 | if t>start_noise: 176 | network.set_lora_slider(scale=0) 177 | else: 178 | network.set_lora_slider(scale=scale) 179 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 180 | latent_model_input = torch.cat([latents] * 2) 181 | 182 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) 183 | # predict the noise residual 184 | with network: 185 | with torch.no_grad(): 186 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 187 | # perform guidance 188 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 189 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 190 | 191 | # compute the previous noisy sample x_t -> x_t-1 192 | latents = noise_scheduler.step(noise_pred, t, latents).prev_sample 193 | 194 | # scale and decode the image latents with vae 195 | latents = 1 / 0.18215 * latents 196 | with torch.no_grad(): 197 | image = vae.decode(latents).sample 198 | image = (image / 2 + 0.5).clamp(0, 1) 199 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 200 | images = (image * 255).round().astype("uint8") 201 | pil_images = [Image.fromarray(image) for image in images] 202 | images_list.append(pil_images) 203 | for num in range(num_samples): 204 | fig, ax = plt.subplots(1, len(images_list), figsize=(4*(len(scales)),4)) 205 | for i, a in enumerate(ax): 206 | images_list[i][num].save(f'{folder_path}/{scales_str[i]}/{case_number}_{num}.png') 207 | a.imshow(images_list[i][num]) 208 | a.set_title(f"{scales[i]}",fontsize=15) 209 | a.axis('off') 210 | fig.savefig(f'{folder_path}/all/{case_number}_{num}.png',bbox_inches='tight') 211 | plt.close() 212 | del network, unet 213 | flush() 214 | if __name__=='__main__': 215 | parser = argparse.ArgumentParser( 216 | prog = 'generateImages', 217 | description = 'Generate Images using Diffusers Code') 218 | parser.add_argument('--model_name', help='name of model', type=str, required=True) 219 | parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True) 220 | parser.add_argument('--negative_prompts', help='negative prompt', type=str, required=False, default=None) 221 | parser.add_argument('--save_path', help='folder where to save images', type=str, required=True) 222 | parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0') 223 | parser.add_argument('--base', help='version of stable diffusion to use', type=str, required=False, default='1.4') 224 | parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5) 225 | parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512) 226 | parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000) 227 | parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0) 228 | parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=2) 229 | parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50) 230 | parser.add_argument('--rank', help='rank of the LoRA', type=int, required=False, default=4) 231 | parser.add_argument('--start_noise', help='what time stamp to flip to edited model', type=int, required=False, default=850) 232 | 233 | args = parser.parse_args() 234 | 235 | model_name = args.model_name 236 | rank = args.rank 237 | if 'rank1' in model_name: 238 | rank = 1 239 | prompts_path = args.prompts_path 240 | save_path = args.save_path 241 | device = args.device 242 | guidance_scale = args.guidance_scale 243 | image_size = args.image_size 244 | ddim_steps = args.ddim_steps 245 | num_samples= args.num_samples 246 | from_case = args.from_case 247 | till_case = args.till_case 248 | start_noise = args.start_noise 249 | base = args.base 250 | negative_prompts_path = args.negative_prompts 251 | if negative_prompts_path is not None: 252 | negative_prompt = '' 253 | with open(negative_prompts_path, 'r') as fp: 254 | vals = json.load(fp) 255 | for val in vals: 256 | negative_prompt+=val+' ,' 257 | print(f'Negative prompt is being used: {negative_prompt}') 258 | else: 259 | negative_prompt = None 260 | generate_images(model_name=model_name, prompts_path=prompts_path, save_path=save_path, negative_prompt=negative_prompt, device=device, guidance_scale = guidance_scale, image_size=image_size, ddim_steps=ddim_steps, num_samples=num_samples,from_case=from_case, till_case=till_case, base=base, rank=rank, start_noise=start_noise) 261 | -------------------------------------------------------------------------------- /eval-scripts/generate_images_textinversion.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | import pandas as pd 3 | import torch 4 | import os 5 | import argparse 6 | 7 | if __name__=='__main__': 8 | parser = argparse.ArgumentParser( 9 | prog = 'Generate Text Inversion Images',) 10 | 11 | parser.add_argument('--model_name', help='path to custom model', type=str, required=True) 12 | parser.add_argument('--prompts_path', help='path to csv prompts', type=str, required=True) 13 | parser.add_argument('--token', help='path to csv prompts', type=str, required=True) 14 | args = parser.parse_args() 15 | model_id = args.model_name 16 | custom_token = args.token #'' 17 | 18 | 19 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cuda") 20 | 21 | df = pd.read_csv(args.prompts_path) 22 | 23 | prompts = list(df.prompt) 24 | seeds = list(df.evaluation_seed) 25 | case_numbers = list(df.case_number) 26 | file = os.path.basename(model_id) 27 | 28 | os.makedirs(f'images/text_inversion/{file}/',exist_ok=True) 29 | for idx,prompt in enumerate(prompts): 30 | 31 | prompt += f" with {custom_token}" 32 | case_number = case_numbers[idx] 33 | generator = torch.manual_seed(seeds[idx]) 34 | images = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=5).images 35 | for i, im in enumerate(images): 36 | im.save(f'images/text_inversion/{file}/{case_number}_{i}.png') 37 | -------------------------------------------------------------------------------- /eval-scripts/generate_images_textinversion_xl.py: -------------------------------------------------------------------------------- 1 | from diffusers import DiffusionPipeline,DDPMScheduler 2 | import pandas as pd 3 | import os 4 | import glob 5 | import torch 6 | import random 7 | 8 | 9 | def load_XLembedding(base,token="my",embedding_file="myToken.pt",path="./Embeddings/"): 10 | emb=torch.load(path+embedding_file) 11 | set_XLembedding(base,emb,token) 12 | 13 | def set_XLembedding(base,emb,token="my"): 14 | with torch.no_grad(): 15 | # Embeddings[tokenNo] to learn 16 | tokens=base.components["tokenizer"].encode(token) 17 | assert len(tokens)==3, "token is not a single token in 'tokenizer'" 18 | tokenNo=tokens[1] 19 | tokens=base.components["tokenizer_2"].encode(token) 20 | assert len(tokens)==3, "token is not a single token in 'tokenizer_2'" 21 | tokenNo2=tokens[1] 22 | embs=base.components["text_encoder"].text_model.embeddings.token_embedding.weight 23 | embs2=base.components["text_encoder_2"].text_model.embeddings.token_embedding.weight 24 | assert embs[tokenNo].shape==emb["emb"].shape, "different 'text_encoder'" 25 | assert embs2[tokenNo2].shape==emb["emb2"].shape, "different 'text_encoder_2'" 26 | embs[tokenNo]=emb["emb"].to(embs.dtype).to(embs.device) 27 | embs2[tokenNo2]=emb["emb2"].to(embs2.dtype).to(embs2.device) 28 | 29 | 30 | base_model_path="stabilityai/stable-diffusion-xl-base-1.0" 31 | pipe = DiffusionPipeline.from_pretrained( 32 | base_model_path, 33 | torch_dtype=torch.float16, #torch.bfloat16 34 | variant="fp32", 35 | use_safetensors=True, 36 | add_watermarker=False, 37 | ) 38 | pipe.enable_xformers_memory_efficient_attention() 39 | torch.set_grad_enabled(False) 40 | _=pipe.to("cuda:1") 41 | 42 | 43 | df = pd.read_csv('prompts/prompts-personreal.csv') 44 | prompts = list(df.prompt) 45 | seeds = list(df.evaluation_seed) 46 | case_numbers = list(df.case_number) 47 | 48 | learned="sks" 49 | embs_path="./textualinversion_models/" 50 | emb_file="eyesize_textual_inversion.pt" 51 | 52 | load_XLembedding(pipe,token=learned,embedding_file=emb_file,path=embs_path) 53 | 54 | p1="photo of a person, realistic, 8k with {} eyes" 55 | n_steps=50 56 | 57 | seed = random.randint(0,2**15) 58 | sample_prompt = p1 59 | prompt=sample_prompt.format(learned) 60 | 61 | 62 | for idx, prompt in enumerate(prompts): 63 | case_number = case_numbers[idx] 64 | seed = seeds[idx] 65 | 66 | print(prompt, seed) 67 | with torch.no_grad(): 68 | generator = torch.manual_seed(seed) 69 | images = pipe( 70 | prompt=prompt+ ' with sks eyes', 71 | num_inference_steps=n_steps, 72 | num_images_per_prompt=5, 73 | generator = generator 74 | ).images 75 | for i, im in enumerate(images): 76 | im.save(f'images/textualinversion/eyesize_xl/{case_number}_{i}.png') 77 | -------------------------------------------------------------------------------- /eval-scripts/lpip_score.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | from PIL import Image 9 | import matplotlib.pyplot as plt 10 | 11 | import torchvision.transforms as transforms 12 | import torchvision.models as models 13 | import numpy as np 14 | import copy 15 | import os 16 | import pandas as pd 17 | import argparse 18 | import lpips 19 | 20 | 21 | # desired size of the output image 22 | imsize = 64 23 | loader = transforms.Compose([ 24 | transforms.Resize(imsize), # scale imported image 25 | transforms.ToTensor()]) # transform it into a torch tensor 26 | 27 | 28 | def image_loader(image_name): 29 | image = Image.open(image_name) 30 | # fake batch dimension required to fit network's input dimensions 31 | image = loader(image).unsqueeze(0) 32 | image = (image-0.5)*2 33 | return image.to(torch.float) 34 | 35 | 36 | if __name__=='__main__': 37 | parser = argparse.ArgumentParser( 38 | prog = 'LPIPS', 39 | description = 'Takes the path to two images and gives LPIPS') 40 | parser.add_argument('--im_path', help='path to original image', type=str, required=True) 41 | parser.add_argument('--prompts_path', help='path to csv prompts', type=str, required=True) 42 | parser.add_argument('--true', help='path to true SD images', type=str, required=True) 43 | 44 | loss_fn_alex = lpips.LPIPS(net='alex') 45 | args = parser.parse_args() 46 | 47 | true = args.true 48 | models = os.listdir(args.im_path) 49 | models = [m for m in models if m not in [true,'all'] and '.csv' not in m] 50 | 51 | original_path = os.path.join(args.im_path,true) 52 | df_prompts = pd.read_csv(args.prompts_path) 53 | for model_name in models: 54 | edited_path = os.path.join(args.im_path,model_name) 55 | file_names = [name for name in os.listdir(edited_path) if '.png' in name] 56 | model_name = model_name.replace('half','0.5') 57 | df_prompts[f'lpips_{model_name}'] = df_prompts['case_number'] *0 58 | for index, row in df_prompts.iterrows(): 59 | case_number = row.case_number 60 | files = [file for file in file_names if file.startswith(f'{case_number}_')] 61 | lpips_scores = [] 62 | for file in files: 63 | print(file) 64 | try: 65 | original = image_loader(os.path.join(original_path,file)) 66 | edited = image_loader(os.path.join(edited_path,file)) 67 | 68 | l = loss_fn_alex(original, edited) 69 | 70 | lpips_scores.append(l.item()) 71 | except Exception: 72 | print('No File') 73 | pass 74 | print(f'Case {case_number}: {np.mean(lpips_scores)}') 75 | df_prompts.loc[index,f'lpips_{model_name}'] = np.mean(lpips_scores) 76 | df_prompts.to_csv(os.path.join(args.im_path, f'lpips_score.csv'), index=False) 77 | -------------------------------------------------------------------------------- /flux-sliders/flux-requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes 2 | dadaptation 3 | diffusers 4 | xformers 5 | torchvision 6 | accelerate 7 | transformers 8 | sentencepiece 9 | lion_pytorch 10 | lpips 11 | matplotlib 12 | numpy 13 | opencv_python 14 | opencv_python_headless 15 | pandas 16 | Pillow 17 | prodigyopt 18 | pydantic 19 | PyYAML 20 | Requests 21 | safetensors 22 | torch 23 | tqdm 24 | wandb 25 | datasets 26 | ftfy 27 | openai 28 | scikit-learn 29 | git+https://github.com/davidbau/baukit 30 | anthropic 31 | peft -------------------------------------------------------------------------------- /flux-sliders/utils/lora.py: -------------------------------------------------------------------------------- 1 | # ref: 2 | # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py 3 | # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py 4 | 5 | import os 6 | import math 7 | from typing import Optional, List, Type, Set, Literal 8 | 9 | import torch 10 | import torch.nn as nn 11 | from diffusers import UNet2DConditionModel 12 | from safetensors.torch import save_file 13 | from datetime import datetime 14 | 15 | UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ 16 | # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 17 | "Attention" 18 | ] 19 | UNET_TARGET_REPLACE_MODULE_CONV = [ 20 | "ResnetBlock2D", 21 | "Downsample2D", 22 | "Upsample2D", 23 | "DownBlock2D", 24 | "UpBlock2D", 25 | 26 | ] # locon, 3clier 27 | 28 | LORA_PREFIX_UNET = "lora_unet" 29 | 30 | DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER 31 | 32 | TRAINING_METHODS = Literal[ 33 | "noxattn", # train all layers except x-attns and time_embed layers 34 | "innoxattn", # train all layers except self attention layers 35 | "selfattn", # ESD-u, train only self attention layers 36 | "xattn", # ESD-x, train only x attention layers 37 | "xattn-up", # all up blocks only 38 | "xattn-down",# all down blocks only 39 | "xattn-mid",# mid blocks only 40 | "full", # train all layers 41 | "xattn-strict", # q and k values 42 | "noxattn-hspace", 43 | "noxattn-hspace-last", 44 | # "xlayer", 45 | # "outxattn", 46 | # "outsattn", 47 | # "inxattn", 48 | # "inmidsattn", 49 | # "selflayer", 50 | ] 51 | 52 | def load_ortho_dict(n): 53 | path = f'~/orthogonal_basis/{n:09}.ckpt' 54 | if os.path.isfile(path): 55 | return torch.load(path) 56 | else: 57 | x = torch.randn(n,n) 58 | eig, _, _ = torch.svd(x) 59 | torch.save(eig, path) 60 | return eig 61 | 62 | def init_ortho_proj(rank, weight): 63 | seed = torch.seed() 64 | torch.manual_seed(datetime.now().timestamp()) 65 | q_index = torch.randint(high=weight.size(0),size=(rank,)) 66 | torch.manual_seed(seed) 67 | 68 | ortho_q_init = load_ortho_dict(weight.size(0)).to(dtype=weight.dtype)[:,q_index] 69 | return nn.Parameter(ortho_q_init) 70 | 71 | 72 | class LoRAModule(nn.Module): 73 | """ 74 | replaces forward method of the original Linear, instead of replacing the original Linear module. 75 | """ 76 | 77 | def __init__( 78 | self, 79 | lora_name, 80 | org_module: nn.Module, 81 | multiplier=1.0, 82 | lora_dim=4, 83 | alpha=1, 84 | train_method='xattn' 85 | ): 86 | """if alpha == 0 or None, alpha is rank (no scaling).""" 87 | super().__init__() 88 | self.lora_name = lora_name 89 | self.lora_dim = lora_dim 90 | 91 | if "Linear" in org_module.__class__.__name__: 92 | in_dim = org_module.in_features 93 | out_dim = org_module.out_features 94 | self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) 95 | self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) 96 | 97 | elif "Conv" in org_module.__class__.__name__: # 一応 98 | in_dim = org_module.in_channels 99 | out_dim = org_module.out_channels 100 | 101 | self.lora_dim = min(self.lora_dim, in_dim, out_dim) 102 | if self.lora_dim != lora_dim: 103 | print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") 104 | 105 | kernel_size = org_module.kernel_size 106 | stride = org_module.stride 107 | padding = org_module.padding 108 | self.lora_down = nn.Conv2d( 109 | in_dim, self.lora_dim, kernel_size, stride, padding, bias=False 110 | ) 111 | self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) 112 | 113 | if type(alpha) == torch.Tensor: 114 | alpha = alpha.detach().numpy() 115 | alpha = lora_dim if alpha is None or alpha == 0 else alpha 116 | self.scale = alpha / self.lora_dim 117 | self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える 118 | 119 | # same as microsoft's 120 | nn.init.kaiming_uniform_(self.lora_down.weight, a=1) 121 | if train_method == 'full': 122 | nn.init.zeros_(self.lora_up.weight) 123 | else: 124 | self.lora_up.weight = init_ortho_proj(lora_dim, self.lora_up.weight) 125 | self.lora_up.weight.requires_grad_(False) 126 | 127 | self.multiplier = multiplier 128 | self.org_module = org_module # remove in applying 129 | 130 | def apply_to(self): 131 | self.org_forward = self.org_module.forward 132 | self.org_module.forward = self.forward 133 | del self.org_module 134 | 135 | def forward(self, x): 136 | return ( 137 | self.org_forward(x) 138 | + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale 139 | ) 140 | 141 | 142 | class LoRANetwork(nn.Module): 143 | def __init__( 144 | self, 145 | unet: UNet2DConditionModel, 146 | rank: int = 4, 147 | multiplier: float = 1.0, 148 | alpha: float = 1.0, 149 | train_method: TRAINING_METHODS = "full", 150 | layers = ['Linear', 'Conv'] 151 | ) -> None: 152 | super().__init__() 153 | self.lora_scale = 1 154 | self.multiplier = multiplier 155 | self.lora_dim = rank 156 | self.alpha = alpha 157 | self.train_method=train_method 158 | # LoRAのみ 159 | self.module = LoRAModule 160 | 161 | # unetのloraを作る 162 | self.unet_loras = self.create_modules( 163 | LORA_PREFIX_UNET, 164 | unet, 165 | DEFAULT_TARGET_REPLACE, 166 | self.lora_dim, 167 | self.multiplier, 168 | train_method=train_method, 169 | layers = layers 170 | ) 171 | print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") 172 | 173 | # assertion 名前の被りがないか確認しているようだ 174 | lora_names = set() 175 | for lora in self.unet_loras: 176 | assert ( 177 | lora.lora_name not in lora_names 178 | ), f"duplicated lora name: {lora.lora_name}. {lora_names}" 179 | lora_names.add(lora.lora_name) 180 | 181 | # 適用する 182 | for lora in self.unet_loras: 183 | lora.apply_to() 184 | self.add_module( 185 | lora.lora_name, 186 | lora, 187 | ) 188 | 189 | del unet 190 | 191 | torch.cuda.empty_cache() 192 | 193 | def create_modules( 194 | self, 195 | prefix: str, 196 | root_module: nn.Module, 197 | target_replace_modules: List[str], 198 | rank: int, 199 | multiplier: float, 200 | train_method: TRAINING_METHODS, 201 | layers: List[str], 202 | ) -> list: 203 | filt_layers = [] 204 | if 'Linear' in layers: 205 | filt_layers.extend(["Linear", "LoRACompatibleLinear"]) 206 | if 'Conv' in layers: 207 | filt_layers.extend(["Conv2d", "LoRACompatibleConv"]) 208 | loras = [] 209 | names = [] 210 | for name, module in root_module.named_modules(): 211 | if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習 212 | if "attn2" in name or "time_embed" in name: 213 | continue 214 | elif train_method == "innoxattn": # Cross Attention 以外学習 215 | if "attn2" in name: 216 | continue 217 | elif train_method == "selfattn": # Self Attention のみ学習 218 | if "attn1" not in name: 219 | continue 220 | elif train_method in ["xattn", "xattn-strict", "xattn-up", "xattn-down", "xattn-mid"]: # Cross Attention 221 | if "attn" not in name: 222 | continue 223 | if train_method == 'xattn-up': 224 | if 'up_block' not in name: 225 | continue 226 | if train_method == 'xattn-down': 227 | if 'down_block' not in name: 228 | continue 229 | if train_method == 'xattn-mid': 230 | if 'mid_block' not in name: 231 | continue 232 | elif train_method == "full": # 全部学習 233 | pass 234 | else: 235 | raise NotImplementedError( 236 | f"train_method: {train_method} is not implemented." 237 | ) 238 | if module.__class__.__name__ in target_replace_modules: 239 | for child_name, child_module in module.named_modules(): 240 | if child_module.__class__.__name__ in filt_layers: 241 | 242 | 243 | if train_method == 'xattn-strict': 244 | if 'out' in child_name: 245 | continue 246 | if 'to_q' in child_name: 247 | continue 248 | if train_method == 'noxattn-hspace': 249 | if 'mid_block' not in name: 250 | continue 251 | if train_method == 'noxattn-hspace-last': 252 | if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name: 253 | continue 254 | lora_name = prefix + "." + name + "." + child_name 255 | lora_name = lora_name.replace(".", "_") 256 | # print(f"{lora_name}") 257 | lora = self.module( 258 | lora_name, child_module, multiplier, rank, self.alpha, train_method 259 | ) 260 | # print(name, child_name) 261 | # print(child_module.weight.shape) 262 | if lora_name not in names: 263 | loras.append(lora) 264 | names.append(lora_name) 265 | # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}') 266 | return loras 267 | 268 | def prepare_optimizer_params(self): 269 | all_params = [] 270 | 271 | if self.unet_loras: # 実質これしかない 272 | params = [] 273 | if self.train_method == 'full': 274 | [params.extend(lora.parameters()) for lora in self.unet_loras] 275 | else: 276 | [params.extend(lora.lora_down.parameters()) for lora in self.unet_loras] 277 | param_data = {"params": params} 278 | all_params.append(param_data) 279 | 280 | return all_params 281 | 282 | def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): 283 | state_dict = self.state_dict() 284 | 285 | if dtype is not None: 286 | for key in list(state_dict.keys()): 287 | v = state_dict[key] 288 | v = v.detach().clone().to("cpu").to(dtype) 289 | state_dict[key] = v 290 | 291 | if os.path.splitext(file)[1] == ".safetensors": 292 | save_file(state_dict, file, metadata) 293 | else: 294 | torch.save(state_dict, file) 295 | def set_lora_slider(self, scale): 296 | self.lora_scale = scale 297 | 298 | def __enter__(self): 299 | for lora in self.unet_loras: 300 | lora.multiplier = 1.0 * self.lora_scale 301 | 302 | def __exit__(self, exc_type, exc_value, tb): 303 | for lora in self.unet_loras: 304 | lora.multiplier = 0 305 | -------------------------------------------------------------------------------- /flux-sliders/utils/model_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Union, Optional 2 | 3 | import torch 4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection 5 | from diffusers import ( 6 | UNet2DConditionModel, 7 | SchedulerMixin, 8 | StableDiffusionPipeline, 9 | StableDiffusionXLPipeline, 10 | ) 11 | from diffusers.schedulers import ( 12 | DDIMScheduler, 13 | DDPMScheduler, 14 | LMSDiscreteScheduler, 15 | EulerAncestralDiscreteScheduler, 16 | ) 17 | 18 | 19 | TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" 20 | TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" 21 | 22 | AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"] 23 | 24 | SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] 25 | 26 | DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this 27 | 28 | 29 | def load_diffusers_model( 30 | pretrained_model_name_or_path: str, 31 | v2: bool = False, 32 | clip_skip: Optional[int] = None, 33 | weight_dtype: torch.dtype = torch.float32, 34 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: 35 | # VAE はいらない 36 | 37 | if v2: 38 | tokenizer = CLIPTokenizer.from_pretrained( 39 | TOKENIZER_V2_MODEL_NAME, 40 | subfolder="tokenizer", 41 | torch_dtype=weight_dtype, 42 | cache_dir=DIFFUSERS_CACHE_DIR, 43 | ) 44 | text_encoder = CLIPTextModel.from_pretrained( 45 | pretrained_model_name_or_path, 46 | subfolder="text_encoder", 47 | # default is clip skip 2 48 | num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, 49 | torch_dtype=weight_dtype, 50 | cache_dir=DIFFUSERS_CACHE_DIR, 51 | ) 52 | else: 53 | tokenizer = CLIPTokenizer.from_pretrained( 54 | TOKENIZER_V1_MODEL_NAME, 55 | subfolder="tokenizer", 56 | torch_dtype=weight_dtype, 57 | cache_dir=DIFFUSERS_CACHE_DIR, 58 | ) 59 | text_encoder = CLIPTextModel.from_pretrained( 60 | pretrained_model_name_or_path, 61 | subfolder="text_encoder", 62 | num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, 63 | torch_dtype=weight_dtype, 64 | cache_dir=DIFFUSERS_CACHE_DIR, 65 | ) 66 | 67 | unet = UNet2DConditionModel.from_pretrained( 68 | pretrained_model_name_or_path, 69 | subfolder="unet", 70 | torch_dtype=weight_dtype, 71 | cache_dir=DIFFUSERS_CACHE_DIR, 72 | ) 73 | 74 | return tokenizer, text_encoder, unet 75 | 76 | 77 | def load_checkpoint_model( 78 | checkpoint_path: str, 79 | v2: bool = False, 80 | clip_skip: Optional[int] = None, 81 | weight_dtype: torch.dtype = torch.float32, 82 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: 83 | pipe = StableDiffusionPipeline.from_ckpt( 84 | checkpoint_path, 85 | upcast_attention=True if v2 else False, 86 | torch_dtype=weight_dtype, 87 | cache_dir=DIFFUSERS_CACHE_DIR, 88 | ) 89 | 90 | unet = pipe.unet 91 | tokenizer = pipe.tokenizer 92 | text_encoder = pipe.text_encoder 93 | if clip_skip is not None: 94 | if v2: 95 | text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) 96 | else: 97 | text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) 98 | 99 | del pipe 100 | 101 | return tokenizer, text_encoder, unet 102 | 103 | 104 | def load_models( 105 | pretrained_model_name_or_path: str, 106 | scheduler_name: AVAILABLE_SCHEDULERS, 107 | v2: bool = False, 108 | v_pred: bool = False, 109 | weight_dtype: torch.dtype = torch.float32, 110 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: 111 | if pretrained_model_name_or_path.endswith( 112 | ".ckpt" 113 | ) or pretrained_model_name_or_path.endswith(".safetensors"): 114 | tokenizer, text_encoder, unet = load_checkpoint_model( 115 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype 116 | ) 117 | else: # diffusers 118 | tokenizer, text_encoder, unet = load_diffusers_model( 119 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype 120 | ) 121 | 122 | # VAE はいらない 123 | 124 | scheduler = create_noise_scheduler( 125 | scheduler_name, 126 | prediction_type="v_prediction" if v_pred else "epsilon", 127 | ) 128 | 129 | return tokenizer, text_encoder, unet, scheduler 130 | 131 | 132 | def load_diffusers_model_xl( 133 | pretrained_model_name_or_path: str, 134 | weight_dtype: torch.dtype = torch.float32, 135 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: 136 | # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet 137 | 138 | tokenizers = [ 139 | CLIPTokenizer.from_pretrained( 140 | pretrained_model_name_or_path, 141 | subfolder="tokenizer", 142 | torch_dtype=weight_dtype, 143 | cache_dir=DIFFUSERS_CACHE_DIR, 144 | ), 145 | CLIPTokenizer.from_pretrained( 146 | pretrained_model_name_or_path, 147 | subfolder="tokenizer_2", 148 | torch_dtype=weight_dtype, 149 | cache_dir=DIFFUSERS_CACHE_DIR, 150 | pad_token_id=0, # same as open clip 151 | ), 152 | ] 153 | 154 | text_encoders = [ 155 | CLIPTextModel.from_pretrained( 156 | pretrained_model_name_or_path, 157 | subfolder="text_encoder", 158 | torch_dtype=weight_dtype, 159 | cache_dir=DIFFUSERS_CACHE_DIR, 160 | ), 161 | CLIPTextModelWithProjection.from_pretrained( 162 | pretrained_model_name_or_path, 163 | subfolder="text_encoder_2", 164 | torch_dtype=weight_dtype, 165 | cache_dir=DIFFUSERS_CACHE_DIR, 166 | ), 167 | ] 168 | 169 | unet = UNet2DConditionModel.from_pretrained( 170 | pretrained_model_name_or_path, 171 | subfolder="unet", 172 | torch_dtype=weight_dtype, 173 | cache_dir=DIFFUSERS_CACHE_DIR, 174 | ) 175 | 176 | return tokenizers, text_encoders, unet 177 | 178 | 179 | def load_checkpoint_model_xl( 180 | checkpoint_path: str, 181 | weight_dtype: torch.dtype = torch.float32, 182 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: 183 | pipe = StableDiffusionXLPipeline.from_single_file( 184 | checkpoint_path, 185 | torch_dtype=weight_dtype, 186 | cache_dir=DIFFUSERS_CACHE_DIR, 187 | ) 188 | 189 | unet = pipe.unet 190 | tokenizers = [pipe.tokenizer, pipe.tokenizer_2] 191 | text_encoders = [pipe.text_encoder, pipe.text_encoder_2] 192 | if len(text_encoders) == 2: 193 | text_encoders[1].pad_token_id = 0 194 | 195 | del pipe 196 | 197 | return tokenizers, text_encoders, unet 198 | 199 | 200 | def load_models_xl( 201 | pretrained_model_name_or_path: str, 202 | scheduler_name: AVAILABLE_SCHEDULERS, 203 | weight_dtype: torch.dtype = torch.float32, 204 | ) -> tuple[ 205 | list[CLIPTokenizer], 206 | list[SDXL_TEXT_ENCODER_TYPE], 207 | UNet2DConditionModel, 208 | SchedulerMixin, 209 | ]: 210 | if pretrained_model_name_or_path.endswith( 211 | ".ckpt" 212 | ) or pretrained_model_name_or_path.endswith(".safetensors"): 213 | ( 214 | tokenizers, 215 | text_encoders, 216 | unet, 217 | ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype) 218 | else: # diffusers 219 | ( 220 | tokenizers, 221 | text_encoders, 222 | unet, 223 | ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype) 224 | 225 | scheduler = create_noise_scheduler(scheduler_name) 226 | 227 | return tokenizers, text_encoders, unet, scheduler 228 | 229 | 230 | def create_noise_scheduler( 231 | scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", 232 | prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", 233 | ) -> SchedulerMixin: 234 | # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。 235 | 236 | name = scheduler_name.lower().replace(" ", "_") 237 | if name == "ddim": 238 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim 239 | scheduler = DDIMScheduler( 240 | beta_start=0.00085, 241 | beta_end=0.012, 242 | beta_schedule="scaled_linear", 243 | num_train_timesteps=1000, 244 | clip_sample=False, 245 | prediction_type=prediction_type, # これでいいの? 246 | ) 247 | elif name == "ddpm": 248 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm 249 | scheduler = DDPMScheduler( 250 | beta_start=0.00085, 251 | beta_end=0.012, 252 | beta_schedule="scaled_linear", 253 | num_train_timesteps=1000, 254 | clip_sample=False, 255 | prediction_type=prediction_type, 256 | ) 257 | elif name == "lms": 258 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete 259 | scheduler = LMSDiscreteScheduler( 260 | beta_start=0.00085, 261 | beta_end=0.012, 262 | beta_schedule="scaled_linear", 263 | num_train_timesteps=1000, 264 | prediction_type=prediction_type, 265 | ) 266 | elif name == "euler_a": 267 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral 268 | scheduler = EulerAncestralDiscreteScheduler( 269 | beta_start=0.00085, 270 | beta_end=0.012, 271 | beta_schedule="scaled_linear", 272 | num_train_timesteps=1000, 273 | prediction_type=prediction_type, 274 | ) 275 | else: 276 | raise ValueError(f"Unknown scheduler name: {name}") 277 | 278 | return scheduler 279 | -------------------------------------------------------------------------------- /flux-sliders/utils/prompt_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, Union, List 2 | 3 | import yaml 4 | from pathlib import Path 5 | 6 | 7 | from pydantic import BaseModel, root_validator 8 | import torch 9 | import copy 10 | 11 | ACTION_TYPES = Literal[ 12 | "erase", 13 | "enhance", 14 | ] 15 | 16 | 17 | # XL は二種類必要なので 18 | class PromptEmbedsXL: 19 | text_embeds: torch.FloatTensor 20 | pooled_embeds: torch.FloatTensor 21 | 22 | def __init__(self, *args) -> None: 23 | self.text_embeds = args[0] 24 | self.pooled_embeds = args[1] 25 | 26 | 27 | # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL 28 | PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL] 29 | 30 | 31 | class PromptEmbedsCache: # 使いまわしたいので 32 | prompts: dict[str, PROMPT_EMBEDDING] = {} 33 | 34 | def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None: 35 | self.prompts[__name] = __value 36 | 37 | def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]: 38 | if __name in self.prompts: 39 | return self.prompts[__name] 40 | else: 41 | return None 42 | 43 | 44 | class PromptSettings(BaseModel): # yaml のやつ 45 | target: str 46 | positive: str = None # if None, target will be used 47 | unconditional: str = "" # default is "" 48 | neutral: str = None # if None, unconditional will be used 49 | action: ACTION_TYPES = "erase" # default is "erase" 50 | guidance_scale: float = 1.0 # default is 1.0 51 | resolution: int = 512 # default is 512 52 | dynamic_resolution: bool = False # default is False 53 | batch_size: int = 1 # default is 1 54 | dynamic_crops: bool = False # default is False. only used when model is XL 55 | 56 | @root_validator(pre=True) 57 | def fill_prompts(cls, values): 58 | keys = values.keys() 59 | if "target" not in keys: 60 | raise ValueError("target must be specified") 61 | if "positive" not in keys: 62 | values["positive"] = values["target"] 63 | if "unconditional" not in keys: 64 | values["unconditional"] = "" 65 | if "neutral" not in keys: 66 | values["neutral"] = values["unconditional"] 67 | 68 | return values 69 | 70 | 71 | class PromptEmbedsPair: 72 | target: PROMPT_EMBEDDING # not want to generate the concept 73 | positive: PROMPT_EMBEDDING # generate the concept 74 | unconditional: PROMPT_EMBEDDING # uncondition (default should be empty) 75 | neutral: PROMPT_EMBEDDING # base condition (default should be empty) 76 | 77 | guidance_scale: float 78 | resolution: int 79 | dynamic_resolution: bool 80 | batch_size: int 81 | dynamic_crops: bool 82 | 83 | loss_fn: torch.nn.Module 84 | action: ACTION_TYPES 85 | 86 | def __init__( 87 | self, 88 | loss_fn: torch.nn.Module, 89 | target: PROMPT_EMBEDDING, 90 | positive: PROMPT_EMBEDDING, 91 | unconditional: PROMPT_EMBEDDING, 92 | neutral: PROMPT_EMBEDDING, 93 | settings: PromptSettings, 94 | ) -> None: 95 | self.loss_fn = loss_fn 96 | self.target = target 97 | self.positive = positive 98 | self.unconditional = unconditional 99 | self.neutral = neutral 100 | 101 | self.guidance_scale = settings.guidance_scale 102 | self.resolution = settings.resolution 103 | self.dynamic_resolution = settings.dynamic_resolution 104 | self.batch_size = settings.batch_size 105 | self.dynamic_crops = settings.dynamic_crops 106 | self.action = settings.action 107 | 108 | def _erase( 109 | self, 110 | target_latents: torch.FloatTensor, # "van gogh" 111 | positive_latents: torch.FloatTensor, # "van gogh" 112 | unconditional_latents: torch.FloatTensor, # "" 113 | neutral_latents: torch.FloatTensor, # "" 114 | ) -> torch.FloatTensor: 115 | """Target latents are going not to have the positive concept.""" 116 | return self.loss_fn( 117 | target_latents, 118 | neutral_latents 119 | - self.guidance_scale * (positive_latents - unconditional_latents) 120 | ) 121 | 122 | 123 | def _enhance( 124 | self, 125 | target_latents: torch.FloatTensor, # "van gogh" 126 | positive_latents: torch.FloatTensor, # "van gogh" 127 | unconditional_latents: torch.FloatTensor, # "" 128 | neutral_latents: torch.FloatTensor, # "" 129 | ): 130 | """Target latents are going to have the positive concept.""" 131 | return self.loss_fn( 132 | target_latents, 133 | neutral_latents 134 | + self.guidance_scale * (positive_latents - unconditional_latents) 135 | ) 136 | 137 | def loss( 138 | self, 139 | **kwargs, 140 | ): 141 | if self.action == "erase": 142 | return self._erase(**kwargs) 143 | 144 | elif self.action == "enhance": 145 | return self._enhance(**kwargs) 146 | 147 | else: 148 | raise ValueError("action must be erase or enhance") 149 | 150 | 151 | def load_prompts_from_yaml(path, attributes = []): 152 | with open(path, "r") as f: 153 | prompts = yaml.safe_load(f) 154 | print(prompts) 155 | if len(prompts) == 0: 156 | raise ValueError("prompts file is empty") 157 | if len(attributes)!=0: 158 | newprompts = [] 159 | for i in range(len(prompts)): 160 | for att in attributes: 161 | copy_ = copy.deepcopy(prompts[i]) 162 | copy_['target'] = att + ' ' + copy_['target'] 163 | copy_['positive'] = att + ' ' + copy_['positive'] 164 | copy_['neutral'] = att + ' ' + copy_['neutral'] 165 | copy_['unconditional'] = att + ' ' + copy_['unconditional'] 166 | newprompts.append(copy_) 167 | else: 168 | newprompts = copy.deepcopy(prompts) 169 | 170 | print(newprompts) 171 | print(len(prompts), len(newprompts)) 172 | prompt_settings = [PromptSettings(**prompt) for prompt in newprompts] 173 | 174 | return prompt_settings 175 | -------------------------------------------------------------------------------- /flux-sliders/utils/ptp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image, ImageDraw, ImageFont 18 | import cv2 19 | from typing import Optional, Union, Tuple, List, Callable, Dict 20 | from IPython.display import display 21 | from tqdm.notebook import tqdm 22 | 23 | 24 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 25 | h, w, c = image.shape 26 | offset = int(h * .2) 27 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 28 | font = cv2.FONT_HERSHEY_SIMPLEX 29 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 30 | img[:h] = image 31 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 32 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 33 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) 34 | return img 35 | 36 | 37 | def view_images(images, num_rows=1, offset_ratio=0.02): 38 | if type(images) is list: 39 | num_empty = len(images) % num_rows 40 | elif images.ndim == 4: 41 | num_empty = images.shape[0] % num_rows 42 | else: 43 | images = [images] 44 | num_empty = 0 45 | 46 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 47 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 48 | num_items = len(images) 49 | 50 | h, w, c = images[0].shape 51 | offset = int(h * offset_ratio) 52 | num_cols = num_items // num_rows 53 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 54 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 55 | for i in range(num_rows): 56 | for j in range(num_cols): 57 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 58 | i * num_cols + j] 59 | 60 | pil_img = Image.fromarray(image_) 61 | display(pil_img) 62 | 63 | 64 | def diffusion_step(unet, model, controller, latents, context, t, guidance_scale, low_resource=False): 65 | if low_resource: 66 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 67 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 68 | else: 69 | latents_input = torch.cat([latents] * 2) 70 | noise_pred = unet(latents_input, t, encoder_hidden_states=context)["sample"] 71 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 72 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 73 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 74 | latents = controller.step_callback(latents) 75 | return latents 76 | 77 | 78 | def latent2image(vae, latents): 79 | latents = 1 / 0.18215 * latents 80 | image = vae.decode(latents)['sample'] 81 | image = (image / 2 + 0.5).clamp(0, 1) 82 | image = image.cpu().permute(0, 2, 3, 1).numpy() 83 | image = (image * 255).astype(np.uint8) 84 | return image 85 | 86 | 87 | def init_latent(latent, model, height, width, generator, batch_size): 88 | if latent is None: 89 | latent = torch.randn( 90 | (1, model.unet.in_channels, height // 8, width // 8), 91 | generator=generator, 92 | ) 93 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 94 | return latent, latents 95 | 96 | 97 | @torch.no_grad() 98 | def text2image_ldm( 99 | model, 100 | prompt: List[str], 101 | controller, 102 | num_inference_steps: int = 50, 103 | guidance_scale: Optional[float] = 7., 104 | generator: Optional[torch.Generator] = None, 105 | latent: Optional[torch.FloatTensor] = None, 106 | ): 107 | register_attention_control(model, controller) 108 | height = width = 256 109 | batch_size = len(prompt) 110 | 111 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 112 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] 113 | 114 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 115 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] 116 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 117 | context = torch.cat([uncond_embeddings, text_embeddings]) 118 | 119 | model.scheduler.set_timesteps(num_inference_steps) 120 | for t in tqdm(model.scheduler.timesteps): 121 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 122 | 123 | image = latent2image(model.vqvae, latents) 124 | 125 | return image, latent 126 | 127 | 128 | @torch.no_grad() 129 | def text2image_ldm_stable( 130 | model, 131 | prompt: List[str], 132 | controller, 133 | num_inference_steps: int = 50, 134 | guidance_scale: float = 7.5, 135 | generator: Optional[torch.Generator] = None, 136 | latent: Optional[torch.FloatTensor] = None, 137 | low_resource: bool = False, 138 | ): 139 | register_attention_control(model, controller) 140 | height = width = 512 141 | batch_size = len(prompt) 142 | 143 | text_input = model.tokenizer( 144 | prompt, 145 | padding="max_length", 146 | max_length=model.tokenizer.model_max_length, 147 | truncation=True, 148 | return_tensors="pt", 149 | ) 150 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 151 | max_length = text_input.input_ids.shape[-1] 152 | uncond_input = model.tokenizer( 153 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 154 | ) 155 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 156 | 157 | context = [uncond_embeddings, text_embeddings] 158 | if not low_resource: 159 | context = torch.cat(context) 160 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 161 | 162 | # set timesteps 163 | extra_set_kwargs = {"offset": 1} 164 | model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 165 | for t in tqdm(model.scheduler.timesteps): 166 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) 167 | 168 | image = latent2image(model.vae, latents) 169 | 170 | return image, latent 171 | 172 | 173 | def register_attention_control(model, controller): 174 | def ca_forward(self, place_in_unet): 175 | to_out = self.to_out 176 | if type(to_out) is torch.nn.modules.container.ModuleList: 177 | to_out = self.to_out[0] 178 | else: 179 | to_out = self.to_out 180 | 181 | def forward(x, context=None, mask=None): 182 | batch_size, sequence_length, dim = x.shape 183 | h = self.heads 184 | q = self.to_q(x) 185 | is_cross = context is not None 186 | context = context if is_cross else x 187 | k = self.to_k(context) 188 | v = self.to_v(context) 189 | q = self.reshape_heads_to_batch_dim(q) 190 | k = self.reshape_heads_to_batch_dim(k) 191 | v = self.reshape_heads_to_batch_dim(v) 192 | 193 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale 194 | 195 | if mask is not None: 196 | mask = mask.reshape(batch_size, -1) 197 | max_neg_value = -torch.finfo(sim.dtype).max 198 | mask = mask[:, None, :].repeat(h, 1, 1) 199 | sim.masked_fill_(~mask, max_neg_value) 200 | 201 | # attention, what we cannot get enough of 202 | attn = sim.softmax(dim=-1) 203 | attn = controller(attn, is_cross, place_in_unet) 204 | out = torch.einsum("b i j, b j d -> b i d", attn, v) 205 | out = self.reshape_batch_dim_to_heads(out) 206 | return to_out(out) 207 | 208 | return forward 209 | 210 | class DummyController: 211 | 212 | def __call__(self, *args): 213 | return args[0] 214 | 215 | def __init__(self): 216 | self.num_att_layers = 0 217 | 218 | if controller is None: 219 | controller = DummyController() 220 | 221 | def register_recr(net_, count, place_in_unet): 222 | if net_.__class__.__name__ == 'CrossAttention': 223 | net_.forward = ca_forward(net_, place_in_unet) 224 | return count + 1 225 | elif hasattr(net_, 'children'): 226 | for net__ in net_.children(): 227 | count = register_recr(net__, count, place_in_unet) 228 | return count 229 | 230 | cross_att_count = 0 231 | sub_nets = model.unet.named_children() 232 | for net in sub_nets: 233 | if "down" in net[0]: 234 | cross_att_count += register_recr(net[1], 0, "down") 235 | elif "up" in net[0]: 236 | cross_att_count += register_recr(net[1], 0, "up") 237 | elif "mid" in net[0]: 238 | cross_att_count += register_recr(net[1], 0, "mid") 239 | 240 | controller.num_att_layers = cross_att_count 241 | 242 | 243 | def get_word_inds(text: str, word_place: int, tokenizer): 244 | split_text = text.split(" ") 245 | if type(word_place) is str: 246 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 247 | elif type(word_place) is int: 248 | word_place = [word_place] 249 | out = [] 250 | if len(word_place) > 0: 251 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 252 | cur_len, ptr = 0, 0 253 | 254 | for i in range(len(words_encode)): 255 | cur_len += len(words_encode[i]) 256 | if ptr in word_place: 257 | out.append(i + 1) 258 | if cur_len >= len(split_text[ptr]): 259 | ptr += 1 260 | cur_len = 0 261 | return np.array(out) 262 | 263 | 264 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 265 | word_inds: Optional[torch.Tensor]=None): 266 | if type(bounds) is float: 267 | bounds = 0, bounds 268 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 269 | if word_inds is None: 270 | word_inds = torch.arange(alpha.shape[2]) 271 | alpha[: start, prompt_ind, word_inds] = 0 272 | alpha[start: end, prompt_ind, word_inds] = 1 273 | alpha[end:, prompt_ind, word_inds] = 0 274 | return alpha 275 | 276 | 277 | def get_time_words_attention_alpha(prompts, num_steps, 278 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 279 | tokenizer, max_num_words=77): 280 | if type(cross_replace_steps) is not dict: 281 | cross_replace_steps = {"default_": cross_replace_steps} 282 | if "default_" not in cross_replace_steps: 283 | cross_replace_steps["default_"] = (0., 1.) 284 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 285 | for i in range(len(prompts) - 1): 286 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 287 | i) 288 | for key, item in cross_replace_steps.items(): 289 | if key != "default_": 290 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 291 | for i, ind in enumerate(inds): 292 | if len(ind) > 0: 293 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 294 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 295 | return alpha_time_words -------------------------------------------------------------------------------- /images/main_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/images/main_figure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes==0.41.1 2 | dadaptation==3.1 3 | diffusers==0.20.2 4 | ipython==8.7.0 5 | lion_pytorch==0.1.2 6 | lpips==0.1.4 7 | matplotlib==3.6.2 8 | numpy==1.23.5 9 | opencv_python==4.5.5.64 10 | opencv_python_headless==4.7.0.68 11 | pandas==1.5.2 12 | Pillow==10.1.0 13 | prodigyopt==1.0 14 | pydantic==2.6.3 15 | PyYAML==6.0.1 16 | Requests==2.31.0 17 | safetensors==0.3.1 18 | torch==2.0.1 19 | torchvision==0.15.2 20 | tqdm==4.64.1 21 | transformers==4.27.4 22 | wandb==0.12.21 23 | xformers==0.0.21 24 | accelerate==0.16.0 25 | -------------------------------------------------------------------------------- /trainscripts/__init__.py: -------------------------------------------------------------------------------- 1 | # from textsliders import lora -------------------------------------------------------------------------------- /trainscripts/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trainscripts/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/imagesliders/__pycache__/config_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/config_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/imagesliders/__pycache__/debug_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/debug_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/imagesliders/__pycache__/lora.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/lora.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/imagesliders/__pycache__/model_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/model_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/imagesliders/__pycache__/prompt_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/prompt_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/imagesliders/__pycache__/train_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/train_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/imagesliders/config_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import yaml 4 | 5 | from pydantic import BaseModel 6 | import torch 7 | 8 | from lora import TRAINING_METHODS 9 | 10 | PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"] 11 | NETWORK_TYPES = Literal["lierla", "c3lier"] 12 | 13 | 14 | class PretrainedModelConfig(BaseModel): 15 | name_or_path: str 16 | v2: bool = False 17 | v_pred: bool = False 18 | 19 | clip_skip: Optional[int] = None 20 | 21 | 22 | class NetworkConfig(BaseModel): 23 | type: NETWORK_TYPES = "lierla" 24 | rank: int = 4 25 | alpha: float = 1.0 26 | 27 | training_method: TRAINING_METHODS = "full" 28 | 29 | 30 | class TrainConfig(BaseModel): 31 | precision: PRECISION_TYPES = "bfloat16" 32 | noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim" 33 | 34 | iterations: int = 500 35 | lr: float = 1e-4 36 | optimizer: str = "adamw" 37 | optimizer_args: str = "" 38 | lr_scheduler: str = "constant" 39 | 40 | max_denoising_steps: int = 50 41 | 42 | 43 | class SaveConfig(BaseModel): 44 | name: str = "untitled" 45 | path: str = "./output" 46 | per_steps: int = 200 47 | precision: PRECISION_TYPES = "float32" 48 | 49 | 50 | class LoggingConfig(BaseModel): 51 | use_wandb: bool = False 52 | 53 | verbose: bool = False 54 | 55 | 56 | class OtherConfig(BaseModel): 57 | use_xformers: bool = False 58 | 59 | 60 | class RootConfig(BaseModel): 61 | prompts_file: str 62 | pretrained_model: PretrainedModelConfig 63 | 64 | network: NetworkConfig 65 | 66 | train: Optional[TrainConfig] 67 | 68 | save: Optional[SaveConfig] 69 | 70 | logging: Optional[LoggingConfig] 71 | 72 | other: Optional[OtherConfig] 73 | 74 | 75 | def parse_precision(precision: str) -> torch.dtype: 76 | if precision == "fp32" or precision == "float32": 77 | return torch.float32 78 | elif precision == "fp16" or precision == "float16": 79 | return torch.float16 80 | elif precision == "bf16" or precision == "bfloat16": 81 | return torch.bfloat16 82 | 83 | raise ValueError(f"Invalid precision type: {precision}") 84 | 85 | 86 | def load_config_from_yaml(config_path: str) -> RootConfig: 87 | with open(config_path, "r") as f: 88 | config = yaml.load(f, Loader=yaml.FullLoader) 89 | 90 | root = RootConfig(**config) 91 | 92 | if root.train is None: 93 | root.train = TrainConfig() 94 | 95 | if root.save is None: 96 | root.save = SaveConfig() 97 | 98 | if root.logging is None: 99 | root.logging = LoggingConfig() 100 | 101 | if root.other is None: 102 | root.other = OtherConfig() 103 | 104 | return root 105 | -------------------------------------------------------------------------------- /trainscripts/imagesliders/data/config-xl.yaml: -------------------------------------------------------------------------------- 1 | prompts_file: "trainscripts/imagesliders/data/prompts-xl.yaml" 2 | pretrained_model: 3 | name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models 4 | v2: false # true if model is v2.x 5 | v_pred: false # true if model uses v-prediction 6 | network: 7 | type: "c3lier" # or "c3lier" or "lierla" 8 | rank: 4 9 | alpha: 1.0 10 | training_method: "noxattn" 11 | train: 12 | precision: "bfloat16" 13 | noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" 14 | iterations: 1000 15 | lr: 0.0002 16 | optimizer: "AdamW" 17 | lr_scheduler: "constant" 18 | max_denoising_steps: 50 19 | save: 20 | name: "temp" 21 | path: "./models" 22 | per_steps: 500 23 | precision: "bfloat16" 24 | logging: 25 | use_wandb: false 26 | verbose: false 27 | other: 28 | use_xformers: true -------------------------------------------------------------------------------- /trainscripts/imagesliders/data/config.yaml: -------------------------------------------------------------------------------- 1 | prompts_file: "trainscripts/imagesliders/data/prompts.yaml" 2 | pretrained_model: 3 | name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models 4 | v2: false # true if model is v2.x 5 | v_pred: false # true if model uses v-prediction 6 | network: 7 | type: "c3lier" # or "c3lier" or "lierla" 8 | rank: 4 9 | alpha: 1.0 10 | training_method: "noxattn" 11 | train: 12 | precision: "bfloat16" 13 | noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" 14 | iterations: 1000 15 | lr: 0.0002 16 | optimizer: "AdamW" 17 | lr_scheduler: "constant" 18 | max_denoising_steps: 50 19 | save: 20 | name: "temp" 21 | path: "./models" 22 | per_steps: 500 23 | precision: "bfloat16" 24 | logging: 25 | use_wandb: false 26 | verbose: false 27 | other: 28 | use_xformers: true -------------------------------------------------------------------------------- /trainscripts/imagesliders/data/prompts.yaml: -------------------------------------------------------------------------------- 1 | # - target: "person" # what word for erasing the positive concept from 2 | # positive: "person, very old" # concept to erase 3 | # unconditional: "person" # word to take the difference from the positive concept 4 | # neutral: "person" # starting point for conditioning the target 5 | # action: "enhance" # erase or enhance 6 | # guidance_scale: 4 7 | # resolution: 512 8 | # dynamic_resolution: false 9 | # batch_size: 1 10 | - target: "" # what word for erasing the positive concept from 11 | positive: "" # concept to erase 12 | unconditional: "" # word to take the difference from the positive concept 13 | neutral: "" # starting point for conditioning the target 14 | action: "enhance" # erase or enhance 15 | guidance_scale: 1 16 | resolution: 512 17 | dynamic_resolution: false 18 | batch_size: 1 19 | # - target: "" # what word for erasing the positive concept from 20 | # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase 21 | # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept 22 | # neutral: "" # starting point for conditioning the target 23 | # action: "enhance" # erase or enhance 24 | # guidance_scale: 4 25 | # resolution: 512 26 | # dynamic_resolution: false 27 | # batch_size: 1 28 | # - target: "food" # what word for erasing the positive concept from 29 | # positive: "food, expensive and fine dining" # concept to erase 30 | # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept 31 | # neutral: "food" # starting point for conditioning the target 32 | # action: "enhance" # erase or enhance 33 | # guidance_scale: 4 34 | # resolution: 512 35 | # dynamic_resolution: false 36 | # batch_size: 1 37 | # - target: "room" # what word for erasing the positive concept from 38 | # positive: "room, dirty disorganised and cluttered" # concept to erase 39 | # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept 40 | # neutral: "room" # starting point for conditioning the target 41 | # action: "enhance" # erase or enhance 42 | # guidance_scale: 4 43 | # resolution: 512 44 | # dynamic_resolution: false 45 | # batch_size: 1 46 | # - target: "male person" # what word for erasing the positive concept from 47 | # positive: "male person, with a surprised look" # concept to erase 48 | # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept 49 | # neutral: "male person" # starting point for conditioning the target 50 | # action: "enhance" # erase or enhance 51 | # guidance_scale: 4 52 | # resolution: 512 53 | # dynamic_resolution: false 54 | # batch_size: 1 55 | # - target: "female person" # what word for erasing the positive concept from 56 | # positive: "female person, with a surprised look" # concept to erase 57 | # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept 58 | # neutral: "female person" # starting point for conditioning the target 59 | # action: "enhance" # erase or enhance 60 | # guidance_scale: 4 61 | # resolution: 512 62 | # dynamic_resolution: false 63 | # batch_size: 1 64 | # - target: "sky" # what word for erasing the positive concept from 65 | # positive: "peaceful sky" # concept to erase 66 | # unconditional: "sky" # word to take the difference from the positive concept 67 | # neutral: "sky" # starting point for conditioning the target 68 | # action: "enhance" # erase or enhance 69 | # guidance_scale: 4 70 | # resolution: 512 71 | # dynamic_resolution: false 72 | # batch_size: 1 73 | # - target: "sky" # what word for erasing the positive concept from 74 | # positive: "chaotic dark sky" # concept to erase 75 | # unconditional: "sky" # word to take the difference from the positive concept 76 | # neutral: "sky" # starting point for conditioning the target 77 | # action: "erase" # erase or enhance 78 | # guidance_scale: 4 79 | # resolution: 512 80 | # dynamic_resolution: false 81 | # batch_size: 1 82 | # - target: "person" # what word for erasing the positive concept from 83 | # positive: "person, very young" # concept to erase 84 | # unconditional: "person" # word to take the difference from the positive concept 85 | # neutral: "person" # starting point for conditioning the target 86 | # action: "erase" # erase or enhance 87 | # guidance_scale: 4 88 | # resolution: 512 89 | # dynamic_resolution: false 90 | # batch_size: 1 91 | # overweight 92 | # - target: "art" # what word for erasing the positive concept from 93 | # positive: "realistic art" # concept to erase 94 | # unconditional: "art" # word to take the difference from the positive concept 95 | # neutral: "art" # starting point for conditioning the target 96 | # action: "enhance" # erase or enhance 97 | # guidance_scale: 4 98 | # resolution: 512 99 | # dynamic_resolution: false 100 | # batch_size: 1 101 | # - target: "art" # what word for erasing the positive concept from 102 | # positive: "abstract art" # concept to erase 103 | # unconditional: "art" # word to take the difference from the positive concept 104 | # neutral: "art" # starting point for conditioning the target 105 | # action: "erase" # erase or enhance 106 | # guidance_scale: 4 107 | # resolution: 512 108 | # dynamic_resolution: false 109 | # batch_size: 1 110 | # sky 111 | # - target: "weather" # what word for erasing the positive concept from 112 | # positive: "bright pleasant weather" # concept to erase 113 | # unconditional: "weather" # word to take the difference from the positive concept 114 | # neutral: "weather" # starting point for conditioning the target 115 | # action: "enhance" # erase or enhance 116 | # guidance_scale: 4 117 | # resolution: 512 118 | # dynamic_resolution: false 119 | # batch_size: 1 120 | # - target: "weather" # what word for erasing the positive concept from 121 | # positive: "dark gloomy weather" # concept to erase 122 | # unconditional: "weather" # word to take the difference from the positive concept 123 | # neutral: "weather" # starting point for conditioning the target 124 | # action: "erase" # erase or enhance 125 | # guidance_scale: 4 126 | # resolution: 512 127 | # dynamic_resolution: false 128 | # batch_size: 1 129 | # hair 130 | # - target: "person" # what word for erasing the positive concept from 131 | # positive: "person with long hair" # concept to erase 132 | # unconditional: "person" # word to take the difference from the positive concept 133 | # neutral: "person" # starting point for conditioning the target 134 | # action: "enhance" # erase or enhance 135 | # guidance_scale: 4 136 | # resolution: 512 137 | # dynamic_resolution: false 138 | # batch_size: 1 139 | # - target: "person" # what word for erasing the positive concept from 140 | # positive: "person with short hair" # concept to erase 141 | # unconditional: "person" # word to take the difference from the positive concept 142 | # neutral: "person" # starting point for conditioning the target 143 | # action: "erase" # erase or enhance 144 | # guidance_scale: 4 145 | # resolution: 512 146 | # dynamic_resolution: false 147 | # batch_size: 1 148 | # - target: "girl" # what word for erasing the positive concept from 149 | # positive: "baby girl" # concept to erase 150 | # unconditional: "girl" # word to take the difference from the positive concept 151 | # neutral: "girl" # starting point for conditioning the target 152 | # action: "enhance" # erase or enhance 153 | # guidance_scale: -4 154 | # resolution: 512 155 | # dynamic_resolution: false 156 | # batch_size: 1 157 | # - target: "boy" # what word for erasing the positive concept from 158 | # positive: "old man" # concept to erase 159 | # unconditional: "boy" # word to take the difference from the positive concept 160 | # neutral: "boy" # starting point for conditioning the target 161 | # action: "enhance" # erase or enhance 162 | # guidance_scale: 4 163 | # resolution: 512 164 | # dynamic_resolution: false 165 | # batch_size: 1 166 | # - target: "boy" # what word for erasing the positive concept from 167 | # positive: "baby boy" # concept to erase 168 | # unconditional: "boy" # word to take the difference from the positive concept 169 | # neutral: "boy" # starting point for conditioning the target 170 | # action: "enhance" # erase or enhance 171 | # guidance_scale: -4 172 | # resolution: 512 173 | # dynamic_resolution: false 174 | # batch_size: 1 -------------------------------------------------------------------------------- /trainscripts/imagesliders/debug_util.py: -------------------------------------------------------------------------------- 1 | # デバッグ用... 2 | 3 | import torch 4 | 5 | 6 | def check_requires_grad(model: torch.nn.Module): 7 | for name, module in list(model.named_modules())[:5]: 8 | if len(list(module.parameters())) > 0: 9 | print(f"Module: {name}") 10 | for name, param in list(module.named_parameters())[:2]: 11 | print(f" Parameter: {name}, Requires Grad: {param.requires_grad}") 12 | 13 | 14 | def check_training_mode(model: torch.nn.Module): 15 | for name, module in list(model.named_modules())[:5]: 16 | print(f"Module: {name}, Training Mode: {module.training}") 17 | -------------------------------------------------------------------------------- /trainscripts/imagesliders/lora.py: -------------------------------------------------------------------------------- 1 | # ref: 2 | # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py 3 | # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py 4 | 5 | import os 6 | import math 7 | from typing import Optional, List, Type, Set, Literal 8 | 9 | import torch 10 | import torch.nn as nn 11 | from diffusers import UNet2DConditionModel 12 | from safetensors.torch import save_file 13 | 14 | 15 | UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ 16 | # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 17 | "Attention" 18 | ] 19 | UNET_TARGET_REPLACE_MODULE_CONV = [ 20 | "ResnetBlock2D", 21 | "Downsample2D", 22 | "Upsample2D", 23 | # "DownBlock2D", 24 | # "UpBlock2D" 25 | ] # locon, 3clier 26 | 27 | LORA_PREFIX_UNET = "lora_unet" 28 | 29 | DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER 30 | 31 | TRAINING_METHODS = Literal[ 32 | "noxattn", # train all layers except x-attns and time_embed layers 33 | "innoxattn", # train all layers except self attention layers 34 | "selfattn", # ESD-u, train only self attention layers 35 | "xattn", # ESD-x, train only x attention layers 36 | "full", # train all layers 37 | "xattn-strict", # q and k values 38 | "noxattn-hspace", 39 | "noxattn-hspace-last", 40 | # "xlayer", 41 | # "outxattn", 42 | # "outsattn", 43 | # "inxattn", 44 | # "inmidsattn", 45 | # "selflayer", 46 | ] 47 | 48 | 49 | class LoRAModule(nn.Module): 50 | """ 51 | replaces forward method of the original Linear, instead of replacing the original Linear module. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | lora_name, 57 | org_module: nn.Module, 58 | multiplier=1.0, 59 | lora_dim=4, 60 | alpha=1, 61 | ): 62 | """if alpha == 0 or None, alpha is rank (no scaling).""" 63 | super().__init__() 64 | self.lora_name = lora_name 65 | self.lora_dim = lora_dim 66 | 67 | if "Linear" in org_module.__class__.__name__: 68 | in_dim = org_module.in_features 69 | out_dim = org_module.out_features 70 | self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) 71 | self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) 72 | 73 | elif "Conv" in org_module.__class__.__name__: # 一応 74 | in_dim = org_module.in_channels 75 | out_dim = org_module.out_channels 76 | 77 | self.lora_dim = min(self.lora_dim, in_dim, out_dim) 78 | if self.lora_dim != lora_dim: 79 | print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") 80 | 81 | kernel_size = org_module.kernel_size 82 | stride = org_module.stride 83 | padding = org_module.padding 84 | self.lora_down = nn.Conv2d( 85 | in_dim, self.lora_dim, kernel_size, stride, padding, bias=False 86 | ) 87 | self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) 88 | 89 | if type(alpha) == torch.Tensor: 90 | alpha = alpha.detach().numpy() 91 | alpha = lora_dim if alpha is None or alpha == 0 else alpha 92 | self.scale = alpha / self.lora_dim 93 | self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える 94 | 95 | # same as microsoft's 96 | nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) 97 | nn.init.zeros_(self.lora_up.weight) 98 | 99 | self.multiplier = multiplier 100 | self.org_module = org_module # remove in applying 101 | 102 | def apply_to(self): 103 | self.org_forward = self.org_module.forward 104 | self.org_module.forward = self.forward 105 | del self.org_module 106 | 107 | def forward(self, x): 108 | return ( 109 | self.org_forward(x) 110 | + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale 111 | ) 112 | 113 | 114 | class LoRANetwork(nn.Module): 115 | def __init__( 116 | self, 117 | unet: UNet2DConditionModel, 118 | rank: int = 4, 119 | multiplier: float = 1.0, 120 | alpha: float = 1.0, 121 | train_method: TRAINING_METHODS = "full", 122 | ) -> None: 123 | super().__init__() 124 | self.lora_scale = 1 125 | self.multiplier = multiplier 126 | self.lora_dim = rank 127 | self.alpha = alpha 128 | 129 | # LoRAのみ 130 | self.module = LoRAModule 131 | 132 | # unetのloraを作る 133 | self.unet_loras = self.create_modules( 134 | LORA_PREFIX_UNET, 135 | unet, 136 | DEFAULT_TARGET_REPLACE, 137 | self.lora_dim, 138 | self.multiplier, 139 | train_method=train_method, 140 | ) 141 | print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") 142 | 143 | # assertion 名前の被りがないか確認しているようだ 144 | lora_names = set() 145 | for lora in self.unet_loras: 146 | assert ( 147 | lora.lora_name not in lora_names 148 | ), f"duplicated lora name: {lora.lora_name}. {lora_names}" 149 | lora_names.add(lora.lora_name) 150 | 151 | # 適用する 152 | for lora in self.unet_loras: 153 | lora.apply_to() 154 | self.add_module( 155 | lora.lora_name, 156 | lora, 157 | ) 158 | 159 | del unet 160 | 161 | torch.cuda.empty_cache() 162 | 163 | def create_modules( 164 | self, 165 | prefix: str, 166 | root_module: nn.Module, 167 | target_replace_modules: List[str], 168 | rank: int, 169 | multiplier: float, 170 | train_method: TRAINING_METHODS, 171 | ) -> list: 172 | loras = [] 173 | names = [] 174 | for name, module in root_module.named_modules(): 175 | if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習 176 | if "attn2" in name or "time_embed" in name: 177 | continue 178 | elif train_method == "innoxattn": # Cross Attention 以外学習 179 | if "attn2" in name: 180 | continue 181 | elif train_method == "selfattn": # Self Attention のみ学習 182 | if "attn1" not in name: 183 | continue 184 | elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習 185 | if "attn2" not in name: 186 | continue 187 | elif train_method == "full": # 全部学習 188 | pass 189 | else: 190 | raise NotImplementedError( 191 | f"train_method: {train_method} is not implemented." 192 | ) 193 | if module.__class__.__name__ in target_replace_modules: 194 | for child_name, child_module in module.named_modules(): 195 | if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]: 196 | if train_method == 'xattn-strict': 197 | if 'out' in child_name: 198 | continue 199 | if train_method == 'noxattn-hspace': 200 | if 'mid_block' not in name: 201 | continue 202 | if train_method == 'noxattn-hspace-last': 203 | if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name: 204 | continue 205 | lora_name = prefix + "." + name + "." + child_name 206 | lora_name = lora_name.replace(".", "_") 207 | # print(f"{lora_name}") 208 | lora = self.module( 209 | lora_name, child_module, multiplier, rank, self.alpha 210 | ) 211 | # print(name, child_name) 212 | # print(child_module.weight.shape) 213 | loras.append(lora) 214 | names.append(lora_name) 215 | # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}') 216 | return loras 217 | 218 | def prepare_optimizer_params(self): 219 | all_params = [] 220 | 221 | if self.unet_loras: # 実質これしかない 222 | params = [] 223 | [params.extend(lora.parameters()) for lora in self.unet_loras] 224 | param_data = {"params": params} 225 | all_params.append(param_data) 226 | 227 | return all_params 228 | 229 | def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): 230 | state_dict = self.state_dict() 231 | 232 | if dtype is not None: 233 | for key in list(state_dict.keys()): 234 | v = state_dict[key] 235 | v = v.detach().clone().to("cpu").to(dtype) 236 | state_dict[key] = v 237 | 238 | # for key in list(state_dict.keys()): 239 | # if not key.startswith("lora"): 240 | # # lora以外除外 241 | # del state_dict[key] 242 | 243 | if os.path.splitext(file)[1] == ".safetensors": 244 | save_file(state_dict, file, metadata) 245 | else: 246 | torch.save(state_dict, file) 247 | def set_lora_slider(self, scale): 248 | self.lora_scale = scale 249 | 250 | def __enter__(self): 251 | for lora in self.unet_loras: 252 | lora.multiplier = 1.0 * self.lora_scale 253 | 254 | def __exit__(self, exc_type, exc_value, tb): 255 | for lora in self.unet_loras: 256 | lora.multiplier = 0 257 | -------------------------------------------------------------------------------- /trainscripts/imagesliders/model_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Union, Optional 2 | 3 | import torch 4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection 5 | from diffusers import ( 6 | UNet2DConditionModel, 7 | SchedulerMixin, 8 | StableDiffusionPipeline, 9 | StableDiffusionXLPipeline, 10 | AutoencoderKL, 11 | ) 12 | from diffusers.schedulers import ( 13 | DDIMScheduler, 14 | DDPMScheduler, 15 | LMSDiscreteScheduler, 16 | EulerAncestralDiscreteScheduler, 17 | ) 18 | 19 | 20 | TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" 21 | TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" 22 | 23 | AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"] 24 | 25 | SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] 26 | 27 | DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this 28 | 29 | 30 | def load_diffusers_model( 31 | pretrained_model_name_or_path: str, 32 | v2: bool = False, 33 | clip_skip: Optional[int] = None, 34 | weight_dtype: torch.dtype = torch.float32, 35 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: 36 | # VAE はいらない 37 | 38 | if v2: 39 | tokenizer = CLIPTokenizer.from_pretrained( 40 | TOKENIZER_V2_MODEL_NAME, 41 | subfolder="tokenizer", 42 | torch_dtype=weight_dtype, 43 | cache_dir=DIFFUSERS_CACHE_DIR, 44 | ) 45 | text_encoder = CLIPTextModel.from_pretrained( 46 | pretrained_model_name_or_path, 47 | subfolder="text_encoder", 48 | # default is clip skip 2 49 | num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, 50 | torch_dtype=weight_dtype, 51 | cache_dir=DIFFUSERS_CACHE_DIR, 52 | ) 53 | else: 54 | tokenizer = CLIPTokenizer.from_pretrained( 55 | TOKENIZER_V1_MODEL_NAME, 56 | subfolder="tokenizer", 57 | torch_dtype=weight_dtype, 58 | cache_dir=DIFFUSERS_CACHE_DIR, 59 | ) 60 | text_encoder = CLIPTextModel.from_pretrained( 61 | pretrained_model_name_or_path, 62 | subfolder="text_encoder", 63 | num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, 64 | torch_dtype=weight_dtype, 65 | cache_dir=DIFFUSERS_CACHE_DIR, 66 | ) 67 | 68 | unet = UNet2DConditionModel.from_pretrained( 69 | pretrained_model_name_or_path, 70 | subfolder="unet", 71 | torch_dtype=weight_dtype, 72 | cache_dir=DIFFUSERS_CACHE_DIR, 73 | ) 74 | 75 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") 76 | 77 | return tokenizer, text_encoder, unet, vae 78 | 79 | 80 | def load_checkpoint_model( 81 | checkpoint_path: str, 82 | v2: bool = False, 83 | clip_skip: Optional[int] = None, 84 | weight_dtype: torch.dtype = torch.float32, 85 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: 86 | pipe = StableDiffusionPipeline.from_ckpt( 87 | checkpoint_path, 88 | upcast_attention=True if v2 else False, 89 | torch_dtype=weight_dtype, 90 | cache_dir=DIFFUSERS_CACHE_DIR, 91 | ) 92 | 93 | unet = pipe.unet 94 | tokenizer = pipe.tokenizer 95 | text_encoder = pipe.text_encoder 96 | vae = pipe.vae 97 | if clip_skip is not None: 98 | if v2: 99 | text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) 100 | else: 101 | text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) 102 | 103 | del pipe 104 | 105 | return tokenizer, text_encoder, unet, vae 106 | 107 | 108 | def load_models( 109 | pretrained_model_name_or_path: str, 110 | scheduler_name: AVAILABLE_SCHEDULERS, 111 | v2: bool = False, 112 | v_pred: bool = False, 113 | weight_dtype: torch.dtype = torch.float32, 114 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: 115 | if pretrained_model_name_or_path.endswith( 116 | ".ckpt" 117 | ) or pretrained_model_name_or_path.endswith(".safetensors"): 118 | tokenizer, text_encoder, unet, vae = load_checkpoint_model( 119 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype 120 | ) 121 | else: # diffusers 122 | tokenizer, text_encoder, unet, vae = load_diffusers_model( 123 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype 124 | ) 125 | 126 | # VAE はいらない 127 | 128 | scheduler = create_noise_scheduler( 129 | scheduler_name, 130 | prediction_type="v_prediction" if v_pred else "epsilon", 131 | ) 132 | 133 | return tokenizer, text_encoder, unet, scheduler, vae 134 | 135 | 136 | def load_diffusers_model_xl( 137 | pretrained_model_name_or_path: str, 138 | weight_dtype: torch.dtype = torch.float32, 139 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: 140 | # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet 141 | 142 | tokenizers = [ 143 | CLIPTokenizer.from_pretrained( 144 | pretrained_model_name_or_path, 145 | subfolder="tokenizer", 146 | torch_dtype=weight_dtype, 147 | cache_dir=DIFFUSERS_CACHE_DIR, 148 | ), 149 | CLIPTokenizer.from_pretrained( 150 | pretrained_model_name_or_path, 151 | subfolder="tokenizer_2", 152 | torch_dtype=weight_dtype, 153 | cache_dir=DIFFUSERS_CACHE_DIR, 154 | pad_token_id=0, # same as open clip 155 | ), 156 | ] 157 | 158 | text_encoders = [ 159 | CLIPTextModel.from_pretrained( 160 | pretrained_model_name_or_path, 161 | subfolder="text_encoder", 162 | torch_dtype=weight_dtype, 163 | cache_dir=DIFFUSERS_CACHE_DIR, 164 | ), 165 | CLIPTextModelWithProjection.from_pretrained( 166 | pretrained_model_name_or_path, 167 | subfolder="text_encoder_2", 168 | torch_dtype=weight_dtype, 169 | cache_dir=DIFFUSERS_CACHE_DIR, 170 | ), 171 | ] 172 | 173 | unet = UNet2DConditionModel.from_pretrained( 174 | pretrained_model_name_or_path, 175 | subfolder="unet", 176 | torch_dtype=weight_dtype, 177 | cache_dir=DIFFUSERS_CACHE_DIR, 178 | ) 179 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") 180 | return tokenizers, text_encoders, unet, vae 181 | 182 | 183 | def load_checkpoint_model_xl( 184 | checkpoint_path: str, 185 | weight_dtype: torch.dtype = torch.float32, 186 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: 187 | pipe = StableDiffusionXLPipeline.from_single_file( 188 | checkpoint_path, 189 | torch_dtype=weight_dtype, 190 | cache_dir=DIFFUSERS_CACHE_DIR, 191 | ) 192 | 193 | unet = pipe.unet 194 | tokenizers = [pipe.tokenizer, pipe.tokenizer_2] 195 | text_encoders = [pipe.text_encoder, pipe.text_encoder_2] 196 | if len(text_encoders) == 2: 197 | text_encoders[1].pad_token_id = 0 198 | vae = pipe.vae 199 | del pipe 200 | 201 | return tokenizers, text_encoders, unet, vae 202 | 203 | 204 | def load_models_xl( 205 | pretrained_model_name_or_path: str, 206 | scheduler_name: AVAILABLE_SCHEDULERS, 207 | weight_dtype: torch.dtype = torch.float32, 208 | ) -> tuple[ 209 | list[CLIPTokenizer], 210 | list[SDXL_TEXT_ENCODER_TYPE], 211 | UNet2DConditionModel, 212 | SchedulerMixin, 213 | ]: 214 | if pretrained_model_name_or_path.endswith( 215 | ".ckpt" 216 | ) or pretrained_model_name_or_path.endswith(".safetensors"): 217 | ( 218 | tokenizers, 219 | text_encoders, 220 | unet, 221 | vae 222 | ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype) 223 | else: # diffusers 224 | ( 225 | tokenizers, 226 | text_encoders, 227 | unet, 228 | vae 229 | ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype) 230 | 231 | scheduler = create_noise_scheduler(scheduler_name) 232 | 233 | return tokenizers, text_encoders, unet, scheduler, vae 234 | 235 | 236 | def create_noise_scheduler( 237 | scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", 238 | prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", 239 | ) -> SchedulerMixin: 240 | # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。 241 | 242 | name = scheduler_name.lower().replace(" ", "_") 243 | if name == "ddim": 244 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim 245 | scheduler = DDIMScheduler( 246 | beta_start=0.00085, 247 | beta_end=0.012, 248 | beta_schedule="scaled_linear", 249 | num_train_timesteps=1000, 250 | clip_sample=False, 251 | prediction_type=prediction_type, # これでいいの? 252 | ) 253 | elif name == "ddpm": 254 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm 255 | scheduler = DDPMScheduler( 256 | beta_start=0.00085, 257 | beta_end=0.012, 258 | beta_schedule="scaled_linear", 259 | num_train_timesteps=1000, 260 | clip_sample=False, 261 | prediction_type=prediction_type, 262 | ) 263 | elif name == "lms": 264 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete 265 | scheduler = LMSDiscreteScheduler( 266 | beta_start=0.00085, 267 | beta_end=0.012, 268 | beta_schedule="scaled_linear", 269 | num_train_timesteps=1000, 270 | prediction_type=prediction_type, 271 | ) 272 | elif name == "euler_a": 273 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral 274 | scheduler = EulerAncestralDiscreteScheduler( 275 | beta_start=0.00085, 276 | beta_end=0.012, 277 | beta_schedule="scaled_linear", 278 | num_train_timesteps=1000, 279 | prediction_type=prediction_type, 280 | ) 281 | else: 282 | raise ValueError(f"Unknown scheduler name: {name}") 283 | 284 | return scheduler 285 | -------------------------------------------------------------------------------- /trainscripts/imagesliders/prompt_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, Union, List 2 | 3 | import yaml 4 | from pathlib import Path 5 | 6 | 7 | from pydantic import BaseModel, root_validator 8 | import torch 9 | import copy 10 | 11 | ACTION_TYPES = Literal[ 12 | "erase", 13 | "enhance", 14 | ] 15 | 16 | 17 | # XL は二種類必要なので 18 | class PromptEmbedsXL: 19 | text_embeds: torch.FloatTensor 20 | pooled_embeds: torch.FloatTensor 21 | 22 | def __init__(self, *args) -> None: 23 | self.text_embeds = args[0] 24 | self.pooled_embeds = args[1] 25 | 26 | 27 | # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL 28 | PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL] 29 | 30 | 31 | class PromptEmbedsCache: # 使いまわしたいので 32 | prompts: dict[str, PROMPT_EMBEDDING] = {} 33 | 34 | def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None: 35 | self.prompts[__name] = __value 36 | 37 | def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]: 38 | if __name in self.prompts: 39 | return self.prompts[__name] 40 | else: 41 | return None 42 | 43 | 44 | class PromptSettings(BaseModel): # yaml のやつ 45 | target: str 46 | positive: str = None # if None, target will be used 47 | unconditional: str = "" # default is "" 48 | neutral: str = None # if None, unconditional will be used 49 | action: ACTION_TYPES = "erase" # default is "erase" 50 | guidance_scale: float = 1.0 # default is 1.0 51 | resolution: int = 512 # default is 512 52 | dynamic_resolution: bool = False # default is False 53 | batch_size: int = 1 # default is 1 54 | dynamic_crops: bool = False # default is False. only used when model is XL 55 | 56 | @root_validator(pre=True) 57 | def fill_prompts(cls, values): 58 | keys = values.keys() 59 | if "target" not in keys: 60 | raise ValueError("target must be specified") 61 | if "positive" not in keys: 62 | values["positive"] = values["target"] 63 | if "unconditional" not in keys: 64 | values["unconditional"] = "" 65 | if "neutral" not in keys: 66 | values["neutral"] = values["unconditional"] 67 | 68 | return values 69 | 70 | 71 | class PromptEmbedsPair: 72 | target: PROMPT_EMBEDDING # not want to generate the concept 73 | positive: PROMPT_EMBEDDING # generate the concept 74 | unconditional: PROMPT_EMBEDDING # uncondition (default should be empty) 75 | neutral: PROMPT_EMBEDDING # base condition (default should be empty) 76 | 77 | guidance_scale: float 78 | resolution: int 79 | dynamic_resolution: bool 80 | batch_size: int 81 | dynamic_crops: bool 82 | 83 | loss_fn: torch.nn.Module 84 | action: ACTION_TYPES 85 | 86 | def __init__( 87 | self, 88 | loss_fn: torch.nn.Module, 89 | target: PROMPT_EMBEDDING, 90 | positive: PROMPT_EMBEDDING, 91 | unconditional: PROMPT_EMBEDDING, 92 | neutral: PROMPT_EMBEDDING, 93 | settings: PromptSettings, 94 | ) -> None: 95 | self.loss_fn = loss_fn 96 | self.target = target 97 | self.positive = positive 98 | self.unconditional = unconditional 99 | self.neutral = neutral 100 | 101 | self.guidance_scale = settings.guidance_scale 102 | self.resolution = settings.resolution 103 | self.dynamic_resolution = settings.dynamic_resolution 104 | self.batch_size = settings.batch_size 105 | self.dynamic_crops = settings.dynamic_crops 106 | self.action = settings.action 107 | 108 | def _erase( 109 | self, 110 | target_latents: torch.FloatTensor, # "van gogh" 111 | positive_latents: torch.FloatTensor, # "van gogh" 112 | unconditional_latents: torch.FloatTensor, # "" 113 | neutral_latents: torch.FloatTensor, # "" 114 | ) -> torch.FloatTensor: 115 | """Target latents are going not to have the positive concept.""" 116 | return self.loss_fn( 117 | target_latents, 118 | neutral_latents 119 | - self.guidance_scale * (positive_latents - unconditional_latents) 120 | ) 121 | 122 | 123 | def _enhance( 124 | self, 125 | target_latents: torch.FloatTensor, # "van gogh" 126 | positive_latents: torch.FloatTensor, # "van gogh" 127 | unconditional_latents: torch.FloatTensor, # "" 128 | neutral_latents: torch.FloatTensor, # "" 129 | ): 130 | """Target latents are going to have the positive concept.""" 131 | return self.loss_fn( 132 | target_latents, 133 | neutral_latents 134 | + self.guidance_scale * (positive_latents - unconditional_latents) 135 | ) 136 | 137 | def loss( 138 | self, 139 | **kwargs, 140 | ): 141 | if self.action == "erase": 142 | return self._erase(**kwargs) 143 | 144 | elif self.action == "enhance": 145 | return self._enhance(**kwargs) 146 | 147 | else: 148 | raise ValueError("action must be erase or enhance") 149 | 150 | 151 | def load_prompts_from_yaml(path, attributes = []): 152 | with open(path, "r") as f: 153 | prompts = yaml.safe_load(f) 154 | print(prompts) 155 | if len(prompts) == 0: 156 | raise ValueError("prompts file is empty") 157 | if len(attributes)!=0: 158 | newprompts = [] 159 | for i in range(len(prompts)): 160 | for att in attributes: 161 | copy_ = copy.deepcopy(prompts[i]) 162 | copy_['target'] = att + ' ' + copy_['target'] 163 | copy_['positive'] = att + ' ' + copy_['positive'] 164 | copy_['neutral'] = att + ' ' + copy_['neutral'] 165 | copy_['unconditional'] = att + ' ' + copy_['unconditional'] 166 | newprompts.append(copy_) 167 | else: 168 | newprompts = copy.deepcopy(prompts) 169 | 170 | print(newprompts) 171 | print(len(prompts), len(newprompts)) 172 | prompt_settings = [PromptSettings(**prompt) for prompt in newprompts] 173 | 174 | return prompt_settings 175 | -------------------------------------------------------------------------------- /trainscripts/textsliders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__init__.py -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/config_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/config_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/debug_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/debug_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/lora.cpython-310.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/lora.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/lora.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/model_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/model_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/prompt_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/prompt_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/ptp_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/ptp_utils.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/__pycache__/train_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/train_util.cpython-39.pyc -------------------------------------------------------------------------------- /trainscripts/textsliders/config_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import yaml 4 | 5 | from pydantic import BaseModel 6 | import torch 7 | 8 | from lora import TRAINING_METHODS 9 | 10 | PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"] 11 | NETWORK_TYPES = Literal["lierla", "c3lier"] 12 | 13 | 14 | class PretrainedModelConfig(BaseModel): 15 | name_or_path: str 16 | v2: bool = False 17 | v_pred: bool = False 18 | 19 | clip_skip: Optional[int] = None 20 | 21 | 22 | class NetworkConfig(BaseModel): 23 | type: NETWORK_TYPES = "lierla" 24 | rank: int = 4 25 | alpha: float = 1.0 26 | 27 | training_method: TRAINING_METHODS = "full" 28 | 29 | 30 | class TrainConfig(BaseModel): 31 | precision: PRECISION_TYPES = "bfloat16" 32 | noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim" 33 | 34 | iterations: int = 500 35 | lr: float = 1e-4 36 | optimizer: str = "adamw" 37 | optimizer_args: str = "" 38 | lr_scheduler: str = "constant" 39 | 40 | max_denoising_steps: int = 50 41 | 42 | 43 | class SaveConfig(BaseModel): 44 | name: str = "untitled" 45 | path: str = "./output" 46 | per_steps: int = 200 47 | precision: PRECISION_TYPES = "float32" 48 | 49 | 50 | class LoggingConfig(BaseModel): 51 | use_wandb: bool = False 52 | 53 | verbose: bool = False 54 | 55 | 56 | class OtherConfig(BaseModel): 57 | use_xformers: bool = False 58 | 59 | 60 | class RootConfig(BaseModel): 61 | prompts_file: str 62 | pretrained_model: PretrainedModelConfig 63 | 64 | network: NetworkConfig 65 | 66 | train: Optional[TrainConfig] 67 | 68 | save: Optional[SaveConfig] 69 | 70 | logging: Optional[LoggingConfig] 71 | 72 | other: Optional[OtherConfig] 73 | 74 | 75 | def parse_precision(precision: str) -> torch.dtype: 76 | if precision == "fp32" or precision == "float32": 77 | return torch.float32 78 | elif precision == "fp16" or precision == "float16": 79 | return torch.float16 80 | elif precision == "bf16" or precision == "bfloat16": 81 | return torch.bfloat16 82 | 83 | raise ValueError(f"Invalid precision type: {precision}") 84 | 85 | 86 | def load_config_from_yaml(config_path: str) -> RootConfig: 87 | with open(config_path, "r") as f: 88 | config = yaml.load(f, Loader=yaml.FullLoader) 89 | 90 | root = RootConfig(**config) 91 | 92 | if root.train is None: 93 | root.train = TrainConfig() 94 | 95 | if root.save is None: 96 | root.save = SaveConfig() 97 | 98 | if root.logging is None: 99 | root.logging = LoggingConfig() 100 | 101 | if root.other is None: 102 | root.other = OtherConfig() 103 | 104 | return root 105 | -------------------------------------------------------------------------------- /trainscripts/textsliders/data/.ipynb_checkpoints/prompts-person_age_slider_GPT-checkpoint.yaml: -------------------------------------------------------------------------------- 1 | - target: male white person 2 | positive: male white person, very old, wrinkly, gray hair, aged skin 3 | unconditional: male white person, very young, smooth skin, youthful 4 | neutral: male white person 5 | guidance: 4 6 | rank: 4 7 | action: enhance 8 | resolution: 512 9 | dynamic_resolution: false 10 | batch_size: 1 11 | - target: male black person 12 | positive: male black person, very old, wrinkly, gray hair, aged skin 13 | unconditional: male black person, very young, smooth skin, youthful 14 | neutral: male black person 15 | guidance: 4 16 | rank: 4 17 | action: enhance 18 | resolution: 512 19 | dynamic_resolution: false 20 | batch_size: 1 21 | - target: male indian person 22 | positive: male indian person, very old, wrinkly, gray hair, aged skin 23 | unconditional: male indian person, very young, smooth skin, youthful 24 | neutral: male indian person 25 | guidance: 4 26 | rank: 4 27 | action: enhance 28 | resolution: 512 29 | dynamic_resolution: false 30 | batch_size: 1 31 | - target: male asian person 32 | positive: male asian person, very old, wrinkly, gray hair, aged skin 33 | unconditional: male asian person, very young, smooth skin, youthful 34 | neutral: male asian person 35 | guidance: 4 36 | rank: 4 37 | action: enhance 38 | resolution: 512 39 | dynamic_resolution: false 40 | batch_size: 1 41 | - target: male hispanic person 42 | positive: male hispanic person, very old, wrinkly, gray hair, aged skin 43 | unconditional: male hispanic person, very young, smooth skin, youthful 44 | neutral: male hispanic person 45 | guidance: 4 46 | rank: 4 47 | action: enhance 48 | resolution: 512 49 | dynamic_resolution: false 50 | batch_size: 1 51 | - target: ' female white person' 52 | positive: ' female white person, very old, wrinkly, gray hair, aged skin' 53 | unconditional: ' female white person, very young, smooth skin, youthful' 54 | neutral: ' female white person' 55 | guidance: 4 56 | rank: 4 57 | action: enhance 58 | resolution: 512 59 | dynamic_resolution: false 60 | batch_size: 1 61 | - target: ' female black person' 62 | positive: ' female black person, very old, wrinkly, gray hair, aged skin' 63 | unconditional: ' female black person, very young, smooth skin, youthful' 64 | neutral: ' female black person' 65 | guidance: 4 66 | rank: 4 67 | action: enhance 68 | resolution: 512 69 | dynamic_resolution: false 70 | batch_size: 1 71 | - target: ' female indian person' 72 | positive: ' female indian person, very old, wrinkly, gray hair, aged skin' 73 | unconditional: ' female indian person, very young, smooth skin, youthful' 74 | neutral: ' female indian person' 75 | guidance: 4 76 | rank: 4 77 | action: enhance 78 | resolution: 512 79 | dynamic_resolution: false 80 | batch_size: 1 81 | - target: ' female asian person' 82 | positive: ' female asian person, very old, wrinkly, gray hair, aged skin' 83 | unconditional: ' female asian person, very young, smooth skin, youthful' 84 | neutral: ' female asian person' 85 | guidance: 4 86 | rank: 4 87 | action: enhance 88 | resolution: 512 89 | dynamic_resolution: false 90 | batch_size: 1 91 | - target: ' female hispanic person' 92 | positive: ' female hispanic person, very old, wrinkly, gray hair, aged skin' 93 | unconditional: ' female hispanic person, very young, smooth skin, youthful' 94 | neutral: ' female hispanic person' 95 | guidance: 4 96 | rank: 4 97 | action: enhance 98 | resolution: 512 99 | dynamic_resolution: false 100 | batch_size: 1 101 | -------------------------------------------------------------------------------- /trainscripts/textsliders/data/.ipynb_checkpoints/prompts-smile_slider_GPT-checkpoint.yaml: -------------------------------------------------------------------------------- 1 | - target: male white person 2 | positive: male white person, smiling, happy face, big smile 3 | unconditional: male white person, frowning, grumpy, sad 4 | neutral: male white person 5 | guidance: 4 6 | rank: 4 7 | action: enhance 8 | resolution: 512 9 | dynamic_resolution: false 10 | batch_size: 1 11 | - target: male black person 12 | positive: male black person, smiling, happy face, big smile 13 | unconditional: male black person, frowning, grumpy, sad 14 | neutral: male black person 15 | guidance: 4 16 | rank: 4 17 | action: enhance 18 | resolution: 512 19 | dynamic_resolution: false 20 | batch_size: 1 21 | - target: male indian person 22 | positive: male indian person, smiling, happy face, big smile 23 | unconditional: male indian person, frowning, grumpy, sad 24 | neutral: male indian person 25 | guidance: 4 26 | rank: 4 27 | action: enhance 28 | resolution: 512 29 | dynamic_resolution: false 30 | batch_size: 1 31 | - target: male asian person 32 | positive: male asian person, smiling, happy face, big smile 33 | unconditional: male asian person, frowning, grumpy, sad 34 | neutral: male asian person 35 | guidance: 4 36 | rank: 4 37 | action: enhance 38 | resolution: 512 39 | dynamic_resolution: false 40 | batch_size: 1 41 | - target: male hispanic person 42 | positive: male hispanic person, smiling, happy face, big smile 43 | unconditional: male hispanic person, frowning, grumpy, sad 44 | neutral: male hispanic person 45 | guidance: 4 46 | rank: 4 47 | action: enhance 48 | resolution: 512 49 | dynamic_resolution: false 50 | batch_size: 1 51 | - target: female white person 52 | positive: female white person, smiling, happy face, big smile 53 | unconditional: female white person, frowning, grumpy, sad 54 | neutral: female white person 55 | guidance: 4 56 | rank: 4 57 | action: enhance 58 | resolution: 512 59 | dynamic_resolution: false 60 | batch_size: 1 61 | - target: female black person 62 | positive: female black person, smiling, happy face, big smile 63 | unconditional: female black person, frowning, grumpy, sad 64 | neutral: female black person 65 | guidance: 4 66 | rank: 4 67 | action: enhance 68 | resolution: 512 69 | dynamic_resolution: false 70 | batch_size: 1 71 | - target: female indian person 72 | positive: female indian person, smiling, happy face, big smile 73 | unconditional: female indian person, frowning, grumpy, sad 74 | neutral: female indian person 75 | guidance: 4 76 | rank: 4 77 | action: enhance 78 | resolution: 512 79 | dynamic_resolution: false 80 | batch_size: 1 81 | - target: female asian person 82 | positive: female asian person, smiling, happy face, big smile 83 | unconditional: female asian person, frowning, grumpy, sad 84 | neutral: female asian person 85 | guidance: 4 86 | rank: 4 87 | action: enhance 88 | resolution: 512 89 | dynamic_resolution: false 90 | batch_size: 1 91 | - target: female hispanic person 92 | positive: female hispanic person, smiling, happy face, big smile 93 | unconditional: female hispanic person, frowning, grumpy, sad 94 | neutral: female hispanic person 95 | guidance: 4 96 | rank: 4 97 | action: enhance 98 | resolution: 512 99 | dynamic_resolution: false 100 | batch_size: 1 101 | -------------------------------------------------------------------------------- /trainscripts/textsliders/data/config-xl.yaml: -------------------------------------------------------------------------------- 1 | prompts_file: "trainscripts/textsliders/data/prompts-xl.yaml" 2 | pretrained_model: 3 | name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models 4 | v2: false # true if model is v2.x 5 | v_pred: false # true if model uses v-prediction 6 | network: 7 | type: "c3lier" # or "c3lier" or "lierla" 8 | rank: 4 9 | alpha: 1.0 10 | training_method: "noxattn" 11 | train: 12 | precision: "bfloat16" 13 | noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" 14 | iterations: 1000 15 | lr: 0.0002 16 | optimizer: "AdamW" 17 | lr_scheduler: "constant" 18 | max_denoising_steps: 50 19 | save: 20 | name: "temp" 21 | path: "./models" 22 | per_steps: 500 23 | precision: "bfloat16" 24 | logging: 25 | use_wandb: false 26 | verbose: false 27 | other: 28 | use_xformers: true -------------------------------------------------------------------------------- /trainscripts/textsliders/data/config.yaml: -------------------------------------------------------------------------------- 1 | prompts_file: "trainscripts/textsliders/data/prompts.yaml" 2 | pretrained_model: 3 | name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models 4 | v2: false # true if model is v2.x 5 | v_pred: false # true if model uses v-prediction 6 | network: 7 | type: "c3lier" # or "c3lier" or "lierla" 8 | rank: 4 9 | alpha: 1.0 10 | training_method: "noxattn" 11 | train: 12 | precision: "bfloat16" 13 | noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" 14 | iterations: 1000 15 | lr: 0.0002 16 | optimizer: "AdamW" 17 | lr_scheduler: "constant" 18 | max_denoising_steps: 50 19 | save: 20 | name: "temp" 21 | path: "./models" 22 | per_steps: 500 23 | precision: "bfloat16" 24 | logging: 25 | use_wandb: false 26 | verbose: false 27 | other: 28 | use_xformers: true -------------------------------------------------------------------------------- /trainscripts/textsliders/data/prompts-animated_eyes_GPT.yaml: -------------------------------------------------------------------------------- 1 | - target: male white person 2 | positive: male white person, eyes extremely large, animatedly big eyes, cartoonish 3 | proportions 4 | unconditional: male white person, eyes normal size, realistic proportions 5 | neutral: male white person 6 | guidance: 4 7 | rank: 4 8 | action: enhance 9 | resolution: 512 10 | dynamic_resolution: false 11 | batch_size: 1 12 | - target: male black person 13 | positive: male black person, eyes extremely large, animatedly big eyes, cartoonish 14 | proportions 15 | unconditional: male black person, eyes normal size, realistic proportions 16 | neutral: male black person 17 | guidance: 4 18 | rank: 4 19 | action: enhance 20 | resolution: 512 21 | dynamic_resolution: false 22 | batch_size: 1 23 | - target: male indian person 24 | positive: male indian person, eyes extremely large, animatedly big eyes, cartoonish 25 | proportions 26 | unconditional: male indian person, eyes normal size, realistic proportions 27 | neutral: male indian person 28 | guidance: 4 29 | rank: 4 30 | action: enhance 31 | resolution: 512 32 | dynamic_resolution: false 33 | batch_size: 1 34 | - target: male asian person 35 | positive: male asian person, eyes extremely large, animatedly big eyes, cartoonish 36 | proportions 37 | unconditional: male asian person, eyes normal size, realistic proportions 38 | neutral: male asian person 39 | guidance: 4 40 | rank: 4 41 | action: enhance 42 | resolution: 512 43 | dynamic_resolution: false 44 | batch_size: 1 45 | - target: male hispanic person 46 | positive: male hispanic person, eyes extremely large, animatedly big eyes, cartoonish 47 | proportions 48 | unconditional: male hispanic person, eyes normal size, realistic proportions 49 | neutral: male hispanic person 50 | guidance: 4 51 | rank: 4 52 | action: enhance 53 | resolution: 512 54 | dynamic_resolution: false 55 | batch_size: 1 56 | - target: female white person 57 | positive: female white person, eyes extremely large, animatedly big eyes, cartoonish 58 | proportions 59 | unconditional: female white person, eyes normal size, realistic proportions 60 | neutral: female white person 61 | guidance: 4 62 | rank: 4 63 | action: enhance 64 | resolution: 512 65 | dynamic_resolution: false 66 | batch_size: 1 67 | - target: female black person 68 | positive: female black person, eyes extremely large, animatedly big eyes, cartoonish 69 | proportions 70 | unconditional: female black person, eyes normal size, realistic proportions 71 | neutral: female black person 72 | guidance: 4 73 | rank: 4 74 | action: enhance 75 | resolution: 512 76 | dynamic_resolution: false 77 | batch_size: 1 78 | - target: female indian person 79 | positive: female indian person, eyes extremely large, animatedly big eyes, cartoonish 80 | proportions 81 | unconditional: female indian person, eyes normal size, realistic proportions 82 | neutral: female indian person 83 | guidance: 4 84 | rank: 4 85 | action: enhance 86 | resolution: 512 87 | dynamic_resolution: false 88 | batch_size: 1 89 | - target: female asian person 90 | positive: female asian person, eyes extremely large, animatedly big eyes, cartoonish 91 | proportions 92 | unconditional: female asian person, eyes normal size, realistic proportions 93 | neutral: female asian person 94 | guidance: 4 95 | rank: 4 96 | action: enhance 97 | resolution: 512 98 | dynamic_resolution: false 99 | batch_size: 1 100 | - target: female hispanic person 101 | positive: female hispanic person, eyes extremely large, animatedly big eyes, cartoonish 102 | proportions 103 | unconditional: female hispanic person, eyes normal size, realistic proportions 104 | neutral: female hispanic person 105 | guidance: 4 106 | rank: 4 107 | action: enhance 108 | resolution: 512 109 | dynamic_resolution: false 110 | batch_size: 1 111 | -------------------------------------------------------------------------------- /trainscripts/textsliders/data/prompts-car_alienTechFuturistic_GPT.yaml: -------------------------------------------------------------------------------- 1 | - target: car 2 | positive: car, alien technology, futuristic design, advanced propulsion, sleek aerodynamic 3 | shape 4 | unconditional: car, outdated, old-fashioned, conventional design 5 | neutral: car 6 | guidance: 4 7 | rank: 4 8 | action: enhance 9 | resolution: 512 10 | dynamic_resolution: false 11 | batch_size: 1 12 | -------------------------------------------------------------------------------- /trainscripts/textsliders/data/prompts-jewelry_diamonds_GPT.yaml: -------------------------------------------------------------------------------- 1 | - target: male white person 2 | positive: male white person adorned with jewelry, sparkling diamonds, elegant necklaces, 3 | luxurious bracelets, shimmering earrings 4 | unconditional: male white person with no jewelry, plain appearance, unadorned 5 | neutral: male white person 6 | guidance: 4 7 | rank: 4 8 | action: enhance 9 | resolution: 512 10 | dynamic_resolution: false 11 | batch_size: 1 12 | - target: male black person 13 | positive: male black person adorned with jewelry, sparkling diamonds, elegant necklaces, 14 | luxurious bracelets, shimmering earrings 15 | unconditional: male black person with no jewelry, plain appearance, unadorned 16 | neutral: male black person 17 | guidance: 4 18 | rank: 4 19 | action: enhance 20 | resolution: 512 21 | dynamic_resolution: false 22 | batch_size: 1 23 | - target: male indian person 24 | positive: male indian person adorned with jewelry, sparkling diamonds, elegant necklaces, 25 | luxurious bracelets, shimmering earrings 26 | unconditional: male indian person with no jewelry, plain appearance, unadorned 27 | neutral: male indian person 28 | guidance: 4 29 | rank: 4 30 | action: enhance 31 | resolution: 512 32 | dynamic_resolution: false 33 | batch_size: 1 34 | - target: male asian person 35 | positive: male asian person adorned with jewelry, sparkling diamonds, elegant necklaces, 36 | luxurious bracelets, shimmering earrings 37 | unconditional: male asian person with no jewelry, plain appearance, unadorned 38 | neutral: male asian person 39 | guidance: 4 40 | rank: 4 41 | action: enhance 42 | resolution: 512 43 | dynamic_resolution: false 44 | batch_size: 1 45 | - target: male hispanic person 46 | positive: male hispanic person adorned with jewelry, sparkling diamonds, elegant 47 | necklaces, luxurious bracelets, shimmering earrings 48 | unconditional: male hispanic person with no jewelry, plain appearance, unadorned 49 | neutral: male hispanic person 50 | guidance: 4 51 | rank: 4 52 | action: enhance 53 | resolution: 512 54 | dynamic_resolution: false 55 | batch_size: 1 56 | - target: female white person 57 | positive: female white person adorned with jewelry, sparkling diamonds, elegant 58 | necklaces, luxurious bracelets, shimmering earrings 59 | unconditional: female white person with no jewelry, plain appearance, unadorned 60 | neutral: female white person 61 | guidance: 4 62 | rank: 4 63 | action: enhance 64 | resolution: 512 65 | dynamic_resolution: false 66 | batch_size: 1 67 | - target: female black person 68 | positive: female black person adorned with jewelry, sparkling diamonds, elegant 69 | necklaces, luxurious bracelets, shimmering earrings 70 | unconditional: female black person with no jewelry, plain appearance, unadorned 71 | neutral: female black person 72 | guidance: 4 73 | rank: 4 74 | action: enhance 75 | resolution: 512 76 | dynamic_resolution: false 77 | batch_size: 1 78 | - target: female indian person 79 | positive: female indian person adorned with jewelry, sparkling diamonds, elegant 80 | necklaces, luxurious bracelets, shimmering earrings 81 | unconditional: female indian person with no jewelry, plain appearance, unadorned 82 | neutral: female indian person 83 | guidance: 4 84 | rank: 4 85 | action: enhance 86 | resolution: 512 87 | dynamic_resolution: false 88 | batch_size: 1 89 | - target: female asian person 90 | positive: female asian person adorned with jewelry, sparkling diamonds, elegant 91 | necklaces, luxurious bracelets, shimmering earrings 92 | unconditional: female asian person with no jewelry, plain appearance, unadorned 93 | neutral: female asian person 94 | guidance: 4 95 | rank: 4 96 | action: enhance 97 | resolution: 512 98 | dynamic_resolution: false 99 | batch_size: 1 100 | - target: female hispanic person 101 | positive: female hispanic person adorned with jewelry, sparkling diamonds, elegant 102 | necklaces, luxurious bracelets, shimmering earrings 103 | unconditional: female hispanic person with no jewelry, plain appearance, unadorned 104 | neutral: female hispanic person 105 | guidance: 4 106 | rank: 4 107 | action: enhance 108 | resolution: 512 109 | dynamic_resolution: false 110 | batch_size: 1 111 | -------------------------------------------------------------------------------- /trainscripts/textsliders/data/prompts-person_age_slider_GPT.yaml: -------------------------------------------------------------------------------- 1 | - target: male white person 2 | positive: male white person, very old, wrinkly, gray hair, aged skin 3 | unconditional: male white person, very young, smooth skin, youthful 4 | neutral: male white person 5 | guidance: 4 6 | rank: 4 7 | action: enhance 8 | resolution: 512 9 | dynamic_resolution: false 10 | batch_size: 1 11 | - target: male black person 12 | positive: male black person, very old, wrinkly, gray hair, aged skin 13 | unconditional: male black person, very young, smooth skin, youthful 14 | neutral: male black person 15 | guidance: 4 16 | rank: 4 17 | action: enhance 18 | resolution: 512 19 | dynamic_resolution: false 20 | batch_size: 1 21 | - target: male indian person 22 | positive: male indian person, very old, wrinkly, gray hair, aged skin 23 | unconditional: male indian person, very young, smooth skin, youthful 24 | neutral: male indian person 25 | guidance: 4 26 | rank: 4 27 | action: enhance 28 | resolution: 512 29 | dynamic_resolution: false 30 | batch_size: 1 31 | - target: male asian person 32 | positive: male asian person, very old, wrinkly, gray hair, aged skin 33 | unconditional: male asian person, very young, smooth skin, youthful 34 | neutral: male asian person 35 | guidance: 4 36 | rank: 4 37 | action: enhance 38 | resolution: 512 39 | dynamic_resolution: false 40 | batch_size: 1 41 | - target: male hispanic person 42 | positive: male hispanic person, very old, wrinkly, gray hair, aged skin 43 | unconditional: male hispanic person, very young, smooth skin, youthful 44 | neutral: male hispanic person 45 | guidance: 4 46 | rank: 4 47 | action: enhance 48 | resolution: 512 49 | dynamic_resolution: false 50 | batch_size: 1 51 | - target: ' female white person' 52 | positive: ' female white person, very old, wrinkly, gray hair, aged skin' 53 | unconditional: ' female white person, very young, smooth skin, youthful' 54 | neutral: ' female white person' 55 | guidance: 4 56 | rank: 4 57 | action: enhance 58 | resolution: 512 59 | dynamic_resolution: false 60 | batch_size: 1 61 | - target: ' female black person' 62 | positive: ' female black person, very old, wrinkly, gray hair, aged skin' 63 | unconditional: ' female black person, very young, smooth skin, youthful' 64 | neutral: ' female black person' 65 | guidance: 4 66 | rank: 4 67 | action: enhance 68 | resolution: 512 69 | dynamic_resolution: false 70 | batch_size: 1 71 | - target: ' female indian person' 72 | positive: ' female indian person, very old, wrinkly, gray hair, aged skin' 73 | unconditional: ' female indian person, very young, smooth skin, youthful' 74 | neutral: ' female indian person' 75 | guidance: 4 76 | rank: 4 77 | action: enhance 78 | resolution: 512 79 | dynamic_resolution: false 80 | batch_size: 1 81 | - target: ' female asian person' 82 | positive: ' female asian person, very old, wrinkly, gray hair, aged skin' 83 | unconditional: ' female asian person, very young, smooth skin, youthful' 84 | neutral: ' female asian person' 85 | guidance: 4 86 | rank: 4 87 | action: enhance 88 | resolution: 512 89 | dynamic_resolution: false 90 | batch_size: 1 91 | - target: ' female hispanic person' 92 | positive: ' female hispanic person, very old, wrinkly, gray hair, aged skin' 93 | unconditional: ' female hispanic person, very young, smooth skin, youthful' 94 | neutral: ' female hispanic person' 95 | guidance: 4 96 | rank: 4 97 | action: enhance 98 | resolution: 512 99 | dynamic_resolution: false 100 | batch_size: 1 101 | -------------------------------------------------------------------------------- /trainscripts/textsliders/data/prompts-person_surprised_GPT.yaml: -------------------------------------------------------------------------------- 1 | - target: male, white, person 2 | positive: male, white, person, looking surprised, wide eyes, open mouth 3 | unconditional: male, white, person, looking calm, neutral expression 4 | neutral: male, white, person 5 | guidance: 4 6 | rank: 4 7 | action: enhance 8 | resolution: 512 9 | dynamic_resolution: false 10 | batch_size: 1 11 | - target: male, black, person 12 | positive: male, black, person, looking surprised, wide eyes, open mouth 13 | unconditional: male, black, person, looking calm, neutral expression 14 | neutral: male, black, person 15 | guidance: 4 16 | rank: 4 17 | action: enhance 18 | resolution: 512 19 | dynamic_resolution: false 20 | batch_size: 1 21 | - target: male, indian, person 22 | positive: male, indian, person, looking surprised, wide eyes, open mouth 23 | unconditional: male, indian, person, looking calm, neutral expression 24 | neutral: male, indian, person 25 | guidance: 4 26 | rank: 4 27 | action: enhance 28 | resolution: 512 29 | dynamic_resolution: false 30 | batch_size: 1 31 | - target: male, asian, person 32 | positive: male, asian, person, looking surprised, wide eyes, open mouth 33 | unconditional: male, asian, person, looking calm, neutral expression 34 | neutral: male, asian, person 35 | guidance: 4 36 | rank: 4 37 | action: enhance 38 | resolution: 512 39 | dynamic_resolution: false 40 | batch_size: 1 41 | - target: male, hispanic, person 42 | positive: male, hispanic, person, looking surprised, wide eyes, open mouth 43 | unconditional: male, hispanic, person, looking calm, neutral expression 44 | neutral: male, hispanic, person 45 | guidance: 4 46 | rank: 4 47 | action: enhance 48 | resolution: 512 49 | dynamic_resolution: false 50 | batch_size: 1 51 | - target: ' female, white, person' 52 | positive: ' female, white, person, looking surprised, wide eyes, open mouth' 53 | unconditional: ' female, white, person, looking calm, neutral expression' 54 | neutral: ' female, white, person' 55 | guidance: 4 56 | rank: 4 57 | action: enhance 58 | resolution: 512 59 | dynamic_resolution: false 60 | batch_size: 1 61 | - target: ' female, black, person' 62 | positive: ' female, black, person, looking surprised, wide eyes, open mouth' 63 | unconditional: ' female, black, person, looking calm, neutral expression' 64 | neutral: ' female, black, person' 65 | guidance: 4 66 | rank: 4 67 | action: enhance 68 | resolution: 512 69 | dynamic_resolution: false 70 | batch_size: 1 71 | - target: ' female, indian, person' 72 | positive: ' female, indian, person, looking surprised, wide eyes, open mouth' 73 | unconditional: ' female, indian, person, looking calm, neutral expression' 74 | neutral: ' female, indian, person' 75 | guidance: 4 76 | rank: 4 77 | action: enhance 78 | resolution: 512 79 | dynamic_resolution: false 80 | batch_size: 1 81 | - target: ' female, asian, person' 82 | positive: ' female, asian, person, looking surprised, wide eyes, open mouth' 83 | unconditional: ' female, asian, person, looking calm, neutral expression' 84 | neutral: ' female, asian, person' 85 | guidance: 4 86 | rank: 4 87 | action: enhance 88 | resolution: 512 89 | dynamic_resolution: false 90 | batch_size: 1 91 | - target: ' female, hispanic, person' 92 | positive: ' female, hispanic, person, looking surprised, wide eyes, open mouth' 93 | unconditional: ' female, hispanic, person, looking calm, neutral expression' 94 | neutral: ' female, hispanic, person' 95 | guidance: 4 96 | rank: 4 97 | action: enhance 98 | resolution: 512 99 | dynamic_resolution: false 100 | batch_size: 1 101 | -------------------------------------------------------------------------------- /trainscripts/textsliders/data/prompts-smile_slider_GPT.yaml: -------------------------------------------------------------------------------- 1 | - target: male white person 2 | positive: male white person, smiling, happy face, big smile 3 | unconditional: male white person, frowning, grumpy, sad 4 | neutral: male white person 5 | guidance: 4 6 | rank: 4 7 | action: enhance 8 | resolution: 512 9 | dynamic_resolution: false 10 | batch_size: 1 11 | - target: male black person 12 | positive: male black person, smiling, happy face, big smile 13 | unconditional: male black person, frowning, grumpy, sad 14 | neutral: male black person 15 | guidance: 4 16 | rank: 4 17 | action: enhance 18 | resolution: 512 19 | dynamic_resolution: false 20 | batch_size: 1 21 | - target: male indian person 22 | positive: male indian person, smiling, happy face, big smile 23 | unconditional: male indian person, frowning, grumpy, sad 24 | neutral: male indian person 25 | guidance: 4 26 | rank: 4 27 | action: enhance 28 | resolution: 512 29 | dynamic_resolution: false 30 | batch_size: 1 31 | - target: male asian person 32 | positive: male asian person, smiling, happy face, big smile 33 | unconditional: male asian person, frowning, grumpy, sad 34 | neutral: male asian person 35 | guidance: 4 36 | rank: 4 37 | action: enhance 38 | resolution: 512 39 | dynamic_resolution: false 40 | batch_size: 1 41 | - target: male hispanic person 42 | positive: male hispanic person, smiling, happy face, big smile 43 | unconditional: male hispanic person, frowning, grumpy, sad 44 | neutral: male hispanic person 45 | guidance: 4 46 | rank: 4 47 | action: enhance 48 | resolution: 512 49 | dynamic_resolution: false 50 | batch_size: 1 51 | - target: female white person 52 | positive: female white person, smiling, happy face, big smile 53 | unconditional: female white person, frowning, grumpy, sad 54 | neutral: female white person 55 | guidance: 4 56 | rank: 4 57 | action: enhance 58 | resolution: 512 59 | dynamic_resolution: false 60 | batch_size: 1 61 | - target: female black person 62 | positive: female black person, smiling, happy face, big smile 63 | unconditional: female black person, frowning, grumpy, sad 64 | neutral: female black person 65 | guidance: 4 66 | rank: 4 67 | action: enhance 68 | resolution: 512 69 | dynamic_resolution: false 70 | batch_size: 1 71 | - target: female indian person 72 | positive: female indian person, smiling, happy face, big smile 73 | unconditional: female indian person, frowning, grumpy, sad 74 | neutral: female indian person 75 | guidance: 4 76 | rank: 4 77 | action: enhance 78 | resolution: 512 79 | dynamic_resolution: false 80 | batch_size: 1 81 | - target: female asian person 82 | positive: female asian person, smiling, happy face, big smile 83 | unconditional: female asian person, frowning, grumpy, sad 84 | neutral: female asian person 85 | guidance: 4 86 | rank: 4 87 | action: enhance 88 | resolution: 512 89 | dynamic_resolution: false 90 | batch_size: 1 91 | - target: female hispanic person 92 | positive: female hispanic person, smiling, happy face, big smile 93 | unconditional: female hispanic person, frowning, grumpy, sad 94 | neutral: female hispanic person 95 | guidance: 4 96 | rank: 4 97 | action: enhance 98 | resolution: 512 99 | dynamic_resolution: false 100 | batch_size: 1 101 | -------------------------------------------------------------------------------- /trainscripts/textsliders/data/prompts.yaml: -------------------------------------------------------------------------------- 1 | - target: "male person" # what word for erasing the positive concept from 2 | positive: "male person, very old" # concept to erase 3 | unconditional: "male person, very young" # word to take the difference from the positive concept 4 | neutral: "male person" # starting point for conditioning the target 5 | action: "enhance" # erase or enhance 6 | guidance_scale: 4 7 | resolution: 512 8 | dynamic_resolution: false 9 | batch_size: 1 10 | - target: "female person" # what word for erasing the positive concept from 11 | positive: "female person, very old" # concept to erase 12 | unconditional: "female person, very young" # word to take the difference from the positive concept 13 | neutral: "female person" # starting point for conditioning the target 14 | action: "enhance" # erase or enhance 15 | guidance_scale: 4 16 | resolution: 512 17 | dynamic_resolution: false 18 | batch_size: 1 19 | # - target: "" # what word for erasing the positive concept from 20 | # positive: "a group of people" # concept to erase 21 | # unconditional: "a person" # word to take the difference from the positive concept 22 | # neutral: "" # starting point for conditioning the target 23 | # action: "enhance" # erase or enhance 24 | # guidance_scale: 4 25 | # resolution: 512 26 | # dynamic_resolution: false 27 | # batch_size: 1 28 | # - target: "" # what word for erasing the positive concept from 29 | # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase 30 | # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept 31 | # neutral: "" # starting point for conditioning the target 32 | # action: "enhance" # erase or enhance 33 | # guidance_scale: 4 34 | # resolution: 512 35 | # dynamic_resolution: false 36 | # batch_size: 1 37 | # - target: "" # what word for erasing the positive concept from 38 | # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase 39 | # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept 40 | # unconditional: "" 41 | # neutral: "" # starting point for conditioning the target 42 | # action: "enhance" # erase or enhance 43 | # guidance_scale: 4 44 | # resolution: 512 45 | # dynamic_resolution: false 46 | # batch_size: 1 47 | # - target: "food" # what word for erasing the positive concept from 48 | # positive: "food, expensive and fine dining" # concept to erase 49 | # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept 50 | # neutral: "food" # starting point for conditioning the target 51 | # action: "enhance" # erase or enhance 52 | # guidance_scale: 4 53 | # resolution: 512 54 | # dynamic_resolution: false 55 | # batch_size: 1 56 | # - target: "room" # what word for erasing the positive concept from 57 | # positive: "room, dirty disorganised and cluttered" # concept to erase 58 | # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept 59 | # neutral: "room" # starting point for conditioning the target 60 | # action: "enhance" # erase or enhance 61 | # guidance_scale: 4 62 | # resolution: 512 63 | # dynamic_resolution: false 64 | # batch_size: 1 65 | # - target: "male person" # what word for erasing the positive concept from 66 | # positive: "male person, with a surprised look" # concept to erase 67 | # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept 68 | # neutral: "male person" # starting point for conditioning the target 69 | # action: "enhance" # erase or enhance 70 | # guidance_scale: 4 71 | # resolution: 512 72 | # dynamic_resolution: false 73 | # batch_size: 1 74 | # - target: "female person" # what word for erasing the positive concept from 75 | # positive: "female person, with a surprised look" # concept to erase 76 | # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept 77 | # neutral: "female person" # starting point for conditioning the target 78 | # action: "enhance" # erase or enhance 79 | # guidance_scale: 4 80 | # resolution: 512 81 | # dynamic_resolution: false 82 | # batch_size: 1 83 | # - target: "sky" # what word for erasing the positive concept from 84 | # positive: "peaceful sky" # concept to erase 85 | # unconditional: "sky" # word to take the difference from the positive concept 86 | # neutral: "sky" # starting point for conditioning the target 87 | # action: "enhance" # erase or enhance 88 | # guidance_scale: 4 89 | # resolution: 512 90 | # dynamic_resolution: false 91 | # batch_size: 1 92 | # - target: "sky" # what word for erasing the positive concept from 93 | # positive: "chaotic dark sky" # concept to erase 94 | # unconditional: "sky" # word to take the difference from the positive concept 95 | # neutral: "sky" # starting point for conditioning the target 96 | # action: "erase" # erase or enhance 97 | # guidance_scale: 4 98 | # resolution: 512 99 | # dynamic_resolution: false 100 | # batch_size: 1 101 | # - target: "person" # what word for erasing the positive concept from 102 | # positive: "person, very young" # concept to erase 103 | # unconditional: "person" # word to take the difference from the positive concept 104 | # neutral: "person" # starting point for conditioning the target 105 | # action: "erase" # erase or enhance 106 | # guidance_scale: 4 107 | # resolution: 512 108 | # dynamic_resolution: false 109 | # batch_size: 1 110 | # overweight 111 | # - target: "art" # what word for erasing the positive concept from 112 | # positive: "realistic art" # concept to erase 113 | # unconditional: "art" # word to take the difference from the positive concept 114 | # neutral: "art" # starting point for conditioning the target 115 | # action: "enhance" # erase or enhance 116 | # guidance_scale: 4 117 | # resolution: 512 118 | # dynamic_resolution: false 119 | # batch_size: 1 120 | # - target: "art" # what word for erasing the positive concept from 121 | # positive: "abstract art" # concept to erase 122 | # unconditional: "art" # word to take the difference from the positive concept 123 | # neutral: "art" # starting point for conditioning the target 124 | # action: "erase" # erase or enhance 125 | # guidance_scale: 4 126 | # resolution: 512 127 | # dynamic_resolution: false 128 | # batch_size: 1 129 | # sky 130 | # - target: "weather" # what word for erasing the positive concept from 131 | # positive: "bright pleasant weather" # concept to erase 132 | # unconditional: "weather" # word to take the difference from the positive concept 133 | # neutral: "weather" # starting point for conditioning the target 134 | # action: "enhance" # erase or enhance 135 | # guidance_scale: 4 136 | # resolution: 512 137 | # dynamic_resolution: false 138 | # batch_size: 1 139 | # - target: "weather" # what word for erasing the positive concept from 140 | # positive: "dark gloomy weather" # concept to erase 141 | # unconditional: "weather" # word to take the difference from the positive concept 142 | # neutral: "weather" # starting point for conditioning the target 143 | # action: "erase" # erase or enhance 144 | # guidance_scale: 4 145 | # resolution: 512 146 | # dynamic_resolution: false 147 | # batch_size: 1 148 | # hair 149 | # - target: "person" # what word for erasing the positive concept from 150 | # positive: "person with long hair" # concept to erase 151 | # unconditional: "person" # word to take the difference from the positive concept 152 | # neutral: "person" # starting point for conditioning the target 153 | # action: "enhance" # erase or enhance 154 | # guidance_scale: 4 155 | # resolution: 512 156 | # dynamic_resolution: false 157 | # batch_size: 1 158 | # - target: "person" # what word for erasing the positive concept from 159 | # positive: "person with short hair" # concept to erase 160 | # unconditional: "person" # word to take the difference from the positive concept 161 | # neutral: "person" # starting point for conditioning the target 162 | # action: "erase" # erase or enhance 163 | # guidance_scale: 4 164 | # resolution: 512 165 | # dynamic_resolution: false 166 | # batch_size: 1 167 | # - target: "girl" # what word for erasing the positive concept from 168 | # positive: "baby girl" # concept to erase 169 | # unconditional: "girl" # word to take the difference from the positive concept 170 | # neutral: "girl" # starting point for conditioning the target 171 | # action: "enhance" # erase or enhance 172 | # guidance_scale: -4 173 | # resolution: 512 174 | # dynamic_resolution: false 175 | # batch_size: 1 176 | # - target: "boy" # what word for erasing the positive concept from 177 | # positive: "old man" # concept to erase 178 | # unconditional: "boy" # word to take the difference from the positive concept 179 | # neutral: "boy" # starting point for conditioning the target 180 | # action: "enhance" # erase or enhance 181 | # guidance_scale: 4 182 | # resolution: 512 183 | # dynamic_resolution: false 184 | # batch_size: 1 185 | # - target: "boy" # what word for erasing the positive concept from 186 | # positive: "baby boy" # concept to erase 187 | # unconditional: "boy" # word to take the difference from the positive concept 188 | # neutral: "boy" # starting point for conditioning the target 189 | # action: "enhance" # erase or enhance 190 | # guidance_scale: -4 191 | # resolution: 512 192 | # dynamic_resolution: false 193 | # batch_size: 1 -------------------------------------------------------------------------------- /trainscripts/textsliders/debug_util.py: -------------------------------------------------------------------------------- 1 | # デバッグ用... 2 | 3 | import torch 4 | 5 | 6 | def check_requires_grad(model: torch.nn.Module): 7 | for name, module in list(model.named_modules())[:5]: 8 | if len(list(module.parameters())) > 0: 9 | print(f"Module: {name}") 10 | for name, param in list(module.named_parameters())[:2]: 11 | print(f" Parameter: {name}, Requires Grad: {param.requires_grad}") 12 | 13 | 14 | def check_training_mode(model: torch.nn.Module): 15 | for name, module in list(model.named_modules())[:5]: 16 | print(f"Module: {name}, Training Mode: {module.training}") 17 | -------------------------------------------------------------------------------- /trainscripts/textsliders/flush.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | 4 | torch.cuda.empty_cache() 5 | gc.collect() 6 | -------------------------------------------------------------------------------- /trainscripts/textsliders/lora.py: -------------------------------------------------------------------------------- 1 | # ref: 2 | # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py 3 | # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py 4 | 5 | import os 6 | import math 7 | from typing import Optional, List, Type, Set, Literal 8 | 9 | import torch 10 | import torch.nn as nn 11 | from diffusers import UNet2DConditionModel 12 | from safetensors.torch import save_file 13 | 14 | 15 | UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ 16 | # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 17 | "Attention" 18 | ] 19 | UNET_TARGET_REPLACE_MODULE_CONV = [ 20 | "ResnetBlock2D", 21 | "Downsample2D", 22 | "Upsample2D", 23 | "DownBlock2D", 24 | "UpBlock2D", 25 | 26 | ] # locon, 3clier 27 | 28 | LORA_PREFIX_UNET = "lora_unet" 29 | 30 | DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER 31 | 32 | TRAINING_METHODS = Literal[ 33 | "noxattn", # train all layers except x-attns and time_embed layers 34 | "innoxattn", # train all layers except self attention layers 35 | "selfattn", # ESD-u, train only self attention layers 36 | "xattn", # ESD-x, train only x attention layers 37 | "full", # train all layers 38 | "xattn-strict", # q and k values 39 | "noxattn-hspace", 40 | "noxattn-hspace-last", 41 | # "xlayer", 42 | # "outxattn", 43 | # "outsattn", 44 | # "inxattn", 45 | # "inmidsattn", 46 | # "selflayer", 47 | ] 48 | 49 | 50 | class LoRAModule(nn.Module): 51 | """ 52 | replaces forward method of the original Linear, instead of replacing the original Linear module. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | lora_name, 58 | org_module: nn.Module, 59 | multiplier=1.0, 60 | lora_dim=4, 61 | alpha=1, 62 | ): 63 | """if alpha == 0 or None, alpha is rank (no scaling).""" 64 | super().__init__() 65 | self.lora_name = lora_name 66 | self.lora_dim = lora_dim 67 | 68 | if "Linear" in org_module.__class__.__name__: 69 | in_dim = org_module.in_features 70 | out_dim = org_module.out_features 71 | self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) 72 | self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) 73 | 74 | elif "Conv" in org_module.__class__.__name__: # 一応 75 | in_dim = org_module.in_channels 76 | out_dim = org_module.out_channels 77 | 78 | self.lora_dim = min(self.lora_dim, in_dim, out_dim) 79 | if self.lora_dim != lora_dim: 80 | print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") 81 | 82 | kernel_size = org_module.kernel_size 83 | stride = org_module.stride 84 | padding = org_module.padding 85 | self.lora_down = nn.Conv2d( 86 | in_dim, self.lora_dim, kernel_size, stride, padding, bias=False 87 | ) 88 | self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) 89 | 90 | if type(alpha) == torch.Tensor: 91 | alpha = alpha.detach().numpy() 92 | alpha = lora_dim if alpha is None or alpha == 0 else alpha 93 | self.scale = alpha / self.lora_dim 94 | self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える 95 | 96 | # same as microsoft's 97 | nn.init.kaiming_uniform_(self.lora_down.weight, a=1) 98 | nn.init.zeros_(self.lora_up.weight) 99 | 100 | self.multiplier = multiplier 101 | self.org_module = org_module # remove in applying 102 | 103 | def apply_to(self): 104 | self.org_forward = self.org_module.forward 105 | self.org_module.forward = self.forward 106 | del self.org_module 107 | 108 | def forward(self, x): 109 | return ( 110 | self.org_forward(x) 111 | + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale 112 | ) 113 | 114 | 115 | class LoRANetwork(nn.Module): 116 | def __init__( 117 | self, 118 | unet: UNet2DConditionModel, 119 | rank: int = 4, 120 | multiplier: float = 1.0, 121 | alpha: float = 1.0, 122 | train_method: TRAINING_METHODS = "full", 123 | ) -> None: 124 | super().__init__() 125 | self.lora_scale = 1 126 | self.multiplier = multiplier 127 | self.lora_dim = rank 128 | self.alpha = alpha 129 | 130 | # LoRAのみ 131 | self.module = LoRAModule 132 | 133 | # unetのloraを作る 134 | self.unet_loras = self.create_modules( 135 | LORA_PREFIX_UNET, 136 | unet, 137 | DEFAULT_TARGET_REPLACE, 138 | self.lora_dim, 139 | self.multiplier, 140 | train_method=train_method, 141 | ) 142 | print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") 143 | 144 | # assertion 名前の被りがないか確認しているようだ 145 | lora_names = set() 146 | for lora in self.unet_loras: 147 | assert ( 148 | lora.lora_name not in lora_names 149 | ), f"duplicated lora name: {lora.lora_name}. {lora_names}" 150 | lora_names.add(lora.lora_name) 151 | 152 | # 適用する 153 | for lora in self.unet_loras: 154 | lora.apply_to() 155 | self.add_module( 156 | lora.lora_name, 157 | lora, 158 | ) 159 | 160 | del unet 161 | 162 | torch.cuda.empty_cache() 163 | 164 | def create_modules( 165 | self, 166 | prefix: str, 167 | root_module: nn.Module, 168 | target_replace_modules: List[str], 169 | rank: int, 170 | multiplier: float, 171 | train_method: TRAINING_METHODS, 172 | ) -> list: 173 | loras = [] 174 | names = [] 175 | for name, module in root_module.named_modules(): 176 | if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習 177 | if "attn2" in name or "time_embed" in name: 178 | continue 179 | elif train_method == "innoxattn": # Cross Attention 以外学習 180 | if "attn2" in name: 181 | continue 182 | elif train_method == "selfattn": # Self Attention のみ学習 183 | if "attn1" not in name: 184 | continue 185 | elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習 186 | if "attn2" not in name: 187 | continue 188 | elif train_method == "full": # 全部学習 189 | pass 190 | else: 191 | raise NotImplementedError( 192 | f"train_method: {train_method} is not implemented." 193 | ) 194 | if module.__class__.__name__ in target_replace_modules: 195 | for child_name, child_module in module.named_modules(): 196 | if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]: 197 | if train_method == 'xattn-strict': 198 | if 'out' in child_name: 199 | continue 200 | if train_method == 'noxattn-hspace': 201 | if 'mid_block' not in name: 202 | continue 203 | if train_method == 'noxattn-hspace-last': 204 | if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name: 205 | continue 206 | lora_name = prefix + "." + name + "." + child_name 207 | lora_name = lora_name.replace(".", "_") 208 | # print(f"{lora_name}") 209 | lora = self.module( 210 | lora_name, child_module, multiplier, rank, self.alpha 211 | ) 212 | # print(name, child_name) 213 | # print(child_module.weight.shape) 214 | if lora_name not in names: 215 | loras.append(lora) 216 | names.append(lora_name) 217 | # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}') 218 | return loras 219 | 220 | def prepare_optimizer_params(self): 221 | all_params = [] 222 | 223 | if self.unet_loras: # 実質これしかない 224 | params = [] 225 | [params.extend(lora.parameters()) for lora in self.unet_loras] 226 | param_data = {"params": params} 227 | all_params.append(param_data) 228 | 229 | return all_params 230 | 231 | def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): 232 | state_dict = self.state_dict() 233 | 234 | if dtype is not None: 235 | for key in list(state_dict.keys()): 236 | v = state_dict[key] 237 | v = v.detach().clone().to("cpu").to(dtype) 238 | state_dict[key] = v 239 | 240 | # for key in list(state_dict.keys()): 241 | # if not key.startswith("lora"): 242 | # # lora以外除外 243 | # del state_dict[key] 244 | 245 | if os.path.splitext(file)[1] == ".safetensors": 246 | save_file(state_dict, file, metadata) 247 | else: 248 | torch.save(state_dict, file) 249 | def set_lora_slider(self, scale): 250 | self.lora_scale = scale 251 | 252 | def __enter__(self): 253 | for lora in self.unet_loras: 254 | lora.multiplier = 1.0 * self.lora_scale 255 | 256 | def __exit__(self, exc_type, exc_value, tb): 257 | for lora in self.unet_loras: 258 | lora.multiplier = 0 259 | -------------------------------------------------------------------------------- /trainscripts/textsliders/model_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Union, Optional 2 | 3 | import torch 4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection 5 | from diffusers import ( 6 | UNet2DConditionModel, 7 | SchedulerMixin, 8 | StableDiffusionPipeline, 9 | StableDiffusionXLPipeline, 10 | ) 11 | from diffusers.schedulers import ( 12 | DDIMScheduler, 13 | DDPMScheduler, 14 | LMSDiscreteScheduler, 15 | EulerAncestralDiscreteScheduler, 16 | ) 17 | 18 | 19 | TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" 20 | TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" 21 | 22 | AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"] 23 | 24 | SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] 25 | 26 | DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this 27 | 28 | 29 | def load_diffusers_model( 30 | pretrained_model_name_or_path: str, 31 | v2: bool = False, 32 | clip_skip: Optional[int] = None, 33 | weight_dtype: torch.dtype = torch.float32, 34 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: 35 | # VAE はいらない 36 | 37 | if v2: 38 | tokenizer = CLIPTokenizer.from_pretrained( 39 | TOKENIZER_V2_MODEL_NAME, 40 | subfolder="tokenizer", 41 | torch_dtype=weight_dtype, 42 | cache_dir=DIFFUSERS_CACHE_DIR, 43 | ) 44 | text_encoder = CLIPTextModel.from_pretrained( 45 | pretrained_model_name_or_path, 46 | subfolder="text_encoder", 47 | # default is clip skip 2 48 | num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, 49 | torch_dtype=weight_dtype, 50 | cache_dir=DIFFUSERS_CACHE_DIR, 51 | ) 52 | else: 53 | tokenizer = CLIPTokenizer.from_pretrained( 54 | TOKENIZER_V1_MODEL_NAME, 55 | subfolder="tokenizer", 56 | torch_dtype=weight_dtype, 57 | cache_dir=DIFFUSERS_CACHE_DIR, 58 | ) 59 | text_encoder = CLIPTextModel.from_pretrained( 60 | pretrained_model_name_or_path, 61 | subfolder="text_encoder", 62 | num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, 63 | torch_dtype=weight_dtype, 64 | cache_dir=DIFFUSERS_CACHE_DIR, 65 | ) 66 | 67 | unet = UNet2DConditionModel.from_pretrained( 68 | pretrained_model_name_or_path, 69 | subfolder="unet", 70 | torch_dtype=weight_dtype, 71 | cache_dir=DIFFUSERS_CACHE_DIR, 72 | ) 73 | 74 | return tokenizer, text_encoder, unet 75 | 76 | 77 | def load_checkpoint_model( 78 | checkpoint_path: str, 79 | v2: bool = False, 80 | clip_skip: Optional[int] = None, 81 | weight_dtype: torch.dtype = torch.float32, 82 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: 83 | pipe = StableDiffusionPipeline.from_ckpt( 84 | checkpoint_path, 85 | upcast_attention=True if v2 else False, 86 | torch_dtype=weight_dtype, 87 | cache_dir=DIFFUSERS_CACHE_DIR, 88 | ) 89 | 90 | unet = pipe.unet 91 | tokenizer = pipe.tokenizer 92 | text_encoder = pipe.text_encoder 93 | if clip_skip is not None: 94 | if v2: 95 | text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) 96 | else: 97 | text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) 98 | 99 | del pipe 100 | 101 | return tokenizer, text_encoder, unet 102 | 103 | 104 | def load_models( 105 | pretrained_model_name_or_path: str, 106 | scheduler_name: AVAILABLE_SCHEDULERS, 107 | v2: bool = False, 108 | v_pred: bool = False, 109 | weight_dtype: torch.dtype = torch.float32, 110 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: 111 | if pretrained_model_name_or_path.endswith( 112 | ".ckpt" 113 | ) or pretrained_model_name_or_path.endswith(".safetensors"): 114 | tokenizer, text_encoder, unet = load_checkpoint_model( 115 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype 116 | ) 117 | else: # diffusers 118 | tokenizer, text_encoder, unet = load_diffusers_model( 119 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype 120 | ) 121 | 122 | # VAE はいらない 123 | 124 | scheduler = create_noise_scheduler( 125 | scheduler_name, 126 | prediction_type="v_prediction" if v_pred else "epsilon", 127 | ) 128 | 129 | return tokenizer, text_encoder, unet, scheduler 130 | 131 | 132 | def load_diffusers_model_xl( 133 | pretrained_model_name_or_path: str, 134 | weight_dtype: torch.dtype = torch.float32, 135 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: 136 | # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet 137 | 138 | tokenizers = [ 139 | CLIPTokenizer.from_pretrained( 140 | pretrained_model_name_or_path, 141 | subfolder="tokenizer", 142 | torch_dtype=weight_dtype, 143 | cache_dir=DIFFUSERS_CACHE_DIR, 144 | ), 145 | CLIPTokenizer.from_pretrained( 146 | pretrained_model_name_or_path, 147 | subfolder="tokenizer_2", 148 | torch_dtype=weight_dtype, 149 | cache_dir=DIFFUSERS_CACHE_DIR, 150 | pad_token_id=0, # same as open clip 151 | ), 152 | ] 153 | 154 | text_encoders = [ 155 | CLIPTextModel.from_pretrained( 156 | pretrained_model_name_or_path, 157 | subfolder="text_encoder", 158 | torch_dtype=weight_dtype, 159 | cache_dir=DIFFUSERS_CACHE_DIR, 160 | ), 161 | CLIPTextModelWithProjection.from_pretrained( 162 | pretrained_model_name_or_path, 163 | subfolder="text_encoder_2", 164 | torch_dtype=weight_dtype, 165 | cache_dir=DIFFUSERS_CACHE_DIR, 166 | ), 167 | ] 168 | 169 | unet = UNet2DConditionModel.from_pretrained( 170 | pretrained_model_name_or_path, 171 | subfolder="unet", 172 | torch_dtype=weight_dtype, 173 | cache_dir=DIFFUSERS_CACHE_DIR, 174 | ) 175 | 176 | return tokenizers, text_encoders, unet 177 | 178 | 179 | def load_checkpoint_model_xl( 180 | checkpoint_path: str, 181 | weight_dtype: torch.dtype = torch.float32, 182 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: 183 | pipe = StableDiffusionXLPipeline.from_single_file( 184 | checkpoint_path, 185 | torch_dtype=weight_dtype, 186 | cache_dir=DIFFUSERS_CACHE_DIR, 187 | ) 188 | 189 | unet = pipe.unet 190 | tokenizers = [pipe.tokenizer, pipe.tokenizer_2] 191 | text_encoders = [pipe.text_encoder, pipe.text_encoder_2] 192 | if len(text_encoders) == 2: 193 | text_encoders[1].pad_token_id = 0 194 | 195 | del pipe 196 | 197 | return tokenizers, text_encoders, unet 198 | 199 | 200 | def load_models_xl( 201 | pretrained_model_name_or_path: str, 202 | scheduler_name: AVAILABLE_SCHEDULERS, 203 | weight_dtype: torch.dtype = torch.float32, 204 | ) -> tuple[ 205 | list[CLIPTokenizer], 206 | list[SDXL_TEXT_ENCODER_TYPE], 207 | UNet2DConditionModel, 208 | SchedulerMixin, 209 | ]: 210 | if pretrained_model_name_or_path.endswith( 211 | ".ckpt" 212 | ) or pretrained_model_name_or_path.endswith(".safetensors"): 213 | ( 214 | tokenizers, 215 | text_encoders, 216 | unet, 217 | ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype) 218 | else: # diffusers 219 | ( 220 | tokenizers, 221 | text_encoders, 222 | unet, 223 | ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype) 224 | 225 | scheduler = create_noise_scheduler(scheduler_name) 226 | 227 | return tokenizers, text_encoders, unet, scheduler 228 | 229 | 230 | def create_noise_scheduler( 231 | scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", 232 | prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", 233 | ) -> SchedulerMixin: 234 | # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。 235 | 236 | name = scheduler_name.lower().replace(" ", "_") 237 | if name == "ddim": 238 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim 239 | scheduler = DDIMScheduler( 240 | beta_start=0.00085, 241 | beta_end=0.012, 242 | beta_schedule="scaled_linear", 243 | num_train_timesteps=1000, 244 | clip_sample=False, 245 | prediction_type=prediction_type, # これでいいの? 246 | ) 247 | elif name == "ddpm": 248 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm 249 | scheduler = DDPMScheduler( 250 | beta_start=0.00085, 251 | beta_end=0.012, 252 | beta_schedule="scaled_linear", 253 | num_train_timesteps=1000, 254 | clip_sample=False, 255 | prediction_type=prediction_type, 256 | ) 257 | elif name == "lms": 258 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete 259 | scheduler = LMSDiscreteScheduler( 260 | beta_start=0.00085, 261 | beta_end=0.012, 262 | beta_schedule="scaled_linear", 263 | num_train_timesteps=1000, 264 | prediction_type=prediction_type, 265 | ) 266 | elif name == "euler_a": 267 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral 268 | scheduler = EulerAncestralDiscreteScheduler( 269 | beta_start=0.00085, 270 | beta_end=0.012, 271 | beta_schedule="scaled_linear", 272 | num_train_timesteps=1000, 273 | prediction_type=prediction_type, 274 | ) 275 | else: 276 | raise ValueError(f"Unknown scheduler name: {name}") 277 | 278 | return scheduler 279 | -------------------------------------------------------------------------------- /trainscripts/textsliders/prompt_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, Union, List 2 | 3 | import yaml 4 | from pathlib import Path 5 | 6 | 7 | from pydantic import BaseModel, root_validator 8 | import torch 9 | import copy 10 | 11 | ACTION_TYPES = Literal[ 12 | "erase", 13 | "enhance", 14 | ] 15 | 16 | 17 | # XL は二種類必要なので 18 | class PromptEmbedsXL: 19 | text_embeds: torch.FloatTensor 20 | pooled_embeds: torch.FloatTensor 21 | 22 | def __init__(self, *args) -> None: 23 | self.text_embeds = args[0] 24 | self.pooled_embeds = args[1] 25 | 26 | 27 | # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL 28 | PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL] 29 | 30 | 31 | class PromptEmbedsCache: # 使いまわしたいので 32 | prompts: dict[str, PROMPT_EMBEDDING] = {} 33 | 34 | def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None: 35 | self.prompts[__name] = __value 36 | 37 | def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]: 38 | if __name in self.prompts: 39 | return self.prompts[__name] 40 | else: 41 | return None 42 | 43 | 44 | class PromptSettings(BaseModel): # yaml のやつ 45 | target: str 46 | positive: str = None # if None, target will be used 47 | unconditional: str = "" # default is "" 48 | neutral: str = None # if None, unconditional will be used 49 | action: ACTION_TYPES = "erase" # default is "erase" 50 | guidance_scale: float = 1.0 # default is 1.0 51 | resolution: int = 512 # default is 512 52 | dynamic_resolution: bool = False # default is False 53 | batch_size: int = 1 # default is 1 54 | dynamic_crops: bool = False # default is False. only used when model is XL 55 | 56 | @root_validator(pre=True) 57 | def fill_prompts(cls, values): 58 | keys = values.keys() 59 | if "target" not in keys: 60 | raise ValueError("target must be specified") 61 | if "positive" not in keys: 62 | values["positive"] = values["target"] 63 | if "unconditional" not in keys: 64 | values["unconditional"] = "" 65 | if "neutral" not in keys: 66 | values["neutral"] = values["unconditional"] 67 | 68 | return values 69 | 70 | 71 | class PromptEmbedsPair: 72 | target: PROMPT_EMBEDDING # not want to generate the concept 73 | positive: PROMPT_EMBEDDING # generate the concept 74 | unconditional: PROMPT_EMBEDDING # uncondition (default should be empty) 75 | neutral: PROMPT_EMBEDDING # base condition (default should be empty) 76 | 77 | guidance_scale: float 78 | resolution: int 79 | dynamic_resolution: bool 80 | batch_size: int 81 | dynamic_crops: bool 82 | 83 | loss_fn: torch.nn.Module 84 | action: ACTION_TYPES 85 | 86 | def __init__( 87 | self, 88 | loss_fn: torch.nn.Module, 89 | target: PROMPT_EMBEDDING, 90 | positive: PROMPT_EMBEDDING, 91 | unconditional: PROMPT_EMBEDDING, 92 | neutral: PROMPT_EMBEDDING, 93 | settings: PromptSettings, 94 | ) -> None: 95 | self.loss_fn = loss_fn 96 | self.target = target 97 | self.positive = positive 98 | self.unconditional = unconditional 99 | self.neutral = neutral 100 | 101 | self.guidance_scale = settings.guidance_scale 102 | self.resolution = settings.resolution 103 | self.dynamic_resolution = settings.dynamic_resolution 104 | self.batch_size = settings.batch_size 105 | self.dynamic_crops = settings.dynamic_crops 106 | self.action = settings.action 107 | 108 | def _erase( 109 | self, 110 | target_latents: torch.FloatTensor, # "van gogh" 111 | positive_latents: torch.FloatTensor, # "van gogh" 112 | unconditional_latents: torch.FloatTensor, # "" 113 | neutral_latents: torch.FloatTensor, # "" 114 | ) -> torch.FloatTensor: 115 | """Target latents are going not to have the positive concept.""" 116 | return self.loss_fn( 117 | target_latents, 118 | neutral_latents 119 | - self.guidance_scale * (positive_latents - unconditional_latents) 120 | ) 121 | 122 | 123 | def _enhance( 124 | self, 125 | target_latents: torch.FloatTensor, # "van gogh" 126 | positive_latents: torch.FloatTensor, # "van gogh" 127 | unconditional_latents: torch.FloatTensor, # "" 128 | neutral_latents: torch.FloatTensor, # "" 129 | ): 130 | """Target latents are going to have the positive concept.""" 131 | return self.loss_fn( 132 | target_latents, 133 | neutral_latents 134 | + self.guidance_scale * (positive_latents - unconditional_latents) 135 | ) 136 | 137 | def loss( 138 | self, 139 | **kwargs, 140 | ): 141 | if self.action == "erase": 142 | return self._erase(**kwargs) 143 | 144 | elif self.action == "enhance": 145 | return self._enhance(**kwargs) 146 | 147 | else: 148 | raise ValueError("action must be erase or enhance") 149 | 150 | 151 | def load_prompts_from_yaml(path, attributes = []): 152 | with open(path, "r") as f: 153 | prompts = yaml.safe_load(f) 154 | print(prompts) 155 | if len(prompts) == 0: 156 | raise ValueError("prompts file is empty") 157 | if len(attributes)!=0: 158 | newprompts = [] 159 | for i in range(len(prompts)): 160 | for att in attributes: 161 | copy_ = copy.deepcopy(prompts[i]) 162 | copy_['target'] = att + ' ' + copy_['target'] 163 | copy_['positive'] = att + ' ' + copy_['positive'] 164 | copy_['neutral'] = att + ' ' + copy_['neutral'] 165 | copy_['unconditional'] = att + ' ' + copy_['unconditional'] 166 | newprompts.append(copy_) 167 | else: 168 | newprompts = copy.deepcopy(prompts) 169 | 170 | print(newprompts) 171 | print(len(prompts), len(newprompts)) 172 | prompt_settings = [PromptSettings(**prompt) for prompt in newprompts] 173 | 174 | return prompt_settings 175 | -------------------------------------------------------------------------------- /trainscripts/textsliders/ptp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image, ImageDraw, ImageFont 18 | import cv2 19 | from typing import Optional, Union, Tuple, List, Callable, Dict 20 | from IPython.display import display 21 | from tqdm.notebook import tqdm 22 | 23 | 24 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 25 | h, w, c = image.shape 26 | offset = int(h * .2) 27 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 28 | font = cv2.FONT_HERSHEY_SIMPLEX 29 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 30 | img[:h] = image 31 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 32 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 33 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) 34 | return img 35 | 36 | 37 | def view_images(images, num_rows=1, offset_ratio=0.02): 38 | if type(images) is list: 39 | num_empty = len(images) % num_rows 40 | elif images.ndim == 4: 41 | num_empty = images.shape[0] % num_rows 42 | else: 43 | images = [images] 44 | num_empty = 0 45 | 46 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 47 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 48 | num_items = len(images) 49 | 50 | h, w, c = images[0].shape 51 | offset = int(h * offset_ratio) 52 | num_cols = num_items // num_rows 53 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 54 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 55 | for i in range(num_rows): 56 | for j in range(num_cols): 57 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 58 | i * num_cols + j] 59 | 60 | pil_img = Image.fromarray(image_) 61 | display(pil_img) 62 | 63 | 64 | def diffusion_step(unet, model, controller, latents, context, t, guidance_scale, low_resource=False): 65 | if low_resource: 66 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 67 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 68 | else: 69 | latents_input = torch.cat([latents] * 2) 70 | noise_pred = unet(latents_input, t, encoder_hidden_states=context)["sample"] 71 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 72 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 73 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 74 | latents = controller.step_callback(latents) 75 | return latents 76 | 77 | 78 | def latent2image(vae, latents): 79 | latents = 1 / 0.18215 * latents 80 | image = vae.decode(latents)['sample'] 81 | image = (image / 2 + 0.5).clamp(0, 1) 82 | image = image.cpu().permute(0, 2, 3, 1).numpy() 83 | image = (image * 255).astype(np.uint8) 84 | return image 85 | 86 | 87 | def init_latent(latent, model, height, width, generator, batch_size): 88 | if latent is None: 89 | latent = torch.randn( 90 | (1, model.unet.in_channels, height // 8, width // 8), 91 | generator=generator, 92 | ) 93 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 94 | return latent, latents 95 | 96 | 97 | @torch.no_grad() 98 | def text2image_ldm( 99 | model, 100 | prompt: List[str], 101 | controller, 102 | num_inference_steps: int = 50, 103 | guidance_scale: Optional[float] = 7., 104 | generator: Optional[torch.Generator] = None, 105 | latent: Optional[torch.FloatTensor] = None, 106 | ): 107 | register_attention_control(model, controller) 108 | height = width = 256 109 | batch_size = len(prompt) 110 | 111 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 112 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] 113 | 114 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 115 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] 116 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 117 | context = torch.cat([uncond_embeddings, text_embeddings]) 118 | 119 | model.scheduler.set_timesteps(num_inference_steps) 120 | for t in tqdm(model.scheduler.timesteps): 121 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 122 | 123 | image = latent2image(model.vqvae, latents) 124 | 125 | return image, latent 126 | 127 | 128 | @torch.no_grad() 129 | def text2image_ldm_stable( 130 | model, 131 | prompt: List[str], 132 | controller, 133 | num_inference_steps: int = 50, 134 | guidance_scale: float = 7.5, 135 | generator: Optional[torch.Generator] = None, 136 | latent: Optional[torch.FloatTensor] = None, 137 | low_resource: bool = False, 138 | ): 139 | register_attention_control(model, controller) 140 | height = width = 512 141 | batch_size = len(prompt) 142 | 143 | text_input = model.tokenizer( 144 | prompt, 145 | padding="max_length", 146 | max_length=model.tokenizer.model_max_length, 147 | truncation=True, 148 | return_tensors="pt", 149 | ) 150 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 151 | max_length = text_input.input_ids.shape[-1] 152 | uncond_input = model.tokenizer( 153 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 154 | ) 155 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 156 | 157 | context = [uncond_embeddings, text_embeddings] 158 | if not low_resource: 159 | context = torch.cat(context) 160 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 161 | 162 | # set timesteps 163 | extra_set_kwargs = {"offset": 1} 164 | model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 165 | for t in tqdm(model.scheduler.timesteps): 166 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) 167 | 168 | image = latent2image(model.vae, latents) 169 | 170 | return image, latent 171 | 172 | 173 | def register_attention_control(model, controller): 174 | def ca_forward(self, place_in_unet): 175 | to_out = self.to_out 176 | if type(to_out) is torch.nn.modules.container.ModuleList: 177 | to_out = self.to_out[0] 178 | else: 179 | to_out = self.to_out 180 | 181 | def forward(x, context=None, mask=None): 182 | batch_size, sequence_length, dim = x.shape 183 | h = self.heads 184 | q = self.to_q(x) 185 | is_cross = context is not None 186 | context = context if is_cross else x 187 | k = self.to_k(context) 188 | v = self.to_v(context) 189 | q = self.reshape_heads_to_batch_dim(q) 190 | k = self.reshape_heads_to_batch_dim(k) 191 | v = self.reshape_heads_to_batch_dim(v) 192 | 193 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale 194 | 195 | if mask is not None: 196 | mask = mask.reshape(batch_size, -1) 197 | max_neg_value = -torch.finfo(sim.dtype).max 198 | mask = mask[:, None, :].repeat(h, 1, 1) 199 | sim.masked_fill_(~mask, max_neg_value) 200 | 201 | # attention, what we cannot get enough of 202 | attn = sim.softmax(dim=-1) 203 | attn = controller(attn, is_cross, place_in_unet) 204 | out = torch.einsum("b i j, b j d -> b i d", attn, v) 205 | out = self.reshape_batch_dim_to_heads(out) 206 | return to_out(out) 207 | 208 | return forward 209 | 210 | class DummyController: 211 | 212 | def __call__(self, *args): 213 | return args[0] 214 | 215 | def __init__(self): 216 | self.num_att_layers = 0 217 | 218 | if controller is None: 219 | controller = DummyController() 220 | 221 | def register_recr(net_, count, place_in_unet): 222 | if net_.__class__.__name__ == 'CrossAttention': 223 | net_.forward = ca_forward(net_, place_in_unet) 224 | return count + 1 225 | elif hasattr(net_, 'children'): 226 | for net__ in net_.children(): 227 | count = register_recr(net__, count, place_in_unet) 228 | return count 229 | 230 | cross_att_count = 0 231 | sub_nets = model.unet.named_children() 232 | for net in sub_nets: 233 | if "down" in net[0]: 234 | cross_att_count += register_recr(net[1], 0, "down") 235 | elif "up" in net[0]: 236 | cross_att_count += register_recr(net[1], 0, "up") 237 | elif "mid" in net[0]: 238 | cross_att_count += register_recr(net[1], 0, "mid") 239 | 240 | controller.num_att_layers = cross_att_count 241 | 242 | 243 | def get_word_inds(text: str, word_place: int, tokenizer): 244 | split_text = text.split(" ") 245 | if type(word_place) is str: 246 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 247 | elif type(word_place) is int: 248 | word_place = [word_place] 249 | out = [] 250 | if len(word_place) > 0: 251 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 252 | cur_len, ptr = 0, 0 253 | 254 | for i in range(len(words_encode)): 255 | cur_len += len(words_encode[i]) 256 | if ptr in word_place: 257 | out.append(i + 1) 258 | if cur_len >= len(split_text[ptr]): 259 | ptr += 1 260 | cur_len = 0 261 | return np.array(out) 262 | 263 | 264 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 265 | word_inds: Optional[torch.Tensor]=None): 266 | if type(bounds) is float: 267 | bounds = 0, bounds 268 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 269 | if word_inds is None: 270 | word_inds = torch.arange(alpha.shape[2]) 271 | alpha[: start, prompt_ind, word_inds] = 0 272 | alpha[start: end, prompt_ind, word_inds] = 1 273 | alpha[end:, prompt_ind, word_inds] = 0 274 | return alpha 275 | 276 | 277 | def get_time_words_attention_alpha(prompts, num_steps, 278 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 279 | tokenizer, max_num_words=77): 280 | if type(cross_replace_steps) is not dict: 281 | cross_replace_steps = {"default_": cross_replace_steps} 282 | if "default_" not in cross_replace_steps: 283 | cross_replace_steps["default_"] = (0., 1.) 284 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 285 | for i in range(len(prompts) - 1): 286 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 287 | i) 288 | for key, item in cross_replace_steps.items(): 289 | if key != "default_": 290 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 291 | for i, ind in enumerate(inds): 292 | if len(ind) > 0: 293 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 294 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 295 | return alpha_time_words --------------------------------------------------------------------------------