├── GPT_prompt_helper.ipynb
├── LICENSE
├── README.md
├── SD1-sliders-inference.ipynb
├── XL-sliders-inference.ipynb
├── __init__.py
├── demo_SDXL_Turbo.ipynb
├── demo_concept_sliders.ipynb
├── demo_image_editing.ipynb
├── eval-scripts
├── .ipynb_checkpoints
│ └── generate_images_xl-checkpoint.py
├── clip_score.py
├── generate_images-uce.py
├── generate_images_customdiffusion.py
├── generate_images_sd1.py
├── generate_images_textinversion.py
├── generate_images_textinversion_xl.py
├── generate_images_xl.py
└── lpip_score.py
├── flux-sliders
├── flux-requirements.txt
├── train-flux-concept-sliders.ipynb
└── utils
│ ├── custom_flux_pipeline.py
│ ├── lora.py
│ ├── model_util.py
│ ├── prompt_util.py
│ ├── ptp_utils.py
│ └── train_util.py
├── images
└── main_figure.png
├── prompts
├── prompts-car.csv
├── prompts-food.csv
├── prompts-person.csv
├── prompts-room.csv
└── prompts-sky.csv
├── requirements.txt
└── trainscripts
├── __init__.py
├── __pycache__
├── __init__.cpython-310.pyc
└── __init__.cpython-39.pyc
├── imagesliders
├── __pycache__
│ ├── config_util.cpython-39.pyc
│ ├── debug_util.cpython-39.pyc
│ ├── lora.cpython-39.pyc
│ ├── model_util.cpython-39.pyc
│ ├── prompt_util.cpython-39.pyc
│ └── train_util.cpython-39.pyc
├── config_util.py
├── data
│ ├── config-xl.yaml
│ ├── config.yaml
│ ├── prompts-xl.yaml
│ └── prompts.yaml
├── debug_util.py
├── lora.py
├── model_util.py
├── prompt_util.py
├── train_lora-scale-xl.py
├── train_lora-scale.py
└── train_util.py
└── textsliders
├── .ipynb_checkpoints
├── train_lora-checkpoint.py
└── train_lora_xl-checkpoint.py
├── __init__.py
├── __pycache__
├── __init__.cpython-310.pyc
├── __init__.cpython-39.pyc
├── config_util.cpython-39.pyc
├── debug_util.cpython-39.pyc
├── lora.cpython-310.pyc
├── lora.cpython-39.pyc
├── model_util.cpython-39.pyc
├── prompt_util.cpython-39.pyc
├── ptp_utils.cpython-39.pyc
└── train_util.cpython-39.pyc
├── config_util.py
├── data
├── .ipynb_checkpoints
│ ├── prompts-person_age_slider_GPT-checkpoint.yaml
│ ├── prompts-smile_slider_GPT-checkpoint.yaml
│ └── prompts-xl-checkpoint.yaml
├── config-xl.yaml
├── config.yaml
├── prompts-animated_eyes_GPT.yaml
├── prompts-car_alienTechFuturistic_GPT.yaml
├── prompts-jewelry_diamonds_GPT.yaml
├── prompts-person_age_slider_GPT.yaml
├── prompts-person_surprised_GPT.yaml
├── prompts-smile_slider_GPT.yaml
├── prompts-xl.yaml
└── prompts.yaml
├── debug_util.py
├── flush.py
├── generate_images_xl.py
├── lora.py
├── model_util.py
├── prompt_util.py
├── ptp_utils.py
├── train_lora.py
├── train_lora_xl.py
└── train_util.py
/GPT_prompt_helper.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5f21e08c-259d-4099-b1ed-b782bf94be05",
6 | "metadata": {},
7 | "source": [
8 | "Replace `key` with your openai key (do not use quotations)"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 5,
14 | "id": "7f45bebd-fe6e-41a8-a6e1-bf56d765463e",
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "env: OPENAI_API_KEY=key\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "%env OPENAI_API_KEY=key"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "id": "7784d6c8-e28a-494b-b97f-374fcc94213f",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "import os\n",
37 | "import yaml\n",
38 | "from openai import OpenAI\n",
39 | "client = OpenAI()\n"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "id": "6de55bbf-b978-478c-a8f0-db412bc52e9e",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "def generate_prompts_sliders(slider_query, \n",
50 | " file_name_to_save=None, \n",
51 | " temperature=0.2, \n",
52 | " max_tokens=256, \n",
53 | " frequency_penalty=0.0,\n",
54 | " model=\"gpt-4-turbo-preview\",\n",
55 | " verbose=False, \n",
56 | " save=True):\n",
57 | " '''\n",
58 | " A function to automatically build prompts for text sliders using GPT4 (or any other openAI model). \n",
59 | " \n",
60 | " Inputs\n",
61 | " ------\n",
62 | " slider_query (str): A natural language query describing the slider effects the user desired (eg: \"I want to make people older\")\n",
63 | " file_name_to_save (str) (optional): a full name of the yaml file a user desires. If left as None, a name will be chosen by GPT\n",
64 | " temperature (float) (optional): GPT temperature parameter (use smaller values for less randomness)\n",
65 | " max_tokens (int) (optional): GPT output token limit\n",
66 | " frequency_penalty (float) (optional): GPT frequency penalty\n",
67 | " model (str) (optional): The model class from openAI. By default uses GPT-4-Turbo\n",
68 | " verbose (bool) (optional): A flag to print intermediate responses by GPT\n",
69 | " save (bool) (optional): A flag to save the prompts to a destination path\n",
70 | " '''\n",
71 | " gpt_assistant_prompt = '''You are an expert in prompting text-image generation models. Given a concept to edit, your task is to generate 4 detailed prompts.\n",
72 | " 1. Target prompt: a prompt that describes the target class which the concept edit is intended to modify (for example, to edit the concept \"professional\" the target concept is \"person\". Leave it empty if the target concept is too large. For example if user asks for their generations to be more futuristic, since all the images have to be edited, just leave the target \"\"\n",
73 | " 2. Positive prompt: a detailed prompt that describes the extreme positive end of the edit concept with the target concept included (for example, \"person, very professional, blazer, neat, organized)\"\n",
74 | " 3. Negative prompt: a detailed prompt that describes the extreme negative end of the edit concept with the target concept included (for example, \"person, non-professional, ragidy, unkempt\"). This is optional, you can leave it \"\" if there is no obvious negative prompt.\n",
75 | " 4. Preservation prompt: a prompt (must be comma separated) that describes any concepts except the ones to edit that should be preserved when making the edit without the target concept included (for example, \"white, black, indian, asian, hispanic; male, female\" as the race or gender of a person may be changed when we edit the professionalism.). This should not include edit concepts and should not include any of the positive or negative concepts. if there are no obvious entanglement issues with the edit, leave the prompt \"\"\n",
76 | " make preservation prompt comma seperated for each class of perservation. For example if you want to preserve both race and gender, then give something like \"white race, black race, indian race, asian race; male, female\"\n",
77 | "\n",
78 | " All the prompts must be strictly string type. Be specific. Do not use any alphanumeric symbols.\n",
79 | " \n",
80 | " This is an example template for your response when asked to generate prompts for making people smile:\n",
81 | " Target: person\n",
82 | " Positive: person, smiling, happy face, big smile\n",
83 | " Negative: person, frowning, grumpy, sad\n",
84 | " Preservation: white, black, indian, asian, hispanic ; male, female\n",
85 | " Name: person_age_GPT\n",
86 | " \n",
87 | " Here is another example template for your response when asked - \"I want to make images more detailed\":\n",
88 | " Target: \n",
89 | " Positive: highly detailed, intricate patterns, fine textures, realistic shading\n",
90 | " Negative: simplistic, minimalistic, abstract, rough outlines\n",
91 | " Preservation: \n",
92 | " Name: detailed_GPT\n",
93 | " '''\n",
94 | " gpt_user_prompt = slider_query\n",
95 | " gpt_prompt = gpt_assistant_prompt, gpt_user_prompt\n",
96 | " message=[{\"role\": \"assistant\", \"content\": gpt_assistant_prompt}, {\"role\": \"user\", \"content\": gpt_user_prompt}]\n",
97 | " \n",
98 | " response = client.chat.completions.create(\n",
99 | " model= model,\n",
100 | " messages = message,\n",
101 | " temperature=temperature,\n",
102 | " max_tokens=max_tokens,\n",
103 | " frequency_penalty=frequency_penalty\n",
104 | " )\n",
105 | " content = response.choices[0].message.content\n",
106 | " if verbose:\n",
107 | " print(content)\n",
108 | " prompts = content.splitlines()\n",
109 | " result = {}\n",
110 | " result['target'] = \"\"\n",
111 | " result['positive'] = \"\"\n",
112 | " result['unconditional'] = \"\"\n",
113 | " result['neutral'] = \"\"\n",
114 | " for prompt in prompts:\n",
115 | " key = prompt.split(':')\n",
116 | " if key[0].lower().strip() == 'preservation':\n",
117 | " final_attributes = []\n",
118 | " attributes = key[1].split(';')\n",
119 | " for attribute in attributes:\n",
120 | " if len(attribute.strip()) == 0:\n",
121 | " continue\n",
122 | " final_attributes.append(attribute.strip().split(','))\n",
123 | " elif key[0].lower().strip() == 'name':\n",
124 | " name = key[1].strip()\n",
125 | " for prompt in prompts:\n",
126 | " key = prompt.split(':')\n",
127 | " if len(key)!=2:\n",
128 | " continue\n",
129 | " if key[0].lower().strip() == 'target':\n",
130 | " result['target'] = key[1].strip()\n",
131 | " elif key[0].lower().strip() == 'positive':\n",
132 | " result['positive'] = key[1].strip()\n",
133 | " elif key[0].lower().strip() == 'negative':\n",
134 | " result['unconditional'] = key[1].strip()\n",
135 | " result['neutral'] = result['target']\n",
136 | " results = [result]\n",
137 | " \n",
138 | " for attribute_class in final_attributes:\n",
139 | " results_final = []\n",
140 | " for attribute in attribute_class:\n",
141 | " for result in results:\n",
142 | " r = {}\n",
143 | " for key in result.keys():\n",
144 | " r[key] = attribute.strip() + f' {result[key].strip()}'\n",
145 | " r[key] = r[key].strip()\n",
146 | " results_final.append(r)\n",
147 | " \n",
148 | " results = results_final\n",
149 | " results_final = []\n",
150 | " for result in results:\n",
151 | " r_final = result\n",
152 | " r_final['guidance'] = 4\n",
153 | " r_final['rank'] = 4\n",
154 | " r_final['action'] = 'enhance'\n",
155 | " r_final['resolution'] = 512\n",
156 | " r_final['dynamic_resolution'] = False\n",
157 | " r_final['batch_size'] = 1\n",
158 | " results_final.append(r_final)\n",
159 | " if file_name_to_save is None:\n",
160 | " if name is None:\n",
161 | " file_name_to_save = 'custom-prompts-GPT.yaml'\n",
162 | " else:\n",
163 | " file_name_to_save = f'prompts-{name}.yaml'\n",
164 | " if save:\n",
165 | " with open(f'trainscripts/textsliders/data/{file_name_to_save}', 'w+') as f:\n",
166 | " yaml.dump(results_final, f, allow_unicode=True, sort_keys=False)\n",
167 | " if verbose:\n",
168 | " print(f'Prompt file saved to: \"trainscripts/textsliders/data/{file_name_to_save}\"')\n",
169 | " return f'trainscripts/textsliders/data/{file_name_to_save}'"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 4,
175 | "id": "eae07933-c3d7-43d1-99ba-b87b749ec1fd",
176 | "metadata": {
177 | "scrolled": true
178 | },
179 | "outputs": [
180 | {
181 | "name": "stdout",
182 | "output_type": "stream",
183 | "text": [
184 | "Target: person\n",
185 | "Positive: person, aged, wrinkles, grey hair, elderly features\n",
186 | "Negative: person, youthful, smooth skin, vibrant, young\n",
187 | "Preservation: white, black, indian, asian, hispanic ; male, female\n",
188 | "Name: person_age_GPT\n",
189 | "Prompt file saved to: \"trainscripts/textsliders/data/prompts-person_age_GPT.yaml\"\n"
190 | ]
191 | },
192 | {
193 | "data": {
194 | "text/plain": [
195 | "'trainscripts/textsliders/data/prompts-person_age_GPT.yaml'"
196 | ]
197 | },
198 | "execution_count": 4,
199 | "metadata": {},
200 | "output_type": "execute_result"
201 | }
202 | ],
203 | "source": [
204 | "query = \"I want to build a slider to make people old\"\n",
205 | "generate_prompts_sliders(slider_query=query, model=\"gpt-4-turbo-preview\", save=True, verbose=True)"
206 | ]
207 | }
208 | ],
209 | "metadata": {
210 | "kernelspec": {
211 | "display_name": "Python 3 (ipykernel)",
212 | "language": "python",
213 | "name": "python3"
214 | },
215 | "language_info": {
216 | "codemirror_mode": {
217 | "name": "ipython",
218 | "version": 3
219 | },
220 | "file_extension": ".py",
221 | "mimetype": "text/x-python",
222 | "name": "python",
223 | "nbconvert_exporter": "python",
224 | "pygments_lexer": "ipython3",
225 | "version": "3.9.18"
226 | }
227 | },
228 | "nbformat": 4,
229 | "nbformat_minor": 5
230 | }
231 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Rohit Gandikota
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Concept Sliders
2 | ### [Project Website](https://sliders.baulab.info) | [Arxiv Preprint](https://arxiv.org/pdf/2311.12092.pdf) | [Trained Sliders](https://sliders.baulab.info/weights/xl_sliders/) | [Colab Demo](https://colab.research.google.com/github/rohitgandikota/sliders/blob/main/demo_concept_sliders.ipynb) | [Huggingface Demo](https://huggingface.co/spaces/baulab/ConceptSliders)
3 | Official code implementation of "Concept Sliders: LoRA Adaptors for Precise Control in Diffusion Models", European Conference on Computer Vision (ECCV 2024).
4 |
5 |
6 |

7 |
8 |
9 | ## Colab Demo
10 | Try out our colab demo here [](https://colab.research.google.com/github/rohitgandikota/sliders/blob/main/demo_concept_sliders.ipynb)
11 |
12 | ## FLUX Support 🚀🚀🚀
13 | You can train sliders for FLUX-1 models. Right now it is experimental! Please be patient if it doesn't work as good as SDXL. FLUX is not designed the same way as SDXL.
14 |
15 | To play with flux sliders you need to update your packages.
16 | ```
17 | pip install -r flux-sliders/flux-requirements.txt
18 | ```
19 |
20 | Now just open the notebook present in the folder `flux-sliders` and have fun!
21 |
22 | ## UPDATE
23 | You can now use GPT-4 (or any other openAI model) to create prompts for your text sliders. All you need to do is describe what slider you want to create (e.g: "i want to make people look happy").
24 | Please refer to the [GPT-notebook](https://github.com/rohitgandikota/sliders/blob/main/GPT_prompt_helper.ipynb)
25 |
26 | ## Setup
27 | To set up your python environment:
28 | ```
29 | conda create -n sliders python=3.9
30 | conda activate sliders
31 |
32 | git clone https://github.com/rohitgandikota/sliders.git
33 | cd sliders
34 | pip install -r requirements.txt
35 | ```
36 | If you are running on Windows - please refer to these Windows setup guidelines [here](https://github.com/rohitgandikota/sliders/issues/27#issuecomment-1833572579)
37 | ## Textual Concept Sliders
38 | ### Training SD-1.x and SD-2.x LoRa
39 | To train an age slider - go to `train-scripts/textsliders/data/prompts.yaml` and edit the `target=person` and `positive=old person` and `unconditional=young person` (opposite of positive) and `neutral=person` and `action=enhance` with `guidance=4`.
40 | If you do not want your edit to be targetted to person replace it with any target you want (eg. dog) or if you need it global replace `person` with `""`
41 | Finally, run the command:
42 | ```
43 | python trainscripts/textsliders/train_lora.py --attributes 'male, female' --name 'ageslider' --rank 4 --alpha 1 --config_file 'trainscripts/textsliders/data/config.yaml'
44 | ```
45 |
46 | `--attributes` argument is used to disentangle concepts from the slider. For instance age slider makes all old people male (so instead add the `"female, male"` attributes to allow disentanglement)
47 |
48 |
49 | #### Evaluate
50 | To evaluate your trained models use the notebook `SD1-sliders-inference.ipynb`
51 |
52 |
53 | ### Training SD-XL
54 | To train sliders for SD-XL, use the script `train_lora_xl.py`. The setup is same as SDv1.4
55 |
56 | ```
57 | python trainscripts/textsliders/train_lora_xl.py --attributes 'male, female' --name 'agesliderXL' --rank 4 --alpha 1 --config_file 'trainscripts/textsliders/data/config-xl.yaml'
58 | ```
59 |
60 | #### Evaluate
61 | To evaluate your trained models use the notebook `XL-sliders-inference.ipynb`
62 |
63 |
64 | ## Visual Concept Sliders
65 | ### Training SD-1.x and SD-2.x LoRa
66 | To train image based sliders, you need to create a ~4-6 pairs of image dataset (before/after edit for desired concept). Save the before images and after images separately. You can also create a dataset with varied intensity effect and save them differently.
67 |
68 | To train an image slider for eye size - go to `train-scripts/imagesliders/data/config.yaml` and edit the `target=eye` and `itive='eye'` and `unconditional=''` and `neutral=eye` and `action=enhance` with `guidance=4`.
69 | If you want the diffusion model to figure out the edit concept - leave `target, positive, unconditional, neutral` as `''`
70 | Finally, run the command:
71 | ```
72 | python trainscripts/imagesliders/train_lora-scale.py --name 'eyeslider' --rank 4 --alpha 1 --config_file 'trainscripts/imagesliders/data/config.yaml' --folder_main 'datasets/eyesize/' --folders 'bigsize, smallsize' --scales '1, -1'
73 | ```
74 | For this to work - you need to store your before images in `smallsize` and after images in `bigsize`. The corresponding paired files in both the folders should have same names. Both these subfolders should be under `datasets/eyesize`. Feel free to make your own datasets in your own named conventions.
75 | ### Training SD-XL
76 | To train image sliders for SD-XL, use the script `train-lora-scale-xl.py`. The setup is same as SDv1.4
77 |
78 | ```
79 | python trainscripts/imagesliders/train_lora-scale-xl.py --name 'eyesliderXL' --rank 4 --alpha 1 --config_file 'trainscripts/imagesliders/data/config-xl.yaml' --folder_main '/share/u/rohit/imageXLdataset/eyesize_data/'
80 | ```
81 |
82 | ## Editing Real Images
83 | Concept sliders can be used to edit real images. We use null inversion to edit the images - instead of prompt, we use sliders!
84 | Checkout - `demo_image_editing.ipynb` for mode details.
85 |
86 | ## Running Gradio Demo Locally
87 | You can also run the HF hosted gradio slider tool (huge shoutout to gradio and HF team) locally using the following scripts
88 | ```
89 | git lfs install
90 | git clone https://huggingface.co/spaces/baulab/ConceptSliders
91 | cd ConceptSliders
92 | pip install requirements.txt
93 | python app.py
94 | ```
95 | For more inference time gradio demos please refer to Cameduru's repo [here](https://github.com/camenduru/sliders-colab)
96 |
97 | ## Running with ControlNet Integration
98 | Our user community is amazing! Here is the resource that integrates ControlNet: https://github.com/rohitgandikota/sliders/issues/76#issuecomment-2099766893
99 | ## Citing our work
100 | The preprint can be cited as follows
101 | ```
102 | @inproceedings{gandikota2023erasing,
103 | title={Erasing Concepts from Diffusion Models},
104 | author={Rohit Gandikota and Joanna Materzy\'nska and Tingrui Zhou and Antonio Torralba and David Bau},
105 | booktitle={Proceedings of the 2024 IEEE European Conference on Computer Vision},
106 | note={arXiv preprint arXiv:2311.12092},
107 | year={2024}
108 | }
109 | ```
110 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from trainscripts.textsliders import lora
--------------------------------------------------------------------------------
/eval-scripts/clip_score.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import requests
3 | import os, glob
4 | import pandas as pd
5 | import numpy as np
6 | import re
7 | import argparse
8 | from transformers import CLIPProcessor, CLIPModel
9 |
10 |
11 | if __name__=='__main__':
12 | parser = argparse.ArgumentParser(
13 | prog = 'clipScore',
14 | description = 'Generate CLIP score for images')
15 | parser.add_argument('--im_path', help='path for images', type=str, required=True)
16 | parser.add_argument('--prompt', help='prompt to check clip score against', type=str, required=True)
17 | parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
18 | parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0')
19 | parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000)
20 | parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
21 |
22 | args = parser.parse_args()
23 |
24 | model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
25 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26 |
27 | def sorted_nicely( l ):
28 | convert = lambda text: int(text) if text.isdigit() else text
29 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
30 | return sorted(l, key = alphanum_key)
31 |
32 |
33 | path = args.im_path
34 | model_names = os.listdir(path)
35 | model_names = [m for m in model_names if 'all' not in m and '.csv' not in m]
36 | csv_path = args.prompts_path
37 | save_path = ''
38 | prompt = args. prompt.strip()
39 | print(f'Eval agaisnt prompt: {prompt}')
40 | model_names.sort()
41 | print(model_names)
42 | df = pd.read_csv(csv_path)
43 | for model_name in model_names:
44 | print(model_name)
45 |
46 | im_folder = os.path.join(path, model_name)
47 |
48 | images = os.listdir(im_folder)
49 | images = sorted_nicely(images)
50 | ratios = {}
51 | model_name = model_name.replace('half','0.5')
52 | df[f'clip_{model_name}'] = np.nan
53 | for image in images:
54 | try:
55 | case_number = int(image.split('_')[0].replace('.png',''))
56 | if case_number not in list(df['case_number']):
57 | continue
58 | im = Image.open(os.path.join(im_folder, image))
59 | inputs = processor(text=[prompt], images=im, return_tensors="pt", padding=True)
60 | outputs = model(**inputs)
61 | clip_score = outputs.logits_per_image[0][0].detach().cpu() # this is the image-text similarity score
62 | ratios[case_number] = ratios.get(case_number, []) + [clip_score]
63 | # print(image, clip_score)
64 | except:
65 | pass
66 | for key in ratios.keys():
67 | df.loc[key,f'clip_{model_name}'] = np.mean(ratios[key])
68 | # df = df.dropna(axis=0)
69 | print(f"Mean CLIP score: {df[f'clip_{model_name}'].mean()}")
70 | print('-------------------------------------------------')
71 | print('\n')
72 | df.to_csv(f'{path}/clip_scores.csv', index=False)
73 |
--------------------------------------------------------------------------------
/eval-scripts/generate_images-uce.py:
--------------------------------------------------------------------------------
1 | import glob, re
2 | import torch
3 | from diffusers import DiffusionPipeline
4 | from safetensors.torch import load_file
5 | import matplotlib.pyplot as plt
6 | import matplotlib.image as mpimg
7 | import random
8 | import copy
9 | import gc
10 | from tqdm.auto import tqdm
11 | import random
12 | from PIL import Image
13 | from transformers import CLIPTextModel, CLIPTokenizer
14 |
15 | import diffusers
16 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
17 | from diffusers.loaders import AttnProcsLayers
18 | from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor
19 | import torch
20 | from typing import Any, Dict, List, Optional, Tuple, Union
21 | from transformers import CLIPTextModel, CLIPTokenizer
22 | import os
23 | import sys
24 | import argparse
25 | import pandas as pd
26 | sys.path.insert(1, os.getcwd())
27 | from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
28 |
29 | def sorted_nicely( l ):
30 | convert = lambda text: float(text) if text.replace('-','').replace('.','').isdigit() else text
31 | alphanum_key = lambda key: [convert(c) for c in re.split('(-?[0-9]+.?[0-9]+?)', key) ]
32 | return sorted(l, key = alphanum_key)
33 |
34 | def flush():
35 | torch.cuda.empty_cache()
36 | gc.collect()
37 | flush()
38 |
39 | from transformers import CLIPTextModel, CLIPTokenizer
40 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
41 | from diffusers import LMSDiscreteScheduler
42 | import torch
43 | from PIL import Image
44 | import argparse
45 | import os, json, random
46 | import pandas as pd
47 | def image_grid(imgs, rows=2, cols=2):
48 | assert len(imgs) == rows*cols
49 |
50 | w, h = imgs[0].size
51 | grid = Image.new('RGB', size=(cols*w, rows*h))
52 | grid_w, grid_h = grid.size
53 |
54 | for i, img in enumerate(imgs):
55 | grid.paste(img, box=(i%cols*w, i//cols*h))
56 | return grid
57 | def generate_images(unet, new_state_dict, vae, tokenizer, text_encoder, prompt, evaluation_seed=3124, model_path=None, device='cuda:0', guidance_scale = 7.5, image_size=512, ddim_steps=50, num_samples=4, from_case=0, till_case=1000000, base='1.4', start_noise=800):
58 |
59 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
60 |
61 | vae.to(device)
62 | text_encoder.to(device)
63 | unet.to(device)
64 | torch_device = device
65 | prompt = [str(prompt)]*num_samples
66 | seed = evaluation_seed
67 |
68 | height = image_size # default height of Stable Diffusion
69 | width = image_size # default width of Stable Diffusion
70 |
71 | num_inference_steps = ddim_steps # Number of denoising steps
72 |
73 | guidance_scale = guidance_scale # Scale for classifier-free guidance
74 |
75 | generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
76 |
77 | batch_size = len(prompt)
78 |
79 | text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
80 |
81 | text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
82 |
83 | max_length = text_input.input_ids.shape[-1]
84 | uncond_input = tokenizer(
85 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
86 | )
87 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
88 |
89 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
90 |
91 | latents = torch.randn(
92 | (batch_size, unet.in_channels, height // 8, width // 8),
93 | generator=generator,
94 | )
95 | latents = latents.to(torch_device)
96 |
97 | scheduler.set_timesteps(num_inference_steps)
98 |
99 | latents = latents * scheduler.init_noise_sigma
100 |
101 | from tqdm.auto import tqdm
102 |
103 | scheduler.set_timesteps(num_inference_steps)
104 | flag = False
105 | for t in tqdm(scheduler.timesteps):
106 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
107 | if t<=start_noise:
108 | if not flag:
109 | flag = True
110 | unet.load_state_dict(new_state_dict)
111 | latent_model_input = torch.cat([latents] * 2)
112 |
113 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
114 |
115 | # predict the noise residual
116 | with torch.no_grad():
117 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
118 |
119 | # perform guidance
120 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
121 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
122 |
123 | # compute the previous noisy sample x_t -> x_t-1
124 | latents = scheduler.step(noise_pred, t, latents).prev_sample
125 |
126 | # scale and decode the image latents with vae
127 | latents = 1 / 0.18215 * latents
128 | with torch.no_grad():
129 | image = vae.decode(latents).sample
130 |
131 | image = (image / 2 + 0.5).clamp(0, 1)
132 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
133 | images = (image * 255).round().astype("uint8")
134 | pil_images = [Image.fromarray(image) for image in images]
135 | return pil_images
136 |
137 |
138 | def generate_images_(model_name, prompts_path, save_path, negative_prompt, device, guidance_scale , image_size, ddim_steps, num_samples,from_case, till_case, base, rank, start_noise):
139 | # Load scheduler, tokenizer and models.
140 | scales = [-2, -1, -.5, 0, .5, 1, 2]
141 |
142 | revision = None
143 | pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
144 | weight_dtype = torch.float32
145 |
146 | # Load scheduler, tokenizer and models.
147 | noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
148 | tokenizer = CLIPTokenizer.from_pretrained(
149 | pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
150 | )
151 | text_encoder = CLIPTextModel.from_pretrained(
152 | pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
153 | )
154 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
155 | unet = UNet2DConditionModel.from_pretrained(
156 | pretrained_model_name_or_path, subfolder="unet", revision=revision
157 | )
158 | # freeze parameters of models to save more memory
159 | unet.requires_grad_(False)
160 | unet.to(device, dtype=weight_dtype)
161 | vae.requires_grad_(False)
162 |
163 | text_encoder.requires_grad_(False)
164 |
165 | vae.to(device, dtype=weight_dtype)
166 | text_encoder.to(device, dtype=weight_dtype)
167 |
168 | df = pd.read_csv(prompts_path)
169 | prompts = df.prompt
170 | seeds = df.evaluation_seed
171 | case_numbers = df.case_number
172 |
173 | name = os.path.basename(model_name)
174 | folder_path = f'{save_path}/{name}'
175 | os.makedirs(folder_path, exist_ok=True)
176 | os.makedirs(folder_path+f'/all', exist_ok=True)
177 | scales_str = []
178 | for scale in scales:
179 | scale_str = f'{scale}'
180 | scale_str = scale_str.replace('0.5','half')
181 | scales_str.append(scale_str)
182 | os.makedirs(folder_path+f'/{scale_str}', exist_ok=True)
183 | height = image_size # default height of Stable Diffusion
184 | width = image_size # default width of Stable Diffusion
185 |
186 | num_inference_steps = ddim_steps # Number of denoising steps
187 |
188 | guidance_scale = guidance_scale # Scale for classifier-free guidance
189 | torch_device = device
190 |
191 | model_version = "CompVis/stable-diffusion-v1-4"
192 |
193 | unet = UNet2DConditionModel.from_pretrained(model_version, subfolder="unet")
194 | old_state_dict = copy.deepcopy(unet.state_dict())
195 | model_path = model_name
196 | new_state_dict_ = copy.deepcopy(torch.load(model_path, map_location='cpu'))
197 | delta_dict = {}
198 | for key, value in old_state_dict.items():
199 | delta_dict[key] = new_state_dict_[key] - value
200 | del new_state_dict_
201 | for _, row in df.iterrows():
202 | prompt = str(row.prompt)
203 | seed = row.evaluation_seed
204 | case_number = row.case_number
205 | if not (case_number>=from_case and case_number<=till_case):
206 | continue
207 | images_list = []
208 | for scale in scales:
209 |
210 | new_state_dict = {}
211 | for key, value in old_state_dict.items():
212 | new_state_dict[key] = value + scale * delta_dict[key]
213 |
214 |
215 | im = generate_images(unet=unet, new_state_dict=new_state_dict, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, prompt=prompt, evaluation_seed=seed,num_samples=num_samples, start_noise=start_noise, device=device)
216 | images_list.append(im)
217 | del new_state_dict
218 | for num in range(num_samples):
219 | fig, ax = plt.subplots(1, len(images_list), figsize=(4*(len(scales)),4))
220 | for i, a in enumerate(ax):
221 | images_list[i][num].save(f'{folder_path}/{scales_str[i]}/{case_number}_{num}.png')
222 | a.imshow(images_list[i][num])
223 | a.set_title(f"{scales[i]}",fontsize=15)
224 | a.axis('off')
225 | fig.savefig(f'{folder_path}/all/{case_number}_{num}.png',bbox_inches='tight')
226 | plt.close()
227 | del unet
228 | flush()
229 | if __name__=='__main__':
230 | parser = argparse.ArgumentParser(
231 | prog = 'generateImages',
232 | description = 'Generate Images using Diffusers Code')
233 | parser.add_argument('--model_name', help='name of model', type=str, required=True)
234 | parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
235 | parser.add_argument('--negative_prompts', help='negative prompt', type=str, required=False, default=None)
236 | parser.add_argument('--save_path', help='folder where to save images', type=str, required=True)
237 | parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0')
238 | parser.add_argument('--base', help='version of stable diffusion to use', type=str, required=False, default='1.4')
239 | parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5)
240 | parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
241 | parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000)
242 | parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
243 | parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=5)
244 | parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50)
245 | parser.add_argument('--rank', help='rank of the LoRA', type=int, required=False, default=4)
246 | parser.add_argument('--start_noise', help='to start the finetuned model from', type=int, required=False, default=800)
247 |
248 | args = parser.parse_args()
249 |
250 | model_name = args.model_name
251 | rank = args.rank
252 | if 'rank1' in model_name:
253 | rank = 1
254 | prompts_path = args.prompts_path
255 | save_path = args.save_path
256 | device = args.device
257 | guidance_scale = args.guidance_scale
258 | image_size = args.image_size
259 | ddim_steps = args.ddim_steps
260 | num_samples= args.num_samples
261 | from_case = args.from_case
262 | till_case = args.till_case
263 | start_noise = args.start_noise
264 | base = args.base
265 | negative_prompts_path = args.negative_prompts
266 | if negative_prompts_path is not None:
267 | negative_prompt = ''
268 | with open(negative_prompts_path, 'r') as fp:
269 | vals = json.load(fp)
270 | for val in vals:
271 | negative_prompt+=val+' ,'
272 | print(f'Negative prompt is being used: {negative_prompt}')
273 | else:
274 | negative_prompt = None
275 | generate_images_(model_name=model_name, prompts_path=prompts_path, save_path=save_path, negative_prompt=negative_prompt, device=device, guidance_scale = guidance_scale, image_size=image_size, ddim_steps=ddim_steps, num_samples=num_samples,from_case=from_case, till_case=till_case, base=base, rank=rank, start_noise=start_noise)
--------------------------------------------------------------------------------
/eval-scripts/generate_images_sd1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import argparse
4 | import os, json, random
5 | import pandas as pd
6 | import matplotlib.pyplot as plt
7 | import glob, re,sys
8 | from tqdm.auto import tqdm
9 |
10 | from safetensors.torch import load_file
11 | import matplotlib.image as mpimg
12 | import copy
13 | import gc
14 | from transformers import CLIPTextModel, CLIPTokenizer
15 |
16 | import diffusers
17 | from diffusers import DiffusionPipeline
18 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
19 | from diffusers.loaders import AttnProcsLayers
20 | from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor
21 | from typing import Any, Dict, List, Optional, Tuple, Union
22 | sys.path.insert(1, os.getcwd())
23 | from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
24 |
25 |
26 | def flush():
27 | torch.cuda.empty_cache()
28 | gc.collect()
29 | flush()
30 |
31 |
32 |
33 |
34 | def sorted_nicely( l ):
35 | convert = lambda text: float(text) if text.replace('-','').replace('.','').isdigit() else text
36 | alphanum_key = lambda key: [convert(c) for c in re.split('(-?[0-9]+.?[0-9]+?)', key) ]
37 | return sorted(l, key = alphanum_key)
38 |
39 | def flush():
40 | torch.cuda.empty_cache()
41 | gc.collect()
42 |
43 | def generate_images(model_name, prompts_path, save_path, negative_prompt, device, guidance_scale , image_size, ddim_steps, num_samples,from_case, till_case, base, rank, start_noise):
44 | # Load scheduler, tokenizer and models.
45 | scales = [-2, -1, 0, 1, 2]
46 | revision = None
47 | pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
48 | weight_dtype = torch.float16
49 |
50 | # Load scheduler, tokenizer and models.
51 | noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
52 | tokenizer = CLIPTokenizer.from_pretrained(
53 | pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
54 | )
55 | text_encoder = CLIPTextModel.from_pretrained(
56 | pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
57 | )
58 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
59 | unet = UNet2DConditionModel.from_pretrained(
60 | pretrained_model_name_or_path, subfolder="unet", revision=revision
61 | )
62 | # freeze parameters of models to save more memory
63 | unet.requires_grad_(False)
64 | unet.to(device, dtype=weight_dtype)
65 | vae.requires_grad_(False)
66 |
67 | text_encoder.requires_grad_(False)
68 |
69 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
70 | # as these weights are only used for inference, keeping weights in full precision is not required.
71 |
72 |
73 | # Move unet, vae and text_encoder to device and cast to weight_dtype
74 |
75 | vae.to(device, dtype=weight_dtype)
76 | text_encoder.to(device, dtype=weight_dtype)
77 |
78 | name = os.path.basename(model_name)
79 | alpha = 1
80 | train_method = 'xattn'
81 | n = model_name.split('/')[-2]
82 | if 'noxattn' in n:
83 | train_method = 'noxattn'
84 | if 'hspace' in n:
85 | train_method+='-hspace'
86 | scales = [-5, -2, -1, 0, 1, 2, 5]
87 | if 'last' in n:
88 | train_method+='-last'
89 | scales = [-5, -2, -1, 0, 1, 2, 5]
90 | network_type = "c3lier"
91 | if train_method == 'xattn':
92 | network_type = 'lierla'
93 |
94 | modules = DEFAULT_TARGET_REPLACE
95 | if network_type == "c3lier":
96 | modules += UNET_TARGET_REPLACE_MODULE_CONV
97 |
98 | network = LoRANetwork(
99 | unet,
100 | rank=rank,
101 | multiplier=1.0,
102 | alpha=alpha,
103 | train_method=train_method,
104 | ).to(device, dtype=weight_dtype)
105 |
106 | network.load_state_dict(torch.load(model_name))
107 |
108 | df = pd.read_csv(prompts_path)
109 | prompts = df.prompt
110 | seeds = df.evaluation_seed
111 | case_numbers = df.case_number
112 |
113 | folder_path = f'{save_path}/{name}'
114 | os.makedirs(folder_path, exist_ok=True)
115 | os.makedirs(folder_path+f'/all', exist_ok=True)
116 | scales_str = []
117 | for scale in scales:
118 | scale_str = f'{scale}'
119 | scale_str = scale_str.replace('0.5','half')
120 | scales_str.append(scale_str)
121 | os.makedirs(folder_path+f'/{scale_str}', exist_ok=True)
122 | height = image_size # default height of Stable Diffusion
123 | width = image_size # default width of Stable Diffusion
124 |
125 | num_inference_steps = ddim_steps # Number of denoising steps
126 |
127 | guidance_scale = guidance_scale # Scale for classifier-free guidance
128 | torch_device = device
129 | for _, row in df.iterrows():
130 | print(str(row.prompt),str(row.evaluation_seed))
131 | prompt = [str(row.prompt)]*num_samples
132 | batch_size = len(prompt)
133 | seed = row.evaluation_seed
134 | case_number = row.case_number
135 | if not (case_number>=from_case and case_number<=till_case):
136 | continue
137 | images_list = []
138 | for scale in scales:
139 | torch_device = device
140 | negative_prompt = None
141 | height = 512
142 | width = 512
143 | guidance_scale = 7.5
144 |
145 | generator = torch.manual_seed(seed)
146 | text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
147 |
148 | text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
149 |
150 | max_length = text_input.input_ids.shape[-1]
151 | if negative_prompt is None:
152 | uncond_input = tokenizer(
153 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
154 | )
155 | else:
156 | uncond_input = tokenizer(
157 | [negative_prompt] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
158 | )
159 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
160 |
161 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
162 |
163 | latents = torch.randn(
164 | (batch_size, unet.in_channels, height // 8, width // 8),
165 | generator=generator,
166 | )
167 | latents = latents.to(torch_device)
168 |
169 | noise_scheduler.set_timesteps(ddim_steps)
170 |
171 | latents = latents * noise_scheduler.init_noise_sigma
172 | latents = latents.to(weight_dtype)
173 | latent_model_input = torch.cat([latents] * 2)
174 | for t in tqdm(noise_scheduler.timesteps):
175 | if t>start_noise:
176 | network.set_lora_slider(scale=0)
177 | else:
178 | network.set_lora_slider(scale=scale)
179 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
180 | latent_model_input = torch.cat([latents] * 2)
181 |
182 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
183 | # predict the noise residual
184 | with network:
185 | with torch.no_grad():
186 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
187 | # perform guidance
188 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
189 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
190 |
191 | # compute the previous noisy sample x_t -> x_t-1
192 | latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
193 |
194 | # scale and decode the image latents with vae
195 | latents = 1 / 0.18215 * latents
196 | with torch.no_grad():
197 | image = vae.decode(latents).sample
198 | image = (image / 2 + 0.5).clamp(0, 1)
199 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
200 | images = (image * 255).round().astype("uint8")
201 | pil_images = [Image.fromarray(image) for image in images]
202 | images_list.append(pil_images)
203 | for num in range(num_samples):
204 | fig, ax = plt.subplots(1, len(images_list), figsize=(4*(len(scales)),4))
205 | for i, a in enumerate(ax):
206 | images_list[i][num].save(f'{folder_path}/{scales_str[i]}/{case_number}_{num}.png')
207 | a.imshow(images_list[i][num])
208 | a.set_title(f"{scales[i]}",fontsize=15)
209 | a.axis('off')
210 | fig.savefig(f'{folder_path}/all/{case_number}_{num}.png',bbox_inches='tight')
211 | plt.close()
212 | del network, unet
213 | flush()
214 | if __name__=='__main__':
215 | parser = argparse.ArgumentParser(
216 | prog = 'generateImages',
217 | description = 'Generate Images using Diffusers Code')
218 | parser.add_argument('--model_name', help='name of model', type=str, required=True)
219 | parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
220 | parser.add_argument('--negative_prompts', help='negative prompt', type=str, required=False, default=None)
221 | parser.add_argument('--save_path', help='folder where to save images', type=str, required=True)
222 | parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0')
223 | parser.add_argument('--base', help='version of stable diffusion to use', type=str, required=False, default='1.4')
224 | parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5)
225 | parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
226 | parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000)
227 | parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
228 | parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=2)
229 | parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50)
230 | parser.add_argument('--rank', help='rank of the LoRA', type=int, required=False, default=4)
231 | parser.add_argument('--start_noise', help='what time stamp to flip to edited model', type=int, required=False, default=850)
232 |
233 | args = parser.parse_args()
234 |
235 | model_name = args.model_name
236 | rank = args.rank
237 | if 'rank1' in model_name:
238 | rank = 1
239 | prompts_path = args.prompts_path
240 | save_path = args.save_path
241 | device = args.device
242 | guidance_scale = args.guidance_scale
243 | image_size = args.image_size
244 | ddim_steps = args.ddim_steps
245 | num_samples= args.num_samples
246 | from_case = args.from_case
247 | till_case = args.till_case
248 | start_noise = args.start_noise
249 | base = args.base
250 | negative_prompts_path = args.negative_prompts
251 | if negative_prompts_path is not None:
252 | negative_prompt = ''
253 | with open(negative_prompts_path, 'r') as fp:
254 | vals = json.load(fp)
255 | for val in vals:
256 | negative_prompt+=val+' ,'
257 | print(f'Negative prompt is being used: {negative_prompt}')
258 | else:
259 | negative_prompt = None
260 | generate_images(model_name=model_name, prompts_path=prompts_path, save_path=save_path, negative_prompt=negative_prompt, device=device, guidance_scale = guidance_scale, image_size=image_size, ddim_steps=ddim_steps, num_samples=num_samples,from_case=from_case, till_case=till_case, base=base, rank=rank, start_noise=start_noise)
261 |
--------------------------------------------------------------------------------
/eval-scripts/generate_images_textinversion.py:
--------------------------------------------------------------------------------
1 | from diffusers import StableDiffusionPipeline
2 | import pandas as pd
3 | import torch
4 | import os
5 | import argparse
6 |
7 | if __name__=='__main__':
8 | parser = argparse.ArgumentParser(
9 | prog = 'Generate Text Inversion Images',)
10 |
11 | parser.add_argument('--model_name', help='path to custom model', type=str, required=True)
12 | parser.add_argument('--prompts_path', help='path to csv prompts', type=str, required=True)
13 | parser.add_argument('--token', help='path to csv prompts', type=str, required=True)
14 | args = parser.parse_args()
15 | model_id = args.model_name
16 | custom_token = args.token #''
17 |
18 |
19 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cuda")
20 |
21 | df = pd.read_csv(args.prompts_path)
22 |
23 | prompts = list(df.prompt)
24 | seeds = list(df.evaluation_seed)
25 | case_numbers = list(df.case_number)
26 | file = os.path.basename(model_id)
27 |
28 | os.makedirs(f'images/text_inversion/{file}/',exist_ok=True)
29 | for idx,prompt in enumerate(prompts):
30 |
31 | prompt += f" with {custom_token}"
32 | case_number = case_numbers[idx]
33 | generator = torch.manual_seed(seeds[idx])
34 | images = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=5).images
35 | for i, im in enumerate(images):
36 | im.save(f'images/text_inversion/{file}/{case_number}_{i}.png')
37 |
--------------------------------------------------------------------------------
/eval-scripts/generate_images_textinversion_xl.py:
--------------------------------------------------------------------------------
1 | from diffusers import DiffusionPipeline,DDPMScheduler
2 | import pandas as pd
3 | import os
4 | import glob
5 | import torch
6 | import random
7 |
8 |
9 | def load_XLembedding(base,token="my",embedding_file="myToken.pt",path="./Embeddings/"):
10 | emb=torch.load(path+embedding_file)
11 | set_XLembedding(base,emb,token)
12 |
13 | def set_XLembedding(base,emb,token="my"):
14 | with torch.no_grad():
15 | # Embeddings[tokenNo] to learn
16 | tokens=base.components["tokenizer"].encode(token)
17 | assert len(tokens)==3, "token is not a single token in 'tokenizer'"
18 | tokenNo=tokens[1]
19 | tokens=base.components["tokenizer_2"].encode(token)
20 | assert len(tokens)==3, "token is not a single token in 'tokenizer_2'"
21 | tokenNo2=tokens[1]
22 | embs=base.components["text_encoder"].text_model.embeddings.token_embedding.weight
23 | embs2=base.components["text_encoder_2"].text_model.embeddings.token_embedding.weight
24 | assert embs[tokenNo].shape==emb["emb"].shape, "different 'text_encoder'"
25 | assert embs2[tokenNo2].shape==emb["emb2"].shape, "different 'text_encoder_2'"
26 | embs[tokenNo]=emb["emb"].to(embs.dtype).to(embs.device)
27 | embs2[tokenNo2]=emb["emb2"].to(embs2.dtype).to(embs2.device)
28 |
29 |
30 | base_model_path="stabilityai/stable-diffusion-xl-base-1.0"
31 | pipe = DiffusionPipeline.from_pretrained(
32 | base_model_path,
33 | torch_dtype=torch.float16, #torch.bfloat16
34 | variant="fp32",
35 | use_safetensors=True,
36 | add_watermarker=False,
37 | )
38 | pipe.enable_xformers_memory_efficient_attention()
39 | torch.set_grad_enabled(False)
40 | _=pipe.to("cuda:1")
41 |
42 |
43 | df = pd.read_csv('prompts/prompts-personreal.csv')
44 | prompts = list(df.prompt)
45 | seeds = list(df.evaluation_seed)
46 | case_numbers = list(df.case_number)
47 |
48 | learned="sks"
49 | embs_path="./textualinversion_models/"
50 | emb_file="eyesize_textual_inversion.pt"
51 |
52 | load_XLembedding(pipe,token=learned,embedding_file=emb_file,path=embs_path)
53 |
54 | p1="photo of a person, realistic, 8k with {} eyes"
55 | n_steps=50
56 |
57 | seed = random.randint(0,2**15)
58 | sample_prompt = p1
59 | prompt=sample_prompt.format(learned)
60 |
61 |
62 | for idx, prompt in enumerate(prompts):
63 | case_number = case_numbers[idx]
64 | seed = seeds[idx]
65 |
66 | print(prompt, seed)
67 | with torch.no_grad():
68 | generator = torch.manual_seed(seed)
69 | images = pipe(
70 | prompt=prompt+ ' with sks eyes',
71 | num_inference_steps=n_steps,
72 | num_images_per_prompt=5,
73 | generator = generator
74 | ).images
75 | for i, im in enumerate(images):
76 | im.save(f'images/textualinversion/eyesize_xl/{case_number}_{i}.png')
77 |
--------------------------------------------------------------------------------
/eval-scripts/lpip_score.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 |
8 | from PIL import Image
9 | import matplotlib.pyplot as plt
10 |
11 | import torchvision.transforms as transforms
12 | import torchvision.models as models
13 | import numpy as np
14 | import copy
15 | import os
16 | import pandas as pd
17 | import argparse
18 | import lpips
19 |
20 |
21 | # desired size of the output image
22 | imsize = 64
23 | loader = transforms.Compose([
24 | transforms.Resize(imsize), # scale imported image
25 | transforms.ToTensor()]) # transform it into a torch tensor
26 |
27 |
28 | def image_loader(image_name):
29 | image = Image.open(image_name)
30 | # fake batch dimension required to fit network's input dimensions
31 | image = loader(image).unsqueeze(0)
32 | image = (image-0.5)*2
33 | return image.to(torch.float)
34 |
35 |
36 | if __name__=='__main__':
37 | parser = argparse.ArgumentParser(
38 | prog = 'LPIPS',
39 | description = 'Takes the path to two images and gives LPIPS')
40 | parser.add_argument('--im_path', help='path to original image', type=str, required=True)
41 | parser.add_argument('--prompts_path', help='path to csv prompts', type=str, required=True)
42 | parser.add_argument('--true', help='path to true SD images', type=str, required=True)
43 |
44 | loss_fn_alex = lpips.LPIPS(net='alex')
45 | args = parser.parse_args()
46 |
47 | true = args.true
48 | models = os.listdir(args.im_path)
49 | models = [m for m in models if m not in [true,'all'] and '.csv' not in m]
50 |
51 | original_path = os.path.join(args.im_path,true)
52 | df_prompts = pd.read_csv(args.prompts_path)
53 | for model_name in models:
54 | edited_path = os.path.join(args.im_path,model_name)
55 | file_names = [name for name in os.listdir(edited_path) if '.png' in name]
56 | model_name = model_name.replace('half','0.5')
57 | df_prompts[f'lpips_{model_name}'] = df_prompts['case_number'] *0
58 | for index, row in df_prompts.iterrows():
59 | case_number = row.case_number
60 | files = [file for file in file_names if file.startswith(f'{case_number}_')]
61 | lpips_scores = []
62 | for file in files:
63 | print(file)
64 | try:
65 | original = image_loader(os.path.join(original_path,file))
66 | edited = image_loader(os.path.join(edited_path,file))
67 |
68 | l = loss_fn_alex(original, edited)
69 |
70 | lpips_scores.append(l.item())
71 | except Exception:
72 | print('No File')
73 | pass
74 | print(f'Case {case_number}: {np.mean(lpips_scores)}')
75 | df_prompts.loc[index,f'lpips_{model_name}'] = np.mean(lpips_scores)
76 | df_prompts.to_csv(os.path.join(args.im_path, f'lpips_score.csv'), index=False)
77 |
--------------------------------------------------------------------------------
/flux-sliders/flux-requirements.txt:
--------------------------------------------------------------------------------
1 | bitsandbytes
2 | dadaptation
3 | diffusers
4 | xformers
5 | torchvision
6 | accelerate
7 | transformers
8 | sentencepiece
9 | lion_pytorch
10 | lpips
11 | matplotlib
12 | numpy
13 | opencv_python
14 | opencv_python_headless
15 | pandas
16 | Pillow
17 | prodigyopt
18 | pydantic
19 | PyYAML
20 | Requests
21 | safetensors
22 | torch
23 | tqdm
24 | wandb
25 | datasets
26 | ftfy
27 | openai
28 | scikit-learn
29 | git+https://github.com/davidbau/baukit
30 | anthropic
31 | peft
--------------------------------------------------------------------------------
/flux-sliders/utils/lora.py:
--------------------------------------------------------------------------------
1 | # ref:
2 | # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
3 | # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
4 |
5 | import os
6 | import math
7 | from typing import Optional, List, Type, Set, Literal
8 |
9 | import torch
10 | import torch.nn as nn
11 | from diffusers import UNet2DConditionModel
12 | from safetensors.torch import save_file
13 | from datetime import datetime
14 |
15 | UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
16 | # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
17 | "Attention"
18 | ]
19 | UNET_TARGET_REPLACE_MODULE_CONV = [
20 | "ResnetBlock2D",
21 | "Downsample2D",
22 | "Upsample2D",
23 | "DownBlock2D",
24 | "UpBlock2D",
25 |
26 | ] # locon, 3clier
27 |
28 | LORA_PREFIX_UNET = "lora_unet"
29 |
30 | DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
31 |
32 | TRAINING_METHODS = Literal[
33 | "noxattn", # train all layers except x-attns and time_embed layers
34 | "innoxattn", # train all layers except self attention layers
35 | "selfattn", # ESD-u, train only self attention layers
36 | "xattn", # ESD-x, train only x attention layers
37 | "xattn-up", # all up blocks only
38 | "xattn-down",# all down blocks only
39 | "xattn-mid",# mid blocks only
40 | "full", # train all layers
41 | "xattn-strict", # q and k values
42 | "noxattn-hspace",
43 | "noxattn-hspace-last",
44 | # "xlayer",
45 | # "outxattn",
46 | # "outsattn",
47 | # "inxattn",
48 | # "inmidsattn",
49 | # "selflayer",
50 | ]
51 |
52 | def load_ortho_dict(n):
53 | path = f'~/orthogonal_basis/{n:09}.ckpt'
54 | if os.path.isfile(path):
55 | return torch.load(path)
56 | else:
57 | x = torch.randn(n,n)
58 | eig, _, _ = torch.svd(x)
59 | torch.save(eig, path)
60 | return eig
61 |
62 | def init_ortho_proj(rank, weight):
63 | seed = torch.seed()
64 | torch.manual_seed(datetime.now().timestamp())
65 | q_index = torch.randint(high=weight.size(0),size=(rank,))
66 | torch.manual_seed(seed)
67 |
68 | ortho_q_init = load_ortho_dict(weight.size(0)).to(dtype=weight.dtype)[:,q_index]
69 | return nn.Parameter(ortho_q_init)
70 |
71 |
72 | class LoRAModule(nn.Module):
73 | """
74 | replaces forward method of the original Linear, instead of replacing the original Linear module.
75 | """
76 |
77 | def __init__(
78 | self,
79 | lora_name,
80 | org_module: nn.Module,
81 | multiplier=1.0,
82 | lora_dim=4,
83 | alpha=1,
84 | train_method='xattn'
85 | ):
86 | """if alpha == 0 or None, alpha is rank (no scaling)."""
87 | super().__init__()
88 | self.lora_name = lora_name
89 | self.lora_dim = lora_dim
90 |
91 | if "Linear" in org_module.__class__.__name__:
92 | in_dim = org_module.in_features
93 | out_dim = org_module.out_features
94 | self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
95 | self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
96 |
97 | elif "Conv" in org_module.__class__.__name__: # 一応
98 | in_dim = org_module.in_channels
99 | out_dim = org_module.out_channels
100 |
101 | self.lora_dim = min(self.lora_dim, in_dim, out_dim)
102 | if self.lora_dim != lora_dim:
103 | print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
104 |
105 | kernel_size = org_module.kernel_size
106 | stride = org_module.stride
107 | padding = org_module.padding
108 | self.lora_down = nn.Conv2d(
109 | in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
110 | )
111 | self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
112 |
113 | if type(alpha) == torch.Tensor:
114 | alpha = alpha.detach().numpy()
115 | alpha = lora_dim if alpha is None or alpha == 0 else alpha
116 | self.scale = alpha / self.lora_dim
117 | self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
118 |
119 | # same as microsoft's
120 | nn.init.kaiming_uniform_(self.lora_down.weight, a=1)
121 | if train_method == 'full':
122 | nn.init.zeros_(self.lora_up.weight)
123 | else:
124 | self.lora_up.weight = init_ortho_proj(lora_dim, self.lora_up.weight)
125 | self.lora_up.weight.requires_grad_(False)
126 |
127 | self.multiplier = multiplier
128 | self.org_module = org_module # remove in applying
129 |
130 | def apply_to(self):
131 | self.org_forward = self.org_module.forward
132 | self.org_module.forward = self.forward
133 | del self.org_module
134 |
135 | def forward(self, x):
136 | return (
137 | self.org_forward(x)
138 | + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
139 | )
140 |
141 |
142 | class LoRANetwork(nn.Module):
143 | def __init__(
144 | self,
145 | unet: UNet2DConditionModel,
146 | rank: int = 4,
147 | multiplier: float = 1.0,
148 | alpha: float = 1.0,
149 | train_method: TRAINING_METHODS = "full",
150 | layers = ['Linear', 'Conv']
151 | ) -> None:
152 | super().__init__()
153 | self.lora_scale = 1
154 | self.multiplier = multiplier
155 | self.lora_dim = rank
156 | self.alpha = alpha
157 | self.train_method=train_method
158 | # LoRAのみ
159 | self.module = LoRAModule
160 |
161 | # unetのloraを作る
162 | self.unet_loras = self.create_modules(
163 | LORA_PREFIX_UNET,
164 | unet,
165 | DEFAULT_TARGET_REPLACE,
166 | self.lora_dim,
167 | self.multiplier,
168 | train_method=train_method,
169 | layers = layers
170 | )
171 | print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
172 |
173 | # assertion 名前の被りがないか確認しているようだ
174 | lora_names = set()
175 | for lora in self.unet_loras:
176 | assert (
177 | lora.lora_name not in lora_names
178 | ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
179 | lora_names.add(lora.lora_name)
180 |
181 | # 適用する
182 | for lora in self.unet_loras:
183 | lora.apply_to()
184 | self.add_module(
185 | lora.lora_name,
186 | lora,
187 | )
188 |
189 | del unet
190 |
191 | torch.cuda.empty_cache()
192 |
193 | def create_modules(
194 | self,
195 | prefix: str,
196 | root_module: nn.Module,
197 | target_replace_modules: List[str],
198 | rank: int,
199 | multiplier: float,
200 | train_method: TRAINING_METHODS,
201 | layers: List[str],
202 | ) -> list:
203 | filt_layers = []
204 | if 'Linear' in layers:
205 | filt_layers.extend(["Linear", "LoRACompatibleLinear"])
206 | if 'Conv' in layers:
207 | filt_layers.extend(["Conv2d", "LoRACompatibleConv"])
208 | loras = []
209 | names = []
210 | for name, module in root_module.named_modules():
211 | if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
212 | if "attn2" in name or "time_embed" in name:
213 | continue
214 | elif train_method == "innoxattn": # Cross Attention 以外学習
215 | if "attn2" in name:
216 | continue
217 | elif train_method == "selfattn": # Self Attention のみ学習
218 | if "attn1" not in name:
219 | continue
220 | elif train_method in ["xattn", "xattn-strict", "xattn-up", "xattn-down", "xattn-mid"]: # Cross Attention
221 | if "attn" not in name:
222 | continue
223 | if train_method == 'xattn-up':
224 | if 'up_block' not in name:
225 | continue
226 | if train_method == 'xattn-down':
227 | if 'down_block' not in name:
228 | continue
229 | if train_method == 'xattn-mid':
230 | if 'mid_block' not in name:
231 | continue
232 | elif train_method == "full": # 全部学習
233 | pass
234 | else:
235 | raise NotImplementedError(
236 | f"train_method: {train_method} is not implemented."
237 | )
238 | if module.__class__.__name__ in target_replace_modules:
239 | for child_name, child_module in module.named_modules():
240 | if child_module.__class__.__name__ in filt_layers:
241 |
242 |
243 | if train_method == 'xattn-strict':
244 | if 'out' in child_name:
245 | continue
246 | if 'to_q' in child_name:
247 | continue
248 | if train_method == 'noxattn-hspace':
249 | if 'mid_block' not in name:
250 | continue
251 | if train_method == 'noxattn-hspace-last':
252 | if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
253 | continue
254 | lora_name = prefix + "." + name + "." + child_name
255 | lora_name = lora_name.replace(".", "_")
256 | # print(f"{lora_name}")
257 | lora = self.module(
258 | lora_name, child_module, multiplier, rank, self.alpha, train_method
259 | )
260 | # print(name, child_name)
261 | # print(child_module.weight.shape)
262 | if lora_name not in names:
263 | loras.append(lora)
264 | names.append(lora_name)
265 | # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
266 | return loras
267 |
268 | def prepare_optimizer_params(self):
269 | all_params = []
270 |
271 | if self.unet_loras: # 実質これしかない
272 | params = []
273 | if self.train_method == 'full':
274 | [params.extend(lora.parameters()) for lora in self.unet_loras]
275 | else:
276 | [params.extend(lora.lora_down.parameters()) for lora in self.unet_loras]
277 | param_data = {"params": params}
278 | all_params.append(param_data)
279 |
280 | return all_params
281 |
282 | def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
283 | state_dict = self.state_dict()
284 |
285 | if dtype is not None:
286 | for key in list(state_dict.keys()):
287 | v = state_dict[key]
288 | v = v.detach().clone().to("cpu").to(dtype)
289 | state_dict[key] = v
290 |
291 | if os.path.splitext(file)[1] == ".safetensors":
292 | save_file(state_dict, file, metadata)
293 | else:
294 | torch.save(state_dict, file)
295 | def set_lora_slider(self, scale):
296 | self.lora_scale = scale
297 |
298 | def __enter__(self):
299 | for lora in self.unet_loras:
300 | lora.multiplier = 1.0 * self.lora_scale
301 |
302 | def __exit__(self, exc_type, exc_value, tb):
303 | for lora in self.unet_loras:
304 | lora.multiplier = 0
305 |
--------------------------------------------------------------------------------
/flux-sliders/utils/model_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Union, Optional
2 |
3 | import torch
4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5 | from diffusers import (
6 | UNet2DConditionModel,
7 | SchedulerMixin,
8 | StableDiffusionPipeline,
9 | StableDiffusionXLPipeline,
10 | )
11 | from diffusers.schedulers import (
12 | DDIMScheduler,
13 | DDPMScheduler,
14 | LMSDiscreteScheduler,
15 | EulerAncestralDiscreteScheduler,
16 | )
17 |
18 |
19 | TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
20 | TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
21 |
22 | AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
23 |
24 | SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
25 |
26 | DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
27 |
28 |
29 | def load_diffusers_model(
30 | pretrained_model_name_or_path: str,
31 | v2: bool = False,
32 | clip_skip: Optional[int] = None,
33 | weight_dtype: torch.dtype = torch.float32,
34 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
35 | # VAE はいらない
36 |
37 | if v2:
38 | tokenizer = CLIPTokenizer.from_pretrained(
39 | TOKENIZER_V2_MODEL_NAME,
40 | subfolder="tokenizer",
41 | torch_dtype=weight_dtype,
42 | cache_dir=DIFFUSERS_CACHE_DIR,
43 | )
44 | text_encoder = CLIPTextModel.from_pretrained(
45 | pretrained_model_name_or_path,
46 | subfolder="text_encoder",
47 | # default is clip skip 2
48 | num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
49 | torch_dtype=weight_dtype,
50 | cache_dir=DIFFUSERS_CACHE_DIR,
51 | )
52 | else:
53 | tokenizer = CLIPTokenizer.from_pretrained(
54 | TOKENIZER_V1_MODEL_NAME,
55 | subfolder="tokenizer",
56 | torch_dtype=weight_dtype,
57 | cache_dir=DIFFUSERS_CACHE_DIR,
58 | )
59 | text_encoder = CLIPTextModel.from_pretrained(
60 | pretrained_model_name_or_path,
61 | subfolder="text_encoder",
62 | num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
63 | torch_dtype=weight_dtype,
64 | cache_dir=DIFFUSERS_CACHE_DIR,
65 | )
66 |
67 | unet = UNet2DConditionModel.from_pretrained(
68 | pretrained_model_name_or_path,
69 | subfolder="unet",
70 | torch_dtype=weight_dtype,
71 | cache_dir=DIFFUSERS_CACHE_DIR,
72 | )
73 |
74 | return tokenizer, text_encoder, unet
75 |
76 |
77 | def load_checkpoint_model(
78 | checkpoint_path: str,
79 | v2: bool = False,
80 | clip_skip: Optional[int] = None,
81 | weight_dtype: torch.dtype = torch.float32,
82 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
83 | pipe = StableDiffusionPipeline.from_ckpt(
84 | checkpoint_path,
85 | upcast_attention=True if v2 else False,
86 | torch_dtype=weight_dtype,
87 | cache_dir=DIFFUSERS_CACHE_DIR,
88 | )
89 |
90 | unet = pipe.unet
91 | tokenizer = pipe.tokenizer
92 | text_encoder = pipe.text_encoder
93 | if clip_skip is not None:
94 | if v2:
95 | text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
96 | else:
97 | text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
98 |
99 | del pipe
100 |
101 | return tokenizer, text_encoder, unet
102 |
103 |
104 | def load_models(
105 | pretrained_model_name_or_path: str,
106 | scheduler_name: AVAILABLE_SCHEDULERS,
107 | v2: bool = False,
108 | v_pred: bool = False,
109 | weight_dtype: torch.dtype = torch.float32,
110 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
111 | if pretrained_model_name_or_path.endswith(
112 | ".ckpt"
113 | ) or pretrained_model_name_or_path.endswith(".safetensors"):
114 | tokenizer, text_encoder, unet = load_checkpoint_model(
115 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
116 | )
117 | else: # diffusers
118 | tokenizer, text_encoder, unet = load_diffusers_model(
119 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
120 | )
121 |
122 | # VAE はいらない
123 |
124 | scheduler = create_noise_scheduler(
125 | scheduler_name,
126 | prediction_type="v_prediction" if v_pred else "epsilon",
127 | )
128 |
129 | return tokenizer, text_encoder, unet, scheduler
130 |
131 |
132 | def load_diffusers_model_xl(
133 | pretrained_model_name_or_path: str,
134 | weight_dtype: torch.dtype = torch.float32,
135 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
136 | # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
137 |
138 | tokenizers = [
139 | CLIPTokenizer.from_pretrained(
140 | pretrained_model_name_or_path,
141 | subfolder="tokenizer",
142 | torch_dtype=weight_dtype,
143 | cache_dir=DIFFUSERS_CACHE_DIR,
144 | ),
145 | CLIPTokenizer.from_pretrained(
146 | pretrained_model_name_or_path,
147 | subfolder="tokenizer_2",
148 | torch_dtype=weight_dtype,
149 | cache_dir=DIFFUSERS_CACHE_DIR,
150 | pad_token_id=0, # same as open clip
151 | ),
152 | ]
153 |
154 | text_encoders = [
155 | CLIPTextModel.from_pretrained(
156 | pretrained_model_name_or_path,
157 | subfolder="text_encoder",
158 | torch_dtype=weight_dtype,
159 | cache_dir=DIFFUSERS_CACHE_DIR,
160 | ),
161 | CLIPTextModelWithProjection.from_pretrained(
162 | pretrained_model_name_or_path,
163 | subfolder="text_encoder_2",
164 | torch_dtype=weight_dtype,
165 | cache_dir=DIFFUSERS_CACHE_DIR,
166 | ),
167 | ]
168 |
169 | unet = UNet2DConditionModel.from_pretrained(
170 | pretrained_model_name_or_path,
171 | subfolder="unet",
172 | torch_dtype=weight_dtype,
173 | cache_dir=DIFFUSERS_CACHE_DIR,
174 | )
175 |
176 | return tokenizers, text_encoders, unet
177 |
178 |
179 | def load_checkpoint_model_xl(
180 | checkpoint_path: str,
181 | weight_dtype: torch.dtype = torch.float32,
182 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
183 | pipe = StableDiffusionXLPipeline.from_single_file(
184 | checkpoint_path,
185 | torch_dtype=weight_dtype,
186 | cache_dir=DIFFUSERS_CACHE_DIR,
187 | )
188 |
189 | unet = pipe.unet
190 | tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
191 | text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
192 | if len(text_encoders) == 2:
193 | text_encoders[1].pad_token_id = 0
194 |
195 | del pipe
196 |
197 | return tokenizers, text_encoders, unet
198 |
199 |
200 | def load_models_xl(
201 | pretrained_model_name_or_path: str,
202 | scheduler_name: AVAILABLE_SCHEDULERS,
203 | weight_dtype: torch.dtype = torch.float32,
204 | ) -> tuple[
205 | list[CLIPTokenizer],
206 | list[SDXL_TEXT_ENCODER_TYPE],
207 | UNet2DConditionModel,
208 | SchedulerMixin,
209 | ]:
210 | if pretrained_model_name_or_path.endswith(
211 | ".ckpt"
212 | ) or pretrained_model_name_or_path.endswith(".safetensors"):
213 | (
214 | tokenizers,
215 | text_encoders,
216 | unet,
217 | ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
218 | else: # diffusers
219 | (
220 | tokenizers,
221 | text_encoders,
222 | unet,
223 | ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
224 |
225 | scheduler = create_noise_scheduler(scheduler_name)
226 |
227 | return tokenizers, text_encoders, unet, scheduler
228 |
229 |
230 | def create_noise_scheduler(
231 | scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
232 | prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
233 | ) -> SchedulerMixin:
234 | # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。
235 |
236 | name = scheduler_name.lower().replace(" ", "_")
237 | if name == "ddim":
238 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
239 | scheduler = DDIMScheduler(
240 | beta_start=0.00085,
241 | beta_end=0.012,
242 | beta_schedule="scaled_linear",
243 | num_train_timesteps=1000,
244 | clip_sample=False,
245 | prediction_type=prediction_type, # これでいいの?
246 | )
247 | elif name == "ddpm":
248 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
249 | scheduler = DDPMScheduler(
250 | beta_start=0.00085,
251 | beta_end=0.012,
252 | beta_schedule="scaled_linear",
253 | num_train_timesteps=1000,
254 | clip_sample=False,
255 | prediction_type=prediction_type,
256 | )
257 | elif name == "lms":
258 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
259 | scheduler = LMSDiscreteScheduler(
260 | beta_start=0.00085,
261 | beta_end=0.012,
262 | beta_schedule="scaled_linear",
263 | num_train_timesteps=1000,
264 | prediction_type=prediction_type,
265 | )
266 | elif name == "euler_a":
267 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
268 | scheduler = EulerAncestralDiscreteScheduler(
269 | beta_start=0.00085,
270 | beta_end=0.012,
271 | beta_schedule="scaled_linear",
272 | num_train_timesteps=1000,
273 | prediction_type=prediction_type,
274 | )
275 | else:
276 | raise ValueError(f"Unknown scheduler name: {name}")
277 |
278 | return scheduler
279 |
--------------------------------------------------------------------------------
/flux-sliders/utils/prompt_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional, Union, List
2 |
3 | import yaml
4 | from pathlib import Path
5 |
6 |
7 | from pydantic import BaseModel, root_validator
8 | import torch
9 | import copy
10 |
11 | ACTION_TYPES = Literal[
12 | "erase",
13 | "enhance",
14 | ]
15 |
16 |
17 | # XL は二種類必要なので
18 | class PromptEmbedsXL:
19 | text_embeds: torch.FloatTensor
20 | pooled_embeds: torch.FloatTensor
21 |
22 | def __init__(self, *args) -> None:
23 | self.text_embeds = args[0]
24 | self.pooled_embeds = args[1]
25 |
26 |
27 | # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
28 | PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
29 |
30 |
31 | class PromptEmbedsCache: # 使いまわしたいので
32 | prompts: dict[str, PROMPT_EMBEDDING] = {}
33 |
34 | def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
35 | self.prompts[__name] = __value
36 |
37 | def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
38 | if __name in self.prompts:
39 | return self.prompts[__name]
40 | else:
41 | return None
42 |
43 |
44 | class PromptSettings(BaseModel): # yaml のやつ
45 | target: str
46 | positive: str = None # if None, target will be used
47 | unconditional: str = "" # default is ""
48 | neutral: str = None # if None, unconditional will be used
49 | action: ACTION_TYPES = "erase" # default is "erase"
50 | guidance_scale: float = 1.0 # default is 1.0
51 | resolution: int = 512 # default is 512
52 | dynamic_resolution: bool = False # default is False
53 | batch_size: int = 1 # default is 1
54 | dynamic_crops: bool = False # default is False. only used when model is XL
55 |
56 | @root_validator(pre=True)
57 | def fill_prompts(cls, values):
58 | keys = values.keys()
59 | if "target" not in keys:
60 | raise ValueError("target must be specified")
61 | if "positive" not in keys:
62 | values["positive"] = values["target"]
63 | if "unconditional" not in keys:
64 | values["unconditional"] = ""
65 | if "neutral" not in keys:
66 | values["neutral"] = values["unconditional"]
67 |
68 | return values
69 |
70 |
71 | class PromptEmbedsPair:
72 | target: PROMPT_EMBEDDING # not want to generate the concept
73 | positive: PROMPT_EMBEDDING # generate the concept
74 | unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
75 | neutral: PROMPT_EMBEDDING # base condition (default should be empty)
76 |
77 | guidance_scale: float
78 | resolution: int
79 | dynamic_resolution: bool
80 | batch_size: int
81 | dynamic_crops: bool
82 |
83 | loss_fn: torch.nn.Module
84 | action: ACTION_TYPES
85 |
86 | def __init__(
87 | self,
88 | loss_fn: torch.nn.Module,
89 | target: PROMPT_EMBEDDING,
90 | positive: PROMPT_EMBEDDING,
91 | unconditional: PROMPT_EMBEDDING,
92 | neutral: PROMPT_EMBEDDING,
93 | settings: PromptSettings,
94 | ) -> None:
95 | self.loss_fn = loss_fn
96 | self.target = target
97 | self.positive = positive
98 | self.unconditional = unconditional
99 | self.neutral = neutral
100 |
101 | self.guidance_scale = settings.guidance_scale
102 | self.resolution = settings.resolution
103 | self.dynamic_resolution = settings.dynamic_resolution
104 | self.batch_size = settings.batch_size
105 | self.dynamic_crops = settings.dynamic_crops
106 | self.action = settings.action
107 |
108 | def _erase(
109 | self,
110 | target_latents: torch.FloatTensor, # "van gogh"
111 | positive_latents: torch.FloatTensor, # "van gogh"
112 | unconditional_latents: torch.FloatTensor, # ""
113 | neutral_latents: torch.FloatTensor, # ""
114 | ) -> torch.FloatTensor:
115 | """Target latents are going not to have the positive concept."""
116 | return self.loss_fn(
117 | target_latents,
118 | neutral_latents
119 | - self.guidance_scale * (positive_latents - unconditional_latents)
120 | )
121 |
122 |
123 | def _enhance(
124 | self,
125 | target_latents: torch.FloatTensor, # "van gogh"
126 | positive_latents: torch.FloatTensor, # "van gogh"
127 | unconditional_latents: torch.FloatTensor, # ""
128 | neutral_latents: torch.FloatTensor, # ""
129 | ):
130 | """Target latents are going to have the positive concept."""
131 | return self.loss_fn(
132 | target_latents,
133 | neutral_latents
134 | + self.guidance_scale * (positive_latents - unconditional_latents)
135 | )
136 |
137 | def loss(
138 | self,
139 | **kwargs,
140 | ):
141 | if self.action == "erase":
142 | return self._erase(**kwargs)
143 |
144 | elif self.action == "enhance":
145 | return self._enhance(**kwargs)
146 |
147 | else:
148 | raise ValueError("action must be erase or enhance")
149 |
150 |
151 | def load_prompts_from_yaml(path, attributes = []):
152 | with open(path, "r") as f:
153 | prompts = yaml.safe_load(f)
154 | print(prompts)
155 | if len(prompts) == 0:
156 | raise ValueError("prompts file is empty")
157 | if len(attributes)!=0:
158 | newprompts = []
159 | for i in range(len(prompts)):
160 | for att in attributes:
161 | copy_ = copy.deepcopy(prompts[i])
162 | copy_['target'] = att + ' ' + copy_['target']
163 | copy_['positive'] = att + ' ' + copy_['positive']
164 | copy_['neutral'] = att + ' ' + copy_['neutral']
165 | copy_['unconditional'] = att + ' ' + copy_['unconditional']
166 | newprompts.append(copy_)
167 | else:
168 | newprompts = copy.deepcopy(prompts)
169 |
170 | print(newprompts)
171 | print(len(prompts), len(newprompts))
172 | prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
173 |
174 | return prompt_settings
175 |
--------------------------------------------------------------------------------
/flux-sliders/utils/ptp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch
17 | from PIL import Image, ImageDraw, ImageFont
18 | import cv2
19 | from typing import Optional, Union, Tuple, List, Callable, Dict
20 | from IPython.display import display
21 | from tqdm.notebook import tqdm
22 |
23 |
24 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
25 | h, w, c = image.shape
26 | offset = int(h * .2)
27 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
28 | font = cv2.FONT_HERSHEY_SIMPLEX
29 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
30 | img[:h] = image
31 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
32 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
33 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
34 | return img
35 |
36 |
37 | def view_images(images, num_rows=1, offset_ratio=0.02):
38 | if type(images) is list:
39 | num_empty = len(images) % num_rows
40 | elif images.ndim == 4:
41 | num_empty = images.shape[0] % num_rows
42 | else:
43 | images = [images]
44 | num_empty = 0
45 |
46 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
47 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
48 | num_items = len(images)
49 |
50 | h, w, c = images[0].shape
51 | offset = int(h * offset_ratio)
52 | num_cols = num_items // num_rows
53 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
54 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
55 | for i in range(num_rows):
56 | for j in range(num_cols):
57 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
58 | i * num_cols + j]
59 |
60 | pil_img = Image.fromarray(image_)
61 | display(pil_img)
62 |
63 |
64 | def diffusion_step(unet, model, controller, latents, context, t, guidance_scale, low_resource=False):
65 | if low_resource:
66 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
67 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
68 | else:
69 | latents_input = torch.cat([latents] * 2)
70 | noise_pred = unet(latents_input, t, encoder_hidden_states=context)["sample"]
71 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
72 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
73 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
74 | latents = controller.step_callback(latents)
75 | return latents
76 |
77 |
78 | def latent2image(vae, latents):
79 | latents = 1 / 0.18215 * latents
80 | image = vae.decode(latents)['sample']
81 | image = (image / 2 + 0.5).clamp(0, 1)
82 | image = image.cpu().permute(0, 2, 3, 1).numpy()
83 | image = (image * 255).astype(np.uint8)
84 | return image
85 |
86 |
87 | def init_latent(latent, model, height, width, generator, batch_size):
88 | if latent is None:
89 | latent = torch.randn(
90 | (1, model.unet.in_channels, height // 8, width // 8),
91 | generator=generator,
92 | )
93 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
94 | return latent, latents
95 |
96 |
97 | @torch.no_grad()
98 | def text2image_ldm(
99 | model,
100 | prompt: List[str],
101 | controller,
102 | num_inference_steps: int = 50,
103 | guidance_scale: Optional[float] = 7.,
104 | generator: Optional[torch.Generator] = None,
105 | latent: Optional[torch.FloatTensor] = None,
106 | ):
107 | register_attention_control(model, controller)
108 | height = width = 256
109 | batch_size = len(prompt)
110 |
111 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
112 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
113 |
114 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
115 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
116 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
117 | context = torch.cat([uncond_embeddings, text_embeddings])
118 |
119 | model.scheduler.set_timesteps(num_inference_steps)
120 | for t in tqdm(model.scheduler.timesteps):
121 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
122 |
123 | image = latent2image(model.vqvae, latents)
124 |
125 | return image, latent
126 |
127 |
128 | @torch.no_grad()
129 | def text2image_ldm_stable(
130 | model,
131 | prompt: List[str],
132 | controller,
133 | num_inference_steps: int = 50,
134 | guidance_scale: float = 7.5,
135 | generator: Optional[torch.Generator] = None,
136 | latent: Optional[torch.FloatTensor] = None,
137 | low_resource: bool = False,
138 | ):
139 | register_attention_control(model, controller)
140 | height = width = 512
141 | batch_size = len(prompt)
142 |
143 | text_input = model.tokenizer(
144 | prompt,
145 | padding="max_length",
146 | max_length=model.tokenizer.model_max_length,
147 | truncation=True,
148 | return_tensors="pt",
149 | )
150 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
151 | max_length = text_input.input_ids.shape[-1]
152 | uncond_input = model.tokenizer(
153 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
154 | )
155 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
156 |
157 | context = [uncond_embeddings, text_embeddings]
158 | if not low_resource:
159 | context = torch.cat(context)
160 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
161 |
162 | # set timesteps
163 | extra_set_kwargs = {"offset": 1}
164 | model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
165 | for t in tqdm(model.scheduler.timesteps):
166 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
167 |
168 | image = latent2image(model.vae, latents)
169 |
170 | return image, latent
171 |
172 |
173 | def register_attention_control(model, controller):
174 | def ca_forward(self, place_in_unet):
175 | to_out = self.to_out
176 | if type(to_out) is torch.nn.modules.container.ModuleList:
177 | to_out = self.to_out[0]
178 | else:
179 | to_out = self.to_out
180 |
181 | def forward(x, context=None, mask=None):
182 | batch_size, sequence_length, dim = x.shape
183 | h = self.heads
184 | q = self.to_q(x)
185 | is_cross = context is not None
186 | context = context if is_cross else x
187 | k = self.to_k(context)
188 | v = self.to_v(context)
189 | q = self.reshape_heads_to_batch_dim(q)
190 | k = self.reshape_heads_to_batch_dim(k)
191 | v = self.reshape_heads_to_batch_dim(v)
192 |
193 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
194 |
195 | if mask is not None:
196 | mask = mask.reshape(batch_size, -1)
197 | max_neg_value = -torch.finfo(sim.dtype).max
198 | mask = mask[:, None, :].repeat(h, 1, 1)
199 | sim.masked_fill_(~mask, max_neg_value)
200 |
201 | # attention, what we cannot get enough of
202 | attn = sim.softmax(dim=-1)
203 | attn = controller(attn, is_cross, place_in_unet)
204 | out = torch.einsum("b i j, b j d -> b i d", attn, v)
205 | out = self.reshape_batch_dim_to_heads(out)
206 | return to_out(out)
207 |
208 | return forward
209 |
210 | class DummyController:
211 |
212 | def __call__(self, *args):
213 | return args[0]
214 |
215 | def __init__(self):
216 | self.num_att_layers = 0
217 |
218 | if controller is None:
219 | controller = DummyController()
220 |
221 | def register_recr(net_, count, place_in_unet):
222 | if net_.__class__.__name__ == 'CrossAttention':
223 | net_.forward = ca_forward(net_, place_in_unet)
224 | return count + 1
225 | elif hasattr(net_, 'children'):
226 | for net__ in net_.children():
227 | count = register_recr(net__, count, place_in_unet)
228 | return count
229 |
230 | cross_att_count = 0
231 | sub_nets = model.unet.named_children()
232 | for net in sub_nets:
233 | if "down" in net[0]:
234 | cross_att_count += register_recr(net[1], 0, "down")
235 | elif "up" in net[0]:
236 | cross_att_count += register_recr(net[1], 0, "up")
237 | elif "mid" in net[0]:
238 | cross_att_count += register_recr(net[1], 0, "mid")
239 |
240 | controller.num_att_layers = cross_att_count
241 |
242 |
243 | def get_word_inds(text: str, word_place: int, tokenizer):
244 | split_text = text.split(" ")
245 | if type(word_place) is str:
246 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
247 | elif type(word_place) is int:
248 | word_place = [word_place]
249 | out = []
250 | if len(word_place) > 0:
251 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
252 | cur_len, ptr = 0, 0
253 |
254 | for i in range(len(words_encode)):
255 | cur_len += len(words_encode[i])
256 | if ptr in word_place:
257 | out.append(i + 1)
258 | if cur_len >= len(split_text[ptr]):
259 | ptr += 1
260 | cur_len = 0
261 | return np.array(out)
262 |
263 |
264 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
265 | word_inds: Optional[torch.Tensor]=None):
266 | if type(bounds) is float:
267 | bounds = 0, bounds
268 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
269 | if word_inds is None:
270 | word_inds = torch.arange(alpha.shape[2])
271 | alpha[: start, prompt_ind, word_inds] = 0
272 | alpha[start: end, prompt_ind, word_inds] = 1
273 | alpha[end:, prompt_ind, word_inds] = 0
274 | return alpha
275 |
276 |
277 | def get_time_words_attention_alpha(prompts, num_steps,
278 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
279 | tokenizer, max_num_words=77):
280 | if type(cross_replace_steps) is not dict:
281 | cross_replace_steps = {"default_": cross_replace_steps}
282 | if "default_" not in cross_replace_steps:
283 | cross_replace_steps["default_"] = (0., 1.)
284 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
285 | for i in range(len(prompts) - 1):
286 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
287 | i)
288 | for key, item in cross_replace_steps.items():
289 | if key != "default_":
290 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
291 | for i, ind in enumerate(inds):
292 | if len(ind) > 0:
293 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
294 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
295 | return alpha_time_words
--------------------------------------------------------------------------------
/images/main_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/images/main_figure.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bitsandbytes==0.41.1
2 | dadaptation==3.1
3 | diffusers==0.20.2
4 | ipython==8.7.0
5 | lion_pytorch==0.1.2
6 | lpips==0.1.4
7 | matplotlib==3.6.2
8 | numpy==1.23.5
9 | opencv_python==4.5.5.64
10 | opencv_python_headless==4.7.0.68
11 | pandas==1.5.2
12 | Pillow==10.1.0
13 | prodigyopt==1.0
14 | pydantic==2.6.3
15 | PyYAML==6.0.1
16 | Requests==2.31.0
17 | safetensors==0.3.1
18 | torch==2.0.1
19 | torchvision==0.15.2
20 | tqdm==4.64.1
21 | transformers==4.27.4
22 | wandb==0.12.21
23 | xformers==0.0.21
24 | accelerate==0.16.0
25 |
--------------------------------------------------------------------------------
/trainscripts/__init__.py:
--------------------------------------------------------------------------------
1 | # from textsliders import lora
--------------------------------------------------------------------------------
/trainscripts/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/trainscripts/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/imagesliders/__pycache__/config_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/config_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/imagesliders/__pycache__/debug_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/debug_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/imagesliders/__pycache__/lora.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/lora.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/imagesliders/__pycache__/model_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/model_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/imagesliders/__pycache__/prompt_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/prompt_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/imagesliders/__pycache__/train_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/imagesliders/__pycache__/train_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/imagesliders/config_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional
2 |
3 | import yaml
4 |
5 | from pydantic import BaseModel
6 | import torch
7 |
8 | from lora import TRAINING_METHODS
9 |
10 | PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
11 | NETWORK_TYPES = Literal["lierla", "c3lier"]
12 |
13 |
14 | class PretrainedModelConfig(BaseModel):
15 | name_or_path: str
16 | v2: bool = False
17 | v_pred: bool = False
18 |
19 | clip_skip: Optional[int] = None
20 |
21 |
22 | class NetworkConfig(BaseModel):
23 | type: NETWORK_TYPES = "lierla"
24 | rank: int = 4
25 | alpha: float = 1.0
26 |
27 | training_method: TRAINING_METHODS = "full"
28 |
29 |
30 | class TrainConfig(BaseModel):
31 | precision: PRECISION_TYPES = "bfloat16"
32 | noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
33 |
34 | iterations: int = 500
35 | lr: float = 1e-4
36 | optimizer: str = "adamw"
37 | optimizer_args: str = ""
38 | lr_scheduler: str = "constant"
39 |
40 | max_denoising_steps: int = 50
41 |
42 |
43 | class SaveConfig(BaseModel):
44 | name: str = "untitled"
45 | path: str = "./output"
46 | per_steps: int = 200
47 | precision: PRECISION_TYPES = "float32"
48 |
49 |
50 | class LoggingConfig(BaseModel):
51 | use_wandb: bool = False
52 |
53 | verbose: bool = False
54 |
55 |
56 | class OtherConfig(BaseModel):
57 | use_xformers: bool = False
58 |
59 |
60 | class RootConfig(BaseModel):
61 | prompts_file: str
62 | pretrained_model: PretrainedModelConfig
63 |
64 | network: NetworkConfig
65 |
66 | train: Optional[TrainConfig]
67 |
68 | save: Optional[SaveConfig]
69 |
70 | logging: Optional[LoggingConfig]
71 |
72 | other: Optional[OtherConfig]
73 |
74 |
75 | def parse_precision(precision: str) -> torch.dtype:
76 | if precision == "fp32" or precision == "float32":
77 | return torch.float32
78 | elif precision == "fp16" or precision == "float16":
79 | return torch.float16
80 | elif precision == "bf16" or precision == "bfloat16":
81 | return torch.bfloat16
82 |
83 | raise ValueError(f"Invalid precision type: {precision}")
84 |
85 |
86 | def load_config_from_yaml(config_path: str) -> RootConfig:
87 | with open(config_path, "r") as f:
88 | config = yaml.load(f, Loader=yaml.FullLoader)
89 |
90 | root = RootConfig(**config)
91 |
92 | if root.train is None:
93 | root.train = TrainConfig()
94 |
95 | if root.save is None:
96 | root.save = SaveConfig()
97 |
98 | if root.logging is None:
99 | root.logging = LoggingConfig()
100 |
101 | if root.other is None:
102 | root.other = OtherConfig()
103 |
104 | return root
105 |
--------------------------------------------------------------------------------
/trainscripts/imagesliders/data/config-xl.yaml:
--------------------------------------------------------------------------------
1 | prompts_file: "trainscripts/imagesliders/data/prompts-xl.yaml"
2 | pretrained_model:
3 | name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models
4 | v2: false # true if model is v2.x
5 | v_pred: false # true if model uses v-prediction
6 | network:
7 | type: "c3lier" # or "c3lier" or "lierla"
8 | rank: 4
9 | alpha: 1.0
10 | training_method: "noxattn"
11 | train:
12 | precision: "bfloat16"
13 | noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14 | iterations: 1000
15 | lr: 0.0002
16 | optimizer: "AdamW"
17 | lr_scheduler: "constant"
18 | max_denoising_steps: 50
19 | save:
20 | name: "temp"
21 | path: "./models"
22 | per_steps: 500
23 | precision: "bfloat16"
24 | logging:
25 | use_wandb: false
26 | verbose: false
27 | other:
28 | use_xformers: true
--------------------------------------------------------------------------------
/trainscripts/imagesliders/data/config.yaml:
--------------------------------------------------------------------------------
1 | prompts_file: "trainscripts/imagesliders/data/prompts.yaml"
2 | pretrained_model:
3 | name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models
4 | v2: false # true if model is v2.x
5 | v_pred: false # true if model uses v-prediction
6 | network:
7 | type: "c3lier" # or "c3lier" or "lierla"
8 | rank: 4
9 | alpha: 1.0
10 | training_method: "noxattn"
11 | train:
12 | precision: "bfloat16"
13 | noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14 | iterations: 1000
15 | lr: 0.0002
16 | optimizer: "AdamW"
17 | lr_scheduler: "constant"
18 | max_denoising_steps: 50
19 | save:
20 | name: "temp"
21 | path: "./models"
22 | per_steps: 500
23 | precision: "bfloat16"
24 | logging:
25 | use_wandb: false
26 | verbose: false
27 | other:
28 | use_xformers: true
--------------------------------------------------------------------------------
/trainscripts/imagesliders/data/prompts.yaml:
--------------------------------------------------------------------------------
1 | # - target: "person" # what word for erasing the positive concept from
2 | # positive: "person, very old" # concept to erase
3 | # unconditional: "person" # word to take the difference from the positive concept
4 | # neutral: "person" # starting point for conditioning the target
5 | # action: "enhance" # erase or enhance
6 | # guidance_scale: 4
7 | # resolution: 512
8 | # dynamic_resolution: false
9 | # batch_size: 1
10 | - target: "" # what word for erasing the positive concept from
11 | positive: "" # concept to erase
12 | unconditional: "" # word to take the difference from the positive concept
13 | neutral: "" # starting point for conditioning the target
14 | action: "enhance" # erase or enhance
15 | guidance_scale: 1
16 | resolution: 512
17 | dynamic_resolution: false
18 | batch_size: 1
19 | # - target: "" # what word for erasing the positive concept from
20 | # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase
21 | # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
22 | # neutral: "" # starting point for conditioning the target
23 | # action: "enhance" # erase or enhance
24 | # guidance_scale: 4
25 | # resolution: 512
26 | # dynamic_resolution: false
27 | # batch_size: 1
28 | # - target: "food" # what word for erasing the positive concept from
29 | # positive: "food, expensive and fine dining" # concept to erase
30 | # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
31 | # neutral: "food" # starting point for conditioning the target
32 | # action: "enhance" # erase or enhance
33 | # guidance_scale: 4
34 | # resolution: 512
35 | # dynamic_resolution: false
36 | # batch_size: 1
37 | # - target: "room" # what word for erasing the positive concept from
38 | # positive: "room, dirty disorganised and cluttered" # concept to erase
39 | # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
40 | # neutral: "room" # starting point for conditioning the target
41 | # action: "enhance" # erase or enhance
42 | # guidance_scale: 4
43 | # resolution: 512
44 | # dynamic_resolution: false
45 | # batch_size: 1
46 | # - target: "male person" # what word for erasing the positive concept from
47 | # positive: "male person, with a surprised look" # concept to erase
48 | # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
49 | # neutral: "male person" # starting point for conditioning the target
50 | # action: "enhance" # erase or enhance
51 | # guidance_scale: 4
52 | # resolution: 512
53 | # dynamic_resolution: false
54 | # batch_size: 1
55 | # - target: "female person" # what word for erasing the positive concept from
56 | # positive: "female person, with a surprised look" # concept to erase
57 | # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
58 | # neutral: "female person" # starting point for conditioning the target
59 | # action: "enhance" # erase or enhance
60 | # guidance_scale: 4
61 | # resolution: 512
62 | # dynamic_resolution: false
63 | # batch_size: 1
64 | # - target: "sky" # what word for erasing the positive concept from
65 | # positive: "peaceful sky" # concept to erase
66 | # unconditional: "sky" # word to take the difference from the positive concept
67 | # neutral: "sky" # starting point for conditioning the target
68 | # action: "enhance" # erase or enhance
69 | # guidance_scale: 4
70 | # resolution: 512
71 | # dynamic_resolution: false
72 | # batch_size: 1
73 | # - target: "sky" # what word for erasing the positive concept from
74 | # positive: "chaotic dark sky" # concept to erase
75 | # unconditional: "sky" # word to take the difference from the positive concept
76 | # neutral: "sky" # starting point for conditioning the target
77 | # action: "erase" # erase or enhance
78 | # guidance_scale: 4
79 | # resolution: 512
80 | # dynamic_resolution: false
81 | # batch_size: 1
82 | # - target: "person" # what word for erasing the positive concept from
83 | # positive: "person, very young" # concept to erase
84 | # unconditional: "person" # word to take the difference from the positive concept
85 | # neutral: "person" # starting point for conditioning the target
86 | # action: "erase" # erase or enhance
87 | # guidance_scale: 4
88 | # resolution: 512
89 | # dynamic_resolution: false
90 | # batch_size: 1
91 | # overweight
92 | # - target: "art" # what word for erasing the positive concept from
93 | # positive: "realistic art" # concept to erase
94 | # unconditional: "art" # word to take the difference from the positive concept
95 | # neutral: "art" # starting point for conditioning the target
96 | # action: "enhance" # erase or enhance
97 | # guidance_scale: 4
98 | # resolution: 512
99 | # dynamic_resolution: false
100 | # batch_size: 1
101 | # - target: "art" # what word for erasing the positive concept from
102 | # positive: "abstract art" # concept to erase
103 | # unconditional: "art" # word to take the difference from the positive concept
104 | # neutral: "art" # starting point for conditioning the target
105 | # action: "erase" # erase or enhance
106 | # guidance_scale: 4
107 | # resolution: 512
108 | # dynamic_resolution: false
109 | # batch_size: 1
110 | # sky
111 | # - target: "weather" # what word for erasing the positive concept from
112 | # positive: "bright pleasant weather" # concept to erase
113 | # unconditional: "weather" # word to take the difference from the positive concept
114 | # neutral: "weather" # starting point for conditioning the target
115 | # action: "enhance" # erase or enhance
116 | # guidance_scale: 4
117 | # resolution: 512
118 | # dynamic_resolution: false
119 | # batch_size: 1
120 | # - target: "weather" # what word for erasing the positive concept from
121 | # positive: "dark gloomy weather" # concept to erase
122 | # unconditional: "weather" # word to take the difference from the positive concept
123 | # neutral: "weather" # starting point for conditioning the target
124 | # action: "erase" # erase or enhance
125 | # guidance_scale: 4
126 | # resolution: 512
127 | # dynamic_resolution: false
128 | # batch_size: 1
129 | # hair
130 | # - target: "person" # what word for erasing the positive concept from
131 | # positive: "person with long hair" # concept to erase
132 | # unconditional: "person" # word to take the difference from the positive concept
133 | # neutral: "person" # starting point for conditioning the target
134 | # action: "enhance" # erase or enhance
135 | # guidance_scale: 4
136 | # resolution: 512
137 | # dynamic_resolution: false
138 | # batch_size: 1
139 | # - target: "person" # what word for erasing the positive concept from
140 | # positive: "person with short hair" # concept to erase
141 | # unconditional: "person" # word to take the difference from the positive concept
142 | # neutral: "person" # starting point for conditioning the target
143 | # action: "erase" # erase or enhance
144 | # guidance_scale: 4
145 | # resolution: 512
146 | # dynamic_resolution: false
147 | # batch_size: 1
148 | # - target: "girl" # what word for erasing the positive concept from
149 | # positive: "baby girl" # concept to erase
150 | # unconditional: "girl" # word to take the difference from the positive concept
151 | # neutral: "girl" # starting point for conditioning the target
152 | # action: "enhance" # erase or enhance
153 | # guidance_scale: -4
154 | # resolution: 512
155 | # dynamic_resolution: false
156 | # batch_size: 1
157 | # - target: "boy" # what word for erasing the positive concept from
158 | # positive: "old man" # concept to erase
159 | # unconditional: "boy" # word to take the difference from the positive concept
160 | # neutral: "boy" # starting point for conditioning the target
161 | # action: "enhance" # erase or enhance
162 | # guidance_scale: 4
163 | # resolution: 512
164 | # dynamic_resolution: false
165 | # batch_size: 1
166 | # - target: "boy" # what word for erasing the positive concept from
167 | # positive: "baby boy" # concept to erase
168 | # unconditional: "boy" # word to take the difference from the positive concept
169 | # neutral: "boy" # starting point for conditioning the target
170 | # action: "enhance" # erase or enhance
171 | # guidance_scale: -4
172 | # resolution: 512
173 | # dynamic_resolution: false
174 | # batch_size: 1
--------------------------------------------------------------------------------
/trainscripts/imagesliders/debug_util.py:
--------------------------------------------------------------------------------
1 | # デバッグ用...
2 |
3 | import torch
4 |
5 |
6 | def check_requires_grad(model: torch.nn.Module):
7 | for name, module in list(model.named_modules())[:5]:
8 | if len(list(module.parameters())) > 0:
9 | print(f"Module: {name}")
10 | for name, param in list(module.named_parameters())[:2]:
11 | print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
12 |
13 |
14 | def check_training_mode(model: torch.nn.Module):
15 | for name, module in list(model.named_modules())[:5]:
16 | print(f"Module: {name}, Training Mode: {module.training}")
17 |
--------------------------------------------------------------------------------
/trainscripts/imagesliders/lora.py:
--------------------------------------------------------------------------------
1 | # ref:
2 | # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
3 | # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
4 |
5 | import os
6 | import math
7 | from typing import Optional, List, Type, Set, Literal
8 |
9 | import torch
10 | import torch.nn as nn
11 | from diffusers import UNet2DConditionModel
12 | from safetensors.torch import save_file
13 |
14 |
15 | UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
16 | # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
17 | "Attention"
18 | ]
19 | UNET_TARGET_REPLACE_MODULE_CONV = [
20 | "ResnetBlock2D",
21 | "Downsample2D",
22 | "Upsample2D",
23 | # "DownBlock2D",
24 | # "UpBlock2D"
25 | ] # locon, 3clier
26 |
27 | LORA_PREFIX_UNET = "lora_unet"
28 |
29 | DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
30 |
31 | TRAINING_METHODS = Literal[
32 | "noxattn", # train all layers except x-attns and time_embed layers
33 | "innoxattn", # train all layers except self attention layers
34 | "selfattn", # ESD-u, train only self attention layers
35 | "xattn", # ESD-x, train only x attention layers
36 | "full", # train all layers
37 | "xattn-strict", # q and k values
38 | "noxattn-hspace",
39 | "noxattn-hspace-last",
40 | # "xlayer",
41 | # "outxattn",
42 | # "outsattn",
43 | # "inxattn",
44 | # "inmidsattn",
45 | # "selflayer",
46 | ]
47 |
48 |
49 | class LoRAModule(nn.Module):
50 | """
51 | replaces forward method of the original Linear, instead of replacing the original Linear module.
52 | """
53 |
54 | def __init__(
55 | self,
56 | lora_name,
57 | org_module: nn.Module,
58 | multiplier=1.0,
59 | lora_dim=4,
60 | alpha=1,
61 | ):
62 | """if alpha == 0 or None, alpha is rank (no scaling)."""
63 | super().__init__()
64 | self.lora_name = lora_name
65 | self.lora_dim = lora_dim
66 |
67 | if "Linear" in org_module.__class__.__name__:
68 | in_dim = org_module.in_features
69 | out_dim = org_module.out_features
70 | self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
71 | self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
72 |
73 | elif "Conv" in org_module.__class__.__name__: # 一応
74 | in_dim = org_module.in_channels
75 | out_dim = org_module.out_channels
76 |
77 | self.lora_dim = min(self.lora_dim, in_dim, out_dim)
78 | if self.lora_dim != lora_dim:
79 | print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
80 |
81 | kernel_size = org_module.kernel_size
82 | stride = org_module.stride
83 | padding = org_module.padding
84 | self.lora_down = nn.Conv2d(
85 | in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
86 | )
87 | self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
88 |
89 | if type(alpha) == torch.Tensor:
90 | alpha = alpha.detach().numpy()
91 | alpha = lora_dim if alpha is None or alpha == 0 else alpha
92 | self.scale = alpha / self.lora_dim
93 | self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
94 |
95 | # same as microsoft's
96 | nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
97 | nn.init.zeros_(self.lora_up.weight)
98 |
99 | self.multiplier = multiplier
100 | self.org_module = org_module # remove in applying
101 |
102 | def apply_to(self):
103 | self.org_forward = self.org_module.forward
104 | self.org_module.forward = self.forward
105 | del self.org_module
106 |
107 | def forward(self, x):
108 | return (
109 | self.org_forward(x)
110 | + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
111 | )
112 |
113 |
114 | class LoRANetwork(nn.Module):
115 | def __init__(
116 | self,
117 | unet: UNet2DConditionModel,
118 | rank: int = 4,
119 | multiplier: float = 1.0,
120 | alpha: float = 1.0,
121 | train_method: TRAINING_METHODS = "full",
122 | ) -> None:
123 | super().__init__()
124 | self.lora_scale = 1
125 | self.multiplier = multiplier
126 | self.lora_dim = rank
127 | self.alpha = alpha
128 |
129 | # LoRAのみ
130 | self.module = LoRAModule
131 |
132 | # unetのloraを作る
133 | self.unet_loras = self.create_modules(
134 | LORA_PREFIX_UNET,
135 | unet,
136 | DEFAULT_TARGET_REPLACE,
137 | self.lora_dim,
138 | self.multiplier,
139 | train_method=train_method,
140 | )
141 | print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
142 |
143 | # assertion 名前の被りがないか確認しているようだ
144 | lora_names = set()
145 | for lora in self.unet_loras:
146 | assert (
147 | lora.lora_name not in lora_names
148 | ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
149 | lora_names.add(lora.lora_name)
150 |
151 | # 適用する
152 | for lora in self.unet_loras:
153 | lora.apply_to()
154 | self.add_module(
155 | lora.lora_name,
156 | lora,
157 | )
158 |
159 | del unet
160 |
161 | torch.cuda.empty_cache()
162 |
163 | def create_modules(
164 | self,
165 | prefix: str,
166 | root_module: nn.Module,
167 | target_replace_modules: List[str],
168 | rank: int,
169 | multiplier: float,
170 | train_method: TRAINING_METHODS,
171 | ) -> list:
172 | loras = []
173 | names = []
174 | for name, module in root_module.named_modules():
175 | if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
176 | if "attn2" in name or "time_embed" in name:
177 | continue
178 | elif train_method == "innoxattn": # Cross Attention 以外学習
179 | if "attn2" in name:
180 | continue
181 | elif train_method == "selfattn": # Self Attention のみ学習
182 | if "attn1" not in name:
183 | continue
184 | elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習
185 | if "attn2" not in name:
186 | continue
187 | elif train_method == "full": # 全部学習
188 | pass
189 | else:
190 | raise NotImplementedError(
191 | f"train_method: {train_method} is not implemented."
192 | )
193 | if module.__class__.__name__ in target_replace_modules:
194 | for child_name, child_module in module.named_modules():
195 | if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
196 | if train_method == 'xattn-strict':
197 | if 'out' in child_name:
198 | continue
199 | if train_method == 'noxattn-hspace':
200 | if 'mid_block' not in name:
201 | continue
202 | if train_method == 'noxattn-hspace-last':
203 | if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
204 | continue
205 | lora_name = prefix + "." + name + "." + child_name
206 | lora_name = lora_name.replace(".", "_")
207 | # print(f"{lora_name}")
208 | lora = self.module(
209 | lora_name, child_module, multiplier, rank, self.alpha
210 | )
211 | # print(name, child_name)
212 | # print(child_module.weight.shape)
213 | loras.append(lora)
214 | names.append(lora_name)
215 | # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
216 | return loras
217 |
218 | def prepare_optimizer_params(self):
219 | all_params = []
220 |
221 | if self.unet_loras: # 実質これしかない
222 | params = []
223 | [params.extend(lora.parameters()) for lora in self.unet_loras]
224 | param_data = {"params": params}
225 | all_params.append(param_data)
226 |
227 | return all_params
228 |
229 | def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
230 | state_dict = self.state_dict()
231 |
232 | if dtype is not None:
233 | for key in list(state_dict.keys()):
234 | v = state_dict[key]
235 | v = v.detach().clone().to("cpu").to(dtype)
236 | state_dict[key] = v
237 |
238 | # for key in list(state_dict.keys()):
239 | # if not key.startswith("lora"):
240 | # # lora以外除外
241 | # del state_dict[key]
242 |
243 | if os.path.splitext(file)[1] == ".safetensors":
244 | save_file(state_dict, file, metadata)
245 | else:
246 | torch.save(state_dict, file)
247 | def set_lora_slider(self, scale):
248 | self.lora_scale = scale
249 |
250 | def __enter__(self):
251 | for lora in self.unet_loras:
252 | lora.multiplier = 1.0 * self.lora_scale
253 |
254 | def __exit__(self, exc_type, exc_value, tb):
255 | for lora in self.unet_loras:
256 | lora.multiplier = 0
257 |
--------------------------------------------------------------------------------
/trainscripts/imagesliders/model_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Union, Optional
2 |
3 | import torch
4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5 | from diffusers import (
6 | UNet2DConditionModel,
7 | SchedulerMixin,
8 | StableDiffusionPipeline,
9 | StableDiffusionXLPipeline,
10 | AutoencoderKL,
11 | )
12 | from diffusers.schedulers import (
13 | DDIMScheduler,
14 | DDPMScheduler,
15 | LMSDiscreteScheduler,
16 | EulerAncestralDiscreteScheduler,
17 | )
18 |
19 |
20 | TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
21 | TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
22 |
23 | AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
24 |
25 | SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
26 |
27 | DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
28 |
29 |
30 | def load_diffusers_model(
31 | pretrained_model_name_or_path: str,
32 | v2: bool = False,
33 | clip_skip: Optional[int] = None,
34 | weight_dtype: torch.dtype = torch.float32,
35 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
36 | # VAE はいらない
37 |
38 | if v2:
39 | tokenizer = CLIPTokenizer.from_pretrained(
40 | TOKENIZER_V2_MODEL_NAME,
41 | subfolder="tokenizer",
42 | torch_dtype=weight_dtype,
43 | cache_dir=DIFFUSERS_CACHE_DIR,
44 | )
45 | text_encoder = CLIPTextModel.from_pretrained(
46 | pretrained_model_name_or_path,
47 | subfolder="text_encoder",
48 | # default is clip skip 2
49 | num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
50 | torch_dtype=weight_dtype,
51 | cache_dir=DIFFUSERS_CACHE_DIR,
52 | )
53 | else:
54 | tokenizer = CLIPTokenizer.from_pretrained(
55 | TOKENIZER_V1_MODEL_NAME,
56 | subfolder="tokenizer",
57 | torch_dtype=weight_dtype,
58 | cache_dir=DIFFUSERS_CACHE_DIR,
59 | )
60 | text_encoder = CLIPTextModel.from_pretrained(
61 | pretrained_model_name_or_path,
62 | subfolder="text_encoder",
63 | num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
64 | torch_dtype=weight_dtype,
65 | cache_dir=DIFFUSERS_CACHE_DIR,
66 | )
67 |
68 | unet = UNet2DConditionModel.from_pretrained(
69 | pretrained_model_name_or_path,
70 | subfolder="unet",
71 | torch_dtype=weight_dtype,
72 | cache_dir=DIFFUSERS_CACHE_DIR,
73 | )
74 |
75 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
76 |
77 | return tokenizer, text_encoder, unet, vae
78 |
79 |
80 | def load_checkpoint_model(
81 | checkpoint_path: str,
82 | v2: bool = False,
83 | clip_skip: Optional[int] = None,
84 | weight_dtype: torch.dtype = torch.float32,
85 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
86 | pipe = StableDiffusionPipeline.from_ckpt(
87 | checkpoint_path,
88 | upcast_attention=True if v2 else False,
89 | torch_dtype=weight_dtype,
90 | cache_dir=DIFFUSERS_CACHE_DIR,
91 | )
92 |
93 | unet = pipe.unet
94 | tokenizer = pipe.tokenizer
95 | text_encoder = pipe.text_encoder
96 | vae = pipe.vae
97 | if clip_skip is not None:
98 | if v2:
99 | text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
100 | else:
101 | text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
102 |
103 | del pipe
104 |
105 | return tokenizer, text_encoder, unet, vae
106 |
107 |
108 | def load_models(
109 | pretrained_model_name_or_path: str,
110 | scheduler_name: AVAILABLE_SCHEDULERS,
111 | v2: bool = False,
112 | v_pred: bool = False,
113 | weight_dtype: torch.dtype = torch.float32,
114 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
115 | if pretrained_model_name_or_path.endswith(
116 | ".ckpt"
117 | ) or pretrained_model_name_or_path.endswith(".safetensors"):
118 | tokenizer, text_encoder, unet, vae = load_checkpoint_model(
119 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
120 | )
121 | else: # diffusers
122 | tokenizer, text_encoder, unet, vae = load_diffusers_model(
123 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
124 | )
125 |
126 | # VAE はいらない
127 |
128 | scheduler = create_noise_scheduler(
129 | scheduler_name,
130 | prediction_type="v_prediction" if v_pred else "epsilon",
131 | )
132 |
133 | return tokenizer, text_encoder, unet, scheduler, vae
134 |
135 |
136 | def load_diffusers_model_xl(
137 | pretrained_model_name_or_path: str,
138 | weight_dtype: torch.dtype = torch.float32,
139 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
140 | # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
141 |
142 | tokenizers = [
143 | CLIPTokenizer.from_pretrained(
144 | pretrained_model_name_or_path,
145 | subfolder="tokenizer",
146 | torch_dtype=weight_dtype,
147 | cache_dir=DIFFUSERS_CACHE_DIR,
148 | ),
149 | CLIPTokenizer.from_pretrained(
150 | pretrained_model_name_or_path,
151 | subfolder="tokenizer_2",
152 | torch_dtype=weight_dtype,
153 | cache_dir=DIFFUSERS_CACHE_DIR,
154 | pad_token_id=0, # same as open clip
155 | ),
156 | ]
157 |
158 | text_encoders = [
159 | CLIPTextModel.from_pretrained(
160 | pretrained_model_name_or_path,
161 | subfolder="text_encoder",
162 | torch_dtype=weight_dtype,
163 | cache_dir=DIFFUSERS_CACHE_DIR,
164 | ),
165 | CLIPTextModelWithProjection.from_pretrained(
166 | pretrained_model_name_or_path,
167 | subfolder="text_encoder_2",
168 | torch_dtype=weight_dtype,
169 | cache_dir=DIFFUSERS_CACHE_DIR,
170 | ),
171 | ]
172 |
173 | unet = UNet2DConditionModel.from_pretrained(
174 | pretrained_model_name_or_path,
175 | subfolder="unet",
176 | torch_dtype=weight_dtype,
177 | cache_dir=DIFFUSERS_CACHE_DIR,
178 | )
179 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
180 | return tokenizers, text_encoders, unet, vae
181 |
182 |
183 | def load_checkpoint_model_xl(
184 | checkpoint_path: str,
185 | weight_dtype: torch.dtype = torch.float32,
186 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
187 | pipe = StableDiffusionXLPipeline.from_single_file(
188 | checkpoint_path,
189 | torch_dtype=weight_dtype,
190 | cache_dir=DIFFUSERS_CACHE_DIR,
191 | )
192 |
193 | unet = pipe.unet
194 | tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
195 | text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
196 | if len(text_encoders) == 2:
197 | text_encoders[1].pad_token_id = 0
198 | vae = pipe.vae
199 | del pipe
200 |
201 | return tokenizers, text_encoders, unet, vae
202 |
203 |
204 | def load_models_xl(
205 | pretrained_model_name_or_path: str,
206 | scheduler_name: AVAILABLE_SCHEDULERS,
207 | weight_dtype: torch.dtype = torch.float32,
208 | ) -> tuple[
209 | list[CLIPTokenizer],
210 | list[SDXL_TEXT_ENCODER_TYPE],
211 | UNet2DConditionModel,
212 | SchedulerMixin,
213 | ]:
214 | if pretrained_model_name_or_path.endswith(
215 | ".ckpt"
216 | ) or pretrained_model_name_or_path.endswith(".safetensors"):
217 | (
218 | tokenizers,
219 | text_encoders,
220 | unet,
221 | vae
222 | ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
223 | else: # diffusers
224 | (
225 | tokenizers,
226 | text_encoders,
227 | unet,
228 | vae
229 | ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
230 |
231 | scheduler = create_noise_scheduler(scheduler_name)
232 |
233 | return tokenizers, text_encoders, unet, scheduler, vae
234 |
235 |
236 | def create_noise_scheduler(
237 | scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
238 | prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
239 | ) -> SchedulerMixin:
240 | # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。
241 |
242 | name = scheduler_name.lower().replace(" ", "_")
243 | if name == "ddim":
244 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
245 | scheduler = DDIMScheduler(
246 | beta_start=0.00085,
247 | beta_end=0.012,
248 | beta_schedule="scaled_linear",
249 | num_train_timesteps=1000,
250 | clip_sample=False,
251 | prediction_type=prediction_type, # これでいいの?
252 | )
253 | elif name == "ddpm":
254 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
255 | scheduler = DDPMScheduler(
256 | beta_start=0.00085,
257 | beta_end=0.012,
258 | beta_schedule="scaled_linear",
259 | num_train_timesteps=1000,
260 | clip_sample=False,
261 | prediction_type=prediction_type,
262 | )
263 | elif name == "lms":
264 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
265 | scheduler = LMSDiscreteScheduler(
266 | beta_start=0.00085,
267 | beta_end=0.012,
268 | beta_schedule="scaled_linear",
269 | num_train_timesteps=1000,
270 | prediction_type=prediction_type,
271 | )
272 | elif name == "euler_a":
273 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
274 | scheduler = EulerAncestralDiscreteScheduler(
275 | beta_start=0.00085,
276 | beta_end=0.012,
277 | beta_schedule="scaled_linear",
278 | num_train_timesteps=1000,
279 | prediction_type=prediction_type,
280 | )
281 | else:
282 | raise ValueError(f"Unknown scheduler name: {name}")
283 |
284 | return scheduler
285 |
--------------------------------------------------------------------------------
/trainscripts/imagesliders/prompt_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional, Union, List
2 |
3 | import yaml
4 | from pathlib import Path
5 |
6 |
7 | from pydantic import BaseModel, root_validator
8 | import torch
9 | import copy
10 |
11 | ACTION_TYPES = Literal[
12 | "erase",
13 | "enhance",
14 | ]
15 |
16 |
17 | # XL は二種類必要なので
18 | class PromptEmbedsXL:
19 | text_embeds: torch.FloatTensor
20 | pooled_embeds: torch.FloatTensor
21 |
22 | def __init__(self, *args) -> None:
23 | self.text_embeds = args[0]
24 | self.pooled_embeds = args[1]
25 |
26 |
27 | # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
28 | PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
29 |
30 |
31 | class PromptEmbedsCache: # 使いまわしたいので
32 | prompts: dict[str, PROMPT_EMBEDDING] = {}
33 |
34 | def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
35 | self.prompts[__name] = __value
36 |
37 | def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
38 | if __name in self.prompts:
39 | return self.prompts[__name]
40 | else:
41 | return None
42 |
43 |
44 | class PromptSettings(BaseModel): # yaml のやつ
45 | target: str
46 | positive: str = None # if None, target will be used
47 | unconditional: str = "" # default is ""
48 | neutral: str = None # if None, unconditional will be used
49 | action: ACTION_TYPES = "erase" # default is "erase"
50 | guidance_scale: float = 1.0 # default is 1.0
51 | resolution: int = 512 # default is 512
52 | dynamic_resolution: bool = False # default is False
53 | batch_size: int = 1 # default is 1
54 | dynamic_crops: bool = False # default is False. only used when model is XL
55 |
56 | @root_validator(pre=True)
57 | def fill_prompts(cls, values):
58 | keys = values.keys()
59 | if "target" not in keys:
60 | raise ValueError("target must be specified")
61 | if "positive" not in keys:
62 | values["positive"] = values["target"]
63 | if "unconditional" not in keys:
64 | values["unconditional"] = ""
65 | if "neutral" not in keys:
66 | values["neutral"] = values["unconditional"]
67 |
68 | return values
69 |
70 |
71 | class PromptEmbedsPair:
72 | target: PROMPT_EMBEDDING # not want to generate the concept
73 | positive: PROMPT_EMBEDDING # generate the concept
74 | unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
75 | neutral: PROMPT_EMBEDDING # base condition (default should be empty)
76 |
77 | guidance_scale: float
78 | resolution: int
79 | dynamic_resolution: bool
80 | batch_size: int
81 | dynamic_crops: bool
82 |
83 | loss_fn: torch.nn.Module
84 | action: ACTION_TYPES
85 |
86 | def __init__(
87 | self,
88 | loss_fn: torch.nn.Module,
89 | target: PROMPT_EMBEDDING,
90 | positive: PROMPT_EMBEDDING,
91 | unconditional: PROMPT_EMBEDDING,
92 | neutral: PROMPT_EMBEDDING,
93 | settings: PromptSettings,
94 | ) -> None:
95 | self.loss_fn = loss_fn
96 | self.target = target
97 | self.positive = positive
98 | self.unconditional = unconditional
99 | self.neutral = neutral
100 |
101 | self.guidance_scale = settings.guidance_scale
102 | self.resolution = settings.resolution
103 | self.dynamic_resolution = settings.dynamic_resolution
104 | self.batch_size = settings.batch_size
105 | self.dynamic_crops = settings.dynamic_crops
106 | self.action = settings.action
107 |
108 | def _erase(
109 | self,
110 | target_latents: torch.FloatTensor, # "van gogh"
111 | positive_latents: torch.FloatTensor, # "van gogh"
112 | unconditional_latents: torch.FloatTensor, # ""
113 | neutral_latents: torch.FloatTensor, # ""
114 | ) -> torch.FloatTensor:
115 | """Target latents are going not to have the positive concept."""
116 | return self.loss_fn(
117 | target_latents,
118 | neutral_latents
119 | - self.guidance_scale * (positive_latents - unconditional_latents)
120 | )
121 |
122 |
123 | def _enhance(
124 | self,
125 | target_latents: torch.FloatTensor, # "van gogh"
126 | positive_latents: torch.FloatTensor, # "van gogh"
127 | unconditional_latents: torch.FloatTensor, # ""
128 | neutral_latents: torch.FloatTensor, # ""
129 | ):
130 | """Target latents are going to have the positive concept."""
131 | return self.loss_fn(
132 | target_latents,
133 | neutral_latents
134 | + self.guidance_scale * (positive_latents - unconditional_latents)
135 | )
136 |
137 | def loss(
138 | self,
139 | **kwargs,
140 | ):
141 | if self.action == "erase":
142 | return self._erase(**kwargs)
143 |
144 | elif self.action == "enhance":
145 | return self._enhance(**kwargs)
146 |
147 | else:
148 | raise ValueError("action must be erase or enhance")
149 |
150 |
151 | def load_prompts_from_yaml(path, attributes = []):
152 | with open(path, "r") as f:
153 | prompts = yaml.safe_load(f)
154 | print(prompts)
155 | if len(prompts) == 0:
156 | raise ValueError("prompts file is empty")
157 | if len(attributes)!=0:
158 | newprompts = []
159 | for i in range(len(prompts)):
160 | for att in attributes:
161 | copy_ = copy.deepcopy(prompts[i])
162 | copy_['target'] = att + ' ' + copy_['target']
163 | copy_['positive'] = att + ' ' + copy_['positive']
164 | copy_['neutral'] = att + ' ' + copy_['neutral']
165 | copy_['unconditional'] = att + ' ' + copy_['unconditional']
166 | newprompts.append(copy_)
167 | else:
168 | newprompts = copy.deepcopy(prompts)
169 |
170 | print(newprompts)
171 | print(len(prompts), len(newprompts))
172 | prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
173 |
174 | return prompt_settings
175 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__init__.py
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/config_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/config_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/debug_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/debug_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/lora.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/lora.cpython-310.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/lora.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/lora.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/model_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/model_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/prompt_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/prompt_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/ptp_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/ptp_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/__pycache__/train_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/sliders-for-windows/60eb565b0601b5bb5490da4633987f04fa16fadc/trainscripts/textsliders/__pycache__/train_util.cpython-39.pyc
--------------------------------------------------------------------------------
/trainscripts/textsliders/config_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional
2 |
3 | import yaml
4 |
5 | from pydantic import BaseModel
6 | import torch
7 |
8 | from lora import TRAINING_METHODS
9 |
10 | PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
11 | NETWORK_TYPES = Literal["lierla", "c3lier"]
12 |
13 |
14 | class PretrainedModelConfig(BaseModel):
15 | name_or_path: str
16 | v2: bool = False
17 | v_pred: bool = False
18 |
19 | clip_skip: Optional[int] = None
20 |
21 |
22 | class NetworkConfig(BaseModel):
23 | type: NETWORK_TYPES = "lierla"
24 | rank: int = 4
25 | alpha: float = 1.0
26 |
27 | training_method: TRAINING_METHODS = "full"
28 |
29 |
30 | class TrainConfig(BaseModel):
31 | precision: PRECISION_TYPES = "bfloat16"
32 | noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
33 |
34 | iterations: int = 500
35 | lr: float = 1e-4
36 | optimizer: str = "adamw"
37 | optimizer_args: str = ""
38 | lr_scheduler: str = "constant"
39 |
40 | max_denoising_steps: int = 50
41 |
42 |
43 | class SaveConfig(BaseModel):
44 | name: str = "untitled"
45 | path: str = "./output"
46 | per_steps: int = 200
47 | precision: PRECISION_TYPES = "float32"
48 |
49 |
50 | class LoggingConfig(BaseModel):
51 | use_wandb: bool = False
52 |
53 | verbose: bool = False
54 |
55 |
56 | class OtherConfig(BaseModel):
57 | use_xformers: bool = False
58 |
59 |
60 | class RootConfig(BaseModel):
61 | prompts_file: str
62 | pretrained_model: PretrainedModelConfig
63 |
64 | network: NetworkConfig
65 |
66 | train: Optional[TrainConfig]
67 |
68 | save: Optional[SaveConfig]
69 |
70 | logging: Optional[LoggingConfig]
71 |
72 | other: Optional[OtherConfig]
73 |
74 |
75 | def parse_precision(precision: str) -> torch.dtype:
76 | if precision == "fp32" or precision == "float32":
77 | return torch.float32
78 | elif precision == "fp16" or precision == "float16":
79 | return torch.float16
80 | elif precision == "bf16" or precision == "bfloat16":
81 | return torch.bfloat16
82 |
83 | raise ValueError(f"Invalid precision type: {precision}")
84 |
85 |
86 | def load_config_from_yaml(config_path: str) -> RootConfig:
87 | with open(config_path, "r") as f:
88 | config = yaml.load(f, Loader=yaml.FullLoader)
89 |
90 | root = RootConfig(**config)
91 |
92 | if root.train is None:
93 | root.train = TrainConfig()
94 |
95 | if root.save is None:
96 | root.save = SaveConfig()
97 |
98 | if root.logging is None:
99 | root.logging = LoggingConfig()
100 |
101 | if root.other is None:
102 | root.other = OtherConfig()
103 |
104 | return root
105 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/.ipynb_checkpoints/prompts-person_age_slider_GPT-checkpoint.yaml:
--------------------------------------------------------------------------------
1 | - target: male white person
2 | positive: male white person, very old, wrinkly, gray hair, aged skin
3 | unconditional: male white person, very young, smooth skin, youthful
4 | neutral: male white person
5 | guidance: 4
6 | rank: 4
7 | action: enhance
8 | resolution: 512
9 | dynamic_resolution: false
10 | batch_size: 1
11 | - target: male black person
12 | positive: male black person, very old, wrinkly, gray hair, aged skin
13 | unconditional: male black person, very young, smooth skin, youthful
14 | neutral: male black person
15 | guidance: 4
16 | rank: 4
17 | action: enhance
18 | resolution: 512
19 | dynamic_resolution: false
20 | batch_size: 1
21 | - target: male indian person
22 | positive: male indian person, very old, wrinkly, gray hair, aged skin
23 | unconditional: male indian person, very young, smooth skin, youthful
24 | neutral: male indian person
25 | guidance: 4
26 | rank: 4
27 | action: enhance
28 | resolution: 512
29 | dynamic_resolution: false
30 | batch_size: 1
31 | - target: male asian person
32 | positive: male asian person, very old, wrinkly, gray hair, aged skin
33 | unconditional: male asian person, very young, smooth skin, youthful
34 | neutral: male asian person
35 | guidance: 4
36 | rank: 4
37 | action: enhance
38 | resolution: 512
39 | dynamic_resolution: false
40 | batch_size: 1
41 | - target: male hispanic person
42 | positive: male hispanic person, very old, wrinkly, gray hair, aged skin
43 | unconditional: male hispanic person, very young, smooth skin, youthful
44 | neutral: male hispanic person
45 | guidance: 4
46 | rank: 4
47 | action: enhance
48 | resolution: 512
49 | dynamic_resolution: false
50 | batch_size: 1
51 | - target: ' female white person'
52 | positive: ' female white person, very old, wrinkly, gray hair, aged skin'
53 | unconditional: ' female white person, very young, smooth skin, youthful'
54 | neutral: ' female white person'
55 | guidance: 4
56 | rank: 4
57 | action: enhance
58 | resolution: 512
59 | dynamic_resolution: false
60 | batch_size: 1
61 | - target: ' female black person'
62 | positive: ' female black person, very old, wrinkly, gray hair, aged skin'
63 | unconditional: ' female black person, very young, smooth skin, youthful'
64 | neutral: ' female black person'
65 | guidance: 4
66 | rank: 4
67 | action: enhance
68 | resolution: 512
69 | dynamic_resolution: false
70 | batch_size: 1
71 | - target: ' female indian person'
72 | positive: ' female indian person, very old, wrinkly, gray hair, aged skin'
73 | unconditional: ' female indian person, very young, smooth skin, youthful'
74 | neutral: ' female indian person'
75 | guidance: 4
76 | rank: 4
77 | action: enhance
78 | resolution: 512
79 | dynamic_resolution: false
80 | batch_size: 1
81 | - target: ' female asian person'
82 | positive: ' female asian person, very old, wrinkly, gray hair, aged skin'
83 | unconditional: ' female asian person, very young, smooth skin, youthful'
84 | neutral: ' female asian person'
85 | guidance: 4
86 | rank: 4
87 | action: enhance
88 | resolution: 512
89 | dynamic_resolution: false
90 | batch_size: 1
91 | - target: ' female hispanic person'
92 | positive: ' female hispanic person, very old, wrinkly, gray hair, aged skin'
93 | unconditional: ' female hispanic person, very young, smooth skin, youthful'
94 | neutral: ' female hispanic person'
95 | guidance: 4
96 | rank: 4
97 | action: enhance
98 | resolution: 512
99 | dynamic_resolution: false
100 | batch_size: 1
101 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/.ipynb_checkpoints/prompts-smile_slider_GPT-checkpoint.yaml:
--------------------------------------------------------------------------------
1 | - target: male white person
2 | positive: male white person, smiling, happy face, big smile
3 | unconditional: male white person, frowning, grumpy, sad
4 | neutral: male white person
5 | guidance: 4
6 | rank: 4
7 | action: enhance
8 | resolution: 512
9 | dynamic_resolution: false
10 | batch_size: 1
11 | - target: male black person
12 | positive: male black person, smiling, happy face, big smile
13 | unconditional: male black person, frowning, grumpy, sad
14 | neutral: male black person
15 | guidance: 4
16 | rank: 4
17 | action: enhance
18 | resolution: 512
19 | dynamic_resolution: false
20 | batch_size: 1
21 | - target: male indian person
22 | positive: male indian person, smiling, happy face, big smile
23 | unconditional: male indian person, frowning, grumpy, sad
24 | neutral: male indian person
25 | guidance: 4
26 | rank: 4
27 | action: enhance
28 | resolution: 512
29 | dynamic_resolution: false
30 | batch_size: 1
31 | - target: male asian person
32 | positive: male asian person, smiling, happy face, big smile
33 | unconditional: male asian person, frowning, grumpy, sad
34 | neutral: male asian person
35 | guidance: 4
36 | rank: 4
37 | action: enhance
38 | resolution: 512
39 | dynamic_resolution: false
40 | batch_size: 1
41 | - target: male hispanic person
42 | positive: male hispanic person, smiling, happy face, big smile
43 | unconditional: male hispanic person, frowning, grumpy, sad
44 | neutral: male hispanic person
45 | guidance: 4
46 | rank: 4
47 | action: enhance
48 | resolution: 512
49 | dynamic_resolution: false
50 | batch_size: 1
51 | - target: female white person
52 | positive: female white person, smiling, happy face, big smile
53 | unconditional: female white person, frowning, grumpy, sad
54 | neutral: female white person
55 | guidance: 4
56 | rank: 4
57 | action: enhance
58 | resolution: 512
59 | dynamic_resolution: false
60 | batch_size: 1
61 | - target: female black person
62 | positive: female black person, smiling, happy face, big smile
63 | unconditional: female black person, frowning, grumpy, sad
64 | neutral: female black person
65 | guidance: 4
66 | rank: 4
67 | action: enhance
68 | resolution: 512
69 | dynamic_resolution: false
70 | batch_size: 1
71 | - target: female indian person
72 | positive: female indian person, smiling, happy face, big smile
73 | unconditional: female indian person, frowning, grumpy, sad
74 | neutral: female indian person
75 | guidance: 4
76 | rank: 4
77 | action: enhance
78 | resolution: 512
79 | dynamic_resolution: false
80 | batch_size: 1
81 | - target: female asian person
82 | positive: female asian person, smiling, happy face, big smile
83 | unconditional: female asian person, frowning, grumpy, sad
84 | neutral: female asian person
85 | guidance: 4
86 | rank: 4
87 | action: enhance
88 | resolution: 512
89 | dynamic_resolution: false
90 | batch_size: 1
91 | - target: female hispanic person
92 | positive: female hispanic person, smiling, happy face, big smile
93 | unconditional: female hispanic person, frowning, grumpy, sad
94 | neutral: female hispanic person
95 | guidance: 4
96 | rank: 4
97 | action: enhance
98 | resolution: 512
99 | dynamic_resolution: false
100 | batch_size: 1
101 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/config-xl.yaml:
--------------------------------------------------------------------------------
1 | prompts_file: "trainscripts/textsliders/data/prompts-xl.yaml"
2 | pretrained_model:
3 | name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models
4 | v2: false # true if model is v2.x
5 | v_pred: false # true if model uses v-prediction
6 | network:
7 | type: "c3lier" # or "c3lier" or "lierla"
8 | rank: 4
9 | alpha: 1.0
10 | training_method: "noxattn"
11 | train:
12 | precision: "bfloat16"
13 | noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14 | iterations: 1000
15 | lr: 0.0002
16 | optimizer: "AdamW"
17 | lr_scheduler: "constant"
18 | max_denoising_steps: 50
19 | save:
20 | name: "temp"
21 | path: "./models"
22 | per_steps: 500
23 | precision: "bfloat16"
24 | logging:
25 | use_wandb: false
26 | verbose: false
27 | other:
28 | use_xformers: true
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/config.yaml:
--------------------------------------------------------------------------------
1 | prompts_file: "trainscripts/textsliders/data/prompts.yaml"
2 | pretrained_model:
3 | name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models
4 | v2: false # true if model is v2.x
5 | v_pred: false # true if model uses v-prediction
6 | network:
7 | type: "c3lier" # or "c3lier" or "lierla"
8 | rank: 4
9 | alpha: 1.0
10 | training_method: "noxattn"
11 | train:
12 | precision: "bfloat16"
13 | noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14 | iterations: 1000
15 | lr: 0.0002
16 | optimizer: "AdamW"
17 | lr_scheduler: "constant"
18 | max_denoising_steps: 50
19 | save:
20 | name: "temp"
21 | path: "./models"
22 | per_steps: 500
23 | precision: "bfloat16"
24 | logging:
25 | use_wandb: false
26 | verbose: false
27 | other:
28 | use_xformers: true
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/prompts-animated_eyes_GPT.yaml:
--------------------------------------------------------------------------------
1 | - target: male white person
2 | positive: male white person, eyes extremely large, animatedly big eyes, cartoonish
3 | proportions
4 | unconditional: male white person, eyes normal size, realistic proportions
5 | neutral: male white person
6 | guidance: 4
7 | rank: 4
8 | action: enhance
9 | resolution: 512
10 | dynamic_resolution: false
11 | batch_size: 1
12 | - target: male black person
13 | positive: male black person, eyes extremely large, animatedly big eyes, cartoonish
14 | proportions
15 | unconditional: male black person, eyes normal size, realistic proportions
16 | neutral: male black person
17 | guidance: 4
18 | rank: 4
19 | action: enhance
20 | resolution: 512
21 | dynamic_resolution: false
22 | batch_size: 1
23 | - target: male indian person
24 | positive: male indian person, eyes extremely large, animatedly big eyes, cartoonish
25 | proportions
26 | unconditional: male indian person, eyes normal size, realistic proportions
27 | neutral: male indian person
28 | guidance: 4
29 | rank: 4
30 | action: enhance
31 | resolution: 512
32 | dynamic_resolution: false
33 | batch_size: 1
34 | - target: male asian person
35 | positive: male asian person, eyes extremely large, animatedly big eyes, cartoonish
36 | proportions
37 | unconditional: male asian person, eyes normal size, realistic proportions
38 | neutral: male asian person
39 | guidance: 4
40 | rank: 4
41 | action: enhance
42 | resolution: 512
43 | dynamic_resolution: false
44 | batch_size: 1
45 | - target: male hispanic person
46 | positive: male hispanic person, eyes extremely large, animatedly big eyes, cartoonish
47 | proportions
48 | unconditional: male hispanic person, eyes normal size, realistic proportions
49 | neutral: male hispanic person
50 | guidance: 4
51 | rank: 4
52 | action: enhance
53 | resolution: 512
54 | dynamic_resolution: false
55 | batch_size: 1
56 | - target: female white person
57 | positive: female white person, eyes extremely large, animatedly big eyes, cartoonish
58 | proportions
59 | unconditional: female white person, eyes normal size, realistic proportions
60 | neutral: female white person
61 | guidance: 4
62 | rank: 4
63 | action: enhance
64 | resolution: 512
65 | dynamic_resolution: false
66 | batch_size: 1
67 | - target: female black person
68 | positive: female black person, eyes extremely large, animatedly big eyes, cartoonish
69 | proportions
70 | unconditional: female black person, eyes normal size, realistic proportions
71 | neutral: female black person
72 | guidance: 4
73 | rank: 4
74 | action: enhance
75 | resolution: 512
76 | dynamic_resolution: false
77 | batch_size: 1
78 | - target: female indian person
79 | positive: female indian person, eyes extremely large, animatedly big eyes, cartoonish
80 | proportions
81 | unconditional: female indian person, eyes normal size, realistic proportions
82 | neutral: female indian person
83 | guidance: 4
84 | rank: 4
85 | action: enhance
86 | resolution: 512
87 | dynamic_resolution: false
88 | batch_size: 1
89 | - target: female asian person
90 | positive: female asian person, eyes extremely large, animatedly big eyes, cartoonish
91 | proportions
92 | unconditional: female asian person, eyes normal size, realistic proportions
93 | neutral: female asian person
94 | guidance: 4
95 | rank: 4
96 | action: enhance
97 | resolution: 512
98 | dynamic_resolution: false
99 | batch_size: 1
100 | - target: female hispanic person
101 | positive: female hispanic person, eyes extremely large, animatedly big eyes, cartoonish
102 | proportions
103 | unconditional: female hispanic person, eyes normal size, realistic proportions
104 | neutral: female hispanic person
105 | guidance: 4
106 | rank: 4
107 | action: enhance
108 | resolution: 512
109 | dynamic_resolution: false
110 | batch_size: 1
111 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/prompts-car_alienTechFuturistic_GPT.yaml:
--------------------------------------------------------------------------------
1 | - target: car
2 | positive: car, alien technology, futuristic design, advanced propulsion, sleek aerodynamic
3 | shape
4 | unconditional: car, outdated, old-fashioned, conventional design
5 | neutral: car
6 | guidance: 4
7 | rank: 4
8 | action: enhance
9 | resolution: 512
10 | dynamic_resolution: false
11 | batch_size: 1
12 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/prompts-jewelry_diamonds_GPT.yaml:
--------------------------------------------------------------------------------
1 | - target: male white person
2 | positive: male white person adorned with jewelry, sparkling diamonds, elegant necklaces,
3 | luxurious bracelets, shimmering earrings
4 | unconditional: male white person with no jewelry, plain appearance, unadorned
5 | neutral: male white person
6 | guidance: 4
7 | rank: 4
8 | action: enhance
9 | resolution: 512
10 | dynamic_resolution: false
11 | batch_size: 1
12 | - target: male black person
13 | positive: male black person adorned with jewelry, sparkling diamonds, elegant necklaces,
14 | luxurious bracelets, shimmering earrings
15 | unconditional: male black person with no jewelry, plain appearance, unadorned
16 | neutral: male black person
17 | guidance: 4
18 | rank: 4
19 | action: enhance
20 | resolution: 512
21 | dynamic_resolution: false
22 | batch_size: 1
23 | - target: male indian person
24 | positive: male indian person adorned with jewelry, sparkling diamonds, elegant necklaces,
25 | luxurious bracelets, shimmering earrings
26 | unconditional: male indian person with no jewelry, plain appearance, unadorned
27 | neutral: male indian person
28 | guidance: 4
29 | rank: 4
30 | action: enhance
31 | resolution: 512
32 | dynamic_resolution: false
33 | batch_size: 1
34 | - target: male asian person
35 | positive: male asian person adorned with jewelry, sparkling diamonds, elegant necklaces,
36 | luxurious bracelets, shimmering earrings
37 | unconditional: male asian person with no jewelry, plain appearance, unadorned
38 | neutral: male asian person
39 | guidance: 4
40 | rank: 4
41 | action: enhance
42 | resolution: 512
43 | dynamic_resolution: false
44 | batch_size: 1
45 | - target: male hispanic person
46 | positive: male hispanic person adorned with jewelry, sparkling diamonds, elegant
47 | necklaces, luxurious bracelets, shimmering earrings
48 | unconditional: male hispanic person with no jewelry, plain appearance, unadorned
49 | neutral: male hispanic person
50 | guidance: 4
51 | rank: 4
52 | action: enhance
53 | resolution: 512
54 | dynamic_resolution: false
55 | batch_size: 1
56 | - target: female white person
57 | positive: female white person adorned with jewelry, sparkling diamonds, elegant
58 | necklaces, luxurious bracelets, shimmering earrings
59 | unconditional: female white person with no jewelry, plain appearance, unadorned
60 | neutral: female white person
61 | guidance: 4
62 | rank: 4
63 | action: enhance
64 | resolution: 512
65 | dynamic_resolution: false
66 | batch_size: 1
67 | - target: female black person
68 | positive: female black person adorned with jewelry, sparkling diamonds, elegant
69 | necklaces, luxurious bracelets, shimmering earrings
70 | unconditional: female black person with no jewelry, plain appearance, unadorned
71 | neutral: female black person
72 | guidance: 4
73 | rank: 4
74 | action: enhance
75 | resolution: 512
76 | dynamic_resolution: false
77 | batch_size: 1
78 | - target: female indian person
79 | positive: female indian person adorned with jewelry, sparkling diamonds, elegant
80 | necklaces, luxurious bracelets, shimmering earrings
81 | unconditional: female indian person with no jewelry, plain appearance, unadorned
82 | neutral: female indian person
83 | guidance: 4
84 | rank: 4
85 | action: enhance
86 | resolution: 512
87 | dynamic_resolution: false
88 | batch_size: 1
89 | - target: female asian person
90 | positive: female asian person adorned with jewelry, sparkling diamonds, elegant
91 | necklaces, luxurious bracelets, shimmering earrings
92 | unconditional: female asian person with no jewelry, plain appearance, unadorned
93 | neutral: female asian person
94 | guidance: 4
95 | rank: 4
96 | action: enhance
97 | resolution: 512
98 | dynamic_resolution: false
99 | batch_size: 1
100 | - target: female hispanic person
101 | positive: female hispanic person adorned with jewelry, sparkling diamonds, elegant
102 | necklaces, luxurious bracelets, shimmering earrings
103 | unconditional: female hispanic person with no jewelry, plain appearance, unadorned
104 | neutral: female hispanic person
105 | guidance: 4
106 | rank: 4
107 | action: enhance
108 | resolution: 512
109 | dynamic_resolution: false
110 | batch_size: 1
111 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/prompts-person_age_slider_GPT.yaml:
--------------------------------------------------------------------------------
1 | - target: male white person
2 | positive: male white person, very old, wrinkly, gray hair, aged skin
3 | unconditional: male white person, very young, smooth skin, youthful
4 | neutral: male white person
5 | guidance: 4
6 | rank: 4
7 | action: enhance
8 | resolution: 512
9 | dynamic_resolution: false
10 | batch_size: 1
11 | - target: male black person
12 | positive: male black person, very old, wrinkly, gray hair, aged skin
13 | unconditional: male black person, very young, smooth skin, youthful
14 | neutral: male black person
15 | guidance: 4
16 | rank: 4
17 | action: enhance
18 | resolution: 512
19 | dynamic_resolution: false
20 | batch_size: 1
21 | - target: male indian person
22 | positive: male indian person, very old, wrinkly, gray hair, aged skin
23 | unconditional: male indian person, very young, smooth skin, youthful
24 | neutral: male indian person
25 | guidance: 4
26 | rank: 4
27 | action: enhance
28 | resolution: 512
29 | dynamic_resolution: false
30 | batch_size: 1
31 | - target: male asian person
32 | positive: male asian person, very old, wrinkly, gray hair, aged skin
33 | unconditional: male asian person, very young, smooth skin, youthful
34 | neutral: male asian person
35 | guidance: 4
36 | rank: 4
37 | action: enhance
38 | resolution: 512
39 | dynamic_resolution: false
40 | batch_size: 1
41 | - target: male hispanic person
42 | positive: male hispanic person, very old, wrinkly, gray hair, aged skin
43 | unconditional: male hispanic person, very young, smooth skin, youthful
44 | neutral: male hispanic person
45 | guidance: 4
46 | rank: 4
47 | action: enhance
48 | resolution: 512
49 | dynamic_resolution: false
50 | batch_size: 1
51 | - target: ' female white person'
52 | positive: ' female white person, very old, wrinkly, gray hair, aged skin'
53 | unconditional: ' female white person, very young, smooth skin, youthful'
54 | neutral: ' female white person'
55 | guidance: 4
56 | rank: 4
57 | action: enhance
58 | resolution: 512
59 | dynamic_resolution: false
60 | batch_size: 1
61 | - target: ' female black person'
62 | positive: ' female black person, very old, wrinkly, gray hair, aged skin'
63 | unconditional: ' female black person, very young, smooth skin, youthful'
64 | neutral: ' female black person'
65 | guidance: 4
66 | rank: 4
67 | action: enhance
68 | resolution: 512
69 | dynamic_resolution: false
70 | batch_size: 1
71 | - target: ' female indian person'
72 | positive: ' female indian person, very old, wrinkly, gray hair, aged skin'
73 | unconditional: ' female indian person, very young, smooth skin, youthful'
74 | neutral: ' female indian person'
75 | guidance: 4
76 | rank: 4
77 | action: enhance
78 | resolution: 512
79 | dynamic_resolution: false
80 | batch_size: 1
81 | - target: ' female asian person'
82 | positive: ' female asian person, very old, wrinkly, gray hair, aged skin'
83 | unconditional: ' female asian person, very young, smooth skin, youthful'
84 | neutral: ' female asian person'
85 | guidance: 4
86 | rank: 4
87 | action: enhance
88 | resolution: 512
89 | dynamic_resolution: false
90 | batch_size: 1
91 | - target: ' female hispanic person'
92 | positive: ' female hispanic person, very old, wrinkly, gray hair, aged skin'
93 | unconditional: ' female hispanic person, very young, smooth skin, youthful'
94 | neutral: ' female hispanic person'
95 | guidance: 4
96 | rank: 4
97 | action: enhance
98 | resolution: 512
99 | dynamic_resolution: false
100 | batch_size: 1
101 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/prompts-person_surprised_GPT.yaml:
--------------------------------------------------------------------------------
1 | - target: male, white, person
2 | positive: male, white, person, looking surprised, wide eyes, open mouth
3 | unconditional: male, white, person, looking calm, neutral expression
4 | neutral: male, white, person
5 | guidance: 4
6 | rank: 4
7 | action: enhance
8 | resolution: 512
9 | dynamic_resolution: false
10 | batch_size: 1
11 | - target: male, black, person
12 | positive: male, black, person, looking surprised, wide eyes, open mouth
13 | unconditional: male, black, person, looking calm, neutral expression
14 | neutral: male, black, person
15 | guidance: 4
16 | rank: 4
17 | action: enhance
18 | resolution: 512
19 | dynamic_resolution: false
20 | batch_size: 1
21 | - target: male, indian, person
22 | positive: male, indian, person, looking surprised, wide eyes, open mouth
23 | unconditional: male, indian, person, looking calm, neutral expression
24 | neutral: male, indian, person
25 | guidance: 4
26 | rank: 4
27 | action: enhance
28 | resolution: 512
29 | dynamic_resolution: false
30 | batch_size: 1
31 | - target: male, asian, person
32 | positive: male, asian, person, looking surprised, wide eyes, open mouth
33 | unconditional: male, asian, person, looking calm, neutral expression
34 | neutral: male, asian, person
35 | guidance: 4
36 | rank: 4
37 | action: enhance
38 | resolution: 512
39 | dynamic_resolution: false
40 | batch_size: 1
41 | - target: male, hispanic, person
42 | positive: male, hispanic, person, looking surprised, wide eyes, open mouth
43 | unconditional: male, hispanic, person, looking calm, neutral expression
44 | neutral: male, hispanic, person
45 | guidance: 4
46 | rank: 4
47 | action: enhance
48 | resolution: 512
49 | dynamic_resolution: false
50 | batch_size: 1
51 | - target: ' female, white, person'
52 | positive: ' female, white, person, looking surprised, wide eyes, open mouth'
53 | unconditional: ' female, white, person, looking calm, neutral expression'
54 | neutral: ' female, white, person'
55 | guidance: 4
56 | rank: 4
57 | action: enhance
58 | resolution: 512
59 | dynamic_resolution: false
60 | batch_size: 1
61 | - target: ' female, black, person'
62 | positive: ' female, black, person, looking surprised, wide eyes, open mouth'
63 | unconditional: ' female, black, person, looking calm, neutral expression'
64 | neutral: ' female, black, person'
65 | guidance: 4
66 | rank: 4
67 | action: enhance
68 | resolution: 512
69 | dynamic_resolution: false
70 | batch_size: 1
71 | - target: ' female, indian, person'
72 | positive: ' female, indian, person, looking surprised, wide eyes, open mouth'
73 | unconditional: ' female, indian, person, looking calm, neutral expression'
74 | neutral: ' female, indian, person'
75 | guidance: 4
76 | rank: 4
77 | action: enhance
78 | resolution: 512
79 | dynamic_resolution: false
80 | batch_size: 1
81 | - target: ' female, asian, person'
82 | positive: ' female, asian, person, looking surprised, wide eyes, open mouth'
83 | unconditional: ' female, asian, person, looking calm, neutral expression'
84 | neutral: ' female, asian, person'
85 | guidance: 4
86 | rank: 4
87 | action: enhance
88 | resolution: 512
89 | dynamic_resolution: false
90 | batch_size: 1
91 | - target: ' female, hispanic, person'
92 | positive: ' female, hispanic, person, looking surprised, wide eyes, open mouth'
93 | unconditional: ' female, hispanic, person, looking calm, neutral expression'
94 | neutral: ' female, hispanic, person'
95 | guidance: 4
96 | rank: 4
97 | action: enhance
98 | resolution: 512
99 | dynamic_resolution: false
100 | batch_size: 1
101 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/prompts-smile_slider_GPT.yaml:
--------------------------------------------------------------------------------
1 | - target: male white person
2 | positive: male white person, smiling, happy face, big smile
3 | unconditional: male white person, frowning, grumpy, sad
4 | neutral: male white person
5 | guidance: 4
6 | rank: 4
7 | action: enhance
8 | resolution: 512
9 | dynamic_resolution: false
10 | batch_size: 1
11 | - target: male black person
12 | positive: male black person, smiling, happy face, big smile
13 | unconditional: male black person, frowning, grumpy, sad
14 | neutral: male black person
15 | guidance: 4
16 | rank: 4
17 | action: enhance
18 | resolution: 512
19 | dynamic_resolution: false
20 | batch_size: 1
21 | - target: male indian person
22 | positive: male indian person, smiling, happy face, big smile
23 | unconditional: male indian person, frowning, grumpy, sad
24 | neutral: male indian person
25 | guidance: 4
26 | rank: 4
27 | action: enhance
28 | resolution: 512
29 | dynamic_resolution: false
30 | batch_size: 1
31 | - target: male asian person
32 | positive: male asian person, smiling, happy face, big smile
33 | unconditional: male asian person, frowning, grumpy, sad
34 | neutral: male asian person
35 | guidance: 4
36 | rank: 4
37 | action: enhance
38 | resolution: 512
39 | dynamic_resolution: false
40 | batch_size: 1
41 | - target: male hispanic person
42 | positive: male hispanic person, smiling, happy face, big smile
43 | unconditional: male hispanic person, frowning, grumpy, sad
44 | neutral: male hispanic person
45 | guidance: 4
46 | rank: 4
47 | action: enhance
48 | resolution: 512
49 | dynamic_resolution: false
50 | batch_size: 1
51 | - target: female white person
52 | positive: female white person, smiling, happy face, big smile
53 | unconditional: female white person, frowning, grumpy, sad
54 | neutral: female white person
55 | guidance: 4
56 | rank: 4
57 | action: enhance
58 | resolution: 512
59 | dynamic_resolution: false
60 | batch_size: 1
61 | - target: female black person
62 | positive: female black person, smiling, happy face, big smile
63 | unconditional: female black person, frowning, grumpy, sad
64 | neutral: female black person
65 | guidance: 4
66 | rank: 4
67 | action: enhance
68 | resolution: 512
69 | dynamic_resolution: false
70 | batch_size: 1
71 | - target: female indian person
72 | positive: female indian person, smiling, happy face, big smile
73 | unconditional: female indian person, frowning, grumpy, sad
74 | neutral: female indian person
75 | guidance: 4
76 | rank: 4
77 | action: enhance
78 | resolution: 512
79 | dynamic_resolution: false
80 | batch_size: 1
81 | - target: female asian person
82 | positive: female asian person, smiling, happy face, big smile
83 | unconditional: female asian person, frowning, grumpy, sad
84 | neutral: female asian person
85 | guidance: 4
86 | rank: 4
87 | action: enhance
88 | resolution: 512
89 | dynamic_resolution: false
90 | batch_size: 1
91 | - target: female hispanic person
92 | positive: female hispanic person, smiling, happy face, big smile
93 | unconditional: female hispanic person, frowning, grumpy, sad
94 | neutral: female hispanic person
95 | guidance: 4
96 | rank: 4
97 | action: enhance
98 | resolution: 512
99 | dynamic_resolution: false
100 | batch_size: 1
101 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/data/prompts.yaml:
--------------------------------------------------------------------------------
1 | - target: "male person" # what word for erasing the positive concept from
2 | positive: "male person, very old" # concept to erase
3 | unconditional: "male person, very young" # word to take the difference from the positive concept
4 | neutral: "male person" # starting point for conditioning the target
5 | action: "enhance" # erase or enhance
6 | guidance_scale: 4
7 | resolution: 512
8 | dynamic_resolution: false
9 | batch_size: 1
10 | - target: "female person" # what word for erasing the positive concept from
11 | positive: "female person, very old" # concept to erase
12 | unconditional: "female person, very young" # word to take the difference from the positive concept
13 | neutral: "female person" # starting point for conditioning the target
14 | action: "enhance" # erase or enhance
15 | guidance_scale: 4
16 | resolution: 512
17 | dynamic_resolution: false
18 | batch_size: 1
19 | # - target: "" # what word for erasing the positive concept from
20 | # positive: "a group of people" # concept to erase
21 | # unconditional: "a person" # word to take the difference from the positive concept
22 | # neutral: "" # starting point for conditioning the target
23 | # action: "enhance" # erase or enhance
24 | # guidance_scale: 4
25 | # resolution: 512
26 | # dynamic_resolution: false
27 | # batch_size: 1
28 | # - target: "" # what word for erasing the positive concept from
29 | # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase
30 | # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
31 | # neutral: "" # starting point for conditioning the target
32 | # action: "enhance" # erase or enhance
33 | # guidance_scale: 4
34 | # resolution: 512
35 | # dynamic_resolution: false
36 | # batch_size: 1
37 | # - target: "" # what word for erasing the positive concept from
38 | # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
39 | # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
40 | # unconditional: ""
41 | # neutral: "" # starting point for conditioning the target
42 | # action: "enhance" # erase or enhance
43 | # guidance_scale: 4
44 | # resolution: 512
45 | # dynamic_resolution: false
46 | # batch_size: 1
47 | # - target: "food" # what word for erasing the positive concept from
48 | # positive: "food, expensive and fine dining" # concept to erase
49 | # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
50 | # neutral: "food" # starting point for conditioning the target
51 | # action: "enhance" # erase or enhance
52 | # guidance_scale: 4
53 | # resolution: 512
54 | # dynamic_resolution: false
55 | # batch_size: 1
56 | # - target: "room" # what word for erasing the positive concept from
57 | # positive: "room, dirty disorganised and cluttered" # concept to erase
58 | # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
59 | # neutral: "room" # starting point for conditioning the target
60 | # action: "enhance" # erase or enhance
61 | # guidance_scale: 4
62 | # resolution: 512
63 | # dynamic_resolution: false
64 | # batch_size: 1
65 | # - target: "male person" # what word for erasing the positive concept from
66 | # positive: "male person, with a surprised look" # concept to erase
67 | # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
68 | # neutral: "male person" # starting point for conditioning the target
69 | # action: "enhance" # erase or enhance
70 | # guidance_scale: 4
71 | # resolution: 512
72 | # dynamic_resolution: false
73 | # batch_size: 1
74 | # - target: "female person" # what word for erasing the positive concept from
75 | # positive: "female person, with a surprised look" # concept to erase
76 | # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
77 | # neutral: "female person" # starting point for conditioning the target
78 | # action: "enhance" # erase or enhance
79 | # guidance_scale: 4
80 | # resolution: 512
81 | # dynamic_resolution: false
82 | # batch_size: 1
83 | # - target: "sky" # what word for erasing the positive concept from
84 | # positive: "peaceful sky" # concept to erase
85 | # unconditional: "sky" # word to take the difference from the positive concept
86 | # neutral: "sky" # starting point for conditioning the target
87 | # action: "enhance" # erase or enhance
88 | # guidance_scale: 4
89 | # resolution: 512
90 | # dynamic_resolution: false
91 | # batch_size: 1
92 | # - target: "sky" # what word for erasing the positive concept from
93 | # positive: "chaotic dark sky" # concept to erase
94 | # unconditional: "sky" # word to take the difference from the positive concept
95 | # neutral: "sky" # starting point for conditioning the target
96 | # action: "erase" # erase or enhance
97 | # guidance_scale: 4
98 | # resolution: 512
99 | # dynamic_resolution: false
100 | # batch_size: 1
101 | # - target: "person" # what word for erasing the positive concept from
102 | # positive: "person, very young" # concept to erase
103 | # unconditional: "person" # word to take the difference from the positive concept
104 | # neutral: "person" # starting point for conditioning the target
105 | # action: "erase" # erase or enhance
106 | # guidance_scale: 4
107 | # resolution: 512
108 | # dynamic_resolution: false
109 | # batch_size: 1
110 | # overweight
111 | # - target: "art" # what word for erasing the positive concept from
112 | # positive: "realistic art" # concept to erase
113 | # unconditional: "art" # word to take the difference from the positive concept
114 | # neutral: "art" # starting point for conditioning the target
115 | # action: "enhance" # erase or enhance
116 | # guidance_scale: 4
117 | # resolution: 512
118 | # dynamic_resolution: false
119 | # batch_size: 1
120 | # - target: "art" # what word for erasing the positive concept from
121 | # positive: "abstract art" # concept to erase
122 | # unconditional: "art" # word to take the difference from the positive concept
123 | # neutral: "art" # starting point for conditioning the target
124 | # action: "erase" # erase or enhance
125 | # guidance_scale: 4
126 | # resolution: 512
127 | # dynamic_resolution: false
128 | # batch_size: 1
129 | # sky
130 | # - target: "weather" # what word for erasing the positive concept from
131 | # positive: "bright pleasant weather" # concept to erase
132 | # unconditional: "weather" # word to take the difference from the positive concept
133 | # neutral: "weather" # starting point for conditioning the target
134 | # action: "enhance" # erase or enhance
135 | # guidance_scale: 4
136 | # resolution: 512
137 | # dynamic_resolution: false
138 | # batch_size: 1
139 | # - target: "weather" # what word for erasing the positive concept from
140 | # positive: "dark gloomy weather" # concept to erase
141 | # unconditional: "weather" # word to take the difference from the positive concept
142 | # neutral: "weather" # starting point for conditioning the target
143 | # action: "erase" # erase or enhance
144 | # guidance_scale: 4
145 | # resolution: 512
146 | # dynamic_resolution: false
147 | # batch_size: 1
148 | # hair
149 | # - target: "person" # what word for erasing the positive concept from
150 | # positive: "person with long hair" # concept to erase
151 | # unconditional: "person" # word to take the difference from the positive concept
152 | # neutral: "person" # starting point for conditioning the target
153 | # action: "enhance" # erase or enhance
154 | # guidance_scale: 4
155 | # resolution: 512
156 | # dynamic_resolution: false
157 | # batch_size: 1
158 | # - target: "person" # what word for erasing the positive concept from
159 | # positive: "person with short hair" # concept to erase
160 | # unconditional: "person" # word to take the difference from the positive concept
161 | # neutral: "person" # starting point for conditioning the target
162 | # action: "erase" # erase or enhance
163 | # guidance_scale: 4
164 | # resolution: 512
165 | # dynamic_resolution: false
166 | # batch_size: 1
167 | # - target: "girl" # what word for erasing the positive concept from
168 | # positive: "baby girl" # concept to erase
169 | # unconditional: "girl" # word to take the difference from the positive concept
170 | # neutral: "girl" # starting point for conditioning the target
171 | # action: "enhance" # erase or enhance
172 | # guidance_scale: -4
173 | # resolution: 512
174 | # dynamic_resolution: false
175 | # batch_size: 1
176 | # - target: "boy" # what word for erasing the positive concept from
177 | # positive: "old man" # concept to erase
178 | # unconditional: "boy" # word to take the difference from the positive concept
179 | # neutral: "boy" # starting point for conditioning the target
180 | # action: "enhance" # erase or enhance
181 | # guidance_scale: 4
182 | # resolution: 512
183 | # dynamic_resolution: false
184 | # batch_size: 1
185 | # - target: "boy" # what word for erasing the positive concept from
186 | # positive: "baby boy" # concept to erase
187 | # unconditional: "boy" # word to take the difference from the positive concept
188 | # neutral: "boy" # starting point for conditioning the target
189 | # action: "enhance" # erase or enhance
190 | # guidance_scale: -4
191 | # resolution: 512
192 | # dynamic_resolution: false
193 | # batch_size: 1
--------------------------------------------------------------------------------
/trainscripts/textsliders/debug_util.py:
--------------------------------------------------------------------------------
1 | # デバッグ用...
2 |
3 | import torch
4 |
5 |
6 | def check_requires_grad(model: torch.nn.Module):
7 | for name, module in list(model.named_modules())[:5]:
8 | if len(list(module.parameters())) > 0:
9 | print(f"Module: {name}")
10 | for name, param in list(module.named_parameters())[:2]:
11 | print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
12 |
13 |
14 | def check_training_mode(model: torch.nn.Module):
15 | for name, module in list(model.named_modules())[:5]:
16 | print(f"Module: {name}, Training Mode: {module.training}")
17 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/flush.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import gc
3 |
4 | torch.cuda.empty_cache()
5 | gc.collect()
6 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/lora.py:
--------------------------------------------------------------------------------
1 | # ref:
2 | # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
3 | # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
4 |
5 | import os
6 | import math
7 | from typing import Optional, List, Type, Set, Literal
8 |
9 | import torch
10 | import torch.nn as nn
11 | from diffusers import UNet2DConditionModel
12 | from safetensors.torch import save_file
13 |
14 |
15 | UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
16 | # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
17 | "Attention"
18 | ]
19 | UNET_TARGET_REPLACE_MODULE_CONV = [
20 | "ResnetBlock2D",
21 | "Downsample2D",
22 | "Upsample2D",
23 | "DownBlock2D",
24 | "UpBlock2D",
25 |
26 | ] # locon, 3clier
27 |
28 | LORA_PREFIX_UNET = "lora_unet"
29 |
30 | DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
31 |
32 | TRAINING_METHODS = Literal[
33 | "noxattn", # train all layers except x-attns and time_embed layers
34 | "innoxattn", # train all layers except self attention layers
35 | "selfattn", # ESD-u, train only self attention layers
36 | "xattn", # ESD-x, train only x attention layers
37 | "full", # train all layers
38 | "xattn-strict", # q and k values
39 | "noxattn-hspace",
40 | "noxattn-hspace-last",
41 | # "xlayer",
42 | # "outxattn",
43 | # "outsattn",
44 | # "inxattn",
45 | # "inmidsattn",
46 | # "selflayer",
47 | ]
48 |
49 |
50 | class LoRAModule(nn.Module):
51 | """
52 | replaces forward method of the original Linear, instead of replacing the original Linear module.
53 | """
54 |
55 | def __init__(
56 | self,
57 | lora_name,
58 | org_module: nn.Module,
59 | multiplier=1.0,
60 | lora_dim=4,
61 | alpha=1,
62 | ):
63 | """if alpha == 0 or None, alpha is rank (no scaling)."""
64 | super().__init__()
65 | self.lora_name = lora_name
66 | self.lora_dim = lora_dim
67 |
68 | if "Linear" in org_module.__class__.__name__:
69 | in_dim = org_module.in_features
70 | out_dim = org_module.out_features
71 | self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
72 | self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
73 |
74 | elif "Conv" in org_module.__class__.__name__: # 一応
75 | in_dim = org_module.in_channels
76 | out_dim = org_module.out_channels
77 |
78 | self.lora_dim = min(self.lora_dim, in_dim, out_dim)
79 | if self.lora_dim != lora_dim:
80 | print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
81 |
82 | kernel_size = org_module.kernel_size
83 | stride = org_module.stride
84 | padding = org_module.padding
85 | self.lora_down = nn.Conv2d(
86 | in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
87 | )
88 | self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
89 |
90 | if type(alpha) == torch.Tensor:
91 | alpha = alpha.detach().numpy()
92 | alpha = lora_dim if alpha is None or alpha == 0 else alpha
93 | self.scale = alpha / self.lora_dim
94 | self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
95 |
96 | # same as microsoft's
97 | nn.init.kaiming_uniform_(self.lora_down.weight, a=1)
98 | nn.init.zeros_(self.lora_up.weight)
99 |
100 | self.multiplier = multiplier
101 | self.org_module = org_module # remove in applying
102 |
103 | def apply_to(self):
104 | self.org_forward = self.org_module.forward
105 | self.org_module.forward = self.forward
106 | del self.org_module
107 |
108 | def forward(self, x):
109 | return (
110 | self.org_forward(x)
111 | + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
112 | )
113 |
114 |
115 | class LoRANetwork(nn.Module):
116 | def __init__(
117 | self,
118 | unet: UNet2DConditionModel,
119 | rank: int = 4,
120 | multiplier: float = 1.0,
121 | alpha: float = 1.0,
122 | train_method: TRAINING_METHODS = "full",
123 | ) -> None:
124 | super().__init__()
125 | self.lora_scale = 1
126 | self.multiplier = multiplier
127 | self.lora_dim = rank
128 | self.alpha = alpha
129 |
130 | # LoRAのみ
131 | self.module = LoRAModule
132 |
133 | # unetのloraを作る
134 | self.unet_loras = self.create_modules(
135 | LORA_PREFIX_UNET,
136 | unet,
137 | DEFAULT_TARGET_REPLACE,
138 | self.lora_dim,
139 | self.multiplier,
140 | train_method=train_method,
141 | )
142 | print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
143 |
144 | # assertion 名前の被りがないか確認しているようだ
145 | lora_names = set()
146 | for lora in self.unet_loras:
147 | assert (
148 | lora.lora_name not in lora_names
149 | ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
150 | lora_names.add(lora.lora_name)
151 |
152 | # 適用する
153 | for lora in self.unet_loras:
154 | lora.apply_to()
155 | self.add_module(
156 | lora.lora_name,
157 | lora,
158 | )
159 |
160 | del unet
161 |
162 | torch.cuda.empty_cache()
163 |
164 | def create_modules(
165 | self,
166 | prefix: str,
167 | root_module: nn.Module,
168 | target_replace_modules: List[str],
169 | rank: int,
170 | multiplier: float,
171 | train_method: TRAINING_METHODS,
172 | ) -> list:
173 | loras = []
174 | names = []
175 | for name, module in root_module.named_modules():
176 | if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
177 | if "attn2" in name or "time_embed" in name:
178 | continue
179 | elif train_method == "innoxattn": # Cross Attention 以外学習
180 | if "attn2" in name:
181 | continue
182 | elif train_method == "selfattn": # Self Attention のみ学習
183 | if "attn1" not in name:
184 | continue
185 | elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習
186 | if "attn2" not in name:
187 | continue
188 | elif train_method == "full": # 全部学習
189 | pass
190 | else:
191 | raise NotImplementedError(
192 | f"train_method: {train_method} is not implemented."
193 | )
194 | if module.__class__.__name__ in target_replace_modules:
195 | for child_name, child_module in module.named_modules():
196 | if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
197 | if train_method == 'xattn-strict':
198 | if 'out' in child_name:
199 | continue
200 | if train_method == 'noxattn-hspace':
201 | if 'mid_block' not in name:
202 | continue
203 | if train_method == 'noxattn-hspace-last':
204 | if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
205 | continue
206 | lora_name = prefix + "." + name + "." + child_name
207 | lora_name = lora_name.replace(".", "_")
208 | # print(f"{lora_name}")
209 | lora = self.module(
210 | lora_name, child_module, multiplier, rank, self.alpha
211 | )
212 | # print(name, child_name)
213 | # print(child_module.weight.shape)
214 | if lora_name not in names:
215 | loras.append(lora)
216 | names.append(lora_name)
217 | # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
218 | return loras
219 |
220 | def prepare_optimizer_params(self):
221 | all_params = []
222 |
223 | if self.unet_loras: # 実質これしかない
224 | params = []
225 | [params.extend(lora.parameters()) for lora in self.unet_loras]
226 | param_data = {"params": params}
227 | all_params.append(param_data)
228 |
229 | return all_params
230 |
231 | def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
232 | state_dict = self.state_dict()
233 |
234 | if dtype is not None:
235 | for key in list(state_dict.keys()):
236 | v = state_dict[key]
237 | v = v.detach().clone().to("cpu").to(dtype)
238 | state_dict[key] = v
239 |
240 | # for key in list(state_dict.keys()):
241 | # if not key.startswith("lora"):
242 | # # lora以外除外
243 | # del state_dict[key]
244 |
245 | if os.path.splitext(file)[1] == ".safetensors":
246 | save_file(state_dict, file, metadata)
247 | else:
248 | torch.save(state_dict, file)
249 | def set_lora_slider(self, scale):
250 | self.lora_scale = scale
251 |
252 | def __enter__(self):
253 | for lora in self.unet_loras:
254 | lora.multiplier = 1.0 * self.lora_scale
255 |
256 | def __exit__(self, exc_type, exc_value, tb):
257 | for lora in self.unet_loras:
258 | lora.multiplier = 0
259 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/model_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Union, Optional
2 |
3 | import torch
4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5 | from diffusers import (
6 | UNet2DConditionModel,
7 | SchedulerMixin,
8 | StableDiffusionPipeline,
9 | StableDiffusionXLPipeline,
10 | )
11 | from diffusers.schedulers import (
12 | DDIMScheduler,
13 | DDPMScheduler,
14 | LMSDiscreteScheduler,
15 | EulerAncestralDiscreteScheduler,
16 | )
17 |
18 |
19 | TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
20 | TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
21 |
22 | AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
23 |
24 | SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
25 |
26 | DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
27 |
28 |
29 | def load_diffusers_model(
30 | pretrained_model_name_or_path: str,
31 | v2: bool = False,
32 | clip_skip: Optional[int] = None,
33 | weight_dtype: torch.dtype = torch.float32,
34 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
35 | # VAE はいらない
36 |
37 | if v2:
38 | tokenizer = CLIPTokenizer.from_pretrained(
39 | TOKENIZER_V2_MODEL_NAME,
40 | subfolder="tokenizer",
41 | torch_dtype=weight_dtype,
42 | cache_dir=DIFFUSERS_CACHE_DIR,
43 | )
44 | text_encoder = CLIPTextModel.from_pretrained(
45 | pretrained_model_name_or_path,
46 | subfolder="text_encoder",
47 | # default is clip skip 2
48 | num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
49 | torch_dtype=weight_dtype,
50 | cache_dir=DIFFUSERS_CACHE_DIR,
51 | )
52 | else:
53 | tokenizer = CLIPTokenizer.from_pretrained(
54 | TOKENIZER_V1_MODEL_NAME,
55 | subfolder="tokenizer",
56 | torch_dtype=weight_dtype,
57 | cache_dir=DIFFUSERS_CACHE_DIR,
58 | )
59 | text_encoder = CLIPTextModel.from_pretrained(
60 | pretrained_model_name_or_path,
61 | subfolder="text_encoder",
62 | num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
63 | torch_dtype=weight_dtype,
64 | cache_dir=DIFFUSERS_CACHE_DIR,
65 | )
66 |
67 | unet = UNet2DConditionModel.from_pretrained(
68 | pretrained_model_name_or_path,
69 | subfolder="unet",
70 | torch_dtype=weight_dtype,
71 | cache_dir=DIFFUSERS_CACHE_DIR,
72 | )
73 |
74 | return tokenizer, text_encoder, unet
75 |
76 |
77 | def load_checkpoint_model(
78 | checkpoint_path: str,
79 | v2: bool = False,
80 | clip_skip: Optional[int] = None,
81 | weight_dtype: torch.dtype = torch.float32,
82 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
83 | pipe = StableDiffusionPipeline.from_ckpt(
84 | checkpoint_path,
85 | upcast_attention=True if v2 else False,
86 | torch_dtype=weight_dtype,
87 | cache_dir=DIFFUSERS_CACHE_DIR,
88 | )
89 |
90 | unet = pipe.unet
91 | tokenizer = pipe.tokenizer
92 | text_encoder = pipe.text_encoder
93 | if clip_skip is not None:
94 | if v2:
95 | text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
96 | else:
97 | text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
98 |
99 | del pipe
100 |
101 | return tokenizer, text_encoder, unet
102 |
103 |
104 | def load_models(
105 | pretrained_model_name_or_path: str,
106 | scheduler_name: AVAILABLE_SCHEDULERS,
107 | v2: bool = False,
108 | v_pred: bool = False,
109 | weight_dtype: torch.dtype = torch.float32,
110 | ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
111 | if pretrained_model_name_or_path.endswith(
112 | ".ckpt"
113 | ) or pretrained_model_name_or_path.endswith(".safetensors"):
114 | tokenizer, text_encoder, unet = load_checkpoint_model(
115 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
116 | )
117 | else: # diffusers
118 | tokenizer, text_encoder, unet = load_diffusers_model(
119 | pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
120 | )
121 |
122 | # VAE はいらない
123 |
124 | scheduler = create_noise_scheduler(
125 | scheduler_name,
126 | prediction_type="v_prediction" if v_pred else "epsilon",
127 | )
128 |
129 | return tokenizer, text_encoder, unet, scheduler
130 |
131 |
132 | def load_diffusers_model_xl(
133 | pretrained_model_name_or_path: str,
134 | weight_dtype: torch.dtype = torch.float32,
135 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
136 | # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
137 |
138 | tokenizers = [
139 | CLIPTokenizer.from_pretrained(
140 | pretrained_model_name_or_path,
141 | subfolder="tokenizer",
142 | torch_dtype=weight_dtype,
143 | cache_dir=DIFFUSERS_CACHE_DIR,
144 | ),
145 | CLIPTokenizer.from_pretrained(
146 | pretrained_model_name_or_path,
147 | subfolder="tokenizer_2",
148 | torch_dtype=weight_dtype,
149 | cache_dir=DIFFUSERS_CACHE_DIR,
150 | pad_token_id=0, # same as open clip
151 | ),
152 | ]
153 |
154 | text_encoders = [
155 | CLIPTextModel.from_pretrained(
156 | pretrained_model_name_or_path,
157 | subfolder="text_encoder",
158 | torch_dtype=weight_dtype,
159 | cache_dir=DIFFUSERS_CACHE_DIR,
160 | ),
161 | CLIPTextModelWithProjection.from_pretrained(
162 | pretrained_model_name_or_path,
163 | subfolder="text_encoder_2",
164 | torch_dtype=weight_dtype,
165 | cache_dir=DIFFUSERS_CACHE_DIR,
166 | ),
167 | ]
168 |
169 | unet = UNet2DConditionModel.from_pretrained(
170 | pretrained_model_name_or_path,
171 | subfolder="unet",
172 | torch_dtype=weight_dtype,
173 | cache_dir=DIFFUSERS_CACHE_DIR,
174 | )
175 |
176 | return tokenizers, text_encoders, unet
177 |
178 |
179 | def load_checkpoint_model_xl(
180 | checkpoint_path: str,
181 | weight_dtype: torch.dtype = torch.float32,
182 | ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
183 | pipe = StableDiffusionXLPipeline.from_single_file(
184 | checkpoint_path,
185 | torch_dtype=weight_dtype,
186 | cache_dir=DIFFUSERS_CACHE_DIR,
187 | )
188 |
189 | unet = pipe.unet
190 | tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
191 | text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
192 | if len(text_encoders) == 2:
193 | text_encoders[1].pad_token_id = 0
194 |
195 | del pipe
196 |
197 | return tokenizers, text_encoders, unet
198 |
199 |
200 | def load_models_xl(
201 | pretrained_model_name_or_path: str,
202 | scheduler_name: AVAILABLE_SCHEDULERS,
203 | weight_dtype: torch.dtype = torch.float32,
204 | ) -> tuple[
205 | list[CLIPTokenizer],
206 | list[SDXL_TEXT_ENCODER_TYPE],
207 | UNet2DConditionModel,
208 | SchedulerMixin,
209 | ]:
210 | if pretrained_model_name_or_path.endswith(
211 | ".ckpt"
212 | ) or pretrained_model_name_or_path.endswith(".safetensors"):
213 | (
214 | tokenizers,
215 | text_encoders,
216 | unet,
217 | ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
218 | else: # diffusers
219 | (
220 | tokenizers,
221 | text_encoders,
222 | unet,
223 | ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
224 |
225 | scheduler = create_noise_scheduler(scheduler_name)
226 |
227 | return tokenizers, text_encoders, unet, scheduler
228 |
229 |
230 | def create_noise_scheduler(
231 | scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
232 | prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
233 | ) -> SchedulerMixin:
234 | # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。
235 |
236 | name = scheduler_name.lower().replace(" ", "_")
237 | if name == "ddim":
238 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
239 | scheduler = DDIMScheduler(
240 | beta_start=0.00085,
241 | beta_end=0.012,
242 | beta_schedule="scaled_linear",
243 | num_train_timesteps=1000,
244 | clip_sample=False,
245 | prediction_type=prediction_type, # これでいいの?
246 | )
247 | elif name == "ddpm":
248 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
249 | scheduler = DDPMScheduler(
250 | beta_start=0.00085,
251 | beta_end=0.012,
252 | beta_schedule="scaled_linear",
253 | num_train_timesteps=1000,
254 | clip_sample=False,
255 | prediction_type=prediction_type,
256 | )
257 | elif name == "lms":
258 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
259 | scheduler = LMSDiscreteScheduler(
260 | beta_start=0.00085,
261 | beta_end=0.012,
262 | beta_schedule="scaled_linear",
263 | num_train_timesteps=1000,
264 | prediction_type=prediction_type,
265 | )
266 | elif name == "euler_a":
267 | # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
268 | scheduler = EulerAncestralDiscreteScheduler(
269 | beta_start=0.00085,
270 | beta_end=0.012,
271 | beta_schedule="scaled_linear",
272 | num_train_timesteps=1000,
273 | prediction_type=prediction_type,
274 | )
275 | else:
276 | raise ValueError(f"Unknown scheduler name: {name}")
277 |
278 | return scheduler
279 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/prompt_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional, Union, List
2 |
3 | import yaml
4 | from pathlib import Path
5 |
6 |
7 | from pydantic import BaseModel, root_validator
8 | import torch
9 | import copy
10 |
11 | ACTION_TYPES = Literal[
12 | "erase",
13 | "enhance",
14 | ]
15 |
16 |
17 | # XL は二種類必要なので
18 | class PromptEmbedsXL:
19 | text_embeds: torch.FloatTensor
20 | pooled_embeds: torch.FloatTensor
21 |
22 | def __init__(self, *args) -> None:
23 | self.text_embeds = args[0]
24 | self.pooled_embeds = args[1]
25 |
26 |
27 | # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
28 | PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
29 |
30 |
31 | class PromptEmbedsCache: # 使いまわしたいので
32 | prompts: dict[str, PROMPT_EMBEDDING] = {}
33 |
34 | def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
35 | self.prompts[__name] = __value
36 |
37 | def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
38 | if __name in self.prompts:
39 | return self.prompts[__name]
40 | else:
41 | return None
42 |
43 |
44 | class PromptSettings(BaseModel): # yaml のやつ
45 | target: str
46 | positive: str = None # if None, target will be used
47 | unconditional: str = "" # default is ""
48 | neutral: str = None # if None, unconditional will be used
49 | action: ACTION_TYPES = "erase" # default is "erase"
50 | guidance_scale: float = 1.0 # default is 1.0
51 | resolution: int = 512 # default is 512
52 | dynamic_resolution: bool = False # default is False
53 | batch_size: int = 1 # default is 1
54 | dynamic_crops: bool = False # default is False. only used when model is XL
55 |
56 | @root_validator(pre=True)
57 | def fill_prompts(cls, values):
58 | keys = values.keys()
59 | if "target" not in keys:
60 | raise ValueError("target must be specified")
61 | if "positive" not in keys:
62 | values["positive"] = values["target"]
63 | if "unconditional" not in keys:
64 | values["unconditional"] = ""
65 | if "neutral" not in keys:
66 | values["neutral"] = values["unconditional"]
67 |
68 | return values
69 |
70 |
71 | class PromptEmbedsPair:
72 | target: PROMPT_EMBEDDING # not want to generate the concept
73 | positive: PROMPT_EMBEDDING # generate the concept
74 | unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
75 | neutral: PROMPT_EMBEDDING # base condition (default should be empty)
76 |
77 | guidance_scale: float
78 | resolution: int
79 | dynamic_resolution: bool
80 | batch_size: int
81 | dynamic_crops: bool
82 |
83 | loss_fn: torch.nn.Module
84 | action: ACTION_TYPES
85 |
86 | def __init__(
87 | self,
88 | loss_fn: torch.nn.Module,
89 | target: PROMPT_EMBEDDING,
90 | positive: PROMPT_EMBEDDING,
91 | unconditional: PROMPT_EMBEDDING,
92 | neutral: PROMPT_EMBEDDING,
93 | settings: PromptSettings,
94 | ) -> None:
95 | self.loss_fn = loss_fn
96 | self.target = target
97 | self.positive = positive
98 | self.unconditional = unconditional
99 | self.neutral = neutral
100 |
101 | self.guidance_scale = settings.guidance_scale
102 | self.resolution = settings.resolution
103 | self.dynamic_resolution = settings.dynamic_resolution
104 | self.batch_size = settings.batch_size
105 | self.dynamic_crops = settings.dynamic_crops
106 | self.action = settings.action
107 |
108 | def _erase(
109 | self,
110 | target_latents: torch.FloatTensor, # "van gogh"
111 | positive_latents: torch.FloatTensor, # "van gogh"
112 | unconditional_latents: torch.FloatTensor, # ""
113 | neutral_latents: torch.FloatTensor, # ""
114 | ) -> torch.FloatTensor:
115 | """Target latents are going not to have the positive concept."""
116 | return self.loss_fn(
117 | target_latents,
118 | neutral_latents
119 | - self.guidance_scale * (positive_latents - unconditional_latents)
120 | )
121 |
122 |
123 | def _enhance(
124 | self,
125 | target_latents: torch.FloatTensor, # "van gogh"
126 | positive_latents: torch.FloatTensor, # "van gogh"
127 | unconditional_latents: torch.FloatTensor, # ""
128 | neutral_latents: torch.FloatTensor, # ""
129 | ):
130 | """Target latents are going to have the positive concept."""
131 | return self.loss_fn(
132 | target_latents,
133 | neutral_latents
134 | + self.guidance_scale * (positive_latents - unconditional_latents)
135 | )
136 |
137 | def loss(
138 | self,
139 | **kwargs,
140 | ):
141 | if self.action == "erase":
142 | return self._erase(**kwargs)
143 |
144 | elif self.action == "enhance":
145 | return self._enhance(**kwargs)
146 |
147 | else:
148 | raise ValueError("action must be erase or enhance")
149 |
150 |
151 | def load_prompts_from_yaml(path, attributes = []):
152 | with open(path, "r") as f:
153 | prompts = yaml.safe_load(f)
154 | print(prompts)
155 | if len(prompts) == 0:
156 | raise ValueError("prompts file is empty")
157 | if len(attributes)!=0:
158 | newprompts = []
159 | for i in range(len(prompts)):
160 | for att in attributes:
161 | copy_ = copy.deepcopy(prompts[i])
162 | copy_['target'] = att + ' ' + copy_['target']
163 | copy_['positive'] = att + ' ' + copy_['positive']
164 | copy_['neutral'] = att + ' ' + copy_['neutral']
165 | copy_['unconditional'] = att + ' ' + copy_['unconditional']
166 | newprompts.append(copy_)
167 | else:
168 | newprompts = copy.deepcopy(prompts)
169 |
170 | print(newprompts)
171 | print(len(prompts), len(newprompts))
172 | prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
173 |
174 | return prompt_settings
175 |
--------------------------------------------------------------------------------
/trainscripts/textsliders/ptp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch
17 | from PIL import Image, ImageDraw, ImageFont
18 | import cv2
19 | from typing import Optional, Union, Tuple, List, Callable, Dict
20 | from IPython.display import display
21 | from tqdm.notebook import tqdm
22 |
23 |
24 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
25 | h, w, c = image.shape
26 | offset = int(h * .2)
27 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
28 | font = cv2.FONT_HERSHEY_SIMPLEX
29 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
30 | img[:h] = image
31 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
32 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
33 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
34 | return img
35 |
36 |
37 | def view_images(images, num_rows=1, offset_ratio=0.02):
38 | if type(images) is list:
39 | num_empty = len(images) % num_rows
40 | elif images.ndim == 4:
41 | num_empty = images.shape[0] % num_rows
42 | else:
43 | images = [images]
44 | num_empty = 0
45 |
46 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
47 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
48 | num_items = len(images)
49 |
50 | h, w, c = images[0].shape
51 | offset = int(h * offset_ratio)
52 | num_cols = num_items // num_rows
53 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
54 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
55 | for i in range(num_rows):
56 | for j in range(num_cols):
57 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
58 | i * num_cols + j]
59 |
60 | pil_img = Image.fromarray(image_)
61 | display(pil_img)
62 |
63 |
64 | def diffusion_step(unet, model, controller, latents, context, t, guidance_scale, low_resource=False):
65 | if low_resource:
66 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
67 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
68 | else:
69 | latents_input = torch.cat([latents] * 2)
70 | noise_pred = unet(latents_input, t, encoder_hidden_states=context)["sample"]
71 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
72 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
73 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
74 | latents = controller.step_callback(latents)
75 | return latents
76 |
77 |
78 | def latent2image(vae, latents):
79 | latents = 1 / 0.18215 * latents
80 | image = vae.decode(latents)['sample']
81 | image = (image / 2 + 0.5).clamp(0, 1)
82 | image = image.cpu().permute(0, 2, 3, 1).numpy()
83 | image = (image * 255).astype(np.uint8)
84 | return image
85 |
86 |
87 | def init_latent(latent, model, height, width, generator, batch_size):
88 | if latent is None:
89 | latent = torch.randn(
90 | (1, model.unet.in_channels, height // 8, width // 8),
91 | generator=generator,
92 | )
93 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
94 | return latent, latents
95 |
96 |
97 | @torch.no_grad()
98 | def text2image_ldm(
99 | model,
100 | prompt: List[str],
101 | controller,
102 | num_inference_steps: int = 50,
103 | guidance_scale: Optional[float] = 7.,
104 | generator: Optional[torch.Generator] = None,
105 | latent: Optional[torch.FloatTensor] = None,
106 | ):
107 | register_attention_control(model, controller)
108 | height = width = 256
109 | batch_size = len(prompt)
110 |
111 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
112 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
113 |
114 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
115 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
116 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
117 | context = torch.cat([uncond_embeddings, text_embeddings])
118 |
119 | model.scheduler.set_timesteps(num_inference_steps)
120 | for t in tqdm(model.scheduler.timesteps):
121 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
122 |
123 | image = latent2image(model.vqvae, latents)
124 |
125 | return image, latent
126 |
127 |
128 | @torch.no_grad()
129 | def text2image_ldm_stable(
130 | model,
131 | prompt: List[str],
132 | controller,
133 | num_inference_steps: int = 50,
134 | guidance_scale: float = 7.5,
135 | generator: Optional[torch.Generator] = None,
136 | latent: Optional[torch.FloatTensor] = None,
137 | low_resource: bool = False,
138 | ):
139 | register_attention_control(model, controller)
140 | height = width = 512
141 | batch_size = len(prompt)
142 |
143 | text_input = model.tokenizer(
144 | prompt,
145 | padding="max_length",
146 | max_length=model.tokenizer.model_max_length,
147 | truncation=True,
148 | return_tensors="pt",
149 | )
150 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
151 | max_length = text_input.input_ids.shape[-1]
152 | uncond_input = model.tokenizer(
153 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
154 | )
155 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
156 |
157 | context = [uncond_embeddings, text_embeddings]
158 | if not low_resource:
159 | context = torch.cat(context)
160 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
161 |
162 | # set timesteps
163 | extra_set_kwargs = {"offset": 1}
164 | model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
165 | for t in tqdm(model.scheduler.timesteps):
166 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
167 |
168 | image = latent2image(model.vae, latents)
169 |
170 | return image, latent
171 |
172 |
173 | def register_attention_control(model, controller):
174 | def ca_forward(self, place_in_unet):
175 | to_out = self.to_out
176 | if type(to_out) is torch.nn.modules.container.ModuleList:
177 | to_out = self.to_out[0]
178 | else:
179 | to_out = self.to_out
180 |
181 | def forward(x, context=None, mask=None):
182 | batch_size, sequence_length, dim = x.shape
183 | h = self.heads
184 | q = self.to_q(x)
185 | is_cross = context is not None
186 | context = context if is_cross else x
187 | k = self.to_k(context)
188 | v = self.to_v(context)
189 | q = self.reshape_heads_to_batch_dim(q)
190 | k = self.reshape_heads_to_batch_dim(k)
191 | v = self.reshape_heads_to_batch_dim(v)
192 |
193 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
194 |
195 | if mask is not None:
196 | mask = mask.reshape(batch_size, -1)
197 | max_neg_value = -torch.finfo(sim.dtype).max
198 | mask = mask[:, None, :].repeat(h, 1, 1)
199 | sim.masked_fill_(~mask, max_neg_value)
200 |
201 | # attention, what we cannot get enough of
202 | attn = sim.softmax(dim=-1)
203 | attn = controller(attn, is_cross, place_in_unet)
204 | out = torch.einsum("b i j, b j d -> b i d", attn, v)
205 | out = self.reshape_batch_dim_to_heads(out)
206 | return to_out(out)
207 |
208 | return forward
209 |
210 | class DummyController:
211 |
212 | def __call__(self, *args):
213 | return args[0]
214 |
215 | def __init__(self):
216 | self.num_att_layers = 0
217 |
218 | if controller is None:
219 | controller = DummyController()
220 |
221 | def register_recr(net_, count, place_in_unet):
222 | if net_.__class__.__name__ == 'CrossAttention':
223 | net_.forward = ca_forward(net_, place_in_unet)
224 | return count + 1
225 | elif hasattr(net_, 'children'):
226 | for net__ in net_.children():
227 | count = register_recr(net__, count, place_in_unet)
228 | return count
229 |
230 | cross_att_count = 0
231 | sub_nets = model.unet.named_children()
232 | for net in sub_nets:
233 | if "down" in net[0]:
234 | cross_att_count += register_recr(net[1], 0, "down")
235 | elif "up" in net[0]:
236 | cross_att_count += register_recr(net[1], 0, "up")
237 | elif "mid" in net[0]:
238 | cross_att_count += register_recr(net[1], 0, "mid")
239 |
240 | controller.num_att_layers = cross_att_count
241 |
242 |
243 | def get_word_inds(text: str, word_place: int, tokenizer):
244 | split_text = text.split(" ")
245 | if type(word_place) is str:
246 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
247 | elif type(word_place) is int:
248 | word_place = [word_place]
249 | out = []
250 | if len(word_place) > 0:
251 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
252 | cur_len, ptr = 0, 0
253 |
254 | for i in range(len(words_encode)):
255 | cur_len += len(words_encode[i])
256 | if ptr in word_place:
257 | out.append(i + 1)
258 | if cur_len >= len(split_text[ptr]):
259 | ptr += 1
260 | cur_len = 0
261 | return np.array(out)
262 |
263 |
264 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
265 | word_inds: Optional[torch.Tensor]=None):
266 | if type(bounds) is float:
267 | bounds = 0, bounds
268 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
269 | if word_inds is None:
270 | word_inds = torch.arange(alpha.shape[2])
271 | alpha[: start, prompt_ind, word_inds] = 0
272 | alpha[start: end, prompt_ind, word_inds] = 1
273 | alpha[end:, prompt_ind, word_inds] = 0
274 | return alpha
275 |
276 |
277 | def get_time_words_attention_alpha(prompts, num_steps,
278 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
279 | tokenizer, max_num_words=77):
280 | if type(cross_replace_steps) is not dict:
281 | cross_replace_steps = {"default_": cross_replace_steps}
282 | if "default_" not in cross_replace_steps:
283 | cross_replace_steps["default_"] = (0., 1.)
284 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
285 | for i in range(len(prompts) - 1):
286 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
287 | i)
288 | for key, item in cross_replace_steps.items():
289 | if key != "default_":
290 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
291 | for i, ind in enumerate(inds):
292 | if len(ind) > 0:
293 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
294 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
295 | return alpha_time_words
--------------------------------------------------------------------------------