├── Deformation_Stage ├── data │ ├── aligned_dataset_vitonhd.py │ ├── base_data_loader.py │ ├── base_dataset.py │ ├── custom_dataset_data_loader.py │ └── data_loader.py ├── models │ ├── ds_selector.py │ ├── dsdnet.py │ ├── external_function.py │ └── light_net.py ├── scripts │ ├── test.sh │ └── train.sh ├── test.py ├── train.py └── utils │ ├── lpips │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── base_model.cpython-36.pyc │ │ ├── dist_model.cpython-36.pyc │ │ ├── networks_basic.cpython-36.pyc │ │ └── pretrained_networks.cpython-36.pyc │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ ├── lpips_2dirs.py │ ├── test_ssim.py │ └── utils.py ├── README.md ├── Synthesis_Stage ├── configs │ └── vitonhd_512.yaml ├── ldm │ ├── data │ │ ├── aligned_dataset_vitonhd.py │ │ ├── base.py │ │ └── base_dataset.py │ ├── lr_scheduler.py │ ├── models │ │ ├── __pycache__ │ │ │ └── autoencoder.cpython-38.pyc │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── ddim.cpython-38.pyc │ │ │ ├── ddpm.cpython-38.pyc │ │ │ ├── ddpm_nodelta.cpython-38.pyc │ │ │ ├── ddpm_norec.cpython-38.pyc │ │ │ ├── dtdpm.cpython-38.pyc │ │ │ └── plms.cpython-38.pyc │ │ │ ├── classifier.py │ │ │ ├── clip │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ └── base_clip.cpython-38.pyc │ │ │ ├── base_clip.py │ │ │ └── clip │ │ │ │ ├── .ipynb_checkpoints │ │ │ │ ├── __init__-checkpoint.py │ │ │ │ ├── clip-checkpoint.py │ │ │ │ ├── model-checkpoint.py │ │ │ │ └── simple_tokenizer-checkpoint.py │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── clip.cpython-36.pyc │ │ │ │ ├── clip.cpython-38.pyc │ │ │ │ ├── model.cpython-36.pyc │ │ │ │ ├── model.cpython-38.pyc │ │ │ │ ├── simple_tokenizer.cpython-36.pyc │ │ │ │ └── simple_tokenizer.cpython-38.pyc │ │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ │ ├── clip.py │ │ │ │ ├── model.py │ │ │ │ └── simple_tokenizer.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ └── plms.py │ ├── modules │ │ ├── __pycache__ │ │ │ ├── attention.cpython-38.pyc │ │ │ ├── ema.cpython-38.pyc │ │ │ ├── vgg.cpython-38.pyc │ │ │ └── x_transformer.cpython-38.pyc │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── model.cpython-38.pyc │ │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ │ ├── openaimodel2.cpython-38.pyc │ │ │ │ └── util.cpython-38.pyc │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ └── distributions.cpython-38.pyc │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── modules.cpython-38.pyc │ │ │ │ └── xf.cpython-38.pyc │ │ │ ├── modules.py │ │ │ └── xf.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ ├── vgg.py │ │ └── x_transformer.py │ ├── resizer.py │ └── util.py ├── main.py ├── scripts │ ├── test.sh │ └── train.sh └── test.py ├── assets ├── pipeline.png └── results.png └── environment.yaml /Deformation_Stage/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /Deformation_Stage/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import random 6 | 7 | class BaseDataset(data.Dataset): 8 | def __init__(self): 9 | super(BaseDataset, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseDataset' 13 | 14 | def initialize(self, opt): 15 | pass 16 | 17 | def get_params(opt, size): 18 | w, h = size 19 | new_h = h 20 | new_w = w 21 | if opt.resize_or_crop == 'resize_and_crop': 22 | new_h = new_w = opt.loadSize 23 | elif opt.resize_or_crop == 'scale_width_and_crop': 24 | new_w = opt.loadSize 25 | new_h = opt.loadSize * h // w 26 | 27 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 28 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 29 | 30 | #flip = random.random() > 0.5 31 | flip = 0 32 | return {'crop_pos': (x, y), 'flip': flip} 33 | 34 | def get_transform_resize(opt, params, method=Image.BICUBIC, normalize=True): 35 | transform_list = [] 36 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 37 | osize = [256,192] 38 | transform_list.append(transforms.Scale(osize, method)) 39 | if 'crop' in opt.resize_or_crop: 40 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 41 | 42 | if opt.resize_or_crop == 'none': 43 | base = float(2 ** opt.n_downsample_global) 44 | if opt.netG == 'local': 45 | base *= (2 ** opt.n_local_enhancers) 46 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 47 | 48 | if opt.mode == 'train' and not opt.no_flip: 49 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 50 | 51 | transform_list += [transforms.ToTensor()] 52 | 53 | if normalize: 54 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 55 | (0.5, 0.5, 0.5))] 56 | return transforms.Compose(transform_list) 57 | 58 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True): 59 | transform_list = [] 60 | if 'resize' in opt.resize_or_crop: 61 | osize = [opt.loadSize, opt.loadSize] 62 | transform_list.append(transforms.Scale(osize, method)) 63 | elif 'scale_width' in opt.resize_or_crop: 64 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 65 | osize = [256,192] 66 | transform_list.append(transforms.Scale(osize, method)) 67 | if 'crop' in opt.resize_or_crop: 68 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 69 | 70 | if opt.resize_or_crop == 'none': 71 | base = float(2 ** opt.n_downsample_global) 72 | if opt.netG == 'local': 73 | base *= (2 ** opt.n_local_enhancers) 74 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 75 | 76 | if opt.mode == 'train' and not opt.no_flip: 77 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 78 | 79 | transform_list += [transforms.ToTensor()] 80 | 81 | if normalize: 82 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 83 | (0.5, 0.5, 0.5))] 84 | return transforms.Compose(transform_list) 85 | 86 | def normalize(): 87 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 88 | 89 | def __make_power_2(img, base, method=Image.BICUBIC): 90 | ow, oh = img.size 91 | h = int(round(oh / base) * base) 92 | w = int(round(ow / base) * base) 93 | if (h == oh) and (w == ow): 94 | return img 95 | return img.resize((w, h), method) 96 | 97 | def __scale_width(img, target_width, method=Image.BICUBIC): 98 | ow, oh = img.size 99 | if (ow == target_width): 100 | return img 101 | w = target_width 102 | h = int(target_width * oh / ow) 103 | return img.resize((w, h), method) 104 | 105 | def __crop(img, pos, size): 106 | ow, oh = img.size 107 | x1, y1 = pos 108 | tw = th = size 109 | if (ow > tw or oh > th): 110 | return img.crop((x1, y1, x1 + tw, y1 + th)) 111 | return img 112 | 113 | def __flip(img, flip): 114 | if flip: 115 | return img.transpose(Image.FLIP_LEFT_RIGHT) 116 | return img 117 | -------------------------------------------------------------------------------- /Deformation_Stage/data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.aligned_dataset import AlignedDataset 8 | dataset = AlignedDataset() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class CustomDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'CustomDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle=not opt.serial_batches, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /Deformation_Stage/data/data_loader.py: -------------------------------------------------------------------------------- 1 | def CreateDataLoader(opt): 2 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 3 | data_loader = CustomDatasetDataLoader() 4 | print(data_loader.name()) 5 | data_loader.initialize(opt) 6 | return data_loader 7 | -------------------------------------------------------------------------------- /Deformation_Stage/models/ds_selector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import sys 4 | from torch import nn 5 | 6 | 7 | class ExtractionOperation(nn.Module): 8 | def __init__(self, in_channel=256, out_channel=256, num_label=8, match_kernel=3): 9 | super(ExtractionOperation, self).__init__() 10 | self.value_conv = EqualConv2d(in_channel, in_channel, match_kernel, 1, match_kernel//2, bias=True) 11 | self.semantic_extraction_filter = EqualConv2d(in_channel, num_label, match_kernel, 1, match_kernel//2, bias=False) 12 | 13 | self.softmax = nn.Softmax(dim=-1) 14 | self.num_label = num_label 15 | self.proj = nn.Linear(in_channel, out_channel) 16 | 17 | def forward(self, value): 18 | key = value 19 | b, c, h, w = value.shape 20 | key = self.semantic_extraction_filter(self.feature_norm(key)) 21 | extraction_softmax = key.view(b, -1, h*w) # bkm 22 | values_flatten = self.value_conv(value).view(b, -1, h*w) 23 | neural_textures = torch.einsum('bkm,bvm->bkv', extraction_softmax, values_flatten) 24 | attn = self.proj(neural_textures) 25 | coarse_mask = gumbel_softmax(attn) 26 | fine_mask = gumbel_softmax(neural_textures) 27 | 28 | return coarse_mask, fine_mask # extraction_softmax 29 | 30 | def feature_norm(self, input_tensor): 31 | input_tensor = input_tensor - input_tensor.mean(dim=1, keepdim=True) 32 | norm = torch.norm(input_tensor, 2, 1, keepdim=True) + sys.float_info.epsilon 33 | out = torch.div(input_tensor, norm) 34 | return out 35 | 36 | 37 | class EqualConv2d(nn.Module): 38 | def __init__( 39 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 40 | ): 41 | super().__init__() 42 | 43 | self.weight = nn.Parameter( 44 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 45 | ) 46 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 47 | 48 | self.stride = stride 49 | self.padding = padding 50 | 51 | if bias: 52 | self.bias = nn.Parameter(torch.zeros(out_channel)) 53 | 54 | else: 55 | self.bias = None 56 | 57 | def forward(self, input): 58 | out = torch.nn.functional.conv2d( 59 | input, 60 | self.weight * self.scale, 61 | bias=self.bias, 62 | stride=self.stride, 63 | padding=self.padding, 64 | ) 65 | 66 | return out 67 | 68 | def __repr__(self): 69 | return ( 70 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 71 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 72 | ) 73 | 74 | 75 | def gumbel_softmax(logits: torch.Tensor, tau: float = 1, dim: int = -2) -> torch.Tensor: 76 | gumbel_dist = torch.distributions.gumbel.Gumbel( 77 | torch.tensor(0., device=logits.device, dtype=logits.dtype), 78 | torch.tensor(1., device=logits.device, dtype=logits.dtype)) 79 | gumbels = gumbel_dist.sample(logits.shape) 80 | 81 | gumbels = (logits + gumbels) / tau 82 | y_soft = gumbels.softmax(dim) 83 | 84 | index = y_soft.max(dim, keepdim=True)[1] 85 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) 86 | ret = y_hard - y_soft.detach() + y_soft 87 | 88 | return ret 89 | 90 | 91 | if __name__ == '__main__': 92 | net = ExtractionOperation(64, num_label=8, match_kernel=3).cuda() 93 | # for k,v in net.state_dict().items(): 94 | # print(k) 95 | garment = torch.ones(2, 64, 256, 192).cuda() 96 | mask1, mask2 = net(garment) 97 | print(mask1.shape) 98 | -------------------------------------------------------------------------------- /Deformation_Stage/models/light_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LightNet(nn.Module): 7 | def __init__(self, hidden_dim=64): 8 | super(LightNet, self).__init__() 9 | self.hidden_dim = hidden_dim 10 | self.encoder = torch.nn.Sequential( 11 | torch.nn.Conv2d(3, out_channels=self.hidden_dim // 2, kernel_size=3, stride=1, padding=1), 12 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 13 | torch.nn.Conv2d(in_channels=self.hidden_dim // 2, out_channels=self.hidden_dim, kernel_size=1, stride=1) 14 | ) 15 | 16 | self.decoder = torch.nn.Sequential( 17 | torch.nn.Conv2d(self.hidden_dim, out_channels=self.hidden_dim // 2, kernel_size=1, stride=1), 18 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 19 | torch.nn.Conv2d(in_channels=self.hidden_dim // 2, out_channels=3, kernel_size=3, stride=1, padding=1) 20 | ) 21 | 22 | def forward(self, gar_img, num_layer): 23 | results = [] 24 | for num in range(num_layer): 25 | cur_gar_img = F.interpolate(gar_img, scale_factor=0.5 ** (4 - num), mode='bilinear') 26 | x = self.encoder(cur_gar_img) 27 | x = self.decoder(x) 28 | results.append(x) 29 | return results -------------------------------------------------------------------------------- /Deformation_Stage/scripts/test.sh: -------------------------------------------------------------------------------- 1 | python -u test.py -b 16 --gpu 0 --name d4vton_deform --mode test \ 2 | --exp_name \ 3 | --dataroot \ 4 | --image_pairs_txt \ 5 | --ckpt_dir checkpoints/vitonhd_deformation.pt -------------------------------------------------------------------------------- /Deformation_Stage/scripts/train.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=6231 train.py \ 2 | --dataroot \ 3 | -b 2 --num_gpus 4 --name d4vton_deform --group_num 8 -------------------------------------------------------------------------------- /Deformation_Stage/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch.backends.cudnn as cudnn 4 | import torch 5 | from torch.nn import functional as F 6 | import tqdm 7 | import numpy as np 8 | from data.aligned_dataset_vitonhd import AlignedDataset 9 | from models.dsdnet import DSDNet 10 | from torch.utils import data 11 | from torchvision.utils import save_image 12 | cudnn.benchmark = True 13 | 14 | 15 | def load_networks(opt, network, load_path): 16 | device = torch.device("cuda:" + str(opt.gpu) if torch.cuda.is_available() else "cpu") 17 | if not os.path.exists(load_path): 18 | print("not exsits %s" % load_path) 19 | return 20 | print('loading the model from %s' % load_path) 21 | 22 | state_dict = torch.load(load_path, map_location=device) 23 | # load params 24 | network.load_state_dict(state_dict["state_dict"]) 25 | 26 | return network 27 | 28 | 29 | def get_opt(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--name', type=str, required=True) 32 | parser.add_argument('-b', '--batch_size', type=int, default=8) 33 | parser.add_argument('--load_height', type=int, default=512) 34 | parser.add_argument('--load_width', type=int, default=384) 35 | parser.add_argument('--shuffle', action='store_false') 36 | parser.add_argument('--mode', type=str, default='test') 37 | parser.add_argument('--sample_nums', type=int, default=1) 38 | parser.add_argument('--group_nums', type=int, default=8) 39 | # dataset 40 | parser.add_argument('--dataroot', type=str, default='/VITON-HD-512/') 41 | parser.add_argument('--image_pairs_txt', type=str, default='test_pairs_unpaired_1018.txt') 42 | parser.add_argument('--label_nc', type=int, default=14, help='# of input label channels') 43 | parser.add_argument('--resize_or_crop', type=str, default='None', 44 | help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 45 | parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size') 46 | parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 47 | parser.add_argument('--no_flip', action='store_true', 48 | help='if specified, do not flip the images for data argumentation') 49 | parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG') 50 | parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use') 51 | # test setting 52 | parser.add_argument('--gpu', type=int, default=0, help='which gpu to use') 53 | parser.add_argument('--exp_name', type=str, default='unpaired-cloth-warp') 54 | parser.add_argument('--save_dir', type=str, default='./results/') 55 | parser.add_argument('--num_gpus', type=int, default=1, help='the number of gpus') 56 | parser.add_argument('--ckpt_dir', type=str, default='checkpoints/vitonhd_deformation.pt') 57 | opt = parser.parse_args() 58 | return opt 59 | 60 | 61 | def deformation_test(opt, warp_model): 62 | test_dataset = AlignedDataset() 63 | test_dataset.initialize(opt) 64 | test_loader = data.DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=8) 65 | 66 | with torch.no_grad(): 67 | for i, inputs in enumerate(tqdm.tqdm(test_loader)): 68 | img_names = inputs['img_path'] 69 | pre_clothes_edge = inputs['edge'] 70 | clothes = inputs['color'] 71 | clothes = clothes * pre_clothes_edge 72 | real_image = inputs['image'] 73 | pose = inputs['pose'] 74 | size = inputs['color'].size() 75 | oneHot_size1 = (size[0], 25, size[2], size[3]) 76 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_() 77 | densepose = densepose.scatter_(1, inputs['densepose'].data.long().cuda(), 1.0) 78 | densepose = densepose * 2.0 - 1.0 79 | 80 | pose = pose.cuda() 81 | clothes = clothes.cuda() 82 | preserve_mask = inputs['preserve_mask'].cuda() 83 | 84 | condition = torch.cat([densepose, pose, preserve_mask], 1) 85 | results_all = warp_model(condition, clothes) 86 | 87 | for j in range(real_image.shape[0]): 88 | save_image(results_all[-1][j:j+1], os.path.join(opt.save_dir, opt.name, opt.exp_name, img_names[j].split('/')[-1]), 89 | nrow=1, normalize=True, range=(-1, 1)) 90 | 91 | 92 | def main(): 93 | opt = get_opt() 94 | print(opt) 95 | torch.cuda.set_device("cuda:" + str(opt.gpu)) 96 | if not os.path.exists(os.path.join(opt.save_dir, opt.name, opt.exp_name)): 97 | os.makedirs(os.path.join(opt.save_dir, opt.name, opt.exp_name)) 98 | 99 | # define model 100 | warp_model = DSDNet(cond_in_channel=51, sample_nums=opt.sample_nums, group_nums=opt.group_nums).cuda() 101 | warp_model.eval() 102 | load_networks(opt, warp_model, opt.ckpt_dir) 103 | 104 | deformation_test(opt, warp_model) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | 110 | -------------------------------------------------------------------------------- /Deformation_Stage/train.py: -------------------------------------------------------------------------------- 1 | from models.dsdnet import DSDNet 2 | import os 3 | import torch, argparse, wandb 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data.distributed import DistributedSampler 8 | import tqdm 9 | from models import external_function 10 | from utils import lpips 11 | from utils.utils import AverageMeter 12 | from torchvision import utils 13 | from data.aligned_dataset_vitonhd import AlignedDataset 14 | 15 | 16 | def load_last_checkpoint(opt, network, optimizer): 17 | load_path = opt.save_dir + opt.name + f"/{str(opt.continue_from_epoch).zfill(3)}_viton_{str(opt.name)}.pt" 18 | if not os.path.exists(load_path): 19 | print("not exsits %s" % load_path) 20 | return 21 | print('loading the model from %s' % load_path) 22 | 23 | checkpoint = torch.load(load_path, map_location='cuda:{}'.format(opt.local_rank)) 24 | network.load_state_dict(checkpoint["state_dict"]) 25 | optimizer.load_state_dict(checkpoint["optim"]) 26 | 27 | 28 | def load_checkpoint_parallel(opt, network, load_path): 29 | if not os.path.exists(load_path): 30 | print("not exsits %s" % load_path) 31 | return 32 | print('loading the model from %s' % load_path) 33 | checkpoint = torch.load(load_path, map_location='cuda:{}'.format(opt.local_rank)) 34 | checkpoint_new = network.state_dict() 35 | for param in checkpoint_new: 36 | checkpoint_new[param] = checkpoint['state_dict'][param] 37 | network.load_state_dict(checkpoint_new) 38 | 39 | 40 | def get_opt(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--name', type=str, required=True) 43 | parser.add_argument('-b', '--batch_size', type=int, default=8) 44 | parser.add_argument('--lr', type=float, default=0.00003) 45 | parser.add_argument('-j', '--workers', type=int, default=4) 46 | parser.add_argument('--load_height', type=int, default=512) 47 | parser.add_argument('--load_width', type=int, default=384) 48 | parser.add_argument('--mode', type=str, default='train') 49 | parser.add_argument('--sample_nums', type=int, default=1) 50 | parser.add_argument('--group_nums', type=int, default=8) 51 | # dataset 52 | parser.add_argument('--dataroot', type=str, default='/VITON-HD-512/') 53 | parser.add_argument('--image_pairs_txt', type=str, default='train_pairs_1018.txt') 54 | parser.add_argument('--label_nc', type=int, default=14, help='# of input label channels') 55 | parser.add_argument('--resize_or_crop', type=str, default='None', 56 | help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 57 | parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size') 58 | parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 59 | parser.add_argument('--no_flip', action='store_true', 60 | help='if specified, do not flip the images for data argumentation') 61 | parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG') 62 | parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use') 63 | # log & checkpoints 64 | parser.add_argument('--save_dir', type=str, default='./results/') 65 | parser.add_argument('--display_freq', type=int, default=200) 66 | parser.add_argument('--save_freq', type=int, default=5) 67 | parser.add_argument('--local_rank', type=int, default=0) 68 | parser.add_argument('--num_gpus', type=int, default=1, help='the number of gpus') 69 | parser.add_argument('--continue_from_epoch', default=0, type=int, help='Continue training from epoch (default=0)') 70 | parser.add_argument('--epochs', default=80, type=int, help='training epochs (default=80)') 71 | parser.add_argument('--light_dir', type=str, default='./checkpoints/hd_lightnet.pt') 72 | 73 | opt = parser.parse_args() 74 | return opt 75 | 76 | 77 | def train(opt): 78 | torch.cuda.set_device(opt.local_rank) 79 | torch.distributed.init_process_group( 80 | 'nccl', 81 | init_method='env://' 82 | ) 83 | device = torch.device(f'cuda:{opt.local_rank}') 84 | 85 | train_data = AlignedDataset() 86 | train_data.initialize(opt) 87 | train_sampler = DistributedSampler(train_data) 88 | train_loader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=False, 89 | num_workers=4, pin_memory=True, sampler=train_sampler) 90 | dataset_size = len(train_loader) 91 | 92 | warp_model = DSDNet(cond_in_channel=51, sample_nums=opt.sample_nums, group_nums=opt.group_nums).cuda() 93 | warp_model.train() 94 | 95 | optimizer = torch.optim.AdamW(warp_model.parameters(), lr=opt.lr) 96 | 97 | if opt.continue_from_epoch > 0: 98 | load_last_checkpoint(opt, warp_model, optimizer) 99 | else: 100 | load_checkpoint_parallel(opt, warp_model.dsdms.lightnet, opt.light_dir) 101 | warp_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(warp_model).to(device) 102 | 103 | if opt.num_gpus != 0: 104 | warp_model = torch.nn.parallel.DistributedDataParallel(warp_model, device_ids=[opt.local_rank]) 105 | 106 | # criterion 107 | criterion_L1 = nn.L1Loss() 108 | criterion_percept = lpips.exportPerceptualLoss(model="net-lin", net="vgg", use_gpu=False) 109 | criterion_style = external_function.VGGLoss().to(device) 110 | 111 | if opt.local_rank == 0: 112 | wandb.init(project="d4-vton", name=opt.name, settings=wandb.Settings(code_dir=".")) 113 | print('#training images = %d' % dataset_size) 114 | 115 | # scheduler 116 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.2, 117 | last_epoch=opt.continue_from_epoch - 1) 118 | 119 | for epoch in range(opt.continue_from_epoch, opt.epochs): 120 | loss_l1_avg = AverageMeter() 121 | loss_vgg_avg = AverageMeter() 122 | train_sampler.set_epoch(epoch) 123 | iterations = 0 124 | 125 | for data in tqdm.tqdm(train_loader): 126 | iterations += 1 127 | pre_clothes_edge = data['edge'] 128 | clothes = data['color'] 129 | clothes = clothes * pre_clothes_edge 130 | person_clothes_edge = data['person_clothes_mask'] 131 | real_image = data['image'] 132 | person_clothes = real_image * person_clothes_edge 133 | pose = data['pose'] 134 | size = data['color'].size() 135 | oneHot_size1 = (size[0], 25, size[2], size[3]) 136 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_() 137 | densepose = densepose.scatter_(1, data['densepose'].data.long().cuda(), 1.0) 138 | densepose = densepose * 2.0 - 1.0 139 | 140 | person_clothes = person_clothes.cuda() 141 | pose = pose.cuda() 142 | clothes = clothes.cuda() 143 | preserve_mask = data['preserve_mask'].cuda() 144 | 145 | condition = torch.cat([densepose, pose, preserve_mask], 1) 146 | results_all = warp_model(condition, clothes) 147 | 148 | loss_all = 0 149 | num_layer = 5 150 | for num in range(num_layer): 151 | if num == 1 or num == 3: 152 | continue 153 | cur_img = F.interpolate(person_clothes, scale_factor=0.5 ** (4 - num), mode='bilinear') 154 | loss_l1 = criterion_L1(results_all[num], cur_img) 155 | if num == 0: 156 | cur_img = F.interpolate(cur_img, scale_factor=2, mode='bilinear') 157 | results_all[num] = F.interpolate(results_all[num], scale_factor=2, mode='bilinear') 158 | loss_perceptual = criterion_percept(cur_img, results_all[num]).mean() 159 | loss_content, loss_style = criterion_style(results_all[num], cur_img) 160 | loss_vgg = loss_perceptual + 100 * loss_style + 0.1 * loss_content 161 | loss_all = loss_all + (num + 1) * loss_l1 + (num + 1) * loss_vgg 162 | 163 | loss = loss_all 164 | loss_l1_avg.update(loss_all.item()) 165 | loss_vgg_avg.update(loss_vgg.item()) 166 | 167 | optimizer.zero_grad() 168 | loss.backward() 169 | optimizer.step() 170 | 171 | if iterations % 50 == 1 and opt.local_rank == 0: 172 | wandb.log({'l1_loss': loss_l1_avg.avg, 173 | 'vgg_loss': loss_vgg_avg.avg, 174 | 'epoch': epoch, 'steps': iterations}) 175 | 176 | if iterations % opt.display_freq == 0 and opt.local_rank == 0: 177 | parse_pred = torch.cat([real_image.cuda(), clothes, results_all[-1], person_clothes], 3) 178 | utils.save_image( 179 | parse_pred, 180 | f"{os.path.join(opt.save_dir, opt.name)}/log_sample/{str(epoch + 1).zfill(3)}_{str(iterations).zfill(4)}_{str(opt.name)}.jpg", 181 | nrow=1, 182 | normalize=True, 183 | range=(-1, 1), 184 | ) 185 | 186 | if (epoch + 1) % opt.save_freq == 0 and opt.local_rank == 0: 187 | torch.save( 188 | { 189 | "state_dict": warp_model.module.state_dict(), 190 | "optim": optimizer.state_dict(), 191 | }, 192 | opt.save_dir + opt.name + f"/{str(epoch + 1).zfill(3)}_viton_{str(opt.name)}.pt") 193 | 194 | scheduler.step() 195 | 196 | 197 | if __name__ == '__main__': 198 | opt = get_opt() 199 | if opt.local_rank == 0: 200 | if not os.path.exists(os.path.join(opt.save_dir, opt.name)): 201 | os.makedirs(os.path.join(opt.save_dir, opt.name)) 202 | os.makedirs(os.path.join(opt.save_dir, opt.name, 'log_sample')) 203 | train(opt) -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/__init__.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/__init__.py 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import numpy as np 10 | import torch 11 | from torch.autograd import Variable 12 | 13 | from . import dist_model 14 | 15 | 16 | class exportPerceptualLoss(torch.nn.Module): 17 | def __init__( 18 | self, model="net-lin", net="alex", colorspace="rgb", spatial=False, use_gpu=True 19 | ): # VGG using our perceptually-learned weights (LPIPS metric) 20 | super(exportPerceptualLoss, self).__init__() 21 | print("Setting up Perceptual loss...") 22 | self.use_gpu = use_gpu 23 | self.spatial = spatial 24 | self.model = dist_model.exportModel() 25 | self.model.initialize( 26 | model=model, 27 | net=net, 28 | use_gpu=use_gpu, 29 | colorspace=colorspace, 30 | spatial=self.spatial, 31 | ) 32 | print("...[%s] initialized" % self.model.name()) 33 | print("...Done") 34 | 35 | def forward(self, pred, target): 36 | return self.model.forward(target, pred) 37 | 38 | 39 | class PerceptualLoss(torch.nn.Module): 40 | def __init__( 41 | self, 42 | model="net-lin", 43 | net="alex", 44 | colorspace="rgb", 45 | spatial=False, 46 | use_gpu=True, 47 | gpu_ids=[0], 48 | ): # VGG using our perceptually-learned weights (LPIPS metric) 49 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 50 | super(PerceptualLoss, self).__init__() 51 | print("Setting up Perceptual loss...") 52 | self.use_gpu = use_gpu 53 | self.spatial = spatial 54 | self.gpu_ids = gpu_ids 55 | self.model = dist_model.DistModel() 56 | self.model.initialize( 57 | model=model, 58 | net=net, 59 | use_gpu=use_gpu, 60 | colorspace=colorspace, 61 | spatial=self.spatial, 62 | gpu_ids=gpu_ids, 63 | ) 64 | print("...[%s] initialized" % self.model.name()) 65 | print("...Done") 66 | 67 | def forward(self, pred, target, normalize=False): 68 | """ 69 | Pred and target are Variables. 70 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 71 | If normalize is False, assumes the images are already between [-1,+1] 72 | 73 | Inputs pred and target are Nx3xHxW 74 | Output pytorch Variable N long 75 | """ 76 | 77 | if normalize: 78 | target = 2 * target - 1 79 | pred = 2 * pred - 1 80 | 81 | return self.model.forward(target, pred) 82 | 83 | 84 | def normalize_tensor(in_feat, eps=1e-10): 85 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) 86 | return in_feat / (norm_factor + eps) 87 | 88 | 89 | def l2(p0, p1, range=255.0): 90 | return 0.5 * np.mean((p0 / range - p1 / range) ** 2) 91 | 92 | 93 | def psnr(p0, p1, peak=255.0): 94 | return 10 * np.log10(peak ** 2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2)) 95 | 96 | 97 | def dssim(p0, p1, range=255.0): 98 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0 99 | 100 | 101 | def rgb2lab(in_img, mean_cent=False): 102 | from skimage import color 103 | 104 | img_lab = color.rgb2lab(in_img) 105 | if mean_cent: 106 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 107 | return img_lab 108 | 109 | 110 | def tensor2np(tensor_obj): 111 | # change dimension of a tensor object into a numpy array 112 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 113 | 114 | 115 | def np2tensor(np_obj): 116 | # change dimenion of np array into tensor array 117 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 118 | 119 | 120 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 121 | # image tensor to lab tensor 122 | from skimage import color 123 | 124 | img = tensor2im(image_tensor) 125 | img_lab = color.rgb2lab(img) 126 | if mc_only: 127 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 128 | if to_norm and not mc_only: 129 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 130 | img_lab = img_lab / 100.0 131 | 132 | return np2tensor(img_lab) 133 | 134 | 135 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 136 | from skimage import color 137 | import warnings 138 | 139 | warnings.filterwarnings("ignore") 140 | 141 | lab = tensor2np(lab_tensor) * 100.0 142 | lab[:, :, 0] = lab[:, :, 0] + 50 143 | 144 | rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1) 145 | if return_inbnd: 146 | # convert back to lab, see if we match 147 | lab_back = color.rgb2lab(rgb_back.astype("uint8")) 148 | mask = 1.0 * np.isclose(lab_back, lab, atol=2.0) 149 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 150 | return (im2tensor(rgb_back), mask) 151 | else: 152 | return im2tensor(rgb_back) 153 | 154 | 155 | def rgb2lab(input): 156 | from skimage import color 157 | 158 | return color.rgb2lab(input / 255.0) 159 | 160 | 161 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 162 | image_numpy = image_tensor[0].cpu().float().numpy() 163 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 164 | return image_numpy.astype(imtype) 165 | 166 | 167 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 168 | return torch.Tensor( 169 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 170 | ) 171 | 172 | 173 | def tensor2vec(vector_tensor): 174 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 175 | 176 | 177 | def voc_ap(rec, prec, use_07_metric=False): 178 | """ap = voc_ap(rec, prec, [use_07_metric]) 179 | Compute VOC AP given precision and recall. 180 | If use_07_metric is true, uses the 181 | VOC 07 11 point method (default:False). 182 | """ 183 | if use_07_metric: 184 | # 11 point metric 185 | ap = 0.0 186 | for t in np.arange(0.0, 1.1, 0.1): 187 | if np.sum(rec >= t) == 0: 188 | p = 0 189 | else: 190 | p = np.max(prec[rec >= t]) 191 | ap = ap + p / 11.0 192 | else: 193 | # correct AP calculation 194 | # first append sentinel values at the end 195 | mrec = np.concatenate(([0.0], rec, [1.0])) 196 | mpre = np.concatenate(([0.0], prec, [0.0])) 197 | 198 | # compute the precision envelope 199 | for i in range(mpre.size - 1, 0, -1): 200 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 201 | 202 | # to calculate area under PR curve, look for points 203 | # where X axis (recall) changes value 204 | i = np.where(mrec[1:] != mrec[:-1])[0] 205 | 206 | # and sum (\Delta recall) * prec 207 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 208 | return ap 209 | 210 | 211 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 212 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 213 | image_numpy = image_tensor[0].cpu().float().numpy() 214 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 215 | return image_numpy.astype(imtype) 216 | 217 | 218 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 219 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 220 | return torch.Tensor( 221 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 222 | ) 223 | -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/__pycache__/dist_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/__pycache__/dist_model.cpython-36.pyc -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/__pycache__/networks_basic.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/__pycache__/networks_basic.cpython-36.pyc -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/__pycache__/pretrained_networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/__pycache__/pretrained_networks.cpython-36.pyc -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/base_model.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/trainer.py 4 | """ 5 | import os 6 | import torch 7 | from torch.autograd import Variable 8 | from pdb import set_trace as st 9 | 10 | 11 | class BaseModel: 12 | def __init__(self): 13 | pass 14 | 15 | def name(self): 16 | return "BaseModel" 17 | 18 | def initialize(self, use_gpu=True, gpu_ids=[0]): 19 | self.use_gpu = use_gpu 20 | self.gpu_ids = gpu_ids 21 | 22 | def forward(self): 23 | pass 24 | 25 | def get_image_paths(self): 26 | pass 27 | 28 | def optimize_parameters(self): 29 | pass 30 | 31 | def get_current_visuals(self): 32 | return self.input 33 | 34 | def get_current_errors(self): 35 | return {} 36 | 37 | def save(self, label): 38 | pass 39 | 40 | # helper saving function that can be used by subclasses 41 | def save_network(self, network, path, network_label, epoch_label): 42 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 43 | save_path = os.path.join(path, save_filename) 44 | torch.save(network.state_dict(), save_path) 45 | 46 | # helper loading function that can be used by subclasses 47 | def load_network(self, network, network_label, epoch_label): 48 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 49 | save_path = os.path.join(self.save_dir, save_filename) 50 | print("Loading network from %s" % save_path) 51 | network.load_state_dict(torch.load(save_path)) 52 | 53 | def update_learning_rate(): 54 | pass 55 | 56 | def get_image_paths(self): 57 | return self.image_paths 58 | 59 | def save_done(self, flag=False): 60 | np.save(os.path.join(self.save_dir, "done_flag"), flag) 61 | np.savetxt( 62 | os.path.join(self.save_dir, "done_flag"), 63 | [ 64 | flag, 65 | ], 66 | fmt="%i", 67 | ) 68 | -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/networks_basic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/networks_basic.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py 4 | """ 5 | from __future__ import absolute_import 6 | import sys 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | from torch.autograd import Variable 11 | import numpy as np 12 | from pdb import set_trace as st 13 | from skimage import color 14 | from . import pretrained_networks as pn 15 | 16 | from utils import lpips as util 17 | 18 | 19 | def spatial_average(in_tens, keepdim=True): 20 | return in_tens.mean([2, 3], keepdim=keepdim) 21 | 22 | 23 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 24 | in_H = in_tens.shape[2] 25 | scale_factor = 1.0 * out_H / in_H 26 | 27 | return nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False)( 28 | in_tens 29 | ) 30 | 31 | 32 | # Learned perceptual metric 33 | class PNetLin(nn.Module): 34 | def __init__( 35 | self, 36 | pnet_type="vgg", 37 | pnet_rand=False, 38 | pnet_tune=False, 39 | use_dropout=True, 40 | spatial=False, 41 | version="0.1", 42 | lpips=True, 43 | ): 44 | super(PNetLin, self).__init__() 45 | 46 | self.pnet_type = pnet_type 47 | self.pnet_tune = pnet_tune 48 | self.pnet_rand = pnet_rand 49 | self.spatial = spatial 50 | self.lpips = lpips 51 | self.version = version 52 | self.scaling_layer = ScalingLayer() 53 | 54 | if self.pnet_type in ["vgg", "vgg16"]: 55 | net_type = pn.vgg16 56 | self.chns = [64, 128, 256, 512, 512] 57 | elif self.pnet_type == "alex": 58 | net_type = pn.alexnet 59 | self.chns = [64, 192, 384, 256, 256] 60 | elif self.pnet_type == "squeeze": 61 | net_type = pn.squeezenet 62 | self.chns = [64, 128, 256, 384, 384, 512, 512] 63 | self.L = len(self.chns) 64 | 65 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 66 | 67 | if lpips: 68 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 69 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 70 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 71 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 72 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 73 | self.lins = nn.ModuleList( 74 | [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 75 | ) 76 | 77 | if self.pnet_type == "squeeze": # 7 layers for squeezenet 78 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 79 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 80 | self.lins.extend([self.lin5, self.lin6]) 81 | 82 | def forward(self, in0, in1, retPerLayer=False): 83 | # v0.0 - original release had a bug, where input was not scaled 84 | in0_input, in1_input = ( 85 | (self.scaling_layer(in0), self.scaling_layer(in1)) 86 | if self.version == "0.1" 87 | else (in0, in1) 88 | ) 89 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 90 | feats0, feats1, diffs = {}, {}, {} 91 | 92 | for kk in range(self.L): 93 | feats0[kk], feats1[kk] = ( 94 | util.normalize_tensor(outs0[kk]), 95 | util.normalize_tensor(outs1[kk]), 96 | ) 97 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 98 | 99 | if self.lpips: 100 | if self.spatial: 101 | res = [ 102 | upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) 103 | for kk in range(self.L) 104 | ] 105 | else: 106 | res = [ 107 | spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) 108 | for kk in range(self.L) 109 | ] 110 | else: 111 | if self.spatial: 112 | res = [ 113 | upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) 114 | for kk in range(self.L) 115 | ] 116 | else: 117 | res = [ 118 | spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) 119 | for kk in range(self.L) 120 | ] 121 | 122 | val = res[0] 123 | for l in range(1, self.L): 124 | val += res[l] 125 | 126 | if retPerLayer: 127 | return (val, res) 128 | else: 129 | return val 130 | 131 | 132 | class ScalingLayer(nn.Module): 133 | def __init__(self): 134 | super(ScalingLayer, self).__init__() 135 | self.register_buffer( 136 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None].cuda() 137 | ) 138 | self.register_buffer( 139 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None].cuda() 140 | ) 141 | 142 | def forward(self, inp): 143 | return (inp - self.shift) / self.scale 144 | 145 | 146 | class NetLinLayer(nn.Module): 147 | """ A single linear layer which does a 1x1 conv """ 148 | 149 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 150 | super(NetLinLayer, self).__init__() 151 | 152 | layers = ( 153 | [ 154 | nn.Dropout(), 155 | ] 156 | if (use_dropout) 157 | else [] 158 | ) 159 | layers += [ 160 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 161 | ] 162 | self.model = nn.Sequential(*layers) 163 | 164 | 165 | class Dist2LogitLayer(nn.Module): 166 | """ takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) """ 167 | 168 | def __init__(self, chn_mid=32, use_sigmoid=True): 169 | super(Dist2LogitLayer, self).__init__() 170 | 171 | layers = [ 172 | nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), 173 | ] 174 | layers += [ 175 | nn.LeakyReLU(0.2, True), 176 | ] 177 | layers += [ 178 | nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), 179 | ] 180 | layers += [ 181 | nn.LeakyReLU(0.2, True), 182 | ] 183 | layers += [ 184 | nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), 185 | ] 186 | if use_sigmoid: 187 | layers += [ 188 | nn.Sigmoid(), 189 | ] 190 | self.model = nn.Sequential(*layers) 191 | 192 | def forward(self, d0, d1, eps=0.1): 193 | return self.model.forward( 194 | torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1) 195 | ) 196 | 197 | 198 | class BCERankingLoss(nn.Module): 199 | def __init__(self, chn_mid=32): 200 | super(BCERankingLoss, self).__init__() 201 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 202 | # self.parameters = list(self.net.parameters()) 203 | self.loss = torch.nn.BCELoss() 204 | 205 | def forward(self, d0, d1, judge): 206 | per = (judge + 1.0) / 2.0 207 | self.logit = self.net.forward(d0, d1) 208 | return self.loss(self.logit, per) 209 | 210 | 211 | # L2, DSSIM training 212 | class FakeNet(nn.Module): 213 | def __init__(self, use_gpu=True, colorspace="Lab"): 214 | super(FakeNet, self).__init__() 215 | self.use_gpu = use_gpu 216 | self.colorspace = colorspace 217 | 218 | 219 | class L2(FakeNet): 220 | def forward(self, in0, in1, retPerLayer=None): 221 | assert in0.size()[0] == 1 # currently only supports batchSize 1 222 | 223 | if self.colorspace == "RGB": 224 | (N, C, X, Y) = in0.size() 225 | value = torch.mean( 226 | torch.mean( 227 | torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2 228 | ).view(N, 1, 1, Y), 229 | dim=3, 230 | ).view(N) 231 | return value 232 | elif self.colorspace == "Lab": 233 | value = util.l2( 234 | util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 235 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), 236 | range=100.0, 237 | ).astype("float") 238 | ret_var = Variable(torch.Tensor((value,))) 239 | if self.use_gpu: 240 | ret_var = ret_var.cuda() 241 | return ret_var 242 | 243 | 244 | class DSSIM(FakeNet): 245 | def forward(self, in0, in1, retPerLayer=None): 246 | assert in0.size()[0] == 1 # currently only supports batchSize 1 247 | 248 | if self.colorspace == "RGB": 249 | value = util.dssim( 250 | 1.0 * util.tensor2im(in0.data), 251 | 1.0 * util.tensor2im(in1.data), 252 | range=255.0, 253 | ).astype("float") 254 | elif self.colorspace == "Lab": 255 | value = util.dssim( 256 | util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 257 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), 258 | range=100.0, 259 | ).astype("float") 260 | ret_var = Variable(torch.Tensor((value,))) 261 | if self.use_gpu: 262 | ret_var = ret_var.cuda() 263 | return ret_var 264 | 265 | 266 | def print_network(net): 267 | num_params = 0 268 | for param in net.parameters(): 269 | num_params += param.numel() 270 | print("Network", net) 271 | print("Total number of parameters: %d" % num_params) 272 | -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/pretrained_networks.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py 4 | """ 5 | from collections import namedtuple 6 | import torch 7 | from torchvision import models as tv 8 | 9 | 10 | class squeezenet(torch.nn.Module): 11 | def __init__(self, requires_grad=False, pretrained=True): 12 | super(squeezenet, self).__init__() 13 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 14 | self.slice1 = torch.nn.Sequential() 15 | self.slice2 = torch.nn.Sequential() 16 | self.slice3 = torch.nn.Sequential() 17 | self.slice4 = torch.nn.Sequential() 18 | self.slice5 = torch.nn.Sequential() 19 | self.slice6 = torch.nn.Sequential() 20 | self.slice7 = torch.nn.Sequential() 21 | self.N_slices = 7 22 | for x in range(2): 23 | self.slice1.add_module(str(x), pretrained_features[x]) 24 | for x in range(2, 5): 25 | self.slice2.add_module(str(x), pretrained_features[x]) 26 | for x in range(5, 8): 27 | self.slice3.add_module(str(x), pretrained_features[x]) 28 | for x in range(8, 10): 29 | self.slice4.add_module(str(x), pretrained_features[x]) 30 | for x in range(10, 11): 31 | self.slice5.add_module(str(x), pretrained_features[x]) 32 | for x in range(11, 12): 33 | self.slice6.add_module(str(x), pretrained_features[x]) 34 | for x in range(12, 13): 35 | self.slice7.add_module(str(x), pretrained_features[x]) 36 | if not requires_grad: 37 | for param in self.parameters(): 38 | param.requires_grad = False 39 | 40 | def forward(self, X): 41 | h = self.slice1(X) 42 | h_relu1 = h 43 | h = self.slice2(h) 44 | h_relu2 = h 45 | h = self.slice3(h) 46 | h_relu3 = h 47 | h = self.slice4(h) 48 | h_relu4 = h 49 | h = self.slice5(h) 50 | h_relu5 = h 51 | h = self.slice6(h) 52 | h_relu6 = h 53 | h = self.slice7(h) 54 | h_relu7 = h 55 | vgg_outputs = namedtuple( 56 | "SqueezeOutputs", 57 | ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"], 58 | ) 59 | out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) 60 | 61 | return out 62 | 63 | 64 | class alexnet(torch.nn.Module): 65 | def __init__(self, requires_grad=False, pretrained=True): 66 | super(alexnet, self).__init__() 67 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 68 | self.slice1 = torch.nn.Sequential() 69 | self.slice2 = torch.nn.Sequential() 70 | self.slice3 = torch.nn.Sequential() 71 | self.slice4 = torch.nn.Sequential() 72 | self.slice5 = torch.nn.Sequential() 73 | self.N_slices = 5 74 | for x in range(2): 75 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 76 | for x in range(2, 5): 77 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 78 | for x in range(5, 8): 79 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 80 | for x in range(8, 10): 81 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 82 | for x in range(10, 12): 83 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 84 | if not requires_grad: 85 | for param in self.parameters(): 86 | param.requires_grad = False 87 | 88 | def forward(self, X): 89 | h = self.slice1(X) 90 | h_relu1 = h 91 | h = self.slice2(h) 92 | h_relu2 = h 93 | h = self.slice3(h) 94 | h_relu3 = h 95 | h = self.slice4(h) 96 | h_relu4 = h 97 | h = self.slice5(h) 98 | h_relu5 = h 99 | alexnet_outputs = namedtuple( 100 | "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] 101 | ) 102 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 103 | 104 | return out 105 | 106 | 107 | class vgg16(torch.nn.Module): 108 | def __init__(self, requires_grad=False, pretrained=True): 109 | super(vgg16, self).__init__() 110 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 111 | self.slice1 = torch.nn.Sequential() 112 | self.slice2 = torch.nn.Sequential() 113 | self.slice3 = torch.nn.Sequential() 114 | self.slice4 = torch.nn.Sequential() 115 | self.slice5 = torch.nn.Sequential() 116 | self.N_slices = 5 117 | for x in range(4): 118 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 119 | for x in range(4, 9): 120 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 121 | for x in range(9, 16): 122 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 123 | for x in range(16, 23): 124 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 125 | for x in range(23, 30): 126 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 127 | if not requires_grad: 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | 131 | def forward(self, X): 132 | h = self.slice1(X) 133 | h_relu1_2 = h 134 | h = self.slice2(h) 135 | h_relu2_2 = h 136 | h = self.slice3(h) 137 | h_relu3_3 = h 138 | h = self.slice4(h) 139 | h_relu4_3 = h 140 | h = self.slice5(h) 141 | h_relu5_3 = h 142 | vgg_outputs = namedtuple( 143 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 144 | ) 145 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 146 | 147 | return out 148 | 149 | 150 | class resnet(torch.nn.Module): 151 | def __init__(self, requires_grad=False, pretrained=True, num=18): 152 | super(resnet, self).__init__() 153 | if num == 18: 154 | self.net = tv.resnet18(pretrained=pretrained) 155 | elif num == 34: 156 | self.net = tv.resnet34(pretrained=pretrained) 157 | elif num == 50: 158 | self.net = tv.resnet50(pretrained=pretrained) 159 | elif num == 101: 160 | self.net = tv.resnet101(pretrained=pretrained) 161 | elif num == 152: 162 | self.net = tv.resnet152(pretrained=pretrained) 163 | self.N_slices = 5 164 | 165 | self.conv1 = self.net.conv1 166 | self.bn1 = self.net.bn1 167 | self.relu = self.net.relu 168 | self.maxpool = self.net.maxpool 169 | self.layer1 = self.net.layer1 170 | self.layer2 = self.net.layer2 171 | self.layer3 = self.net.layer3 172 | self.layer4 = self.net.layer4 173 | 174 | def forward(self, X): 175 | h = self.conv1(X) 176 | h = self.bn1(h) 177 | h = self.relu(h) 178 | h_relu1 = h 179 | h = self.maxpool(h) 180 | h = self.layer1(h) 181 | h_conv2 = h 182 | h = self.layer2(h) 183 | h_conv3 = h 184 | h = self.layer3(h) 185 | h_conv4 = h 186 | h = self.layer4(h) 187 | h_conv5 = h 188 | 189 | outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]) 190 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 191 | 192 | return out 193 | -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Deformation_Stage/utils/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /Deformation_Stage/utils/lpips_2dirs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import lpips 4 | import torch 5 | import numpy as np 6 | 7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | parser.add_argument('--input1', type=str, default='./imgs/ex_dir1') 9 | parser.add_argument('--input2', type=str, default='./imgs/ex_dir2') 10 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt') 11 | parser.add_argument('-v','--version', type=str, default='0.1') 12 | parser.add_argument('--gpu', type=int, default=0, help='which gpu to use') 13 | 14 | opt = parser.parse_args() 15 | 16 | ## Initializing the model 17 | loss_fn = lpips.LPIPS(net='alex',version=opt.version) 18 | device = torch.device("cuda:" + str(opt.gpu) if torch.cuda.is_available() else "cpu") 19 | loss_fn.to(device) 20 | 21 | # crawl directories 22 | files = os.listdir(opt.input1) 23 | LPIPS = [] 24 | for file in files: 25 | if os.path.exists(os.path.join(opt.input2, file)): 26 | img1 = lpips.im2tensor(lpips.load_image(os.path.join(opt.input1, file))) 27 | img2 = lpips.im2tensor(lpips.load_image(os.path.join(opt.input2, file))) 28 | img1 = img1.to(device) 29 | img2 = img2.to(device) 30 | dist01 = loss_fn.forward(img1, img2) 31 | LPIPS.append(dist01.item()) 32 | mean_LPIPS = torch.mean(torch.tensor(LPIPS)) 33 | print("Mean LPIPS:", np.array(mean_LPIPS)) 34 | 35 | -------------------------------------------------------------------------------- /Deformation_Stage/utils/test_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import os 4 | from skimage.metrics import structural_similarity 5 | from skimage.metrics import peak_signal_noise_ratio 6 | import numpy as np 7 | import argparse 8 | from PIL import Image 9 | SSIMS =[] 10 | PSNRS = [] 11 | import torch 12 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 13 | parser.add_argument('--input1', type=str) 14 | parser.add_argument('--input2', type=str) 15 | parser.add_argument('--gpu', type=int, default=0, help='which gpu to use') 16 | opt = parser.parse_args() 17 | #ssim_loss = pytorch_ssim.SSIM(window_size = 11) 18 | img2_files = glob.glob(opt.input2+"/*") 19 | device = torch.device("cuda:" + str(opt.gpu) if torch.cuda.is_available() else "cpu") 20 | for file2 in img2_files: 21 | # Extract the filename from the path 22 | fake_filename = os.path.basename(file2) 23 | # Find the corresponding real image file 24 | real_file = os.path.join(opt.input1, fake_filename) 25 | 26 | if os.path.exists(real_file): 27 | real = cv2.imread(real_file) 28 | fake = cv2.imread(file2) 29 | 30 | if real.shape[0] != fake.shape[0] or real.shape[1] != fake.shape[1]: 31 | pil_img = Image.fromarray(fake) 32 | pil_img = pil_img.resize((real.shape[1], real.shape[0])) 33 | hazy_img = np.array(pil_img) 34 | 35 | # 计算PSNR 36 | PSNR = peak_signal_noise_ratio(real, fake) 37 | PSNRS.append(PSNR) 38 | 39 | # 计算SSIM 40 | SSIM = structural_similarity(real, fake, multichannel=True) 41 | SSIMS.append(SSIM) 42 | 43 | print('PSNR: ', sum(PSNRS) / len(PSNRS)) 44 | print('SSIM: ', sum(SSIMS) / len(SSIMS)) 45 | -------------------------------------------------------------------------------- /Deformation_Stage/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.nn.parallel.data_parallel import DataParallel 3 | from torch.nn.parallel.parallel_apply import parallel_apply 4 | from torch.nn.parallel._functions import Scatter 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | 10 | 11 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 12 | r""" 13 | Slices tensors into approximately equal chunks and 14 | distributes them across given GPUs. Duplicates 15 | references to objects that are not tensors. 16 | """ 17 | 18 | def scatter_map(obj): 19 | if isinstance(obj, torch.Tensor): 20 | try: 21 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 22 | except Exception: 23 | print('obj', obj.size()) 24 | print('dim', dim) 25 | print('chunk_sizes', chunk_sizes) 26 | quit() 27 | if isinstance(obj, tuple) and len(obj) > 0: 28 | return list(zip(*map(scatter_map, obj))) 29 | if isinstance(obj, list) and len(obj) > 0: 30 | return list(map(list, zip(*map(scatter_map, obj)))) 31 | if isinstance(obj, dict) and len(obj) > 0: 32 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 33 | return [obj for targets in target_gpus] 34 | 35 | # After scatter_map is called, a scatter_map cell will exist. This cell 36 | # has a reference to the actual function scatter_map, which has references 37 | # to a closure that has a reference to the scatter_map cell (because the 38 | # fn is recursive). To avoid this reference cycle, we set the function to 39 | # None, clearing the cell 40 | try: 41 | return scatter_map(inputs) 42 | finally: 43 | scatter_map = None 44 | 45 | 46 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 47 | """Scatter with support for kwargs dictionary""" 48 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 49 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 50 | if len(inputs) < len(kwargs): 51 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 52 | elif len(kwargs) < len(inputs): 53 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 54 | inputs = tuple(inputs) 55 | kwargs = tuple(kwargs) 56 | return inputs, kwargs 57 | 58 | 59 | class BalancedDataParallel(DataParallel): 60 | 61 | def __init__(self, gpu0_bsz, *args, **kwargs): 62 | self.gpu0_bsz = gpu0_bsz 63 | super().__init__(*args, **kwargs) 64 | 65 | def forward(self, *inputs, **kwargs): 66 | if not self.device_ids: 67 | return self.module(*inputs, **kwargs) 68 | if self.gpu0_bsz == 0: 69 | device_ids = self.device_ids[1:] 70 | else: 71 | device_ids = self.device_ids 72 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 73 | if len(self.device_ids) == 1: 74 | return self.module(*inputs[0], **kwargs[0]) 75 | replicas = self.replicate(self.module, self.device_ids) 76 | if self.gpu0_bsz == 0: 77 | replicas = replicas[1:] 78 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 79 | return self.gather(outputs, self.output_device) 80 | 81 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 82 | return parallel_apply(replicas, inputs, kwargs, device_ids) 83 | 84 | def scatter(self, inputs, kwargs, device_ids): 85 | bsz = inputs[0].size(self.dim) 86 | num_dev = len(self.device_ids) 87 | gpu0_bsz = self.gpu0_bsz 88 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 89 | if gpu0_bsz < bsz_unit: 90 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 91 | delta = bsz - sum(chunk_sizes) 92 | for i in range(delta): 93 | chunk_sizes[i + 1] += 1 94 | if gpu0_bsz == 0: 95 | chunk_sizes = chunk_sizes[1:] 96 | else: 97 | return super().scatter(inputs, kwargs, device_ids) 98 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 99 | 100 | 101 | class AverageMeter(object): 102 | """Computes and stores the average and current value. 103 | 104 | Examples:: 105 | >>> # Initialize a meter to record loss 106 | >>> losses = AverageMeter() 107 | >>> # Update meter after every minibatch update 108 | >>> losses.update(loss_value, batch_size) 109 | """ 110 | 111 | def __init__(self): 112 | self.reset() 113 | 114 | def reset(self): 115 | self.val = 0 116 | self.avg = 0 117 | self.sum = 0 118 | self.count = 0 119 | 120 | def update(self, val, n=1): 121 | self.val = val 122 | self.sum += val * n 123 | self.count += n 124 | self.avg = self.sum / self.count 125 | 126 | def gen_noise(shape): 127 | noise = np.zeros(shape, dtype=np.uint8) 128 | ### noise 129 | noise = cv2.randn(noise, 0, 255) 130 | noise = np.asarray(noise / 255, dtype=np.uint8) 131 | noise = torch.tensor(noise, dtype=torch.float32) 132 | return noise 133 | 134 | 135 | def save_images(img_tensors, img_names, save_dir): 136 | for img_tensor, img_name in zip(img_tensors, img_names): 137 | tensor = (img_tensor.clone()+1)*0.5 * 255 138 | tensor = tensor.cpu().clamp(0,255) 139 | 140 | try: 141 | array = tensor.numpy().astype('uint8') 142 | except: 143 | array = tensor.detach().numpy().astype('uint8') 144 | 145 | if array.shape[0] == 1: 146 | array = array.squeeze(0) 147 | elif array.shape[0] == 3: 148 | array = array.swapaxes(0, 1).swapaxes(1, 2) 149 | 150 | im = Image.fromarray(array) 151 | im.save(os.path.join(save_dir, img_name), format='JPEG') 152 | 153 | 154 | def load_checkpoint(model, checkpoint_path): 155 | if not os.path.exists(checkpoint_path): 156 | raise ValueError("'{}' is not a valid checkpoint path".format(checkpoint_path)) 157 | model.load_state_dict(torch.load(checkpoint_path)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

[ECCV 2024] D4-VTON

2 | 3 | 4 | This is the official PyTorch codes for the paper: 5 | >**D$^4$-VTON: Dynamic Semantics Disentangling for Differential Diffusion based Virtual Try-On**
[Zhaotong Yang](https://github.com/Jerome-Young), [Zicheng Jiang](https://github.com/bronyajiang), [Xinzhe Li](https://github.com/lixinzhe-ouc), [Huiyu Zhou](https://le.ac.uk/people/huiyu-zhou), [Junyu Dong](https://it.ouc.edu.cn/djy_23898/main.htm), [Huaidong Zhang](https://github.com/Xiaodomgdomg/), [Yong Du*](https://www.csyongdu.cn/) ( * indicates corresponding author)
6 | >Proceedings of the European Conference on Computer Vision 7 | 8 | ### Pipeline 9 | 10 |
11 | 12 |
13 | 14 | ### News 15 | 16 | - ⭐**Aug 02, 2024:** We release inference and training code! 17 | - ⭐**Jul 01, 2024:** D4-VTON was accepted into ECCV 2024! 18 | 19 | ## Getting started 20 | 21 | ### Setup 22 | 23 | 1. Clone and enter into repo directory. 24 | ```shell 25 | git clone https://github.com/Jerome-Young/D4-VTON.git 26 | cd D4-VTON 27 | ``` 28 | 2. Install requirements using following scripts. 29 | ```shell 30 | conda env create -f environment.yaml 31 | conda activate d4-vton 32 | ``` 33 | 3. Please download the pre-trained [vgg](https://drive.google.com/file/d/1rvow8jStPt8t2prDcSRlnf8yzXhrYeGo/view?usp=sharing) checkpoint and put it in `Synthesis_Stage/model/vgg/`. 34 | 35 | ### Data Preparation 36 | To test the D4-VTON, you can download the VITON-HD(512 x 384) datasets from [GP-VTON](https://github.com/xiezhy6/GP-VTON). 37 | Or you can re-train the entire model on the high resolution (1024 x 768) dataset. 38 | 39 | ## Inference 40 | 41 | ### Stage 1 42 | 43 | Download the pre-trained checkpoint from [Google Drive](https://drive.google.com/file/d/1oPB-E6S2jz63wkLpz5-NxPMcYDlBewK6/view?usp=sharing), and put it in `Deformation_Stage/checkpoints/`. 44 | 45 | To test the Deformation Network, run the following command: 46 | ```shell 47 | cd Deformation_Stage 48 | 49 | python -u test.py -b 16 --gpu 0 --name d4vton_deform --mode test \ 50 | --exp_name \ 51 | --dataroot \ 52 | --image_pairs_txt \ 53 | --ckpt_dir checkpoints/vitonhd_deformation.pt 54 | 55 | # or you can run the bash scripts 56 | bash scripts/test.sh 57 | ``` 58 | 59 | Then you should put the result directory `unpaired-cloth-warp` (for unpaired setting) or `cloth-warp` (for paired setting) under the test directory of VITON-HD dataset (i.e., `VITON-HD-512/test`). 60 | 61 | ### Stage 2 62 | 63 | Download the pre-trained checkpoint from [Google Drive](https://drive.google.com/file/d/1rPGaxMZ5wgyrdoMtMTLJlvSkM8hd_UvV/view?usp=sharing), and put it in `Synthesis_Stage/checkpoints/`. 64 | 65 | To test the Synthesis Network, run the following command: 66 | ```shell 67 | cd Synthesis_Stage 68 | 69 | python test.py --gpu_id 0 --ddim_steps 100 \ 70 | --outdir results/d4vton_unpaired_syn --config configs/vitonhd_512.yaml \ 71 | --dataroot \ 72 | --ckpt checkpoints/vitonhd_synthesis.ckpt --delta_step 89 \ 73 | --n_samples 12 --seed 23 --scale 1 --H 512 --unpaired 74 | 75 | # or you can run the bash scripts 76 | bash scripts/test.sh 77 | ``` 78 | 79 | ## Training 80 | 81 | ### Stage 1 82 | 83 | Please download the pre-trained lightweight net from [Google Drive](https://drive.google.com/file/d/1SfkG0LCpfkgcPrKOg8-TPZwqRdrEVg5q/view?usp=sharing) for initialization and put it under the `Deformation_Stage/checkpoints` directory. 84 | 85 | To train the Deformation Network, run the following command: 86 | ```shell 87 | cd Deformation_Stage 88 | 89 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=6231 train.py \ 90 | --dataroot \ 91 | -b 2 --num_gpus 4 --name d4vton_deform --group_num 8 92 | 93 | # or you can run the bash scripts 94 | bash scripts/train.sh 95 | ``` 96 | 97 | In a similar inference process, you should warp the clothes in the training set under the paired setting and rename the result directory to `cloth-warp`, then put them under the train directory of VITON-HD dataset (i.e., `VITON-HD-512/train`). 98 | 99 | ### Stage 2 100 | 101 | We use the pretrained [Paint-by-Example](https://drive.google.com/file/d/15QzaTWsvZonJcXsNv-ilMRCYaQLhzR_i/view) checkpoint for initialization. Please put it under the `Synthesis_Stage/checkpoints` directory. 102 | 103 | To train the Synthesis Network, you first need to modify the `dataroot` in the `Synthesis_Stage/configs/vitonhd_512.yaml` file to your VITON-HD directory, and then run the following command: 104 | ```shell 105 | cd Synthesis_Stage 106 | 107 | python -u main.py --logdir models/d4vton_syn --pretrained_model checkpoints/model.ckpt \ 108 | --base configs/vitonhd_512.yaml --scale_lr False 109 | 110 | # or you can run the bash scripts 111 | bash scripts/train.sh 112 | ``` 113 | 114 | ## Results 115 | 116 |
117 | 118 |
119 | 120 | ## Acknowledgements 121 | Our code references the implementation of [DAFlow](https://github.com/OFA-Sys/DAFlow) and [DCI-VTON](https://github.com/bcmi/DCI-VTON-Virtual-Try-On). Thanks for their awesome works. 122 | 123 | ## Citation 124 | 125 | If you find our work useful for your research, please cite us: 126 | ``` 127 | @inproceedings{yang2025textrm, 128 | title={$$$\backslash$textrm $\{$D$\}$\^{} 4$$-VTON: Dynamic Semantics Disentangling for Differential Diffusion Based Virtual Try-On}, 129 | author={Yang, Zhaotong and Jiang, Zicheng and Li, Xinzhe and Zhou, Huiyu and Dong, Junyu and Zhang, Huaidong and Du, Yong}, 130 | booktitle={European Conference on Computer Vision}, 131 | pages={36--52}, 132 | year={2025}, 133 | organization={Springer} 134 | } 135 | ``` 136 | 137 | ## License 138 | 139 | All material is made available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicate any changes** that you've made. 140 | -------------------------------------------------------------------------------- /Synthesis_Stage/configs/vitonhd_512.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentTryOnDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "inpaint" 11 | cond_stage_key: "image" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: true # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | u_cond_percent: 0.2 18 | scale_factor: 0.18215 19 | use_ema: False 20 | 21 | scheduler_config: # 10000 warmup steps 22 | target: ldm.lr_scheduler.LambdaLinearScheduler 23 | params: 24 | warm_up_steps: [ 3000 ] 25 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 26 | f_start: [ 1.e-6 ] 27 | f_max: [ 1. ] 28 | f_min: [ 1. ] 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 32 | params: 33 | image_size: 64 # unused 34 | in_channels: 9 35 | out_channels: 4 36 | model_channels: 320 37 | attention_resolutions: [ 4, 2, 1 ] 38 | num_res_blocks: 2 39 | channel_mult: [ 1, 2, 4, 4 ] 40 | num_heads: 8 41 | use_spatial_transformer: True 42 | transformer_depth: 1 43 | context_dim: 768 44 | use_checkpoint: True 45 | legacy: False 46 | add_conv_in_front_of_unet: False 47 | 48 | first_stage_config: 49 | target: ldm.models.autoencoder.AutoencoderKL 50 | params: 51 | embed_dim: 4 52 | monitor: val/rec_loss 53 | ddconfig: 54 | double_z: true 55 | z_channels: 4 56 | resolution: 512 57 | in_channels: 3 58 | out_ch: 3 59 | ch: 128 60 | ch_mult: 61 | - 1 62 | - 2 63 | - 4 64 | - 4 65 | num_res_blocks: 2 66 | attn_resolutions: [] 67 | dropout: 0.0 68 | lossconfig: 69 | target: torch.nn.Identity 70 | 71 | cond_stage_config: 72 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder 73 | 74 | data: 75 | target: main.DataModuleFromConfig 76 | params: 77 | batch_size: 2 78 | wrap: False 79 | train: 80 | target: ldm.data.aligned_dataset_vitonhd.AlignedDataset 81 | params: 82 | mode: train 83 | dataroot: /data/user/VITON-HD-512/ 84 | resolution: 512 85 | validation: 86 | target: ldm.data.aligned_dataset_vitonhd.AlignedDataset 87 | params: 88 | mode: test 89 | dataroot: /data/user/VITON-HD-512/ 90 | resolution: 512 91 | test: 92 | target: ldm.data.aligned_dataset_vitonhd.AlignedDataset 93 | params: 94 | mode: test 95 | dataroot: /data/user/VITON-HD-512/ 96 | resolution: 512 97 | 98 | lightning: 99 | trainer: 100 | max_epochs: 60 101 | num_nodes: 1 102 | accelerator: 'ddp' 103 | gpus: "0,1,2,3" 104 | # accumulate_grad_batches: 2 105 | plugins: 'ddp_sharded' 106 | # precision: 16 -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/data/aligned_dataset_vitonhd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torchvision.transforms as transforms 4 | from .base_dataset import BaseDataset 5 | from PIL import Image, ImageDraw 6 | import torch 7 | import numpy as np 8 | import cv2 9 | import pycocotools.mask as maskUtils 10 | import math 11 | import json 12 | from torchvision import utils as vutils 13 | 14 | 15 | def mask2bbox(mask): 16 | up = np.max(np.where(mask)[0]) 17 | down = np.min(np.where(mask)[0]) 18 | left = np.min(np.where(mask)[1]) 19 | right = np.max(np.where(mask)[1]) 20 | center = ((up + down) // 2, (left + right) // 2) 21 | 22 | factor = random.random() * 0.1 + 0.1 23 | 24 | up = int(min(up * (1 + factor) - center[0] * factor + 1, mask.shape[0])) 25 | down = int(max(down * (1 + factor) - center[0] * factor, 0)) 26 | left = int(max(left * (1 + factor) - center[1] * factor, 0)) 27 | right = int(min(right * (1 + factor) - center[1] * factor + 1, mask.shape[1])) 28 | return (down, up, left, right) 29 | 30 | 31 | class AlignedDataset(BaseDataset): 32 | def __init__(self, dataroot, resolution=512, mode='train', unpaired=False): 33 | super(AlignedDataset, self).__init__() 34 | self.root = dataroot 35 | self.mode = mode 36 | self.unpaired = unpaired 37 | self.toTensor = transforms.ToTensor() 38 | self.rgb_transform = transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 41 | self.clip_normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 42 | (0.26862954, 0.26130258, 0.27577711)) 43 | 44 | if resolution == 512: 45 | self.load_height = 512 46 | self.load_width = 384 47 | self.radius = 8 48 | else: 49 | self.load_height = 1024 50 | self.load_width = 768 51 | self.radius = 16 52 | self.crop_size = (self.load_height, self.load_width) 53 | 54 | if mode == 'train': 55 | pair_txt_path = os.path.join(self.root, 'train_pairs_1018.txt') 56 | elif self.unpaired: 57 | pair_txt_path = os.path.join(self.root, 'test_pairs_unpaired_1018.txt') 58 | else: 59 | pair_txt_path = os.path.join(self.root, 'test_pairs_paired_1018.txt') 60 | with open(pair_txt_path, 'r') as f: 61 | lines = f.readlines() 62 | 63 | im_names = [] 64 | c_names = [] 65 | for line in lines: 66 | im_name, c_name, c_type = line.strip().split() 67 | im_names.append(im_name) 68 | c_names.append(c_name) 69 | 70 | self.im_names = im_names 71 | self.c_names = dict() 72 | self.c_names['paired'] = im_names 73 | self.c_names['unpaired'] = c_names 74 | 75 | self.dataset_size = len(self.im_names) 76 | 77 | ############### get palm mask ################ 78 | def get_mask_from_kps(self, kps, img_h, img_w): 79 | rles = maskUtils.frPyObjects(kps, img_h, img_w) 80 | rle = maskUtils.merge(rles) 81 | mask = maskUtils.decode(rle)[..., np.newaxis].astype(np.float32) 82 | mask = mask * 255.0 83 | return mask 84 | 85 | def get_rectangle_mask(self, a, b, c, d, img_h, img_w): 86 | x1, y1 = a + (b - d) / 4, b + (c - a) / 4 87 | x2, y2 = a - (b - d) / 4, b - (c - a) / 4 88 | 89 | x3, y3 = c + (b - d) / 4, d + (c - a) / 4 90 | x4, y4 = c - (b - d) / 4, d - (c - a) / 4 91 | 92 | kps = [x1, y1, x2, y2] 93 | 94 | v0_x, v0_y = c - a, d - b 95 | v1_x, v1_y = x3 - x1, y3 - y1 96 | v2_x, v2_y = x4 - x1, y4 - y1 97 | 98 | cos1 = (v0_x * v1_x + v0_y * v1_y) / \ 99 | (math.sqrt(v0_x * v0_x + v0_y * v0_y) * math.sqrt(v1_x * v1_x + v1_y * v1_y)) 100 | cos2 = (v0_x * v2_x + v0_y * v2_y) / \ 101 | (math.sqrt(v0_x * v0_x + v0_y * v0_y) * math.sqrt(v2_x * v2_x + v2_y * v2_y)) 102 | 103 | if cos1 < cos2: 104 | kps.extend([x3, y3, x4, y4]) 105 | else: 106 | kps.extend([x4, y4, x3, y3]) 107 | 108 | kps = np.array(kps).reshape(1, -1).tolist() 109 | mask = self.get_mask_from_kps(kps, img_h=img_h, img_w=img_w) 110 | 111 | return mask 112 | 113 | def get_hand_mask(self, hand_keypoints, h, w): 114 | # shoulder, elbow, wrist 115 | s_x, s_y, s_c = hand_keypoints[0] 116 | e_x, e_y, e_c = hand_keypoints[1] 117 | w_x, w_y, w_c = hand_keypoints[2] 118 | 119 | up_mask = np.ones((h, w, 1), dtype=np.float32) 120 | bottom_mask = np.ones((h, w, 1), dtype=np.float32) 121 | if s_c > 0.1 and e_c > 0.1: 122 | up_mask = self.get_rectangle_mask(s_x, s_y, e_x, e_y, h, w) 123 | if self.load_height == 512: 124 | kernel = np.ones((50, 50), np.uint8) 125 | else: 126 | kernel = np.ones((100, 100), np.uint8) 127 | up_mask = cv2.dilate(up_mask, kernel, iterations=1) 128 | up_mask = (up_mask > 0).astype(np.float32)[..., np.newaxis] 129 | if e_c > 0.1 and w_c > 0.1: 130 | bottom_mask = self.get_rectangle_mask(e_x, e_y, w_x, w_y, h, w) 131 | if self.load_height == 512: 132 | kernel = np.ones((30, 30), np.uint8) 133 | else: 134 | kernel = np.ones((60, 60), np.uint8) 135 | bottom_mask = cv2.dilate(bottom_mask, kernel, iterations=1) 136 | bottom_mask = (bottom_mask > 0).astype(np.float32)[..., np.newaxis] 137 | 138 | return up_mask, bottom_mask 139 | 140 | def get_palm_mask(self, hand_mask, hand_up_mask, hand_bottom_mask): 141 | inter_up_mask = ((hand_mask + hand_up_mask) == 2).astype(np.float32) 142 | hand_mask = hand_mask - inter_up_mask 143 | inter_bottom_mask = ((hand_mask + hand_bottom_mask) 144 | == 2).astype(np.float32) 145 | palm_mask = hand_mask - inter_bottom_mask 146 | 147 | return palm_mask 148 | 149 | def get_palm(self, parsing, keypoints): 150 | h, w = parsing.shape[0:2] 151 | 152 | left_hand_keypoints = keypoints[[5, 6, 7], :].copy() 153 | right_hand_keypoints = keypoints[[2, 3, 4], :].copy() 154 | 155 | left_hand_up_mask, left_hand_bottom_mask = self.get_hand_mask( 156 | left_hand_keypoints, h, w) 157 | right_hand_up_mask, right_hand_bottom_mask = self.get_hand_mask( 158 | right_hand_keypoints, h, w) 159 | 160 | # mask refined by parsing 161 | left_hand_mask = (parsing == 15).astype(np.float32) 162 | right_hand_mask = (parsing == 16).astype(np.float32) 163 | 164 | left_palm_mask = self.get_palm_mask( 165 | left_hand_mask, left_hand_up_mask, left_hand_bottom_mask) 166 | right_palm_mask = self.get_palm_mask( 167 | right_hand_mask, right_hand_up_mask, right_hand_bottom_mask) 168 | palm_mask = ((left_palm_mask + right_palm_mask) > 0).astype(np.uint8) 169 | 170 | return palm_mask 171 | 172 | def __getitem__(self, index): 173 | # C_type = self.C_types[index] 174 | if self.unpaired: 175 | key = 'unpaired' 176 | else: 177 | key = 'paired' 178 | 179 | # person image 180 | P_path = os.path.join(self.root, self.mode, 'image', self.im_names[index]) 181 | # P = transforms.Resize(self.crop_size, interpolation=2)(Image.open(P_path).convert('RGB')) 182 | P = Image.open(P_path).convert('RGB') 183 | P_tensor = self.rgb_transform(P) 184 | 185 | # person 2d pose 186 | pose_path = P_path.replace('image', 'openpose_json')[:-4] + '_keypoints.json' 187 | with open(pose_path, 'r') as f: 188 | datas = json.load(f) 189 | pose_data = np.array(datas['people'][0]['pose_keypoints_2d']).reshape(-1, 3) 190 | 191 | # person parsing 192 | parsing_path = P_path.replace('image', 'parse-bytedance')[:-4] + '.png' 193 | parsing = Image.open(parsing_path).convert('L') 194 | parsing_tensor = self.toTensor(parsing) * 255.0 195 | 196 | parsing_np = (parsing_tensor.numpy().transpose(1, 2, 0)[..., 0:1]).astype(np.uint8) 197 | palm_mask_np = self.get_palm(parsing_np, pose_data) 198 | palm_mask = torch.tensor(palm_mask_np.transpose(2, 0, 1)).float() 199 | 200 | # clothes 201 | C_path = os.path.join(self.root, self.mode, 'cloth', self.c_names[key][index]) 202 | # C = transforms.Resize(self.crop_size, interpolation=2)(Image.open(C_path).convert('RGB')) 203 | C = Image.open(C_path).convert('RGB') 204 | C_tensor = self.rgb_transform(C) 205 | 206 | CM_path = C_path.replace('cloth', 'cloth_mask-bytedance')[:-4] + '.png' 207 | # CM = transforms.Resize(self.crop_size, interpolation=0)(Image.open(CM_path).convert('L')) 208 | CM = Image.open(CM_path).convert('L') 209 | CM_tensor = self.toTensor(CM) 210 | 211 | mask = np.array([(parsing_np == index).astype(int) for index in [5, 6, 11, 15, 16, 21, 22, 24, 25]]) 212 | mask = np.sum(mask, axis=0) 213 | kernel_size = int(5 * (self.load_width / 256)) 214 | mask = cv2.dilate(mask.astype(np.uint8), kernel=np.ones((kernel_size, kernel_size)), iterations=3) 215 | mask = cv2.erode(mask.astype(np.uint8), kernel=np.ones((kernel_size, kernel_size)), iterations=1) 216 | mask = mask.astype(np.float32) 217 | inpaint_mask = 1 - self.toTensor(mask) 218 | 219 | W_path = P_path.replace('image', 'cloth-warp' if not self.unpaired else 'unpaired-cloth-warp') 220 | # W_tensor = transforms.Resize(self.crop_size, interpolation=2)(Image.open(W_path)) 221 | W_tensor = Image.open(W_path) 222 | W_tensor = self.rgb_transform(W_tensor) 223 | feat = W_tensor * (1 - inpaint_mask) + P_tensor * inpaint_mask 224 | 225 | down, up, left, right = mask2bbox(CM_tensor[0].numpy()) 226 | ref_image = C_tensor[:, down:up, left:right] 227 | ref_image = (ref_image + 1.0) / 2.0 228 | ref_image = transforms.Resize((224, 224))(ref_image) 229 | ref_image = self.clip_normalize(ref_image) 230 | 231 | inpaint = feat * (1 - palm_mask) + P_tensor * palm_mask 232 | 233 | input_dict = { 234 | 'GT': P_tensor, 235 | "inpaint_image": inpaint, 236 | "inpaint_mask": inpaint_mask, 237 | 'warp_feat': feat, 238 | 'ref_imgs': ref_image, 239 | 'file_name': P_path.split('/')[-1], 240 | } 241 | # vutils.save_image(input_dict['GT'], P_path.replace('test/image', 'pair_gt_500'), normalize=True) 242 | 243 | return input_dict 244 | 245 | def __len__(self): 246 | return self.dataset_size 247 | 248 | def name(self): 249 | return 'AlignedDataset' -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import random 6 | 7 | class BaseDataset(data.Dataset): 8 | def __init__(self): 9 | super(BaseDataset, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseDataset' 13 | 14 | def initialize(self, opt): 15 | pass 16 | 17 | def get_params(opt, size): 18 | w, h = size 19 | new_h = h 20 | new_w = w 21 | if opt.resize_or_crop == 'resize_and_crop': 22 | new_h = new_w = opt.loadSize 23 | elif opt.resize_or_crop == 'scale_width_and_crop': 24 | new_w = opt.loadSize 25 | new_h = opt.loadSize * h // w 26 | 27 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 28 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 29 | 30 | #flip = random.random() > 0.5 31 | flip = 0 32 | return {'crop_pos': (x, y), 'flip': flip} 33 | 34 | def get_transform_resize(opt, params, method=Image.BICUBIC, normalize=True): 35 | transform_list = [] 36 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 37 | osize = [256,192] 38 | transform_list.append(transforms.Scale(osize, method)) 39 | if 'crop' in opt.resize_or_crop: 40 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 41 | 42 | if opt.resize_or_crop == 'none': 43 | base = float(2 ** opt.n_downsample_global) 44 | if opt.netG == 'local': 45 | base *= (2 ** opt.n_local_enhancers) 46 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 47 | 48 | if opt.mode == 'train' and not opt.no_flip: 49 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 50 | 51 | transform_list += [transforms.ToTensor()] 52 | 53 | if normalize: 54 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 55 | (0.5, 0.5, 0.5))] 56 | return transforms.Compose(transform_list) 57 | 58 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True): 59 | transform_list = [] 60 | if 'resize' in opt.resize_or_crop: 61 | osize = [opt.loadSize, opt.loadSize] 62 | transform_list.append(transforms.Scale(osize, method)) 63 | elif 'scale_width' in opt.resize_or_crop: 64 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 65 | osize = [256,192] 66 | transform_list.append(transforms.Scale(osize, method)) 67 | if 'crop' in opt.resize_or_crop: 68 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 69 | 70 | if opt.resize_or_crop == 'none': 71 | base = float(2 ** opt.n_downsample_global) 72 | if opt.netG == 'local': 73 | base *= (2 ** opt.n_local_enhancers) 74 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 75 | 76 | if opt.mode == 'train' and not opt.no_flip: 77 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 78 | 79 | transform_list += [transforms.ToTensor()] 80 | 81 | if normalize: 82 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 83 | (0.5, 0.5, 0.5))] 84 | return transforms.Compose(transform_list) 85 | 86 | def normalize(): 87 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 88 | 89 | def __make_power_2(img, base, method=Image.BICUBIC): 90 | ow, oh = img.size 91 | h = int(round(oh / base) * base) 92 | w = int(round(ow / base) * base) 93 | if (h == oh) and (w == ow): 94 | return img 95 | return img.resize((w, h), method) 96 | 97 | def __scale_width(img, target_width, method=Image.BICUBIC): 98 | ow, oh = img.size 99 | if (ow == target_width): 100 | return img 101 | w = target_width 102 | h = int(target_width * oh / ow) 103 | return img.resize((w, h), method) 104 | 105 | def __crop(img, pos, size): 106 | ow, oh = img.size 107 | x1, y1 = pos 108 | tw = th = size 109 | if (ow > tw or oh > th): 110 | return img.crop((x1, y1, x1 + tw, y1 + th)) 111 | return img 112 | 113 | def __flip(img, flip): 114 | if flip: 115 | return img.transpose(Image.FLIP_LEFT_RIGHT) 116 | return img 117 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/__pycache__/ddpm_nodelta.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/__pycache__/ddpm_nodelta.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/__pycache__/ddpm_norec.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/__pycache__/ddpm_norec.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/__pycache__/dtdpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/__pycache__/dtdpm.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_clip import CLIPEncoder 2 | 3 | __all__ = ['CLIPEncoder'] 4 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/__pycache__/base_clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/__pycache__/base_clip.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/base_clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .clip import clip 4 | # from clip import clip 5 | import torchvision 6 | from PIL import Image 7 | 8 | model_name = "ViT-B/16" 9 | # model_name = "ViT-B/32" 10 | 11 | 12 | def load_clip_to_cpu(): 13 | url = clip._MODELS[model_name] 14 | model_path = clip._download(url) 15 | 16 | try: 17 | # loading JIT archive 18 | model = torch.jit.load(model_path, map_location="cpu").eval() 19 | state_dict = None 20 | 21 | except RuntimeError: 22 | state_dict = torch.load(model_path, map_location="cpu") 23 | 24 | model = clip.build_model(state_dict or model.state_dict()) 25 | 26 | return model 27 | 28 | 29 | class CLIPEncoder(nn.Module): 30 | def __init__(self, ref_img=None): 31 | super().__init__() 32 | self.clip_model = load_clip_to_cpu() 33 | self.clip_model.requires_grad = True 34 | self.preprocess = torchvision.transforms.Normalize( 35 | (0.48145466*2-1, 0.4578275*2-1, 0.40821073*2-1), 36 | (0.26862954*2, 0.26130258*2, 0.27577711*2) 37 | ) 38 | self.ref = ref_img 39 | 40 | def get_gram_matrix_residual(self, im1): 41 | im1 = torch.nn.functional.interpolate(im1, size=(224, 224), mode='bicubic') 42 | im1 = self.preprocess(im1) 43 | 44 | f1, feats1 = self.clip_model.encode_image_with_features(im1) 45 | f2, feats2 = self.clip_model.encode_image_with_features(self.ref) 46 | 47 | feat1 = feats1[2][1:, 0, :].to(torch.float16) 48 | feat2 = feats2[2][1:, 0, :].to(torch.float16) 49 | gram1 = torch.mm(feat1.t(), feat1) 50 | gram2 = torch.mm(feat2.t(), feat2) 51 | return gram1 - gram2 52 | 53 | 54 | 55 | if __name__ == "__main__": 56 | m = CLIPEncoder().cuda() 57 | im1 = torch.randn((1, 3, 224, 224)).cuda() 58 | im2 = torch.randn((1, 3, 224, 224)).cuda() 59 | m.get_gram_matrix_residual(im1, im2) 60 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .clip import * -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/.ipynb_checkpoints/clip-checkpoint.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | # if torch.__version__.split(".") < ["1", "7", "1"]: 23 | # warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | 207 | sot_token = _tokenizer.encoder["<|startoftext|>"] 208 | eot_token = _tokenizer.encoder["<|endoftext|>"] 209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 211 | 212 | for i, tokens in enumerate(all_tokens): 213 | if len(tokens) > context_length: 214 | if truncate: 215 | tokens = tokens[:context_length] 216 | tokens[-1] = eot_token 217 | else: 218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 219 | result[i, :len(tokens)] = torch.tensor(tokens) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/.ipynb_checkpoints/simple_tokenizer-checkpoint.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/clip.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/clip.cpython-36.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/simple_tokenizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/simple_tokenizer.cpython-36.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/clip/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/models/diffusion/clip/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | # if torch.__version__.split(".") < ["1", "7", "1"]: 23 | # warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | 207 | sot_token = _tokenizer.encoder["<|startoftext|>"] 208 | eot_token = _tokenizer.encoder["<|endoftext|>"] 209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 211 | 212 | for i, tokens in enumerate(all_tokens): 213 | if len(tokens) > context_length: 214 | if truncate: 215 | tokens = tokens[:context_length] 216 | tokens[-1] = eot_token 217 | else: 218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 219 | result[i, :len(tokens)] = torch.tensor(tokens) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/models/diffusion/clip/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/__pycache__/vgg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/__pycache__/vgg.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/__pycache__/x_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/__pycache__/x_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/openaimodel2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/openaimodel2.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | # steps_out = ddim_timesteps 59 | if verbose: 60 | print(f'Selected timesteps for ddim sampler: {steps_out}') 61 | return steps_out 62 | 63 | 64 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 65 | # select alphas for computing the variance schedule 66 | alphas = alphacums[ddim_timesteps] 67 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 68 | 69 | # according the the formula provided in https://arxiv.org/abs/2010.02502 70 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 71 | if verbose: 72 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 73 | print(f'For the chosen value of eta, which is {eta}, ' 74 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 75 | return sigmas, alphas, alphas_prev 76 | 77 | 78 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 79 | """ 80 | Create a beta schedule that discretizes the given alpha_t_bar function, 81 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 82 | :param num_diffusion_timesteps: the number of betas to produce. 83 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 84 | produces the cumulative product of (1-beta) up to that 85 | part of the diffusion process. 86 | :param max_beta: the maximum beta to use; use values lower than 1 to 87 | prevent singularities. 88 | """ 89 | betas = [] 90 | for i in range(num_diffusion_timesteps): 91 | t1 = i / num_diffusion_timesteps 92 | t2 = (i + 1) / num_diffusion_timesteps 93 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 94 | return np.array(betas) 95 | 96 | 97 | def extract_into_tensor(a, t, x_shape): 98 | b, *_ = t.shape 99 | out = a.gather(-1, t) 100 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 101 | 102 | 103 | def checkpoint(func, inputs, params, flag): 104 | """ 105 | Evaluate a function without caching intermediate activations, allowing for 106 | reduced memory at the expense of extra compute in the backward pass. 107 | :param func: the function to evaluate. 108 | :param inputs: the argument sequence to pass to `func`. 109 | :param params: a sequence of parameters `func` depends on but does not 110 | explicitly take as arguments. 111 | :param flag: if False, disable gradient checkpointing. 112 | """ 113 | if flag: 114 | args = tuple(inputs) + tuple(params) 115 | return CheckpointFunction.apply(func, len(inputs), *args) 116 | else: 117 | return func(*inputs) 118 | 119 | 120 | class CheckpointFunction(torch.autograd.Function): 121 | @staticmethod 122 | def forward(ctx, run_function, length, *args): 123 | ctx.run_function = run_function 124 | ctx.input_tensors = list(args[:length]) 125 | ctx.input_params = list(args[length:]) 126 | 127 | with torch.no_grad(): 128 | output_tensors = ctx.run_function(*ctx.input_tensors) 129 | return output_tensors 130 | 131 | @staticmethod 132 | def backward(ctx, *output_grads): 133 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 134 | with torch.enable_grad(): 135 | # Fixes a bug where the first op in run_function modifies the 136 | # Tensor storage in place, which is not allowed for detach()'d 137 | # Tensors. 138 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 139 | output_tensors = ctx.run_function(*shallow_copies) 140 | input_grads = torch.autograd.grad( 141 | output_tensors, 142 | ctx.input_tensors + ctx.input_params, 143 | output_grads, 144 | allow_unused=True, 145 | ) 146 | del ctx.input_tensors 147 | del ctx.input_params 148 | del output_tensors 149 | return (None, None) + input_grads 150 | 151 | 152 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 153 | """ 154 | Create sinusoidal timestep embeddings. 155 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 156 | These may be fractional. 157 | :param dim: the dimension of the output. 158 | :param max_period: controls the minimum frequency of the embeddings. 159 | :return: an [N x dim] Tensor of positional embeddings. 160 | """ 161 | if not repeat_only: 162 | half = dim // 2 163 | freqs = torch.exp( 164 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 165 | ).to(device=timesteps.device) 166 | args = timesteps[:, None].float() * freqs[None] 167 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 168 | if dim % 2: 169 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 170 | else: 171 | embedding = repeat(timesteps, 'b -> b d', d=dim) 172 | return embedding 173 | 174 | 175 | def zero_module(module): 176 | """ 177 | Zero out the parameters of a module and return it. 178 | """ 179 | for p in module.parameters(): 180 | p.detach().zero_() 181 | return module 182 | 183 | 184 | def scale_module(module, scale): 185 | """ 186 | Scale the parameters of a module and return it. 187 | """ 188 | for p in module.parameters(): 189 | p.detach().mul_(scale) 190 | return module 191 | 192 | 193 | def mean_flat(tensor): 194 | """ 195 | Take the mean over all non-batch dimensions. 196 | """ 197 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 198 | 199 | 200 | def normalization(channels): 201 | """ 202 | Make a standard normalization layer. 203 | :param channels: number of input channels. 204 | :return: an nn.Module for normalization. 205 | """ 206 | return GroupNorm32(32, channels) 207 | 208 | 209 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 210 | class SiLU(nn.Module): 211 | def forward(self, x): 212 | return x * torch.sigmoid(x) 213 | 214 | 215 | class GroupNorm32(nn.GroupNorm): 216 | def forward(self, x): 217 | return super().forward(x.float()).type(x.dtype) 218 | 219 | def conv_nd(dims, *args, **kwargs): 220 | """ 221 | Create a 1D, 2D, or 3D convolution module. 222 | """ 223 | if dims == 1: 224 | return nn.Conv1d(*args, **kwargs) 225 | elif dims == 2: 226 | return nn.Conv2d(*args, **kwargs) 227 | elif dims == 3: 228 | return nn.Conv3d(*args, **kwargs) 229 | raise ValueError(f"unsupported dimensions: {dims}") 230 | 231 | 232 | def linear(*args, **kwargs): 233 | """ 234 | Create a linear module. 235 | """ 236 | return nn.Linear(*args, **kwargs) 237 | 238 | 239 | def nonlinearity(x): 240 | # swish 241 | return x * torch.sigmoid(x) 242 | 243 | 244 | def avg_pool_nd(dims, *args, **kwargs): 245 | """ 246 | Create a 1D, 2D, or 3D average pooling module. 247 | """ 248 | if dims == 1: 249 | return nn.AvgPool1d(*args, **kwargs) 250 | elif dims == 2: 251 | return nn.AvgPool2d(*args, **kwargs) 252 | elif dims == 3: 253 | return nn.AvgPool3d(*args, **kwargs) 254 | raise ValueError(f"unsupported dimensions: {dims}") 255 | 256 | 257 | class HybridConditioner(nn.Module): 258 | 259 | def __init__(self, c_concat_config, c_crossattn_config): 260 | super().__init__() 261 | self.concat_conditioner = instantiate_from_config(c_concat_config) 262 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 263 | 264 | def forward(self, c_concat, c_crossattn): 265 | c_concat = self.concat_conditioner(c_concat) 266 | c_crossattn = self.crossattn_conditioner(c_crossattn) 267 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 268 | 269 | 270 | def noise_like(shape, device, repeat=False): 271 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 272 | noise = lambda: torch.randn(shape, device=device) 273 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/encoders/__pycache__/xf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/Synthesis_Stage/ldm/modules/encoders/__pycache__/xf.cpython-38.pyc -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | from transformers import CLIPTokenizer, CLIPTextModel, CLIPVisionModel, CLIPModel 7 | import kornia 8 | from ldm.modules.x_transformer import Encoder, \ 9 | TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a requirement? --> test 10 | from .xf import LayerNorm, Transformer 11 | import math 12 | 13 | 14 | class AbstractEncoder(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def encode(self, *args, **kwargs): 19 | raise NotImplementedError 20 | 21 | 22 | class ClassEmbedder(nn.Module): 23 | def __init__(self, embed_dim, n_classes=1000, key='class'): 24 | super().__init__() 25 | self.key = key 26 | self.embedding = nn.Embedding(n_classes, embed_dim) 27 | 28 | def forward(self, batch, key=None): 29 | if key is None: 30 | key = self.key 31 | # this is for use in crossattn 32 | c = batch[key][:, None] 33 | c = self.embedding(c) 34 | return c 35 | 36 | 37 | class TransformerEmbedder(AbstractEncoder): 38 | """Some transformer encoder layers""" 39 | 40 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 41 | super().__init__() 42 | self.device = device 43 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 44 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 45 | 46 | def forward(self, tokens): 47 | tokens = tokens.to(self.device) # meh 48 | z = self.transformer(tokens, return_embeddings=True) 49 | return z 50 | 51 | def encode(self, x): 52 | return self(x) 53 | 54 | 55 | class BERTTokenizer(AbstractEncoder): 56 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 57 | 58 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 59 | super().__init__() 60 | from transformers import BertTokenizerFast # TODO: add to requirements 61 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 62 | self.device = device 63 | self.vq_interface = vq_interface 64 | self.max_length = max_length 65 | 66 | def forward(self, text): 67 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 68 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 69 | tokens = batch_encoding["input_ids"].to(self.device) 70 | return tokens 71 | 72 | @torch.no_grad() 73 | def encode(self, text): 74 | tokens = self(text) 75 | if not self.vq_interface: 76 | return tokens 77 | return None, None, [None, None, tokens] 78 | 79 | def decode(self, text): 80 | return text 81 | 82 | 83 | class BERTEmbedder(AbstractEncoder): 84 | """Uses the BERT tokenizer model and add some transformer encoder layers""" 85 | 86 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 87 | device="cuda", use_tokenizer=True, embedding_dropout=0.0): 88 | super().__init__() 89 | self.use_tknz_fn = use_tokenizer 90 | if self.use_tknz_fn: 91 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 92 | self.device = device 93 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 94 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 95 | emb_dropout=embedding_dropout) 96 | 97 | def forward(self, text): 98 | if self.use_tknz_fn: 99 | tokens = self.tknz_fn(text) # .to(self.device) 100 | else: 101 | tokens = text 102 | z = self.transformer(tokens, return_embeddings=True) 103 | return z 104 | 105 | def encode(self, text): 106 | # output of length 77 107 | return self(text) 108 | 109 | 110 | class SpatialRescaler(nn.Module): 111 | def __init__(self, 112 | n_stages=1, 113 | method='bilinear', 114 | multiplier=0.5, 115 | in_channels=3, 116 | out_channels=None, 117 | bias=False): 118 | super().__init__() 119 | self.n_stages = n_stages 120 | assert self.n_stages >= 0 121 | assert method in ['nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'] 122 | self.multiplier = multiplier 123 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 124 | self.remap_output = out_channels is not None 125 | if self.remap_output: 126 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 127 | self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias) 128 | 129 | def forward(self, x): 130 | for stage in range(self.n_stages): 131 | x = self.interpolator(x, scale_factor=self.multiplier) 132 | 133 | if self.remap_output: 134 | x = self.channel_mapper(x) 135 | return x 136 | 137 | def encode(self, x): 138 | return self(x) 139 | 140 | 141 | class FrozenCLIPImageEmbedder(AbstractEncoder): 142 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 143 | 144 | def __init__(self, version="openai/clip-vit-large-patch14"): 145 | super().__init__() 146 | self.transformer = CLIPVisionModel.from_pretrained(version) 147 | self.final_ln = LayerNorm(1024) 148 | self.mapper = Transformer( 149 | 1, 150 | 1024, 151 | 5, 152 | 1, 153 | ) 154 | 155 | self.freeze() 156 | 157 | def freeze(self): 158 | self.transformer = self.transformer.eval() 159 | for param in self.parameters(): 160 | param.requires_grad = False 161 | for param in self.mapper.parameters(): 162 | param.requires_grad = True 163 | for param in self.final_ln.parameters(): 164 | param.requires_grad = True 165 | 166 | def forward(self, image): 167 | outputs = self.transformer(pixel_values=image) 168 | z = outputs.pooler_output 169 | z = z.unsqueeze(1) 170 | z = self.mapper(z) 171 | z = self.final_ln(z) 172 | return z 173 | 174 | def encode(self, image): 175 | return self(image) 176 | 177 | 178 | if __name__ == "__main__": 179 | from ldm.util import count_params 180 | 181 | model = FrozenCLIPEmbedder() 182 | count_params(model, verbose=True) 183 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/encoders/xf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer implementation adapted from CLIP ViT: 3 | https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py 4 | """ 5 | 6 | import math 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | 12 | def convert_module_to_f16(l): 13 | """ 14 | Convert primitive modules to float16. 15 | """ 16 | if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): 17 | l.weight.data = l.weight.data.half() 18 | if l.bias is not None: 19 | l.bias.data = l.bias.data.half() 20 | 21 | 22 | class LayerNorm(nn.LayerNorm): 23 | """ 24 | Implementation that supports fp16 inputs but fp32 gains/biases. 25 | """ 26 | 27 | def forward(self, x: th.Tensor): 28 | return super().forward(x.float()).to(x.dtype) 29 | 30 | 31 | class MultiheadAttention(nn.Module): 32 | def __init__(self, n_ctx, width, heads): 33 | super().__init__() 34 | self.n_ctx = n_ctx 35 | self.width = width 36 | self.heads = heads 37 | self.c_qkv = nn.Linear(width, width * 3) 38 | self.c_proj = nn.Linear(width, width) 39 | self.attention = QKVMultiheadAttention(heads, n_ctx) 40 | 41 | def forward(self, x): 42 | x = self.c_qkv(x) 43 | x = self.attention(x) 44 | x = self.c_proj(x) 45 | return x 46 | 47 | 48 | class MLP(nn.Module): 49 | def __init__(self, width): 50 | super().__init__() 51 | self.width = width 52 | self.c_fc = nn.Linear(width, width * 4) 53 | self.c_proj = nn.Linear(width * 4, width) 54 | self.gelu = nn.GELU() 55 | 56 | def forward(self, x): 57 | return self.c_proj(self.gelu(self.c_fc(x))) 58 | 59 | 60 | class QKVMultiheadAttention(nn.Module): 61 | def __init__(self, n_heads: int, n_ctx: int): 62 | super().__init__() 63 | self.n_heads = n_heads 64 | self.n_ctx = n_ctx 65 | 66 | def forward(self, qkv): 67 | bs, n_ctx, width = qkv.shape 68 | attn_ch = width // self.n_heads // 3 69 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 70 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1) 71 | q, k, v = th.split(qkv, attn_ch, dim=-1) 72 | weight = th.einsum( 73 | "bthc,bshc->bhts", q * scale, k * scale 74 | ) # More stable with f16 than dividing afterwards 75 | wdtype = weight.dtype 76 | weight = th.softmax(weight.float(), dim=-1).type(wdtype) 77 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 78 | 79 | 80 | class ResidualAttentionBlock(nn.Module): 81 | def __init__( 82 | self, 83 | n_ctx: int, 84 | width: int, 85 | heads: int, 86 | ): 87 | super().__init__() 88 | 89 | self.attn = MultiheadAttention( 90 | n_ctx, 91 | width, 92 | heads, 93 | ) 94 | self.ln_1 = LayerNorm(width) 95 | self.mlp = MLP(width) 96 | self.ln_2 = LayerNorm(width) 97 | 98 | def forward(self, x: th.Tensor): 99 | x = x + self.attn(self.ln_1(x)) 100 | x = x + self.mlp(self.ln_2(x)) 101 | return x 102 | 103 | 104 | class Transformer(nn.Module): 105 | def __init__( 106 | self, 107 | n_ctx: int, 108 | width: int, 109 | layers: int, 110 | heads: int, 111 | ): 112 | super().__init__() 113 | self.n_ctx = n_ctx 114 | self.width = width 115 | self.layers = layers 116 | self.resblocks = nn.ModuleList( 117 | [ 118 | ResidualAttentionBlock( 119 | n_ctx, 120 | width, 121 | heads, 122 | ) 123 | for _ in range(layers) 124 | ] 125 | ) 126 | 127 | def forward(self, x: th.Tensor): 128 | for block in self.resblocks: 129 | x = block(x) 130 | return x 131 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/modules/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def vgg_preprocess(tensor, vgg_normal_correct=False): 7 | if vgg_normal_correct: 8 | tensor = (tensor + 1) / 2 9 | # input is RGB tensor which ranges in [0,1] 10 | # output is BGR tensor which ranges in [0,255] 11 | tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1) 12 | # tensor_bgr = tensor[:, [2, 1, 0], ...] 13 | tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view( 14 | 1, 3, 1, 1 15 | ) 16 | tensor_rst = tensor_bgr_ml * 255 17 | return tensor_rst 18 | 19 | 20 | class VGG19_feature_color_torchversion(nn.Module): 21 | """ 22 | NOTE: there is no need to pre-process the input 23 | input tensor should range in [0,1] 24 | """ 25 | 26 | def __init__(self, pool="max", vgg_normal_correct=True, ic=3): 27 | super(VGG19_feature_color_torchversion, self).__init__() 28 | self.vgg_normal_correct = vgg_normal_correct 29 | 30 | self.conv1_1 = nn.Conv2d(ic, 64, kernel_size=3, padding=1) 31 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 32 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 33 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 34 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 35 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 36 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 37 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 38 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 39 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 40 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 41 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 42 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 43 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 44 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 45 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 46 | if pool == "max": 47 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 48 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 49 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 50 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 51 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 52 | elif pool == "avg": 53 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 54 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 55 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) 56 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) 57 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) 58 | 59 | def forward(self, x, out_keys, preprocess=True): 60 | """ 61 | NOTE: input tensor should range in [0,1] 62 | """ 63 | out = {} 64 | if preprocess: 65 | x = vgg_preprocess(x, vgg_normal_correct=self.vgg_normal_correct) 66 | out["r11"] = F.relu(self.conv1_1(x)) 67 | out["r12"] = F.relu(self.conv1_2(out["r11"])) 68 | out["p1"] = self.pool1(out["r12"]) 69 | out["r21"] = F.relu(self.conv2_1(out["p1"])) 70 | out["r22"] = F.relu(self.conv2_2(out["r21"])) 71 | out["p2"] = self.pool2(out["r22"]) 72 | out["r31"] = F.relu(self.conv3_1(out["p2"])) 73 | out["r32"] = F.relu(self.conv3_2(out["r31"])) 74 | out["r33"] = F.relu(self.conv3_3(out["r32"])) 75 | out["r34"] = F.relu(self.conv3_4(out["r33"])) 76 | out["p3"] = self.pool3(out["r34"]) 77 | out["r41"] = F.relu(self.conv4_1(out["p3"])) 78 | out["r42"] = F.relu(self.conv4_2(out["r41"])) 79 | out["r43"] = F.relu(self.conv4_3(out["r42"])) 80 | out["r44"] = F.relu(self.conv4_4(out["r43"])) 81 | out["p4"] = self.pool4(out["r44"]) 82 | out["r51"] = F.relu(self.conv5_1(out["p4"])) 83 | out["r52"] = F.relu(self.conv5_2(out["r51"])) 84 | out["r53"] = F.relu(self.conv5_3(out["r52"])) 85 | out["r54"] = F.relu(self.conv5_4(out["r53"])) 86 | out["p5"] = self.pool5(out["r54"]) 87 | return [out[key] for key in out_keys] 88 | -------------------------------------------------------------------------------- /Synthesis_Stage/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /Synthesis_Stage/scripts/test.sh: -------------------------------------------------------------------------------- 1 | python test.py --gpu_id 0 --ddim_steps 100 \ 2 | --outdir results/d4vton_unpaired_syn --config configs/vitonhd_512.yaml \ 3 | --dataroot \ 4 | --ckpt checkpoints/vitonhd_synthesis.ckpt --delta_step 89 \ 5 | --n_samples 12 --seed 23 --scale 1 --H 512 --unpaired -------------------------------------------------------------------------------- /Synthesis_Stage/scripts/train.sh: -------------------------------------------------------------------------------- 1 | python -u main.py --logdir models/d4vton_syn --pretrained_model checkpoints/model.ckpt \ 2 | --base configs/vitonhd_512.yaml --scale_lr False -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/assets/pipeline.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jerome-Young/D4-VTON/c01a3b278ef2d334c217819f67f0bb567dc9de0d/assets/results.png -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: d4-vton 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=5.1 8 | - blas=1.0 9 | - brotli-python=1.0.9 10 | - bzip2=1.0.8 11 | - ca-certificates=2023.08.22 12 | - certifi=2023.11.17 13 | - cffi=1.15.1 14 | - charset-normalizer=2.0.4 15 | - cryptography=41.0.3 16 | - cudatoolkit=11.3.1 17 | - ffmpeg=4.3 18 | - freetype=2.12.1 19 | - giflib=5.2.1 20 | - gmp=6.2.1 21 | - gnutls=3.6.15 22 | - idna=3.4 23 | - intel-openmp=2021.4.0 24 | - jpeg=9e 25 | - lame=3.100 26 | - lcms2=2.12 27 | - ld_impl_linux-64=2.38 28 | - lerc=3.0 29 | - libdeflate=1.17 30 | - libffi=3.3 31 | - libgcc-ng=11.2.0 32 | - libgfortran-ng=11.2.0 33 | - libgfortran5=11.2.0 34 | - libgomp=11.2.0 35 | - libiconv=1.16 36 | - libidn2=2.3.4 37 | - libpng=1.6.39 38 | - libstdcxx-ng=11.2.0 39 | - libtasn1=4.19.0 40 | - libtiff=4.5.1 41 | - libunistring=0.9.10 42 | - libuv=1.44.2 43 | - libwebp=1.3.2 44 | - libwebp-base=1.3.2 45 | - lz4-c=1.9.4 46 | - mkl=2021.4.0 47 | - mkl-service=2.4.0 48 | - mkl_fft=1.3.1 49 | - mkl_random=1.2.2 50 | - ncurses=6.4 51 | - nettle=3.7.3 52 | - openh264=2.1.1 53 | - openjpeg=2.4.0 54 | - openssl=1.1.1w 55 | - pip=20.3.3 56 | - pycparser=2.21 57 | - pyopenssl=23.2.0 58 | - pysocks=1.7.1 59 | - python=3.8.5 60 | - pytorch=1.11.0 61 | - pytorch-mutex=1.0 62 | - readline=8.2 63 | - requests=2.31.0 64 | - setuptools=68.0.0 65 | - six=1.16.0 66 | - sqlite=3.41.2 67 | - tk=8.6.12 68 | - torchvision=0.12.0 69 | - typing_extensions=4.7.1 70 | - urllib3=1.26.18 71 | - wheel=0.41.2 72 | - xz=5.4.2 73 | - zlib=1.2.13 74 | - zstd=1.5.5 75 | - pip: 76 | - absl-py==2.0.0 77 | - aiohttp==3.9.0 78 | - aiosignal==1.3.1 79 | - albumentations==0.4.3 80 | - altair==5.1.2 81 | - antlr4-python3-runtime==4.8 82 | - asttokens==2.4.1 83 | - async-timeout==4.0.3 84 | - attrs==23.1.0 85 | - backcall==0.2.0 86 | - backports-zoneinfo==0.2.1 87 | - beautifulsoup4==4.12.3 88 | - bezier==2023.7.28 89 | - bleach==6.1.0 90 | - blessed==1.20.0 91 | - blinker==1.7.0 92 | - cachetools==5.3.2 93 | - click==8.1.7 94 | - contourpy==1.1.1 95 | - cycler==0.12.1 96 | - decorator==5.1.1 97 | - defusedxml==0.7.1 98 | - diffusers==0.23.1 99 | - docker-pycreds==0.4.0 100 | - docopt==0.6.2 101 | - einops==0.3.0 102 | - executing==2.0.1 103 | - fairscale==0.4.13 104 | - fastjsonschema==2.20.0 105 | - filelock==3.13.1 106 | - fonttools==4.45.0 107 | - frozenlist==1.4.0 108 | - fsspec==2023.10.0 109 | - ftfy==6.1.3 110 | - future==0.18.3 111 | - gitdb==4.0.11 112 | - gitpython==3.1.40 113 | - google-auth==2.23.4 114 | - google-auth-oauthlib==1.0.0 115 | - grpcio==1.59.3 116 | - huggingface-hub==0.19.4 117 | - imageio==2.9.0 118 | - imageio-ffmpeg==0.4.2 119 | - imgaug==0.2.6 120 | - importlib-metadata==6.8.0 121 | - importlib-resources==6.1.1 122 | - invisible-watermark==0.2.0 123 | - ipython==8.12.3 124 | - jedi==0.19.1 125 | - jinja2==3.1.2 126 | - jsonschema==4.20.0 127 | - jsonschema-specifications==2023.11.1 128 | - jupyter-client==8.6.2 129 | - jupyter-core==5.7.2 130 | - jupyterlab-pygments==0.3.0 131 | - kiwisolver==1.4.5 132 | - kornia==0.6.0 133 | - lazy-loader==0.3 134 | - lightning-utilities==0.10.0 135 | - markdown==3.5.1 136 | - markdown-it-py==3.0.0 137 | - markupsafe==2.1.3 138 | - matplotlib==3.7.4 139 | - matplotlib-inline==0.1.7 140 | - mdurl==0.1.2 141 | - mistune==3.0.2 142 | - multidict==6.0.4 143 | - nbclient==0.10.0 144 | - nbconvert==7.16.4 145 | - nbformat==5.10.4 146 | - networkx==3.1 147 | - numpy==1.24.4 148 | - nvidia-ml-py==12.535.133 149 | - oauthlib==3.2.2 150 | - omegaconf==2.1.1 151 | - openai-clip==1.0.1 152 | - opencv-python==4.1.2.30 153 | - opencv-python-headless==4.8.1.78 154 | - packaging==23.2 155 | - pandas==2.0.3 156 | - pandocfilters==1.5.1 157 | - parso==0.8.4 158 | - pexpect==4.9.0 159 | - pickleshare==0.7.5 160 | - pillow==9.5.0 161 | - pipreqs==0.5.0 162 | - pkgutil-resolve-name==1.3.10 163 | - platformdirs==4.2.2 164 | - prompt-toolkit==3.0.47 165 | - protobuf==4.25.1 166 | - psutil==5.9.6 167 | - ptyprocess==0.7.0 168 | - pudb==2019.2 169 | - pure-eval==0.2.3 170 | - pyarrow==14.0.1 171 | - pyasn1==0.5.1 172 | - pyasn1-modules==0.3.0 173 | - pycocotools==2.0.7 174 | - pydeck==0.8.1b0 175 | - pydeprecate==0.3.1 176 | - pygments==2.17.2 177 | - pyparsing==3.1.1 178 | - python-dateutil==2.8.2 179 | - pytorch-lightning==1.4.2 180 | - pytz==2023.3.post1 181 | - pywavelets==1.4.1 182 | - pyyaml==6.0.1 183 | - pyzmq==26.0.3 184 | - referencing==0.31.0 185 | - regex==2023.10.3 186 | - requests-oauthlib==1.3.1 187 | - rich==13.7.0 188 | - rpds-py==0.13.1 189 | - rsa==4.9 190 | - safetensors==0.4.0 191 | - scikit-image==0.19.3 192 | - scipy==1.10.1 193 | - sentry-sdk==2.12.0 194 | - setproctitle==1.3.3 195 | - smmap==5.0.1 196 | - soupsieve==2.5 197 | - stack-data==0.6.3 198 | - streamlit==1.28.2 199 | - tenacity==8.2.3 200 | - tensorboard==2.14.0 201 | - tensorboard-data-server==0.7.2 202 | - test-tube==0.7.5 203 | - tifffile==2023.7.10 204 | - tinycss2==1.3.0 205 | - tokenizers==0.12.1 206 | - toml==0.10.2 207 | - toolz==0.12.0 208 | - torchmetrics==0.6.0 209 | - tornado==6.3.3 210 | - tqdm==4.66.1 211 | - traitlets==5.14.3 212 | - transformers==4.19.2 213 | - tzdata==2023.3 214 | - tzlocal==5.2 215 | - urwid==2.2.3 216 | - validators==0.22.0 217 | - wandb==0.17.5 218 | - watchdog==3.0.0 219 | - wcwidth==0.2.12 220 | - webencodings==0.5.1 221 | - werkzeug==3.0.1 222 | - yarg==0.1.9 223 | - yarl==1.9.3 224 | - zipp==3.17.0 225 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 226 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip --------------------------------------------------------------------------------