├── .DS_Store ├── LICENSE ├── README.md ├── datasets ├── dreampose_dataset.py ├── train_vae_dataset.py └── ubc_dataset.py ├── demo └── sample │ ├── key_frame.png │ └── train │ ├── frame_50.png │ └── frame_50_densepose.npy ├── finetune-unet.py ├── finetune-vae.py ├── media ├── DreamPose.png ├── Teaser.png ├── demo.gif └── demo.mov ├── models ├── .DS_Store └── unet_dual_encoder.py ├── pipelines └── dual_encoder_pipeline.py ├── test.py ├── train.py └── utils ├── densepose.py └── parse_args.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannakarras/DreamPose/5bf30b7df70cf6f2e0bb25556c6ff2cbf0f2b1bf/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Johanna Karras 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DreamPose 2 | Official implementation of "DreamPose: Fashion Image-to-Video Synthesis via Stable Diffusion" by Johanna Karras, Aleksander Holynski, Ting-Chun Wang, and Ira Kemelmacher-Shlizerman. 3 | 4 | * [Project Page](https://grail.cs.washington.edu/projects/dreampose) 5 | * [Paper](https://arxiv.org/abs/2304.06025) 6 | 7 | ![Teaser Image](media/Teaser.png "Teaser") 8 | 9 | ## Demo 10 | 11 | You can generate a video using DreamPose using our pretrained models. 12 | 13 | 1. [Download](https://drive.google.com/drive/folders/15SaT3kZFRIjxuHT6UrGr6j0183clTK_D?usp=share_link) and unzip the pretrained models inside demo/custom-chkpts.zip 14 | 2. [Download](https://drive.google.com/drive/folders/1CjzcOp_ZUt-dyrzNAFE0T8bS3cbKTsVG?usp=share_link) and unzip the input poses inside demo/sample/poses.zip 15 | 3. Run demo.py using the command below: 16 | ``` 17 | python test.py --epoch 499 --folder demo/custom-chkpts --pose_folder demo/sample/poses --key_frame_path demo/sample/key_frame.png --s1 8 --s2 3 --n_steps 100 --output_dir demo/sample/results --custom_vae demo/custom-chkpts/vae_1499.pth 18 | ``` 19 | ## Data Preparation 20 | 21 | To prepare a sample for finetuning, create a directory containing train and test subdirectories containing the train frames (desired subject) and test frames (desired pose sequence), respectively. Note that the test frames are not expected to be of the same subject. See demo/sample for an example. 22 | 23 | Then, run [DensePose](https://github.com/facebookresearch/detectron2/tree/main/projects/DensePose) using the "densepose_rcnn_R_50_FPN_s1x" checkpoint on all images in the sample directory. Finally, reformat the pickled DensePose output using utils/densepose.py. You need to change the "outpath" filepath to point to the pickled DensePose output. 24 | 25 | ## Download or Finetune Base Model 26 | 27 | DreamPose is finetuned on the UBC Fashion Dataset from a pretrained Stable Diffusion checkpoint. You can download our pretrained base model from [Google Drive](https://drive.google.com/file/d/10JjayW2mMqGxhUyM9ds_GHEvuqCTDaH3/view?usp=share_link), or finetune pretrained Stable Diffusion on your own image dataset. We train on 2 NVIDIA A100 GPUs. 28 | 29 | ``` 30 | accelerate launch --num_processes=4 train.py --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" --instance_data_dir=../path/to/dataset --output_dir=checkpoints --resolution=512 --train_batch_size=2 --gradient_accumulation_steps=4 --learning_rate=5e-6 --lr_scheduler="constant" --lr_warmup_steps=0 --num_train_epochs=300 --run_name dreampose --dropout_rate=0.15 --revision "ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c" 31 | ``` 32 | 33 | ## Finetune on Sample 34 | 35 | In this next step, we finetune DreamPose on a one or more input frames to create a subject-specific model. 36 | 37 | 1. Finetune the UNet 38 | 39 | ``` 40 | accelerate launch finetune-unet.py --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" --instance_data_dir=demo/sample/train --output_dir=demo/custom-chkpts --resolution=512 --train_batch_size=1 --gradient_accumulation_steps=1 --learning_rate=1e-5 --num_train_epochs=500 --dropout_rate=0.0 --custom_chkpt=checkpoints/unet_epoch_20.pth --revision "ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c" 41 | ``` 42 | 43 | 2. Finetune the VAE decoder 44 | 45 | ``` 46 | accelerate launch --num_processes=1 finetune-vae.py --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" --instance_data_dir=demo/sample/train --output_dir=demo/custom-chkpts --resolution=512 --train_batch_size=4 --gradient_accumulation_steps=4 --learning_rate=5e-5 --num_train_epochs=1500 --run_name finetuning/ubc-vae --revision "ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c" 47 | ``` 48 | 49 | ## Testing 50 | 51 | Once you have finetuned your custom, subject-specific DreamPose model, you can generate frames using the following command: 52 | 53 | ``` 54 | python test.py --epoch 499 --folder demo/custom-chkpts --pose_folder demo/sample/poses --key_frame_path demo/sample/key_frame.png --s1 8 --s2 3 --n_steps 100 --output_dir results --custom_vae demo/custom-chkpts/vae_1499.pth 55 | ``` 56 | 57 | ### Acknowledgment 58 | 59 | This code is largely adapted from the [Hugging Face diffusers repo](https://github.com/huggingface/diffusers). 60 | -------------------------------------------------------------------------------- /datasets/dreampose_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | from torchvision import transforms 4 | import torch 5 | import torch.nn.functional as F 6 | from PIL import Image 7 | import numpy as np 8 | import os, cv2, glob 9 | 10 | class DreamPoseDataset(Dataset): 11 | """ 12 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 13 | It pre-processes the images and the tokenizes prompts. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | instance_data_root, 19 | class_data_root=None, 20 | class_prompt=None, 21 | size=512, 22 | center_crop=False, 23 | train=True, 24 | p_jitter=0.9 25 | ): 26 | self.size = (640, 512) 27 | self.center_crop = center_crop 28 | self.train = train 29 | 30 | self.instance_data_root = Path(instance_data_root) 31 | if not self.instance_data_root.exists(): 32 | raise ValueError("Instance images root doesn't exists.") 33 | 34 | # Load UBC Fashion Dataset 35 | self.instance_images_path = glob.glob(instance_data_root+'/*png') 36 | 37 | self.num_instance_images = len(self.instance_images_path) 38 | self._length = self.num_instance_images 39 | 40 | if class_data_root is not None: 41 | self.class_data_root = Path(class_data_root) 42 | self.class_data_root.mkdir(parents=True, exist_ok=True) 43 | self.class_images_path = list(self.class_data_root.iterdir()) 44 | self.num_class_images = len(self.class_images_path) 45 | self._length = max(self.num_class_images, self.num_instance_images) 46 | self.class_prompt = class_prompt 47 | else: 48 | self.class_data_root = None 49 | 50 | self.image_transforms = transforms.Compose( 51 | [ 52 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 53 | transforms.ToTensor(), 54 | ] 55 | ) 56 | 57 | self.tensor_transforms = transforms.Compose( 58 | [ 59 | ] 60 | ) 61 | 62 | def __len__(self): 63 | return self._length 64 | 65 | # resize sparse uv flow to size 66 | def resize_pose(self, pose): 67 | h1, w1 = pose.shape 68 | h2, w2 = self.size, self.size 69 | resized_pose = np.zeros((h2, w2)) 70 | x_vals = np.where(pose != 0)[0] 71 | y_vals = np.where(pose != 0)[1] 72 | for (x, y) in list(zip(x_vals, y_vals)): 73 | # find new coordinates 74 | x2, y2 = int(x * h2 / h1), int(y * w2 / w1) 75 | resized_pose[x2, y2] = pose[x, y] 76 | return resized_pose 77 | 78 | def __getitem__(self, index): 79 | example = {} 80 | 81 | frame_path = self.instance_images_path[index % self.num_instance_images] 82 | frame_folder = frame_path.replace(os.path.basename(frame_path), '') 83 | #frame_number = int(os.path.basename(frame_path).split('frame_')[-1].replace('.png', '')) 84 | 85 | # load frame i 86 | instance_image = Image.open(frame_path) 87 | if not instance_image.mode == "RGB": 88 | instance_image = instance_image.convert("RGB") 89 | 90 | example["frame_i"] = self.image_transforms(instance_image) 91 | example["frame_prev"] = self.image_transforms(instance_image) 92 | 93 | assert example["frame_i"].shape == (3, 640, 512) 94 | 95 | # Select other frame in this folder 96 | frame_paths = glob.glob(frame_folder+'/*png') 97 | frame_paths = [p for p in frame_paths if os.path.exists(p.replace('.png', '_densepose.npy'))] 98 | frame_j_path = np.random.choice(frame_paths) 99 | 100 | # load frame j 101 | frame_j_path = np.random.choice(frame_paths) 102 | instance_image = Image.open(frame_j_path) 103 | if not instance_image.mode == "RGB": 104 | instance_image = instance_image.convert("RGB") 105 | example["frame_j"] = self.image_transforms(instance_image) 106 | 107 | 108 | # construct 5 input poses 109 | poses = [] 110 | h, w = 640, 512 111 | for pose_number in range(5): 112 | dp_path = frame_j_path.replace('.png', '_densepose.npy') 113 | dp_i = F.interpolate(torch.from_numpy(np.load(dp_path).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) 114 | poses.append(self.tensor_transforms(dp_i)) 115 | input_pose = torch.cat(poses, 0) 116 | example["pose_j"] = input_pose 117 | 118 | ''' Data Augmentation ''' 119 | key_frame = example["frame_i"] 120 | frame = example["frame_j"] 121 | prev_frame = example["frame_prev"] 122 | 123 | #dp = transforms.ToPILImage()(dp) 124 | 125 | # Get random transforms to target 70% of the time 126 | p = np.random.randint(0, 100) 127 | if p < 70: 128 | ang = np.random.randint(-15, 15) # rotation angle 129 | distort = np.random.rand(0, 1) 130 | top, left = np.random.randint(0, 25), np.random.randint(0, 25) 131 | h_ = np.random.randint(self.size[0]-25, self.size[0]-top) 132 | w_ = int(h_ / h * w) 133 | 134 | t = transforms.Compose([transforms.ToPILImage(),\ 135 | transforms.Resize((h,w), interpolation=transforms.InterpolationMode.BILINEAR), \ 136 | transforms.ToTensor(),\ 137 | ]) 138 | 139 | # Apply transforms 140 | frame = transforms.functional.crop(frame, top, left, h_, w_) # random crop 141 | 142 | example["frame_j"] = t(frame) 143 | 144 | for pose_id in range(5): 145 | start, end = 2*pose_id, 2*pose_id+2 146 | # convert dense pose to PIL image 147 | dp = example['pose_j'][start:end] 148 | c, h, w = dp.shape 149 | dp = torch.cat((dp, torch.zeros(1, h, w)), 0) 150 | dp = transforms.functional.crop(dp, top, left, h_, w_) # random crop 151 | dp = t(dp)[0:2] # Remove extra channel from input pose 152 | example["pose_j"][start:end] = dp.clone() 153 | 154 | # slightly perturb transforms to previous frame, to prevent copy/paste 155 | top += np.random.randint(0, 5) 156 | left += np.random.randint(0, 5) 157 | h_ += np.random.randint(0, 5) 158 | w_ += np.random.randint(0, 5) 159 | prev_frame = transforms.functional.crop(prev_frame, top, left, h_, w_) # random crop 160 | example["frame_prev"] = t(prev_frame) 161 | else: 162 | # slightly perturb transforms to previous frame, to prevent copy/paste 163 | top, left = np.random.randint(0, 5), np.random.randint(0, 5) 164 | h_ = np.random.randint(self.size[0]-5, self.size[0]-top) 165 | w_ = int(h_ / h * w) 166 | 167 | t = transforms.Compose([transforms.ToPILImage(),\ 168 | transforms.Resize((h,w), interpolation=transforms.InterpolationMode.BILINEAR), \ 169 | transforms.ToTensor(),\ 170 | ]) 171 | 172 | prev_frame = transforms.functional.crop(prev_frame, top, left, h_, w_) # random crop 173 | example["frame_prev"] = t(prev_frame) 174 | 175 | for pose_id in range(5): 176 | start, end = 2*pose_id, 2*pose_id+2 177 | dp = example['pose_j'][start:end] 178 | example["pose_j"][start:end] = dp.clone() 179 | 180 | return example 181 | -------------------------------------------------------------------------------- /datasets/train_vae_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | from torchvision import transforms 4 | import torch 5 | import torch.nn.functional as F 6 | from PIL import Image 7 | import numpy as np 8 | import os, cv2, glob 9 | 10 | class DreamPoseDataset(Dataset): 11 | """ 12 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 13 | It pre-processes the images and the tokenizes prompts. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | instance_data_root, 19 | class_data_root=None, 20 | class_prompt=None, 21 | size=512, 22 | center_crop=False, 23 | train=True, 24 | p_jitter=0.9 25 | ): 26 | self.size = (640, 512) 27 | self.center_crop = center_crop 28 | self.train = train 29 | 30 | self.instance_data_root = Path(instance_data_root) 31 | if not self.instance_data_root.exists(): 32 | raise ValueError("Instance images root doesn't exists.") 33 | 34 | # Load UBC Fashion Dataset 35 | self.instance_images_path = [path for path in glob.glob(instance_data_root+'/*/*/*') if 'frame_i.png' in path] 36 | 37 | if len(self.instance_images_path) == 0: 38 | self.instance_images_path = [path for path in glob.glob(instance_data_root+'/*') if 'png' in path] 39 | 40 | len1 = len(self.instance_images_path) 41 | # Load Deep Fashion Dataset 42 | #self.instance_images_path.extend([path for path in glob.glob('../Deep_Fashion_Dataset/img_highres/*/*/*/*.jpg') \ 43 | # if os.path.exists(path.replace('.jpg', '_densepose.npy'))]) 44 | 45 | len2 = len(self.instance_images_path) 46 | print(f"Train Dataset: {len1} UBC Fashion images, {len2-len1} Deep Fashion images.") 47 | 48 | self.num_instance_images = len(self.instance_images_path) 49 | self._length = self.num_instance_images 50 | 51 | if class_data_root is not None: 52 | self.class_data_root = Path(class_data_root) 53 | self.class_data_root.mkdir(parents=True, exist_ok=True) 54 | self.class_images_path = list(self.class_data_root.iterdir()) 55 | self.num_class_images = len(self.class_images_path) 56 | self._length = max(self.num_class_images, self.num_instance_images) 57 | self.class_prompt = class_prompt 58 | else: 59 | self.class_data_root = None 60 | 61 | self.image_transforms = transforms.Compose( 62 | [ 63 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 64 | #transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.3, hue=0.3), 65 | transforms.ToTensor(), 66 | #transforms.Normalize([0.5], [0.5]), 67 | ] 68 | ) 69 | 70 | self.tensor_transforms = transforms.Compose( 71 | [ 72 | #transforms.Normalize([0.5], [0.5]), 73 | ] 74 | ) 75 | 76 | def __len__(self): 77 | return self._length 78 | 79 | # resize sparse uv flow to size 80 | def resize_pose(self, pose): 81 | h1, w1 = pose.shape 82 | h2, w2 = self.size[0], self.size[1] 83 | resized_pose = np.zeros((h2, w2)) 84 | x_vals = np.where(pose != 0)[0] 85 | y_vals = np.where(pose != 0)[1] 86 | for (x, y) in list(zip(x_vals, y_vals)): 87 | # find new coordinates 88 | x2, y2 = int(x * h2 / h1), int(y * w2 / w1) 89 | resized_pose[x2, y2] = pose[x, y] 90 | return resized_pose 91 | 92 | def __getitem__(self, index): 93 | example = {} 94 | 95 | frame_path = self.instance_images_path[index % self.num_instance_images] 96 | 97 | # load frame j 98 | frame_path = frame_path.replace('frame_i', 'frame_j') 99 | instance_image = Image.open(frame_path) 100 | if not instance_image.mode == "RGB": 101 | instance_image = instance_image.convert("RGB") 102 | frame_j = instance_image 103 | frame_j = frame_j.resize((self.size[1], self.size[0])) 104 | 105 | # Load pose j 106 | h, w = self.size[0], self.size[1] 107 | dp_path = self.instance_images_path[index % self.num_instance_images].replace('frame_i', 'frame_j').replace('.png', '_densepose.npy') 108 | dp_j = F.interpolate(torch.from_numpy(np.load(dp_path, allow_pickle=True).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) 109 | 110 | # Load joints j 111 | #pose_path = self.instance_images_path[index % self.num_instance_images].replace('frame', 'pose').replace('.png', '_refined.npy') 112 | #pose = np.load(pose_path).astype('float32') 113 | #pose = self.resize_pose(pose / 32).astype('float32') 114 | #joints_j = torch.from_numpy(pose).unsqueeze(0) 115 | 116 | # Apply random crops 117 | max_crop = int(0.1*min(frame_j.size[0], frame_j.size[1])) 118 | top, left = np.random.randint(0, max_crop), np.random.randint(0, max_crop) 119 | h_ = np.random.randint(self.size[0]-max_crop, self.size[0]-top) 120 | w_ = int(h_ / h * w) 121 | #print(self.size[0]-max_crop, self.size[0]-top, h_, w_) 122 | frame_j = transforms.functional.crop(frame_j, top, left, h_, w_) # random crop 123 | dp_j = transforms.functional.crop(dp_j, top, left, h_, w_) # random crop 124 | #joints_j = transforms.functional.crop(joints_j, top, left, h_, w_) # random crop 125 | 126 | # Apply resize and normalization 127 | example["frame_j"] = self.image_transforms(frame_j) 128 | dp_j = self.tensor_transforms(dp_j) 129 | example["pose_j"] = F.interpolate(dp_j.unsqueeze(0), (h, w), mode='bilinear').squeeze(0) 130 | 131 | #joints_j = self.resize_pose(joints_j[0].numpy()) 132 | #example["joints_j"] = torch.from_numpy(joints_j).unsqueeze(0) 133 | 134 | return example 135 | -------------------------------------------------------------------------------- /datasets/ubc_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | from torchvision import transforms 4 | import torch 5 | import torch.nn.functional as F 6 | from PIL import Image 7 | import numpy as np 8 | import os, cv2, glob 9 | 10 | ''' 11 | - Passes 5 consecutive input poses per sample 12 | - Ensures at least one pair of consecutive frames per batch 13 | ''' 14 | class DreamPoseDataset(Dataset): 15 | """ 16 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 17 | It pre-processes the images and the tokenizes prompts. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | instance_data_root, 23 | class_data_root=None, 24 | class_prompt=None, 25 | size=512, 26 | center_crop=False, 27 | train=True, 28 | p_jitter=0.9, 29 | n_poses=5 30 | ): 31 | self.size = (640, 512) 32 | self.center_crop = center_crop 33 | self.train = train 34 | self.n_poses = n_poses 35 | 36 | self.instance_data_root = Path(instance_data_root) 37 | if not self.instance_data_root.exists(): 38 | raise ValueError("Instance images root doesn't exists.") 39 | 40 | # Load UBC Fashion Dataset 41 | self.instance_images_path = glob.glob('../UBC_Fashion_Dataset/train-frames/*/*png') 42 | self.instance_images_path = [p for p in self.instance_images_path if os.path.exists(p.replace('.png', '_densepose.npy'))] 43 | len1 = len(self.instance_images_path) 44 | 45 | # Load Deep Fashion Dataset 46 | self.instance_images_path.extend([path for path in glob.glob('../Deep_Fashion_Dataset/img_highres/*/*/*/*.jpg') \ 47 | if os.path.exists(path.replace('.jpg', '_densepose.npy'))]) 48 | 49 | len2 = len(self.instance_images_path) 50 | print(f"Train Dataset: {len1} UBC Fashion images, {len2-len1} Deep Fashion images.") 51 | 52 | self.num_instance_images = len(self.instance_images_path) 53 | self._length = self.num_instance_images 54 | 55 | if class_data_root is not None: 56 | self.class_data_root = Path(class_data_root) 57 | self.class_data_root.mkdir(parents=True, exist_ok=True) 58 | self.class_images_path = list(self.class_data_root.iterdir()) 59 | self.num_class_images = len(self.class_images_path) 60 | self._length = max(self.num_class_images, self.num_instance_images) 61 | self.class_prompt = class_prompt 62 | else: 63 | self.class_data_root = None 64 | 65 | self.image_transforms = transforms.Compose( 66 | [ 67 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 68 | #transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.3, hue=0.3), 69 | #transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 70 | transforms.ToTensor(), 71 | #transforms.Normalize([0.5], [0.5]), 72 | ] 73 | ) 74 | 75 | self.tensor_transforms = transforms.Compose( 76 | [ 77 | #transforms.Normalize([0.5], [0.5]), 78 | ] 79 | ) 80 | 81 | def __len__(self): 82 | return self._length 83 | 84 | # resize sparse uv flow to size 85 | def resize_pose(self, pose): 86 | h1, w1 = pose.shape 87 | h2, w2 = self.size, self.size 88 | resized_pose = np.zeros((h2, w2)) 89 | x_vals = np.where(pose != 0)[0] 90 | y_vals = np.where(pose != 0)[1] 91 | for (x, y) in list(zip(x_vals, y_vals)): 92 | # find new coordinates 93 | x2, y2 = int(x * h2 / h1), int(y * w2 / w1) 94 | resized_pose[x2, y2] = pose[x, y] 95 | return resized_pose 96 | 97 | # return two consecutive frames per call 98 | def __getitem__(self, index): 99 | example = {} 100 | 101 | ''' 102 | 103 | Prepare frame #1 104 | 105 | ''' 106 | # load frame i 107 | frame_path = self.instance_images_path[index % self.num_instance_images] 108 | instance_image = Image.open(frame_path) 109 | if not instance_image.mode == "RGB": 110 | instance_image = instance_image.convert("RGB") 111 | example["frame_i"] = self.image_transforms(instance_image) 112 | 113 | # Get additional frames in this folder 114 | sample_folder = frame_path.replace(os.path.basename(frame_path), '') 115 | samples = [path for path in glob.glob(sample_folder+'/*') if 'npy' not in path] 116 | samples = [path for path in samples if os.path.exists(path.replace('.jpg', '_densepose.npy').replace('.png', '_densepose.npy'))] 117 | 118 | if 'Deep_Fashion' in frame_path: 119 | idx = os.path.basename(frame_path).split('_')[0] 120 | samples = [s for s in samples if os.path.basename(s).split('_')[0] == idx] 121 | #print("Frame Path = ", frame_path) 122 | #print("Sampels = ", samples) 123 | 124 | frame_j_path = samples[np.random.choice(range(len(samples)))] 125 | pose_j_path = frame_j_path.replace('.jpg', '_densepose.npy') 126 | 127 | # load frame j 128 | instance_image = Image.open(frame_j_path) 129 | if not instance_image.mode == "RGB": 130 | instance_image = instance_image.convert("RGB") 131 | example["frame_j"] = self.image_transforms(instance_image) 132 | 133 | # Load 5 poses surrounding j 134 | _, h, w = example["frame_i"].shape 135 | poses = [] 136 | idx1= int(self.n_poses // 2) 137 | idx2 = self.n_poses - idx1 138 | for pose_number in range(5): 139 | dp_path = frame_j_path.replace('.jpg', '_densepose.npy').replace('.png', '_densepose.npy') 140 | dp_i = F.interpolate(torch.from_numpy(np.load(dp_path, allow_pickle=True).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) 141 | poses.append(self.tensor_transforms(dp_i)) 142 | 143 | example["pose_j"] = torch.cat(poses, 0) 144 | 145 | ''' 146 | 147 | Prepare frame #2 148 | 149 | ''' 150 | new_frame_path = samples[np.random.choice(range(len(samples)))] 151 | frame_path = new_frame_path 152 | 153 | # load frame i 154 | instance_image = Image.open(frame_path) 155 | if not instance_image.mode == "RGB": 156 | instance_image = instance_image.convert("RGB") 157 | example["frame_i"] = torch.stack((example["frame_i"], self.image_transforms(instance_image)), 0) 158 | 159 | assert example["frame_i"].shape == (2, 3, 640, 512) 160 | 161 | # Load frame j 162 | frame_j_path = samples[np.random.choice(range(len(samples)))] 163 | instance_image = Image.open(frame_j_path) 164 | if not instance_image.mode == "RGB": 165 | instance_image = instance_image.convert("RGB") 166 | example["frame_j"] = torch.stack((example['frame_j'], self.image_transforms(instance_image)), 0) 167 | 168 | # Load 5 poses surrounding j 169 | poses = [] 170 | for pose_number in range(5): 171 | dp_path = frame_j_path.replace('.jpg', '_densepose.npy').replace('.png', '_densepose.npy') 172 | dp_i = F.interpolate(torch.from_numpy(np.load(dp_path, allow_pickle=True).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) 173 | poses.append(self.tensor_transforms(dp_i)) 174 | 175 | poses = torch.cat(poses, 0) 176 | example["pose_j"] = torch.stack((example["pose_j"], poses), 0) 177 | 178 | #print(example["frame_i"].shape, example["frame_j"].shape, example["pose_j"].shape) 179 | return example 180 | -------------------------------------------------------------------------------- /demo/sample/key_frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannakarras/DreamPose/5bf30b7df70cf6f2e0bb25556c6ff2cbf0f2b1bf/demo/sample/key_frame.png -------------------------------------------------------------------------------- /demo/sample/train/frame_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannakarras/DreamPose/5bf30b7df70cf6f2e0bb25556c6ff2cbf0f2b1bf/demo/sample/train/frame_50.png -------------------------------------------------------------------------------- /demo/sample/train/frame_50_densepose.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannakarras/DreamPose/5bf30b7df70cf6f2e0bb25556c6ff2cbf0f2b1bf/demo/sample/train/frame_50_densepose.npy -------------------------------------------------------------------------------- /finetune-unet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import itertools 4 | import math 5 | import os 6 | import random 7 | from pathlib import Path 8 | from typing import Optional 9 | from einops import rearrange 10 | from collections import OrderedDict 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint 16 | from torch.utils.data import Dataset 17 | import torch.nn as nn 18 | import numpy as np 19 | import cv2 20 | 21 | from accelerate import Accelerator 22 | from accelerate.logging import get_logger 23 | from accelerate.utils import set_seed 24 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 25 | from diffusers.optimization import get_scheduler 26 | from huggingface_hub import HfFolder, Repository, whoami 27 | from PIL import Image 28 | from torchvision import transforms 29 | from tqdm.auto import tqdm 30 | from transformers import CLIPFeatureExtractor, CLIPTokenizer, CLIPProcessor, CLIPVisionModel 31 | 32 | from torch.utils.tensorboard import SummaryWriter 33 | 34 | logger = get_logger(__name__) 35 | 36 | from utils.parse_args import parse_args 37 | from datasets.dreampose_dataset import DreamPoseDataset 38 | from pipelines.dual_encoder_pipeline import StableDiffusionImg2ImgPipeline 39 | from models.unet_dual_encoder import get_unet, Embedding_Adapter 40 | 41 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 42 | if token is None: 43 | token = HfFolder.get_token() 44 | if organization is None: 45 | username = whoami(token)["name"] 46 | return f"{username}/{model_id}" 47 | else: 48 | return f"{organization}/{model_id}" 49 | 50 | def main(args): 51 | logging_dir = Path(args.output_dir, args.logging_dir) 52 | 53 | accelerator = Accelerator( 54 | gradient_accumulation_steps=args.gradient_accumulation_steps, 55 | mixed_precision=args.mixed_precision, 56 | log_with="tensorboard", 57 | logging_dir=logging_dir, 58 | ) 59 | 60 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 61 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 62 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 63 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 64 | raise ValueError( 65 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 66 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 67 | ) 68 | 69 | if args.seed is not None: 70 | set_seed(args.seed) 71 | 72 | # Handle the repository creation 73 | if accelerator.is_main_process: 74 | if args.push_to_hub: 75 | if args.hub_model_id is None: 76 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 77 | else: 78 | repo_name = args.hub_model_id 79 | repo = Repository(args.output_dir, clone_from=repo_name) 80 | 81 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 82 | if "step_*" not in gitignore: 83 | gitignore.write("step_*\n") 84 | if "epoch_*" not in gitignore: 85 | gitignore.write("epoch_*\n") 86 | elif args.output_dir is not None: 87 | os.makedirs(args.output_dir, exist_ok=True) 88 | 89 | # Load CLIP Image Encoder 90 | clip_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").cuda() 91 | clip_encoder.requires_grad_(False) 92 | clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 93 | 94 | # Load models and create wrapper for stable diffusion 95 | vae = AutoencoderKL.from_pretrained( 96 | "CompVis/stable-diffusion-v1-4", 97 | subfolder="vae", 98 | revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c" 99 | ) 100 | 101 | # Load pretrained UNet layers 102 | unet = get_unet(args.pretrained_model_name_or_path, args.revision, resolution=args.resolution) 103 | 104 | if args.custom_chkpt is not None: 105 | print("Loading ", args.custom_chkpt) 106 | unet_state_dict = torch.load(args.custom_chkpt) 107 | new_state_dict = OrderedDict() 108 | for k, v in unet_state_dict.items(): 109 | name = k[7:] if k[:7] == 'module' else k 110 | new_state_dict[name] = v 111 | unet.load_state_dict(new_state_dict) 112 | unet = unet.cuda() 113 | 114 | # Embedding adapter 115 | adapter = Embedding_Adapter(input_nc=1280, output_nc=1280) 116 | 117 | if args.custom_chkpt is not None: 118 | adapter_chkpt = args.custom_chkpt.replace('unet_epoch', 'adapter') 119 | print("Loading ", adapter_chkpt) 120 | adapter_state_dict = torch.load(adapter_chkpt) 121 | new_state_dict = OrderedDict() 122 | for k, v in adapter_state_dict.items(): 123 | name = k[7:] if k[:7] == 'module' else k 124 | new_state_dict[name] = v 125 | adapter.load_state_dict(new_state_dict) 126 | adapter = adapter.cuda() 127 | 128 | #adapter.requires_grad_(True) 129 | 130 | vae.requires_grad_(False) 131 | 132 | if args.gradient_checkpointing: 133 | unet.enable_gradient_checkpointing() 134 | adapter.enable_gradient_checkpointing() 135 | 136 | if args.scale_lr: 137 | args.learning_rate = ( 138 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 139 | ) 140 | 141 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 142 | if args.use_8bit_adam: 143 | try: 144 | import bitsandbytes as bnb 145 | except ImportError: 146 | raise ImportError( 147 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 148 | ) 149 | 150 | optimizer_class = bnb.optim.AdamW8bit 151 | else: 152 | optimizer_class = torch.optim.AdamW 153 | 154 | params_to_optimize = ( 155 | itertools.chain(unet.parameters(), adapter.parameters(),) 156 | ) 157 | 158 | optimizer = optimizer_class( 159 | params_to_optimize, 160 | lr=args.learning_rate, 161 | betas=(args.adam_beta1, args.adam_beta2), 162 | weight_decay=args.adam_weight_decay, 163 | eps=args.adam_epsilon, 164 | ) 165 | 166 | noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 167 | 168 | # Load the tokenizer 169 | if args.tokenizer_name: 170 | tokenizer = CLIPTokenizer.from_pretrained( 171 | args.tokenizer_name, 172 | revision=args.revision, 173 | ) 174 | elif args.pretrained_model_name_or_path: 175 | tokenizer = CLIPTokenizer.from_pretrained( 176 | args.pretrained_model_name_or_path, 177 | subfolder="tokenizer", 178 | revision=args.revision, 179 | ) 180 | 181 | train_dataset = DreamPoseDataset( 182 | instance_data_root=args.instance_data_dir, 183 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 184 | class_prompt=args.class_prompt, 185 | size=args.resolution, 186 | center_crop=args.center_crop, 187 | ) 188 | 189 | def collate_fn(examples): 190 | frame_i = [example["frame_i"] for example in examples] 191 | frame_j = [example["frame_i"] for example in examples] 192 | poses = [example["pose_j"] for example in examples] 193 | 194 | # Concat class and instance examples for prior preservation. 195 | # We do this to avoid doing two forward passes. 196 | if args.with_prior_preservation: 197 | input_ids += [example["class_prompt_ids"] for example in examples] 198 | frame_i += [example["class_frame_i"] for example in examples] 199 | frame_j += [example["class_frame_j"] for example in examples] 200 | poses += [example["class_pose_j"] for example in examples] 201 | 202 | frame_i = torch.stack(frame_i, 0) 203 | frame_j = torch.stack(frame_j, 0) 204 | poses = torch.stack(poses, 0) 205 | 206 | # Dropout 207 | p = random.random() 208 | if p <= args.dropout_rate / 3: # dropout pose 209 | poses = torch.zeros(poses.shape) 210 | elif p <= 2*args.dropout_rate / 3: # dropout image 211 | frame_i = torch.zeros(frame_i.shape) 212 | elif p <= args.dropout_rate: # dropout image and pose 213 | poses = torch.zeros(poses.shape) 214 | frame_i = torch.zeros(frame_i.shape) 215 | 216 | frame_i = frame_i.to(memory_format=torch.contiguous_format).float() 217 | frame_j = frame_j.to(memory_format=torch.contiguous_format).float() 218 | poses = poses.to(memory_format=torch.contiguous_format).float() 219 | 220 | batch = { 221 | "frame_i": frame_i, 222 | "frame_j": frame_j, 223 | "poses": poses, 224 | } 225 | return batch 226 | 227 | train_dataloader = torch.utils.data.DataLoader( 228 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 229 | ) 230 | 231 | # Scheduler and math around the number of training steps. 232 | overrode_max_train_steps = False 233 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 234 | if args.max_train_steps is None: 235 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 236 | overrode_max_train_steps = True 237 | 238 | lr_scheduler = get_scheduler( 239 | args.lr_scheduler, 240 | optimizer=optimizer, 241 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 242 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 243 | ) 244 | 245 | if args.train_text_encoder: 246 | unet, adapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 247 | unet, adapter, optimizer, train_dataloader, lr_scheduler 248 | ) 249 | else: 250 | unet, adapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 251 | unet, adapter, optimizer, train_dataloader, lr_scheduler 252 | ) 253 | 254 | weight_dtype = torch.float32 255 | if accelerator.mixed_precision == "fp16": 256 | weight_dtype = torch.float16 257 | elif accelerator.mixed_precision == "bf16": 258 | weight_dtype = torch.bfloat16 259 | 260 | # Move text_encode and vae to gpu. 261 | # For mixed precision training we cast the image_encoder and vae weights to half-precision 262 | # as these models are only used for inference, keeping weights in full precision is not required. 263 | vae.to(accelerator.device, dtype=weight_dtype) 264 | 265 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 266 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 267 | if overrode_max_train_steps: 268 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 269 | # Afterwards we recalculate our number of training epochs 270 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 271 | 272 | # We need to initialize the trackers we use, and also store our configuration. 273 | # The trackers initializes automatically on the main process. 274 | if accelerator.is_main_process: 275 | accelerator.init_trackers("dreambooth", config=vars(args)) 276 | 277 | # Train! 278 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 279 | 280 | logger.info("***** Running training *****") 281 | logger.info(f" Num examples = {len(train_dataset)}") 282 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 283 | logger.info(f" Num Epochs = {args.num_train_epochs}") 284 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 285 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 286 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 287 | logger.info(f" Total optimization steps = {args.max_train_steps}") 288 | # Only show the progress bar once on each machine. 289 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 290 | progress_bar.set_description("Steps") 291 | global_step = 0 292 | 293 | def latents2img(latents): 294 | latents = 1 / 0.18215 * latents 295 | images = vae.decode(latents).sample 296 | images = (images / 2 + 0.5).clamp(0, 1) 297 | images = images.detach().cpu().numpy() 298 | images = (images * 255).round().astype("uint8") 299 | return images 300 | 301 | def inputs2img(input): 302 | target_images = (input / 2 + 0.5).clamp(0, 1) 303 | target_images = target_images.detach().cpu().numpy() 304 | target_images = (target_images * 255).round().astype("uint8") 305 | return target_images 306 | 307 | def visualize_dp(im, dp): 308 | im = im.transpose((1,2,0)) 309 | hsv = np.zeros(im.shape, dtype=np.uint8) 310 | hsv[..., 1] = 255 311 | 312 | dp = dp.cpu().detach().numpy() 313 | mag, ang = cv2.cartToPolar(dp[0], dp[1]) 314 | hsv[..., 0] = ang * 180 / np.pi / 2 315 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 316 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 317 | 318 | bgr = bgr.transpose((2,0,1)) 319 | return bgr 320 | 321 | latest_chkpt_step = 0 322 | for epoch in range(args.epoch, args.num_train_epochs): 323 | unet.train() 324 | adapter.train() 325 | first_batch = True 326 | for step, batch in enumerate(train_dataloader): 327 | if first_batch and latest_chkpt_step is not None: 328 | #os.system(f"python test_img2img.py --step {latest_chkpt_step} --strength 0.8") 329 | first_batch = False 330 | with accelerator.accumulate(unet): 331 | # Convert images to latent space 332 | latents = vae.encode(batch["frame_j"].to(dtype=weight_dtype)).latent_dist.sample() 333 | latents = latents * 0.18215 334 | 335 | # Sample noise that we'll add to the latents 336 | noise = torch.randn_like(latents) 337 | bsz = latents.shape[0] 338 | 339 | # Sample a random timestep for each image 340 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 341 | timesteps = timesteps.long() 342 | 343 | # Add noise to the latents according to the noise magnitude at each timestep 344 | # (this is the forward diffusion process) 345 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 346 | 347 | # Concatenate pose with noise 348 | _, _, h, w = noisy_latents.shape 349 | noisy_latents = torch.cat((noisy_latents, F.interpolate(batch['poses'], (h,w))), 1) 350 | 351 | # Get CLIP embeddings 352 | inputs = clip_processor(images=list(batch['frame_i'].to(latents.device)), return_tensors="pt") 353 | inputs = {k: v.to(latents.device) for k, v in inputs.items()} 354 | clip_hidden_states = clip_encoder(**inputs).last_hidden_state.to(latents.device) 355 | 356 | # Get VAE embeddings 357 | image = batch['frame_i'].to(device=latents.device, dtype=weight_dtype) 358 | vae_hidden_states = vae.encode(image).latent_dist.sample() * 0.18215 359 | 360 | encoder_hidden_states = adapter(clip_hidden_states, vae_hidden_states) 361 | 362 | # Predict the noise residual 363 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 364 | 365 | # Get the target for loss depending on the prediction type 366 | if noise_scheduler.config.prediction_type == "epsilon": 367 | target = noise 368 | elif noise_scheduler.config.prediction_type == "v_prediction": 369 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 370 | else: 371 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 372 | 373 | if args.with_prior_preservation: 374 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 375 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 376 | target, target_prior = torch.chunk(target, 2, dim=0) 377 | 378 | # Compute instance loss 379 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 380 | 381 | # Compute prior loss 382 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 383 | 384 | # Add the prior loss to the instance loss. 385 | loss = loss + args.prior_loss_weight * prior_loss 386 | else: 387 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 388 | 389 | accelerator.backward(loss) 390 | if accelerator.sync_gradients: 391 | params_to_clip = ( 392 | itertools.chain(unet.parameters()) 393 | ) 394 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 395 | optimizer.step() 396 | lr_scheduler.step() 397 | optimizer.zero_grad() 398 | 399 | # Checks if the accelerator has performed an optimization step behind the scenes 400 | if accelerator.sync_gradients: 401 | progress_bar.update(1) 402 | global_step += 1 403 | 404 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 405 | progress_bar.set_postfix(**logs) 406 | accelerator.log(logs, step=global_step) 407 | 408 | if global_step >= args.max_train_steps: 409 | break 410 | 411 | # save model 412 | if accelerator.is_main_process and global_step % 500 == 0: 413 | pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( 414 | args.pretrained_model_name_or_path, 415 | #adapter=accelerator.unwrap_model(adapter), 416 | unet=accelerator.unwrap_model(unet), 417 | tokenizer=tokenizer, 418 | image_encoder=accelerator.unwrap_model(clip_encoder), 419 | clip_processor=accelerator.unwrap_model(clip_processor), 420 | revision=args.revision, 421 | ) 422 | pipeline.save_pretrained(os.path.join(args.output_dir, f'checkpoint-{epoch}')) 423 | model_path = args.output_dir+f'/unet_epoch_{epoch}.pth' 424 | torch.save(unet.state_dict(), model_path) 425 | adapter_path = args.output_dir+f'/adapter_{epoch}.pth' 426 | torch.save(adapter.state_dict(), adapter_path) 427 | 428 | accelerator.wait_for_everyone() 429 | 430 | # save model 431 | if accelerator.is_main_process and global_step % 500 == 0: 432 | pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( 433 | args.pretrained_model_name_or_path, 434 | #adapter=accelerator.unwrap_model(adapter), 435 | unet=accelerator.unwrap_model(unet), 436 | tokenizer=tokenizer, 437 | image_encoder=accelerator.unwrap_model(clip_encoder), 438 | clip_processor=accelerator.unwrap_model(clip_processor), 439 | revision=args.revision, 440 | ) 441 | pipeline.save_pretrained(os.path.join(args.output_dir, f'checkpoint-{epoch}')) 442 | model_path = args.output_dir+f'/unet_epoch_{epoch}.pth' 443 | torch.save(unet.state_dict(), model_path) 444 | adapter_path = args.output_dir+f'/adapter_{epoch}.pth' 445 | torch.save(adapter.state_dict(), adapter_path) 446 | 447 | accelerator.end_training() 448 | 449 | 450 | if __name__ == "__main__": 451 | args = parse_args() 452 | main(args) -------------------------------------------------------------------------------- /finetune-vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import itertools 4 | import math 5 | import os 6 | import random 7 | from pathlib import Path 8 | from typing import Optional 9 | from collections import OrderedDict 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint 14 | from torch.utils.data import Dataset 15 | import torch.nn as nn 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import cv2 19 | 20 | from accelerate import Accelerator 21 | from accelerate.logging import get_logger 22 | from accelerate.utils import set_seed 23 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 24 | from diffusers.optimization import get_scheduler 25 | from huggingface_hub import HfFolder, Repository, whoami 26 | from PIL import Image 27 | from torchvision import transforms 28 | from tqdm.auto import tqdm 29 | from transformers import CLIPFeatureExtractor, CLIPTokenizer, CLIPProcessor, CLIPVisionModel 30 | 31 | from torch.utils.tensorboard import SummaryWriter 32 | 33 | logger = get_logger(__name__) 34 | 35 | from utils.parse_args import parse_args 36 | from datasets.train_vae_dataset import DreamPoseDataset 37 | from pipelines.dual_encoder_pipeline import StableDiffusionImg2ImgPipeline 38 | from models.unet_dual_encoder import get_unet, Embedding_Adapter 39 | 40 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 41 | if token is None: 42 | token = HfFolder.get_token() 43 | if organization is None: 44 | username = whoami(token)["name"] 45 | return f"{username}/{model_id}" 46 | else: 47 | return f"{organization}/{model_id}" 48 | 49 | def main(args): 50 | logging_dir = Path(args.output_dir, args.logging_dir) 51 | 52 | writer = SummaryWriter(f'results/logs/{args.run_name}') 53 | 54 | accelerator = Accelerator( 55 | gradient_accumulation_steps=args.gradient_accumulation_steps, 56 | mixed_precision=args.mixed_precision, 57 | log_with="tensorboard", 58 | logging_dir=logging_dir, 59 | ) 60 | 61 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 62 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 63 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 64 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 65 | raise ValueError( 66 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 67 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 68 | ) 69 | 70 | if args.seed is not None: 71 | set_seed(args.seed) 72 | 73 | # initialize perecpetual loss 74 | #lpips_loss = lpips.LPIPS(net='vgg').cuda() 75 | 76 | if args.with_prior_preservation: 77 | class_images_dir = Path(args.class_data_dir) 78 | if not class_images_dir.exists(): 79 | class_images_dir.mkdir(parents=True) 80 | cur_class_images = len(list(class_images_dir.iterdir())) 81 | 82 | if cur_class_images < args.num_class_images: 83 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 84 | pipeline = StableDiffusionPipeline.from_pretrained( 85 | args.pretrained_model_name_or_path, 86 | torch_dtype=torch_dtype, 87 | safety_checker=None, 88 | revision=args.revision, 89 | ) 90 | pipeline.set_progress_bar_config(disable=True) 91 | 92 | num_new_images = args.num_class_images - cur_class_images 93 | logger.info(f"Number of class images to sample: {num_new_images}.") 94 | pipeline.to(accelerator.device) 95 | 96 | del pipeline 97 | if torch.cuda.is_available(): 98 | torch.cuda.empty_cache() 99 | 100 | # Handle the repository creation 101 | if accelerator.is_main_process: 102 | if args.push_to_hub: 103 | if args.hub_model_id is None: 104 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 105 | else: 106 | repo_name = args.hub_model_id 107 | repo = Repository(args.output_dir, clone_from=repo_name) 108 | 109 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 110 | if "step_*" not in gitignore: 111 | gitignore.write("step_*\n") 112 | if "epoch_*" not in gitignore: 113 | gitignore.write("epoch_*\n") 114 | elif args.output_dir is not None: 115 | os.makedirs(args.output_dir, exist_ok=True) 116 | 117 | # Load CLIP Image Encoder 118 | image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") 119 | clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 120 | 121 | # Load models and create wrapper for stable diffusion 122 | vae = AutoencoderKL.from_pretrained( 123 | args.pretrained_model_name_or_path, 124 | subfolder="vae", 125 | revision=args.revision, 126 | ) 127 | # Load pretrained UNet layers 128 | unet = UNet2DConditionModel.from_pretrained( 129 | args.pretrained_model_name_or_path, 130 | subfolder="unet", 131 | revision=args.revision, 132 | ) 133 | # Modify input layer & copy pretrain weights 134 | weights = unet.conv_in.weight.clone() 135 | unet.conv_in = nn.Conv2d(6, weights.shape[0], kernel_size=3, padding=(1, 1)) 136 | with torch.no_grad(): 137 | unet.conv_in.weight[:, :4] = weights # original weights 138 | unet.conv_in.weight[:, 4:] = torch.zeros(unet.conv_in.weight[:, 4:].shape) # new weights initialized to zero 139 | unet.requires_grad_(False) 140 | 141 | # set VAE decoder to be trainable 142 | # Load VAE Pretrained Model 143 | if args.custom_chkpt is not None: 144 | vae_state_dict = torch.load(args.custom_chkpt) #'results/epoch_1/unet.pth')) 145 | new_state_dict = OrderedDict() 146 | for k, v in vae_state_dict.items(): 147 | name = k[7:] if k[:7] == 'module' else k # remove `module.` 148 | new_state_dict[name] = v 149 | vae.load_state_dict(new_state_dict) 150 | vae = vae.cuda() 151 | 152 | vae.requires_grad_(False) 153 | vae_trainable_params = [] 154 | for name, param in vae.named_parameters(): 155 | if 'decoder' in name: 156 | param.requires_grad = True 157 | vae_trainable_params.append(param) 158 | 159 | print(f"VAE total params = {len(list(vae.named_parameters()))}, trainable params = {len(vae_trainable_params)}") 160 | image_encoder.requires_grad_(False) 161 | 162 | if args.gradient_checkpointing: 163 | vae.gradient_checkpointing_enable() # uncomment if training clip model 164 | 165 | if args.scale_lr: 166 | args.learning_rate = ( 167 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 168 | ) 169 | 170 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 171 | if args.use_8bit_adam: 172 | try: 173 | import bitsandbytes as bnb 174 | except ImportError: 175 | raise ImportError( 176 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 177 | ) 178 | 179 | optimizer_class = bnb.optim.AdamW8bit 180 | else: 181 | optimizer_class = torch.optim.AdamW 182 | 183 | params_to_optimize = ( 184 | itertools.chain(vae_trainable_params) 185 | ) 186 | optimizer = optimizer_class( 187 | params_to_optimize, 188 | lr=args.learning_rate, 189 | betas=(args.adam_beta1, args.adam_beta2), 190 | weight_decay=args.adam_weight_decay, 191 | eps=args.adam_epsilon, 192 | ) 193 | 194 | noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 195 | 196 | 197 | train_dataset = DreamPoseDataset( 198 | instance_data_root=args.instance_data_dir, 199 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 200 | class_prompt=args.class_prompt, 201 | size=args.resolution, 202 | center_crop=args.center_crop, 203 | ) 204 | 205 | def collate_fn(examples): 206 | frame_j = [example["frame_j"] for example in examples] 207 | poses = [example["pose_j"] for example in examples] 208 | 209 | frame_j = torch.stack(frame_j) 210 | poses = torch.stack(poses) 211 | 212 | frame_j = frame_j.to(memory_format=torch.contiguous_format).float() 213 | poses = poses.to(memory_format=torch.contiguous_format).float() 214 | 215 | batch = { 216 | "target_frame": frame_j, 217 | "poses": poses, 218 | } 219 | return batch 220 | 221 | train_dataloader = torch.utils.data.DataLoader( 222 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 223 | ) 224 | 225 | # Scheduler and math around the number of training steps. 226 | overrode_max_train_steps = False 227 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 228 | if args.max_train_steps is None: 229 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 230 | overrode_max_train_steps = True 231 | 232 | lr_scheduler = get_scheduler( 233 | args.lr_scheduler, 234 | optimizer=optimizer, 235 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 236 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 237 | ) 238 | 239 | unet, vae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 240 | unet, vae, optimizer, train_dataloader, lr_scheduler 241 | ) 242 | 243 | weight_dtype = torch.float32 244 | if accelerator.mixed_precision == "fp16": 245 | weight_dtype = torch.float16 246 | elif accelerator.mixed_precision == "bf16": 247 | weight_dtype = torch.bfloat16 248 | 249 | vae.to(accelerator.device, dtype=weight_dtype) 250 | if not args.train_text_encoder: 251 | image_encoder.to(accelerator.device, dtype=weight_dtype) 252 | 253 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 254 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 255 | if overrode_max_train_steps: 256 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 257 | # Afterwards we recalculate our number of training epochs 258 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 259 | 260 | # We need to initialize the trackers we use, and also store our configuration. 261 | # The trackers initializes automatically on the main process. 262 | if accelerator.is_main_process: 263 | accelerator.init_trackers("dreambooth", config=vars(args)) 264 | 265 | # Train! 266 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 267 | 268 | logger.info("***** Running training *****") 269 | logger.info(f" Num examples = {len(train_dataset)}") 270 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 271 | logger.info(f" Num Epochs = {args.num_train_epochs}") 272 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 273 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 274 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 275 | logger.info(f" Total optimization steps = {args.max_train_steps}") 276 | # Only show the progress bar once on each machine. 277 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 278 | progress_bar.set_description("Steps") 279 | global_step = 0 280 | 281 | def latents2img(latents): 282 | latents = 1 / 0.18215 * latents 283 | images = vae.decode(latents).sample 284 | images = (images / 2 + 0.5).clamp(0, 1) 285 | images = images.detach().cpu().numpy() 286 | images = (images * 255).round().astype("uint8") 287 | return images 288 | 289 | def inputs2img(input): 290 | target_images = (input / 2 + 0.5).clamp(0, 1) 291 | target_images = target_images.detach().cpu().numpy() 292 | target_images = (target_images * 255).round().astype("uint8") 293 | return target_images 294 | 295 | def visualize_dp(im, dp): 296 | im = im.transpose((1,2,0)) 297 | hsv = np.zeros(im.shape, dtype=np.uint8) 298 | hsv[..., 1] = 255 299 | 300 | dp = dp.cpu().detach().numpy() 301 | mag, ang = cv2.cartToPolar(dp[0], dp[1]) 302 | hsv[..., 0] = ang * 180 / np.pi / 2 303 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 304 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 305 | 306 | bgr = bgr.transpose((2,0,1)) 307 | return bgr 308 | 309 | latest_chkpt_step = 0 310 | for epoch in range(args.epoch, args.num_train_epochs): 311 | vae.train() 312 | first_batch = True 313 | for step, batch in enumerate(train_dataloader): 314 | if first_batch and latest_chkpt_step is not None: 315 | first_batch = False 316 | with accelerator.accumulate(vae): 317 | # Convert images to latent space 318 | latents = vae.encode(batch["target_frame"].to(dtype=weight_dtype)).latent_dist.sample() 319 | latents = latents * 0.18215 320 | 321 | latents = 1 / 0.18215 * latents 322 | pred_images = vae.decode(latents).sample 323 | pred_images = pred_images.clamp(-1, 1) 324 | 325 | loss = F.mse_loss(pred_images.float(), batch['target_frame'].clamp(-1, 1).float(), reduction="mean") 326 | 327 | accelerator.backward(loss) 328 | if accelerator.sync_gradients: 329 | params_to_clip = ( 330 | itertools.chain(vae.parameters()) 331 | ) 332 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 333 | optimizer.step() 334 | lr_scheduler.step() 335 | optimizer.zero_grad() 336 | 337 | # Checks if the accelerator has performed an optimization step behind the scenes 338 | if accelerator.sync_gradients: 339 | progress_bar.update(1) 340 | global_step += 1 341 | 342 | # write to tensorboard 343 | writer.add_scalar("loss/train", loss.detach().item(), global_step) 344 | 345 | # write to tensorboard 346 | if global_step % 10 == 0: 347 | # Draw VAE decoder weights 348 | weights = vae.decoder.conv_out.weight.cpu().detach().numpy() 349 | weights = np.sum(weights, axis=0) 350 | weights = weights.flatten() 351 | plt.figure() 352 | plt.plot(range(len(weights)), weights) 353 | plt.title(f"VAE Decoder Weights = {np.mean(weights)}") 354 | writer.add_figure('decoder_weights', plt.gcf(), global_step=global_step) 355 | 356 | # Draw VAE encoder weights 357 | weights = vae.encoder.conv_out.weight.cpu().detach().numpy() 358 | weights = np.sum(weights, axis=0) 359 | weights = weights.flatten() 360 | plt.figure() 361 | plt.plot(range(len(weights)), weights) 362 | plt.title(f"Fixed VAE Encoder Weights= {np.mean(weights)}") 363 | writer.add_figure('encoder_weights', plt.gcf(), global_step=global_step) 364 | 365 | if global_step == 1 or global_step % 50 == 0: 366 | with torch.no_grad(): 367 | pred_images = inputs2img(pred_images) 368 | target = inputs2img(batch["target_frame"]) 369 | viz = np.concatenate([pred_images[0], target[0]], axis=2) 370 | writer.add_image(f'train/pred_img', viz, global_step=global_step) 371 | 372 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 373 | progress_bar.set_postfix(**logs) 374 | accelerator.log(logs, step=global_step) 375 | 376 | if global_step >= args.max_train_steps: 377 | break 378 | 379 | # save model 380 | if accelerator.is_main_process and global_step % 500 == 0: 381 | model_path = args.output_dir+f'/vae_{epoch}.pth' 382 | torch.save(vae.state_dict(), model_path) 383 | 384 | accelerator.wait_for_everyone() 385 | 386 | # save model 387 | if accelerator.is_main_process: 388 | print("Saving final model to ", args.output_dir) 389 | model_path = args.output_dir+f'/vae_{epoch}.pth' 390 | torch.save(vae.state_dict(), model_path) 391 | 392 | accelerator.end_training() 393 | 394 | 395 | if __name__ == "__main__": 396 | args = parse_args() 397 | main(args) -------------------------------------------------------------------------------- /media/DreamPose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannakarras/DreamPose/5bf30b7df70cf6f2e0bb25556c6ff2cbf0f2b1bf/media/DreamPose.png -------------------------------------------------------------------------------- /media/Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannakarras/DreamPose/5bf30b7df70cf6f2e0bb25556c6ff2cbf0f2b1bf/media/Teaser.png -------------------------------------------------------------------------------- /media/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannakarras/DreamPose/5bf30b7df70cf6f2e0bb25556c6ff2cbf0f2b1bf/media/demo.gif -------------------------------------------------------------------------------- /media/demo.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannakarras/DreamPose/5bf30b7df70cf6f2e0bb25556c6ff2cbf0f2b1bf/media/demo.mov -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johannakarras/DreamPose/5bf30b7df70cf6f2e0bb25556c6ff2cbf0f2b1bf/models/.DS_Store -------------------------------------------------------------------------------- /models/unet_dual_encoder.py: -------------------------------------------------------------------------------- 1 | # Load pretrained 2D UNet and modify with temporal attention 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import einsum 9 | import torch.utils.checkpoint 10 | from einops import rearrange 11 | 12 | import math 13 | 14 | from diffusers import AutoencoderKL 15 | from diffusers.models import UNet2DConditionModel 16 | 17 | def get_unet(pretrained_model_name_or_path, revision, resolution=256, n_poses=5): 18 | # Load pretrained UNet layers 19 | unet = UNet2DConditionModel.from_pretrained( 20 | "CompVis/stable-diffusion-v1-4", 21 | subfolder="unet", 22 | revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c" 23 | ) 24 | 25 | # Modify input layer to have 1 additional input channels (pose) 26 | weights = unet.conv_in.weight.clone() 27 | unet.conv_in = nn.Conv2d(4 + 2*n_poses, weights.shape[0], kernel_size=3, padding=(1, 1)) # input noise + n poses 28 | with torch.no_grad(): 29 | unet.conv_in.weight[:, :4] = weights # original weights 30 | unet.conv_in.weight[:, 4:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) # new weights initialized to zero 31 | 32 | return unet 33 | 34 | ''' 35 | This module takes in CLIP + VAE embeddings and outputs CLIP-compatible embeddings. 36 | ''' 37 | class Embedding_Adapter(nn.Module): 38 | def __init__(self, input_nc=38, output_nc=4, norm_layer=nn.InstanceNorm2d, chkpt=None): 39 | super(Embedding_Adapter, self).__init__() 40 | 41 | self.save_method_name = "adapter" 42 | 43 | self.pool = nn.MaxPool2d(2) 44 | self.vae2clip = nn.Linear(1280, 768) 45 | 46 | self.linear1 = nn.Linear(54, 50) # 50 x 54 shape 47 | 48 | # initialize weights 49 | with torch.no_grad(): 50 | self.linear1.weight = nn.Parameter(torch.eye(50, 54)) 51 | 52 | if chkpt is not None: 53 | pass 54 | 55 | def forward(self, clip, vae): 56 | 57 | vae = self.pool(vae) # 1 4 80 64 --> 1 4 40 32 58 | vae = rearrange(vae, 'b c h w -> b c (h w)') # 1 4 20 16 --> 1 4 1280 59 | 60 | vae = self.vae2clip(vae) # 1 4 768 61 | 62 | # Concatenate 63 | concat = torch.cat((clip, vae), 1) 64 | 65 | # Encode 66 | 67 | concat = rearrange(concat, 'b c d -> b d c') 68 | concat = self.linear1(concat) 69 | concat = rearrange(concat, 'b d c -> b c d') 70 | 71 | return concat 72 | -------------------------------------------------------------------------------- /pipelines/dual_encoder_pipeline.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Callable, List, Optional, Union 3 | from einops import rearrange 4 | 5 | import numpy as np 6 | import torch, torchvision 7 | import torch.nn.functional as F 8 | from torch.cuda.amp import autocast 9 | from torchvision import transforms 10 | from torchvision.utils import make_grid 11 | 12 | import PIL 13 | from diffusers.utils import is_accelerate_available 14 | from packaging import version 15 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPProcessor 16 | 17 | from diffusers.configuration_utils import FrozenDict 18 | from diffusers import AutoencoderKL, UNet2DConditionModel 19 | from diffusers import DiffusionPipeline 20 | from diffusers import ( 21 | DDIMScheduler, 22 | DPMSolverMultistepScheduler, 23 | EulerAncestralDiscreteScheduler, 24 | EulerDiscreteScheduler, 25 | LMSDiscreteScheduler, 26 | PNDMScheduler, 27 | ) 28 | 29 | from diffusers.utils import PIL_INTERPOLATION, deprecate, logging 30 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 31 | from models.unet_dual_encoder import get_unet, Embedding_Adapter 32 | 33 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 34 | 35 | 36 | def preprocess(image): 37 | if isinstance(image, torch.Tensor): 38 | return image 39 | elif isinstance(image, PIL.Image.Image): 40 | image = [image] 41 | 42 | if isinstance(image[0], PIL.Image.Image): 43 | w, h = image[0].size 44 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 45 | 46 | image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] 47 | image = np.concatenate(image, axis=0) 48 | image = np.array(image).astype(np.float32) / 255.0 49 | image = image.transpose(0, 3, 1, 2) 50 | image = 2.0 * image - 1.0 51 | image = torch.from_numpy(image) 52 | elif isinstance(image[0], torch.Tensor): 53 | image = torch.cat(image, dim=0) 54 | return image 55 | 56 | 57 | class StableDiffusionImg2ImgPipeline(DiffusionPipeline): 58 | r""" 59 | Pipeline for text-guided image to image generation using Stable Diffusion. 60 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 61 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 62 | Args: 63 | vae ([`AutoencoderKL`]): 64 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 65 | text_encoder ([`CLIPTextModel`]): 66 | Frozen text-encoder. Stable Diffusion uses the text portion of 67 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 68 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 69 | tokenizer (`CLIPTokenizer`): 70 | Tokenizer of class 71 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 72 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 73 | scheduler ([`SchedulerMixin`]): 74 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 75 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 76 | safety_checker ([`StableDiffusionSafetyChecker`]): 77 | Classification module that estimates whether generated images could be considered offensive or harmful. 78 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 79 | feature_extractor ([`CLIPFeatureExtractor`]): 80 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 81 | """ 82 | _optional_components = ["safety_checker"] 83 | 84 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ 85 | def __init__( 86 | self, 87 | #adapter: Embedding_Adapter, 88 | vae: AutoencoderKL, 89 | image_encoder: CLIPVisionModel, 90 | clip_processor: CLIPProcessor, 91 | unet: UNet2DConditionModel, 92 | scheduler: Union[ 93 | DDIMScheduler, 94 | PNDMScheduler, 95 | LMSDiscreteScheduler, 96 | EulerDiscreteScheduler, 97 | EulerAncestralDiscreteScheduler, 98 | DPMSolverMultistepScheduler, 99 | ], 100 | safety_checker: None, 101 | feature_extractor: CLIPFeatureExtractor, 102 | requires_safety_checker: bool = False, 103 | stochastic_sampling: bool = False 104 | ): 105 | super().__init__() 106 | 107 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 108 | deprecation_message = ( 109 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 110 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 111 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 112 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 113 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 114 | " file" 115 | ) 116 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 117 | new_config = dict(scheduler.config) 118 | new_config["steps_offset"] = 1 119 | scheduler._internal_dict = FrozenDict(new_config) 120 | 121 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 122 | deprecation_message = ( 123 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 124 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 125 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 126 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 127 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 128 | ) 129 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 130 | new_config = dict(scheduler.config) 131 | new_config["clip_sample"] = False 132 | scheduler._internal_dict = FrozenDict(new_config) 133 | 134 | if safety_checker is None and requires_safety_checker: 135 | logger.warning( 136 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 137 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 138 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 139 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 140 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 141 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 142 | ) 143 | 144 | if safety_checker is not None and feature_extractor is None: 145 | raise ValueError( 146 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 147 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 148 | ) 149 | 150 | self.adapter = Embedding_Adapter().cuda() 151 | 152 | self.register_modules( 153 | #adapter=self.adapter, 154 | vae=vae, 155 | image_encoder=image_encoder, 156 | clip_processor=clip_processor, 157 | unet=unet, 158 | scheduler=scheduler, 159 | safety_checker=safety_checker, 160 | feature_extractor=feature_extractor, 161 | ) 162 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 163 | self.clip_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") 164 | 165 | self.vae = self.vae.cuda() 166 | self.unet = self.unet.cuda() 167 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 168 | self.register_to_config(requires_safety_checker=requires_safety_checker) 169 | 170 | self.fixed_noise = None 171 | self.stochastic_sampling = stochastic_sampling 172 | 173 | print("Stochastic Sampling: ", self.stochastic_sampling) 174 | 175 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload 176 | def enable_sequential_cpu_offload(self, gpu_id=0): 177 | r""" 178 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, 179 | text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a 180 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. 181 | """ 182 | if is_accelerate_available(): 183 | from accelerate import cpu_offload 184 | else: 185 | raise ImportError("Please install accelerate via `pip install accelerate`") 186 | 187 | device = torch.device(f"cuda:{gpu_id}") 188 | 189 | for cpu_offloaded_model in [self.unet, self.image_encoder, self.clip_processor, self.vae, self.adapter]: 190 | if cpu_offloaded_model is not None: 191 | cpu_offload(cpu_offloaded_model, device) 192 | 193 | if self.safety_checker is not None: 194 | # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate 195 | # fix by only offloading self.safety_checker for now 196 | cpu_offload(self.safety_checker.vision_model, device) 197 | 198 | @property 199 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device 200 | def _execution_device(self): 201 | r""" 202 | Returns the device on which the pipeline's models will be executed. After calling 203 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 204 | hooks. 205 | """ 206 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 207 | return self.device 208 | for module in self.unet.modules(): 209 | if ( 210 | hasattr(module, "_hf_hook") 211 | and hasattr(module._hf_hook, "execution_device") 212 | and module._hf_hook.execution_device is not None 213 | ): 214 | return torch.device(module._hf_hook.execution_device) 215 | return self.device 216 | 217 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 218 | def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): 219 | r""" 220 | Encodes the prompt into text encoder hidden states. 221 | Args: 222 | image (`str` or `list(int)`): 223 | prompt to be encoded 224 | device: (`torch.device`): 225 | torch device 226 | num_images_per_prompt (`int`): 227 | number of images that should be generated per prompt 228 | do_classifier_free_guidance (`bool`): 229 | whether to use classifier free guidance or not 230 | negative_prompt (`str` or `List[str]`): 231 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 232 | if `guidance_scale` is less than `1`). 233 | """ 234 | batch_size = len(image) if isinstance(image, list) else 1 235 | #print("Batch size = ", batch_size) 236 | 237 | if isinstance(image, list): 238 | uncond_image = [torch.zeros((image[0].size[0], image[0].size[1], 3)) for _ in range(batch_size)] 239 | else: 240 | image = [image] 241 | uncond_image = [torch.zeros((image[0].size[0], image[0].size[1], 3))] 242 | 243 | with autocast(): 244 | # clip encoder 245 | inputs = self.processor(images=image, return_tensors="pt") 246 | clip_image_embeddings = self.clip_encoder(**inputs).last_hidden_state.cuda() 247 | 248 | uncond_inputs = self.processor(images=uncond_image, return_tensors="pt") 249 | clip_uncond_image_embeddings = self.clip_encoder(**uncond_inputs).last_hidden_state.cuda() 250 | 251 | # vae encoder 252 | #image = torch.from_numpy(np.array(image).transpose((2,0,1))) 253 | #image = image.cuda().float().unsqueeze(0) 254 | image_tensor = torch.tensor([np.array(im).transpose((2,0,1)) for im in image]).cuda().float() 255 | vae_image_embeddings = self.vae.encode(image_tensor).latent_dist.sample() * 0.18215 256 | #vae_image_embeddings = rearrange(image_embeddings, 'b h w c -> b (h w) c') 257 | 258 | #img_shape = image.shape 259 | #uncond_image = torch.zeros(img_shape).cuda().float() 260 | #uncond_image = uncond_image.cuda().float() 261 | uncond_image_tensor = torch.tensor([np.array(im).transpose((2,0,1)) for im in uncond_image]).cuda().float() 262 | vae_uncond_image_embeddings = self.vae.encode(uncond_image_tensor).latent_dist.sample() * 0.18215 263 | #vae_uncond_image_embeddings = rearrange(uncond_image_embeddings, 'b h w c -> b (h w) c') 264 | 265 | # adapt embeddings 266 | image_embeddings = self.adapter(clip_image_embeddings, vae_image_embeddings) 267 | uncond_image_embeddings = self.adapter(clip_uncond_image_embeddings, vae_uncond_image_embeddings) 268 | 269 | #print(image_embeddings.shape) 270 | # duplicate text embeddings for each generation per prompt, using mps friendly method 271 | bs_embed, seq_len, _ = image_embeddings.shape 272 | image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) 273 | image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 274 | 275 | bs_embed, seq_len, _ = uncond_image_embeddings .shape 276 | uncond_image_embeddings = uncond_image_embeddings.repeat(1, num_images_per_prompt, 1) 277 | uncond_image_embeddings = uncond_image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 278 | 279 | # get unconditional embeddings for classifier free guidance 280 | if do_classifier_free_guidance: 281 | image_embeddings = torch.cat([uncond_image_embeddings, image_embeddings, image_embeddings]) 282 | 283 | return image_embeddings 284 | 285 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 286 | def run_safety_checker(self, image, device, dtype): 287 | if self.safety_checker is not None: 288 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) 289 | image, has_nsfw_concept = self.safety_checker( 290 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 291 | ) 292 | else: 293 | has_nsfw_concept = None 294 | return image, has_nsfw_concept 295 | 296 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 297 | def decode_latents(self, latents): 298 | with autocast(): 299 | latents = 1 / 0.18215 * latents 300 | image = self.vae.decode(latents).sample 301 | image = (image / 2 + 0.5).clamp(0, 1) 302 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 303 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 304 | return image 305 | 306 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 307 | def prepare_extra_step_kwargs(self, generator, eta): 308 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 309 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 310 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 311 | # and should be between [0, 1] 312 | 313 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 314 | extra_step_kwargs = {} 315 | if accepts_eta: 316 | extra_step_kwargs["eta"] = eta 317 | 318 | # check if the scheduler accepts generator 319 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 320 | if accepts_generator: 321 | extra_step_kwargs["generator"] = generator 322 | return extra_step_kwargs 323 | 324 | def check_inputs(self, prompt, strength, callback_steps): 325 | if not isinstance(prompt, str) and not isinstance(prompt, list): 326 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 327 | 328 | if strength < 0 or strength > 1: 329 | raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") 330 | 331 | if (callback_steps is None) or ( 332 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 333 | ): 334 | raise ValueError( 335 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 336 | f" {type(callback_steps)}." 337 | ) 338 | 339 | def get_timesteps(self, num_inference_steps, strength, device): 340 | # get the original timestep using init_timestep 341 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 342 | 343 | t_start = max(num_inference_steps - init_timestep, 0) 344 | timesteps = self.scheduler.timesteps[t_start:] 345 | 346 | return timesteps, num_inference_steps - t_start 347 | 348 | def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): 349 | with autocast(): 350 | image = image.to(device=device, dtype=dtype).cuda() 351 | init_latent_dist = self.vae.encode(image).latent_dist 352 | init_latents = init_latent_dist.sample(generator=generator) 353 | init_latents = 0.18215 * init_latents 354 | 355 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: 356 | # expand init_latents for batch_size 357 | deprecation_message = ( 358 | f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" 359 | " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" 360 | " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" 361 | " your script to pass as many initial images as text prompts to suppress this warning." 362 | ) 363 | deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) 364 | additional_image_per_prompt = batch_size // init_latents.shape[0] 365 | init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) 366 | elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: 367 | raise ValueError( 368 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 369 | ) 370 | else: 371 | init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) 372 | 373 | # add noise to latents using the timesteps 374 | if self.fixed_noise is None: 375 | #print("Latents Shape = ", init_latents.shape, init_latents[0].shape, image.shape[0]) 376 | single_fixed_noise = torch.randn(init_latents[0].shape, generator=generator, device=device, dtype=dtype) 377 | self.fixed_noise = single_fixed_noise.repeat(image.shape[0], 1, 1, 1)#torch.tensor([single_fixed_noise for _ in range(image.shape[0])]) 378 | noise = self.fixed_noise 379 | 380 | # get latents 381 | init_latents = self.scheduler.add_noise(init_latents.cuda(), noise.cuda(), timestep) 382 | latents = init_latents 383 | 384 | return latents 385 | 386 | @torch.no_grad() 387 | def __call__( 388 | self, 389 | prompt: Union[str, List[str]], 390 | adapter = None, 391 | prev_image: Union[torch.FloatTensor, PIL.Image.Image] = None, 392 | image: Union[torch.FloatTensor, PIL.Image.Image] = None, 393 | pose: Union[torch.FloatTensor, PIL.Image.Image] = None, 394 | strength: float = 1.0, 395 | num_inference_steps: Optional[int] = 100, 396 | guidance_scale: Optional[float] = 7.5, 397 | s1: float = 1.0, # strength of input pose 398 | s2: float = 1.0, # strength of input image 399 | negative_prompt: Optional[Union[str, List[str]]] = None, 400 | num_images_per_prompt: Optional[int] = 1, 401 | eta: Optional[float] = 0.0, 402 | generator: Optional[torch.Generator] = None, 403 | output_type: Optional[str] = "pil", 404 | return_dict: bool = True, 405 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 406 | callback_steps: Optional[int] = 1, 407 | frames = [], 408 | sweep = False, 409 | **kwargs, 410 | ): 411 | r""" 412 | Function invoked when calling the pipeline for generation. 413 | Args: 414 | prompt (`str` or `List[str]`): 415 | The prompt or prompts to guide the image generation. 416 | image (`torch.FloatTensor` or `PIL.Image.Image`): 417 | `Image`, or tensor representing an image batch, that will be used as the starting point for the 418 | process. 419 | strength (`float`, *optional*, defaults to 0.8): 420 | Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` 421 | will be used as a starting point, adding more noise to it the larger the `strength`. The number of 422 | denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will 423 | be maximum and the denoising process will run for the full number of iterations specified in 424 | `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. 425 | num_inference_steps (`int`, *optional*, defaults to 50): 426 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 427 | expense of slower inference. This parameter will be modulated by `strength`. 428 | guidance_scale (`float`, *optional*, defaults to 7.5): 429 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 430 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 431 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 432 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 433 | usually at the expense of lower image quality. 434 | negative_prompt (`str` or `List[str]`, *optional*): 435 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 436 | if `guidance_scale` is less than `1`). 437 | num_images_per_prompt (`int`, *optional*, defaults to 1): 438 | The number of images to generate per prompt. 439 | eta (`float`, *optional*, defaults to 0.0): 440 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 441 | [`schedulers.DDIMScheduler`], will be ignored for others. 442 | generator (`torch.Generator`, *optional*): 443 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 444 | deterministic. 445 | output_type (`str`, *optional*, defaults to `"pil"`): 446 | The output format of the generate image. Choose between 447 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 448 | return_dict (`bool`, *optional*, defaults to `True`): 449 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 450 | plain tuple. 451 | callback (`Callable`, *optional*): 452 | A function that will be called every `callback_steps` steps during inference. The function will be 453 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 454 | callback_steps (`int`, *optional*, defaults to 1): 455 | The frequency at which the `callback` function will be called. If not specified, the callback will be 456 | called at every step. 457 | Returns: 458 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 459 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 460 | When returning a tuple, the first element is a list with the generated images, and the second element is a 461 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 462 | (nsfw) content, according to the `safety_checker`. 463 | """ 464 | 465 | # 1. Check inputs 466 | self.check_inputs(prompt, strength, callback_steps) 467 | 468 | # 2. Set adapter 469 | if adapter is not None: 470 | print("Setting adapter") 471 | self.adapter = adapter 472 | 473 | # 3. Define call parameters 474 | batch_size = 1 if isinstance(prompt, str) else len(prompt) 475 | device = self._execution_device 476 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 477 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 478 | # corresponds to doing no classifier free guidance. 479 | do_classifier_free_guidance = guidance_scale > 1.0 or s1 > 0.0 or s2 > 0.0 480 | 481 | # 4. Encode input image: [unconditional, condional, conditional] 482 | embeddings = self._encode_image( 483 | image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt 484 | ) 485 | 486 | # 5. Preprocess image 487 | image = preprocess(image) 488 | pose = preprocess(pose) 489 | image, pose = torch.tensor(image), torch.tensor(pose) 490 | 491 | # 6. Set timesteps 492 | self.scheduler.set_timesteps(num_inference_steps, device=device) 493 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 494 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 495 | 496 | # 7. Prepare latent variables 497 | latents = self.prepare_latents( 498 | image, latent_timestep, batch_size, num_images_per_prompt, embeddings.dtype, device, generator 499 | ) 500 | 501 | # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 502 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 503 | 504 | # 9. If sweeping (s1, s2) values, prepare variables 505 | if sweep: 506 | s1_vals = [0, 3, 5, 7, 9] 507 | s2_vals = [0, 3, 5, 7, 9] 508 | images = [] # store frames 509 | else: 510 | s1_vals, s2_vals = [s1], [s2] 511 | 512 | # 10. Denoising loop 513 | copy_latents = latents.clone() 514 | for s1 in s1_vals: 515 | for s2 in s2_vals: 516 | latents = copy_latents.clone() 517 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 518 | with self.progress_bar(total=num_inference_steps) as progress_bar: 519 | for i, t in enumerate(timesteps): 520 | t = t.cuda() 521 | 522 | # If stochastic sampling enabled, randomly select from previous images 523 | if self.stochastic_sampling: 524 | idx = np.random.choice(range(len(frames))) 525 | input_image = [frames[idx] for _ in range(len(image))] 526 | #print("Selecting conditioning image #", idx) 527 | embeddings = self._encode_image( 528 | input_image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt 529 | ) 530 | input_image = preprocess(input_image) 531 | 532 | # expand the latents if we are doing classifier free guidance 533 | latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents 534 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 535 | 536 | # Add pose to noisy latents 537 | _, _, h, w = latent_model_input.shape 538 | if do_classifier_free_guidance: 539 | pose_input = torch.cat([torch.zeros(pose.shape), pose, torch.zeros(pose.shape)]) 540 | else: 541 | pose_input = torch.cat([pose, pose, pose]) 542 | latent_model_input = torch.cat((latent_model_input.cuda(), F.interpolate(pose_input, (h,w)).cuda()), 1) 543 | 544 | # predict the noise residual 545 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=embeddings.cuda()).sample 546 | 547 | # perform guidance 548 | if do_classifier_free_guidance: 549 | #print(f"s1={s1}, s2={s2}") 550 | noise_pred_uncond, noise_pred_, noise_pred_img_only = noise_pred.chunk(3) 551 | noise_pred = noise_pred_uncond + \ 552 | s1 * (noise_pred_img_only - noise_pred_uncond) + \ 553 | s2 * (noise_pred_ - noise_pred_img_only) 554 | 555 | # compute the previous noisy sample x_t -> x_t-1 556 | latents = self.scheduler.step(noise_pred.cuda(), t, latents.cuda(), **extra_step_kwargs).prev_sample 557 | 558 | # call the callback, if provided 559 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 560 | progress_bar.update() 561 | if callback is not None and i % callback_steps == 0: 562 | callback(i, t, latents) 563 | 564 | # 11. Post-processing 565 | latents = latents[:,:4, :, :].cuda() #.float() 566 | image = self.decode_latents(latents) 567 | 568 | #print(len(image)) # 1 569 | #print(image[0].shape) # 640, 512, 3 570 | 571 | # 13. Convert to PIL 572 | if output_type == "pil": 573 | image = self.numpy_to_pil(image) 574 | 575 | if sweep: 576 | images.append(torchvision.transforms.ToTensor()(image[0]).clone()) 577 | 578 | # 13. If sweeping, convert images to grid 579 | if sweep: 580 | Grid = make_grid(images, nrow=len(s2_vals)) 581 | image = [torchvision.transforms.ToPILImage()(Grid)] 582 | #image = Grid 583 | #print("Grid complete.") 584 | 585 | # 14. Run safety checker 586 | #image, has_nsfw_concept = self.run_safety_checker(image, device, embeddings.dtype) 587 | 588 | if not return_dict: 589 | return (image, False) 590 | 591 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=False) 592 | 593 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from diffusers import UNet2DConditionModel, DDIMScheduler 4 | from pipelines.dual_encoder_pipeline import StableDiffusionImg2ImgPipeline 5 | import argparse 6 | from torchvision import transforms 7 | import torch 8 | import cv2, PIL, glob, random 9 | import numpy as np 10 | from torch.cuda.amp import autocast 11 | from torchvision import transforms 12 | from collections import OrderedDict 13 | from torch import nn 14 | import torch, cv2 15 | import torch.nn.functional as F 16 | from models.unet_dual_encoder import get_unet, Embedding_Adapter 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--folder", default='dreampose-1', help="Path to custom pretrained checkpoints folder.",) 20 | parser.add_argument("--pose_folder", default='../UBC_Fashion_Dataset/valid/91iZ9x8NI0S.mp4', help="Path to test frames, poses, and joints.",) 21 | parser.add_argument("--test_poses", default=None, help="Path to test frames, poses, and joints.",) 22 | parser.add_argument("--epoch", type=int, default=44, required=True, help="Pretrained custom model checkpoint epoch number.",) 23 | parser.add_argument("--key_frame_path", default='../UBC_Fashion_Dataset/dreampose/91iZ9x8NI0S.mp4/key_frame.png', help="Path to key frame.",) 24 | parser.add_argument("--pose_path", default='../UBC_Fashion_Dataset/valid/A1F1j+kNaDS.mp4/85_to_95_to_116/skeleton_i.npy', help="Pretrained model checkpoint step number.",) 25 | parser.add_argument("--strength", type=float, default=1.0, required=False, help="How much noise to add to input image.",) 26 | parser.add_argument("--s1", type=float, default=0.5, required=False, help="Classifier free guidance of input image.",) 27 | parser.add_argument("--s2", type=float, default=0.5, required=False, help="Classifier free guidance of input pose.",) 28 | parser.add_argument("--iters", default=1, type=int, help="# times to do stochastic sampling for all frames.") 29 | parser.add_argument("--sampler", default='PNDM', help="PNDM or DDIM.") 30 | parser.add_argument("--n_steps", default=100, type=int, help="Number of denoising steps.") 31 | parser.add_argument("--output_dir", default=None, help="Where to save results.") 32 | parser.add_argument("--j", type=int, default=-1, required=False, help="Specific frame number.",) 33 | parser.add_argument("--min_j", type=int, default=0, required=False, help="Lowest predicted frame id.",) 34 | parser.add_argument("--max_j", type=int, default=-1, required=False, help="Max predicted frame id.",) 35 | parser.add_argument("--custom_vae", default=None, help="Path use custom VAE checkpoint.") 36 | parser.add_argument("--batch_size", type=int, default=1, required=False, help="# frames to infer at once.",) 37 | args = parser.parse_args() 38 | 39 | save_folder = args.output_dir if args.output_dir is not None else args.folder #'results-fashion/' 40 | if not os.path.exists(save_folder): 41 | os.mkdir(save_folder) 42 | 43 | # Load custom model 44 | model_id = f"{args.folder}/checkpoint-{args.epoch}" #if args.step > 0 else "CompVis/stable-diffusion-v1-4" 45 | device = "cuda" 46 | 47 | # Load UNet 48 | unet = get_unet('CompVis/stable-diffusion-v1-4', "ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c", resolution=512) 49 | unet_path = f"{args.folder}/unet_epoch_{args.epoch}.pth" 50 | print("Loading ", unet_path) 51 | unet_state_dict = torch.load(unet_path) 52 | new_state_dict = OrderedDict() 53 | for k, v in unet_state_dict.items(): 54 | name = k.replace('module.', '') #k[7:] if k[:7] == 'module' else k 55 | new_state_dict[name] = v 56 | unet.load_state_dict(new_state_dict) 57 | unet = unet.cuda() 58 | 59 | print("Loading custom model from: ", model_id) 60 | pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, revision="fp16") 61 | pipe.safety_checker = lambda images, clip_input: (images, False) # disable safety check 62 | 63 | #pipe.unet.load_state_dict(torch.load(f'{save_folder}/unet_epoch_{args.epoch}.pth')) #'results/epoch_1/unet.pth')) 64 | #pipe.unet = pipe.unet.cuda() 65 | 66 | adapter_chkpt = f'{args.folder}/adapter_{args.epoch}.pth' 67 | print("Loading ", adapter_chkpt) 68 | adapter_state_dict = torch.load(adapter_chkpt) 69 | new_state_dict = OrderedDict() 70 | for k, v in adapter_state_dict.items(): 71 | name = k.replace('module.', '') #name = k[7:] if k[:7] == 'module' else k 72 | new_state_dict[name] = v 73 | print(pipe.adapter.linear1.weight) 74 | pipe.adapter = Embedding_Adapter() 75 | pipe.adapter.load_state_dict(new_state_dict) 76 | print(pipe.adapter.linear1.weight) 77 | pipe.adapter = pipe.adapter.cuda() 78 | 79 | if args.custom_vae is not None: 80 | vae_chkpt = args.custom_vae 81 | print("Loading custom vae checkpoint from ", vae_chkpt, '...') 82 | vae_state_dict = torch.load(vae_chkpt) 83 | new_state_dict = OrderedDict() 84 | for k, v in vae_state_dict.items(): 85 | name = k.replace('module.', '') #name = k[7:] if k[:7] == 'module' else k 86 | new_state_dict[name] = v 87 | pipe.vae.load_state_dict(new_state_dict) 88 | pipe.vae = pipe.vae.cuda() 89 | 90 | # Change scheduler 91 | if args.sampler == 'DDIM': 92 | print("Default scheduler = ", pipe.scheduler) 93 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 94 | print("New scheduler = ", pipe.scheduler) 95 | 96 | def visualize_dp(im, dp): 97 | #im = im.transpose((2, 0, 1)) 98 | print(im.shape, dp.shape) 99 | hsv = np.zeros(im.shape, dtype=np.uint8) 100 | hsv[..., 1] = 255 101 | 102 | dp = dp.cpu().detach().numpy() 103 | mag, ang = cv2.cartToPolar(dp[0], dp[1]) 104 | hsv[..., 0] = ang * 180 / np.pi / 2 105 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 106 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 107 | 108 | return bgr 109 | 110 | n_images_per_sample = 1 111 | 112 | frame_numbers = sorted([int(path.split('frame_')[-1].replace('_densepose.npy', '')) for path in glob.glob(f'{args.pose_folder}/frame_*.npy')]) 113 | frame_numbers = list(set(frame_numbers)) 114 | pose_paths = [f'{args.pose_folder}/frame_{num}_densepose.npy' for num in frame_numbers] 115 | 116 | if args.max_j > -1: 117 | pose_paths = pose_paths[args.min_j:args.max_j] 118 | else: 119 | pose_paths = pose_paths[args.min_j:] 120 | 121 | imSize = (512, 640) 122 | image_transforms = transforms.Compose( 123 | [ 124 | transforms.Resize(imSize, interpolation=transforms.InterpolationMode.BILINEAR), 125 | transforms.ToTensor(), 126 | transforms.Normalize([0.5], [0.5]), 127 | ] 128 | ) 129 | tensor_transforms = transforms.Compose( 130 | [ 131 | transforms.Normalize([0.5], [0.5]), 132 | ] 133 | ) 134 | 135 | # Load key frame 136 | input_image = PIL.Image.open(args.key_frame_path).resize(imSize) 137 | 138 | if args.j >= 0: 139 | j = args.j 140 | pose_paths = pose_paths[j:j+1] 141 | 142 | # Iterate samples 143 | prev_image = input_image 144 | for i, pose_path in enumerate(pose_paths): 145 | frame_number = int(frame_numbers[i]) 146 | h, w = imSize[1], imSize[0] 147 | 148 | # construct 5 input poses 149 | poses = [] 150 | for pose_number in range(frame_number-2, frame_number+3): 151 | dp_path = pose_path.replace(str(frame_number), str(pose_number)) 152 | if not os.path.exists(dp_path): 153 | dp_path = pose_path 154 | print(dp_path) 155 | dp_i = F.interpolate(torch.from_numpy(np.load(dp_path).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) 156 | poses.append(tensor_transforms(dp_i)) 157 | input_pose = torch.cat(poses, 0).unsqueeze(0) 158 | 159 | print(pose_path.split('_')) 160 | j = int(pose_path.split('_')[-2]) 161 | print("j = ", j) 162 | 163 | with autocast(): 164 | image = pipe(prompt="", 165 | image=input_image, 166 | pose=input_pose, 167 | strength=1.0, 168 | num_inference_steps=args.n_steps, 169 | guidance_scale=7.5, 170 | s1=args.s1, 171 | s2=args.s2, 172 | callback_steps=1, 173 | frames=[] 174 | )[0][0] 175 | 176 | 177 | 178 | # Save pose and image 179 | save_path = f"{save_folder}/pred_#{j}.png" 180 | image = image.convert('RGB') 181 | image = np.array(image) 182 | image = image - np.min(image) 183 | image = (255*(image / np.max(image))).astype(np.uint8) 184 | cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import itertools 4 | import math 5 | import os 6 | import random 7 | from pathlib import Path 8 | from typing import Optional 9 | from einops import rearrange 10 | from collections import OrderedDict 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint 16 | from torch.utils.data import Dataset 17 | import torch.nn as nn 18 | import numpy as np 19 | import cv2 20 | 21 | from accelerate import Accelerator 22 | from accelerate.logging import get_logger 23 | from accelerate.utils import set_seed 24 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 25 | from diffusers.optimization import get_scheduler 26 | from huggingface_hub import HfFolder, Repository, whoami 27 | from PIL import Image 28 | from torchvision import transforms 29 | from tqdm.auto import tqdm 30 | from transformers import CLIPFeatureExtractor, CLIPTokenizer, CLIPProcessor, CLIPVisionModel 31 | 32 | from torch.utils.tensorboard import SummaryWriter 33 | 34 | logger = get_logger(__name__) 35 | 36 | from utils.parse_args import parse_args 37 | from datasets.ubc_deepfashion_dataset import DreamPoseDataset 38 | from pipelines.dual_encoder_pipeline import StableDiffusionImg2ImgPipeline 39 | from models.unet_dual_encoder import get_unet, Embedding_Adapter 40 | 41 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 42 | if token is None: 43 | token = HfFolder.get_token() 44 | if organization is None: 45 | username = whoami(token)["name"] 46 | return f"{username}/{model_id}" 47 | else: 48 | return f"{organization}/{model_id}" 49 | 50 | def main(args): 51 | logging_dir = Path(args.output_dir, args.logging_dir) 52 | 53 | writer = SummaryWriter(f'results/logs/{args.run_name}') 54 | 55 | accelerator = Accelerator( 56 | gradient_accumulation_steps=args.gradient_accumulation_steps, 57 | mixed_precision=args.mixed_precision, 58 | log_with="tensorboard", 59 | logging_dir=logging_dir, 60 | ) 61 | 62 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 63 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 64 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 65 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 66 | raise ValueError( 67 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 68 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 69 | ) 70 | 71 | if args.seed is not None: 72 | set_seed(args.seed) 73 | 74 | # Handle the repository creation 75 | if accelerator.is_main_process: 76 | if args.push_to_hub: 77 | if args.hub_model_id is None: 78 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 79 | else: 80 | repo_name = args.hub_model_id 81 | repo = Repository(args.output_dir, clone_from=repo_name) 82 | 83 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 84 | if "step_*" not in gitignore: 85 | gitignore.write("step_*\n") 86 | if "epoch_*" not in gitignore: 87 | gitignore.write("epoch_*\n") 88 | elif args.output_dir is not None: 89 | os.makedirs(args.output_dir, exist_ok=True) 90 | 91 | # Load CLIP Image Encoder 92 | clip_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").cuda() 93 | clip_encoder.requires_grad_(False) 94 | clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 95 | 96 | # Load models and create wrapper for stable diffusion 97 | vae = AutoencoderKL.from_pretrained( 98 | "CompVis/stable-diffusion-v1-4", 99 | subfolder="vae", 100 | revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c" 101 | ) 102 | 103 | # Load pretrained UNet layers 104 | unet = get_unet3(args.pretrained_model_name_or_path, args.revision, resolution=args.resolution) 105 | #unet = get_unet('CompVis/stable-diffusion-v1-4') 106 | 107 | if args.custom_chkpt is not None: 108 | print("Loading ", args.custom_chkpt) 109 | unet_state_dict = torch.load(args.custom_chkpt) 110 | new_state_dict = OrderedDict() 111 | for k, v in unet_state_dict.items(): 112 | name = k[7:] if k[:7] == 'module' else k 113 | new_state_dict[name] = v 114 | unet.load_state_dict(new_state_dict) 115 | unet = unet.cuda() 116 | 117 | # Embedding adapter 118 | adapter = Embedding_Adapter(input_nc=1280, output_nc=1280) 119 | 120 | if args.custom_chkpt is not None: 121 | adapter_chkpt = args.custom_chkpt.replace('unet_epoch', 'adapter') 122 | print("Loading ", adapter_chkpt) 123 | adapter_state_dict = torch.load(adapter_chkpt) 124 | new_state_dict = OrderedDict() 125 | for k, v in adapter_state_dict.items(): 126 | name = k[7:] if k[:7] == 'module' else k 127 | new_state_dict[name] = v 128 | adapter.load_state_dict(new_state_dict) 129 | adapter = adapter.cuda() 130 | 131 | #adapter.requires_grad_(True) 132 | 133 | vae.requires_grad_(False) 134 | 135 | if args.gradient_checkpointing: 136 | unet.enable_gradient_checkpointing() 137 | adapter.enable_gradient_checkpointing() 138 | 139 | if args.scale_lr: 140 | args.learning_rate = ( 141 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 142 | ) 143 | 144 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 145 | if args.use_8bit_adam: 146 | try: 147 | import bitsandbytes as bnb 148 | except ImportError: 149 | raise ImportError( 150 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 151 | ) 152 | 153 | optimizer_class = bnb.optim.AdamW8bit 154 | else: 155 | optimizer_class = torch.optim.AdamW 156 | 157 | params_to_optimize = ( 158 | itertools.chain(unet.parameters(), adapter.parameters(),) 159 | ) 160 | 161 | optimizer = optimizer_class( 162 | params_to_optimize, 163 | lr=args.learning_rate, 164 | betas=(args.adam_beta1, args.adam_beta2), 165 | weight_decay=args.adam_weight_decay, 166 | eps=args.adam_epsilon, 167 | ) 168 | 169 | noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 170 | 171 | # Load the tokenizer 172 | if args.tokenizer_name: 173 | tokenizer = CLIPTokenizer.from_pretrained( 174 | args.tokenizer_name, 175 | revision=args.revision, 176 | ) 177 | elif args.pretrained_model_name_or_path: 178 | tokenizer = CLIPTokenizer.from_pretrained( 179 | args.pretrained_model_name_or_path, 180 | subfolder="tokenizer", 181 | revision=args.revision, 182 | ) 183 | 184 | train_dataset = DreamPoseDataset( 185 | instance_data_root=args.instance_data_dir, 186 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 187 | class_prompt=args.class_prompt, 188 | size=args.resolution, 189 | center_crop=args.center_crop, 190 | ) 191 | 192 | def collate_fn(examples): 193 | frame_i = [example["frame_i"] for example in examples] 194 | frame_j = [example["frame_j"] for example in examples] 195 | poses = [example["pose_j"] for example in examples] 196 | 197 | # Concat class and instance examples for prior preservation. 198 | # We do this to avoid doing two forward passes. 199 | if args.with_prior_preservation: 200 | input_ids += [example["class_prompt_ids"] for example in examples] 201 | frame_i += [example["class_frame_i"] for example in examples] 202 | frame_j += [example["class_frame_j"] for example in examples] 203 | poses += [example["class_pose_j"] for example in examples] 204 | 205 | frame_i = torch.cat(frame_i, 0) 206 | frame_j = torch.cat(frame_j, 0) 207 | poses = torch.cat(poses, 0) 208 | 209 | # Dropout 210 | p = random.random() 211 | if p <= args.dropout_rate / 3: # dropout pose 212 | poses = torch.zeros(poses.shape) 213 | elif p <= 2*args.dropout_rate / 3: # dropout image 214 | frame_i = torch.zeros(frame_i.shape) 215 | #frame_k = torch.zeros(frame_k.shape) 216 | elif p <= args.dropout_rate: # dropout image and pose 217 | poses = torch.zeros(poses.shape) 218 | frame_i = torch.zeros(frame_i.shape) 219 | #frame_k = torch.zeros(frame_k.shape) 220 | 221 | frame_i = frame_i.to(memory_format=torch.contiguous_format).float() 222 | frame_j = frame_j.to(memory_format=torch.contiguous_format).float() 223 | #frame_k = frame_k.to(memory_format=torch.contiguous_format).float() 224 | poses = poses.to(memory_format=torch.contiguous_format).float() 225 | #joints = joints.to(memory_format=torch.contiguous_format).float() 226 | 227 | batch = { 228 | "frame_i": frame_i, 229 | "frame_j": frame_j, 230 | #"frame_k": frame_k, 231 | "poses": poses, 232 | #"joints_j": joints, 233 | } 234 | return batch 235 | 236 | train_dataloader = torch.utils.data.DataLoader( 237 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 238 | ) 239 | 240 | # Scheduler and math around the number of training steps. 241 | overrode_max_train_steps = False 242 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 243 | if args.max_train_steps is None: 244 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 245 | overrode_max_train_steps = True 246 | 247 | lr_scheduler = get_scheduler( 248 | args.lr_scheduler, 249 | optimizer=optimizer, 250 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 251 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 252 | ) 253 | 254 | if args.train_text_encoder: 255 | unet, adapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 256 | unet, adapter, optimizer, train_dataloader, lr_scheduler 257 | ) 258 | else: 259 | unet, adapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 260 | unet, adapter, optimizer, train_dataloader, lr_scheduler 261 | ) 262 | 263 | weight_dtype = torch.float32 264 | if accelerator.mixed_precision == "fp16": 265 | weight_dtype = torch.float16 266 | elif accelerator.mixed_precision == "bf16": 267 | weight_dtype = torch.bfloat16 268 | 269 | # Move text_encode and vae to gpu. 270 | # For mixed precision training we cast the image_encoder and vae weights to half-precision 271 | # as these models are only used for inference, keeping weights in full precision is not required. 272 | vae.to(accelerator.device, dtype=weight_dtype) 273 | 274 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 275 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 276 | if overrode_max_train_steps: 277 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 278 | # Afterwards we recalculate our number of training epochs 279 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 280 | 281 | # We need to initialize the trackers we use, and also store our configuration. 282 | # The trackers initializes automatically on the main process. 283 | if accelerator.is_main_process: 284 | accelerator.init_trackers("dreambooth", config=vars(args)) 285 | 286 | # Train! 287 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 288 | 289 | logger.info("***** Running training *****") 290 | logger.info(f" Num examples = {len(train_dataset)}") 291 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 292 | logger.info(f" Num Epochs = {args.num_train_epochs}") 293 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 294 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 295 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 296 | logger.info(f" Total optimization steps = {args.max_train_steps}") 297 | # Only show the progress bar once on each machine. 298 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 299 | progress_bar.set_description("Steps") 300 | global_step = 0 301 | 302 | def latents2img(latents): 303 | latents = 1 / 0.18215 * latents 304 | images = vae.decode(latents).sample 305 | images = (images / 2 + 0.5).clamp(0, 1) 306 | images = images.detach().cpu().numpy() 307 | images = (images * 255).round().astype("uint8") 308 | return images 309 | 310 | def inputs2img(input): 311 | target_images = (input / 2 + 0.5).clamp(0, 1) 312 | target_images = target_images.detach().cpu().numpy() 313 | target_images = (target_images * 255).round().astype("uint8") 314 | return target_images 315 | 316 | def visualize_dp(im, dp): 317 | im = im.transpose((1,2,0)) 318 | hsv = np.zeros(im.shape, dtype=np.uint8) 319 | hsv[..., 1] = 255 320 | 321 | dp = dp.cpu().detach().numpy() 322 | mag, ang = cv2.cartToPolar(dp[0], dp[1]) 323 | hsv[..., 0] = ang * 180 / np.pi / 2 324 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 325 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 326 | 327 | bgr = bgr.transpose((2,0,1)) 328 | return bgr 329 | 330 | latest_chkpt_step = 0 331 | for epoch in range(args.epoch, args.num_train_epochs): 332 | unet.train() 333 | adapter.train() 334 | first_batch = True 335 | for step, batch in enumerate(train_dataloader): 336 | if first_batch and latest_chkpt_step is not None: 337 | #os.system(f"python test_img2img.py --step {latest_chkpt_step} --strength 0.8") 338 | first_batch = False 339 | with accelerator.accumulate(unet): 340 | # Convert images to latent space 341 | latents = vae.encode(batch["frame_j"].to(dtype=weight_dtype)).latent_dist.sample() 342 | latents = latents * 0.18215 343 | 344 | # Sample noise that we'll add to the latents 345 | noise = torch.randn_like(latents) 346 | bsz = latents.shape[0] 347 | 348 | # Sample a random timestep for each image 349 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 350 | timesteps = timesteps.long() 351 | 352 | # Add noise to the latents according to the noise magnitude at each timestep 353 | # (this is the forward diffusion process) 354 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 355 | 356 | # Concatenate pose with noise 357 | _, _, h, w = noisy_latents.shape 358 | noisy_latents = torch.cat((noisy_latents, F.interpolate(batch['poses'], (h,w))), 1) 359 | 360 | # Get CLIP embeddings 361 | inputs = clip_processor(images=list(batch['frame_i'].to(latents.device)), return_tensors="pt") 362 | inputs = {k: v.to(latents.device) for k, v in inputs.items()} 363 | clip_hidden_states = clip_encoder(**inputs).last_hidden_state.to(latents.device) 364 | 365 | #print("clip states shape = ", clip_hidden_states.shape) 366 | 367 | # Get VAE embeddings 368 | #print("frame i shape = ", batch['frame_i'].shape) 369 | image = batch['frame_i'].to(device=latents.device, dtype=weight_dtype) 370 | vae_hidden_states = vae.encode(image).latent_dist.sample() * 0.18215 371 | #print("vae states shape = ", vae_hidden_states.shape) 372 | 373 | encoder_hidden_states = adapter(clip_hidden_states, vae_hidden_states) 374 | 375 | # Predict the noise residual 376 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 377 | 378 | # Get the target for loss depending on the prediction type 379 | if noise_scheduler.config.prediction_type == "epsilon": 380 | target = noise 381 | elif noise_scheduler.config.prediction_type == "v_prediction": 382 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 383 | else: 384 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 385 | 386 | if args.with_prior_preservation: 387 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 388 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 389 | target, target_prior = torch.chunk(target, 2, dim=0) 390 | 391 | # Compute instance loss 392 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 393 | 394 | # Compute prior loss 395 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 396 | 397 | # Add the prior loss to the instance loss. 398 | loss = loss + args.prior_loss_weight * prior_loss 399 | else: 400 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 401 | 402 | accelerator.backward(loss) 403 | if accelerator.sync_gradients: 404 | params_to_clip = ( 405 | itertools.chain(unet.parameters()) 406 | ) 407 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 408 | optimizer.step() 409 | lr_scheduler.step() 410 | optimizer.zero_grad() 411 | 412 | # Checks if the accelerator has performed an optimization step behind the scenes 413 | if accelerator.sync_gradients: 414 | progress_bar.update(1) 415 | global_step += 1 416 | 417 | # write to tensorboard 418 | if global_step % 10 == 0: 419 | weights = adapter.linear1.weight.cpu().detach().numpy() 420 | weights = np.sum(weights, axis=0) 421 | weights = weights.flatten() 422 | plt.figure() 423 | plt.plot(range(len(weights)), weights) 424 | plt.title(f"VAE Weights = {weights[50:]}") 425 | #plt.hist(weights) 426 | writer.add_figure('embedding_weights', plt.gcf(), global_step=global_step) 427 | #add_image("emebedding_weights", h, global_step=global_step) 428 | #writer.add_histogram('embedding_weights', weights, global_step=global_step, bins='tensorflow') 429 | writer.add_scalar("loss/train", loss.detach().item(), global_step) 430 | if global_step % 50 == 0: 431 | with torch.no_grad(): 432 | pred_latents = noisy_latents[:,:4,:,:] - model_pred 433 | pred_images = latents2img(pred_latents) 434 | noise_viz = latents2img(noisy_latents[:,:4,:,:]) 435 | target = inputs2img(batch["frame_j"]) 436 | input_img = inputs2img(batch["frame_i"]) 437 | middle_pose = visualize_dp(target[0], batch['poses'][0][4:6]) 438 | 439 | pose_viz = [] 440 | for pose_id in range(0, 5): 441 | start, end = 2*pose_id, 2*(pose_id+1) 442 | pose = visualize_dp(target[0], batch['poses'][0][start:end]) 443 | pose_viz.append(pose) 444 | 445 | pose_viz = np.concatenate(pose_viz, axis=2) 446 | frame_viz = np.concatenate([input_img[0], noise_viz[0], middle_pose, pred_images[0], target[0]], axis=2) 447 | viz = np.concatenate([frame_viz, pose_viz], axis=1) 448 | writer.add_image(f'train/pred_img', viz, global_step=global_step) 449 | 450 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 451 | progress_bar.set_postfix(**logs) 452 | accelerator.log(logs, step=global_step) 453 | 454 | if global_step >= args.max_train_steps: 455 | break 456 | 457 | # save model 458 | if accelerator.is_main_process and global_step % 500 == 0: 459 | pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( 460 | args.pretrained_model_name_or_path, 461 | #adapter=accelerator.unwrap_model(adapter), 462 | unet=accelerator.unwrap_model(unet), 463 | tokenizer=tokenizer, 464 | image_encoder=accelerator.unwrap_model(clip_encoder), 465 | clip_processor=accelerator.unwrap_model(clip_processor), 466 | revision=args.revision, 467 | ) 468 | pipeline.save_pretrained(os.path.join(args.output_dir, f'checkpoint-{epoch}')) 469 | model_path = args.output_dir+f'/unet_epoch_{epoch}.pth' 470 | torch.save(unet.state_dict(), model_path) 471 | adapter_path = args.output_dir+f'/adapter_{epoch}.pth' 472 | torch.save(adapter.state_dict(), adapter_path) 473 | 474 | accelerator.wait_for_everyone() 475 | 476 | accelerator.end_training() 477 | 478 | 479 | if __name__ == "__main__": 480 | args = parse_args() 481 | main(args) -------------------------------------------------------------------------------- /utils/densepose.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script correctly formats the pickle output of running DensePose on a directory of images. 3 | Each frame_x_densepose.npy is stored in the same place as its corresponding image, frame_x.png. 4 | ''' 5 | import os 6 | import cv2 7 | import glob 8 | import tqdm 9 | import torch 10 | import numpy as np 11 | 12 | # Filepath to raw DensePose pickle output 13 | outpath = '../UBC_Fashion_Dataset/detectron2/projects/DensePose/densepose.pkl' 14 | 15 | # Convert pickle data to numpy arrays and save 16 | data = torch.load(outpath) 17 | for i in tqdm.tqdm(range(len(data))): 18 | dp = data[i] 19 | path = dp['file_name'] # path to original image 20 | dp_uv = dp['pred_densepose'][0].uv # uv coordinates 21 | h, w, c = cv2.imread(path).shape 22 | _, h_, w_ = dp_uv.shape 23 | (x1, y1, x2, y2) = dp['pred_boxes_XYXY'][0].int().numpy() # location of person 24 | y2, x2 = y1+h_, x1+w_ 25 | dp_im = np.zeros((2, h, w)) 26 | dp_im[:,y1:y2,x1:x2] = dp_uv.cpu().numpy() 27 | savepath = path.replace('.png', '_densepose.npy') 28 | np.save(savepath, dp_im) 29 | 30 | 31 | -------------------------------------------------------------------------------- /utils/parse_args.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | 3 | def parse_args(input_args=None): 4 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 5 | 6 | parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.",) 7 | parser.add_argument("--custom_chkpt", type=str, default=None, required=False, help="Path to custom pretrained model.",) 8 | parser.add_argument('--tb_dir', default="tb", help="Directory for tensorboard files") 9 | parser.add_argument('--cfg', default="cfg/train.cfg", help="Path to config file") 10 | parser.add_argument('--chkpt', default=None, help="Path to checkpoint -state file") 11 | parser.add_argument("--run_name", type=str, default='dreampose-tb') 12 | parser.add_argument('--epoch', default=0, type=int, help="Which epoch to start training at") 13 | parser.add_argument("--revision", type=str, default=None, required=False, help="Revision of pretrained model identifier from huggingface.co/models.",) 14 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name",) 15 | parser.add_argument("--instance_data_dir", type=str, default=None, required=True, help="A folder containing the training data of instance images.",) 16 | parser.add_argument("--class_data_dir", type=str, default=None, required=False, help="A folder containing the training data of class images.",) 17 | parser.add_argument("--class_prompt", type=str, default=None, help="The prompt to specify images in the same class as provided instance images.",) 18 | parser.add_argument('--num_frames', default=8, type=int, help="Which epoch to start training at") 19 | parser.add_argument('--dropout_rate', default=0.2, type=float, help="Percent of training samples to remove conditioning info.") 20 | parser.add_argument("--train_decoder", action='store_true', help="Whether or not to train the VAE decoder with an additional L1-Loss.") 21 | parser.add_argument( 22 | "--with_prior_preservation", 23 | default=False, 24 | action="store_true", 25 | help="Flag to add prior preservation loss.", 26 | ) 27 | parser.add_argument( 28 | "--face_loss", 29 | default=False, 30 | action="store_true", 31 | help="Flag to add face loss.", 32 | ) 33 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 34 | parser.add_argument("--guidance_scale", type=float, default=7.5, help="Classifier-free guidance scale.") 35 | parser.add_argument( 36 | "--num_class_images", 37 | type=int, 38 | default=100, 39 | help=( 40 | "Minimal class images for prior preservation loss. If not have enough images, additional images will be" 41 | " sampled with class_prompt." 42 | ), 43 | ) 44 | parser.add_argument( 45 | "--output_dir", 46 | type=str, 47 | default="text-inversion-model", 48 | help="The output directory where the model predictions and checkpoints will be written.", 49 | ) 50 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 51 | parser.add_argument( 52 | "--resolution", 53 | type=int, 54 | default=512, 55 | help=( 56 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 57 | " resolution" 58 | ), 59 | ) 60 | parser.add_argument( 61 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 62 | ) 63 | parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") 64 | parser.add_argument( 65 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 66 | ) 67 | parser.add_argument( 68 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 69 | ) 70 | parser.add_argument("--num_train_epochs", type=int, default=1) 71 | parser.add_argument( 72 | "--max_train_steps", 73 | type=int, 74 | default=None, 75 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 76 | ) 77 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") 78 | parser.add_argument( 79 | "--gradient_accumulation_steps", 80 | type=int, 81 | default=1, 82 | help="Number of updates steps to accumulate before performing a backward/update pass.", 83 | ) 84 | parser.add_argument( 85 | "--gradient_checkpointing", 86 | action="store_true", 87 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 88 | ) 89 | parser.add_argument( 90 | "--learning_rate", 91 | type=float, 92 | default=5e-6, 93 | help="Initial learning rate (after the potential warmup period) to use.", 94 | ) 95 | parser.add_argument( 96 | "--scale_lr", 97 | action="store_true", 98 | default=False, 99 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 100 | ) 101 | parser.add_argument( 102 | "--lr_scheduler", 103 | type=str, 104 | default="constant", 105 | help=( 106 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 107 | ' "constant", "constant_with_warmup"]' 108 | ), 109 | ) 110 | parser.add_argument( 111 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 112 | ) 113 | parser.add_argument( 114 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 115 | ) 116 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 117 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 118 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 119 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 120 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 121 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 122 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 123 | parser.add_argument( 124 | "--hub_model_id", 125 | type=str, 126 | default=None, 127 | help="The name of the repository to keep in sync with the local `output_dir`.", 128 | ) 129 | parser.add_argument( 130 | "--logging_dir", 131 | type=str, 132 | default="logs", 133 | help=( 134 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 135 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 136 | ), 137 | ) 138 | parser.add_argument( 139 | "--mixed_precision", 140 | type=str, 141 | default=None, 142 | choices=["no", "fp16", "bf16"], 143 | help=( 144 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 145 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 146 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 147 | ), 148 | ) 149 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 150 | 151 | if input_args is not None: 152 | args = parser.parse_args(input_args) 153 | else: 154 | args = parser.parse_args() 155 | 156 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 157 | if env_local_rank != -1 and env_local_rank != args.local_rank: 158 | args.local_rank = env_local_rank 159 | 160 | if args.with_prior_preservation: 161 | if args.class_data_dir is None: 162 | raise ValueError("You must specify a data directory for class images.") 163 | if args.class_prompt is None: 164 | raise ValueError("You must specify prompt for class images.") 165 | else: 166 | if args.class_data_dir is not None: 167 | logger.warning("You need not use --class_data_dir without --with_prior_preservation.") 168 | if args.class_prompt is not None: 169 | logger.warning("You need not use --class_prompt without --with_prior_preservation.") 170 | 171 | return args --------------------------------------------------------------------------------