├── utils ├── __init__.py ├── labelmap.py └── posemap.py ├── assets └── teaser.png ├── examples ├── example1.jpg ├── example2.jpg ├── example2_caption.txt ├── example1_fine_mask.jpg ├── example2_fine_mask.jpg ├── example1_binary_mask.jpg ├── example2_binary_mask.jpg └── example1_caption.txt ├── requirements.txt ├── .gitignore ├── precompute_utils ├── precompute_image_features.py ├── captioning_qwen.py └── precompute_text_features.py ├── inference.py ├── README.md ├── inference_dataset.py ├── metrics.py ├── src ├── transformer_sd3_garm.py ├── transformer_vtoff.py └── transformer_vtoff_repa.py └── dataset ├── vitonhd.py └── dresscode.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidelobba/TEMU-VTOFF/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /examples/example1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidelobba/TEMU-VTOFF/HEAD/examples/example1.jpg -------------------------------------------------------------------------------- /examples/example2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidelobba/TEMU-VTOFF/HEAD/examples/example2.jpg -------------------------------------------------------------------------------- /examples/example2_caption.txt: -------------------------------------------------------------------------------- 1 | a satin wide-leg pants with a high waist and a straight fit, reaching the ankles -------------------------------------------------------------------------------- /examples/example1_fine_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidelobba/TEMU-VTOFF/HEAD/examples/example1_fine_mask.jpg -------------------------------------------------------------------------------- /examples/example2_fine_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidelobba/TEMU-VTOFF/HEAD/examples/example2_fine_mask.jpg -------------------------------------------------------------------------------- /examples/example1_binary_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidelobba/TEMU-VTOFF/HEAD/examples/example1_binary_mask.jpg -------------------------------------------------------------------------------- /examples/example2_binary_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidelobba/TEMU-VTOFF/HEAD/examples/example2_binary_mask.jpg -------------------------------------------------------------------------------- /examples/example1_caption.txt: -------------------------------------------------------------------------------- 1 | A sheer dress with a fitted waist, loose fit, hem with a decorative trim, neckline with a tie detail, three-quarter sleeve length, and a long cloth length. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | transformers 4 | accelerate 5 | sentencepiece 6 | prodigyopt 7 | opencv-python-headless 8 | wandb 9 | matplotlib 10 | einops 11 | scikit-image 12 | clean-fid==0.1.35 13 | torchmetrics 14 | deepspeed 15 | diffusers @ git+https://github.com/huggingface/diffusers 16 | timm 17 | qwen_vl_utils 18 | prettytable 19 | tqdm 20 | DISTS-pytorch 21 | pytorch-fid -------------------------------------------------------------------------------- /utils/labelmap.py: -------------------------------------------------------------------------------- 1 | label_map={ 2 | "background": 0, 3 | "hat": 1, 4 | "hair": 2, 5 | "sunglasses": 3, 6 | "upper_clothes": 4, 7 | "skirt": 5, 8 | "pants": 6, 9 | "dress": 7, 10 | "belt": 8, 11 | "left_shoe": 9, 12 | "right_shoe": 10, 13 | "head": 11, 14 | "left_leg": 12, 15 | "right_leg": 13, 16 | "left_arm": 14, 17 | "right_arm": 15, 18 | "bag": 16, 19 | "scarf": 17, 20 | } -------------------------------------------------------------------------------- /utils/posemap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | 6 | def kpoint_to_heatmap(kpoint, shape, sigma): 7 | """Converts a 2D keypoint to a gaussian heatmap 8 | 9 | Parameters 10 | ---------- 11 | kpoint: np.array 12 | 2D coordinates of keypoint [x, y]. 13 | shape: tuple 14 | Heatmap dimension (HxW). 15 | sigma: float 16 | Variance value of the gaussian. 17 | 18 | Returns 19 | ------- 20 | heatmap: np.array 21 | A gaussian heatmap HxW. 22 | """ 23 | map_h = shape[0] 24 | map_w = shape[1] 25 | if np.any(kpoint > 0): 26 | x, y = kpoint 27 | # x = x * map_w / 384.0 28 | # y = y * map_h / 512.0 29 | xy_grid = np.mgrid[:map_w, :map_h].transpose(2, 1, 0) 30 | heatmap = np.exp(-np.sum((xy_grid - (x, y)) ** 2, axis=-1) / sigma ** 2) 31 | heatmap /= (heatmap.max() + np.finfo('float32').eps) 32 | else: 33 | heatmap = np.zeros((map_h, map_w)) 34 | return torch.Tensor(heatmap) 35 | 36 | def get_coco_body25_mapping(): 37 | #left numbers are coco format while right numbers are body25 format 38 | return { 39 | 0:0, 40 | 1:1, 41 | 2:2, 42 | 3:3, 43 | 4:4, 44 | 5:5, 45 | 6:6, 46 | 7:7, 47 | 8:9, 48 | 9:10, 49 | 10:11, 50 | 11:12, 51 | 12:13, 52 | 13:14, 53 | 14:15, 54 | 15:16, 55 | 16:17, 56 | 17:18 57 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | ### Python Patch ### 166 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 167 | poetry.toml 168 | 169 | # ruff 170 | .ruff_cache/ 171 | 172 | # LSP config files 173 | pyrightconfig.json 174 | 175 | # End of https://www.toptal.com/developers/gitignore/api/python 176 | -------------------------------------------------------------------------------- /precompute_utils/precompute_image_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 5 | import torch 6 | from dataset.dresscode import DressCodeDataset 7 | from dataset.vitonhd import VitonHDDataset 8 | from accelerate import Accelerator 9 | from tqdm import tqdm 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 13 | parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"]) 14 | parser.add_argument("--order", type=str, required=True, choices=["unpaired", "paired"]) 15 | parser.add_argument("--num_workers", type=int, default=1) 16 | parser.add_argument("--batch_size", type=int, default=1) 17 | parser.add_argument("--width", type=int, required=True) 18 | parser.add_argument("--height", type=int, required=True) 19 | 20 | parser.add_argument("--dataroot", type=str, required=True, help="Path to the dataset root directory.") 21 | parser.add_argument("--phase", type=str, required=True, help="Phase of the dataset (e.g. train, test).") 22 | parser.add_argument("--category", type=str, required=True, choices=["all", "dresses", "upper_body", "lower_body"], help="Category of the dataset.") 23 | parser.add_argument("--dataset", type=str, required=True, choices=["dresscode", "vitonhd"], help="Dataset to use.") 24 | parser.add_argument("--mask_type", type=str, required=True, default="bounding_box", help="Type of mask to use.") 25 | 26 | 27 | args = parser.parse_args() 28 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 29 | if env_local_rank != -1 and env_local_rank != args.local_rank: 30 | args.local_rank = env_local_rank 31 | 32 | return args 33 | 34 | def get_clip_image_embeds(image_encoder_large, image_encoder_bigG, image): 35 | image_embeds_large = image_encoder_large(image).image_embeds 36 | image_embeds_bigG = image_encoder_bigG(image).image_embeds 37 | return torch.cat([image_embeds_large, image_embeds_bigG], dim=1) 38 | 39 | @torch.inference_mode() 40 | def main(args): 41 | accelerator = Accelerator( 42 | mixed_precision=args.mixed_precision, 43 | cpu=args.force_cpu 44 | ) 45 | device = accelerator.device 46 | weight_dtype = torch.float32 47 | if accelerator.mixed_precision == "fp16": 48 | weight_dtype = torch.float16 49 | elif accelerator.mixed_precision == "bf16": 50 | weight_dtype = torch.bfloat16 51 | 52 | vit_processing = CLIPImageProcessor() 53 | 54 | image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype) 55 | image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype) 56 | 57 | outputlist = ['image', 'cloth', 'category', 'im_name', 'c_name'] 58 | if args.category == "all": 59 | args.category = ["upper_body", "lower_body", "dresses"] 60 | else: 61 | args.category = [args.category] 62 | 63 | if args.dataset == "dresscode": 64 | test_dataset = DressCodeDataset( 65 | dataroot_path=args.dataroot, 66 | phase=args.phase, 67 | order=args.order, 68 | radius=5, 69 | outputlist=outputlist, 70 | sketch_threshold_range=(20, 20), 71 | category=args.category, 72 | size=(args.height, args.width), 73 | mask_type=args.mask_type, 74 | ) 75 | 76 | elif args.dataset == "vitonhd": 77 | test_dataset = VitonHDDataset( 78 | dataroot_path=args.dataroot, 79 | phase=args.phase, 80 | order=args.order, 81 | radius=5, 82 | outputlist=outputlist, 83 | sketch_threshold_range=(20, 20), 84 | size=(args.height, args.width), 85 | mask_type=args.mask_type, 86 | ) 87 | else: 88 | raise NotImplementedError(f"Dataset {args.dataset} not implemented") 89 | 90 | batch = test_dataset[0] 91 | test_dataloader = torch.utils.data.DataLoader( 92 | test_dataset, 93 | shuffle=False, 94 | batch_size=args.batch_size, 95 | num_workers=args.num_workers, 96 | ) 97 | 98 | dataroot = test_dataloader.dataset.dataroot 99 | test_dataloader = accelerator.prepare(test_dataloader) 100 | 101 | 102 | print("Start extracting image embeddings") 103 | for idx, batch in enumerate(tqdm(test_dataloader)): 104 | vton_img = batch.get("image").to('cpu') 105 | vton_img = (vton_img+1) /2 106 | category = batch.get("category") 107 | 108 | bs = len(category) 109 | 110 | vton_img = vton_img.to(device=device) 111 | vton_img_vit = vit_processing(images=vton_img, return_tensors="pt").data['pixel_values'] 112 | vton_img_vit = vton_img_vit.to(device=device) 113 | vton_img_embeds = get_clip_image_embeds(image_encoder_large, image_encoder_bigG, vton_img_vit, device) 114 | 115 | for i in range(vton_img.shape[0]): 116 | vton_name = batch["im_name"][i] 117 | vton_image_embeds = vton_img_embeds[i] 118 | if args.dataset == "dresscode": 119 | image_feat_path = Path(dataroot) / "vton_image_embeddings" / "CLIP_VIT_L_VIT_G_concat" / batch["category"][i] / f"{vton_name.split('.')[0]}.pt" 120 | elif args.dataset == "vitonhd": 121 | image_feat_path = Path(dataroot) / "vton_image_embeddings" / "CLIP_VIT_L_VIT_G_concat" / test_dataloader.dataset.phase / f"{vton_name.split('.')[0]}.pt" 122 | else: raise NotImplementedError("dataset not supported") 123 | if not image_feat_path.parent.exists(): 124 | image_feat_path.parent.mkdir(parents=True, exist_ok=True) 125 | 126 | if not image_feat_path.exists(): 127 | torch.save(vton_image_embeds, image_feat_path) 128 | 129 | 130 | if __name__ == '__main__': 131 | args = parse_args() 132 | main(args) -------------------------------------------------------------------------------- /precompute_utils/captioning_qwen.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration 2 | from qwen_vl_utils import process_vision_info 3 | import os 4 | from accelerate.utils import set_seed 5 | import torch 6 | import json 7 | import argparse 8 | import sys 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 10 | from dataset.dresscode import DressCodeDataset 11 | from dataset.vitonhd import VitonHDDataset 12 | from tqdm import tqdm 13 | 14 | def parse_args(input_args=None): 15 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 16 | parser.add_argument( 17 | "--pretrained_model_name_or_path", 18 | type=str, 19 | default=None, 20 | required=False, 21 | help="", 22 | ) 23 | parser.add_argument( 24 | "--dataset_name", 25 | type=str, 26 | required=False, 27 | choices=["dresscode", "vitonhd"], 28 | help="", 29 | ) 30 | parser.add_argument( 31 | "--dataset_root", 32 | type=str, 33 | default=None, 34 | required=False, 35 | help="", 36 | ) 37 | parser.add_argument( 38 | "--filename", 39 | type=str, 40 | default=None, 41 | required=False, 42 | help="Path to visual attributes annotations for RAG enhanced VLM captioning", 43 | ) 44 | parser.add_argument("--height",type=int,default=1024) 45 | parser.add_argument("--width",type=int,default=768) 46 | parser.add_argument("--temperatures",type=float, nargs="+", default=[0.2]) 47 | 48 | 49 | args = parser.parse_args() 50 | return args 51 | 52 | def qwen_captioning(args): 53 | set_seed(42) 54 | 55 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 56 | args.pretrained_model_name_or_path, 57 | torch_dtype=torch.bfloat16, 58 | device_map="cuda", 59 | ) 60 | 61 | processor = AutoProcessor.from_pretrained(args.pretrained_model_name_or_path, use_fast=True) 62 | visual_attributes = {"dresses": ["Cloth Type", "Waist", "Fit", "Hem", "Neckline", "Sleeve Length", "Cloth Length"], 63 | "upper_body": ["Cloth Type", "Waist", "Fit", "Hem", "Neckline", "Sleeve Length", "Cloth Length"], 64 | "lower_body": ["Cloth Type", "Waist", "Fit", "Cloth Length"]} 65 | 66 | outputlist=["c_name", "category"] 67 | category=["upper_body", "lower_body", "dresses"] 68 | if args.dataset_name == "dresscode": 69 | train_dataset = DressCodeDataset( 70 | dataroot_path=args.dataset_root, 71 | phase="train", 72 | order="paired", 73 | radius=5, 74 | outputlist=outputlist, 75 | sketch_threshold_range=(20, 20), 76 | category=category, 77 | size=(args.height, args.width), 78 | ) 79 | test_dataset = DressCodeDataset( 80 | dataroot_path=args.dataset_root, 81 | phase="test", 82 | order="paired", 83 | radius=5, 84 | outputlist=outputlist, 85 | sketch_threshold_range=(20, 20), 86 | category=category, 87 | size=(args.height, args.width), 88 | ) 89 | elif args.dataset_name == "vitonhd": 90 | train_dataset = VitonHDDataset( 91 | dataroot_path=args.dataset_root, 92 | phase="train", 93 | order="paired", 94 | radius=5, 95 | outputlist=outputlist, 96 | sketch_threshold_range=(20, 20), 97 | size=(args.height, args.width), 98 | ) 99 | test_dataset = VitonHDDataset( 100 | dataroot_path=args.dataset_root, 101 | phase="test", 102 | order="paired", 103 | radius=5, 104 | outputlist=outputlist, 105 | sketch_threshold_range=(20, 20), 106 | size=(args.height, args.width), 107 | ) 108 | 109 | dataroot_names = train_dataset.dataroot_names + test_dataset.dataroot_names 110 | phases = ["train" for _ in range(len(train_dataset))] + ["test" for _ in range(len(test_dataset))] 111 | c_names = train_dataset.c_names + test_dataset.c_names 112 | temperatures = args.temperatures 113 | vlm_captions_dicts = [{} for _ in range(len(temperatures))] 114 | 115 | for i, t in enumerate(temperatures): 116 | t = str(temperatures[i]).replace(".","_") 117 | if os.path.exists(os.path.join(args.dataset_root, args.filename.replace(".json", f"_{t}.json"))): 118 | with open(os.path.join(args.dataset_root, args.filename.replace(".json", f"_{t}.json"))) as f: 119 | vlm_captions_dicts[i] = json.load(f) 120 | 121 | idx = 0 122 | for dataroot, phase, c_name in tqdm(zip(dataroot_names, phases, c_names)): 123 | idx+=1 124 | item_id = c_name.split("_")[0] 125 | if item_id in vlm_captions_dicts[0].keys(): 126 | continue 127 | 128 | if args.dataset_name == "vitonhd": 129 | category="upper_body" 130 | img_path = os.path.join( 131 | dataroot, phase, 'cloth', c_name) 132 | 133 | elif args.dataset_name == "dresscode": 134 | category = dataroot.split("/")[-1] 135 | img_path = os.path.join( 136 | dataroot, 'cleaned_inshop_imgs', c_name.replace(".jpg", "_cleaned.jpg")) 137 | if not os.path.exists(img_path): img_path = os.path.join(dataroot, 'images', c_name) 138 | 139 | messages = [ 140 | { "role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, 141 | { "role": "user", 142 | "content": [ 143 | { 144 | "type": "image", 145 | "image": f"{img_path}", 146 | }, 147 | {"type": "text", "text": 148 | "Use only visual attributes that are present in the image. " 149 | f"Predict values of the following attributes: {visual_attributes[category]}. " 150 | "It's forbidden to generate the following visual attributes: colors, background and textures/patterns. " 151 | "It's forbidden to generate unspecified predictions. It's forbidden to generate newlines characters. " 152 | f"Generate in this way: a with " 153 | 154 | }, 155 | ], 156 | } 157 | ] 158 | 159 | # Preparation for inference 160 | text = processor.apply_chat_template( 161 | messages, tokenize=False, add_generation_prompt=True 162 | ) 163 | image_inputs, video_inputs = process_vision_info(messages) 164 | inputs = processor( 165 | text=[text], 166 | images=image_inputs, 167 | videos=video_inputs, 168 | padding=True, 169 | return_tensors="pt", 170 | ) 171 | inputs = inputs.to("cuda") 172 | 173 | for i, temperature in enumerate(temperatures): 174 | generated_ids = model.generate(**inputs, max_new_tokens=75, temperature=temperature, do_sample=True, 175 | pad_token_id=processor.tokenizer.eos_token_id) 176 | generated_ids_trimmed = [ 177 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 178 | ] 179 | output_text = processor.batch_decode( 180 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 181 | ) 182 | vlm_captions_dicts[i][item_id] = output_text 183 | 184 | if idx % 1000 == 0: 185 | for i, vlm_captions_dict in enumerate(vlm_captions_dicts): 186 | t = str(temperatures[i]).replace(".","_") 187 | with open(os.path.join(args.dataset_root, args.filename.replace(".json", f"_{t}.json")), "w") as f: 188 | json.dump(vlm_captions_dict, f) 189 | 190 | for i, vlm_captions_dict in enumerate(vlm_captions_dicts): 191 | t = str(temperatures[i]).replace(".","_") 192 | with open(os.path.join(args.dataset_root, args.filename.replace(".json", f"_{t}.json")), "w") as f: 193 | json.dump(vlm_captions_dict, f) 194 | 195 | if __name__ == "__main__": 196 | args = parse_args() 197 | qwen_captioning(args) 198 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | from diffusers import ( 5 | AutoencoderKL, 6 | FlowMatchEulerDiscreteScheduler, 7 | ) 8 | from src.transformer_vtoff import SD3Transformer2DModel 9 | from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DModel_feature_extractor 10 | from src.pipeline_stable_diffusion_3_tryoff_masked import StableDiffusion3TryOffPipelineMasked 11 | from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast 12 | from transformers import CLIPVisionModelWithProjection 13 | from PIL import Image 14 | from torchvision import transforms 15 | 16 | def load_text_encoders(class_one, class_two, class_three): 17 | text_encoder_one, text_encoder_two, text_encoder_three = None, None, None 18 | if class_one is not None and class_two is not None: 19 | text_encoder_one = class_one.from_pretrained( 20 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant, 21 | ) 22 | text_encoder_two = class_two.from_pretrained( 23 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant, 24 | ) 25 | text_encoder_one.requires_grad_(False) 26 | text_encoder_two.requires_grad_(False) 27 | 28 | if class_three is not None: 29 | text_encoder_three = class_three.from_pretrained( 30 | args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant, 31 | ) 32 | text_encoder_three.requires_grad_(False) 33 | return text_encoder_one, text_encoder_two, text_encoder_three 34 | 35 | def import_model_class_from_model_name_or_path( 36 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 37 | ): 38 | text_encoder_config = PretrainedConfig.from_pretrained( 39 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 40 | ) 41 | model_class = text_encoder_config.architectures[0] 42 | if model_class == "CLIPTextModelWithProjection": 43 | from transformers import CLIPTextModelWithProjection 44 | 45 | return CLIPTextModelWithProjection 46 | elif model_class == "T5EncoderModel": 47 | from transformers import T5EncoderModel 48 | 49 | return T5EncoderModel 50 | else: 51 | raise ValueError(f"{model_class} is not supported.") 52 | 53 | 54 | def main(args): 55 | """ 56 | Generate virtual try-off image using SD3 pipeline. 57 | It can be used to generate a try-off image from a single image. 58 | 59 | Args: 60 | args: Parsed command line arguments containing model paths, and generation parameters. 61 | """ 62 | os.makedirs(args.output_dir, exist_ok=True) 63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | weight_dtype = torch.float32 65 | if args.mixed_precision == "fp16": 66 | weight_dtype = torch.float16 67 | elif args.mixed_precision == "bf16": 68 | weight_dtype = torch.bfloat16 69 | 70 | tokenizer_one = CLIPTokenizer.from_pretrained( 71 | args.pretrained_model_name_or_path, 72 | subfolder="tokenizer", 73 | revision=args.revision, 74 | ) 75 | tokenizer_two = CLIPTokenizer.from_pretrained( 76 | args.pretrained_model_name_or_path, 77 | subfolder="tokenizer_2", 78 | revision=args.revision, 79 | ) 80 | tokenizer_three = T5TokenizerFast.from_pretrained( 81 | args.pretrained_model_name_or_path, 82 | subfolder="tokenizer_3", 83 | revision=args.revision, 84 | low_cpu_mem_usage=True, 85 | ) 86 | 87 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 88 | args.pretrained_model_name_or_path, args.revision 89 | ) 90 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 91 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" 92 | ) 93 | text_encoder_cls_three = import_model_class_from_model_name_or_path( 94 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" 95 | ) 96 | 97 | text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( 98 | text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, 99 | ) 100 | 101 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 102 | args.pretrained_model_name_or_path, subfolder="scheduler" 103 | ) 104 | 105 | vae = AutoencoderKL.from_pretrained( 106 | args.pretrained_model_name_or_path, 107 | subfolder="vae", 108 | revision=args.revision, 109 | variant=args.variant, 110 | ) 111 | transformer = SD3Transformer2DModel.from_pretrained( 112 | args.pretrained_model_name_or_path_sd3_tryoff, subfolder="transformer", revision=args.revision, variant=args.variant 113 | ) 114 | 115 | transformer_vton_feature_extractor = SD3Transformer2DModel_feature_extractor.from_pretrained( 116 | args.pretrained_model_name_or_path_sd3_tryoff, subfolder="transformer_vton", revision=args.revision, variant=args.variant 117 | ) 118 | 119 | image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").to(device=device, dtype=weight_dtype) 120 | image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(device=device, dtype=weight_dtype) 121 | 122 | pipeline = StableDiffusion3TryOffPipelineMasked( 123 | scheduler=noise_scheduler, 124 | vae=vae, 125 | transformer_vton_feature_extractor=transformer_vton_feature_extractor, 126 | transformer_garm=transformer, 127 | image_encoder_large=image_encoder_large, 128 | image_encoder_bigG=image_encoder_bigG, 129 | tokenizer=tokenizer_one, 130 | tokenizer_2=tokenizer_two, 131 | tokenizer_3=tokenizer_three, 132 | text_encoder=text_encoder_one, 133 | text_encoder_2=text_encoder_two, 134 | text_encoder_3=text_encoder_three, 135 | ) 136 | 137 | pipeline.to(device, dtype=weight_dtype) 138 | 139 | image_path = args.example_image 140 | image_name = os.path.splitext(os.path.basename(image_path))[0] 141 | image_directory = os.path.dirname(image_path) 142 | 143 | image = Image.open(image_path) 144 | image = image.resize((args.width, args.height)) 145 | 146 | caption = open(os.path.join(image_directory, f"{image_name}_caption.txt"), "r").read() 147 | 148 | image_fine_mask = Image.open(os.path.join(image_directory, f"{image_name}_fine_mask.jpg")) 149 | image_fine_mask = image_fine_mask.resize((args.width, args.height)) 150 | 151 | image_binary_mask = Image.open(os.path.join(image_directory, f"{image_name}_binary_mask.jpg")) 152 | image_binary_mask = image_binary_mask.resize((args.width, args.height)) 153 | 154 | generator = torch.Generator(device=device).manual_seed(args.seed) if args.seed else None 155 | 156 | image = pipeline( 157 | prompt=caption, 158 | height=args.height, 159 | width=args.width, 160 | guidance_scale=args.guidance_scale, 161 | num_inference_steps=args.num_inference_steps, 162 | generator=generator, 163 | vton_image=image, 164 | mask_input=image_binary_mask, 165 | image_input_masked=image_fine_mask, 166 | ).images[0] 167 | 168 | image.save(f"{args.output_dir}/{image_name}_output.png") 169 | 170 | 171 | if __name__ == "__main__": 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument( 174 | "--pretrained_model_name_or_path", 175 | type=str, 176 | default=None, 177 | required=True, 178 | help="Path to pretrained model or model identifier from huggingface.co/models.", 179 | ) 180 | parser.add_argument( 181 | "--pretrained_model_name_or_path_sd3_tryoff", 182 | type=str, 183 | default=None, 184 | required=True, 185 | help="Path to pretrained model or model identifier from huggingface.co/models.", 186 | ) 187 | parser.add_argument( 188 | "--revision", 189 | type=str, 190 | default=None, 191 | required=False, 192 | help="Revision of pretrained model identifier from huggingface.co/models.", 193 | ) 194 | parser.add_argument( 195 | "--variant", 196 | type=str, 197 | default=None, 198 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 199 | ) 200 | 201 | parser.add_argument( 202 | "--seed", 203 | type=int, 204 | default=42, 205 | required=False, 206 | ) 207 | parser.add_argument( 208 | "--width", 209 | type=int, 210 | default=768, 211 | required=True, 212 | ) 213 | parser.add_argument( 214 | "--height", 215 | type=int, 216 | default=1024, 217 | required=True, 218 | ) 219 | parser.add_argument( 220 | "--output_dir", 221 | type=str, 222 | default="outputs", 223 | required=True, 224 | ) 225 | parser.add_argument( 226 | "--mixed_precision", 227 | type=str, 228 | default="bf16", 229 | required=False, 230 | ) 231 | parser.add_argument( 232 | "--example_image", 233 | type=str, 234 | default="examples/example1.jpg", 235 | required=True, 236 | ) 237 | parser.add_argument( 238 | "--guidance_scale", 239 | type=float, 240 | default=2.0, 241 | required=True, 242 | ) 243 | parser.add_argument( 244 | "--num_inference_steps", 245 | type=int, 246 | default=28, 247 | required=True, 248 | ) 249 | args = parser.parse_args() 250 | main(args) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |
4 |

TEMU-VTOFF

5 |

Text-Enhanced MUlti-category Virtual Try-Off

6 |
7 | 8 |
9 | 10 | 11 | MiniMax 12 | 13 | 14 |
15 | 16 |
17 | 18 | > **Inverse Virtual Try-On: Generating Multi-Category Product-Style Images from Clothed Individuals** 19 | > [Davide Lobba](https://scholar.google.com/citations?user=WEMoLPEAAAAJ&hl=en&oi=ao)1,2,\*, [Fulvio Sanguigni](https://scholar.google.com/citations?user=tSpzMUEAAAAJ&hl=en)2,3,\*, [Bin Ren](https://scholar.google.com/citations?user=Md9maLYAAAAJ&hl=en)1,2, [Marcella Cornia](https://scholar.google.com/citations?user=DzgmSJEAAAAJ&hl=en)3, [Rita Cucchiara](https://scholar.google.com/citations?user=OM3sZEoAAAAJ&hl=en)3, [Nicu Sebe](https://scholar.google.com/citations?user=stFCYOAAAAAJ&hl=en)1 20 | > 1University of Trento, 2University of Pisa, 3University of Modena and Reggio Emilia 21 | > * Equal contribution 22 |
23 | 24 |
25 | 26 | Paper 27 | 28 | 29 | PDF 30 | 31 | 32 | webpage 33 | 34 | 35 | Project 36 | 37 | 38 | License 39 | 40 |
41 | 42 | 43 | 44 | 45 |
46 | Table of Contents 47 |
    48 |
  1. About The Project
  2. 49 |
  3. Key Features
  4. 50 |
  5. Getting Started 51 | 55 |
  6. 56 |
  7. Inference
  8. 57 |
  9. Dataset Inference
  10. 58 | 63 |
  11. Contact
  12. 64 |
  13. Citation
  14. 65 |
66 |
67 | 68 | 69 | 70 | 71 | ## 💡 About The Project 72 | 73 | TEMU-VTOFF is a novel dual-DiT (Diffusion Transformer) architecture designed for the Virtual Try-Off task: generating clean, in-shop images of garments worn by a person. By combining a pretrained feature extractor with a text-enhanced generation module, our method can handle occlusions, multiple garment categories, and ambiguous appearances. It further refines generation fidelity via a feature alignment module based on DINOv2. 74 | 75 | ## ✨ Key Features 76 | Our contribution can be summarized as follows: 77 | - **🎯 Multi-Category Try-Off**. We present a unified framework capable of handling multiple garment types (upper-body, lower-body, and full-body clothes) without requiring category-specific pipelines. 78 | - **🔗 Multimodal Hybrid Attention**. We introduce a novel attention mechanism that integrates garment textual descriptions into the generative process by linking them with person-specific features. This helps the model synthesize occluded or ambiguous garment regions more accurately. 79 | - **⚡ Garment Aligner Module**. We design a lightweight aligner that conditions generation on clean garment images, replacing conventional denoising objectives. This leads to better alignment consistency on the overall dataset and preserves more precise visual retention. 80 | - **📊 Extensive experiments**. Experiments on the Dress Code and VITON-HD datasets demonstrate that TEMU-VTOFF outperforms prior methods in both the quality of generated images and alignment with the target garment, highlighting its strong generalization capabilities. 81 | 82 | 83 | ## 💻 Getting Started 84 | ### Prerequisites 85 | 86 | Clone the repository: 87 | ```sh 88 | git clone https://github.com/davidelobba/TEMU-VTOFF.git 89 | ``` 90 | 91 | ### Installation 92 | 93 | 1. We recommend installing the required packages using Python's native virtual environment (venv) as follows: 94 | ```sh 95 | python -m venv venv 96 | source venv/bin/activate 97 | ``` 98 | 2. Upgrade pip and install dependencies 99 | ```sh 100 | pip install --upgrade pip 101 | pip install -r requirements.txt 102 | ``` 103 | 3. Create a .env file like the following: 104 | ```js 105 | export WANDB_API_KEY="ENTER YOUR WANDB TOKEN" 106 | export HF_TOKEN="ENTER YOUR HUGGINGFACE TOKEN" 107 | export HF_HOME="PATH WHERE YOU WANT TO SAVE THE HF MODELS" 108 | ``` 109 | 🧠 Note: Access to Stable Diffusion 3 Medium must be requested via [HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers). 110 | 111 | ## Inference 112 | Let's generate the in-shop garment image. 113 | ```sh 114 | source venv/bin/activate 115 | source .env 116 | 117 | python inference.py \ 118 | --pretrained_model_name_or_path "stabilityai/stable-diffusion-3-medium-diffusers" \ 119 | --pretrained_model_name_or_path_sd3_tryoff "davidelobba/TEMU-VTOFF" \ 120 | --seed 42 \ 121 | --width "768" \ 122 | --height "1024" \ 123 | --output_dir "put here the output path" \ 124 | --mixed_precision "bf16" \ 125 | --example_image "examples/example1.jpg" \ 126 | --guidance_scale 2.0 \ 127 | --num_inference_steps 28 128 | ``` 129 | 130 | ## Dataset Inference 131 | ### Dataset Captioning 132 | Generate textual descriptions for each sample using a multimodal VLM (e.g., Qwen2.5-VL). 133 | 134 | ```sh 135 | python precompute_utils/captioning_qwen.py \ 136 | --pretrained_model_name_or_path "Qwen/Qwen2.5-VL-7B-Instruct" \ 137 | --dataset_name "dresscode" \ 138 | --dataset_root "put here your dataset path" \ 139 | --filename "qwen_captions_2_5.json"\ 140 | --temperatures 0.2 141 | ``` 142 | 143 | ### Feature extraction 144 | Extract textual features using OpenCLIP, CLIP and T5 text encoders. 145 | 146 | ```sh 147 | phases=("test" "train") 148 | for phase in "${phases[@]}"; do 149 | python precompute_utils/precompute_text_features.py \ 150 | --pretrained_model_name_or_path "stabilityai/stable-diffusion-3-medium-diffusers" \ 151 | --dataset_name "dresscode" \ 152 | --dataset_root "put here your dataset path" \ 153 | --phase $phase \ 154 | --order "paired" \ 155 | --category "all" \ 156 | --output_dir "" \ 157 | --seed 42 \ 158 | --height 1024 \ 159 | --width 768 \ 160 | --batch_size 4 \ 161 | --mixed_precision "fp16" \ 162 | --num_workers 8 \ 163 | --text_encoders "T5" "CLIP" \ 164 | --captions_type "qwen_text_embeddings" 165 | ``` 166 | 167 | Extract visual features using OpenCLIP and CLIP vision encoders. 168 | ```sh 169 | phases=("test" "train") 170 | for phase in "${phases[@]}"; do 171 | python precompute_utils/precompute_image_features.py \ 172 | --dataset "dresscode" \ 173 | --dataroot "put here your dataset path" \ 174 | --phase $phase \ 175 | --order "paired" \ 176 | --category "all" \ 177 | --seed 42 \ 178 | --height 1024 \ 179 | --width 768 \ 180 | --batch_size 4 \ 181 | --mixed_precision "fp16" \ 182 | --num_workers 8 183 | ``` 184 | ### Generate Images 185 | Let's generate the in-shop garment images of DressCode or VITON-HD using the TEMU-VTOFF model. 186 | ```sh 187 | source venv/bin/activate 188 | source .env 189 | 190 | python inference_dataset.py \ 191 | --pretrained_model_name_or_path "stabilityai/stable-diffusion-3-medium-diffusers" \ 192 | --pretrained_model_name_or_path_sd3_tryoff "davidelobba/TEMU-VTOFF" \ 193 | --dataset_name "dresscode" \ 194 | --dataset_root "put here your dataset path" \ 195 | --output_dir "put here the output path" \ 196 | --coarse_caption_file "qwen_captions_2_5_0_2.json" \ 197 | --phase "test" \ 198 | --order "paired" \ 199 | --height "1024" \ 200 | --width "768" \ 201 | --mask_type bounding_box \ 202 | --category "all" \ 203 | --batch_size 4 \ 204 | --mixed_precision "bf16" \ 205 | --seed 42 \ 206 | --num_workers 8 \ 207 | --fine_mask \ 208 | --guidance_scale 2.0 \ 209 | --num_inference_steps 28 210 | ``` 211 | 212 | ## 📬 Contact 213 | **Lead Authors:** 214 | - 📧 **Davide Lobba**: [davide.lobba@unitn.it](mailto:davide.lobba@unitn.it) | 🎓 [Google Scholar](https://scholar.google.com/citations?user=WEMoLPEAAAAJ&hl=en&oi=ao) 215 | - 📧 **Fulvio Sanguigni**: [fulvio.sanguigni@unimore.it](mailto:fulvio.sanguigni@unimore.it) | 🎓 [Google Scholar](https://scholar.google.com/citations?user=tSpzMUEAAAAJ&hl=en) 216 | 217 | For questions about the project, feel free to reach out to any of the lead authors! 218 | 219 | 220 | ## Citation 221 | Please cite our paper if you find our work helpful: 222 | ```bibtex 223 | @article{lobba2025inverse, 224 | title={Inverse Virtual Try-On: Generating Multi-Category Product-Style Images from Clothed Individuals}, 225 | author={Lobba, Davide and Sanguigni, Fulvio and Ren, Bin and Cornia, Marcella and Cucchiara, Rita and Sebe, Nicu}, 226 | journal={arXiv preprint arXiv:2505.21062}, 227 | year={2025} 228 | } 229 | ``` 230 | -------------------------------------------------------------------------------- /inference_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from pathlib import Path 4 | import os 5 | from diffusers import ( 6 | AutoencoderKL, 7 | FlowMatchEulerDiscreteScheduler, 8 | ) 9 | from src.transformer_vtoff import SD3Transformer2DModel 10 | from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DModel_feature_extractor 11 | from src.pipeline_stable_diffusion_3_tryoff_masked import StableDiffusion3TryOffPipelineMasked 12 | from transformers import CLIPVisionModelWithProjection 13 | 14 | from dataset.dresscode import DressCodeDataset 15 | from dataset.vitonhd import VitonHDDataset 16 | from tqdm import tqdm 17 | 18 | 19 | def main(args): 20 | """ 21 | Generate virtual try-off images using SD3 pipeline. 22 | It can be used to generate images from the DressCode or VITON-HD dataset. 23 | 24 | Args: 25 | args: Parsed command line arguments containing model paths, 26 | dataset configuration, and generation parameters. 27 | """ 28 | os.makedirs(args.output_dir, exist_ok=True) 29 | weight_dtype = torch.float32 30 | if args.mixed_precision == "fp16": 31 | weight_dtype = torch.float16 32 | elif args.mixed_precision == "bf16": 33 | weight_dtype = torch.bfloat16 34 | 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 37 | args.pretrained_model_name_or_path, subfolder="scheduler" 38 | ) 39 | 40 | vae = AutoencoderKL.from_pretrained( 41 | args.pretrained_model_name_or_path, 42 | subfolder="vae", 43 | revision=args.revision, 44 | variant=args.variant, 45 | ) 46 | transformer = SD3Transformer2DModel.from_pretrained( 47 | args.pretrained_model_name_or_path_sd3_tryoff, subfolder="transformer", revision=args.revision, variant=args.variant 48 | ) 49 | 50 | transformer_vton_feature_extractor = SD3Transformer2DModel_feature_extractor.from_pretrained( 51 | args.pretrained_model_name_or_path_sd3_tryoff, subfolder="transformer_vton", revision=args.revision, variant=args.variant 52 | ) 53 | 54 | image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").to(device=device, dtype=weight_dtype) 55 | image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(device=device, dtype=weight_dtype) 56 | 57 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 58 | pipeline = StableDiffusion3TryOffPipelineMasked( 59 | scheduler=noise_scheduler, 60 | vae=vae, 61 | transformer_vton_feature_extractor=transformer_vton_feature_extractor, 62 | transformer_garm=transformer, 63 | image_encoder_large=image_encoder_large, 64 | image_encoder_bigG=image_encoder_bigG, 65 | tokenizer=None, 66 | tokenizer_2=None, 67 | tokenizer_3=None, 68 | text_encoder=None, 69 | text_encoder_2=None, 70 | text_encoder_3=None, 71 | ) 72 | 73 | pipeline.to(device, dtype=weight_dtype) 74 | 75 | if not args.fine_mask: 76 | outputlist = ['category', 'image', 'im_name', 'vton_image_embeddings', 'clip_embeds', 't5_embeds', 'clip_pooled', 77 | 'inpaint_mask', 'im_mask'] 78 | else: 79 | outputlist = ['category', 'image', 'im_name', 'vton_image_embeddings', 'clip_embeds', 't5_embeds', 'clip_pooled', 80 | 'parse_cloth', 'im_mask_fine'] 81 | 82 | if args.dataset_name == "dresscode": 83 | if args.category == "all": 84 | args.category = ["upper_body", "lower_body", "dresses"] 85 | else: 86 | args.category = [args.category] 87 | dataset = DressCodeDataset( 88 | dataroot_path=args.dataset_root, 89 | phase=args.phase, 90 | order=args.order, 91 | radius=5, 92 | outputlist=outputlist, 93 | sketch_threshold_range=(20, 20), 94 | category=args.category, 95 | mask_type=args.mask_type, 96 | size=(args.height, args.width), 97 | coarse_caption_file=args.coarse_caption_file 98 | ) 99 | elif args.dataset_name == "vitonhd": 100 | dataset = VitonHDDataset( 101 | dataroot_path=args.dataset_root, 102 | phase=args.phase, 103 | order=args.order, 104 | radius=5, 105 | outputlist=outputlist, 106 | size=(args.height, args.width), 107 | mask_type=args.mask_type, 108 | caption_file=args.coarse_caption_file 109 | ) 110 | else: 111 | raise ValueError(f"Unsupported dataset: {args.dataset_name}") 112 | 113 | dataloader = torch.utils.data.DataLoader( 114 | dataset, 115 | shuffle=False, 116 | batch_size=args.batch_size, 117 | num_workers=args.num_workers, 118 | pin_memory=True, 119 | ) 120 | 121 | empty_string_pooled_prompt_embeds = torch.load(Path(args.dataset_root).parent / "empty_string_embeddings" / "CLIP_VIT_L_VIT_G_concat_pooled" / "empty_string.pt", map_location="cpu").to(device, dtype=weight_dtype) 122 | empty_string_clip_prompt_embeds = torch.load(Path(args.dataset_root).parent / "empty_string_embeddings" / "CLIP_VIT_L_VIT_G_concat" / "empty_string.pt", map_location="cpu").to(device, dtype=weight_dtype) 123 | empty_string_t5_prompt_embeds = torch.load(Path(args.dataset_root).parent / "empty_string_embeddings" / "T5_XXL" / "empty_string.pt", map_location="cpu").to(device, dtype=weight_dtype) 124 | negative_text_embeds = torch.cat([empty_string_clip_prompt_embeds, empty_string_t5_prompt_embeds], dim=1) 125 | negative_pooled_embeds = empty_string_pooled_prompt_embeds 126 | 127 | negative_text_embeds = negative_text_embeds.repeat(args.batch_size, 1, 1) 128 | negative_pooled_embeds = negative_pooled_embeds.repeat(args.batch_size, 1) 129 | 130 | for batch in tqdm(dataloader): 131 | image = batch.get("image").to(device=device) 132 | image_name = batch.get("im_name") 133 | image = (image+1)/2 #from [-1, 1] to [0, 1] 134 | clip_text_embeddings = batch.get("clip_embeds").to(device=device, dtype=weight_dtype) 135 | t5_text_embeddings = batch.get("t5_embeds").to(device=device, dtype=weight_dtype) 136 | clip_pooled_embeddings = batch.get("clip_pooled").to(device=device).to(dtype=weight_dtype) 137 | text_embeddings=torch.cat([clip_text_embeddings, t5_text_embeddings], dim=1) 138 | if not args.fine_mask: 139 | masked_vton_img = batch.get("im_mask").to(device=device) 140 | inpaint_mask = batch.get("inpaint_mask").to(device=device) 141 | else: 142 | masked_vton_img = batch.get("im_mask_fine").to(device=device) 143 | inpaint_mask = batch.get("parse_cloth").to(device=device) 144 | if len(inpaint_mask.shape) == 3: 145 | inpaint_mask = inpaint_mask.unsqueeze(1) 146 | 147 | masked_vton_img = masked_vton_img.to(dtype=weight_dtype) 148 | inpaint_mask = inpaint_mask.to(dtype=weight_dtype) 149 | 150 | generator = torch.Generator(device=device).manual_seed(args.seed) if args.seed else None 151 | 152 | images = pipeline( 153 | prompt_embeds = text_embeddings, 154 | pooled_prompt_embeds = clip_pooled_embeddings, 155 | negative_prompt_embeds = negative_text_embeds, 156 | negative_pooled_prompt_embeds = negative_pooled_embeds, 157 | height=args.height, 158 | width=args.width, 159 | guidance_scale=args.guidance_scale, 160 | num_inference_steps=args.num_inference_steps, 161 | generator=generator, 162 | vton_image=image, 163 | mask_input=inpaint_mask, 164 | image_input_masked=masked_vton_img, 165 | ).images 166 | 167 | for i in range(len(images)): 168 | category = batch.get("category")[i] 169 | os.makedirs(f"{os.path.join(args.output_dir, category)}", exist_ok=True) 170 | images[i].save(f"{os.path.join(args.output_dir, category)}/{image_name[i].split('.')[0]}.png") 171 | 172 | 173 | if __name__ == "__main__": 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument( 176 | "--pretrained_model_name_or_path", 177 | type=str, 178 | default=None, 179 | required=True, 180 | help="Path to pretrained model or model identifier from huggingface.co/models.", 181 | ) 182 | parser.add_argument( 183 | "--pretrained_model_name_or_path_sd3_tryoff", 184 | type=str, 185 | default=None, 186 | required=True, 187 | help="Path to pretrained model or model identifier from huggingface.co/models.", 188 | ) 189 | parser.add_argument( 190 | "--revision", 191 | type=str, 192 | default=None, 193 | required=False, 194 | help="Revision of pretrained model identifier from huggingface.co/models.", 195 | ) 196 | parser.add_argument( 197 | "--variant", 198 | type=str, 199 | default=None, 200 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 201 | ) 202 | parser.add_argument( 203 | "--dataset_name", 204 | type=str, 205 | required=True, 206 | choices=["dresscode", "vitonhd"], 207 | ) 208 | parser.add_argument( 209 | "--dataset_root", 210 | type=str, 211 | required=True, 212 | ) 213 | parser.add_argument( 214 | "--coarse_caption_file", 215 | type=str, 216 | required=True, 217 | ) 218 | parser.add_argument( 219 | "--order", 220 | type=str, 221 | required=True, 222 | ) 223 | parser.add_argument( 224 | "--phase", 225 | type=str, 226 | required=True, 227 | ) 228 | parser.add_argument( 229 | "--batch_size", 230 | type=int, 231 | required=True, 232 | ) 233 | parser.add_argument( 234 | "--num_workers", 235 | type=int, 236 | required=True, 237 | ) 238 | parser.add_argument( 239 | "--category", 240 | type=str, 241 | required=True, 242 | ) 243 | parser.add_argument( 244 | "--seed", 245 | type=int, 246 | required=False, 247 | default=42, 248 | ) 249 | parser.add_argument( 250 | "--width", 251 | type=int, 252 | required=True, 253 | default=768, 254 | ) 255 | parser.add_argument( 256 | "--height", 257 | type=int, 258 | required=True, 259 | default=1024, 260 | ) 261 | parser.add_argument( 262 | "--output_dir", 263 | type=str, 264 | required=True, 265 | ) 266 | parser.add_argument( 267 | "--mixed_precision", 268 | type=str, 269 | default="bf16", 270 | choices=["fp16", "bf16"], 271 | required=False, 272 | ) 273 | parser.add_argument( 274 | "--mask_type", 275 | type=str, 276 | required=True, 277 | ) 278 | parser.add_argument( 279 | "--fine_mask", 280 | action="store_true", 281 | ) 282 | parser.add_argument( 283 | "--guidance_scale", 284 | type=float, 285 | required=True, 286 | default=2.0, 287 | ) 288 | parser.add_argument( 289 | "--num_inference_steps", 290 | type=int, 291 | required=True, 292 | default=28, 293 | ) 294 | args = parser.parse_args() 295 | main(args) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | import os 4 | import torch 5 | from cleanfid import fid as CleanFID 6 | from DISTS_pytorch import DISTS 7 | from prettytable import PrettyTable 8 | from pytorch_fid import fid_score as PytorchFID 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | from torchmetrics.image import PeakSignalNoiseRatio 12 | from torchmetrics.image import StructuralSimilarityIndexMeasure 13 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 14 | from torchvision import transforms 15 | from tqdm import tqdm 16 | import shutil 17 | from pathlib import Path 18 | 19 | 20 | class EvalDataset(Dataset): 21 | def __init__(self, gt_folder, pred_files, height=1024): 22 | self.gt_folder = gt_folder 23 | self.pred_files = pred_files 24 | self.height = height 25 | self.data = self.prepare_data() 26 | self.to_tensor = transforms.ToTensor() 27 | self.pred_transform = transforms.Compose([ 28 | transforms.Resize((1024, 768)), 29 | transforms.ToTensor(), 30 | ]) 31 | self.gt_transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 34 | ]) 35 | 36 | def extract_id_from_filename(self, filename): 37 | if "inshop" in filename: 38 | filename = filename.split("_")[0] 39 | return filename 40 | filename = filename.split("_")[0] 41 | return filename 42 | 43 | def prepare_data(self): 44 | gt_files = scan_files_in_dir(self.gt_folder, postfix={'.jpg', '.png'}) 45 | gt_dict = {self.extract_id_from_filename( 46 | file.name): file for file in gt_files} 47 | pred_files = self.pred_files 48 | pred_dict = {self.extract_id_from_filename(file.name): file for file in pred_files} 49 | pred_files = [pred_dict[pred_id] for pred_id in gt_dict.keys()] 50 | assert len(pred_files) == len( 51 | gt_dict), f"Number of gt files {len(gt_dict)} and pred files {len(pred_files)} do not match" 52 | 53 | tuples = [] 54 | for pred_file in pred_files: 55 | pred_id = self.extract_id_from_filename(pred_file.name) 56 | tuples.append((gt_dict[pred_id].path, pred_file.path)) 57 | return tuples 58 | 59 | def resize(self, img): 60 | w, h = img.size 61 | new_w = int(w * self.height / h) 62 | return img.resize((new_w, self.height), Image.LANCZOS) 63 | 64 | def __len__(self): 65 | return len(self.data) 66 | 67 | def __getitem__(self, idx): 68 | gt_path, pred_path = self.data[idx] 69 | gt, pred = self.resize(Image.open(gt_path)), self.resize( 70 | Image.open(pred_path)) 71 | if gt.height != self.height: 72 | gt = self.resize(gt) 73 | if pred.height != self.height: 74 | pred = self.resize(pred) 75 | gt = self.gt_transform(gt) 76 | gt = (gt+1)/2 77 | pred = self.pred_transform(pred) 78 | return gt, pred 79 | 80 | 81 | def scan_files_in_dir(directory, postfix: Set[str] = None, progress_bar: tqdm = None) -> list: 82 | """Scan a directory and return a list of files with the given postfix. Can accept also directory with subdirectories. 83 | WARNING: Put the folder containing only needed files and subfolders to avoid iterating over unecessary subfolders and files 84 | """ 85 | file_list = [] 86 | progress_bar = tqdm(total=0, desc=f"Scanning", 87 | ncols=100) if progress_bar is None else progress_bar 88 | for entry in os.scandir(directory): 89 | if entry.is_file(): 90 | if postfix is None or os.path.splitext(entry.path)[1] in postfix: 91 | file_list.append(entry) 92 | progress_bar.total += 1 93 | progress_bar.update(1) 94 | elif entry.is_dir(): 95 | file_list += scan_files_in_dir(entry.path, 96 | postfix=postfix, progress_bar=progress_bar) 97 | return file_list 98 | 99 | 100 | def copy_resize_gt(gt_folder, height, width): 101 | new_folder = f"{gt_folder}_{height}" 102 | if not os.path.exists(new_folder): 103 | os.makedirs(new_folder, exist_ok=True) 104 | for file in tqdm(os.listdir(gt_folder)): 105 | if os.path.exists(os.path.join(new_folder, file)): 106 | continue 107 | img = Image.open(os.path.join(gt_folder, file)) 108 | img = img.resize((width, height), Image.LANCZOS) 109 | img.save(os.path.join(new_folder, file)) 110 | return new_folder 111 | 112 | 113 | @torch.no_grad() 114 | def psnr(dataloader): 115 | psnr_score = 0 116 | psnr = PeakSignalNoiseRatio(data_range=1.0).to("cuda") 117 | for gt, pred in tqdm(dataloader, desc="Calculating PSNR"): 118 | batch_size = gt.size(0) 119 | gt, pred = gt.to("cuda"), pred.to("cuda") 120 | psnr_score += psnr(pred, gt) * batch_size 121 | return psnr_score / len(dataloader.dataset) 122 | 123 | 124 | @torch.no_grad() 125 | def ssim(dataloader): 126 | ssim_score = 0 127 | ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to("cuda") 128 | for gt, pred in tqdm(dataloader, desc="Calculating SSIM"): 129 | batch_size = gt.size(0) 130 | gt, pred = gt.to("cuda"), pred.to("cuda") 131 | ssim.update(pred, gt) 132 | return ssim.compute().item() 133 | 134 | 135 | @torch.no_grad() 136 | def lpips(dataloader): 137 | lpips_score = LearnedPerceptualImagePatchSimilarity(net_type='alex', normalize=True).to("cuda") 138 | score = 0 139 | for gt, pred in tqdm(dataloader, desc="Calculating LPIPS"): 140 | batch_size = gt.size(0) 141 | pred = pred.to("cuda") 142 | gt = gt.to("cuda") 143 | lpips_score.update(gt, pred) 144 | return lpips_score.compute().item() 145 | 146 | def dists(dataloader): 147 | dists = DISTS().to("cuda") 148 | dists_score = 0 149 | for gt, pred in tqdm(dataloader, desc="DISTS"): 150 | batch_size = gt.size(0) 151 | pred = pred.to("cuda") 152 | gt = gt.to("cuda") 153 | score = dists(gt, pred, batch_average=True) 154 | dists_score+=score 155 | dists_score = dists_score / len(dataloader) 156 | return dists_score.item() 157 | 158 | def structure_dirs(dataset, dataroot, category="all"): 159 | """_summary_ 160 | 161 | Args: 162 | dataroot (_type_): _description_ 163 | category (str, optional): _description_. Defaults to "all". 164 | Accumulate image paths into a temporary folder to generate stats of different categories. For example, tmp_folder = /tmp/dresscode/dresses" 165 | line.strip().split()[1] returns the cloth name, line.strip().split()[0] returns the vton image name 166 | """ 167 | if dataset=="dresscode": 168 | dresscode_filesplit = os.path.join(dataroot, f"test_pairs_paired.txt") 169 | with open(dresscode_filesplit, 'r') as f: 170 | lines = f.read().splitlines() 171 | paths = [] 172 | if category in ['lower_body', 'upper_body', 'dresses']: 173 | for line in lines: 174 | path = os.path.join(dataroot, category, 'cleaned_inshop_imgs', line.strip().split()[1]).replace(".jpg", "_cleaned.jpg") 175 | path_2 = os.path.join(dataroot, category, 'images', line.strip().split()[1]).replace(".jpg", "_cleaned.jpg") 176 | path_3 = os.path.join(dataroot, category, 'images', line.strip().split()[1]) 177 | if os.path.exists(path): 178 | paths.append(path) 179 | elif os.path.exists(path_2): 180 | paths.append(path_2) 181 | elif os.path.exists(path_3): 182 | paths.append(path_3) 183 | tmp_folder = f"/tmp/dresscode/{category}" 184 | shutil.rmtree(tmp_folder, ignore_errors=True) 185 | os.makedirs(tmp_folder, exist_ok=True) 186 | for path in tqdm(paths): 187 | shutil.copy(path, tmp_folder) 188 | return tmp_folder 189 | 190 | elif category == "all": 191 | for category in ['lower_body', 'upper_body', 'dresses']: 192 | for line in lines: 193 | path = os.path.join(dataroot, category, 'cleaned_inshop_imgs', line.strip().split()[1]).replace(".jpg", "_cleaned.jpg") 194 | path_2 = os.path.join(dataroot, category, 'images', line.strip().split()[1]).replace(".jpg", "_cleaned.jpg") 195 | path_3 = os.path.join(dataroot, category, 'images', line.strip().split()[1]) 196 | if os.path.exists(path): 197 | paths.append(path) 198 | elif os.path.exists(path_2): 199 | paths.append(path_2) 200 | elif os.path.exists(path_3): 201 | paths.append(path_3) 202 | tmp_folder = f"/tmp/dresscode/all" 203 | shutil.rmtree(tmp_folder, ignore_errors=True) 204 | os.makedirs(tmp_folder, exist_ok=True) 205 | for path in tqdm(paths): 206 | shutil.copy(path, tmp_folder) 207 | return tmp_folder 208 | 209 | else: raise NotImplementedError("category must be in ['lower_body', 'upper_body', 'dresses', 'all']") 210 | 211 | elif dataset=="vitonhd": 212 | return os.path.join(dataroot, 'test', 'cloth') 213 | 214 | else: raise NotImplementedError("dataset must be in ['dresscode', 'vitonhd']") 215 | 216 | 217 | def compute_metrics(args): 218 | args.gt_folder = structure_dirs(args.dataset, args.gt_folder, args.category) 219 | pred_files = scan_files_in_dir( 220 | args.pred_folder, postfix={'.jpg', '.png'}) 221 | if args.category == "all": 222 | for pred_file in tqdm(pred_files): 223 | new_pred_folder = f"/tmp/{args.dataset}/predictions" 224 | os.makedirs(new_pred_folder, exist_ok=True) 225 | shutil.copy(pred_file.path, new_pred_folder) 226 | pred_folder = new_pred_folder 227 | elif os.path.exists(os.path.join(args.pred_folder, args.category)): 228 | pred_folder = os.path.join(args.pred_folder, args.category) 229 | else: 230 | pred_folder = args.pred_folder 231 | 232 | 233 | gt_sample = os.listdir(args.gt_folder)[0] 234 | gt_img = Image.open(os.path.join(args.gt_folder, gt_sample)) 235 | if args.height != gt_img.height: 236 | title = "--"*30 + \ 237 | f"Resizing GT Images to height {args.height}" + "--"*30 238 | print(title) 239 | args.gt_folder = copy_resize_gt(args.gt_folder, args.height, args.width) 240 | print("-"*len(title)) 241 | 242 | dataset = EvalDataset(args.gt_folder, pred_files, args.height) 243 | dataloader = torch.utils.data.DataLoader( 244 | dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, drop_last=False 245 | ) 246 | 247 | header = [] 248 | row = [] 249 | print("FID and KID") 250 | header += ["PyTorch-FID"] 251 | pytorch_fid_ = PytorchFID.calculate_fid_given_paths( 252 | [args.gt_folder, pred_folder], batch_size=args.batch_size, device="cuda", dims=2048, num_workers=args.num_workers) 253 | row += ["{:.4f}".format(pytorch_fid_)] 254 | print("PyTorch-FID: {:.4f}".format(pytorch_fid_)) 255 | header += ["Clean-FID", "Clean-KID"] 256 | clean_fid_ = CleanFID.compute_fid(args.gt_folder, pred_folder) 257 | print("Clean-FID: {:.4f}".format(clean_fid_)) 258 | row += ["{:.4f}".format(clean_fid_)] 259 | clean_kid_ = CleanFID.compute_kid(args.gt_folder, pred_folder) * 1000 260 | print("Clean-KID: {:.4f}".format(clean_kid_)) 261 | row += ["{:.4f}".format(clean_kid_)] 262 | 263 | if args.paired: 264 | print("SSIM, LPIPS, DISTS") 265 | header += ["SSIM", "LPIPS", "DISTS"] 266 | ssim_ = ssim(dataloader) 267 | print("SSIM: {:.4f}".format(ssim_)) 268 | lpips_ = lpips(dataloader) 269 | print("LPIPS: {:.4f}".format(lpips_)) 270 | dists_ = dists(dataloader) 271 | print("DISTS: {:.4f}".format(dists_)) 272 | row += ["{:.4f}".format(ssim_), 273 | "{:.4f}".format(lpips_),"{:.4f}".format(dists_),] 274 | 275 | print("GT Folder : ", args.gt_folder) 276 | print("Pred Folder: ", args.pred_folder) 277 | table = PrettyTable() 278 | table.field_names = header 279 | table.add_row(row) 280 | print(table) 281 | 282 | with open(os.path.join(Path(args.pred_folder), f"metrics_{args.category}.txt"), "w") as f: 283 | f.write(str(table)) 284 | 285 | 286 | if __name__ == "__main__": 287 | import argparse 288 | parser = argparse.ArgumentParser() 289 | parser.add_argument("--gt_folder", type=str, required=True) 290 | parser.add_argument("--pred_folder", type=str, required=True) 291 | parser.add_argument("--dataset", type=str, required=True) 292 | parser.add_argument("--category", type=str, choices=["lower_body", "upper_body", "dresses", "all"], required=True) 293 | parser.add_argument("--paired", action="store_true") 294 | parser.add_argument("--batch_size", type=int, default=16) 295 | parser.add_argument("--num_workers", type=int, default=4) 296 | parser.add_argument("--height", type=str, default=1024) 297 | parser.add_argument("--width", type=str, default=768) 298 | args = parser.parse_args() 299 | 300 | if args.gt_folder.endswith("/"): 301 | args.gt_folder = args.gt_folder[:-1] 302 | if args.pred_folder.endswith("/"): 303 | args.pred_folder = args.pred_folder[:-1] 304 | 305 | compute_metrics(args) 306 | -------------------------------------------------------------------------------- /src/transformer_sd3_garm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Any, Dict, List, Optional, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 23 | from src.attention_garm import JointTransformerBlock 24 | from src.attention_processor_garm import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 25 | from diffusers.models.modeling_utils import ModelMixin 26 | from diffusers.models.normalization import AdaLayerNormContinuous 27 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 28 | from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed 29 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 30 | 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | 35 | class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 36 | """ 37 | The Transformer model introduced in Stable Diffusion 3. 38 | 39 | Reference: https://arxiv.org/abs/2403.03206 40 | 41 | Parameters: 42 | sample_size (`int`): The width of the latent images. This is fixed during training since 43 | it is used to learn a number of position embeddings. 44 | patch_size (`int`): Patch size to turn the input data into small patches. 45 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 46 | num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. 47 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 48 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 49 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 50 | caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. 51 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 52 | out_channels (`int`, defaults to 16): Number of output channels. 53 | 54 | """ 55 | 56 | _supports_gradient_checkpointing = True 57 | 58 | @register_to_config 59 | def __init__( 60 | self, 61 | sample_size: int = 128, 62 | patch_size: int = 2, 63 | in_channels: int = 16, 64 | num_layers: int = 18, 65 | attention_head_dim: int = 64, 66 | num_attention_heads: int = 18, 67 | joint_attention_dim: int = 4096, 68 | caption_projection_dim: int = 1152, 69 | pooled_projection_dim: int = 2048, 70 | out_channels: int = 16, 71 | pos_embed_max_size: int = 96, 72 | ): 73 | super().__init__() 74 | default_out_channels = in_channels 75 | self.out_channels = out_channels if out_channels is not None else default_out_channels 76 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 77 | 78 | self.pos_embed = PatchEmbed( 79 | height=self.config.sample_size, 80 | width=self.config.sample_size, 81 | patch_size=self.config.patch_size, 82 | in_channels=self.config.in_channels, 83 | embed_dim=self.inner_dim, 84 | pos_embed_max_size=pos_embed_max_size, # hard-code for now. 85 | ) 86 | self.time_text_embed = CombinedTimestepTextProjEmbeddings( 87 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 88 | ) 89 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) 90 | 91 | # `attention_head_dim` is doubled to account for the mixing. 92 | # It needs to crafted when we get the actual checkpoints. 93 | self.transformer_blocks = nn.ModuleList( 94 | [ 95 | JointTransformerBlock( 96 | dim=self.inner_dim, 97 | num_attention_heads=self.config.num_attention_heads, 98 | attention_head_dim=self.config.attention_head_dim, 99 | context_pre_only=i == num_layers - 1, 100 | ) 101 | for i in range(self.config.num_layers) 102 | ] 103 | ) 104 | 105 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 106 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 107 | 108 | self.gradient_checkpointing = False 109 | 110 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 111 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 112 | """ 113 | Sets the attention processor to use [feed forward 114 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 115 | 116 | Parameters: 117 | chunk_size (`int`, *optional*): 118 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 119 | over each tensor of dim=`dim`. 120 | dim (`int`, *optional*, defaults to `0`): 121 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 122 | or dim=1 (sequence length). 123 | """ 124 | if dim not in [0, 1]: 125 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 126 | 127 | # By default chunk size is 1 128 | chunk_size = chunk_size or 1 129 | 130 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 131 | if hasattr(module, "set_chunk_feed_forward"): 132 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 133 | 134 | for child in module.children(): 135 | fn_recursive_feed_forward(child, chunk_size, dim) 136 | 137 | for module in self.children(): 138 | fn_recursive_feed_forward(module, chunk_size, dim) 139 | 140 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking 141 | def disable_forward_chunking(self): 142 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 143 | if hasattr(module, "set_chunk_feed_forward"): 144 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 145 | 146 | for child in module.children(): 147 | fn_recursive_feed_forward(child, chunk_size, dim) 148 | 149 | for module in self.children(): 150 | fn_recursive_feed_forward(module, None, 0) 151 | 152 | @property 153 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 154 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 155 | r""" 156 | Returns: 157 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 158 | indexed by its weight name. 159 | """ 160 | # set recursively 161 | processors = {} 162 | 163 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 164 | if hasattr(module, "get_processor"): 165 | processors[f"{name}.processor"] = module.get_processor() 166 | 167 | for sub_name, child in module.named_children(): 168 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 169 | 170 | return processors 171 | 172 | for name, module in self.named_children(): 173 | fn_recursive_add_processors(name, module, processors) 174 | 175 | return processors 176 | 177 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 178 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 179 | r""" 180 | Sets the attention processor to use to compute attention. 181 | 182 | Parameters: 183 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 184 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 185 | for **all** `Attention` layers. 186 | 187 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 188 | processor. This is strongly recommended when setting trainable attention processors. 189 | 190 | """ 191 | count = len(self.attn_processors.keys()) 192 | 193 | if isinstance(processor, dict) and len(processor) != count: 194 | raise ValueError( 195 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 196 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 197 | ) 198 | 199 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 200 | if hasattr(module, "set_processor"): 201 | if not isinstance(processor, dict): 202 | module.set_processor(processor) 203 | else: 204 | module.set_processor(processor.pop(f"{name}.processor")) 205 | 206 | for sub_name, child in module.named_children(): 207 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 208 | 209 | for name, module in self.named_children(): 210 | fn_recursive_attn_processor(name, module, processor) 211 | 212 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 213 | def fuse_qkv_projections(self): 214 | """ 215 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 216 | are fused. For cross-attention modules, key and value projection matrices are fused. 217 | 218 | 219 | 220 | This API is 🧪 experimental. 221 | 222 | 223 | """ 224 | self.original_attn_processors = None 225 | 226 | for _, attn_processor in self.attn_processors.items(): 227 | if "Added" in str(attn_processor.__class__.__name__): 228 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 229 | 230 | self.original_attn_processors = self.attn_processors 231 | 232 | for module in self.modules(): 233 | if isinstance(module, Attention): 234 | module.fuse_projections(fuse=True) 235 | 236 | self.set_attn_processor(FusedJointAttnProcessor2_0()) 237 | 238 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 239 | def unfuse_qkv_projections(self): 240 | """Disables the fused QKV projection if enabled. 241 | 242 | 243 | 244 | This API is 🧪 experimental. 245 | 246 | 247 | 248 | """ 249 | if self.original_attn_processors is not None: 250 | self.set_attn_processor(self.original_attn_processors) 251 | 252 | def _set_gradient_checkpointing(self, module, value=False): 253 | if hasattr(module, "gradient_checkpointing"): 254 | module.gradient_checkpointing = value 255 | 256 | def forward( 257 | self, 258 | hidden_states: torch.FloatTensor, 259 | encoder_hidden_states: torch.FloatTensor = None, 260 | pooled_projections: torch.FloatTensor = None, 261 | timestep: torch.LongTensor = None, 262 | block_controlnet_hidden_states: List = None, 263 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 264 | return_dict: bool = True, 265 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 266 | """ 267 | The [`SD3Transformer2DModel`] forward method. 268 | 269 | Args: 270 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 271 | Input `hidden_states`. 272 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 273 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 274 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 275 | from the embeddings of input conditions. 276 | timestep ( `torch.LongTensor`): 277 | Used to indicate denoising step. 278 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 279 | A list of tensors that if specified are added to the residuals of transformer blocks. 280 | joint_attention_kwargs (`dict`, *optional*): 281 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 282 | `self.processor` in 283 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 284 | return_dict (`bool`, *optional*, defaults to `True`): 285 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 286 | tuple. 287 | 288 | Returns: 289 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 290 | `tuple` where the first element is the sample tensor. 291 | """ 292 | ref_key = [] 293 | ref_value = [] 294 | if joint_attention_kwargs is not None: 295 | joint_attention_kwargs = joint_attention_kwargs.copy() 296 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 297 | else: 298 | lora_scale = 1.0 299 | 300 | if USE_PEFT_BACKEND: 301 | # weight the lora layers by setting `lora_scale` for each PEFT layer 302 | scale_lora_layers(self, lora_scale) 303 | else: 304 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 305 | logger.warning( 306 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 307 | ) 308 | 309 | height, width = hidden_states.shape[-2:] 310 | 311 | hidden_states = self.pos_embed(hidden_states) 312 | temb = self.time_text_embed(timestep, pooled_projections) 313 | 314 | for index_block, block in enumerate(self.transformer_blocks): 315 | if self.training and self.gradient_checkpointing: 316 | 317 | def create_custom_forward(module, return_dict=None): 318 | def custom_forward(*inputs): 319 | if return_dict is not None: 320 | return module(*inputs, return_dict=return_dict) 321 | else: 322 | return module(*inputs) 323 | 324 | return custom_forward 325 | 326 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 327 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 328 | create_custom_forward(block), 329 | hidden_states, 330 | encoder_hidden_states, 331 | temb, 332 | ref_key, 333 | ref_value, 334 | **ckpt_kwargs, 335 | ) 336 | 337 | else: 338 | encoder_hidden_states, hidden_states = block( 339 | hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, ref_key=ref_key, ref_value=ref_value 340 | ) 341 | 342 | # controlnet residual 343 | if block_controlnet_hidden_states is not None and block.context_pre_only is False: 344 | interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) 345 | hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] 346 | 347 | hidden_states = self.norm_out(hidden_states, temb) 348 | hidden_states = self.proj_out(hidden_states) 349 | 350 | # unpatchify 351 | patch_size = self.config.patch_size 352 | height = height // patch_size 353 | width = width // patch_size 354 | 355 | hidden_states = hidden_states.reshape( 356 | shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) 357 | ) 358 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 359 | output = hidden_states.reshape( 360 | shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) 361 | ) 362 | 363 | if USE_PEFT_BACKEND: 364 | # remove `lora_scale` from each PEFT layer 365 | unscale_lora_layers(self, lora_scale) 366 | 367 | if not return_dict: 368 | return (output,ref_key, ref_value) 369 | 370 | return Transformer2DModelOutput(sample=output) 371 | -------------------------------------------------------------------------------- /src/transformer_vtoff.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Any, Dict, List, Optional, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 23 | from src.attention_vton_mixed import JointTransformerBlock 24 | from src.attention_processor_vton_mixed import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 25 | from diffusers.models.modeling_utils import ModelMixin 26 | from diffusers.models.normalization import AdaLayerNormContinuous 27 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 28 | from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed 29 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 30 | 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | 35 | class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 36 | """ 37 | The Transformer model introduced in Stable Diffusion 3. 38 | 39 | Reference: https://arxiv.org/abs/2403.03206 40 | 41 | Parameters: 42 | sample_size (`int`): The width of the latent images. This is fixed during training since 43 | it is used to learn a number of position embeddings. 44 | patch_size (`int`): Patch size to turn the input data into small patches. 45 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 46 | num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. 47 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 48 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 49 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 50 | caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. 51 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 52 | out_channels (`int`, defaults to 16): Number of output channels. 53 | 54 | """ 55 | 56 | _supports_gradient_checkpointing = True 57 | 58 | @register_to_config 59 | def __init__( 60 | self, 61 | sample_size: int = 128, 62 | patch_size: int = 2, 63 | in_channels: int = 16, 64 | num_layers: int = 18, 65 | attention_head_dim: int = 64, 66 | num_attention_heads: int = 18, 67 | joint_attention_dim: int = 4096, 68 | caption_projection_dim: int = 1152, 69 | pooled_projection_dim: int = 2048, 70 | out_channels: int = 16, 71 | pos_embed_max_size: int = 96, 72 | ): 73 | super().__init__() 74 | default_out_channels = in_channels 75 | self.out_channels = out_channels if out_channels is not None else default_out_channels 76 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 77 | 78 | self.pos_embed = PatchEmbed( 79 | height=self.config.sample_size, 80 | width=self.config.sample_size, 81 | patch_size=self.config.patch_size, 82 | in_channels=self.config.in_channels, 83 | embed_dim=self.inner_dim, 84 | pos_embed_max_size=pos_embed_max_size, # hard-code for now. 85 | ) 86 | self.time_text_embed = CombinedTimestepTextProjEmbeddings( 87 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 88 | ) 89 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) 90 | 91 | # `attention_head_dim` is doubled to account for the mixing. 92 | # It needs to crafted when we get the actual checkpoints. 93 | self.transformer_blocks = nn.ModuleList( 94 | [ 95 | JointTransformerBlock( 96 | dim=self.inner_dim, 97 | num_attention_heads=self.config.num_attention_heads, 98 | attention_head_dim=self.config.attention_head_dim, 99 | context_pre_only=i == num_layers - 1, 100 | ) 101 | for i in range(self.config.num_layers) 102 | ] 103 | ) 104 | 105 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 106 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 107 | 108 | self.gradient_checkpointing = False 109 | 110 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 111 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 112 | """ 113 | Sets the attention processor to use [feed forward 114 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 115 | 116 | Parameters: 117 | chunk_size (`int`, *optional*): 118 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 119 | over each tensor of dim=`dim`. 120 | dim (`int`, *optional*, defaults to `0`): 121 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 122 | or dim=1 (sequence length). 123 | """ 124 | if dim not in [0, 1]: 125 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 126 | 127 | # By default chunk size is 1 128 | chunk_size = chunk_size or 1 129 | 130 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 131 | if hasattr(module, "set_chunk_feed_forward"): 132 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 133 | 134 | for child in module.children(): 135 | fn_recursive_feed_forward(child, chunk_size, dim) 136 | 137 | for module in self.children(): 138 | fn_recursive_feed_forward(module, chunk_size, dim) 139 | 140 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking 141 | def disable_forward_chunking(self): 142 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 143 | if hasattr(module, "set_chunk_feed_forward"): 144 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 145 | 146 | for child in module.children(): 147 | fn_recursive_feed_forward(child, chunk_size, dim) 148 | 149 | for module in self.children(): 150 | fn_recursive_feed_forward(module, None, 0) 151 | 152 | @property 153 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 154 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 155 | r""" 156 | Returns: 157 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 158 | indexed by its weight name. 159 | """ 160 | # set recursively 161 | processors = {} 162 | 163 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 164 | if hasattr(module, "get_processor"): 165 | processors[f"{name}.processor"] = module.get_processor() 166 | 167 | for sub_name, child in module.named_children(): 168 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 169 | 170 | return processors 171 | 172 | for name, module in self.named_children(): 173 | fn_recursive_add_processors(name, module, processors) 174 | 175 | return processors 176 | 177 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 178 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 179 | r""" 180 | Sets the attention processor to use to compute attention. 181 | 182 | Parameters: 183 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 184 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 185 | for **all** `Attention` layers. 186 | 187 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 188 | processor. This is strongly recommended when setting trainable attention processors. 189 | 190 | """ 191 | count = len(self.attn_processors.keys()) 192 | 193 | if isinstance(processor, dict) and len(processor) != count: 194 | raise ValueError( 195 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 196 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 197 | ) 198 | 199 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 200 | if hasattr(module, "set_processor"): 201 | if not isinstance(processor, dict): 202 | module.set_processor(processor) 203 | else: 204 | module.set_processor(processor.pop(f"{name}.processor")) 205 | 206 | for sub_name, child in module.named_children(): 207 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 208 | 209 | for name, module in self.named_children(): 210 | fn_recursive_attn_processor(name, module, processor) 211 | 212 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 213 | def fuse_qkv_projections(self): 214 | """ 215 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 216 | are fused. For cross-attention modules, key and value projection matrices are fused. 217 | 218 | 219 | 220 | This API is 🧪 experimental. 221 | 222 | 223 | """ 224 | self.original_attn_processors = None 225 | 226 | for _, attn_processor in self.attn_processors.items(): 227 | if "Added" in str(attn_processor.__class__.__name__): 228 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 229 | 230 | self.original_attn_processors = self.attn_processors 231 | 232 | for module in self.modules(): 233 | if isinstance(module, Attention): 234 | module.fuse_projections(fuse=True) 235 | 236 | self.set_attn_processor(FusedJointAttnProcessor2_0()) 237 | 238 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 239 | def unfuse_qkv_projections(self): 240 | """Disables the fused QKV projection if enabled. 241 | 242 | 243 | 244 | This API is 🧪 experimental. 245 | 246 | 247 | 248 | """ 249 | if self.original_attn_processors is not None: 250 | self.set_attn_processor(self.original_attn_processors) 251 | 252 | def _set_gradient_checkpointing(self, module, value=False): 253 | if hasattr(module, "gradient_checkpointing"): 254 | module.gradient_checkpointing = value 255 | 256 | def forward( 257 | self, 258 | hidden_states: torch.FloatTensor, 259 | encoder_hidden_states: torch.FloatTensor = None, 260 | pooled_projections: torch.FloatTensor = None, 261 | timestep: torch.LongTensor = None, 262 | block_controlnet_hidden_states: List = None, 263 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 264 | return_dict: bool = True, 265 | ref_key: List = None, 266 | ref_value: List = None, 267 | pose_cond = None 268 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 269 | """ 270 | The [`SD3Transformer2DModel`] forward method. 271 | 272 | Args: 273 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 274 | Input `hidden_states`. 275 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 276 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 277 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 278 | from the embeddings of input conditions. 279 | timestep ( `torch.LongTensor`): 280 | Used to indicate denoising step. 281 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 282 | A list of tensors that if specified are added to the residuals of transformer blocks. 283 | joint_attention_kwargs (`dict`, *optional*): 284 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 285 | `self.processor` in 286 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 287 | return_dict (`bool`, *optional*, defaults to `True`): 288 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 289 | tuple. 290 | 291 | Returns: 292 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 293 | `tuple` where the first element is the sample tensor. 294 | """ 295 | if joint_attention_kwargs is not None: 296 | joint_attention_kwargs = joint_attention_kwargs.copy() 297 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 298 | else: 299 | lora_scale = 1.0 300 | 301 | if USE_PEFT_BACKEND: 302 | # weight the lora layers by setting `lora_scale` for each PEFT layer 303 | scale_lora_layers(self, lora_scale) 304 | else: 305 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 306 | logger.warning( 307 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 308 | ) 309 | 310 | height, width = hidden_states.shape[-2:] 311 | hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. 312 | temb = self.time_text_embed(timestep, pooled_projections) 313 | if encoder_hidden_states is not None: 314 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 315 | 316 | for index_block, block in enumerate(self.transformer_blocks): 317 | if self.training and self.gradient_checkpointing: 318 | 319 | def create_custom_forward(module, return_dict=None): 320 | def custom_forward(*inputs): 321 | if return_dict is not None: 322 | return module(*inputs, return_dict=return_dict) 323 | else: 324 | return module(*inputs) 325 | 326 | return custom_forward 327 | 328 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 329 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 330 | create_custom_forward(block), 331 | hidden_states, 332 | encoder_hidden_states, 333 | temb, 334 | ref_key[index_block], 335 | ref_value[index_block], 336 | **ckpt_kwargs, 337 | ) 338 | 339 | else: 340 | encoder_hidden_states, hidden_states = block( 341 | hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, ref_key=ref_key[index_block], ref_value=ref_value[index_block] 342 | ) 343 | 344 | # controlnet residual 345 | if block_controlnet_hidden_states is not None and block.context_pre_only is False: 346 | interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) 347 | hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] 348 | 349 | hidden_states = self.norm_out(hidden_states, temb) 350 | hidden_states = self.proj_out(hidden_states) 351 | 352 | # unpatchify 353 | patch_size = self.config.patch_size 354 | height = height // patch_size 355 | width = width // patch_size 356 | 357 | hidden_states = hidden_states.reshape( 358 | shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) 359 | ) 360 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 361 | output = hidden_states.reshape( 362 | shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) 363 | ) 364 | 365 | if USE_PEFT_BACKEND: 366 | # remove `lora_scale` from each PEFT layer 367 | unscale_lora_layers(self, lora_scale) 368 | 369 | if not return_dict: 370 | return (output,ref_key, ref_value) 371 | 372 | return Transformer2DModelOutput(sample=output) 373 | -------------------------------------------------------------------------------- /precompute_utils/precompute_text_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import json 4 | import os 5 | import sys 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | from pathlib import Path 8 | import torch 9 | from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast 10 | from dataset.dresscode import DressCodeDataset 11 | from dataset.vitonhd import VitonHDDataset 12 | from accelerate import Accelerator 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_args(input_args=None): 17 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 18 | parser.add_argument( 19 | "--pretrained_model_name_or_path", 20 | type=str, 21 | default=None, 22 | required=True, 23 | help="Path to pretrained model or model identifier from huggingface.co/models.", 24 | ) 25 | parser.add_argument( 26 | "--dataset_name", 27 | type=str, 28 | default=None, 29 | required=True, 30 | choices=["dresscode", "vitonhd"], 31 | help="Name of the dataset to use.", 32 | ) 33 | parser.add_argument( 34 | "--dataset_root", 35 | type=str, 36 | default=None, 37 | required=True, 38 | help="Path to the dataset root.", 39 | ) 40 | parser.add_argument( 41 | "--phase", 42 | type=str, 43 | required=True, 44 | choices=["train", "test"], 45 | help="Phase of the dataset to use.", 46 | ) 47 | parser.add_argument( 48 | "--order", 49 | type=str, 50 | required=True, 51 | choices=["unpaired", "paired"], 52 | help="Test order of the dataset to use.", 53 | ) 54 | parser.add_argument( 55 | "--category", 56 | type=str, 57 | required=False, 58 | choices=["all", "dresses", "upper_body", "lower_body"], 59 | help="Category of the dataset." 60 | ) 61 | parser.add_argument( 62 | "--revision", 63 | type=str, 64 | default=None, 65 | required=False, 66 | help="Revision of pretrained model identifier from huggingface.co/models.", 67 | ) 68 | parser.add_argument( 69 | "--captions_type", 70 | type=str, 71 | required=False, 72 | choices=["text_embeddings", "qwen_text_embeddings", "struct_text_embeddings"], 73 | default="text_embeddings", 74 | help="caption type, for example original (text embeddings), qwen_text_embeddings, struct_text_embeddings", 75 | ) 76 | parser.add_argument( 77 | "--variant", 78 | type=str, 79 | default=None, 80 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 81 | ) 82 | parser.add_argument( 83 | "--max_sequence_length", 84 | type=int, 85 | default=77, 86 | help="Maximum sequence length to use with with the T5 text encoder", 87 | ) 88 | parser.add_argument( 89 | "--height", 90 | type=int, 91 | default=1024, 92 | help="Height of the input images.", 93 | ) 94 | parser.add_argument( 95 | "--width", 96 | type=int, 97 | default=768, 98 | help="Width of the input images.", 99 | ) 100 | parser.add_argument( 101 | "--batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 102 | ) 103 | parser.add_argument( 104 | "--mixed_precision", 105 | type=str, 106 | default="fp16", 107 | choices=["fp16", "bf16", "no"], 108 | help="Mixed precision.", 109 | ) 110 | parser.add_argument( 111 | "--num_workers", 112 | type=int, 113 | default=0, 114 | help="Number of workers for the dataloader.", 115 | ) 116 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 117 | parser.add_argument("--text_encoders", type=str, nargs='+', default="CLIP") 118 | 119 | if input_args is not None: 120 | args = parser.parse_args(input_args) 121 | else: 122 | args = parser.parse_args() 123 | 124 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 125 | if env_local_rank != -1 and env_local_rank != args.local_rank: 126 | args.local_rank = env_local_rank 127 | 128 | return args 129 | 130 | def load_text_encoders(class_one, class_two, class_three, cache_dir=None): 131 | text_encoder_one, text_encoder_two, text_encoder_three = None, None, None 132 | if class_one is not None and class_two is not None: 133 | text_encoder_one = class_one.from_pretrained( 134 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant, 135 | ) 136 | text_encoder_two = class_two.from_pretrained( 137 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant, 138 | ) 139 | text_encoder_one.requires_grad_(False) 140 | text_encoder_two.requires_grad_(False) 141 | 142 | if class_three is not None: 143 | text_encoder_three = class_three.from_pretrained( 144 | args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant, 145 | ) 146 | text_encoder_three.requires_grad_(False) 147 | return text_encoder_one, text_encoder_two, text_encoder_three 148 | 149 | def encode_prompt( 150 | text_encoders, 151 | tokenizers, 152 | prompt: str, 153 | max_sequence_length, 154 | device=None, 155 | num_images_per_prompt: int = 1, 156 | ): 157 | prompt = [prompt] if isinstance(prompt, str) else prompt 158 | 159 | clip_tokenizers = tokenizers[:2] 160 | clip_text_encoders = text_encoders[:2] 161 | 162 | clip_prompt_embeds_list = [] 163 | clip_pooled_prompt_embeds_list = [] 164 | for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): 165 | if text_encoder is None: 166 | continue 167 | prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( 168 | text_encoder=text_encoder, 169 | tokenizer=tokenizer, 170 | prompt=prompt, 171 | device=device if device is not None else text_encoder.device, 172 | num_images_per_prompt=num_images_per_prompt, 173 | text_input_ids=None, 174 | ) 175 | clip_prompt_embeds_list.append(prompt_embeds) 176 | clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) 177 | 178 | if len(clip_prompt_embeds_list) > 0: 179 | clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) 180 | pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) 181 | else: pooled_prompt_embeds = None 182 | 183 | t5_prompt_embed = _encode_prompt_with_t5( 184 | text_encoders[-1], 185 | tokenizers[-1], 186 | max_sequence_length, 187 | prompt=prompt, 188 | num_images_per_prompt=num_images_per_prompt, 189 | device=device if device is not None else text_encoders[-1].device, 190 | ) 191 | 192 | if len(clip_prompt_embeds_list) > 0: 193 | clip_prompt_embeds = torch.nn.functional.pad( 194 | clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) 195 | ) 196 | else: clip_prompt_embeds = None 197 | 198 | return t5_prompt_embed, clip_prompt_embeds, pooled_prompt_embeds 199 | 200 | def _encode_prompt_with_t5( 201 | text_encoder, 202 | tokenizer, 203 | max_sequence_length, 204 | prompt=None, 205 | num_images_per_prompt=1, 206 | device=None, 207 | ): 208 | prompt = [prompt] if isinstance(prompt, str) else prompt 209 | batch_size = len(prompt) 210 | 211 | text_inputs = tokenizer( 212 | prompt, 213 | padding="max_length", 214 | max_length=max_sequence_length, 215 | truncation=True, 216 | add_special_tokens=True, 217 | return_tensors="pt", 218 | ) 219 | text_input_ids = text_inputs.input_ids 220 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 221 | 222 | dtype = text_encoder.dtype 223 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 224 | 225 | _, seq_len, _ = prompt_embeds.shape 226 | 227 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 228 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 229 | 230 | return prompt_embeds 231 | 232 | def encode_prompt_no_t5( 233 | text_encoders, 234 | tokenizers, 235 | prompt: str, 236 | max_sequence_length, 237 | device=None, 238 | num_images_per_prompt: int = 1, 239 | ): 240 | prompt = [prompt] if isinstance(prompt, str) else prompt 241 | 242 | clip_tokenizers = tokenizers[:2] 243 | clip_text_encoders = text_encoders[:2] 244 | 245 | clip_prompt_embeds_list = [] 246 | clip_pooled_prompt_embeds_list = [] 247 | for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): 248 | prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( 249 | text_encoder=text_encoder, 250 | tokenizer=tokenizer, 251 | prompt=prompt, 252 | device=device if device is not None else text_encoder.device, 253 | num_images_per_prompt=num_images_per_prompt, 254 | text_input_ids=None, 255 | ) 256 | clip_prompt_embeds_list.append(prompt_embeds) 257 | clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) 258 | 259 | clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) 260 | pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) 261 | 262 | t5_zeros = torch.zeros( 263 | ( 264 | clip_prompt_embeds.shape[0], # batch size 265 | max_sequence_length, # T5 sequence length 266 | 4096 #transformer.config.joint_attention_dim 267 | ), 268 | device=device if device is not None else clip_prompt_embeds.device, 269 | dtype=clip_prompt_embeds.dtype 270 | ) 271 | 272 | clip_prompt_embeds = torch.nn.functional.pad( 273 | clip_prompt_embeds, (0, t5_zeros.shape[-1] - clip_prompt_embeds.shape[-1]) 274 | ) 275 | prompt_embeds = torch.cat([clip_prompt_embeds, t5_zeros], dim=-2) 276 | 277 | return prompt_embeds, pooled_prompt_embeds 278 | 279 | def _encode_prompt_with_clip( 280 | text_encoder, 281 | tokenizer, 282 | prompt: str, 283 | device=None, 284 | text_input_ids=None, 285 | num_images_per_prompt: int = 1, 286 | ): 287 | prompt = [prompt] if isinstance(prompt, str) else prompt 288 | batch_size = len(prompt) 289 | 290 | if tokenizer is not None: 291 | text_inputs = tokenizer( 292 | prompt, 293 | padding="max_length", 294 | max_length=77, 295 | truncation=True, 296 | return_tensors="pt", 297 | ) 298 | 299 | text_input_ids = text_inputs.input_ids 300 | else: 301 | if text_input_ids is None: 302 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 303 | 304 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) 305 | 306 | pooled_prompt_embeds = prompt_embeds[0] 307 | prompt_embeds = prompt_embeds.hidden_states[-2] 308 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) 309 | 310 | _, seq_len, _ = prompt_embeds.shape 311 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 312 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 313 | 314 | return prompt_embeds, pooled_prompt_embeds 315 | 316 | def import_model_class_from_model_name_or_path( 317 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 318 | ): 319 | text_encoder_config = PretrainedConfig.from_pretrained( 320 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 321 | ) 322 | model_class = text_encoder_config.architectures[0] 323 | if model_class == "CLIPTextModelWithProjection": 324 | from transformers import CLIPTextModelWithProjection 325 | 326 | return CLIPTextModelWithProjection 327 | elif model_class == "T5EncoderModel": 328 | from transformers import T5EncoderModel 329 | 330 | return T5EncoderModel 331 | else: 332 | raise ValueError(f"{model_class} is not supported.") 333 | 334 | 335 | @torch.inference_mode() 336 | def main(args): 337 | width, height = args.width, args.height 338 | 339 | accelerator = Accelerator( 340 | mixed_precision=args.mixed_precision, 341 | ) 342 | device = accelerator.device 343 | 344 | weight_dtype = torch.float32 345 | if accelerator.mixed_precision == "fp16": 346 | weight_dtype = torch.float16 347 | elif accelerator.mixed_precision == "bf16": 348 | weight_dtype = torch.bfloat16 349 | 350 | def compute_text_embeddings(prompt, text_encoders, tokenizers): 351 | with torch.no_grad(): 352 | if text_encoders[-1] is not None: 353 | t5_prompt_embeds, prompt_embeds, pooled_prompt_embeds = encode_prompt( 354 | text_encoders, tokenizers, prompt, args.max_sequence_length 355 | ) 356 | else: 357 | prompt_embeds, pooled_prompt_embeds = encode_prompt_no_t5( 358 | text_encoders, tokenizers, prompt, args.max_sequence_length 359 | ) 360 | t5_prompt_embeds=None 361 | 362 | if prompt_embeds is not None: 363 | prompt_embeds = prompt_embeds.to(accelerator.device) 364 | if pooled_prompt_embeds is not None: 365 | pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) 366 | if t5_prompt_embeds is not None: t5_prompt_embeds = t5_prompt_embeds.to(accelerator.device) 367 | 368 | return t5_prompt_embeds, prompt_embeds, pooled_prompt_embeds 369 | 370 | tokenizer_one = CLIPTokenizer.from_pretrained( 371 | args.pretrained_model_name_or_path, 372 | subfolder="tokenizer", 373 | revision=args.revision, 374 | cache_dir = args.cache_dir 375 | ) 376 | tokenizer_two = CLIPTokenizer.from_pretrained( 377 | args.pretrained_model_name_or_path, 378 | subfolder="tokenizer_2", 379 | revision=args.revision, 380 | cache_dir = args.cache_dir 381 | ) 382 | tokenizer_three = T5TokenizerFast.from_pretrained( 383 | args.pretrained_model_name_or_path, 384 | subfolder="tokenizer_3", 385 | revision=args.revision, 386 | low_cpu_mem_usage=True, 387 | cache_dir = args.cache_dir 388 | ) 389 | 390 | text_encoder_cls_one, text_encoder_cls_two = None, None 391 | if "CLIP" in args.text_encoders: 392 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 393 | args.pretrained_model_name_or_path, args.revision 394 | ) 395 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 396 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" 397 | ) 398 | text_encoder_cls_three = None 399 | if "T5" in args.text_encoders: 400 | text_encoder_cls_three = import_model_class_from_model_name_or_path( 401 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" 402 | ) 403 | 404 | text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( 405 | text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, 406 | ) 407 | 408 | if "CLIP" in args.text_encoders: 409 | text_encoder_one.to(accelerator.device, dtype=weight_dtype) 410 | text_encoder_two.to(accelerator.device, dtype=weight_dtype) 411 | if "T5" in args.text_encoders: 412 | text_encoder_three.to(accelerator.device, dtype=weight_dtype) 413 | 414 | tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three] 415 | text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three] 416 | 417 | outputlist = ['captions', 'c_name', 'category'] 418 | if args.category == "all": 419 | args.category = ["upper_body", "lower_body", "dresses"] 420 | elif type(args.category) is not list: 421 | args.category = [args.category] 422 | 423 | if args.captions_type == "text_embeddings": 424 | coarse_caption_file = "coarse_captions.json" 425 | elif args.captions_type == "struct_text_embeddings": 426 | coarse_caption_file = "all_struct_captions.json" 427 | elif args.captions_type == "qwen_text_embeddings": 428 | coarse_caption_file = "qwen_captions_2_5_0_2.json" 429 | 430 | 431 | if args.dataset_name == "dresscode": 432 | test_dataset = DressCodeDataset( 433 | dataroot_path=args.dataset_root, 434 | phase=args.phase, 435 | order=args.order, 436 | radius=5, 437 | outputlist=outputlist, 438 | sketch_threshold_range=(20, 20), 439 | category=args.category, 440 | size=(height, width), 441 | coarse_caption_file=coarse_caption_file 442 | ) 443 | elif args.dataset_name == "vitonhd": 444 | test_dataset = VitonHDDataset( 445 | dataroot_path=args.dataset_root, 446 | phase=args.phase, 447 | order=args.order, 448 | radius=5, 449 | outputlist=outputlist, 450 | sketch_threshold_range=(20, 20), 451 | size=(height, width), 452 | caption_file=coarse_caption_file 453 | ) 454 | else: 455 | raise NotImplementedError(f"Dataset {args.dataset_name} not implemented") 456 | 457 | batch = test_dataset[0] 458 | dataroot = test_dataset.dataroot 459 | test_dataloader = torch.utils.data.DataLoader( 460 | test_dataset, 461 | shuffle=False, 462 | batch_size=args.batch_size, 463 | num_workers=args.num_workers, 464 | ) 465 | 466 | 467 | test_dataloader = accelerator.prepare(test_dataloader) 468 | 469 | # empty text 470 | prompts = [""] 471 | t5_prompt_embeds, clip_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( 472 | prompts, text_encoders, tokenizers 473 | ) 474 | pooled_feat_path = Path(args.dataset_root).parent / "empty_string_embeddings" / "CLIP_VIT_L_VIT_G_concat_pooled" / "empty_string.pt" 475 | clip_feat_path = Path(args.dataset_root).parent / "empty_string_embeddings" / "CLIP_VIT_L_VIT_G_concat" / "empty_string.pt" 476 | t5_feat_path = Path(args.dataset_root).parent / "empty_string_embeddings" / "T5_XXL" / "empty_string.pt" 477 | 478 | if not pooled_feat_path.parent.exists(): 479 | pooled_feat_path.parent.mkdir(parents=True, exist_ok=True) 480 | if not clip_feat_path.parent.exists(): 481 | clip_feat_path.parent.mkdir(parents=True, exist_ok=True) 482 | if not t5_feat_path.parent.exists(): 483 | t5_feat_path.parent.mkdir(parents=True, exist_ok=True) 484 | 485 | if pooled_prompt_embeds is not None: 486 | pooled_embeds = pooled_prompt_embeds 487 | torch.save(pooled_embeds, pooled_feat_path) 488 | if clip_prompt_embeds is not None: 489 | clip_embeds = clip_prompt_embeds 490 | torch.save(clip_embeds, clip_feat_path) 491 | if t5_prompt_embeds is not None: 492 | t5_embeds = t5_prompt_embeds 493 | torch.save(t5_embeds, t5_feat_path) 494 | 495 | # text embeddings 496 | for idx, batch in enumerate(tqdm(test_dataloader)): 497 | prompts = batch["captions"] 498 | t5_prompt_embeds, clip_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( 499 | prompts, text_encoders, tokenizers 500 | ) 501 | 502 | for i in range(len(prompts)): 503 | name = batch["c_name"][i] 504 | 505 | if args.dataset_name == "dresscode": 506 | pooled_feat_path = Path(dataroot) / args.captions_type / "CLIP_VIT_L_VIT_G_concat_pooled" / batch["category"][i] / f"{name.split('.')[0]}.pt" 507 | clip_feat_path = Path(dataroot) / args.captions_type / "CLIP_VIT_L_VIT_G_concat" / batch["category"][i] / f"{name.split('.')[0]}.pt" 508 | t5_feat_path = Path(dataroot) / args.captions_type / "T5_XXL" / batch["category"][i] / f"{name.split('.')[0]}.pt" 509 | elif args.dataset_name == "vitonhd": 510 | pooled_feat_path = Path(dataroot) / args.captions_type / "CLIP_VIT_L_VIT_G_concat_pooled" / args.phase / f"{name.split('.')[0]}.pt" 511 | clip_feat_path = Path(dataroot) / args.captions_type / "CLIP_VIT_L_VIT_G_concat" / args.phase / f"{name.split('.')[0]}.pt" 512 | t5_feat_path = Path(dataroot) / args.captions_type / "T5_XXL" / args.phase / f"{name.split('.')[0]}.pt" 513 | else: raise NotImplementedError("dataset not supported") 514 | 515 | if not pooled_feat_path.parent.exists(): 516 | pooled_feat_path.parent.mkdir(parents=True, exist_ok=True) 517 | if not clip_feat_path.parent.exists(): 518 | clip_feat_path.parent.mkdir(parents=True, exist_ok=True) 519 | if not t5_feat_path.parent.exists(): 520 | t5_feat_path.parent.mkdir(parents=True, exist_ok=True) 521 | 522 | if pooled_prompt_embeds is not None: 523 | pooled_embeds = pooled_prompt_embeds[i] 524 | torch.save(pooled_embeds, pooled_feat_path) 525 | if clip_prompt_embeds is not None: 526 | clip_embeds = clip_prompt_embeds[i] 527 | torch.save(clip_embeds, clip_feat_path) 528 | if t5_prompt_embeds is not None: 529 | t5_embeds = t5_prompt_embeds[i] 530 | torch.save(t5_embeds, t5_feat_path) 531 | 532 | 533 | if __name__ == '__main__': 534 | args = parse_args() 535 | main(args) -------------------------------------------------------------------------------- /src/transformer_vtoff_repa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Any, Dict, List, Optional, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 23 | from src.attention_vton_mixed import JointTransformerBlock 24 | from src.attention_processor_vton_mixed import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 25 | from diffusers.models.modeling_utils import ModelMixin 26 | from diffusers.models.normalization import AdaLayerNormContinuous 27 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 28 | from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed 29 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 30 | from typing import Literal 31 | import torch.nn.functional as F 32 | 33 | 34 | logger = logging.get_logger(__name__) 35 | 36 | class HiddenStateDownsamplerCNN(nn.Module): 37 | def __init__(self, in_features=1536, out_features=768, projector_dim=2048, target_h=32, target_w=32, downsample_method="adaptive_maxpool"): 38 | super().__init__() 39 | self.in_features = in_features 40 | self.out_features = out_features 41 | self.projector_dim = projector_dim 42 | self.target_h = target_h 43 | self.target_w = target_w 44 | self.downsample_method = downsample_method 45 | 46 | self.conv1 = nn.Conv2d(in_features, projector_dim, kernel_size=1) 47 | self.norm1 = nn.BatchNorm2d(projector_dim) 48 | self.act1 = nn.GELU() 49 | self.conv2 = nn.Conv2d(projector_dim, projector_dim, kernel_size=3, padding=1) 50 | self.norm2 = nn.BatchNorm2d(projector_dim) 51 | self.act2 = nn.GELU() 52 | 53 | if self.downsample_method == "adaptive_maxpool": 54 | self.downsampler = nn.AdaptiveMaxPool2d((target_h, target_w)) 55 | elif self.downsample_method == "adaptive_avgpool": 56 | self.downsampler = nn.AdaptiveAvgPool2d((target_h, target_w)) 57 | 58 | self.conv3 = nn.Conv2d(projector_dim, out_features, kernel_size=1) 59 | 60 | self.final_norm = nn.LayerNorm(self.out_features) 61 | 62 | def forward(self, hidden_states): 63 | B, T_in, D_in = hidden_states.shape 64 | H_in, W_in = 64, 48 65 | if not (H_in * W_in == T_in and D_in == self.in_features): 66 | raise ValueError(f"Unexpected input shape: {hidden_states.shape}. Expected T={H_in*W_in}, D={self.in_features}") 67 | 68 | x_grid = hidden_states.permute(0, 2, 1).reshape(B, D_in, H_in, W_in) 69 | x = self.act1(self.norm1(self.conv1(x_grid))) 70 | x = self.act2(self.norm2(self.conv2(x))) 71 | x_down = self.downsampler(x) 72 | x_out_grid = self.conv3(x_down) 73 | output = x_out_grid.reshape(B, self.out_features, self.target_h * self.target_w).permute(0, 2, 1) 74 | output = self.final_norm(output) 75 | return output 76 | 77 | 78 | class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 79 | """ 80 | The Transformer model introduced in Stable Diffusion 3. 81 | 82 | Reference: https://arxiv.org/abs/2403.03206 83 | 84 | Parameters: 85 | sample_size (`int`): The width of the latent images. This is fixed during training since 86 | it is used to learn a number of position embeddings. 87 | patch_size (`int`): Patch size to turn the input data into small patches. 88 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 89 | num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. 90 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 91 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 92 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 93 | caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. 94 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 95 | out_channels (`int`, defaults to 16): Number of output channels. 96 | 97 | """ 98 | 99 | _supports_gradient_checkpointing = True 100 | 101 | @register_to_config 102 | def __init__( 103 | self, 104 | sample_size: int = 128, 105 | patch_size: int = 2, 106 | in_channels: int = 16, 107 | num_layers: int = 18, 108 | attention_head_dim: int = 64, 109 | num_attention_heads: int = 18, 110 | joint_attention_dim: int = 4096, 111 | caption_projection_dim: int = 1152, 112 | pooled_projection_dim: int = 2048, 113 | out_channels: int = 16, 114 | pos_embed_max_size: int = 96, 115 | projector_dim: int = 2048, 116 | z_dims: List[int] = [768], 117 | encoder_depth: int = 8, 118 | probing_method: Literal["average_pool", "conv"] = "conv", 119 | ): 120 | super().__init__() 121 | default_out_channels = in_channels 122 | self.out_channels = out_channels if out_channels is not None else default_out_channels 123 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 124 | 125 | self.pos_embed = PatchEmbed( 126 | height=self.config.sample_size, 127 | width=self.config.sample_size, 128 | patch_size=self.config.patch_size, 129 | in_channels=self.config.in_channels, 130 | embed_dim=self.inner_dim, 131 | pos_embed_max_size=pos_embed_max_size, # hard-code for now. 132 | ) 133 | self.time_text_embed = CombinedTimestepTextProjEmbeddings( 134 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 135 | ) 136 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) 137 | 138 | # `attention_head_dim` is doubled to account for the mixing. 139 | # It needs to crafted when we get the actual checkpoints. 140 | self.transformer_blocks = nn.ModuleList( 141 | [ 142 | JointTransformerBlock( 143 | dim=self.inner_dim, 144 | num_attention_heads=self.config.num_attention_heads, 145 | attention_head_dim=self.config.attention_head_dim, 146 | context_pre_only=i == num_layers - 1, 147 | ) 148 | for i in range(self.config.num_layers) 149 | ] 150 | ) 151 | 152 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 153 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 154 | 155 | self.gradient_checkpointing = False 156 | 157 | self.probing_method = getattr(self.config, "probing_method", "average_pool") 158 | 159 | if self.probing_method == "conv": 160 | self.downsampler_cnn = nn.ModuleList([HiddenStateDownsamplerCNN( 161 | in_features=self.inner_dim, 162 | out_features=z_dim, 163 | projector_dim=self.projector_dim, 164 | target_h=32, 165 | target_w=32, 166 | downsample_method="adaptive_avgpool" 167 | ) for z_dim in z_dims]) 168 | self.projectors = None 169 | else: 170 | self.downsampler_cnn = None 171 | 172 | self.projectors = nn.ModuleList([ 173 | nn.Sequential( 174 | nn.Linear(self.inner_dim, projector_dim), 175 | nn.SiLU(), 176 | nn.Linear(projector_dim, projector_dim), 177 | nn.SiLU(), 178 | nn.Linear(projector_dim, z_dim), 179 | ) for z_dim in z_dims 180 | ]) 181 | 182 | self.encoder_depth = encoder_depth 183 | 184 | def initialize_weights(self): 185 | def _base_init(m): 186 | if isinstance(m, nn.Linear): 187 | nn.init.xavier_uniform_(m.weight) 188 | if m.bias is not None: 189 | nn.init.zeros_(m.bias) 190 | elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): 191 | nn.init.ones_(m.weight) 192 | nn.init.zeros_(m.bias) 193 | elif isinstance(m, nn.Conv2d): 194 | nn.init.normal_(m.weight, std=0.02) 195 | if m.bias is not None: 196 | nn.init.zeros_(m.bias) 197 | 198 | if self.projectors is not None: 199 | self.projectors.apply(_base_init) 200 | if hasattr(self, "downsampler_cnn") and self.downsampler_cnn is not None: 201 | self.downsampler_cnn.apply(_base_init) 202 | 203 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 204 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 205 | """ 206 | Sets the attention processor to use [feed forward 207 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 208 | 209 | Parameters: 210 | chunk_size (`int`, *optional*): 211 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 212 | over each tensor of dim=`dim`. 213 | dim (`int`, *optional*, defaults to `0`): 214 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 215 | or dim=1 (sequence length). 216 | """ 217 | if dim not in [0, 1]: 218 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 219 | 220 | # By default chunk size is 1 221 | chunk_size = chunk_size or 1 222 | 223 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 224 | if hasattr(module, "set_chunk_feed_forward"): 225 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 226 | 227 | for child in module.children(): 228 | fn_recursive_feed_forward(child, chunk_size, dim) 229 | 230 | for module in self.children(): 231 | fn_recursive_feed_forward(module, chunk_size, dim) 232 | 233 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking 234 | def disable_forward_chunking(self): 235 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 236 | if hasattr(module, "set_chunk_feed_forward"): 237 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 238 | 239 | for child in module.children(): 240 | fn_recursive_feed_forward(child, chunk_size, dim) 241 | 242 | for module in self.children(): 243 | fn_recursive_feed_forward(module, None, 0) 244 | 245 | @property 246 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 247 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 248 | r""" 249 | Returns: 250 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 251 | indexed by its weight name. 252 | """ 253 | # set recursively 254 | processors = {} 255 | 256 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 257 | if hasattr(module, "get_processor"): 258 | processors[f"{name}.processor"] = module.get_processor() 259 | 260 | for sub_name, child in module.named_children(): 261 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 262 | 263 | return processors 264 | 265 | for name, module in self.named_children(): 266 | fn_recursive_add_processors(name, module, processors) 267 | 268 | return processors 269 | 270 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 271 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 272 | r""" 273 | Sets the attention processor to use to compute attention. 274 | 275 | Parameters: 276 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 277 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 278 | for **all** `Attention` layers. 279 | 280 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 281 | processor. This is strongly recommended when setting trainable attention processors. 282 | 283 | """ 284 | count = len(self.attn_processors.keys()) 285 | 286 | if isinstance(processor, dict) and len(processor) != count: 287 | raise ValueError( 288 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 289 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 290 | ) 291 | 292 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 293 | if hasattr(module, "set_processor"): 294 | if not isinstance(processor, dict): 295 | module.set_processor(processor) 296 | else: 297 | module.set_processor(processor.pop(f"{name}.processor")) 298 | 299 | for sub_name, child in module.named_children(): 300 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 301 | 302 | for name, module in self.named_children(): 303 | fn_recursive_attn_processor(name, module, processor) 304 | 305 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 306 | def fuse_qkv_projections(self): 307 | """ 308 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 309 | are fused. For cross-attention modules, key and value projection matrices are fused. 310 | 311 | 312 | 313 | This API is 🧪 experimental. 314 | 315 | 316 | """ 317 | self.original_attn_processors = None 318 | 319 | for _, attn_processor in self.attn_processors.items(): 320 | if "Added" in str(attn_processor.__class__.__name__): 321 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 322 | 323 | self.original_attn_processors = self.attn_processors 324 | 325 | for module in self.modules(): 326 | if isinstance(module, Attention): 327 | module.fuse_projections(fuse=True) 328 | 329 | self.set_attn_processor(FusedJointAttnProcessor2_0()) 330 | 331 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 332 | def unfuse_qkv_projections(self): 333 | """Disables the fused QKV projection if enabled. 334 | 335 | 336 | 337 | This API is 🧪 experimental. 338 | 339 | 340 | 341 | """ 342 | if self.original_attn_processors is not None: 343 | self.set_attn_processor(self.original_attn_processors) 344 | 345 | def _set_gradient_checkpointing(self, module, value=False): 346 | if hasattr(module, "gradient_checkpointing"): 347 | module.gradient_checkpointing = value 348 | 349 | def forward( 350 | self, 351 | hidden_states: torch.FloatTensor, 352 | encoder_hidden_states: torch.FloatTensor = None, 353 | pooled_projections: torch.FloatTensor = None, 354 | timestep: torch.LongTensor = None, 355 | block_controlnet_hidden_states: List = None, 356 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 357 | return_dict: bool = True, 358 | ref_key: List = None, 359 | ref_value: List = None, 360 | pose_cond = None 361 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 362 | """ 363 | The [`SD3Transformer2DModel`] forward method. 364 | 365 | Args: 366 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 367 | Input `hidden_states`. 368 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 369 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 370 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 371 | from the embeddings of input conditions. 372 | timestep ( `torch.LongTensor`): 373 | Used to indicate denoising step. 374 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 375 | A list of tensors that if specified are added to the residuals of transformer blocks. 376 | joint_attention_kwargs (`dict`, *optional*): 377 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 378 | `self.processor` in 379 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 380 | return_dict (`bool`, *optional*, defaults to `True`): 381 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 382 | tuple. 383 | 384 | Returns: 385 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 386 | `tuple` where the first element is the sample tensor. 387 | """ 388 | if joint_attention_kwargs is not None: 389 | joint_attention_kwargs = joint_attention_kwargs.copy() 390 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 391 | else: 392 | lora_scale = 1.0 393 | 394 | if USE_PEFT_BACKEND: 395 | # weight the lora layers by setting `lora_scale` for each PEFT layer 396 | scale_lora_layers(self, lora_scale) 397 | else: 398 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 399 | logger.warning( 400 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 401 | ) 402 | 403 | height, width = hidden_states.shape[-2:] 404 | hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. 405 | temb = self.time_text_embed(timestep, pooled_projections) 406 | if encoder_hidden_states is not None: 407 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 408 | 409 | for index_block, block in enumerate(self.transformer_blocks): 410 | if self.training and self.gradient_checkpointing: 411 | 412 | def create_custom_forward(module, return_dict=None): 413 | def custom_forward(*inputs): 414 | if return_dict is not None: 415 | return module(*inputs, return_dict=return_dict) 416 | else: 417 | return module(*inputs) 418 | 419 | return custom_forward 420 | 421 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 422 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 423 | create_custom_forward(block), 424 | hidden_states, 425 | encoder_hidden_states, 426 | temb, 427 | ref_key[index_block], 428 | ref_value[index_block], 429 | **ckpt_kwargs, 430 | ) 431 | 432 | else: 433 | encoder_hidden_states, hidden_states = block( 434 | hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, ref_key=ref_key[index_block], ref_value=ref_value[index_block] 435 | ) 436 | if (index_block + 1) == self.encoder_depth: 437 | N, T, D = hidden_states.shape # N = batch size, T = sequence length, D = hidden dimension 438 | 439 | zs = [] 440 | if self.probing_method == "average_pool": 441 | hidden_states_transposed = hidden_states.transpose(1, 2) 442 | downsampled_hidden_states_transposed = F.avg_pool1d(hidden_states_transposed, kernel_size=3, stride=3) 443 | downsampled_hidden_states = downsampled_hidden_states_transposed.transpose(1, 2) 444 | zs = [] 445 | downsampled_hidden_states_fp32 = downsampled_hidden_states.float() 446 | 447 | for i, projector in enumerate(self.projectors): 448 | original_projector_dtype = next(projector.parameters()).dtype 449 | projector.float() 450 | try: 451 | z_output_fp32 = projector(downsampled_hidden_states_fp32) 452 | zs.append(z_output_fp32.to(hidden_states.dtype)) 453 | finally: 454 | projector.to(original_projector_dtype) 455 | 456 | elif self.probing_method == "conv" and any(p.requires_grad for p in self.downsampler_cnn.parameters()): 457 | zs = [] 458 | hidden_states_fp32 = hidden_states.float() 459 | for downsampler in self.downsampler_cnn: 460 | original_cnn_dtype = next(downsampler.parameters()).dtype 461 | downsampler.to(hidden_states.dtype) 462 | zs.append(downsampler(hidden_states)) 463 | downsampler.to(original_cnn_dtype) 464 | 465 | # controlnet residual 466 | if block_controlnet_hidden_states is not None and block.context_pre_only is False: 467 | interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) 468 | hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] 469 | 470 | hidden_states = self.norm_out(hidden_states, temb) 471 | hidden_states = self.proj_out(hidden_states) 472 | 473 | # unpatchify 474 | patch_size = self.config.patch_size 475 | height = height // patch_size 476 | width = width // patch_size 477 | 478 | hidden_states = hidden_states.reshape( 479 | shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) 480 | ) 481 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 482 | output = hidden_states.reshape( 483 | shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) 484 | ) 485 | 486 | if USE_PEFT_BACKEND: 487 | # remove `lora_scale` from each PEFT layer 488 | unscale_lora_layers(self, lora_scale) 489 | 490 | if not return_dict: 491 | return (output, ref_key, ref_value, zs) 492 | 493 | return Transformer2DModelOutput(sample=output) 494 | -------------------------------------------------------------------------------- /dataset/vitonhd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from pathlib import Path 4 | import glob 5 | import json 6 | from typing import Tuple, Literal 7 | import shutil 8 | 9 | import cv2 10 | import torch 11 | import torch.utils.data as data 12 | import torchvision.transforms as transforms 13 | from PIL import Image, ImageDraw, ImageOps 14 | import numpy as np 15 | from torchvision.ops import masks_to_boxes 16 | from utils.posemap import get_coco_body25_mapping 17 | from utils.posemap import kpoint_to_heatmap 18 | 19 | PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute() 20 | 21 | class VitonHDDataset(data.Dataset): 22 | """VITON-HD dataset class 23 | """ 24 | 25 | def __init__(self, 26 | dataroot_path: str, 27 | phase: Literal["train", "test"], 28 | radius=5, 29 | caption_file: Literal['coarse_captions.json', 'qwen_captions_2_5_0_2.json', 'all_captions_structural.json'] = 'coarse_captions.json', 30 | sketch_threshold_range: Tuple[int, int] = (20, 127), 31 | order: Literal['paired', 'unpaired'] = 'paired', 32 | outputlist: Tuple[str] = ('c_name', 'im_name', 'cloth', 'image', 'im_cloth', 'shape', 'pose_map', 33 | 'parse_array', 'im_mask', 'inpaint_mask', 'parse_mask_total', 34 | 'cloth_sketch', 'im_sketch', 'greyscale_im_sketch', 'captions', 'category', 35 | 'skeleton', 'category', 'texture', 'parse_cloth', 'cpath', 'cloth_embeddings', 'dwpose', 'vton_image_embeddings'), 36 | size: Tuple[int, int] = (512, 384), 37 | mask_type: Literal["keypoints", "bounding_box"] = "bounding_box", 38 | texture_order: Literal["shuffled", "original"] = "original", 39 | mask_items: Tuple[str] = ["original"], 40 | caption_type: Literal["original", "structural", "qwen"] = "original", 41 | n_chunks = 3, 42 | dilation = 7, 43 | ): 44 | """ 45 | Initialize the PyTroch Dataset Class 46 | :param dataroot_path: dataset root folder 47 | :type dataroot_path: string 48 | :param phase: phase (train | test) 49 | :type phase: string 50 | :param order: setting (paired | unpaired) 51 | :type order: string 52 | :param category: clothing category (upper_body | lower_body | dresses) 53 | :type category: list(str) 54 | :param size: image size (height, width) 55 | :type size: tuple(int) 56 | """ 57 | super(VitonHDDataset, self).__init__() 58 | self.dataroot = dataroot_path 59 | self.phase = phase 60 | self.sketch_threshold_range = sketch_threshold_range 61 | self.category = ('upper_body') 62 | self.outputlist = outputlist 63 | self.height = size[0] 64 | self.width = size[1] 65 | self.radius = radius 66 | self.dilation = dilation 67 | self.transform = transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 70 | ]) 71 | self.transform2D = transforms.Compose([ 72 | transforms.ToTensor(), 73 | transforms.Normalize((0.5,), (0.5,)) 74 | ]) 75 | self.order = order 76 | self.mask_type = mask_type 77 | self.texture_order = texture_order 78 | self.mask_items = mask_items 79 | 80 | im_names = [] 81 | c_names = [] 82 | dataroot_names = [] 83 | 84 | possible_outputs = ['c_name', 'im_name', 'cloth', 'image', 'im_cloth', 'shape', 'im_head', 'im_pose', 85 | 'pose_map', 'parse_array', 'dense_labels', 'dense_uv', 'skeleton', 'im_mask', 86 | 'im_mask_fine','inshop_mask', 87 | 'inpaint_mask', 'parse_mask_total', 'cloth_sketch', 'im_sketch', 'greyscale_im_sketch', 88 | 'captions', 'category', 'texture', 'parse_cloth', 'cpath', 'cloth_embeddings', 'dwpose', 89 | 'vton_image_embeddings', 'clip_pooled', 'clip_embeds', 't5_embeds'] 90 | 91 | assert all(x in possible_outputs for x in outputlist) 92 | 93 | # Load Captions 94 | with open(os.path.join(self.dataroot, caption_file)) as f: 95 | self.captions_dict = json.load(f) 96 | 97 | if caption_file == 'coarse_captions.json': 98 | self.caption_type = 'text_embeddings' 99 | elif caption_file == 'all_struct_captions.json': 100 | self.caption_type = 'struct_text_embeddings' 101 | elif 'qwen' in caption_file: 102 | self.caption_type = 'qwen_text_embeddings' 103 | else: raise NotImplementedError("not supported caption file") 104 | 105 | dataroot = self.dataroot 106 | if phase == 'train': 107 | filename = os.path.join(dataroot, f"{phase}_pairs.txt") 108 | else: 109 | filename = os.path.join(dataroot, f"{phase}_pairs.txt") 110 | texture_mapping = {} 111 | texture_filename = os.path.join(dataroot, f"test_shuffled_textures.txt") 112 | 113 | with open(filename, 'r') as f: 114 | data_len = len(f.readlines()) 115 | 116 | with open(filename, 'r') as f: 117 | for line in f.readlines(): 118 | if phase == 'train': 119 | im_name, _ = line.strip().split() 120 | c_name = im_name 121 | else: 122 | if order == 'paired': 123 | im_name, _ = line.strip().split() 124 | c_name = im_name 125 | else: 126 | im_name, c_name = line.strip().split() 127 | 128 | im_names.append(im_name) 129 | c_names.append(c_name) 130 | dataroot_names.append(dataroot) 131 | 132 | if texture_order == "shuffled": 133 | with open(texture_filename, 'r') as f: 134 | for line in f.readlines(): 135 | im_name, texture_name = line.strip().split() 136 | texture_mapping[im_name] = texture_name 137 | 138 | self.im_names = im_names 139 | self.c_names = c_names 140 | self.dataroot_names = dataroot_names 141 | self.texture_mapping = texture_mapping 142 | self.n_chunks = 3 143 | 144 | def __getitem__(self, index): 145 | """ 146 | For each index return the corresponding sample in the dataset 147 | :param index: data index 148 | :type index: int 149 | :return: dict containing dataset samples 150 | :rtype: dict 151 | """ 152 | c_name = self.c_names[index] 153 | im_name = self.im_names[index] 154 | dataroot = self.dataroot_names[index] 155 | category = 'upper_body' 156 | item_id = os.path.splitext(im_name)[0] 157 | 158 | sketch_threshold = random.randint( 159 | self.sketch_threshold_range[0], self.sketch_threshold_range[1]) 160 | 161 | if "captions" in self.outputlist: 162 | captions = self.captions_dict[c_name.split('_')[0]] 163 | if self.phase == 'train': 164 | random.shuffle(captions) 165 | n_chunks = min(len(captions), self.n_chunks) 166 | captions = captions[:n_chunks] 167 | captions = ", ".join(captions) 168 | 169 | if "cloth" in self.outputlist: 170 | cloth = Image.open(os.path.join( 171 | dataroot, self.phase, 'cloth', c_name)) 172 | cloth = cloth.resize((self.width, self.height)) 173 | cloth = self.transform(cloth) # [-1,1] 174 | cpath = os.path.join( 175 | self.dataroot, self.phase, 'cloth', c_name) 176 | if "inshop_mask" in self.outputlist: 177 | mask_filename = cpath.replace("cloth", "cloth-mask") 178 | mask_image = Image.open(mask_filename) 179 | mask_image = transforms.ToTensor()(mask_image) 180 | inshop_mask = mask_image[0, :, :] 181 | 182 | if "cloth_embeddings" in self.outputlist: 183 | feat_path = os.path.join(dataroot, 'cloth_embeddings', 'CLIP_VIT_L_VIT_G_concat', self.phase, f"{c_name.split('.')[0]}.pt") 184 | local_file = f'/tmp/vitonhd_cloth_{os.getpid()}_{item_id}.pt' 185 | shutil.copy(str(feat_path), local_file) 186 | cloth_embeddings = torch.load(local_file, 187 | map_location="cpu") 188 | os.remove(local_file) 189 | 190 | if "vton_image_embeddings" in self.outputlist: 191 | feat_path = os.path.join(dataroot, 'vton_embeddings', 'CLIP_VIT_L_VIT_G_concat', self.phase, f"{im_name.split('.')[0]}.pt") 192 | local_file = f'/tmp/vitonhd_vton_{os.getpid()}_{item_id}.pt' 193 | shutil.copy(str(feat_path), local_file) 194 | vton_image_embeddings = torch.load(feat_path, 195 | map_location="cpu") 196 | os.remove(local_file) 197 | 198 | if "clip_pooled" in self.outputlist: 199 | feat_path = Path(dataroot) / self.caption_type / "CLIP_VIT_L_VIT_G_concat_pooled" / self.phase / f"{c_name.split('.')[0]}.pt" 200 | local_file = f'/tmp/vitonhd_pooled_{os.getpid()}_{item_id}.pt' 201 | shutil.copy(str(feat_path), local_file) 202 | clip_pooled = torch.load(str(local_file), map_location="cpu") 203 | os.remove(local_file) 204 | 205 | if "clip_embeds" in self.outputlist: 206 | feat_path = Path(dataroot) / self.caption_type / "CLIP_VIT_L_VIT_G_concat" / self.phase / f"{c_name.split('.')[0]}.pt" 207 | local_file = f'/tmp/vitonhd_clip_embeds_{os.getpid()}_{item_id}.pt' 208 | shutil.copy(str(feat_path), local_file) 209 | clip_embeds = torch.load(str(local_file), map_location="cpu") 210 | os.remove(local_file) 211 | 212 | if "t5_embeds" in self.outputlist: 213 | feat_path = Path(dataroot) / self.caption_type / "T5_XXL" / self.phase / f"{c_name.split('.')[0]}.pt" 214 | local_file = f'/tmp/vitonhd_t5_{os.getpid()}_{item_id}.pt' 215 | shutil.copy(str(feat_path), local_file) 216 | t5_embeds = torch.load(str(local_file), map_location="cpu") 217 | os.remove(local_file) 218 | 219 | if "dwpose" in self.outputlist: 220 | dwpose = Image.open(os.path.join(dataroot, self.phase, 'dwpose', f"{im_name.split('.')[0]}_rendered.png")) 221 | 222 | if "image" in self.outputlist or "im_head" in self.outputlist or "im_cloth" in self.outputlist: 223 | image = Image.open(os.path.join( 224 | dataroot, self.phase, 'image', im_name)) 225 | image = image.resize((self.width, self.height)) 226 | image = self.transform(image) # [-1,1] 227 | 228 | if "texture" in self.outputlist: 229 | if self.texture_order == "shuffled": 230 | c_texture_name = self.texture_mapping[im_name] 231 | else: 232 | c_texture_name = c_name 233 | textures = glob.glob(os.path.join( 234 | dataroot, self.phase, 'textures', 'images', c_texture_name.replace('.jpg', ''), "*")) 235 | textures = sorted(textures, key=lambda x: int(str(Path(x).name).split('.')[0])) 236 | if self.phase == 'train': 237 | texture = Image.open(random.choice(textures)) 238 | else: 239 | texture = Image.open(textures[len(textures) // 2]) 240 | 241 | texture = texture.resize((224, 224)) 242 | texture = self.transform(texture) # [-1,1] 243 | 244 | if "cloth_sketch" in self.outputlist: 245 | cloth_sketch = Image.open( 246 | os.path.join(dataroot, self.phase, 'cloth_sketch', c_name.replace(".jpg", ".png"))) 247 | cloth_sketch = cloth_sketch.resize((self.width, self.height)) 248 | cloth_sketch = ImageOps.invert(cloth_sketch) 249 | cloth_sketch = cloth_sketch.point( 250 | lambda p: 255 if p > sketch_threshold else 0) 251 | cloth_sketch = cloth_sketch.convert("RGB") 252 | cloth_sketch = self.transform(cloth_sketch) # [-1,1] 253 | 254 | if "im_sketch" in self.outputlist or "greyscale_im_sketch" in self.outputlist: 255 | if "unpaired" == self.order and self.phase == 'test': 256 | greyscale_im_sketch = Image.open(os.path.join(dataroot, self.phase, 'im_sketch_unpaired', 257 | f'{im_name.replace(".jpg", "")}_{c_name.replace(".jpg", ".png")}')) 258 | else: 259 | greyscale_im_sketch = Image.open(os.path.join(dataroot, self.phase, 'im_sketch', c_name.replace(".jpg", ".png"))) 260 | 261 | greyscale_im_sketch = greyscale_im_sketch.resize((self.width, self.height)) 262 | greyscale_im_sketch = ImageOps.invert(greyscale_im_sketch) 263 | im_sketch = greyscale_im_sketch.point( 264 | lambda p: 255 if p > sketch_threshold else 0) 265 | im_sketch = transforms.functional.to_tensor(im_sketch) # [0,1] 266 | greyscale_im_sketch = transforms.functional.to_tensor(greyscale_im_sketch) # [0,1] 267 | im_sketch = 1 - im_sketch 268 | greyscale_im_sketch = 1 - greyscale_im_sketch 269 | 270 | labels = { 271 | 0: ['background', [0, 10]], 272 | 1: ['hair', [1, 2]], 273 | 2: ['face', [4, 13]], 274 | 3: ['upper', [5, 6, 7]], 275 | 4: ['bottom', [9, 12]], 276 | 5: ['left_arm', [14]], 277 | 6: ['right_arm', [15]], 278 | 7: ['left_leg', [16]], 279 | 8: ['right_leg', [17]], 280 | 9: ['left_shoe', [18]], 281 | 10: ['right_shoe', [19]], 282 | 11: ['socks', [8]], 283 | 12: ['noise', [3, 11]] 284 | } 285 | 286 | if "skeleton" in self.outputlist: 287 | skeleton = Image.open( 288 | os.path.join(dataroot, self.phase, 'openpose_img', im_name.replace('.jpg', '_rendered.png'))) 289 | skeleton = skeleton.resize((self.width, self.height)) 290 | skeleton = self.transform(skeleton) 291 | 292 | if "im_pose" in self.outputlist or "parser_mask" in self.outputlist or "im_mask" in self.outputlist or "im_cloth" in self.outputlist \ 293 | or "parse_mask_total" in self.outputlist or "parse_array" in self.outputlist \ 294 | or "pose_map" in self.outputlist or "parse_array" in self.outputlist or "shape" in self.outputlist \ 295 | or "im_head" in self.outputlist or "inpaint_mask" in self.outputlist or "im_mask_fine" in self.outputlist: 296 | parse_name = im_name.replace('.jpg', '.png') 297 | im_parse = Image.open(os.path.join( 298 | dataroot, self.phase, 'image-parse-v3', parse_name)) 299 | im_parse = im_parse.resize( 300 | (self.width, self.height), Image.NEAREST) 301 | im_parse_final = transforms.ToTensor()(im_parse) * 255 302 | parse_array = np.array(im_parse) 303 | 304 | parse_shape = (parse_array > 0).astype(np.float32) 305 | 306 | parse_hair = (parse_array == 1).astype(np.float32) + \ 307 | (parse_array == 2).astype(np.float32) 308 | 309 | parse_head = (parse_array == 1).astype(np.float32) + \ 310 | (parse_array == 2).astype(np.float32) + \ 311 | (parse_array == 4).astype(np.float32) + \ 312 | (parse_array == 13).astype(np.float32) 313 | 314 | parser_mask_fixed = (parse_array == 1).astype(np.float32) + \ 315 | (parse_array == 2).astype(np.float32) + \ 316 | (parse_array == 18).astype(np.float32) + \ 317 | (parse_array == 19).astype(np.float32) 318 | 319 | parser_mask_changeable = (parse_array == 0).astype(np.float32) 320 | 321 | arms = (parse_array == 14).astype(np.float32) + \ 322 | (parse_array == 15).astype(np.float32) 323 | 324 | parse_cloth = (parse_array == 5).astype(np.float32) + \ 325 | (parse_array == 6).astype(np.float32) + \ 326 | (parse_array == 7).astype(np.float32) 327 | parse_mask = (parse_array == 5).astype(np.float32) + \ 328 | (parse_array == 6).astype(np.float32) + \ 329 | (parse_array == 7).astype(np.float32) 330 | 331 | parser_mask_fixed = parser_mask_fixed + (parse_array == 9).astype(np.float32) + \ 332 | (parse_array == 12).astype(np.float32) 333 | 334 | parser_mask_changeable += np.logical_and( 335 | parse_array, np.logical_not(parser_mask_fixed)) 336 | 337 | parse_head = torch.from_numpy(parse_head) # [0,1] 338 | parse_cloth = torch.from_numpy(parse_cloth) # [0,1] 339 | parse_mask = torch.from_numpy(parse_mask) # [0,1] 340 | parser_mask_fixed = torch.from_numpy(parser_mask_fixed) 341 | parser_mask_changeable = torch.from_numpy(parser_mask_changeable) 342 | 343 | # dilation 344 | parse_without_cloth = np.logical_and( 345 | parse_shape, np.logical_not(parse_mask)) 346 | parse_mask = parse_mask.cpu().numpy() 347 | 348 | if "im_head" in self.outputlist: 349 | # Masked cloth 350 | im_head = image * parse_head - (1 - parse_head) 351 | if "im_cloth" in self.outputlist: 352 | im_cloth = image * parse_cloth + (1 - parse_cloth) 353 | if "im_mask_fine" in self.outputlist: 354 | im_mask_fine = image * (1-parse_cloth) 355 | 356 | # Shape 357 | parse_shape = Image.fromarray((parse_shape * 255).astype(np.uint8)) 358 | parse_shape = parse_shape.resize( 359 | (self.width // 16, self.height // 16), Image.BILINEAR) 360 | parse_shape = parse_shape.resize( 361 | (self.width, self.height), Image.BILINEAR) 362 | shape = self.transform2D(parse_shape) # [-1,1] 363 | 364 | # Load pose points 365 | pose_name = im_name.replace('.jpg', '_keypoints.json') 366 | with open(os.path.join(dataroot, self.phase, 'openpose_json', pose_name), 'r') as f: 367 | pose_label = json.load(f) 368 | pose_data = pose_label['people'][0]['pose_keypoints_2d'] 369 | pose_data = np.array(pose_data) 370 | pose_data = pose_data.reshape((-1, 3))[:, :2] 371 | 372 | pose_data[:, 0] = pose_data[:, 0] * (self.width / 768) 373 | pose_data[:, 1] = pose_data[:, 1] * (self.height / 1024) 374 | 375 | pose_mapping = get_coco_body25_mapping() 376 | point_num = len(pose_mapping) 377 | 378 | pose_map = torch.zeros(point_num, self.height, self.width) 379 | r = self.radius * (self.height / 512.0) 380 | im_pose = Image.new('L', (self.width, self.height)) 381 | pose_draw = ImageDraw.Draw(im_pose) 382 | neck = Image.new('L', (self.width, self.height)) 383 | neck_draw = ImageDraw.Draw(neck) 384 | for i in range(point_num): 385 | one_map = Image.new('L', (self.width, self.height)) 386 | draw = ImageDraw.Draw(one_map) 387 | 388 | point_x = np.multiply(pose_data[pose_mapping[i], 0], 1) 389 | point_y = np.multiply(pose_data[pose_mapping[i], 1], 1) 390 | 391 | if point_x > 1 and point_y > 1: 392 | draw.rectangle((point_x - r, point_y - r, 393 | point_x + r, point_y + r), 'white', 'white') 394 | pose_draw.rectangle( 395 | (point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white') 396 | if i == 2 or i == 5: 397 | neck_draw.ellipse((point_x - r * 4, point_y - r * 4, point_x + r * 4, point_y + r * 4), 'white', 398 | 'white') 399 | one_map = self.transform2D(one_map) 400 | pose_map[i] = one_map[0] 401 | 402 | d = [] 403 | 404 | for idx in range(point_num): 405 | ux = pose_data[pose_mapping[idx], 0] 406 | uy = (pose_data[pose_mapping[idx], 1]) 407 | 408 | px = ux 409 | py = uy 410 | 411 | d.append(kpoint_to_heatmap( 412 | np.array([px, py]), (self.height, self.width), 9)) 413 | 414 | pose_map = torch.stack(d) 415 | 416 | im_pose = self.transform2D(im_pose) 417 | 418 | im_arms = Image.new('L', (self.width, self.height)) 419 | arms_draw = ImageDraw.Draw(im_arms) 420 | 421 | with open(os.path.join(dataroot, self.phase, 'openpose_json', pose_name), 'r') as f: 422 | data = json.load(f) 423 | data = data['people'][0]['pose_keypoints_2d'] 424 | data = np.array(data) 425 | data = data.reshape((-1, 3))[:, :2] 426 | 427 | data[:, 0] = data[:, 0] * (self.width / 768) 428 | data[:, 1] = data[:, 1] * (self.height / 1024) 429 | 430 | shoulder_right = tuple(data[pose_mapping[2]]) 431 | shoulder_left = tuple(data[pose_mapping[5]]) 432 | elbow_right = tuple(data[pose_mapping[3]]) 433 | elbow_left = tuple(data[pose_mapping[6]]) 434 | wrist_right = tuple(data[pose_mapping[4]]) 435 | wrist_left = tuple(data[pose_mapping[7]]) 436 | 437 | ARM_LINE_WIDTH = int(90 / 512 * self.height) 438 | if wrist_right[0] <= 1. and wrist_right[1] <= 1.: 439 | if elbow_right[0] <= 1. and elbow_right[1] <= 1.: 440 | arms_draw.line( 441 | np.concatenate((wrist_left, elbow_left, shoulder_left, shoulder_right)).astype( 442 | np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') 443 | else: 444 | arms_draw.line(np.concatenate( 445 | (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right)).astype( 446 | np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') 447 | elif wrist_left[0] <= 1. and wrist_left[1] <= 1.: 448 | if elbow_left[0] <= 1. and elbow_left[1] <= 1.: 449 | arms_draw.line( 450 | np.concatenate((shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( 451 | np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') 452 | else: 453 | arms_draw.line(np.concatenate( 454 | (elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( 455 | np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') 456 | else: 457 | arms_draw.line(np.concatenate( 458 | (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( 459 | np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') 460 | 461 | hands = np.logical_and(np.logical_not(im_arms), arms) 462 | parse_mask += im_arms 463 | if not ("hands" in self.mask_items or "all" in self.mask_items): 464 | parser_mask_fixed += hands 465 | 466 | # delete neck 467 | parse_head_2 = torch.clone(parse_head) 468 | 469 | parser_mask_fixed = np.logical_or( 470 | parser_mask_fixed, np.array(parse_head_2, dtype=np.uint16)) 471 | parse_mask += np.logical_or(parse_mask, np.logical_and(np.array(parse_head, dtype=np.uint16), 472 | np.logical_not( 473 | np.array(parse_head_2, dtype=np.uint16)))) 474 | 475 | if self.mask_type == "bounding_box": 476 | parse_mask = np.logical_and( 477 | parser_mask_changeable, np.logical_not(parse_mask)) 478 | parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed) 479 | inpaint_mask = 1 - parse_mask_total 480 | 481 | bboxes = masks_to_boxes(inpaint_mask.unsqueeze(0)) 482 | bboxes = bboxes.type(torch.int32) 483 | xmin = bboxes[0, 0] 484 | xmax = bboxes[0, 2] 485 | ymin = bboxes[0, 1] 486 | ymax = bboxes[0, 3] 487 | 488 | inpaint_mask[ymin:ymax + 1, xmin:xmax + 1] = torch.logical_and( 489 | torch.ones_like( 490 | inpaint_mask[ymin:ymax + 1, xmin:xmax + 1]), 491 | torch.logical_not(parser_mask_fixed[ymin:ymax + 1, xmin:xmax + 1])) 492 | 493 | inpaint_mask = inpaint_mask.unsqueeze(0) 494 | im_mask = image * np.logical_not(inpaint_mask.repeat(3, 1, 1)) 495 | parse_mask_total = parse_mask_total.numpy() 496 | parse_mask_total = parse_array * parse_mask_total 497 | parse_mask_total = torch.from_numpy(parse_mask_total) 498 | elif self.mask_type == "keypoints": 499 | parse_mask = cv2.dilate(parse_mask, np.ones( 500 | (5, 5), np.uint16), iterations=5) 501 | parse_mask = np.logical_and( 502 | parser_mask_changeable, np.logical_not(parse_mask)) 503 | parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed) 504 | im_mask = image * parse_mask_total 505 | inpaint_mask = 1 - parse_mask_total 506 | inpaint_mask = inpaint_mask.unsqueeze(0) 507 | parse_mask_total = parse_mask_total.numpy() 508 | parse_mask_total = parse_array * parse_mask_total 509 | parse_mask_total = torch.from_numpy(parse_mask_total) 510 | else: 511 | raise NotImplementedError 512 | 513 | if "dense_uv" in self.outputlist: 514 | uv = np.load(os.path.join(dataroot, 'dense', 515 | im_name.replace('_0.jpg', '_5_uv.npz'))) 516 | uv = uv['uv'] 517 | uv = torch.from_numpy(uv) 518 | uv = transforms.functional.resize(uv, (self.height, self.width)) 519 | 520 | if "dense_labels" in self.outputlist: 521 | labels = Image.open(os.path.join( 522 | dataroot, 'dense', im_name.replace('_0.jpg', '_5.png'))) 523 | labels = labels.resize((self.width, self.height), Image.NEAREST) 524 | labels = np.array(labels) 525 | 526 | result = {} 527 | for k in self.outputlist: 528 | result[k] = vars()[k] 529 | 530 | return result 531 | 532 | def __len__(self): 533 | return len(self.c_names) -------------------------------------------------------------------------------- /dataset/dresscode.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from pathlib import Path 5 | from typing import Tuple, Literal 6 | import shutil 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | import torch.utils.data as data 12 | import torchvision.transforms as transforms 13 | from PIL import Image, ImageDraw, ImageOps 14 | from torchvision.ops import masks_to_boxes 15 | 16 | from utils.labelmap import label_map 17 | from utils.posemap import kpoint_to_heatmap 18 | import random 19 | 20 | torch.multiprocessing.set_sharing_strategy('file_system') 21 | 22 | 23 | class DressCodeDataset(data.Dataset): 24 | """ 25 | Dataset class for the Dress Code Multimodal Dataset 26 | im_parse: inpaint mask 27 | """ 28 | 29 | def __init__(self, 30 | dataroot_path: str, 31 | phase: Literal["train", "test"], 32 | radius: int = 5, 33 | caption_file: str = 'fine_captions.json', 34 | coarse_caption_file: Literal['coarse_captions.json', 'qwen_captions_2_5_0_2.json', 'all_captions_structural.json'] = 'coarse_captions.json', 35 | sketch_threshold_range: Tuple[int, int] = (20, 127), 36 | order: Literal['paired', 'unpaired'] = 'paired', 37 | outputlist: Tuple[str] = ('c_name', 'im_name', 'image', 'im_cloth', 'cloth', 'shape', 'pose_map', 38 | 'parse_array', 'im_mask', 'inpaint_mask', 'parse_mask_total', 39 | 'im_sketch', 'greyscale_im_sketch', 'captions', 'category', 'stitch_label', 40 | 'texture', 'cloth_embeddings', 'dwpose', 'vton_image_embeddings'), 41 | category: Literal['dresses', 'upper_body', 'lower_body'] = ('dresses', 'upper_body', 'lower_body'), 42 | size: Tuple[int, int] = (512, 384), 43 | mask_type: Literal["mask", "bounding_box"] = "bounding_box", 44 | texture_order: Literal["shuffled", "retrieved", "original"] = "original", 45 | mask_items: Tuple[str] = ["original"], 46 | mask_category: Literal['upper_body', 'lower_body'] = None, 47 | n_chunks=3, 48 | ): 49 | 50 | super().__init__() 51 | self.dataroot = dataroot_path 52 | self.phase = phase 53 | self.sketch_threshold_range = sketch_threshold_range 54 | self.category = category 55 | self.outputlist = outputlist 56 | self.height = size[0] 57 | self.width = size[1] 58 | self.radius = radius 59 | self.transform = transforms.Compose([ 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 62 | ]) 63 | self.transform2d = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.5,), (0.5,)) 66 | ]) 67 | self.order = order 68 | self.mask_type = mask_type 69 | self.texture_order = texture_order 70 | self.mask_items = mask_items 71 | self.mask_category = mask_category 72 | 73 | im_names = [] 74 | c_names = [] 75 | dataroot_names = [] 76 | category_names = [] 77 | 78 | possible_outputs = ['c_name', 'im_name', 'cpath', 'cloth', 'image', 'im_cloth', 'shape', 'im_head', 'im_pose', 79 | 'pose_map', 'parse_array', 'dense_labels', 'dense_uv', 'skeleton', 'im_mask', 80 | 'im_mask_fine', 'inshop_mask', 81 | 'inpaint_mask', 'greyscale_im_sketch', 'parse_mask_total', 'cloth_sketch', 'im_sketch', 82 | 'captions', 'category', 'hands', 'parse_head_2', 'stitch_label', 'texture', 'parse_cloth', 83 | 'cloth_embeddings', 'dwpose', 'vton_image_embeddings', 'clip_pooled', 'clip_embeds', 't5_embeds'] 84 | 85 | assert all(x in possible_outputs for x in outputlist) 86 | 87 | # Load Captions 88 | with open(os.path.join(self.dataroot, caption_file)) as f: 89 | self.captions_dict = json.load(f) 90 | self.captions_dict = {k: v for k, v in self.captions_dict.items() if len(v) >= 3} 91 | if coarse_caption_file is not None: 92 | with open(os.path.join(self.dataroot, coarse_caption_file)) as f: 93 | coarse_captions_dict = json.load(f) 94 | self.captions_dict.update(coarse_captions_dict) 95 | 96 | if coarse_caption_file == 'coarse_captions.json': 97 | self.caption_type = 'text_embeddings' 98 | elif coarse_caption_file == 'all_captions_structural.json': 99 | self.caption_type = 'struct_text_embeddings' 100 | elif 'demo_qwen' in coarse_caption_file: 101 | self.caption_type = 'demo_qwen_text_embeddings' 102 | self.captions_dict = coarse_captions_dict 103 | elif 'qwen' in coarse_caption_file: 104 | self.caption_type = 'qwen_text_embeddings' 105 | else: raise NotImplementedError("not supported caption file") 106 | del coarse_captions_dict 107 | 108 | for c in sorted(category): 109 | assert c in ['dresses', 'upper_body', 'lower_body'] 110 | 111 | dataroot = os.path.join(self.dataroot, c) 112 | if phase == 'train': 113 | filename = os.path.join(dataroot, f"{phase}_pairs.txt") 114 | else: 115 | filename = os.path.join(dataroot, f"{phase}_pairs_{order}.txt") 116 | 117 | 118 | with open(filename, 'r') as f: 119 | for line in f.readlines(): 120 | im_name, c_name = line.strip().split() 121 | if c_name.split('_')[0] not in self.captions_dict: 122 | continue 123 | if im_name in im_names: 124 | continue 125 | 126 | im_names.append(im_name) 127 | c_names.append(c_name) 128 | dataroot_names.append(dataroot) 129 | category_names.append(c) 130 | 131 | self.im_names = im_names 132 | self.c_names = c_names 133 | self.dataroot_names = dataroot_names 134 | self.category_names = category_names 135 | self.cloth_complete_names = {c_name: os.path.relpath(os.path.join(dataroot_names[idx], "cleaned_inshop_imgs", \ 136 | c_name.replace(".jpg", "_cleaned.jpg")), self.dataroot) for idx, c_name in enumerate(c_names)} 137 | if not os.path.isfile(os.path.join(dataroot, "complete_dict.json")): 138 | self.complete_dict = {self.cloth_complete_names[c_name]: list(dict.fromkeys(self.captions_dict[c_name.split('_')[0]])) \ 139 | for c_name in sorted(c_names)} 140 | with open(os.path.join(dataroot, f"complete_dict_{phase}.json"), "w") as f: 141 | json.dump(self.complete_dict, f, indent=2) 142 | else: 143 | with open(os.path.join(dataroot, f"complete_dict_{phase}.json")) as f: 144 | self.complete_dict = json.load(f) 145 | 146 | def __getitem__(self, index: int) -> dict: 147 | """For each index return the corresponding element in the dataset 148 | 149 | Args: 150 | index (int): element index 151 | 152 | Raises: 153 | NotImplementedError 154 | 155 | Returns: 156 | dict: 157 | c_name: filename of inshop cloth 158 | im_name: filename of model with cloth 159 | cloth: img of inshop cloth 160 | image: img of the model with that cloth 161 | im_cloth: cut cloth from the model 162 | im_mask: black mask of the cloth in the model img 163 | cloth_sketch: sketch of the inshop cloth 164 | im_sketch: sketch of "im_cloth" 165 | """ 166 | 167 | c_name = self.c_names[index] 168 | im_name = self.im_names[index] 169 | dataroot = self.dataroot_names[index] 170 | category = dataroot.split('/')[-1] 171 | item_id = os.path.splitext(im_name)[0] 172 | 173 | sketch_threshold = random.randint( 174 | self.sketch_threshold_range[0], self.sketch_threshold_range[1]) 175 | 176 | if "captions" in self.outputlist: 177 | captions = self.captions_dict[c_name.split('_')[0]] 178 | if self.phase == 'train': 179 | random.shuffle(captions) 180 | captions = ", ".join(captions) 181 | 182 | if "cloth" in self.outputlist: 183 | cloth_path = os.path.join( 184 | dataroot, 'cleaned_inshop_imgs', c_name.replace(".jpg", "_cleaned.jpg")) 185 | if not os.path.exists(cloth_path): cloth_path = os.path.join(dataroot, 'images', c_name) 186 | cloth = Image.open(cloth_path) 187 | cloth = cloth.resize((self.width, self.height)) 188 | cloth = self.transform(cloth) # [-1,1] 189 | 190 | if "cloth_embeddings" in self.outputlist: 191 | feat_path = os.path.join(str(Path(dataroot).parent), 'cloth_embeddings', 'CLIP_VIT_L_VIT_G_concat', category, f"{c_name.split('.')[0]}.pt") 192 | local_file = f'/tmp/dresscode_cloth_{os.getpid()}_{item_id}.pt' 193 | shutil.copy(str(feat_path), local_file) 194 | cloth_embeddings = torch.load(local_file, 195 | map_location="cpu") 196 | os.remove(local_file) 197 | 198 | if "vton_image_embeddings" in self.outputlist: 199 | feat_path = os.path.join(str(Path(dataroot).parent), 'vton_embeddings', 'CLIP_VIT_L_VIT_G_concat', category, f"{im_name.split('.')[0]}.pt") 200 | local_file = f'/tmp/dresscode_vton_{os.getpid()}_{item_id}.pt' 201 | shutil.copy(str(feat_path), local_file) 202 | vton_image_embeddings = torch.load(local_file, 203 | map_location="cpu") 204 | os.remove(local_file) 205 | 206 | 207 | if "clip_pooled" in self.outputlist: 208 | feat_path = Path(dataroot).parent / self.caption_type / "CLIP_VIT_L_VIT_G_concat_pooled" / category / f"{c_name.split('.')[0]}.pt" 209 | local_file = f'/tmp/dresscode_pooled_{os.getpid()}_{item_id}.pt' 210 | shutil.copy(str(feat_path), local_file) 211 | clip_pooled = torch.load(str(local_file), map_location="cpu") 212 | os.remove(local_file) 213 | 214 | if "clip_embeds" in self.outputlist: 215 | feat_path = Path(dataroot).parent / self.caption_type / "CLIP_VIT_L_VIT_G_concat" / category / f"{c_name.split('.')[0]}.pt" 216 | local_file = f'/tmp/dresscode_clip_embeds_{os.getpid()}_{item_id}.pt' 217 | shutil.copy(str(feat_path), local_file) 218 | clip_embeds = torch.load(str(local_file), map_location="cpu") 219 | os.remove(local_file) 220 | 221 | if "t5_embeds" in self.outputlist: 222 | feat_path = Path(dataroot).parent / self.caption_type / "T5_XXL" / category / f"{c_name.split('.')[0]}.pt" 223 | local_file = f'/tmp/dresscode_t5_{os.getpid()}_{item_id}.pt' 224 | shutil.copy(str(feat_path), local_file) 225 | t5_embeds = torch.load(str(local_file), map_location="cpu") 226 | os.remove(local_file) 227 | 228 | if "dwpose" in self.outputlist: 229 | dwpose = Image.open(os.path.join(dataroot, 'dwpose', f"{im_name.split('_')[0]}_2.png")) 230 | 231 | if "inshop_mask" in self.outputlist: 232 | mask_filename = os.path.join( 233 | dataroot, 'cleaned_inshop_masks', c_name.replace(".jpg", "_cleaned_mask.jpg")) 234 | mask_image = Image.open(mask_filename) 235 | mask_image = transforms.ToTensor()(mask_image) 236 | inshop_mask = mask_image[0, :, :] 237 | 238 | if "image" in self.outputlist or "im_head" in self.outputlist \ 239 | or "im_cloth" in self.outputlist: 240 | image = Image.open(os.path.join(dataroot, 'images', im_name)) 241 | image = image.resize((self.width, self.height)) 242 | image = self.transform(image) # [-1,1] 243 | 244 | if "im_sketch" in self.outputlist or "greyscale_im_sketch" in self.outputlist: 245 | if "unpaired" == self.order and self.phase == 'test': 246 | greyscale_im_sketch = Image.open(os.path.join(dataroot, 'im_sketch_unpaired', 247 | f'{im_name.replace(".jpg", "")}_{c_name.replace(".jpg", ".png")}')) 248 | else: 249 | greyscale_im_sketch = Image.open(os.path.join(dataroot, 'im_sketch', c_name.replace(".jpg", ".png"))) 250 | 251 | greyscale_im_sketch = greyscale_im_sketch.resize((self.width, self.height)) 252 | greyscale_im_sketch = ImageOps.invert(greyscale_im_sketch) 253 | im_sketch = greyscale_im_sketch.point( 254 | lambda p: 255 if p > sketch_threshold else 0) 255 | im_sketch = transforms.functional.to_tensor(im_sketch) # [0,1] 256 | greyscale_im_sketch = transforms.functional.to_tensor(greyscale_im_sketch) # [0,1] 257 | im_sketch = 1 - im_sketch 258 | greyscale_im_sketch = 1 - greyscale_im_sketch 259 | 260 | if "im_pose" in self.outputlist or "parser_mask" in self.outputlist or "im_cloth" in self.outputlist \ 261 | or "im_mask" in self.outputlist or "im_mask_fine" in self.outputlist or "parse_mask_total" in \ 262 | self.outputlist or "parse_array" in self.outputlist or \ 263 | "pose_map" in self.outputlist or "parse_array" in \ 264 | self.outputlist or "shape" in self.outputlist or "im_head" in self.outputlist or "inpaint_mask" in self.outputlist: 265 | # Label Map 266 | parse_name = im_name.replace('_0.jpg', '_4.png') 267 | im_parse = Image.open(os.path.join( 268 | dataroot, 'label_maps', parse_name)) 269 | im_parse = im_parse.resize( 270 | (self.width, self.height), Image.NEAREST) 271 | parse_array = np.array(im_parse) 272 | 273 | parse_shape = (parse_array > 0).astype(np.float32) 274 | 275 | parse_head = (parse_array == 1).astype(np.float32) + \ 276 | (parse_array == 2).astype(np.float32) + \ 277 | (parse_array == 3).astype(np.float32) + \ 278 | (parse_array == 11).astype(np.float32) 279 | 280 | parser_mask_fixed = (parse_array == label_map["hair"]).astype(np.float32) + \ 281 | (parse_array == label_map["left_shoe"]).astype(np.float32) + \ 282 | (parse_array == label_map["right_shoe"]).astype(np.float32) + \ 283 | (parse_array == label_map["hat"]).astype(np.float32) + \ 284 | (parse_array == label_map["sunglasses"]).astype(np.float32) + \ 285 | (parse_array == label_map["scarf"]).astype(np.float32) + \ 286 | (parse_array == label_map["bag"]).astype( 287 | np.float32) 288 | 289 | parser_mask_changeable = ( 290 | parse_array == label_map["background"]).astype(np.float32) 291 | 292 | arms = (parse_array == 14).astype(np.float32) + \ 293 | (parse_array == 15).astype(np.float32) 294 | 295 | category = dataroot.split('/')[-1] 296 | if dataroot.split('/')[-1] == 'dresses': 297 | label_cat = 7 298 | parse_cloth = (parse_array == 7).astype(np.float32) 299 | parse_mask = (parse_array == 7).astype(np.float32) + \ 300 | (parse_array == 12).astype(np.float32) + \ 301 | (parse_array == 13).astype(np.float32) 302 | parser_mask_changeable += np.logical_and( 303 | parse_array, np.logical_not(parser_mask_fixed)) 304 | 305 | # upper body 306 | elif dataroot.split('/')[-1] == 'upper_body' or (dataroot.split('/')[-1] == 'lower_body' and self.mask_category=='upper_body'): 307 | label_cat = 4 308 | parse_cloth = (parse_array == 4).astype(np.float32) 309 | parse_mask = (parse_array == 4).astype(np.float32) 310 | 311 | parser_mask_fixed += (parse_array == label_map["skirt"]).astype(np.float32) + \ 312 | (parse_array == label_map["pants"]).astype( 313 | np.float32) 314 | 315 | parser_mask_changeable += np.logical_and( 316 | parse_array, np.logical_not(parser_mask_fixed)) 317 | elif dataroot.split('/')[-1] == 'lower_body': 318 | label_cat = 6 319 | parse_cloth = (parse_array == 6).astype(np.float32) 320 | parse_mask = (parse_array == 6).astype(np.float32) + \ 321 | (parse_array == 12).astype(np.float32) + \ 322 | (parse_array == 13).astype(np.float32) 323 | 324 | parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \ 325 | (parse_array == 14).astype(np.float32) + \ 326 | (parse_array == 15).astype(np.float32) 327 | parser_mask_changeable += np.logical_and( 328 | parse_array, np.logical_not(parser_mask_fixed)) 329 | else: 330 | raise NotImplementedError 331 | 332 | parse_head = torch.from_numpy(parse_head) # [0,1] 333 | parse_cloth = torch.from_numpy(parse_cloth) # [0,1] 334 | parse_mask = torch.from_numpy(parse_mask) # [0,1] 335 | parser_mask_fixed = torch.from_numpy(parser_mask_fixed) 336 | parser_mask_changeable = torch.from_numpy(parser_mask_changeable) 337 | 338 | parse_without_cloth = np.logical_and( 339 | parse_shape, np.logical_not(parse_mask)) 340 | parse_mask = parse_mask.cpu().numpy() 341 | 342 | if "im_head" in self.outputlist: 343 | # Masked cloth 344 | im_head = image * parse_head - (1 - parse_head) 345 | if "im_cloth" in self.outputlist: 346 | im_cloth = image * parse_cloth + (1 - parse_cloth) 347 | if "im_mask_fine" in self.outputlist: 348 | im_mask_fine = image * (1-parse_cloth) 349 | 350 | # Shape 351 | parse_shape = Image.fromarray((parse_shape * 255).astype(np.uint8)) 352 | parse_shape = parse_shape.resize( 353 | (self.width // 16, self.height // 16), Image.BILINEAR) 354 | parse_shape = parse_shape.resize( 355 | (self.width, self.height), Image.BILINEAR) 356 | shape = self.transform2d(parse_shape) # [-1,1] 357 | 358 | # Load pose points 359 | pose_name = im_name.replace('_0.jpg', '_2.json') 360 | with open(os.path.join(dataroot, 'keypoints', pose_name), 'r') as f: 361 | pose_label = json.load(f) 362 | pose_data = pose_label['keypoints'] 363 | pose_data = np.array(pose_data) 364 | pose_data = pose_data.reshape((-1, 4)) 365 | 366 | point_num = pose_data.shape[0] 367 | pose_map = torch.zeros(point_num, self.height, self.width) 368 | r = self.radius * (self.height / 512.0) 369 | im_pose = Image.new('L', (self.width, self.height)) 370 | pose_draw = ImageDraw.Draw(im_pose) 371 | neck = Image.new('L', (self.width, self.height)) 372 | neck_draw = ImageDraw.Draw(neck) 373 | for i in range(point_num): 374 | one_map = Image.new('L', (self.width, self.height)) 375 | draw = ImageDraw.Draw(one_map) 376 | point_x = np.multiply(pose_data[i, 0], self.width / 384.0) 377 | point_y = np.multiply(pose_data[i, 1], self.height / 512.0) 378 | if point_x > 1 and point_y > 1: 379 | draw.rectangle((point_x - r, point_y - r, 380 | point_x + r, point_y + r), 'white', 'white') 381 | pose_draw.rectangle( 382 | (point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white') 383 | if i == 2 or i == 5: 384 | neck_draw.ellipse((point_x - r * 4, point_y - r * 4, point_x + r * 4, point_y + r * 4), 'white', 385 | 'white') 386 | one_map = self.transform2d(one_map) 387 | pose_map[i] = one_map[0] 388 | 389 | d = [] 390 | for pose_d in pose_data: 391 | ux = pose_d[0] / 384.0 392 | uy = pose_d[1] / 512.0 393 | 394 | # scale posemap points 395 | px = ux * self.width 396 | py = uy * self.height 397 | 398 | d.append(kpoint_to_heatmap( 399 | np.array([px, py]), (self.height, self.width), 9)) 400 | 401 | pose_map = torch.stack(d) 402 | 403 | # just for visualization 404 | im_pose = self.transform2d(im_pose) 405 | 406 | im_arms = Image.new('L', (self.width, self.height)) 407 | arms_draw = ImageDraw.Draw(im_arms) 408 | if dataroot.split('/')[-1] == 'dresses' or dataroot.split('/')[-1] == 'upper_body' or dataroot.split('/')[ 409 | -1] == 'lower_body': 410 | with open(os.path.join(dataroot, 'keypoints', pose_name), 'r') as f: 411 | data = json.load(f) 412 | shoulder_right = np.multiply( 413 | tuple(data['keypoints'][2][:2]), self.height / 512.0) 414 | shoulder_left = np.multiply( 415 | tuple(data['keypoints'][5][:2]), self.height / 512.0) 416 | elbow_right = np.multiply( 417 | tuple(data['keypoints'][3][:2]), self.height / 512.0) 418 | elbow_left = np.multiply( 419 | tuple(data['keypoints'][6][:2]), self.height / 512.0) 420 | wrist_right = np.multiply( 421 | tuple(data['keypoints'][4][:2]), self.height / 512.0) 422 | wrist_left = np.multiply( 423 | tuple(data['keypoints'][7][:2]), self.height / 512.0) 424 | if wrist_right[0] <= 1. and wrist_right[1] <= 1.: 425 | if elbow_right[0] <= 1. and elbow_right[1] <= 1.: 426 | arms_draw.line( 427 | np.concatenate((wrist_left, elbow_left, shoulder_left, shoulder_right)).astype( 428 | np.uint16).tolist(), 'white', 45, 'curve') 429 | else: 430 | arms_draw.line(np.concatenate( 431 | (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right)).astype( 432 | np.uint16).tolist(), 'white', 45, 'curve') 433 | elif wrist_left[0] <= 1. and wrist_left[1] <= 1.: 434 | if elbow_left[0] <= 1. and elbow_left[1] <= 1.: 435 | arms_draw.line( 436 | np.concatenate((shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( 437 | np.uint16).tolist(), 'white', 45, 'curve') 438 | else: 439 | arms_draw.line(np.concatenate( 440 | (elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( 441 | np.uint16).tolist(), 'white', 45, 'curve') 442 | else: 443 | arms_draw.line(np.concatenate( 444 | (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( 445 | np.uint16).tolist(), 'white', 45, 'curve') 446 | 447 | hands = np.logical_and(np.logical_not(im_arms), arms) 448 | 449 | if dataroot.split('/')[-1] == 'dresses' or dataroot.split('/')[-1] == 'upper_body': 450 | parse_mask += im_arms 451 | if not ("hands" in self.mask_items or "all" in self.mask_items): 452 | parser_mask_fixed += hands 453 | 454 | # delete neck 455 | parse_head_2 = torch.clone(parse_head) 456 | if dataroot.split('/')[-1] == 'dresses' or dataroot.split('/')[-1] == 'upper_body': 457 | with open(os.path.join(dataroot, 'keypoints', pose_name), 'r') as f: 458 | data = json.load(f) 459 | points = [] 460 | points.append(np.multiply( 461 | tuple(data['keypoints'][2][:2]), self.height / 512.0)) 462 | points.append(np.multiply( 463 | tuple(data['keypoints'][5][:2]), self.height / 512.0)) 464 | x_coords, y_coords = zip(*points) 465 | A = np.vstack([x_coords, np.ones(len(x_coords))]).T 466 | m, c = np.linalg.lstsq(A, y_coords, rcond=None)[0] 467 | for i in range(parse_array.shape[1]): 468 | y = i * m + c 469 | parse_head_2[int( 470 | y - 20 * (self.height / 512.0)):, i] = 0 471 | 472 | parser_mask_fixed = np.logical_or( 473 | parser_mask_fixed, np.array(parse_head_2, dtype=np.uint16)) 474 | parse_mask += np.logical_or(parse_mask, np.logical_and(np.array(parse_head, dtype=np.uint16), 475 | np.logical_not( 476 | np.array(parse_head_2, dtype=np.uint16)))) 477 | 478 | parse_mask = cv2.dilate(parse_mask, np.ones( 479 | (5, 5), np.uint16), iterations=5) 480 | parse_mask = np.logical_and( 481 | parser_mask_changeable, np.logical_not(parse_mask)) 482 | parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed) 483 | im_mask = image * parse_mask_total 484 | inpaint_mask = 1 - parse_mask_total 485 | 486 | if self.mask_type == 'bounding_box': 487 | bboxes = masks_to_boxes(inpaint_mask.unsqueeze(0)) 488 | bboxes = bboxes.type(torch.int32) 489 | xmin = bboxes[0, 0] 490 | xmax = bboxes[0, 2] 491 | ymin = bboxes[0, 1] 492 | ymax = bboxes[0, 3] 493 | 494 | inpaint_mask[ymin:ymax + 1, xmin:xmax + 1] = torch.logical_and( 495 | torch.ones_like(inpaint_mask[ymin:ymax + 1, xmin:xmax + 1]), 496 | torch.logical_not(parser_mask_fixed[ymin:ymax + 1, xmin:xmax + 1])) 497 | 498 | inpaint_mask = inpaint_mask.unsqueeze(0) 499 | im_mask = image * np.logical_not(inpaint_mask.repeat(3, 1, 1)) 500 | parse_mask_total = parse_mask_total.numpy() 501 | parse_mask_total = parse_array * parse_mask_total 502 | parse_mask_total = torch.from_numpy(parse_mask_total) 503 | elif self.mask_type == "mask": 504 | inpaint_mask = inpaint_mask.unsqueeze(0) 505 | parse_mask_total = parse_mask_total.numpy() 506 | parse_mask_total = parse_array * parse_mask_total 507 | parse_mask_total = torch.from_numpy(parse_mask_total) 508 | else: 509 | raise ValueError("Unknown mask type") 510 | 511 | if "stitch_label" in self.outputlist: 512 | stitch_labelmap = Image.open(os.path.join( 513 | self.dataroot, 'test_stitchmap', im_name.replace(".jpg", ".png"))) 514 | stitch_labelmap = transforms.ToTensor()(stitch_labelmap) * 255 515 | stitch_label = stitch_labelmap == 13 516 | 517 | cpath = self.cloth_complete_names[c_name] 518 | result = {} 519 | for k in self.outputlist: 520 | result[k] = vars()[k] 521 | 522 | return result 523 | 524 | def __len__(self) -> int: 525 | """Return dataset length 526 | 527 | Returns: 528 | int: dataset length 529 | """ 530 | 531 | return len(self.c_names) --------------------------------------------------------------------------------