├── models ├── insight_face │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── helpers.cpython-38.pyc │ │ └── model_irse.cpython-38.pyc │ ├── model_irse.py │ └── helpers.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── ddpm │ ├── __pycache__ │ │ └── diffusion.cpython-38.pyc │ └── diffusion.py └── improved_ddpm │ ├── __pycache__ │ ├── nn.cpython-38.pyc │ ├── unet.cpython-38.pyc │ ├── logger.cpython-38.pyc │ ├── fp16_util.cpython-38.pyc │ └── script_util.cpython-38.pyc │ ├── script_util.py │ ├── nn.py │ ├── fp16_util.py │ ├── logger.py │ └── unet.py ├── losses ├── __pycache__ │ ├── id_loss.cpython-38.pyc │ └── clip_loss.cpython-38.pyc ├── id_loss.py └── clip_loss.py ├── utils ├── __pycache__ │ ├── text_dic.cpython-38.pyc │ ├── align_utils.cpython-38.pyc │ ├── model_utils.cpython-38.pyc │ ├── diffusion_utils.cpython-38.pyc │ ├── text_templates.cpython-38.pyc │ └── image_processing.cpython-38.pyc ├── model_utils.py ├── diffusion_utils.py ├── text_templates.py ├── align_utils.py └── image_processing.py ├── datasets ├── __pycache__ │ ├── data_utils.cpython-38.pyc │ ├── mt_dataset.cpython-38.pyc │ ├── AFHQ_dataset.cpython-38.pyc │ ├── LSUN_dataset.cpython-38.pyc │ ├── imagenet_dic.cpython-38.pyc │ ├── ladn_dataset.cpython-38.pyc │ ├── celeba_dataset.cpython-38.pyc │ ├── CelebA_HQ_dataset.cpython-38.pyc │ └── IMAGENET_dataset.cpython-38.pyc ├── data_utils.py ├── mt_dataset.py └── celeba_dataset.py ├── configs ├── __pycache__ │ └── paths_config.cpython-38.pyc ├── paths_config.py ├── MT.yml └── celeba.yml ├── requirements.txt ├── README.md ├── main.py ├── makeup_removal.py └── makeup_transfer.py /models/insight_face/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /losses/__pycache__/id_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/losses/__pycache__/id_loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/text_dic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/utils/__pycache__/text_dic.cpython-38.pyc -------------------------------------------------------------------------------- /losses/__pycache__/clip_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/losses/__pycache__/clip_loss.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/datasets/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/mt_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/datasets/__pycache__/mt_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/align_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/utils/__pycache__/align_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/utils/__pycache__/model_utils.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/paths_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/configs/__pycache__/paths_config.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/AFHQ_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/datasets/__pycache__/AFHQ_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/LSUN_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/datasets/__pycache__/LSUN_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_dic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/datasets/__pycache__/imagenet_dic.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ladn_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/datasets/__pycache__/ladn_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /models/ddpm/__pycache__/diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/ddpm/__pycache__/diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/diffusion_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/utils/__pycache__/diffusion_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/text_templates.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/utils/__pycache__/text_templates.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/celeba_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/datasets/__pycache__/celeba_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /models/improved_ddpm/__pycache__/nn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/improved_ddpm/__pycache__/nn.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_processing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/utils/__pycache__/image_processing.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/CelebA_HQ_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/datasets/__pycache__/CelebA_HQ_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/IMAGENET_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/datasets/__pycache__/IMAGENET_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /models/improved_ddpm/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/improved_ddpm/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /models/improved_ddpm/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/improved_ddpm/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /models/insight_face/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/insight_face/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/insight_face/__pycache__/helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/insight_face/__pycache__/helpers.cpython-38.pyc -------------------------------------------------------------------------------- /models/improved_ddpm/__pycache__/fp16_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/improved_ddpm/__pycache__/fp16_util.cpython-38.pyc -------------------------------------------------------------------------------- /models/insight_face/__pycache__/model_irse.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/insight_face/__pycache__/model_irse.cpython-38.pyc -------------------------------------------------------------------------------- /models/improved_ddpm/__pycache__/script_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HansSunY/DiffAM/HEAD/models/improved_ddpm/__pycache__/script_util.cpython-38.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cmake==3.25.2 2 | dlib==19.24.2 3 | lpips==0.1.4 4 | opencv-python==4.10.0.84 5 | PyYAML==6.0.1 6 | scipy==1.10.1 7 | torch==2.4.0 8 | torchvision==0.19.0 9 | tqdm==4.66.4 10 | 11 | -------------------------------------------------------------------------------- /configs/paths_config.py: -------------------------------------------------------------------------------- 1 | DATASET_PATHS = { 2 | 'CelebA_HQ': './assets/datasets/CelebAMask-HQ/', 3 | 'MT': './assets/datasets/MT-dataset/' 4 | } 5 | 6 | MODEL_PATHS = { 7 | 'ir_se50': 'pretrained/model_ir_se50.pth', 8 | 'shape_predictor': "pretrained/shape_predictor_68_face_landmarks.dat.bz2", 9 | } 10 | 11 | -------------------------------------------------------------------------------- /configs/MT.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "MT" 3 | category: "MT" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 0 12 | 13 | model: 14 | type: "simple" 15 | in_channels: 3 16 | out_ch: 3 17 | ch: 128 18 | ch_mult: [1, 1, 2, 2, 4, 4] 19 | num_res_blocks: 2 20 | attn_resolutions: [16, ] 21 | dropout: 0.0 22 | var_type: fixedsmall 23 | ema_rate: 0.999 24 | ema: True 25 | resamp_with_conv: True 26 | 27 | diffusion: 28 | beta_schedule: linear 29 | beta_start: 0.0001 30 | beta_end: 0.02 31 | num_diffusion_timesteps: 1000 32 | 33 | sampling: 34 | batch_size: 4 35 | last_only: True -------------------------------------------------------------------------------- /configs/celeba.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CelebA_HQ" 3 | category: "CelebA_HQ" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 0 12 | 13 | model: 14 | type: "simple" 15 | in_channels: 3 16 | out_ch: 3 17 | ch: 128 18 | ch_mult: [1, 1, 2, 2, 4, 4] 19 | num_res_blocks: 2 20 | attn_resolutions: [16, ] 21 | dropout: 0.0 22 | var_type: fixedsmall 23 | ema_rate: 0.999 24 | ema: True 25 | resamp_with_conv: True 26 | 27 | diffusion: 28 | beta_schedule: linear 29 | beta_start: 0.0001 30 | beta_end: 0.02 31 | num_diffusion_timesteps: 1000 32 | 33 | sampling: 34 | batch_size: 4 35 | last_only: True -------------------------------------------------------------------------------- /datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from .celeba_dataset import get_celeba_dataset 3 | from .mt_dataset import get_mt_dataset 4 | def get_dataset(dataset_type, dataset_paths, config): 5 | if dataset_type == 'CelebA_HQ': 6 | train_dataset, test_dataset = get_celeba_dataset(dataset_paths['CelebA_HQ'], config) 7 | elif dataset_type == 'MT': 8 | train_dataset, test_dataset = get_mt_dataset(dataset_paths['MT'],config) 9 | else: 10 | raise ValueError 11 | return train_dataset, test_dataset 12 | 13 | def get_dataloader(train_dataset, test_dataset, bs_train=1, num_workers=0): 14 | train_loader = DataLoader( 15 | train_dataset, 16 | batch_size=bs_train, 17 | drop_last=True, 18 | shuffle=True, 19 | sampler=None, 20 | num_workers=num_workers, 21 | pin_memory=True, 22 | ) 23 | test_loader = DataLoader( 24 | test_dataset, 25 | batch_size=1, 26 | drop_last=True, 27 | sampler=None, 28 | shuffle=True, 29 | num_workers=num_workers, 30 | pin_memory=True, 31 | ) 32 | 33 | return {'train': train_loader, 'test': test_loader} 34 | 35 | 36 | -------------------------------------------------------------------------------- /losses/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from configs.paths_config import MODEL_PATHS 4 | from models.insight_face.model_irse import Backbone, MobileFaceNet 5 | import torch.nn.functional as F 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self, use_mobile_id=False): 9 | super(IDLoss, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(MODEL_PATHS['ir_se50'])) 13 | 14 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 15 | self.facenet.eval() 16 | 17 | def extract_feats(self, x): 18 | x = x[:, :, 35:223, 32:220] # Crop interesting region 19 | x = self.face_pool(x) 20 | x_feats = self.facenet(x) 21 | return x_feats 22 | 23 | def forward(self, x, x_hat): 24 | n_samples = x.shape[0] 25 | x_feats = self.extract_feats(x) 26 | x_feats = x_feats.detach() 27 | 28 | x_hat_feats = self.extract_feats(x_hat) 29 | losses = [] 30 | for i in range(n_samples): 31 | loss_sample = 1 - x_hat_feats[i].dot(x_feats[i]) 32 | losses.append(loss_sample.unsqueeze(0)) 33 | 34 | losses = torch.cat(losses, dim=0) 35 | return losses 36 | 37 | def cos_simi(emb_1, emb_2): 38 | return torch.mean(torch.sum(torch.mul(emb_2, emb_1), dim=1) / emb_2.norm(dim=1) / emb_1.norm(dim=1)) 39 | 40 | def cal_adv_loss(source, target, model_name, target_models): 41 | input_size = target_models[model_name][0] 42 | fr_model = target_models[model_name][1] 43 | source_resize = F.interpolate(source, size=input_size, mode='bilinear') 44 | target_resize = F.interpolate(target, size=input_size, mode='bilinear') 45 | emb_source = fr_model(source_resize) 46 | emb_target = fr_model(target_resize).detach() 47 | cos_loss = 1 - cos_simi(emb_source, emb_target) 48 | return cos_loss 49 | -------------------------------------------------------------------------------- /datasets/mt_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torchvision.transforms as tfs 3 | import os 4 | from utils.align_utils import * 5 | from PIL import Image 6 | from utils.image_processing import ToTensor 7 | 8 | 9 | class MultiResolutionDataset(Dataset): 10 | def __init__(self, mark, path, transform, resolution=256): 11 | self.images = os.listdir(path + 'images/makeup/') 12 | 13 | if mark == 0: 14 | self.train_paths = [path + 'images/makeup/' + 15 | _ for _ in self.images][500:] 16 | self.mask_paths = [path + 'segs/makeup/' + 17 | _ for _ in self.images][500:] 18 | else: 19 | self.train_paths = [path + 'images/makeup/' + 20 | _ for _ in self.images][:500] 21 | self.mask_paths = [path + 'segs/makeup/' + 22 | _ for _ in self.images][:500] 23 | 24 | self.resolution = resolution 25 | self.transform = transform 26 | 27 | def __len__(self): 28 | return len(self.train_paths) 29 | 30 | def __getitem__(self, index): 31 | img_name = self.train_paths[index] 32 | mask_name = self.mask_paths[index] 33 | 34 | aligned_image = Image.open(img_name).resize( 35 | (self.resolution, self.resolution)) 36 | mask = Image.open(mask_name) 37 | 38 | img = self.transform(aligned_image) 39 | return img, ToTensor(mask) 40 | 41 | 42 | ################################################################################ 43 | 44 | def get_mt_dataset(data_root, config): 45 | transform = tfs.Compose([tfs.ToTensor(), tfs.Normalize( 46 | (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)]) 47 | 48 | train_dataset = MultiResolutionDataset( 49 | 0, data_root, transform, config.data.image_size) 50 | test_dataset = MultiResolutionDataset( 51 | 1, data_root, transform, config.data.image_size) 52 | 53 | return train_dataset, test_dataset 54 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from assets.models import irse, ir152, facenet 3 | 4 | def get_model_list(target_model): 5 | if target_model == 0: 6 | model_list = ['facenet','mobile_face','irse50','ir152'] 7 | elif target_model == 1: 8 | model_list = ['facenet','mobile_face','ir152','irse50'] 9 | elif target_model ==2: 10 | model_list = ['facenet','ir152','irse50','mobile_face'] 11 | else: 12 | model_list = ['ir152','irse50','mobile_face','facenet'] 13 | 14 | models = {} 15 | for model in model_list: 16 | if model == 'ir152': 17 | models[model] = [] 18 | models[model].append((112, 112)) 19 | fr_model = ir152.IR_152((112, 112)) 20 | fr_model.load_state_dict(torch.load('./assets/models/ir152.pth')) 21 | fr_model.to("cuda") 22 | fr_model.eval() 23 | models[model].append(fr_model) 24 | if model == 'irse50': 25 | models[model] = [] 26 | models[model].append((112, 112)) 27 | fr_model = irse.Backbone(50, 0.6, 'ir_se') 28 | fr_model.load_state_dict(torch.load('./assets/models/irse50.pth')) 29 | fr_model.to("cuda") 30 | fr_model.eval() 31 | models[model].append(fr_model) 32 | if model == 'facenet': 33 | models[model] = [] 34 | models[model].append((160, 160)) 35 | fr_model = facenet.InceptionResnetV1(num_classes=8631, device="cuda") 36 | fr_model.load_state_dict(torch.load('./assets/models/facenet.pth')) 37 | fr_model.to("cuda") 38 | fr_model.eval() 39 | models[model].append(fr_model) 40 | if model == 'mobile_face': 41 | models[model] = [] 42 | models[model].append((112, 112)) 43 | fr_model = irse.MobileFaceNet(512) 44 | fr_model.load_state_dict(torch.load('./assets/models/mobile_face.pth')) 45 | fr_model.to("cuda") 46 | fr_model.eval() 47 | models[model].append(fr_model) 48 | return models -------------------------------------------------------------------------------- /models/improved_ddpm/script_util.py: -------------------------------------------------------------------------------- 1 | from .unet import UNetModel 2 | 3 | NUM_CLASSES = 1000 4 | 5 | AFHQ_DICT = dict( 6 | attention_resolutions="16", 7 | class_cond=False, 8 | dropout=0.0, 9 | image_size=256, 10 | learn_sigma=True, 11 | num_channels=128, 12 | num_head_channels=64, 13 | num_res_blocks=1, 14 | resblock_updown=True, 15 | use_fp16=False, 16 | use_scale_shift_norm=True, 17 | num_heads=4, 18 | num_heads_upsample=-1, 19 | channel_mult="", 20 | use_checkpoint=False, 21 | use_new_attention_order=False, 22 | ) 23 | 24 | 25 | IMAGENET_DICT = dict( 26 | attention_resolutions="32,16,8", 27 | class_cond=True, 28 | image_size=512, 29 | learn_sigma=True, 30 | num_channels=256, 31 | num_head_channels=64, 32 | num_res_blocks=2, 33 | resblock_updown=True, 34 | use_fp16=False, 35 | use_scale_shift_norm=True, 36 | dropout=0.0, 37 | num_heads=4, 38 | num_heads_upsample=-1, 39 | channel_mult="", 40 | use_checkpoint=False, 41 | use_new_attention_order=False, 42 | ) 43 | 44 | 45 | def create_model( 46 | image_size, 47 | num_channels, 48 | num_res_blocks, 49 | channel_mult="", 50 | learn_sigma=False, 51 | class_cond=False, 52 | use_checkpoint=False, 53 | attention_resolutions="16", 54 | num_heads=1, 55 | num_head_channels=-1, 56 | num_heads_upsample=-1, 57 | use_scale_shift_norm=False, 58 | dropout=0, 59 | resblock_updown=False, 60 | use_fp16=False, 61 | use_new_attention_order=False, 62 | ): 63 | if channel_mult == "": 64 | if image_size == 512: 65 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 66 | elif image_size == 256: 67 | channel_mult = (1, 1, 2, 2, 4, 4) 68 | elif image_size == 128: 69 | channel_mult = (1, 1, 2, 3, 4) 70 | elif image_size == 64: 71 | channel_mult = (1, 2, 3, 4) 72 | else: 73 | raise ValueError(f"unsupported image size: {image_size}") 74 | else: 75 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 76 | 77 | attention_ds = [] 78 | for res in attention_resolutions.split(","): 79 | attention_ds.append(image_size // int(res)) 80 | 81 | return UNetModel( 82 | image_size=image_size, 83 | in_channels=3, 84 | model_channels=num_channels, 85 | out_channels=(3 if not learn_sigma else 6), 86 | num_res_blocks=num_res_blocks, 87 | attention_resolutions=tuple(attention_ds), 88 | dropout=dropout, 89 | channel_mult=channel_mult, 90 | num_classes=(NUM_CLASSES if class_cond else None), 91 | use_checkpoint=use_checkpoint, 92 | use_fp16=use_fp16, 93 | num_heads=num_heads, 94 | num_head_channels=num_head_channels, 95 | num_heads_upsample=num_heads_upsample, 96 | use_scale_shift_norm=use_scale_shift_norm, 97 | resblock_updown=resblock_updown, 98 | use_new_attention_order=use_new_attention_order, 99 | ) 100 | 101 | 102 | def i_DDPM(dataset_name = 'AFHQ'): 103 | if dataset_name in ['AFHQ', 'FFHQ']: 104 | return create_model(**AFHQ_DICT) 105 | elif dataset_name == 'IMAGENET': 106 | return create_model(**IMAGENET_DICT) 107 | else: 108 | print('Not implemented.') 109 | exit() 110 | -------------------------------------------------------------------------------- /utils/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps): 6 | betas = np.linspace(beta_start, beta_end, 7 | num_diffusion_timesteps, dtype=np.float64) 8 | assert betas.shape == (num_diffusion_timesteps,) 9 | return betas 10 | 11 | 12 | def extract(a, t, x_shape): 13 | """Extract coefficients from a based on t and reshape to make it 14 | broadcastable with x_shape.""" 15 | bs, = t.shape 16 | assert x_shape[0] == bs 17 | out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long()) 18 | assert out.shape == (bs,) 19 | out = out.reshape((bs,) + (1,) * (len(x_shape) - 1)) 20 | return out 21 | 22 | 23 | def denoising_step(xt, t, t_next, *, 24 | models, 25 | logvars, 26 | b, 27 | sampling_type='ddpm', 28 | eta=0.0, 29 | learn_sigma=False, 30 | ratio=1.0, 31 | out_x0_t=False, 32 | ): 33 | 34 | # Compute noise and variance 35 | if type(models) != list: 36 | model = models 37 | et = model(xt, t) 38 | if learn_sigma: 39 | et, logvar_learned = torch.split(et, et.shape[1] // 2, dim=1) 40 | logvar = logvar_learned 41 | else: 42 | logvar = extract(logvars, t, xt.shape) 43 | else: 44 | et = 0 45 | logvar = 0 46 | if ratio != 0.0: 47 | et_i = ratio * models[1](xt, t) 48 | if learn_sigma: 49 | et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1) 50 | logvar += logvar_learned 51 | else: 52 | logvar += ratio * extract(logvars, t, xt.shape) 53 | et += et_i 54 | if ratio != 1.0: 55 | et_i = (1 - ratio) * models[1](xt, t) 56 | if learn_sigma: 57 | et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1) 58 | logvar += logvar_learned 59 | else: 60 | logvar += (1 - ratio) * extract(logvars, t, xt.shape) 61 | et += et_i 62 | 63 | # Compute the next x 64 | bt = extract(b, t, xt.shape) 65 | at = extract((1.0 - b).cumprod(dim=0), t, xt.shape) 66 | 67 | if t_next.sum() == -t_next.shape[0]: 68 | at_next = torch.ones_like(at) 69 | else: 70 | at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape) 71 | 72 | xt_next = torch.zeros_like(xt) 73 | if sampling_type == 'ddpm': 74 | weight = bt / torch.sqrt(1 - at) 75 | 76 | mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et) 77 | noise = torch.randn_like(xt) 78 | mask = 1 - (t == 0).float() 79 | mask = mask.reshape((xt.shape[0],) + (1,) * (len(xt.shape) - 1)) 80 | xt_next = mean + mask * torch.exp(0.5 * logvar) * noise 81 | xt_next = xt_next.float() 82 | 83 | elif sampling_type == 'ddim': 84 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 85 | if eta == 0: 86 | xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et 87 | elif at > (at_next): 88 | print('Inversion process is only possible with eta = 0') 89 | raise ValueError 90 | else: 91 | c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt() 92 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 93 | xt_next = at_next.sqrt() * x0_t + c2 * et + c1 * torch.randn_like(xt) 94 | 95 | if out_x0_t == True: 96 | return xt_next, x0_t 97 | else: 98 | return xt_next 99 | 100 | 101 | -------------------------------------------------------------------------------- /utils/text_templates.py: -------------------------------------------------------------------------------- 1 | imagenet_templates = [ 2 | 'a bad photo of a {}.', 3 | 'a sculpture of a {}.', 4 | 'a photo of the hard to see {}.', 5 | 'a low resolution photo of the {}.', 6 | 'a rendering of a {}.', 7 | 'graffiti of a {}.', 8 | 'a bad photo of the {}.', 9 | 'a cropped photo of the {}.', 10 | 'a tattoo of a {}.', 11 | 'the embroidered {}.', 12 | 'a photo of a hard to see {}.', 13 | 'a bright photo of a {}.', 14 | 'a photo of a clean {}.', 15 | 'a photo of a dirty {}.', 16 | 'a dark photo of the {}.', 17 | 'a drawing of a {}.', 18 | 'a photo of my {}.', 19 | 'the plastic {}.', 20 | 'a photo of the cool {}.', 21 | 'a close-up photo of a {}.', 22 | 'a black and white photo of the {}.', 23 | 'a painting of the {}.', 24 | 'a painting of a {}.', 25 | 'a pixelated photo of the {}.', 26 | 'a sculpture of the {}.', 27 | 'a bright photo of the {}.', 28 | 'a cropped photo of a {}.', 29 | 'a plastic {}.', 30 | 'a photo of the dirty {}.', 31 | 'a jpeg corrupted photo of a {}.', 32 | 'a blurry photo of the {}.', 33 | 'a photo of the {}.', 34 | 'a good photo of the {}.', 35 | 'a rendering of the {}.', 36 | 'a {} in a video game.', 37 | 'a photo of one {}.', 38 | 'a doodle of a {}.', 39 | 'a close-up photo of the {}.', 40 | 'a photo of a {}.', 41 | 'the origami {}.', 42 | 'the {} in a video game.', 43 | 'a sketch of a {}.', 44 | 'a doodle of the {}.', 45 | 'a origami {}.', 46 | 'a low resolution photo of a {}.', 47 | 'the toy {}.', 48 | 'a rendition of the {}.', 49 | 'a photo of the clean {}.', 50 | 'a photo of a large {}.', 51 | 'a rendition of a {}.', 52 | 'a photo of a nice {}.', 53 | 'a photo of a weird {}.', 54 | 'a blurry photo of a {}.', 55 | 'a cartoon {}.', 56 | 'art of a {}.', 57 | 'a sketch of the {}.', 58 | 'a embroidered {}.', 59 | 'a pixelated photo of a {}.', 60 | 'itap of the {}.', 61 | 'a jpeg corrupted photo of the {}.', 62 | 'a good photo of a {}.', 63 | 'a plushie {}.', 64 | 'a photo of the nice {}.', 65 | 'a photo of the small {}.', 66 | 'a photo of the weird {}.', 67 | 'the cartoon {}.', 68 | 'art of the {}.', 69 | 'a drawing of the {}.', 70 | 'a photo of the large {}.', 71 | 'a black and white photo of a {}.', 72 | 'the plushie {}.', 73 | 'a dark photo of a {}.', 74 | 'itap of a {}.', 75 | 'graffiti of the {}.', 76 | 'a toy {}.', 77 | 'itap of my {}.', 78 | 'a photo of a cool {}.', 79 | 'a photo of a small {}.', 80 | 'a tattoo of the {}.', 81 | ] 82 | 83 | part_templates = [ 84 | 'the paw of a {}.', 85 | 'the nose of a {}.', 86 | 'the eye of the {}.', 87 | 'the ears of a {}.', 88 | 'an eye of a {}.', 89 | 'the tongue of a {}.', 90 | 'the fur of the {}.', 91 | 'colorful {} fur.', 92 | 'a snout of a {}.', 93 | 'the teeth of the {}.', 94 | 'the {}s fangs.', 95 | 'a claw of the {}.', 96 | 'the face of the {}', 97 | 'a neck of a {}', 98 | 'the head of the {}', 99 | ] 100 | 101 | imagenet_templates_small = [ 102 | 'a photo of a {}.', 103 | 'a rendering of a {}.', 104 | 'a cropped photo of the {}.', 105 | 'the photo of a {}.', 106 | 'a photo of a clean {}.', 107 | 'a photo of a dirty {}.', 108 | 'a dark photo of the {}.', 109 | 'a photo of my {}.', 110 | 'a photo of the cool {}.', 111 | 'a close-up photo of a {}.', 112 | 'a bright photo of the {}.', 113 | 'a cropped photo of a {}.', 114 | 'a photo of the {}.', 115 | 'a good photo of the {}.', 116 | 'a photo of one {}.', 117 | 'a close-up photo of the {}.', 118 | 'a rendition of the {}.', 119 | 'a photo of the clean {}.', 120 | 'a rendition of a {}.', 121 | 'a photo of a nice {}.', 122 | 'a good photo of a {}.', 123 | 'a photo of the nice {}.', 124 | 'a photo of the small {}.', 125 | 'a photo of the weird {}.', 126 | 'a photo of the large {}.', 127 | 'a photo of a cool {}.', 128 | 'a photo of a small {}.', 129 | ] -------------------------------------------------------------------------------- /datasets/celeba_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import torch 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from io import BytesIO 7 | import PIL 8 | from PIL import Image 9 | import torchvision.transforms as tfs 10 | import torchvision.utils as tvu 11 | 12 | class MultiResolutionDataset(Dataset): 13 | def __init__(self, path, transform, mark,resolution=256): 14 | self.path = path 15 | self.attr_path = os.path.join(self.path, "CelebAMask-HQ-attribute-anno.txt") 16 | self.img_path = os.path.join(self.path,"CelebA-HQ-img") 17 | self.resolution = resolution 18 | self.transform = transform 19 | self.transform_mask = tfs.Compose([ 20 | tfs.Resize((256,256),interpolation=PIL.Image.NEAREST), 21 | tfs.ToTensor()]) 22 | self.dataset = [] 23 | self.preprocess(mark) 24 | 25 | def preprocess(self,mark): 26 | """Preprocess the CelebA attribute file.""" 27 | lines = [line.rstrip() for line in open(self.attr_path, 'r')] 28 | 29 | lines = lines[2:] 30 | random.seed(1234) 31 | random.shuffle(lines) 32 | if mark == 0: 33 | lines = lines[4000:5000] 34 | else: 35 | lines = lines[15000:16000] 36 | for i, line in enumerate(lines): 37 | split = line.split() 38 | filename = split[0] 39 | self.dataset.append(filename) 40 | 41 | print('Finished preprocessing the CelebA dataset...') 42 | def __len__(self): 43 | return len(self.dataset) 44 | 45 | def __getitem__(self, index): 46 | filename = self.dataset[index] 47 | image = Image.open(os.path.join(self.img_path, filename)) 48 | id = int(filename[:-4])//2000 49 | maskname = filename[:-4].zfill(5) 50 | mask_root = os.path.join(self.path, "CelebAMask-HQ-mask-anno/{}/".format(id)) 51 | l_eye_path = os.path.join(mask_root, maskname+"_l_eye.png") 52 | r_eye_path = os.path.join(mask_root, maskname+"_r_eye.png") 53 | l_brow_path = os.path.join(mask_root, maskname+"_l_brow.png") 54 | r_brow_path = os.path.join(mask_root, maskname+"_r_brow.png") 55 | l_lip_path = os.path.join(mask_root, maskname+"_l_lip.png") 56 | u_lip_path = os.path.join(mask_root, maskname+"_u_lip.png") 57 | skin_path = os.path.join(mask_root, maskname+"_skin.png") 58 | mouth_path = os.path.join(mask_root, maskname+"_mouth.png") 59 | neck_path = os.path.join(mask_root, maskname+"_neck.png") 60 | if not os.path.exists(l_eye_path) or not os.path.exists(r_eye_path) or not os.path.exists(l_lip_path) or not os.path.exists(u_lip_path): 61 | mask_list = 0 62 | else: 63 | image_l_eye = self.transform_mask(Image.open(l_eye_path)) 64 | image_r_eye = self.transform_mask(Image.open(r_eye_path)) 65 | image_l_lip = self.transform_mask(Image.open(l_lip_path)) 66 | image_u_lip = self.transform_mask(Image.open(u_lip_path)) 67 | image_skin = self.transform_mask(Image.open(skin_path)) 68 | image_lip = torch.clamp(image_u_lip + image_l_lip,0,1) 69 | image_face = image_skin - image_l_eye - image_r_eye - image_lip 70 | image_skin = image_skin - image_l_eye - image_r_eye - image_lip 71 | if os.path.exists(l_brow_path): 72 | image_l_brow = self.transform_mask(Image.open(l_brow_path)) 73 | image_skin = image_skin-image_l_brow 74 | image_face = image_face-image_l_brow 75 | if os.path.exists(r_brow_path): 76 | image_r_brow = self.transform_mask(Image.open(r_brow_path)) 77 | image_skin = image_skin-image_r_brow 78 | image_face = image_face-image_r_brow 79 | if os.path.exists(neck_path): 80 | image_neck = self.transform_mask(Image.open(neck_path)) 81 | image_skin = image_skin+image_neck 82 | if os.path.exists(mouth_path): 83 | image_mouth = self.transform_mask(Image.open(mouth_path)) 84 | image_skin = image_skin-image_mouth 85 | image_face = image_face-image_mouth 86 | mask_list = [image_l_eye,image_r_eye,image_lip,torch.clamp(image_skin,0,1),torch.clamp(image_face,0,1)] 87 | return self.transform(image), mask_list 88 | 89 | 90 | ################################################################################ 91 | 92 | def get_celeba_dataset(data_root, config): 93 | transform = tfs.Compose([tfs.Resize(256),tfs.ToTensor(),tfs.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) 94 | train_dataset = MultiResolutionDataset(data_root, 95 | transform, 0,config.data.image_size) 96 | test_dataset = MultiResolutionDataset(data_root, 97 | transform, 1,config.data.image_size) 98 | 99 | 100 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /models/insight_face/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.insight_face.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | from models.insight_face.helpers import Conv_block, Linear_block, Depth_Wise, Residual 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class MobileFaceNet(Module): 10 | def __init__(self, embedding_size): 11 | super(MobileFaceNet, self).__init__() 12 | self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) 13 | self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) 14 | self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128) 15 | self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 16 | self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256) 17 | self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 18 | self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512) 19 | self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 20 | self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 21 | self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)) 22 | self.conv_6_flatten = Flatten() 23 | self.linear = Linear(512, embedding_size, bias=False) 24 | self.bn = BatchNorm1d(embedding_size) 25 | 26 | def forward(self, x): 27 | out = self.conv1(x) 28 | out = self.conv2_dw(out) 29 | out = self.conv_23(out) 30 | out = self.conv_3(out) 31 | out = self.conv_34(out) 32 | out = self.conv_4(out) 33 | out = self.conv_45(out) 34 | out = self.conv_5(out) 35 | out = self.conv_6_sep(out) 36 | out = self.conv_6_dw(out) 37 | out = self.conv_6_flatten(out) 38 | out = self.linear(out) 39 | out = self.bn(out) 40 | return l2_norm(out) 41 | 42 | 43 | 44 | 45 | 46 | 47 | ###################################################################################### 48 | 49 | class Backbone(Module): 50 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 51 | super(Backbone, self).__init__() 52 | assert input_size in [112, 224], "input_size should be 112 or 224" 53 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 54 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 55 | blocks = get_blocks(num_layers) 56 | if mode == 'ir': 57 | unit_module = bottleneck_IR 58 | elif mode == 'ir_se': 59 | unit_module = bottleneck_IR_SE 60 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 61 | BatchNorm2d(64), 62 | PReLU(64)) 63 | if input_size == 112: 64 | self.output_layer = Sequential(BatchNorm2d(512), 65 | Dropout(drop_ratio), 66 | Flatten(), 67 | Linear(512 * 7 * 7, 512), 68 | BatchNorm1d(512, affine=affine)) 69 | else: 70 | self.output_layer = Sequential(BatchNorm2d(512), 71 | Dropout(drop_ratio), 72 | Flatten(), 73 | Linear(512 * 14 * 14, 512), 74 | BatchNorm1d(512, affine=affine)) 75 | 76 | modules = [] 77 | for block in blocks: 78 | for bottleneck in block: 79 | modules.append(unit_module(bottleneck.in_channel, 80 | bottleneck.depth, 81 | bottleneck.stride)) 82 | self.body = Sequential(*modules) 83 | 84 | def forward(self, x): 85 | x = self.input_layer(x) 86 | x = self.body(x) 87 | x = self.output_layer(x) 88 | return l2_norm(x) 89 | 90 | 91 | def IR_50(input_size): 92 | """Constructs a ir-50 model.""" 93 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 94 | return model 95 | 96 | 97 | def IR_101(input_size): 98 | """Constructs a ir-101 model.""" 99 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 100 | return model 101 | 102 | 103 | def IR_152(input_size): 104 | """Constructs a ir-152 model.""" 105 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 106 | return model 107 | 108 | 109 | def IR_SE_50(input_size): 110 | """Constructs a ir_se-50 model.""" 111 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 112 | return model 113 | 114 | 115 | def IR_SE_101(input_size): 116 | """Constructs a ir_se-101 model.""" 117 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 118 | return model 119 | 120 | 121 | def IR_SE_152(input_size): 122 | """Constructs a ir_se-152 model.""" 123 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 124 | return model 125 | -------------------------------------------------------------------------------- /models/improved_ddpm/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /losses/clip_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import clip 5 | from PIL import Image 6 | from utils.text_templates import imagenet_templates 7 | 8 | 9 | class DirectionLoss(torch.nn.Module): 10 | 11 | def __init__(self, loss_type='mse'): 12 | super(DirectionLoss, self).__init__() 13 | 14 | self.loss_type = loss_type 15 | 16 | self.loss_func = { 17 | 'mse': torch.nn.MSELoss, 18 | 'cosine': torch.nn.CosineSimilarity, 19 | 'mae': torch.nn.L1Loss 20 | }[loss_type]() 21 | 22 | def forward(self, x, y): 23 | if self.loss_type == "cosine": 24 | return 1. - self.loss_func(x, y) 25 | 26 | return self.loss_func(x, y) 27 | 28 | class CLIPLoss(torch.nn.Module): 29 | def __init__(self, device, lambda_makeup_direction , lambda_direction , direction_loss_type='cosine', clip_model='ViT-B/32'): 30 | super(CLIPLoss, self).__init__() 31 | 32 | self.device = device 33 | self.lambda_makeup_direction = lambda_makeup_direction 34 | self.lambda_direction = lambda_direction 35 | self.model, clip_preprocess = clip.load(clip_model, device=self.device) 36 | 37 | self.clip_preprocess = clip_preprocess 38 | 39 | self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]. 40 | clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions 41 | clip_preprocess.transforms[4:]) # + skip convert PIL to tensor 42 | 43 | self.target_direction = None 44 | 45 | self.direction_loss = DirectionLoss(direction_loss_type) 46 | 47 | self.src_text_features = None 48 | self.target_text_features = None 49 | 50 | 51 | def tokenize(self, strings: list): 52 | return clip.tokenize(strings).to(self.device) 53 | 54 | def encode_text(self, tokens: list) -> torch.Tensor: 55 | return self.model.encode_text(tokens) 56 | 57 | def encode_images(self, images: torch.Tensor) -> torch.Tensor: 58 | images = self.preprocess(images).to(self.device) 59 | return self.model.encode_image(images) 60 | 61 | def get_text_features(self, class_str: str, templates=imagenet_templates, norm: bool = True) -> torch.Tensor: 62 | template_text = self.compose_text_with_templates(class_str, templates) 63 | 64 | tokens = clip.tokenize(template_text).to(self.device) 65 | 66 | text_features = self.encode_text(tokens).detach() 67 | 68 | if norm: 69 | text_features /= text_features.norm(dim=-1, keepdim=True) 70 | 71 | return text_features 72 | 73 | def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor: 74 | image_features = self.encode_images(img) 75 | 76 | if norm: 77 | image_features /= image_features.clone().norm(dim=-1, keepdim=True) 78 | 79 | return image_features 80 | 81 | def compute_text_direction(self, source_class: str, target_class: str) -> torch.Tensor: 82 | source_features = self.get_text_features(source_class) 83 | target_features = self.get_text_features(target_class) 84 | 85 | text_direction = (target_features - source_features).mean(axis=0, keepdim=True) 86 | text_direction /= text_direction.norm(dim=-1, keepdim=True) 87 | 88 | return text_direction 89 | 90 | def compose_text_with_templates(self, text: str, templates=imagenet_templates) -> list: 91 | return [template.format(text) for template in templates] 92 | 93 | def clip_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor: 94 | 95 | if self.target_direction is None: 96 | self.target_direction = self.compute_text_direction(source_class, target_class) 97 | 98 | src_encoding = self.get_image_features(src_img) 99 | target_encoding = self.get_image_features(target_img) 100 | 101 | edit_direction = (target_encoding - src_encoding) 102 | edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True) + 1e-7) 103 | return self.direction_loss(edit_direction, self.target_direction).mean() 104 | 105 | def clip_makeup_directional_loss(self, src_img: torch.Tensor, non_makeup_img: torch.Tensor, output_img: torch.Tensor, makeup_img: torch.Tensor) -> torch.Tensor: 106 | non_makeup_encoding = self.get_image_features(non_makeup_img) 107 | makeup_encoding = self.get_image_features(makeup_img) 108 | src_encoding = self.get_image_features(src_img) 109 | output_encoding = self.get_image_features(output_img) 110 | self.target_direction = (makeup_encoding - non_makeup_encoding) 111 | self.target_direction /= (self.target_direction.clone().norm(dim=-1,keepdim=True) + 1e-7) 112 | edit_direction = (output_encoding - src_encoding) 113 | edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True) + 1e-7) 114 | return self.direction_loss(edit_direction, self.target_direction).mean() 115 | 116 | def forward(self, src_img: torch.Tensor, non_makeup_img: torch.Tensor, output_img: torch.Tensor, makeup_img: torch.Tensor, src_txt = None, trg_txt = None): 117 | clip_loss = 0.0 118 | if self.lambda_makeup_direction == 1: 119 | clip_loss += self.clip_makeup_directional_loss(src_img, non_makeup_img, output_img, makeup_img) 120 | elif self.lambda_direction == 1: 121 | clip_loss += self.clip_directional_loss(src_img, src_txt, output_img, trg_txt) 122 | return clip_loss 123 | -------------------------------------------------------------------------------- /models/insight_face/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | 11 | 12 | class Conv_block(Module): 13 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 14 | super(Conv_block, self).__init__() 15 | self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False) 16 | self.bn = BatchNorm2d(out_c) 17 | self.prelu = PReLU(out_c) 18 | def forward(self, x): 19 | x = self.conv(x) 20 | x = self.bn(x) 21 | x = self.prelu(x) 22 | return x 23 | 24 | class Linear_block(Module): 25 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 26 | super(Linear_block, self).__init__() 27 | self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False) 28 | self.bn = BatchNorm2d(out_c) 29 | def forward(self, x): 30 | x = self.conv(x) 31 | x = self.bn(x) 32 | return x 33 | 34 | class Depth_Wise(Module): 35 | def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 36 | super(Depth_Wise, self).__init__() 37 | self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 38 | self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride) 39 | self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 40 | self.residual = residual 41 | def forward(self, x): 42 | if self.residual: 43 | short_cut = x 44 | x = self.conv(x) 45 | x = self.conv_dw(x) 46 | x = self.project(x) 47 | if self.residual: 48 | output = short_cut + x 49 | else: 50 | output = x 51 | return output 52 | 53 | class Residual(Module): 54 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 55 | super(Residual, self).__init__() 56 | modules = [] 57 | for _ in range(num_block): 58 | modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups)) 59 | self.model = Sequential(*modules) 60 | def forward(self, x): 61 | return self.model(x) 62 | 63 | 64 | 65 | 66 | ###################################################################################### 67 | 68 | 69 | class Flatten(Module): 70 | def forward(self, input): 71 | return input.view(input.size(0), -1) 72 | 73 | 74 | def l2_norm(input, axis=1): 75 | norm = torch.norm(input, 2, axis, True) 76 | output = torch.div(input, norm) 77 | return output 78 | 79 | 80 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 81 | """ A named tuple describing a ResNet block. """ 82 | 83 | 84 | def get_block(in_channel, depth, num_units, stride=2): 85 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 86 | 87 | 88 | def get_blocks(num_layers): 89 | if num_layers == 50: 90 | blocks = [ 91 | get_block(in_channel=64, depth=64, num_units=3), 92 | get_block(in_channel=64, depth=128, num_units=4), 93 | get_block(in_channel=128, depth=256, num_units=14), 94 | get_block(in_channel=256, depth=512, num_units=3) 95 | ] 96 | elif num_layers == 100: 97 | blocks = [ 98 | get_block(in_channel=64, depth=64, num_units=3), 99 | get_block(in_channel=64, depth=128, num_units=13), 100 | get_block(in_channel=128, depth=256, num_units=30), 101 | get_block(in_channel=256, depth=512, num_units=3) 102 | ] 103 | elif num_layers == 152: 104 | blocks = [ 105 | get_block(in_channel=64, depth=64, num_units=3), 106 | get_block(in_channel=64, depth=128, num_units=8), 107 | get_block(in_channel=128, depth=256, num_units=36), 108 | get_block(in_channel=256, depth=512, num_units=3) 109 | ] 110 | else: 111 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 112 | return blocks 113 | 114 | 115 | class SEModule(Module): 116 | def __init__(self, channels, reduction): 117 | super(SEModule, self).__init__() 118 | self.avg_pool = AdaptiveAvgPool2d(1) 119 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 120 | self.relu = ReLU(inplace=True) 121 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 122 | self.sigmoid = Sigmoid() 123 | 124 | def forward(self, x): 125 | module_input = x 126 | x = self.avg_pool(x) 127 | x = self.fc1(x) 128 | x = self.relu(x) 129 | x = self.fc2(x) 130 | x = self.sigmoid(x) 131 | return module_input * x 132 | 133 | 134 | class bottleneck_IR(Module): 135 | def __init__(self, in_channel, depth, stride): 136 | super(bottleneck_IR, self).__init__() 137 | if in_channel == depth: 138 | self.shortcut_layer = MaxPool2d(1, stride) 139 | else: 140 | self.shortcut_layer = Sequential( 141 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 142 | BatchNorm2d(depth) 143 | ) 144 | self.res_layer = Sequential( 145 | BatchNorm2d(in_channel), 146 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 147 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 148 | ) 149 | 150 | def forward(self, x): 151 | shortcut = self.shortcut_layer(x) 152 | res = self.res_layer(x) 153 | return res + shortcut 154 | 155 | 156 | class bottleneck_IR_SE(Module): 157 | def __init__(self, in_channel, depth, stride): 158 | super(bottleneck_IR_SE, self).__init__() 159 | if in_channel == depth: 160 | self.shortcut_layer = MaxPool2d(1, stride) 161 | else: 162 | self.shortcut_layer = Sequential( 163 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 164 | BatchNorm2d(depth) 165 | ) 166 | self.res_layer = Sequential( 167 | BatchNorm2d(in_channel), 168 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 169 | PReLU(depth), 170 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 171 | BatchNorm2d(depth), 172 | SEModule(depth, 16) 173 | ) 174 | 175 | def forward(self, x): 176 | shortcut = self.shortcut_layer(x) 177 | res = self.res_layer(x) 178 | return res + shortcut 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffAM: Diffusion-based Adversarial Makeup Transfer for Facial Privacy Protection (CVPR 2024) 2 | 3 | [![arXiv](https://img.shields.io/badge/paper-cvpr2024-cyan)](https://openaccess.thecvf.com/content/CVPR2024/html/Sun_DiffAM_Diffusion-based_Adversarial_Makeup_Transfer_for_Facial_Privacy_Protection_CVPR_2024_paper.html) [![arXiv](https://img.shields.io/badge/arXiv-2405.09882-red)](https://arxiv.org/abs/2405.09882) 4 | 5 | Official PyTorch implementation of paper "DiffAM: Diffusion-based Adversarial Makeup Transfer for Facial Privacy Protection". 6 | 7 | ## Abstract 8 | 9 | With the rapid development of face recognition (FR) sys tems, the privacy of face images on social media is facing severe challenges due to the abuse of unauthorized FR sys tems. Some studies utilize adversarial attack techniques to defend against malicious FR systems by generating adversarial examples. However, the generated adversarial examples, i.e., the protected face images, tend to suffer from sub par visual quality and low transferability. In this paper, we propose a novel face protection approach, dubbed DiffAM, which leverages the powerful generative ability of diffusion models to generate high-quality protected face images with adversarial makeup transferred from reference images. To be specific, we first introduce a makeup removal module to generate non-makeup images utilizing a fine-tuned diffusion model with guidance of textual prompts in CLIP space. As the inverse process of makeup transfer, makeup removal can make it easier to establish the deterministic relation ship between makeup domain and non-makeup domain regardless of elaborate text prompts. Then, with this relationship, a CLIP-based makeup loss along with an ensemble attack strategy is introduced to jointly guide the direction of adversarial makeup domain, achieving the generation of protected face images with natural-looking makeup and high black-box transferability. Extensive experiments demonstrate that DiffAM achieves higher visual quality and attack success rates with a gain of 12.98% under black-box setting compared with the state of the arts. 10 | 11 | ## Setup 12 | 13 | - ### Get code 14 | 15 | ```shell 16 | git clone https://github.com/HansSunY/DiffAM.git 17 | ``` 18 | 19 | - ### Build environment 20 | 21 | ```shell 22 | cd DiffAM 23 | # use anaconda to build environment 24 | conda create -n diffam python=3.8 25 | conda activate diffam 26 | # install packages 27 | pip install -r requirements.txt 28 | pip install git+https://github.com/openai/CLIP.git 29 | ``` 30 | 31 | ## Pretrained models and datasets 32 | 33 | - The weights required for the execution of DiffAM can be downloaded [here](https://drive.google.com/drive/folders/1L8caY-FVzp9razKMuAt37jCcgYh3fjVU?usp=sharing). 34 | 35 | ```shell 36 | mkdir pretrained 37 | mv celeba_hq.ckpt pretrained/ 38 | mv makeup.pt pretrained/ 39 | mv model_ir_se50.pth pretrained/ 40 | mv shape_predictor_68_face_landmarks.dat pretrained/ 41 | ``` 42 | 43 | - Please download the target FR models, MT-datasets and target images [here](https://drive.google.com/file/d/1IKiWLv99eUbv3llpj-dOegF3O7FWW29J/view?usp=sharing). Unzip the assets.zip file in `DiffAM/assets`. 44 | - Please download the [CelebAMask-HQ](https://drive.google.com/file/d/1badu11NqxGf6qM3PTTooQDJvQbejgbTv/view) dataset and unzip the file in `DiffAM/assets/datasets`. 45 | 46 | The final project should be like this: 47 | 48 | ```shell 49 | DiffAM 50 | └- assets 51 | └- datasets 52 | └- CelebAMask-HQ 53 | └- MT-dataset 54 | └- pairs 55 | └- target 56 | └- test 57 | └- models 58 | └- pretrained 59 | └- celeba_hq.ckpt 60 | └- ... 61 | └- ... 62 | ``` 63 | 64 | ## Quick Start 65 | 66 | ### Makeup removal (Optional) 67 | 68 | - We have included five makeup styles for adversarial makeup transfer in `DiffAM/assets/datasets/pairs`, which comprises pairs of makeup and non-makeup images, along with their corresponding masks. Therefore, you can directly **skip** the step and proceed to try out makeup transfer using the provided styles. 69 | - If you want to fine-tune the pretrained diffusion model for makeup removal and generate more pairs of makeup and non-makeup images, please run the following commands: 70 | 71 | ```shell 72 | python main.py --makeup_removal --config MT.yml --exp ./runs/test --do_train 1 --do_test 1 --n_train_img 200 --n_test_img 100 --n_iter 7 --t_0 300 --n_inv_step 40 --n_train_step 6 --n_test_step 40 --lr_clip_finetune 8e-6 --model_path pretrained/makeup.pt 73 | ``` 74 | 75 | Then you can remove the makeup with the trained model and put the pairs of makeup and non-makeup images along with corresponding masks in `DiffAM/assets/datasets/pairs` for the following adversarial makeup transfer. 76 | 77 | ### Adversarial makeup transfer 78 | 79 | To fine-tuned the pretrained diffusion model for adversarial makeup transfer, please run the following commands: 80 | 81 | ```shell 82 | python main.py --makeup_transfer --config celeba.yml --exp ./runs/test --do_train 1 --do_test 1 --n_train_img 200 --n_test_img 100 --n_iter 4 --t_0 60 --n_inv_step 20 --n_train_step 6 --n_test_step 6 --lr_clip_finetune 8e-6 --model_path pretrained/celeba_hq.ckpt --target_img 1 --target_model 2 --ref_img 'XMY-060' 83 | ``` 84 | 85 | - `target_img`: Choose the target identity to attack, a total of 4 options are provided (see details in our supplementary materials). 86 | 87 | - `target_model`: Choose the target FR model to attack, including `[IRSE50, IR152, Mobileface, Facenet]`. 88 | - `ref_img`: Choose the provided makeup style to transfer, including `['XMY-060', 'XYH-045', 'XMY-254', 'vRX912', 'vFG137']`. In addition, by generating pairs of makeup and non-makeup images through makeup removal, you can also transfer the makeup style you want. (Save `{ref_name}_m.png`, `{ref_name}_nm.png`, and `{ref_name}_mask.png` to `DiffAM/assets/datasets/pairs`.) 89 | 90 | ### Edit one image 91 | 92 | You can edit one image for makeup removal and transfer by running the following command: 93 | 94 | ```shell 95 | # makeup removal 96 | python main.py --edit_one_image_MR --config MT.yml --exp ./runs/test --n_iter 1 --t_0 300 --n_inv_step 40 --n_train_step 6 --n_test_step 40 --img_path {IMG_PATH} --model_path {MODEL_PATH} 97 | 98 | # adversarial makeup removal 99 | python main.py --edit_one_image_MT --config celeba.yml --exp ./runs/test --n_iter 1 --t_0 60 --n_inv_step 20 --n_train_step 6 --n_test_step 6 --img_path {IMG_PATH} --model_path {MODEL_PATH} 100 | ``` 101 | 102 | - `img_path`: Path of an image to edit. 103 | - `model_path`: Path of fine-tuned model. 104 | 105 | ## Citation 106 | 107 | ```bibtex 108 | @InProceedings{Sun_2024_CVPR, 109 | author = {Sun, Yuhao and Yu, Lingyun and Xie, Hongtao and Li, Jiaming and Zhang, Yongdong}, 110 | title = {DiffAM: Diffusion-based Adversarial Makeup Transfer for Facial Privacy Protection}, 111 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 112 | month = {June}, 113 | year = {2024}, 114 | pages = {24584-24594} 115 | } 116 | ``` 117 | 118 | ## Acknowledgments 119 | 120 | Our code structure is based on [DiffusionCLIP](https://github.com/gwang-kim/DiffusionCLIP?tab=readme-ov-file) and [AMT-GAN](https://github.com/CGCL-codes/AMT-GAN). 121 | -------------------------------------------------------------------------------- /utils/align_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) 3 | author: lzhbrian (https://lzhbrian.me) 4 | date: 2020.1.5 5 | note: code is heavily borrowed from 6 | https://github.com/NVlabs/ffhq-dataset 7 | http://dlib.net/face_landmark_detection.py.html 8 | 9 | requirements: 10 | apt install cmake 11 | conda install Pillow numpy scipy 12 | pip install dlib 13 | # download face landmark model from: 14 | # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 15 | """ 16 | from argparse import ArgumentParser 17 | import time 18 | import numpy as np 19 | import PIL 20 | import PIL.Image 21 | import os 22 | import scipy 23 | import scipy.ndimage 24 | import dlib 25 | import multiprocessing as mp 26 | import math 27 | 28 | from configs.paths_config import MODEL_PATHS 29 | 30 | SHAPE_PREDICTOR_PATH = MODEL_PATHS["shape_predictor"] 31 | 32 | 33 | def run_alignment(image_path, output_size): 34 | if not os.path.exists("pretrained/shape_predictor_68_face_landmarks.dat"): 35 | print('Downloading files for aligning face image...') 36 | os.system(f'wget -P pretrained/ http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2') 37 | os.system('bzip2 -dk pretrained/shape_predictor_68_face_landmarks.dat.bz2') 38 | print('Done.') 39 | predictor = dlib.shape_predictor("pretrained/shape_predictor_68_face_landmarks.dat") 40 | aligned_image = align_face(filepath=image_path, predictor=predictor, output_size=output_size, transform_size=output_size) 41 | print("Aligned image has shape: {}".format(aligned_image.size)) 42 | return aligned_image 43 | 44 | 45 | def get_landmark(filepath, predictor): 46 | """get landmark with dlib 47 | :return: np.array shape=(68, 2) 48 | """ 49 | detector = dlib.get_frontal_face_detector() 50 | 51 | img = dlib.load_rgb_image(filepath) 52 | dets = detector(img, 1) 53 | 54 | for k, d in enumerate(dets): 55 | shape = predictor(img, d) 56 | 57 | t = list(shape.parts()) 58 | a = [] 59 | for tt in t: 60 | a.append([tt.x, tt.y]) 61 | lm = np.array(a) 62 | return lm 63 | 64 | 65 | def align_face(filepath, predictor, output_size=256, transform_size=256): 66 | """ 67 | :param filepath: str 68 | :return: PIL Image 69 | """ 70 | 71 | lm = get_landmark(filepath, predictor) 72 | 73 | lm_chin = lm[0: 17] # left-right 74 | lm_eyebrow_left = lm[17: 22] # left-right 75 | lm_eyebrow_right = lm[22: 27] # left-right 76 | lm_nose = lm[27: 31] # top-down 77 | lm_nostrils = lm[31: 36] # top-down 78 | lm_eye_left = lm[36: 42] # left-clockwise 79 | lm_eye_right = lm[42: 48] # left-clockwise 80 | lm_mouth_outer = lm[48: 60] # left-clockwise 81 | lm_mouth_inner = lm[60: 68] # left-clockwise 82 | 83 | # Calculate auxiliary vectors. 84 | eye_left = np.mean(lm_eye_left, axis=0) 85 | eye_right = np.mean(lm_eye_right, axis=0) 86 | eye_avg = (eye_left + eye_right) * 0.5 87 | eye_to_eye = eye_right - eye_left 88 | mouth_left = lm_mouth_outer[0] 89 | mouth_right = lm_mouth_outer[6] 90 | mouth_avg = (mouth_left + mouth_right) * 0.5 91 | eye_to_mouth = mouth_avg - eye_avg 92 | 93 | # Choose oriented crop rectangle. 94 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 95 | x /= np.hypot(*x) 96 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 97 | y = np.flipud(x) * [-1, 1] 98 | c = eye_avg + eye_to_mouth * 0.1 99 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 100 | qsize = np.hypot(*x) * 2 101 | 102 | # read image 103 | img = PIL.Image.open(filepath) 104 | enable_padding = True 105 | 106 | # Shrink. 107 | shrink = int(np.floor(qsize / output_size * 0.5)) 108 | if shrink > 1: 109 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 110 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 111 | quad /= shrink 112 | qsize /= shrink 113 | 114 | # Crop. 115 | border = max(int(np.rint(qsize * 0.1)), 3) 116 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 117 | int(np.ceil(max(quad[:, 1])))) 118 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), 119 | min(crop[3] + border, img.size[1])) 120 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 121 | img = img.crop(crop) 122 | quad -= crop[0:2] 123 | 124 | # Pad. 125 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 126 | int(np.ceil(max(quad[:, 1])))) 127 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), 128 | max(pad[3] - img.size[1] + border, 0)) 129 | if enable_padding and max(pad) > border - 4: 130 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 131 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 132 | h, w, _ = img.shape 133 | y, x, _ = np.ogrid[:h, :w, :1] 134 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 135 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) 136 | blur = qsize * 0.02 137 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 138 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 139 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 140 | quad += pad[:2] 141 | 142 | # Transform. 143 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 144 | if output_size < transform_size: 145 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 146 | 147 | # Save aligned image. 148 | return img 149 | 150 | 151 | def chunks(lst, n): 152 | """Yield successive n-sized chunks from lst.""" 153 | for i in range(0, len(lst), n): 154 | yield lst[i:i + n] 155 | 156 | 157 | def extract_on_paths(file_paths): 158 | predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH) 159 | pid = mp.current_process().name 160 | print('\t{} is starting to extract on #{} images'.format(pid, len(file_paths))) 161 | tot_count = len(file_paths) 162 | count = 0 163 | for file_path, res_path in file_paths: 164 | count += 1 165 | if count % 100 == 0: 166 | print('{} done with {}/{}'.format(pid, count, tot_count)) 167 | try: 168 | res = align_face(file_path, predictor) 169 | res = res.convert('RGB') 170 | os.makedirs(os.path.dirname(res_path), exist_ok=True) 171 | res.save(res_path) 172 | except Exception: 173 | continue 174 | print('\tDone!') 175 | 176 | 177 | def parse_args(): 178 | parser = ArgumentParser(add_help=False) 179 | parser.add_argument('--num_threads', type=int, default=1) 180 | parser.add_argument('--root_path', type=str, default='') 181 | args = parser.parse_args() 182 | return args 183 | 184 | 185 | def run(args): 186 | root_path = args.root_path 187 | out_crops_path = root_path + '_crops' 188 | if not os.path.exists(out_crops_path): 189 | os.makedirs(out_crops_path, exist_ok=True) 190 | 191 | file_paths = [] 192 | for root, dirs, files in os.walk(root_path): 193 | for file in files: 194 | file_path = os.path.join(root, file) 195 | fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path)) 196 | res_path = '{}.jpg'.format(os.path.splitext(fname)[0]) 197 | if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path): 198 | continue 199 | file_paths.append((file_path, res_path)) 200 | 201 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 202 | print(len(file_chunks)) 203 | pool = mp.Pool(args.num_threads) 204 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 205 | tic = time.time() 206 | pool.map(extract_on_paths, file_chunks) 207 | toc = time.time() 208 | print('Mischief managed in {}s'.format(toc - tic)) 209 | 210 | 211 | if __name__ == '__main__': 212 | args = parse_args() 213 | run(args) 214 | -------------------------------------------------------------------------------- /utils/image_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torchvision import transforms 4 | import PIL 5 | from PIL import Image 6 | import copy 7 | def get_target_image(target_id): 8 | transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 9 | if target_id == 0: 10 | target_image = Image.open("assets/datasets/target/005869.jpg").convert('RGB') 11 | target_image = (transform(target_image).to("cuda").unsqueeze(0)) 12 | test_image = Image.open("assets/datasets/test/008793.jpg").convert('RGB') 13 | test_image = (transform(test_image).to("cuda").unsqueeze(0)) 14 | target_name = "005869" 15 | elif target_id == 1: 16 | target_image = Image.open("assets/datasets/target/085807.jpg").convert('RGB') 17 | target_image = (transform(target_image).to("cuda").unsqueeze(0)) 18 | test_image = Image.open("assets/datasets/test/047073.jpg").convert('RGB') 19 | test_image = (transform(test_image).to("cuda").unsqueeze(0)) 20 | target_name = "085807" 21 | elif target_id == 2: 22 | target_image = Image.open("assets/datasets/target/116481.jpg").convert('RGB') 23 | target_image = (transform(target_image).to("cuda").unsqueeze(0)) 24 | test_image = Image.open("assets/datasets/test/055622.jpg").convert('RGB') 25 | test_image = (transform(test_image).to("cuda").unsqueeze(0)) 26 | target_name = "116481" 27 | else: 28 | target_image = Image.open("assets/datasets/target/169284.jpg").convert('RGB') 29 | target_image = (transform(target_image).to("cuda").unsqueeze(0)) 30 | test_image = Image.open("assets/datasets/test/166607.jpg").convert('RGB') 31 | test_image = (transform(test_image).to("cuda").unsqueeze(0)) 32 | target_name = "169284" 33 | return target_image, test_image, target_name 34 | 35 | def get_ref_image(ref_id): 36 | train_transform = transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),inplace=True)]) 37 | mask_transform = transforms.Compose([transforms.Resize((256,256),interpolation=PIL.Image.NEAREST),ToTensor]) 38 | makeup_image = Image.open('assets/datasets/pairs/'+ref_id+'_m.png') 39 | non_makeup_image = Image.open('assets/datasets/pairs/'+ref_id+'_nm.png') 40 | makeup_mask = Image.open('assets/datasets/pairs/'+ref_id+'_mask.png') 41 | makeup_image = train_transform(makeup_image).to("cuda") 42 | non_makeup_image = train_transform(non_makeup_image).to("cuda") 43 | makeup_mask = mask_transform(makeup_mask) 44 | return makeup_image, non_makeup_image, makeup_mask 45 | 46 | def cal_hist(image): 47 | """ 48 | cal cumulative hist for channel list 49 | """ 50 | hists = [] 51 | for i in range(0, 3): 52 | channel = image[i] 53 | # channel = image[i, :, :] 54 | channel = torch.from_numpy(channel) 55 | # hist, _ = np.histogram(channel, bins=256, range=(0,255)) 56 | hist = torch.histc(channel, bins=256, min=0, max=256) 57 | hist = hist.numpy() 58 | # refHist=hist.view(256,1) 59 | sum = hist.sum() 60 | pdf = [v / sum for v in hist] 61 | for i in range(1, 256): 62 | pdf[i] = pdf[i - 1] + pdf[i] 63 | hists.append(pdf) 64 | return hists 65 | 66 | 67 | def cal_trans(ref, adj): 68 | """ 69 | calculate transfer function 70 | algorithm refering to wiki item: Histogram matching 71 | """ 72 | table = list(range(0, 256)) 73 | for i in list(range(1, 256)): 74 | for j in list(range(1, 256)): 75 | if ref[i] >= adj[j - 1] and ref[i] <= adj[j]: 76 | table[i] = j 77 | break 78 | table[255] = 255 79 | return table 80 | 81 | 82 | def histogram_matching(dstImg, refImg, index): 83 | """ 84 | perform histogram matching 85 | dstImg is transformed to have the same the histogram with refImg's 86 | index[0], index[1]: the index of pixels that need to be transformed in dstImg 87 | index[2], index[3]: the index of pixels that to compute histogram in refImg 88 | """ 89 | index = [x.cpu().numpy() for x in index] 90 | dstImg = dstImg.detach().cpu().numpy() 91 | refImg = refImg.detach().cpu().numpy() 92 | dst_align = [dstImg[i, index[0], index[1]] for i in range(0, 3)] 93 | ref_align = [refImg[i, index[2], index[3]] for i in range(0, 3)] 94 | hist_ref = cal_hist(ref_align) 95 | hist_dst = cal_hist(dst_align) 96 | tables = [cal_trans(hist_dst[i], hist_ref[i]) for i in range(0, 3)] 97 | 98 | mid = copy.deepcopy(dst_align) 99 | for i in range(0, 3): 100 | for k in range(0, len(index[0])): 101 | dst_align[i][k] = tables[i][int(mid[i][k])] 102 | 103 | for i in range(0, 3): 104 | dstImg[i, index[0], index[1]] = dst_align[i] 105 | 106 | dstImg = torch.FloatTensor(dstImg).cuda() 107 | return dstImg 108 | 109 | def to_var(x, requires_grad=True): 110 | if torch.cuda.is_available(): 111 | x = x.cuda() 112 | if not requires_grad: 113 | return Variable(x, requires_grad=requires_grad) 114 | else: 115 | return Variable(x) 116 | 117 | def de_norm(x): 118 | out = (x + 1) / 2 119 | return out.clamp(0, 1) 120 | 121 | def criterionHis(input_data, target_data,index,criterionL1,mask_src=1, mask_tar=1,): 122 | input_data = (de_norm(input_data) * 255).squeeze() 123 | target_data = (de_norm(target_data) * 255).squeeze() 124 | mask_src = mask_src.expand(1, 3, mask_src.size(2), mask_src.size(2)).squeeze() 125 | mask_tar = mask_tar.expand(1, 3, mask_tar.size(2), mask_tar.size(2)).squeeze() 126 | input_masked = input_data * mask_src 127 | target_masked = target_data * mask_tar 128 | # dstImg = (input_masked.data).cpu().clone() 129 | # refImg = (target_masked.data).cpu().clone() 130 | input_match = histogram_matching(input_masked, target_masked, index) 131 | input_match = to_var(input_match, requires_grad=False) 132 | loss = criterionL1(input_masked, input_match) 133 | return loss,input_match/255 134 | 135 | def ToTensor(pic): 136 | # handle PIL Image 137 | if pic.mode == 'I': 138 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 139 | elif pic.mode == 'I;16': 140 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 141 | else: 142 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 143 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 144 | if pic.mode == 'YCbCr': 145 | nchannel = 3 146 | elif pic.mode == 'I;16': 147 | nchannel = 1 148 | else: 149 | nchannel = len(pic.mode) 150 | img = img.view(pic.size[1], pic.size[0], nchannel) 151 | # put it from HWC to CHW format 152 | # yikes, this transpose takes 80% of the loading time/CPU 153 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 154 | if isinstance(img, torch.ByteTensor): 155 | return img.float() 156 | else: 157 | return img 158 | 159 | def rebound_box(mask_A, mask_B, mask_A_face): 160 | index_tmp = mask_A.nonzero() 161 | x_A_index = index_tmp[:, 1] 162 | y_A_index = index_tmp[:, 2] 163 | index_tmp = mask_B.nonzero() 164 | x_B_index = index_tmp[:, 1] 165 | y_B_index = index_tmp[:, 2] 166 | mask_A_temp = mask_A.copy_(mask_A) 167 | mask_B_temp = mask_B.copy_(mask_B) 168 | mask_A_temp[: ,min(x_A_index)-10:max(x_A_index)+11, min(y_A_index)-10:max(y_A_index)+11] =\ 169 | mask_A_face[: ,min(x_A_index)-10:max(x_A_index)+11, min(y_A_index)-10:max(y_A_index)+11] 170 | mask_B_temp[: ,min(x_B_index)-10:max(x_B_index)+11, min(y_B_index)-10:max(y_B_index)+11] =\ 171 | mask_A_face[: ,min(x_B_index)-10:max(x_B_index)+11, min(y_B_index)-10:max(y_B_index)+11] 172 | mask_A_temp = to_var(mask_A_temp, requires_grad=False) 173 | mask_B_temp = to_var(mask_B_temp, requires_grad=False) 174 | return mask_A_temp, mask_B_temp 175 | 176 | def mask_preprocess(mask_A, mask_B): 177 | index_tmp = mask_A.nonzero() 178 | x_A_index = index_tmp[:, 1] 179 | y_A_index = index_tmp[:, 2] 180 | index_tmp = mask_B.nonzero() 181 | x_B_index = index_tmp[:, 1] 182 | y_B_index = index_tmp[:, 2] 183 | mask_A = to_var(mask_A, requires_grad=False) 184 | mask_B = to_var(mask_B, requires_grad=False) 185 | index = [x_A_index, y_A_index, x_B_index, y_B_index] 186 | index_2 = [x_B_index, y_B_index, x_A_index, y_A_index] 187 | return mask_A, mask_B, index, index_2 -------------------------------------------------------------------------------- /models/improved_ddpm/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import logging 4 | import yaml 5 | import time 6 | import sys 7 | import os 8 | import torch 9 | import numpy as np 10 | from makeup_transfer import DiffAM_MT 11 | from makeup_removal import DiffAM_MR 12 | 13 | def parse_args_and_config(): 14 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 15 | 16 | # Mode 17 | parser.add_argument('--makeup_transfer', action='store_true') 18 | parser.add_argument('--makeup_removal', action='store_true') 19 | parser.add_argument('--edit_one_image_MT', action='store_true') 20 | parser.add_argument('--edit_one_image_MR', action='store_true') 21 | 22 | # Default 23 | parser.add_argument('--config', type=str, required=True, help='Path to the config file') 24 | parser.add_argument('--seed', type=int, default=1234, help='Random seed') 25 | parser.add_argument('--exp', type=str, default='./runs/', help='Path for saving running related data.') 26 | parser.add_argument('--comment', type=str, default='', help='A string for experiment comment') 27 | parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical') 28 | parser.add_argument('--ni', type=int, default=1, help="No interaction. Suitable for Slurm Job launcher") 29 | parser.add_argument('--align_face', type=int, default=0, help='align face or not') 30 | 31 | # Text 32 | parser.add_argument('--src_txts', type=str, default="face with makeup", action='append', help='Source text') 33 | parser.add_argument('--trg_txts', type=str, default="face without makeup", action='append', help='Target text') 34 | 35 | # Sampling 36 | parser.add_argument('--t_0', type=int, default=60, help='Return step in [0, 1000)') 37 | parser.add_argument('--n_inv_step', type=int, default=20, help='# of steps during generative pross for inversion') 38 | parser.add_argument('--n_train_step', type=int, default=6, help='# of steps during generative pross for train') 39 | parser.add_argument('--n_test_step', type=int, default=6, help='# of steps during generative pross for test') 40 | parser.add_argument('--sample_type', type=str, default='ddim', help='ddpm for Markovian sampling, ddim for non-Markovian sampling') 41 | parser.add_argument('--eta', type=float, default=0.0, help='Controls of varaince of the generative process') 42 | 43 | # Train & Test 44 | parser.add_argument('--do_train', type=int, default=1, help='Whether to train or not during CLIP finetuning') 45 | parser.add_argument('--do_test', type=int, default=1, help='Whether to test or not during CLIP finetuning') 46 | parser.add_argument('--save_train_image', type=int, default=1, help='Wheter to save training results during CLIP fineuning') 47 | parser.add_argument('--bs_train', type=int, default=1, help='Training batch size during CLIP fineuning') 48 | parser.add_argument('--bs_test', type=int, default=1, help='Test batch size during CLIP fineuning') 49 | parser.add_argument('--n_precomp_img', type=int, default=200, help='# of images to precompute latents') 50 | parser.add_argument('--n_train_img', type=int, default=200, help='# of training images') 51 | parser.add_argument('--n_test_img', type=int, default=100, help='# of test images') 52 | parser.add_argument('--model_path', type=str, default='pretrained/celeba_hq.ckpt', help='Test model path') 53 | parser.add_argument('--img_path', type=str, default=None, help='Image path to test') 54 | parser.add_argument('--deterministic_inv', type=int, default=1, help='Whether to use deterministic inversion during inference') 55 | parser.add_argument('--model_ratio', type=float, default=1, help='Degree of change, noise ratio from original and finetuned model.') 56 | 57 | 58 | # Loss & Optimization 59 | parser.add_argument('--MT_iter_without_adv', type=int, default=3, help='iters without adv loss') 60 | parser.add_argument('--MT_1_dir_loss_w', type=int, default=0.3, help='Weights of makeup direction loss in MT stage 1') 61 | parser.add_argument('--MT_2_dir_loss_w', type=int, default=0.5, help='Weights of makeup direction loss in MT stage 2') 62 | parser.add_argument('--MT_1_dis_loss_w', type=int, default=1, help='Weights of makeup distance loss') 63 | parser.add_argument('--MT_2_dis_loss_w', type=int, default=1.6, help='Weights of makeup distance loss') 64 | parser.add_argument('--MT_1_l1_loss_w', type=int, default=3, help='Weights of L1 loss in MT stage 1') 65 | parser.add_argument('--MT_2_l1_loss_w', type=int, default=5, help='Weights of L1 loss in MT stage 2') 66 | parser.add_argument('--MT_lpips_loss_w', type=int, default=10, help='Weights of LPIPS loss in MT') 67 | parser.add_argument('--MT_adv_loss_w', type=int, default=0.5, help='Weights of adv loss') 68 | 69 | parser.add_argument('--MR_clip_loss_w', type=int, default=5, help='Weights of CLIP loss in MR') 70 | parser.add_argument('--MR_l1_loss_w', type=float, default=2, help='Weights of L1 loss in MR') 71 | parser.add_argument('--MR_id_loss_w', type=float, default=1, help='Weights of ID loss in MR') 72 | parser.add_argument('--MR_lpips_loss_w', type=float, default=5, help='Weights of LPIPS loss in MR') 73 | 74 | parser.add_argument('--clip_model_name', type=str, default='ViT-B/16', help='ViT-B/16, ViT-B/32, RN50x16 etc') 75 | parser.add_argument('--lr_clip_finetune', type=float, default=8e-6, help='Initial learning rate for finetuning') 76 | parser.add_argument('--n_iter', type=int, default=4, help='# of iterations of a generative process with `n_train_img` images') 77 | parser.add_argument('--scheduler', type=int, default=1, help='Whether to increase the learning rate') 78 | parser.add_argument('--sch_gamma', type=float, default=1.3, help='Scheduler gamma') 79 | 80 | # Attack & Makeup 81 | parser.add_argument('--target_img', type=int, default=1, help='Target identities: 0, 1, 2, 3') 82 | parser.add_argument('--target_model', type=int, default=2, help='Target model for black-box attack. 0:ir152, 1:irse50, 2:mobile_face, 3:facenet') 83 | parser.add_argument('--ref_img', type=str, default='XMY-060', help='Reference image') 84 | 85 | args = parser.parse_args() 86 | 87 | # parse config file 88 | with open(os.path.join('configs', args.config), 'r') as f: 89 | config = yaml.safe_load(f) 90 | new_config = dict2namespace(config) 91 | 92 | if args.makeup_transfer: 93 | args.exp = args.exp + f'_MT_{new_config.data.category}_{args.ref_img}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_dis{args.MT_1_dis_loss_w}_dir{args.MT_1_dir_loss_w}_lr{args.lr_clip_finetune}' 94 | elif args.makeup_removal: 95 | args.exp = args.exp + f'_MR_{new_config.data.category}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.MR_id_loss_w}_l1{args.MR_l1_loss_w}_lr{args.lr_clip_finetune}' 96 | elif args.edit_one_image_MT: 97 | args.exp = args.exp + f'_E1_MT_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}' 98 | elif args.edit_one_image_MR: 99 | args.exp = args.exp + f'_E1_MR_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}' 100 | 101 | 102 | level = getattr(logging, args.verbose.upper(), None) 103 | if not isinstance(level, int): 104 | raise ValueError('level {} not supported'.format(args.verbose)) 105 | 106 | handler1 = logging.StreamHandler() 107 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 108 | handler1.setFormatter(formatter) 109 | logger = logging.getLogger() 110 | logger.addHandler(handler1) 111 | logger.setLevel(level) 112 | 113 | os.makedirs(args.exp, exist_ok=True) 114 | os.makedirs('checkpoint', exist_ok=True) 115 | os.makedirs('precomputed', exist_ok=True) 116 | os.makedirs('sample_real', exist_ok=True) 117 | os.makedirs('sample_fake', exist_ok=True) 118 | os.makedirs('sample_real_test', exist_ok=True) 119 | os.makedirs('sample_fake_test', exist_ok=True) 120 | os.makedirs('runs', exist_ok=True) 121 | os.makedirs(args.exp, exist_ok=True) 122 | 123 | args.image_folder = os.path.join(args.exp, 'image_samples') 124 | if not os.path.exists(args.image_folder): 125 | os.makedirs(args.image_folder) 126 | else: 127 | overwrite = False 128 | if args.ni: 129 | overwrite = True 130 | else: 131 | response = input("Image folder already exists. Overwrite? (Y/N)") 132 | if response.upper() == 'Y': 133 | overwrite = True 134 | 135 | if overwrite: 136 | # shutil.rmtree(args.image_folder) 137 | os.makedirs(args.image_folder, exist_ok=True) 138 | else: 139 | print("Output image folder exists. Program halted.") 140 | sys.exit(0) 141 | 142 | # add device 143 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 144 | logging.info("Using device: {}".format(device)) 145 | new_config.device = device 146 | 147 | # set random seed 148 | torch.manual_seed(args.seed) 149 | np.random.seed(args.seed) 150 | if torch.cuda.is_available(): 151 | torch.cuda.manual_seed_all(args.seed) 152 | 153 | torch.backends.cudnn.benchmark = True 154 | 155 | return args, new_config 156 | 157 | 158 | def dict2namespace(config): 159 | namespace = argparse.Namespace() 160 | for key, value in config.items(): 161 | if isinstance(value, dict): 162 | new_value = dict2namespace(value) 163 | else: 164 | new_value = value 165 | setattr(namespace, key, new_value) 166 | return namespace 167 | 168 | 169 | def main(): 170 | args, config = parse_args_and_config() 171 | print(">" * 80) 172 | logging.info("Exp instance id = {}".format(os.getpid())) 173 | logging.info("Exp comment = {}".format(args.comment)) 174 | logging.info("Config =") 175 | print("<" * 80) 176 | 177 | start_time = time.time() 178 | try: 179 | if args.makeup_transfer: 180 | runner = DiffAM_MT(args, config) 181 | runner.clip_finetune() 182 | elif args.makeup_removal: 183 | runner = DiffAM_MR(args, config) 184 | runner.clip_finetune() 185 | elif args.edit_one_image_MT: 186 | runner = DiffAM_MT(args, config) 187 | runner.edit_one_image() 188 | elif args.edit_one_image_MR: 189 | runner = DiffAM_MR(args, config) 190 | runner.edit_one_image() 191 | else: 192 | print('Choose one mode!') 193 | raise ValueError 194 | except Exception: 195 | logging.error(traceback.format_exc()) 196 | end_time = time.time() 197 | print("total_time:{}".format(end_time-start_time)) 198 | 199 | return 0 200 | 201 | 202 | if __name__ == '__main__': 203 | sys.exit(main()) 204 | -------------------------------------------------------------------------------- /models/ddpm/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_timestep_embedding(timesteps, embedding_dim): 7 | """ 8 | This matches the implementation in Denoising Diffusion Probabilistic Models: 9 | From Fairseq. 10 | Build sinusoidal embeddings. 11 | This matches the implementation in tensor2tensor, but differs slightly 12 | from the description in Section 3.5 of "Attention Is All You Need". 13 | """ 14 | assert len(timesteps.shape) == 1 15 | 16 | half_dim = embedding_dim // 2 17 | emb = math.log(10000) / (half_dim - 1) 18 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 19 | emb = emb.to(device=timesteps.device) 20 | emb = timesteps.float()[:, None] * emb[None, :] 21 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 22 | if embedding_dim % 2 == 1: # zero pad 23 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 24 | return emb 25 | 26 | 27 | def nonlinearity(x): 28 | # swish 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | def Normalize(in_channels): 33 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 34 | 35 | 36 | class Upsample(nn.Module): 37 | def __init__(self, in_channels, with_conv): 38 | super().__init__() 39 | self.with_conv = with_conv 40 | if self.with_conv: 41 | self.conv = torch.nn.Conv2d(in_channels, 42 | in_channels, 43 | kernel_size=3, 44 | stride=1, 45 | padding=1) 46 | 47 | def forward(self, x): 48 | x = torch.nn.functional.interpolate( 49 | x, scale_factor=2.0, mode="nearest") 50 | if self.with_conv: 51 | x = self.conv(x) 52 | return x 53 | 54 | 55 | class Downsample(nn.Module): 56 | def __init__(self, in_channels, with_conv): 57 | super().__init__() 58 | self.with_conv = with_conv 59 | if self.with_conv: 60 | # no asymmetric padding in torch conv, must do it ourselves 61 | self.conv = torch.nn.Conv2d(in_channels, 62 | in_channels, 63 | kernel_size=3, 64 | stride=2, 65 | padding=0) 66 | 67 | def forward(self, x): 68 | if self.with_conv: 69 | pad = (0, 1, 0, 1) 70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 71 | x = self.conv(x) 72 | else: 73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 74 | return x 75 | 76 | 77 | class ResnetBlock(nn.Module): 78 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 79 | dropout, temb_channels=512): 80 | super().__init__() 81 | self.in_channels = in_channels 82 | out_channels = in_channels if out_channels is None else out_channels 83 | self.out_channels = out_channels 84 | self.use_conv_shortcut = conv_shortcut 85 | 86 | self.norm1 = Normalize(in_channels) 87 | self.conv1 = torch.nn.Conv2d(in_channels, 88 | out_channels, 89 | kernel_size=3, 90 | stride=1, 91 | padding=1) 92 | self.temb_proj = torch.nn.Linear(temb_channels, 93 | out_channels) 94 | self.norm2 = Normalize(out_channels) 95 | self.dropout = torch.nn.Dropout(dropout) 96 | self.conv2 = torch.nn.Conv2d(out_channels, 97 | out_channels, 98 | kernel_size=3, 99 | stride=1, 100 | padding=1) 101 | if self.in_channels != self.out_channels: 102 | if self.use_conv_shortcut: 103 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 104 | out_channels, 105 | kernel_size=3, 106 | stride=1, 107 | padding=1) 108 | else: 109 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 110 | out_channels, 111 | kernel_size=1, 112 | stride=1, 113 | padding=0) 114 | 115 | def forward(self, x, temb): 116 | h = x 117 | h = self.norm1(h) 118 | h = nonlinearity(h) 119 | h = self.conv1(h) 120 | 121 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 122 | 123 | h = self.norm2(h) 124 | h = nonlinearity(h) 125 | h = self.dropout(h) 126 | h = self.conv2(h) 127 | 128 | if self.in_channels != self.out_channels: 129 | if self.use_conv_shortcut: 130 | x = self.conv_shortcut(x) 131 | else: 132 | x = self.nin_shortcut(x) 133 | 134 | return x + h 135 | 136 | 137 | class AttnBlock(nn.Module): 138 | def __init__(self, in_channels): 139 | super().__init__() 140 | self.in_channels = in_channels 141 | 142 | self.norm = Normalize(in_channels) 143 | self.q = torch.nn.Conv2d(in_channels, 144 | in_channels, 145 | kernel_size=1, 146 | stride=1, 147 | padding=0) 148 | self.k = torch.nn.Conv2d(in_channels, 149 | in_channels, 150 | kernel_size=1, 151 | stride=1, 152 | padding=0) 153 | self.v = torch.nn.Conv2d(in_channels, 154 | in_channels, 155 | kernel_size=1, 156 | stride=1, 157 | padding=0) 158 | self.proj_out = torch.nn.Conv2d(in_channels, 159 | in_channels, 160 | kernel_size=1, 161 | stride=1, 162 | padding=0) 163 | 164 | def forward(self, x): 165 | h_ = x 166 | h_ = self.norm(h_) 167 | q = self.q(h_) 168 | k = self.k(h_) 169 | v = self.v(h_) 170 | 171 | # compute attention 172 | b, c, h, w = q.shape 173 | q = q.reshape(b, c, h * w) 174 | q = q.permute(0, 2, 1) # b,hw,c 175 | k = k.reshape(b, c, h * w) # b,c,hw 176 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 177 | w_ = w_ * (int(c) ** (-0.5)) 178 | w_ = torch.nn.functional.softmax(w_, dim=2) 179 | 180 | # attend to values 181 | v = v.reshape(b, c, h * w) 182 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 183 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 184 | h_ = torch.bmm(v, w_) 185 | h_ = h_.reshape(b, c, h, w) 186 | 187 | h_ = self.proj_out(h_) 188 | 189 | return x + h_ 190 | 191 | 192 | class DDPM(nn.Module): 193 | def __init__(self, config): 194 | super().__init__() 195 | self.config = config 196 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) 197 | num_res_blocks = config.model.num_res_blocks 198 | attn_resolutions = config.model.attn_resolutions 199 | dropout = config.model.dropout 200 | in_channels = config.model.in_channels 201 | resolution = config.data.image_size 202 | resamp_with_conv = config.model.resamp_with_conv 203 | 204 | self.ch = ch 205 | self.temb_ch = self.ch * 4 206 | self.num_resolutions = len(ch_mult) 207 | self.num_res_blocks = num_res_blocks 208 | self.resolution = resolution 209 | self.in_channels = in_channels 210 | 211 | # timestep embedding 212 | self.temb = nn.Module() 213 | self.temb.dense = nn.ModuleList([ 214 | torch.nn.Linear(self.ch, 215 | self.temb_ch), 216 | torch.nn.Linear(self.temb_ch, 217 | self.temb_ch), 218 | ]) 219 | 220 | # downsampling 221 | self.conv_in = torch.nn.Conv2d(in_channels, 222 | self.ch, 223 | kernel_size=3, 224 | stride=1, 225 | padding=1) 226 | 227 | curr_res = resolution 228 | in_ch_mult = (1,) + ch_mult 229 | self.down = nn.ModuleList() 230 | block_in = None 231 | for i_level in range(self.num_resolutions): 232 | block = nn.ModuleList() 233 | attn = nn.ModuleList() 234 | block_in = ch * in_ch_mult[i_level] 235 | block_out = ch * ch_mult[i_level] 236 | for i_block in range(self.num_res_blocks): 237 | block.append(ResnetBlock(in_channels=block_in, 238 | out_channels=block_out, 239 | temb_channels=self.temb_ch, 240 | dropout=dropout)) 241 | block_in = block_out 242 | if curr_res in attn_resolutions: 243 | attn.append(AttnBlock(block_in)) 244 | down = nn.Module() 245 | down.block = block 246 | down.attn = attn 247 | if i_level != self.num_resolutions - 1: 248 | down.downsample = Downsample(block_in, resamp_with_conv) 249 | curr_res = curr_res // 2 250 | self.down.append(down) 251 | 252 | # middle 253 | self.mid = nn.Module() 254 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 255 | out_channels=block_in, 256 | temb_channels=self.temb_ch, 257 | dropout=dropout) 258 | self.mid.attn_1 = AttnBlock(block_in) 259 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 260 | out_channels=block_in, 261 | temb_channels=self.temb_ch, 262 | dropout=dropout) 263 | 264 | # upsampling 265 | self.up = nn.ModuleList() 266 | for i_level in reversed(range(self.num_resolutions)): 267 | block = nn.ModuleList() 268 | attn = nn.ModuleList() 269 | block_out = ch * ch_mult[i_level] 270 | skip_in = ch * ch_mult[i_level] 271 | for i_block in range(self.num_res_blocks + 1): 272 | if i_block == self.num_res_blocks: 273 | skip_in = ch * in_ch_mult[i_level] 274 | block.append(ResnetBlock(in_channels=block_in + skip_in, 275 | out_channels=block_out, 276 | temb_channels=self.temb_ch, 277 | dropout=dropout)) 278 | block_in = block_out 279 | if curr_res in attn_resolutions: 280 | attn.append(AttnBlock(block_in)) 281 | up = nn.Module() 282 | up.block = block 283 | up.attn = attn 284 | if i_level != 0: 285 | up.upsample = Upsample(block_in, resamp_with_conv) 286 | curr_res = curr_res * 2 287 | self.up.insert(0, up) # prepend to get consistent order 288 | 289 | # end 290 | self.norm_out = Normalize(block_in) 291 | self.conv_out = torch.nn.Conv2d(block_in, 292 | out_ch, 293 | kernel_size=3, 294 | stride=1, 295 | padding=1) 296 | 297 | def forward(self, x, t): 298 | assert x.shape[2] == x.shape[3] == self.resolution 299 | 300 | # timestep embedding 301 | temb = get_timestep_embedding(t, self.ch) 302 | temb = self.temb.dense[0](temb) 303 | temb = nonlinearity(temb) 304 | temb = self.temb.dense[1](temb) 305 | 306 | # downsampling 307 | hs = [self.conv_in(x)] 308 | for i_level in range(self.num_resolutions): 309 | for i_block in range(self.num_res_blocks): 310 | h = self.down[i_level].block[i_block](hs[-1], temb) 311 | if len(self.down[i_level].attn) > 0: 312 | h = self.down[i_level].attn[i_block](h) 313 | hs.append(h) 314 | if i_level != self.num_resolutions - 1: 315 | hs.append(self.down[i_level].downsample(hs[-1])) 316 | 317 | # middle 318 | h = hs[-1] 319 | h = self.mid.block_1(h, temb) 320 | h = self.mid.attn_1(h) 321 | h = self.mid.block_2(h, temb) 322 | 323 | # upsampling 324 | for i_level in reversed(range(self.num_resolutions)): 325 | for i_block in range(self.num_res_blocks + 1): 326 | h = self.up[i_level].block[i_block]( 327 | torch.cat([h, hs.pop()], dim=1), temb) 328 | if len(self.up[i_level].attn) > 0: 329 | h = self.up[i_level].attn[i_block](h) 330 | if i_level != 0: 331 | h = self.up[i_level].upsample(h) 332 | 333 | # end 334 | h = self.norm_out(h) 335 | h = nonlinearity(h) 336 | h = self.conv_out(h) 337 | return h 338 | -------------------------------------------------------------------------------- /models/improved_ddpm/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger based on OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import os.path as osp 9 | import json 10 | import time 11 | import datetime 12 | import tempfile 13 | import warnings 14 | from collections import defaultdict 15 | from contextlib import contextmanager 16 | 17 | DEBUG = 10 18 | INFO = 20 19 | WARN = 30 20 | ERROR = 40 21 | 22 | DISABLED = 50 23 | 24 | 25 | class KVWriter(object): 26 | def writekvs(self, kvs): 27 | raise NotImplementedError 28 | 29 | 30 | class SeqWriter(object): 31 | def writeseq(self, seq): 32 | raise NotImplementedError 33 | 34 | 35 | class HumanOutputFormat(KVWriter, SeqWriter): 36 | def __init__(self, filename_or_file): 37 | if isinstance(filename_or_file, str): 38 | self.file = open(filename_or_file, "wt") 39 | self.own_file = True 40 | else: 41 | assert hasattr(filename_or_file, "read"), ( 42 | "expected file or str, got %s" % filename_or_file 43 | ) 44 | self.file = filename_or_file 45 | self.own_file = False 46 | 47 | def writekvs(self, kvs): 48 | # Create strings for printing 49 | key2str = {} 50 | for (key, val) in sorted(kvs.items()): 51 | if hasattr(val, "__float__"): 52 | valstr = "%-8.3g" % val 53 | else: 54 | valstr = str(val) 55 | key2str[self._truncate(key)] = self._truncate(valstr) 56 | 57 | # Find max widths 58 | if len(key2str) == 0: 59 | print("WARNING: tried to write empty key-value dict") 60 | return 61 | else: 62 | keywidth = max(map(len, key2str.keys())) 63 | valwidth = max(map(len, key2str.values())) 64 | 65 | # Write out the data 66 | dashes = "-" * (keywidth + valwidth + 7) 67 | lines = [dashes] 68 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 69 | lines.append( 70 | "| %s%s | %s%s |" 71 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 72 | ) 73 | lines.append(dashes) 74 | self.file.write("\n".join(lines) + "\n") 75 | 76 | # Flush the output to the file 77 | self.file.flush() 78 | 79 | def _truncate(self, s): 80 | maxlen = 30 81 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 82 | 83 | def writeseq(self, seq): 84 | seq = list(seq) 85 | for (i, elem) in enumerate(seq): 86 | self.file.write(elem) 87 | if i < len(seq) - 1: # add space unless this is the last one 88 | self.file.write(" ") 89 | self.file.write("\n") 90 | self.file.flush() 91 | 92 | def close(self): 93 | if self.own_file: 94 | self.file.close() 95 | 96 | 97 | class JSONOutputFormat(KVWriter): 98 | def __init__(self, filename): 99 | self.file = open(filename, "wt") 100 | 101 | def writekvs(self, kvs): 102 | for k, v in sorted(kvs.items()): 103 | if hasattr(v, "dtype"): 104 | kvs[k] = float(v) 105 | self.file.write(json.dumps(kvs) + "\n") 106 | self.file.flush() 107 | 108 | def close(self): 109 | self.file.close() 110 | 111 | 112 | class CSVOutputFormat(KVWriter): 113 | def __init__(self, filename): 114 | self.file = open(filename, "w+t") 115 | self.keys = [] 116 | self.sep = "," 117 | 118 | def writekvs(self, kvs): 119 | # Add our current row to the history 120 | extra_keys = list(kvs.keys() - self.keys) 121 | extra_keys.sort() 122 | if extra_keys: 123 | self.keys.extend(extra_keys) 124 | self.file.seek(0) 125 | lines = self.file.readlines() 126 | self.file.seek(0) 127 | for (i, k) in enumerate(self.keys): 128 | if i > 0: 129 | self.file.write(",") 130 | self.file.write(k) 131 | self.file.write("\n") 132 | for line in lines[1:]: 133 | self.file.write(line[:-1]) 134 | self.file.write(self.sep * len(extra_keys)) 135 | self.file.write("\n") 136 | for (i, k) in enumerate(self.keys): 137 | if i > 0: 138 | self.file.write(",") 139 | v = kvs.get(k) 140 | if v is not None: 141 | self.file.write(str(v)) 142 | self.file.write("\n") 143 | self.file.flush() 144 | 145 | def close(self): 146 | self.file.close() 147 | 148 | 149 | def make_output_format(format, ev_dir, log_suffix=""): 150 | os.makedirs(ev_dir, exist_ok=True) 151 | if format == "stdout": 152 | return HumanOutputFormat(sys.stdout) 153 | elif format == "log": 154 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 155 | elif format == "json": 156 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 157 | elif format == "csv": 158 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 159 | else: 160 | raise ValueError("Unknown format specified: %s" % (format,)) 161 | 162 | 163 | # ================================================================ 164 | # API 165 | # ================================================================ 166 | 167 | 168 | def logkv(key, val): 169 | """ 170 | Log a value of some diagnostic 171 | Call this once for each diagnostic quantity, each iteration 172 | If called many times, last value will be used. 173 | """ 174 | get_current().logkv(key, val) 175 | 176 | 177 | def logkv_mean(key, val): 178 | """ 179 | The same as logkv(), but if called many times, values averaged. 180 | """ 181 | get_current().logkv_mean(key, val) 182 | 183 | 184 | def logkvs(d): 185 | """ 186 | Log a dictionary of key-value pairs 187 | """ 188 | for (k, v) in d.items(): 189 | logkv(k, v) 190 | 191 | 192 | def dumpkvs(): 193 | """ 194 | Write all of the diagnostics from the current iteration 195 | """ 196 | return get_current().dumpkvs() 197 | 198 | 199 | def getkvs(): 200 | return get_current().name2val 201 | 202 | 203 | def log(*args, level=INFO): 204 | """ 205 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 206 | """ 207 | get_current().log(*args, level=level) 208 | 209 | 210 | def debug(*args): 211 | log(*args, level=DEBUG) 212 | 213 | 214 | def info(*args): 215 | log(*args, level=INFO) 216 | 217 | 218 | def warn(*args): 219 | log(*args, level=WARN) 220 | 221 | 222 | def error(*args): 223 | log(*args, level=ERROR) 224 | 225 | 226 | def set_level(level): 227 | """ 228 | Set logging threshold on current logger. 229 | """ 230 | get_current().set_level(level) 231 | 232 | 233 | def set_comm(comm): 234 | get_current().set_comm(comm) 235 | 236 | 237 | def get_dir(): 238 | """ 239 | Get directory that log files are being written to. 240 | will be None if there is no output directory (i.e., if you didn't call start) 241 | """ 242 | return get_current().get_dir() 243 | 244 | 245 | record_tabular = logkv 246 | dump_tabular = dumpkvs 247 | 248 | 249 | @contextmanager 250 | def profile_kv(scopename): 251 | logkey = "wait_" + scopename 252 | tstart = time.time() 253 | try: 254 | yield 255 | finally: 256 | get_current().name2val[logkey] += time.time() - tstart 257 | 258 | 259 | def profile(n): 260 | """ 261 | Usage: 262 | @profile("my_func") 263 | def my_func(): code 264 | """ 265 | 266 | def decorator_with_name(func): 267 | def func_wrapper(*args, **kwargs): 268 | with profile_kv(n): 269 | return func(*args, **kwargs) 270 | 271 | return func_wrapper 272 | 273 | return decorator_with_name 274 | 275 | 276 | # ================================================================ 277 | # Backend 278 | # ================================================================ 279 | 280 | 281 | def get_current(): 282 | if Logger.CURRENT is None: 283 | _configure_default_logger() 284 | 285 | return Logger.CURRENT 286 | 287 | 288 | class Logger(object): 289 | DEFAULT = None # A logger with no output files. (See right below class definition) 290 | # So that you can still log to the terminal without setting up any output files 291 | CURRENT = None # Current logger being used by the free functions above 292 | 293 | def __init__(self, dir, output_formats, comm=None): 294 | self.name2val = defaultdict(float) # values this iteration 295 | self.name2cnt = defaultdict(int) 296 | self.level = INFO 297 | self.dir = dir 298 | self.output_formats = output_formats 299 | self.comm = comm 300 | 301 | # Logging API, forwarded 302 | # ---------------------------------------- 303 | def logkv(self, key, val): 304 | self.name2val[key] = val 305 | 306 | def logkv_mean(self, key, val): 307 | oldval, cnt = self.name2val[key], self.name2cnt[key] 308 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 309 | self.name2cnt[key] = cnt + 1 310 | 311 | def dumpkvs(self): 312 | if self.comm is None: 313 | d = self.name2val 314 | else: 315 | d = mpi_weighted_mean( 316 | self.comm, 317 | { 318 | name: (val, self.name2cnt.get(name, 1)) 319 | for (name, val) in self.name2val.items() 320 | }, 321 | ) 322 | if self.comm.rank != 0: 323 | d["dummy"] = 1 # so we don't get a warning about empty dict 324 | out = d.copy() # Return the dict for unit testing purposes 325 | for fmt in self.output_formats: 326 | if isinstance(fmt, KVWriter): 327 | fmt.writekvs(d) 328 | self.name2val.clear() 329 | self.name2cnt.clear() 330 | return out 331 | 332 | def log(self, *args, level=INFO): 333 | if self.level <= level: 334 | self._do_log(args) 335 | 336 | # Configuration 337 | # ---------------------------------------- 338 | def set_level(self, level): 339 | self.level = level 340 | 341 | def set_comm(self, comm): 342 | self.comm = comm 343 | 344 | def get_dir(self): 345 | return self.dir 346 | 347 | def close(self): 348 | for fmt in self.output_formats: 349 | fmt.close() 350 | 351 | # Misc 352 | # ---------------------------------------- 353 | def _do_log(self, args): 354 | for fmt in self.output_formats: 355 | if isinstance(fmt, SeqWriter): 356 | fmt.writeseq(map(str, args)) 357 | 358 | 359 | def get_rank_without_mpi_import(): 360 | # check environment variables here instead of importing mpi4py 361 | # to avoid calling MPI_Init() when this module is imported 362 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 363 | if varname in os.environ: 364 | return int(os.environ[varname]) 365 | return 0 366 | 367 | 368 | def mpi_weighted_mean(comm, local_name2valcount): 369 | """ 370 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 371 | Perform a weighted average over dicts that are each on a different node 372 | Input: local_name2valcount: dict mapping key -> (value, count) 373 | Returns: key -> mean 374 | """ 375 | all_name2valcount = comm.gather(local_name2valcount) 376 | if comm.rank == 0: 377 | name2sum = defaultdict(float) 378 | name2count = defaultdict(float) 379 | for n2vc in all_name2valcount: 380 | for (name, (val, count)) in n2vc.items(): 381 | try: 382 | val = float(val) 383 | except ValueError: 384 | if comm.rank == 0: 385 | warnings.warn( 386 | "WARNING: tried to compute mean on non-float {}={}".format( 387 | name, val 388 | ) 389 | ) 390 | else: 391 | name2sum[name] += val * count 392 | name2count[name] += count 393 | return {name: name2sum[name] / name2count[name] for name in name2sum} 394 | else: 395 | return {} 396 | 397 | 398 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 399 | """ 400 | If comm is provided, average all numerical stats across that comm 401 | """ 402 | if dir is None: 403 | dir = os.getenv("OPENAI_LOGDIR") 404 | if dir is None: 405 | dir = osp.join( 406 | tempfile.gettempdir(), 407 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 408 | ) 409 | assert isinstance(dir, str) 410 | dir = os.path.expanduser(dir) 411 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 412 | 413 | rank = get_rank_without_mpi_import() 414 | if rank > 0: 415 | log_suffix = log_suffix + "-rank%03i" % rank 416 | 417 | if format_strs is None: 418 | if rank == 0: 419 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 420 | else: 421 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 422 | format_strs = filter(None, format_strs) 423 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 424 | 425 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 426 | if output_formats: 427 | log("Logging to %s" % dir) 428 | 429 | 430 | def _configure_default_logger(): 431 | configure() 432 | Logger.DEFAULT = Logger.CURRENT 433 | 434 | 435 | def reset(): 436 | if Logger.CURRENT is not Logger.DEFAULT: 437 | Logger.CURRENT.close() 438 | Logger.CURRENT = Logger.DEFAULT 439 | log("Reset logger") 440 | 441 | 442 | @contextmanager 443 | def scoped_configure(dir=None, format_strs=None, comm=None): 444 | prevlogger = Logger.CURRENT 445 | configure(dir=dir, format_strs=format_strs, comm=comm) 446 | try: 447 | yield 448 | finally: 449 | Logger.CURRENT.close() 450 | Logger.CURRENT = prevlogger 451 | 452 | -------------------------------------------------------------------------------- /makeup_removal.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | import os 4 | import numpy as np 5 | import cv2 6 | from PIL import Image 7 | import torch 8 | from torch import nn 9 | import torchvision.utils as tvu 10 | import lpips 11 | 12 | from models.ddpm.diffusion import DDPM 13 | from models.improved_ddpm.script_util import i_DDPM 14 | from utils.diffusion_utils import get_beta_schedule, denoising_step 15 | from losses import id_loss 16 | from losses.clip_loss import CLIPLoss 17 | from datasets.data_utils import get_dataset, get_dataloader 18 | from configs.paths_config import DATASET_PATHS, MODEL_PATHS 19 | from utils.align_utils import run_alignment 20 | 21 | 22 | class DiffAM_MR(object): 23 | def __init__(self, args, config, device=None): 24 | self.args = args 25 | self.config = config 26 | if device is None: 27 | device = torch.device( 28 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 29 | self.device = device 30 | 31 | self.model_var_type = config.model.var_type 32 | betas = get_beta_schedule( 33 | beta_start=config.diffusion.beta_start, 34 | beta_end=config.diffusion.beta_end, 35 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps 36 | ) 37 | self.betas = torch.from_numpy(betas).float().to(self.device) 38 | self.num_timesteps = betas.shape[0] 39 | 40 | alphas = 1.0 - betas 41 | alphas_cumprod = np.cumprod(alphas, axis=0) 42 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 43 | posterior_variance = betas * \ 44 | (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 45 | if self.model_var_type == "fixedlarge": 46 | self.logvar = np.log(np.append(posterior_variance[1], betas[1:])) 47 | 48 | elif self.model_var_type == 'fixedsmall': 49 | self.logvar = np.log(np.maximum(posterior_variance, 1e-20)) 50 | 51 | self.src_txt = self.args.src_txts # "face with makeup" 52 | self.trg_txt = self.args.trg_txts # "face without makeup" 53 | 54 | def clip_finetune(self): 55 | print(self.args.exp) 56 | print(f' {self.src_txt}') 57 | print(f'-> {self.trg_txt}') 58 | 59 | # ----------- Model -----------# 60 | url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt" 61 | 62 | model = i_DDPM() 63 | if self.args.model_path: 64 | init_ckpt = torch.load(self.args.model_path) 65 | else: 66 | init_ckpt = torch.hub.load_state_dict_from_url( 67 | url, map_location=self.device) 68 | learn_sigma = True 69 | print("Original diffusion Model loaded.") 70 | model.load_state_dict(init_ckpt) 71 | model.to(self.device) 72 | model = torch.nn.DataParallel(model) 73 | 74 | # ----------- Optimizer and Scheduler -----------# 75 | print(f"Setting optimizer with lr={self.args.lr_clip_finetune}") 76 | optim_ft = torch.optim.Adam( 77 | model.parameters(), weight_decay=0, lr=self.args.lr_clip_finetune) 78 | init_opt_ckpt = optim_ft.state_dict() 79 | scheduler_ft = torch.optim.lr_scheduler.StepLR( 80 | optim_ft, step_size=1, gamma=self.args.sch_gamma) 81 | init_sch_ckpt = scheduler_ft.state_dict() 82 | 83 | # ----------- Loss -----------# 84 | print("Loading losses") 85 | clip_loss_func = CLIPLoss( 86 | self.device, 87 | lambda_makeup_direction=0, 88 | lambda_direction=1, 89 | clip_model=self.args.clip_model_name) 90 | id_loss_func = id_loss.IDLoss().to(self.device).eval() 91 | loss_fn_alex = lpips.LPIPS(net='alex').to(self.device) 92 | 93 | # ----------- Precompute Latents -----------# 94 | print("Prepare identity latent") 95 | seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0 96 | seq_inv = [int(s) for s in list(seq_inv)] 97 | seq_inv_next = [-1] + list(seq_inv[:-1]) 98 | 99 | n = self.args.bs_train 100 | img_lat_pairs_dic = {} 101 | for mode in ['train', 'test']: 102 | img_lat_pairs = [] 103 | pairs_path = os.path.join('precomputed/', 104 | f'{self.config.data.category}_{mode}_t{self.args.t_0}_nim{self.args.n_precomp_img}_ninv{self.args.n_inv_step}_pairs.pth') 105 | print(pairs_path) 106 | if os.path.exists(pairs_path): 107 | print(f'{mode} pairs exists') 108 | img_lat_pairs_dic[mode] = torch.load(pairs_path) 109 | for step, (x0, x_id, x_lat, mask) in enumerate(img_lat_pairs_dic[mode]): 110 | tvu.save_image( 111 | (x0 + 1) * 0.5, os.path.join(self.args.image_folder, f'{mode}_{step}_0_orig.png')) 112 | tvu.save_image((x_id + 1) * 0.5, os.path.join(self.args.image_folder, 113 | f'{mode}_{step}_1_rec_ninv{self.args.n_inv_step}.png')) 114 | mask_image = Image.fromarray( 115 | mask.detach().clone().squeeze().cpu().numpy()) 116 | mask_image.convert('RGB').save(os.path.join( 117 | self.args.image_folder, f'{mode}_{step}_0_mask.png')) 118 | if step == self.args.n_precomp_img - 1: 119 | break 120 | continue 121 | else: 122 | train_dataset, test_dataset = get_dataset( 123 | self.config.data.dataset, DATASET_PATHS, self.config) 124 | loader_dic = get_dataloader(train_dataset, test_dataset, bs_train=self.args.bs_train, 125 | num_workers=self.config.data.num_workers) 126 | loader = loader_dic[mode] 127 | 128 | for step, (img, mask) in enumerate(loader): 129 | x0 = img.to(self.config.device) 130 | tvu.save_image( 131 | (x0 + 1) * 0.5, os.path.join(self.args.image_folder, f'{mode}_{step}_0_orig.png')) 132 | mask_image = Image.fromarray( 133 | mask.detach().clone().squeeze().cpu().numpy()) 134 | mask_image.convert('RGB').save(os.path.join( 135 | self.args.image_folder, f'{mode}_{step}_0_mask.png')) 136 | x = x0.clone() 137 | model.eval() 138 | with torch.no_grad(): 139 | with tqdm(total=len(seq_inv), desc=f"Inversion process {mode} {step}") as progress_bar: 140 | for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))): 141 | t = (torch.ones(n) * i).to(self.device) 142 | t_prev = (torch.ones(n) * j).to(self.device) 143 | 144 | x = denoising_step(x, t=t, t_next=t_prev, models=model, 145 | logvars=self.logvar, 146 | sampling_type='ddim', 147 | b=self.betas, 148 | eta=0, 149 | learn_sigma=learn_sigma) 150 | 151 | progress_bar.update(1) 152 | x_lat = x.clone() 153 | tvu.save_image((x_lat + 1) * 0.5, os.path.join(self.args.image_folder, 154 | f'{mode}_{step}_1_lat_ninv{self.args.n_inv_step}.png')) 155 | 156 | with tqdm(total=len(seq_inv), desc=f"Generative process {mode} {step}") as progress_bar: 157 | for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))): 158 | t = (torch.ones(n) * i).to(self.device) 159 | t_next = (torch.ones(n) * j).to(self.device) 160 | 161 | x = denoising_step(x, t=t, t_next=t_next, models=model, 162 | logvars=self.logvar, 163 | sampling_type=self.args.sample_type, 164 | b=self.betas, 165 | learn_sigma=learn_sigma) 166 | progress_bar.update(1) 167 | 168 | img_lat_pairs.append( 169 | [x0, x.detach().clone(), x_lat.detach().clone(), mask]) 170 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 171 | f'{mode}_{step}_1_rec_ninv{self.args.n_inv_step}.png')) 172 | if step == self.args.n_precomp_img - 1: 173 | break 174 | 175 | img_lat_pairs_dic[mode] = img_lat_pairs 176 | pairs_path = os.path.join('precomputed/', 177 | f'{self.config.data.category}_{mode}_t{self.args.t_0}_nim{self.args.n_precomp_img}_ninv{self.args.n_inv_step}_pairs.pth') 178 | torch.save(img_lat_pairs, pairs_path) 179 | 180 | # ----------- Finetune Diffusion Models -----------# 181 | print("Start finetuning") 182 | print( 183 | f"Sampling type: {self.args.sample_type.upper()} with eta {self.args.eta}") 184 | if self.args.n_train_step != 0: 185 | seq_train = np.linspace( 186 | 0, 1, self.args.n_train_step) * self.args.t_0 187 | seq_train = [int(s) for s in list(seq_train)] 188 | print('Uniform skip type') 189 | else: 190 | seq_train = list(range(self.args.t_0)) 191 | print('No skip') 192 | seq_train_next = [-1] + list(seq_train[:-1]) 193 | 194 | seq_test = np.linspace(0, 1, self.args.n_test_step) * self.args.t_0 195 | seq_test = [int(s) for s in list(seq_test)] 196 | seq_test_next = [-1] + list(seq_test[:-1]) 197 | 198 | print(f"CHANGE {self.src_txt} TO {self.trg_txt}") 199 | model.module.load_state_dict(init_ckpt) 200 | optim_ft.load_state_dict(init_opt_ckpt) 201 | scheduler_ft.load_state_dict(init_sch_ckpt) 202 | clip_loss_func.target_direction = None 203 | 204 | # ----------- Train -----------# 205 | for it_out in range(self.args.n_iter): 206 | exp_id = os.path.split(self.args.exp)[-1] 207 | save_name = f'checkpoint/{exp_id}_{self.trg_txt.replace(" ", "_")}-{it_out}.pth' 208 | if self.args.do_train: 209 | if os.path.exists(save_name): 210 | print(f'{save_name} already exists.') 211 | model.module.load_state_dict(torch.load(save_name)) 212 | continue 213 | else: 214 | for step, (x0, x_id, x_lat, _) in enumerate(img_lat_pairs_dic['train']): 215 | model.train() 216 | time_in_start = time.time() 217 | 218 | optim_ft.zero_grad() 219 | x = x_lat.clone() 220 | 221 | with tqdm(total=len(seq_train), desc=f"CLIP iteration") as progress_bar: 222 | for t_it, (i, j) in enumerate(zip(reversed(seq_train), reversed(seq_train_next))): 223 | t = (torch.ones(n) * i).to(self.device) 224 | t_next = (torch.ones(n) * j).to(self.device) 225 | 226 | x = denoising_step(x, t=t, t_next=t_next, models=model, 227 | logvars=self.logvar, 228 | sampling_type=self.args.sample_type, 229 | b=self.betas, 230 | eta=self.args.eta, 231 | learn_sigma=learn_sigma) 232 | 233 | progress_bar.update(1) 234 | tvu.save_image( 235 | (x0+1)/2, './sample_real/sample_{}.png'.format(step)) 236 | tvu.save_image( 237 | (x+1)/2, './sample_fake/sample_{}.png'.format(step)) 238 | 239 | loss_clip = (2 - clip_loss_func(x0, None, 240 | x, None, self.src_txt, self.trg_txt)) / 2 241 | loss_clip = -torch.log(loss_clip) 242 | loss_id = torch.mean(id_loss_func(x0, x)) 243 | loss_l1 = nn.L1Loss()(x0, x) 244 | loss_lpips = loss_fn_alex(x0, x) 245 | loss = self.args.MR_clip_loss_w * loss_clip + self.args.MR_id_loss_w * loss_id + \ 246 | self.args.MR_l1_loss_w * loss_l1 + self.args.MR_lpips_loss_w * loss_lpips 247 | loss.backward() 248 | 249 | optim_ft.step() 250 | print( 251 | f"CLIP {step}-{it_out}: loss_id: {loss_id:.3f}, loss_clip: {loss_clip:.3f}") 252 | 253 | if self.args.save_train_image: 254 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 255 | f'train_{step}_2_clip_{self.trg_txt.replace(" ", "_")}_{it_out}_ngen{self.args.n_train_step}.png')) 256 | time_in_end = time.time() 257 | print( 258 | f"Training for 1 image takes {time_in_end - time_in_start:.4f}s") 259 | if step == self.args.n_train_img - 1: 260 | break 261 | 262 | if isinstance(model, nn.DataParallel): 263 | torch.save(model.module.state_dict(), save_name) 264 | else: 265 | torch.save(model.state_dict(), save_name) 266 | print(f'Model {save_name} is saved.') 267 | scheduler_ft.step() 268 | 269 | # ----------- Eval -----------# 270 | if self.args.do_test: 271 | if not self.args.do_train: 272 | print(save_name) 273 | model.module.load_state_dict(torch.load(save_name)) 274 | 275 | model.eval() 276 | img_lat_pairs = img_lat_pairs_dic[mode] 277 | for step, (x0, x_id, x_lat, _) in enumerate(img_lat_pairs): 278 | with torch.no_grad(): 279 | x = x_lat 280 | with tqdm(total=len(seq_test), desc=f"Eval iteration") as progress_bar: 281 | for i, j in zip(reversed(seq_test), reversed(seq_test_next)): 282 | t = (torch.ones(n) * i).to(self.device) 283 | t_next = (torch.ones(n) * j).to(self.device) 284 | 285 | x = denoising_step(x, t=t, t_next=t_next, models=model, 286 | logvars=self.logvar, 287 | sampling_type=self.args.sample_type, 288 | b=self.betas, 289 | eta=self.args.eta, 290 | learn_sigma=learn_sigma) 291 | 292 | progress_bar.update(1) 293 | tvu.save_image( 294 | (x0+1)/2, './sample_real_test/sample_{}.png'.format(step)) 295 | tvu.save_image( 296 | (x+1)/2, './sample_fake_test/sample_{}.png'.format(step)) 297 | print(f"Eval {step}-{it_out}") 298 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 299 | f'{mode}_{step}_2_clip_{self.trg_txt.replace(" ", "_")}_{it_out}_ngen{self.args.n_test_step}.png')) 300 | if step == self.args.n_test_img - 1: 301 | break 302 | 303 | def edit_one_image(self): 304 | # ----------- Data -----------# 305 | n = self.args.bs_test 306 | img = Image.open(self.args.img_path).convert("RGB") 307 | img = img.resize((self.config.data.image_size, 308 | self.config.data.image_size), Image.ANTIALIAS) 309 | img = np.array(img)/255 310 | img = torch.from_numpy(img).type(torch.FloatTensor).permute( 311 | 2, 0, 1).unsqueeze(dim=0).repeat(n, 1, 1, 1) 312 | img = img.to(self.config.device) 313 | tvu.save_image(img, os.path.join( 314 | self.args.image_folder, f'0_orig.png')) 315 | x0 = (img - 0.5) * 2. 316 | 317 | models = [] 318 | model_paths = [None, self.args.model_path] 319 | 320 | for model_path in model_paths: 321 | model_i = i_DDPM() 322 | if model_path: 323 | ckpt = torch.load(model_path) 324 | else: 325 | ckpt = torch.load("pretrained/makeup.pt") 326 | learn_sigma = True 327 | model_i.load_state_dict(ckpt) 328 | model_i.to(self.device) 329 | model_i = torch.nn.DataParallel(model_i) 330 | model_i.eval() 331 | print(f"{model_path} is loaded.") 332 | models.append(model_i) 333 | 334 | with torch.no_grad(): 335 | # ---------------- Invert Image to Latent in case of Deterministic Inversion process -------------------# 336 | if self.args.deterministic_inv: 337 | x_lat_path = os.path.join( 338 | self.args.image_folder, f'x_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}.pth') 339 | if not os.path.exists(x_lat_path): 340 | seq_inv = np.linspace( 341 | 0, 1, self.args.n_inv_step) * self.args.t_0 342 | seq_inv = [int(s) for s in list(seq_inv)] 343 | seq_inv_next = [-1] + list(seq_inv[:-1]) 344 | 345 | x = x0.clone() 346 | with tqdm(total=len(seq_inv), desc=f"Inversion process ") as progress_bar: 347 | for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))): 348 | t = (torch.ones(n) * i).to(self.device) 349 | t_prev = (torch.ones(n) * j).to(self.device) 350 | 351 | x = denoising_step(x, t=t, t_next=t_prev, models=models[0], 352 | logvars=self.logvar, 353 | sampling_type='ddim', 354 | b=self.betas, 355 | eta=0, 356 | learn_sigma=learn_sigma, 357 | ratio=0, 358 | ) 359 | 360 | progress_bar.update(1) 361 | x_lat = x.clone() 362 | torch.save(x_lat, x_lat_path) 363 | else: 364 | print('Latent exists.') 365 | x_lat = torch.load(x_lat_path) 366 | 367 | # ----------- Generative Process -----------# 368 | print(f"Sampling type: {self.args.sample_type.upper()} with eta {self.args.eta}, " 369 | f" Steps: {self.args.n_test_step}/{self.args.t_0}") 370 | if self.args.n_test_step != 0: 371 | seq_test = np.linspace( 372 | 0, 1, self.args.n_test_step) * self.args.t_0 373 | seq_test = [int(s) for s in list(seq_test)] 374 | print('Uniform skip type') 375 | else: 376 | seq_test = list(range(self.args.t_0)) 377 | print('No skip') 378 | seq_test_next = [-1] + list(seq_test[:-1]) 379 | 380 | for it in range(self.args.n_iter): 381 | if self.args.deterministic_inv: 382 | x = x_lat.clone() 383 | else: 384 | e = torch.randn_like(x0) 385 | a = (1 - self.betas).cumprod(dim=0) 386 | x = x0 * a[self.args.t_0 - 1].sqrt() + e * \ 387 | (1.0 - a[self.args.t_0 - 1]).sqrt() 388 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 389 | f'1_lat_ninv{self.args.n_inv_step}.png')) 390 | 391 | with tqdm(total=len(seq_test), desc="Generative process {}".format(it)) as progress_bar: 392 | for i, j in zip(reversed(seq_test), reversed(seq_test_next)): 393 | t = (torch.ones(n) * i).to(self.device) 394 | t_next = (torch.ones(n) * j).to(self.device) 395 | 396 | x = denoising_step(x, t=t, t_next=t_next, models=models, 397 | logvars=self.logvar, 398 | sampling_type=self.args.sample_type, 399 | b=self.betas, 400 | eta=self.args.eta, 401 | learn_sigma=learn_sigma, 402 | ratio=self.args.model_ratio) 403 | 404 | # added intermediate step vis 405 | if (i - 99) % 100 == 0: 406 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 407 | f'2_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}_ngen{self.args.n_test_step}_{i}_it{it}.png')) 408 | progress_bar.update(1) 409 | 410 | x0 = x.clone() 411 | if self.args.model_path: 412 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 413 | f"3_gen_t{self.args.t_0}_it{it}_ninv{self.args.n_inv_step}_ngen{self.args.n_test_step}_mrat{self.args.model_ratio}_{self.args.model_path.split('/')[-1].replace('.pth','')}.png")) 414 | else: 415 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 416 | f'3_gen_t{self.args.t_0}_it{it}_ninv{self.args.n_inv_step}_ngen{self.args.n_test_step}_mrat{self.args.model_ratio}.png')) 417 | -------------------------------------------------------------------------------- /models/improved_ddpm/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | 5 | 6 | from abc import abstractmethod 7 | 8 | import math 9 | 10 | import numpy as np 11 | import torch as th 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 16 | from .nn import ( 17 | checkpoint, 18 | conv_nd, 19 | linear, 20 | avg_pool_nd, 21 | zero_module, 22 | normalization, 23 | timestep_embedding, 24 | ) 25 | 26 | 27 | class AttentionPool2d(nn.Module): 28 | """ 29 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 30 | """ 31 | 32 | def __init__( 33 | self, 34 | spacial_dim: int, 35 | embed_dim: int, 36 | num_heads_channels: int, 37 | output_dim: int = None, 38 | ): 39 | super().__init__() 40 | self.positional_embedding = nn.Parameter( 41 | th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 42 | ) 43 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 44 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 45 | self.num_heads = embed_dim // num_heads_channels 46 | self.attention = QKVAttention(self.num_heads) 47 | 48 | def forward(self, x): 49 | b, c, *_spatial = x.shape 50 | x = x.reshape(b, c, -1) # NC(HW) 51 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 52 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 53 | x = self.qkv_proj(x) 54 | x = self.attention(x) 55 | x = self.c_proj(x) 56 | return x[:, :, 0] 57 | 58 | 59 | class TimestepBlock(nn.Module): 60 | """ 61 | Any module where forward() takes timestep embeddings as a second argument. 62 | """ 63 | 64 | @abstractmethod 65 | def forward(self, x, emb): 66 | """ 67 | Apply the module to `x` given `emb` timestep embeddings. 68 | """ 69 | 70 | 71 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 72 | """ 73 | A sequential module that passes timestep embeddings to the children that 74 | support it as an extra input. 75 | """ 76 | 77 | def forward(self, x, emb): 78 | for layer in self: 79 | if isinstance(layer, TimestepBlock): 80 | x = layer(x, emb) 81 | else: 82 | x = layer(x) 83 | return x 84 | 85 | 86 | class Upsample(nn.Module): 87 | """ 88 | An upsampling layer with an optional convolution. 89 | 90 | :param channels: channels in the inputs and outputs. 91 | :param use_conv: a bool determining if a convolution is applied. 92 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 93 | upsampling occurs in the inner-two dimensions. 94 | """ 95 | 96 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 97 | super().__init__() 98 | self.channels = channels 99 | self.out_channels = out_channels or channels 100 | self.use_conv = use_conv 101 | self.dims = dims 102 | if use_conv: 103 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 104 | 105 | def forward(self, x): 106 | assert x.shape[1] == self.channels 107 | if self.dims == 3: 108 | x = F.interpolate( 109 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 110 | ) 111 | else: 112 | x = F.interpolate(x, scale_factor=2, mode="nearest") 113 | if self.use_conv: 114 | x = self.conv(x) 115 | return x 116 | 117 | 118 | class Downsample(nn.Module): 119 | """ 120 | A downsampling layer with an optional convolution. 121 | 122 | :param channels: channels in the inputs and outputs. 123 | :param use_conv: a bool determining if a convolution is applied. 124 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 125 | downsampling occurs in the inner-two dimensions. 126 | """ 127 | 128 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 129 | super().__init__() 130 | self.channels = channels 131 | self.out_channels = out_channels or channels 132 | self.use_conv = use_conv 133 | self.dims = dims 134 | stride = 2 if dims != 3 else (1, 2, 2) 135 | if use_conv: 136 | self.op = conv_nd( 137 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 138 | ) 139 | else: 140 | assert self.channels == self.out_channels 141 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 142 | 143 | def forward(self, x): 144 | assert x.shape[1] == self.channels 145 | return self.op(x) 146 | 147 | 148 | class ResBlock(TimestepBlock): 149 | """ 150 | A residual block that can optionally change the number of channels. 151 | 152 | :param channels: the number of input channels. 153 | :param emb_channels: the number of timestep embedding channels. 154 | :param dropout: the rate of dropout. 155 | :param out_channels: if specified, the number of out channels. 156 | :param use_conv: if True and out_channels is specified, use a spatial 157 | convolution instead of a smaller 1x1 convolution to change the 158 | channels in the skip connection. 159 | :param dims: determines if the signal is 1D, 2D, or 3D. 160 | :param use_checkpoint: if True, use gradient checkpointing on this module. 161 | :param up: if True, use this block for upsampling. 162 | :param down: if True, use this block for downsampling. 163 | """ 164 | 165 | def __init__( 166 | self, 167 | channels, 168 | emb_channels, 169 | dropout, 170 | out_channels=None, 171 | use_conv=False, 172 | use_scale_shift_norm=False, 173 | dims=2, 174 | use_checkpoint=False, 175 | up=False, 176 | down=False, 177 | ): 178 | super().__init__() 179 | self.channels = channels 180 | self.emb_channels = emb_channels 181 | self.dropout = dropout 182 | self.out_channels = out_channels or channels 183 | self.use_conv = use_conv 184 | self.use_checkpoint = use_checkpoint 185 | self.use_scale_shift_norm = use_scale_shift_norm 186 | 187 | self.in_layers = nn.Sequential( 188 | normalization(channels), 189 | nn.SiLU(), 190 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 191 | ) 192 | 193 | self.updown = up or down 194 | 195 | if up: 196 | self.h_upd = Upsample(channels, False, dims) 197 | self.x_upd = Upsample(channels, False, dims) 198 | elif down: 199 | self.h_upd = Downsample(channels, False, dims) 200 | self.x_upd = Downsample(channels, False, dims) 201 | else: 202 | self.h_upd = self.x_upd = nn.Identity() 203 | 204 | self.emb_layers = nn.Sequential( 205 | nn.SiLU(), 206 | linear( 207 | emb_channels, 208 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 209 | ), 210 | ) 211 | self.out_layers = nn.Sequential( 212 | normalization(self.out_channels), 213 | nn.SiLU(), 214 | nn.Dropout(p=dropout), 215 | zero_module( 216 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 217 | ), 218 | ) 219 | 220 | if self.out_channels == channels: 221 | self.skip_connection = nn.Identity() 222 | elif use_conv: 223 | self.skip_connection = conv_nd( 224 | dims, channels, self.out_channels, 3, padding=1 225 | ) 226 | else: 227 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 228 | 229 | def forward(self, x, emb): 230 | """ 231 | Apply the block to a Tensor, conditioned on a timestep embedding. 232 | 233 | :param x: an [N x C x ...] Tensor of features. 234 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 235 | :return: an [N x C x ...] Tensor of outputs. 236 | """ 237 | return checkpoint( 238 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 239 | ) 240 | 241 | def _forward(self, x, emb): 242 | if self.updown: 243 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 244 | h = in_rest(x) 245 | h = self.h_upd(h) 246 | x = self.x_upd(x) 247 | h = in_conv(h) 248 | else: 249 | h = self.in_layers(x) 250 | emb_out = self.emb_layers(emb).type(h.dtype) 251 | while len(emb_out.shape) < len(h.shape): 252 | emb_out = emb_out[..., None] 253 | if self.use_scale_shift_norm: 254 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 255 | scale, shift = th.chunk(emb_out, 2, dim=1) 256 | h = out_norm(h) * (1 + scale) + shift 257 | h = out_rest(h) 258 | else: 259 | h = h + emb_out 260 | h = self.out_layers(h) 261 | return self.skip_connection(x) + h 262 | 263 | 264 | class AttentionBlock(nn.Module): 265 | """ 266 | An attention block that allows spatial positions to attend to each other. 267 | 268 | Originally ported from here, but adapted to the N-d case. 269 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 270 | """ 271 | 272 | def __init__( 273 | self, 274 | channels, 275 | num_heads=1, 276 | num_head_channels=-1, 277 | use_checkpoint=False, 278 | use_new_attention_order=False, 279 | ): 280 | super().__init__() 281 | self.channels = channels 282 | if num_head_channels == -1: 283 | self.num_heads = num_heads 284 | else: 285 | assert ( 286 | channels % num_head_channels == 0 287 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 288 | self.num_heads = channels // num_head_channels 289 | self.use_checkpoint = use_checkpoint 290 | self.norm = normalization(channels) 291 | self.qkv = conv_nd(1, channels, channels * 3, 1) 292 | if use_new_attention_order: 293 | # split qkv before split heads 294 | self.attention = QKVAttention(self.num_heads) 295 | else: 296 | # split heads before split qkv 297 | self.attention = QKVAttentionLegacy(self.num_heads) 298 | 299 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 300 | 301 | def forward(self, x): 302 | return checkpoint(self._forward, (x,), self.parameters(), True) 303 | 304 | def _forward(self, x): 305 | b, c, *spatial = x.shape 306 | x = x.reshape(b, c, -1) 307 | qkv = self.qkv(self.norm(x)) 308 | h = self.attention(qkv) 309 | h = self.proj_out(h) 310 | return (x + h).reshape(b, c, *spatial) 311 | 312 | 313 | def count_flops_attn(model, _x, y): 314 | """ 315 | A counter for the `thop` package to count the operations in an 316 | attention operation. 317 | Meant to be used like: 318 | macs, params = thop.profile( 319 | model, 320 | inputs=(inputs, timestamps), 321 | custom_ops={QKVAttention: QKVAttention.count_flops}, 322 | ) 323 | """ 324 | b, c, *spatial = y[0].shape 325 | num_spatial = int(np.prod(spatial)) 326 | # We perform two matmuls with the same number of ops. 327 | # The first computes the weight matrix, the second computes 328 | # the combination of the value vectors. 329 | matmul_ops = 2 * b * (num_spatial ** 2) * c 330 | model.total_ops += th.DoubleTensor([matmul_ops]) 331 | 332 | 333 | class QKVAttentionLegacy(nn.Module): 334 | """ 335 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 336 | """ 337 | 338 | def __init__(self, n_heads): 339 | super().__init__() 340 | self.n_heads = n_heads 341 | 342 | def forward(self, qkv): 343 | """ 344 | Apply QKV attention. 345 | 346 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 347 | :return: an [N x (H * C) x T] tensor after attention. 348 | """ 349 | bs, width, length = qkv.shape 350 | assert width % (3 * self.n_heads) == 0 351 | ch = width // (3 * self.n_heads) 352 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 353 | scale = 1 / math.sqrt(math.sqrt(ch)) 354 | weight = th.einsum( 355 | "bct,bcs->bts", q * scale, k * scale 356 | ) # More stable with f16 than dividing afterwards 357 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 358 | a = th.einsum("bts,bcs->bct", weight, v) 359 | return a.reshape(bs, -1, length) 360 | 361 | @staticmethod 362 | def count_flops(model, _x, y): 363 | return count_flops_attn(model, _x, y) 364 | 365 | 366 | class QKVAttention(nn.Module): 367 | """ 368 | A module which performs QKV attention and splits in a different order. 369 | """ 370 | 371 | def __init__(self, n_heads): 372 | super().__init__() 373 | self.n_heads = n_heads 374 | 375 | def forward(self, qkv): 376 | """ 377 | Apply QKV attention. 378 | 379 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 380 | :return: an [N x (H * C) x T] tensor after attention. 381 | """ 382 | bs, width, length = qkv.shape 383 | assert width % (3 * self.n_heads) == 0 384 | ch = width // (3 * self.n_heads) 385 | q, k, v = qkv.chunk(3, dim=1) 386 | scale = 1 / math.sqrt(math.sqrt(ch)) 387 | weight = th.einsum( 388 | "bct,bcs->bts", 389 | (q * scale).view(bs * self.n_heads, ch, length), 390 | (k * scale).view(bs * self.n_heads, ch, length), 391 | ) # More stable with f16 than dividing afterwards 392 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 393 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 394 | return a.reshape(bs, -1, length) 395 | 396 | @staticmethod 397 | def count_flops(model, _x, y): 398 | return count_flops_attn(model, _x, y) 399 | 400 | 401 | class UNetModel(nn.Module): 402 | """ 403 | The full UNet model with attention and timestep embedding. 404 | 405 | :param in_channels: channels in the input Tensor. 406 | :param model_channels: base channel count for the model. 407 | :param out_channels: channels in the output Tensor. 408 | :param num_res_blocks: number of residual blocks per downsample. 409 | :param attention_resolutions: a collection of downsample rates at which 410 | attention will take place. May be a set, list, or tuple. 411 | For example, if this contains 4, then at 4x downsampling, attention 412 | will be used. 413 | :param dropout: the dropout probability. 414 | :param channel_mult: channel multiplier for each level of the UNet. 415 | :param conv_resample: if True, use learned convolutions for upsampling and 416 | downsampling. 417 | :param dims: determines if the signal is 1D, 2D, or 3D. 418 | :param num_classes: if specified (as an int), then this model will be 419 | class-conditional with `num_classes` classes. 420 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 421 | :param num_heads: the number of attention heads in each attention layer. 422 | :param num_heads_channels: if specified, ignore num_heads and instead use 423 | a fixed channel width per attention head. 424 | :param num_heads_upsample: works with num_heads to set a different number 425 | of heads for upsampling. Deprecated. 426 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 427 | :param resblock_updown: use residual blocks for up/downsampling. 428 | :param use_new_attention_order: use a different attention pattern for potentially 429 | increased efficiency. 430 | """ 431 | 432 | def __init__( 433 | self, 434 | image_size, 435 | in_channels, 436 | model_channels, 437 | out_channels, 438 | num_res_blocks, 439 | attention_resolutions, 440 | dropout=0, 441 | channel_mult=(1, 2, 4, 8), 442 | conv_resample=True, 443 | dims=2, 444 | num_classes=None, 445 | use_checkpoint=False, 446 | use_fp16=False, 447 | num_heads=1, 448 | num_head_channels=-1, 449 | num_heads_upsample=-1, 450 | use_scale_shift_norm=False, 451 | resblock_updown=False, 452 | use_new_attention_order=False, 453 | ): 454 | super().__init__() 455 | 456 | if num_heads_upsample == -1: 457 | num_heads_upsample = num_heads 458 | 459 | self.image_size = image_size 460 | self.in_channels = in_channels 461 | self.model_channels = model_channels 462 | self.out_channels = out_channels 463 | self.num_res_blocks = num_res_blocks 464 | self.attention_resolutions = attention_resolutions 465 | self.dropout = dropout 466 | self.channel_mult = channel_mult 467 | self.conv_resample = conv_resample 468 | self.num_classes = num_classes 469 | self.use_checkpoint = use_checkpoint 470 | self.dtype = th.float16 if use_fp16 else th.float32 471 | self.num_heads = num_heads 472 | self.num_head_channels = num_head_channels 473 | self.num_heads_upsample = num_heads_upsample 474 | 475 | time_embed_dim = model_channels * 4 476 | self.time_embed = nn.Sequential( 477 | linear(model_channels, time_embed_dim), 478 | nn.SiLU(), 479 | linear(time_embed_dim, time_embed_dim), 480 | ) 481 | 482 | if self.num_classes is not None: 483 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 484 | 485 | ch = input_ch = int(channel_mult[0] * model_channels) 486 | self.input_blocks = nn.ModuleList( 487 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 488 | ) 489 | self._feature_size = ch 490 | input_block_chans = [ch] 491 | ds = 1 492 | for level, mult in enumerate(channel_mult): 493 | for _ in range(num_res_blocks): 494 | layers = [ 495 | ResBlock( 496 | ch, 497 | time_embed_dim, 498 | dropout, 499 | out_channels=int(mult * model_channels), 500 | dims=dims, 501 | use_checkpoint=use_checkpoint, 502 | use_scale_shift_norm=use_scale_shift_norm, 503 | ) 504 | ] 505 | ch = int(mult * model_channels) 506 | if ds in attention_resolutions: 507 | layers.append( 508 | AttentionBlock( 509 | ch, 510 | use_checkpoint=use_checkpoint, 511 | num_heads=num_heads, 512 | num_head_channels=num_head_channels, 513 | use_new_attention_order=use_new_attention_order, 514 | ) 515 | ) 516 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 517 | self._feature_size += ch 518 | input_block_chans.append(ch) 519 | if level != len(channel_mult) - 1: 520 | out_ch = ch 521 | self.input_blocks.append( 522 | TimestepEmbedSequential( 523 | ResBlock( 524 | ch, 525 | time_embed_dim, 526 | dropout, 527 | out_channels=out_ch, 528 | dims=dims, 529 | use_checkpoint=use_checkpoint, 530 | use_scale_shift_norm=use_scale_shift_norm, 531 | down=True, 532 | ) 533 | if resblock_updown 534 | else Downsample( 535 | ch, conv_resample, dims=dims, out_channels=out_ch 536 | ) 537 | ) 538 | ) 539 | ch = out_ch 540 | input_block_chans.append(ch) 541 | ds *= 2 542 | self._feature_size += ch 543 | 544 | self.middle_block = TimestepEmbedSequential( 545 | ResBlock( 546 | ch, 547 | time_embed_dim, 548 | dropout, 549 | dims=dims, 550 | use_checkpoint=use_checkpoint, 551 | use_scale_shift_norm=use_scale_shift_norm, 552 | ), 553 | AttentionBlock( 554 | ch, 555 | use_checkpoint=use_checkpoint, 556 | num_heads=num_heads, 557 | num_head_channels=num_head_channels, 558 | use_new_attention_order=use_new_attention_order, 559 | ), 560 | ResBlock( 561 | ch, 562 | time_embed_dim, 563 | dropout, 564 | dims=dims, 565 | use_checkpoint=use_checkpoint, 566 | use_scale_shift_norm=use_scale_shift_norm, 567 | ), 568 | ) 569 | self._feature_size += ch 570 | 571 | self.output_blocks = nn.ModuleList([]) 572 | for level, mult in list(enumerate(channel_mult))[::-1]: 573 | for i in range(num_res_blocks + 1): 574 | ich = input_block_chans.pop() 575 | layers = [ 576 | ResBlock( 577 | ch + ich, 578 | time_embed_dim, 579 | dropout, 580 | out_channels=int(model_channels * mult), 581 | dims=dims, 582 | use_checkpoint=use_checkpoint, 583 | use_scale_shift_norm=use_scale_shift_norm, 584 | ) 585 | ] 586 | ch = int(model_channels * mult) 587 | if ds in attention_resolutions: 588 | layers.append( 589 | AttentionBlock( 590 | ch, 591 | use_checkpoint=use_checkpoint, 592 | num_heads=num_heads_upsample, 593 | num_head_channels=num_head_channels, 594 | use_new_attention_order=use_new_attention_order, 595 | ) 596 | ) 597 | if level and i == num_res_blocks: 598 | out_ch = ch 599 | layers.append( 600 | ResBlock( 601 | ch, 602 | time_embed_dim, 603 | dropout, 604 | out_channels=out_ch, 605 | dims=dims, 606 | use_checkpoint=use_checkpoint, 607 | use_scale_shift_norm=use_scale_shift_norm, 608 | up=True, 609 | ) 610 | if resblock_updown 611 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 612 | ) 613 | ds //= 2 614 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 615 | self._feature_size += ch 616 | 617 | self.out = nn.Sequential( 618 | normalization(ch), 619 | nn.SiLU(), 620 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 621 | ) 622 | 623 | def convert_to_fp16(self): 624 | """ 625 | Convert the torso of the model to float16. 626 | """ 627 | self.input_blocks.apply(convert_module_to_f16) 628 | self.middle_block.apply(convert_module_to_f16) 629 | self.output_blocks.apply(convert_module_to_f16) 630 | 631 | def convert_to_fp32(self): 632 | """ 633 | Convert the torso of the model to float32. 634 | """ 635 | self.input_blocks.apply(convert_module_to_f32) 636 | self.middle_block.apply(convert_module_to_f32) 637 | self.output_blocks.apply(convert_module_to_f32) 638 | 639 | def forward(self, x, timesteps, y=None, ref_img=None): 640 | """ 641 | Apply the model to an input batch. 642 | 643 | :param x: an [N x C x ...] Tensor of inputs. 644 | :param timesteps: a 1-D batch of timesteps. 645 | :param y: an [N] Tensor of labels, if class-conditional. 646 | :return: an [N x C x ...] Tensor of outputs. 647 | """ 648 | # assert (y is not None) == ( 649 | # self.num_classes is not None 650 | # ), "must specify y if and only if the model is class-conditional" 651 | 652 | hs = [] 653 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 654 | 655 | # if self.num_classes is not None: 656 | # assert y.shape == (x.shape[0],) 657 | # emb = emb + self.label_emb(y) 658 | 659 | h = x.type(self.dtype) 660 | for module in self.input_blocks: 661 | h = module(h, emb) 662 | hs.append(h) 663 | h = self.middle_block(h, emb) 664 | for module in self.output_blocks: 665 | h = th.cat([h, hs.pop()], dim=1) 666 | h = module(h, emb) 667 | h = h.type(x.dtype) 668 | return self.out(h) 669 | -------------------------------------------------------------------------------- /makeup_transfer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | import os 4 | import numpy as np 5 | import cv2 6 | import copy 7 | from PIL import Image 8 | import torch 9 | from torch import nn 10 | import torchvision.utils as tvu 11 | from models.ddpm.diffusion import DDPM 12 | from utils.diffusion_utils import get_beta_schedule, denoising_step 13 | from utils.image_processing import * 14 | from utils.model_utils import * 15 | from losses.id_loss import cal_adv_loss 16 | from losses.clip_loss import CLIPLoss 17 | from datasets.data_utils import get_dataset, get_dataloader 18 | from configs.paths_config import DATASET_PATHS, MODEL_PATHS 19 | from utils.align_utils import run_alignment 20 | import torch.nn.functional as F 21 | import time 22 | import lpips 23 | 24 | 25 | class DiffAM_MT(object): 26 | def __init__(self, args, config, device=None): 27 | self.args = args 28 | self.config = config 29 | 30 | if device is None: 31 | device = torch.device( 32 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 33 | self.device = device 34 | 35 | self.model_var_type = config.model.var_type 36 | 37 | self.target_id = args.target_img 38 | self.target_model = args.target_model 39 | self.ref_id = args.ref_img 40 | 41 | self.makeup_image, self.non_makeup_image, self.makeup_mask = get_ref_image( 42 | self.ref_id) 43 | self.target_image, self.test_image, self.target_name = get_target_image( 44 | self.target_id) 45 | self.model_list = get_model_list(self.target_model) 46 | 47 | betas = get_beta_schedule( 48 | beta_start=config.diffusion.beta_start, 49 | beta_end=config.diffusion.beta_end, 50 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps 51 | ) 52 | self.betas = torch.from_numpy(betas).float().to(self.device) 53 | self.num_timesteps = betas.shape[0] 54 | 55 | alphas = 1.0 - betas 56 | alphas_cumprod = np.cumprod(alphas, axis=0) 57 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 58 | posterior_variance = betas * \ 59 | (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 60 | if self.model_var_type == "fixedlarge": 61 | self.logvar = np.log(np.append(posterior_variance[1], betas[1:])) 62 | 63 | elif self.model_var_type == 'fixedsmall': 64 | self.logvar = np.log(np.maximum(posterior_variance, 1e-20)) 65 | 66 | def clip_finetune(self): 67 | print(self.args.exp) 68 | print(f'Transfer makeup style {self.ref_id}') 69 | 70 | # ----------- Model -----------# 71 | url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt" 72 | 73 | model = DDPM(self.config) 74 | if self.args.model_path: 75 | init_ckpt = torch.load(self.args.model_path) 76 | else: 77 | init_ckpt = torch.hub.load_state_dict_from_url( 78 | url, map_location=self.device) 79 | learn_sigma = False 80 | print("Original diffusion Model loaded.") 81 | 82 | model.load_state_dict(init_ckpt) 83 | model.to(self.device) 84 | model = torch.nn.DataParallel(model) 85 | 86 | # ----------- Optimizer and Scheduler -----------# 87 | print(f"Setting optimizer with lr={self.args.lr_clip_finetune}") 88 | optim_ft = torch.optim.Adam( 89 | model.parameters(), weight_decay=0, lr=self.args.lr_clip_finetune) 90 | init_opt_ckpt = optim_ft.state_dict() 91 | scheduler_ft = torch.optim.lr_scheduler.StepLR( 92 | optim_ft, step_size=1, gamma=self.args.sch_gamma) 93 | init_sch_ckpt = scheduler_ft.state_dict() 94 | 95 | # ----------- Loss -----------# 96 | print("Loading losses") 97 | clip_loss_func = CLIPLoss( 98 | self.device, 99 | lambda_makeup_direction=1, 100 | lambda_direction=0, 101 | clip_model=self.args.clip_model_name) 102 | loss_fn_alex = lpips.LPIPS(net='alex').to(self.device) 103 | # ----------- Precompute Latents -----------# 104 | print("Prepare identity latent") 105 | seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0 106 | seq_inv = [int(s) for s in list(seq_inv)] 107 | seq_inv_next = [-1] + list(seq_inv[:-1]) 108 | 109 | n = self.args.bs_train 110 | img_lat_pairs_dic = {} 111 | for mode in ['train', 'test']: 112 | img_lat_pairs = [] 113 | pairs_path = os.path.join('precomputed/', 114 | f'{self.config.data.category}_{mode}_t{self.args.t_0}_nim{self.args.n_precomp_img}_ninv{self.args.n_inv_step}_pairs.pth') 115 | print(pairs_path) 116 | if os.path.exists(pairs_path): 117 | print(f'{mode} pairs exists') 118 | img_lat_pairs_dic[mode] = torch.load( 119 | pairs_path, map_location=torch.device('cpu')) 120 | for step, (x0, x_id, x_lat, _) in enumerate(img_lat_pairs_dic[mode]): 121 | tvu.save_image( 122 | (x0 + 1) * 0.5, os.path.join(self.args.image_folder, f'{mode}_{step}_0_orig.png')) 123 | tvu.save_image((x_id + 1) * 0.5, os.path.join(self.args.image_folder, 124 | f'{mode}_{step}_1_rec_ninv{self.args.n_inv_step}.png')) 125 | if step == self.args.n_precomp_img - 1: 126 | break 127 | continue 128 | else: 129 | train_dataset, test_dataset = get_dataset( 130 | self.config.data.dataset, DATASET_PATHS, self.config) 131 | loader_dic = get_dataloader(train_dataset, test_dataset, bs_train=self.args.bs_train, 132 | num_workers=self.config.data.num_workers) 133 | loader = loader_dic[mode] 134 | 135 | for step, (img, mask_list) in enumerate(loader): 136 | if mask_list == 0: 137 | continue 138 | x0 = img.to(self.config.device) 139 | tvu.save_image( 140 | (x0 + 1) * 0.5, os.path.join(self.args.image_folder, f'{mode}_{step}_0_orig.png')) 141 | 142 | x = x0.clone() 143 | model.eval() 144 | with torch.no_grad(): 145 | with tqdm(total=len(seq_inv), desc=f"Inversion process {mode} {step}") as progress_bar: 146 | for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))): 147 | t = (torch.ones(n) * i).to(self.device) 148 | t_prev = (torch.ones(n) * j).to(self.device) 149 | 150 | x = denoising_step(x, t=t, t_next=t_prev, models=model, 151 | logvars=self.logvar, 152 | sampling_type='ddim', 153 | b=self.betas, 154 | eta=0, 155 | learn_sigma=learn_sigma) 156 | 157 | progress_bar.update(1) 158 | x_lat = x.clone() 159 | tvu.save_image((x_lat + 1) * 0.5, os.path.join(self.args.image_folder, 160 | f'{mode}_{step}_1_lat_ninv{self.args.n_inv_step}.png')) 161 | 162 | with tqdm(total=len(seq_inv), desc=f"Generative process {mode} {step}") as progress_bar: 163 | for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))): 164 | t = (torch.ones(n) * i).to(self.device) 165 | t_next = (torch.ones(n) * j).to(self.device) 166 | 167 | x = denoising_step(x, t=t, t_next=t_next, models=model, 168 | logvars=self.logvar, 169 | sampling_type=self.args.sample_type, 170 | b=self.betas, 171 | learn_sigma=learn_sigma) 172 | progress_bar.update(1) 173 | 174 | img_lat_pairs.append( 175 | [x0, x.detach().clone(), x_lat.detach().clone(), mask_list]) 176 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 177 | f'{mode}_{step}_1_rec_ninv{self.args.n_inv_step}.png')) 178 | if step >= self.args.n_precomp_img - 1: 179 | break 180 | 181 | img_lat_pairs_dic[mode] = img_lat_pairs 182 | pairs_path = os.path.join('precomputed/', 183 | f'{self.config.data.category}_{mode}_t{self.args.t_0}_nim{self.args.n_precomp_img}_ninv{self.args.n_inv_step}_pairs.pth') 184 | torch.save(img_lat_pairs, pairs_path) 185 | # ----------- Finetune Diffusion Models -----------# 186 | print("Start finetuning") 187 | print( 188 | f"Sampling type: {self.args.sample_type.upper()} with eta {self.args.eta}") 189 | if self.args.n_train_step != 0: 190 | seq_train = np.linspace( 191 | 0, 1, self.args.n_train_step) * self.args.t_0 192 | seq_train = [int(s) for s in list(seq_train)] 193 | print('Uniform skip type') 194 | else: 195 | seq_train = list(range(self.args.t_0)) 196 | print('No skip') 197 | seq_train_next = [-1] + list(seq_train[:-1]) 198 | 199 | seq_test = np.linspace(0, 1, self.args.n_test_step) * self.args.t_0 200 | seq_test = [int(s) for s in list(seq_test)] 201 | seq_test_next = [-1] + list(seq_test[:-1]) 202 | 203 | mask_B = self.makeup_mask 204 | mask_B_lip = (mask_B == 7).float() + (mask_B == 9).float() 205 | mask_B_skin = (mask_B == 1).float() + (mask_B == 206 | 6).float() + (mask_B == 13).float() 207 | mask_B_eye_left = (mask_B == 4).float() 208 | mask_B_eye_right = (mask_B == 5).float() 209 | mask_B_face = (mask_B == 1).float() + (mask_B == 6).float() 210 | mask_B_eye_left, mask_B_eye_right = rebound_box( 211 | mask_B_eye_left, mask_B_eye_right, mask_B_face) 212 | 213 | model.module.load_state_dict(init_ckpt) 214 | optim_ft.load_state_dict(init_opt_ckpt) 215 | scheduler_ft.load_state_dict(init_sch_ckpt) 216 | clip_loss_func.target_direction = None 217 | 218 | # ----------- Train -----------# 219 | for it_out in range(self.args.n_iter): 220 | exp_id = os.path.split(self.args.exp)[-1] 221 | save_name = f'checkpoint/{exp_id}_{self.ref_id.replace(" ", "_")}-{it_out}.pth' 222 | if self.args.do_train: 223 | if os.path.exists(save_name): 224 | print(f'{save_name} already exists.') 225 | model.module.load_state_dict(torch.load(save_name)) 226 | continue 227 | else: 228 | for step, (x0, x_id, x_lat, mask_list) in enumerate(img_lat_pairs_dic['train']): 229 | 230 | mask_A_lip = torch.mean(mask_list[2], dim=1) 231 | mask_A_skin = torch.mean(mask_list[3], dim=1) 232 | mask_A_face = torch.mean(mask_list[4], dim=1) 233 | mask_A_eye_left = torch.mean(mask_list[0], dim=1) 234 | mask_A_eye_right = torch.mean(mask_list[1], dim=1) 235 | mask_A_eye_left, mask_A_eye_right = rebound_box( 236 | mask_A_eye_left, mask_A_eye_right, mask_A_face) 237 | mask_A_lip, mask_B_lip, index_A_lip, index_B_lip = mask_preprocess( 238 | mask_A_lip, mask_B_lip) 239 | mask_A_skin, mask_B_skin, index_A_skin, index_B_skin = mask_preprocess( 240 | mask_A_skin, mask_B_skin) 241 | mask_A_eye_left, mask_B_eye_left, index_A_eye_left, index_B_eye_left = mask_preprocess( 242 | mask_A_eye_left, mask_B_eye_left) 243 | mask_A_eye_right, mask_B_eye_right, index_A_eye_right, index_B_eye_right = mask_preprocess( 244 | mask_A_eye_right, mask_B_eye_right) 245 | 246 | x0 = x0.to(self.device) 247 | x_id = x_id.to(self.device) 248 | x_lat = x_lat.to(self.device) 249 | model.train() 250 | 251 | optim_ft.zero_grad() 252 | x = x_lat.clone() 253 | with tqdm(total=len(seq_train), desc=f"CLIP iteration") as progress_bar: 254 | for t_it, (i, j) in enumerate(zip(reversed(seq_train), reversed(seq_train_next))): 255 | t = (torch.ones(n) * i).to(self.device) 256 | t_next = (torch.ones(n) * j).to(self.device) 257 | 258 | x = denoising_step(x, t=t, t_next=t_next, models=model, 259 | logvars=self.logvar, 260 | sampling_type=self.args.sample_type, 261 | b=self.betas, 262 | eta=self.args.eta, 263 | learn_sigma=learn_sigma) 264 | 265 | progress_bar.update(1) 266 | 267 | tvu.save_image( 268 | (x0+1)/2, './sample_real/sample_{}.png'.format(step)) 269 | tvu.save_image( 270 | (x+1)/2, './sample_fake/sample_{}.png'.format(step)) 271 | 272 | loss_adv = 0 273 | targeted_loss_list = [] 274 | for model_name in list(self.model_list.keys())[:-1]: 275 | target_loss_A = cal_adv_loss( 276 | x, self.target_image, model_name, self.model_list) 277 | targeted_loss_list.append(target_loss_A) 278 | loss_adv = torch.mean(torch.stack(targeted_loss_list)) 279 | 280 | loss_dis = 0 281 | g_A_lip_loss_his, lip_img = criterionHis(x, self.makeup_image.unsqueeze( 282 | 0), index_A_lip, nn.L1Loss(), mask_A_lip, mask_B_lip) 283 | g_A_skin_loss_his, skin_img = criterionHis(x, self.makeup_image.unsqueeze( 284 | 0), index_A_skin, nn.L1Loss(), mask_A_skin, mask_B_skin) 285 | g_A_eye_left_loss_his, eye_left_img = criterionHis(x, self.makeup_image.unsqueeze( 286 | 0), index_A_eye_left, nn.L1Loss(), mask_A_eye_left, mask_B_eye_left) 287 | g_A_eye_right_loss_his, eye_right_img = criterionHis(x, self.makeup_image.unsqueeze( 288 | 0), index_A_eye_right, nn.L1Loss(), mask_A_eye_right, mask_B_eye_right) 289 | loss_dis = g_A_eye_left_loss_his + g_A_eye_right_loss_his + \ 290 | 0.15 * g_A_skin_loss_his + g_A_lip_loss_his 291 | 292 | loss_dir = (2 - clip_loss_func(x0, self.non_makeup_image.unsqueeze( 293 | 0), x, self.makeup_image.unsqueeze(0))) / 2 294 | loss_dir = -torch.log(loss_dir) 295 | loss_l1 = nn.L1Loss()(x0, x) 296 | loss_lpips = loss_fn_alex(x0, x) 297 | if it_out < self.args.MT_iter_without_adv: 298 | loss = self.args.MT_1_dis_loss_w * loss_dis + self.args.MT_1_dir_loss_w * loss_dir + \ 299 | self.args.MT_lpips_loss_w * loss_lpips + self.args.MT_1_l1_loss_w * loss_l1 300 | else: 301 | loss = self.args.MT_2_dis_loss_w * loss_dis + self.args.MT_2_dir_loss_w * loss_dir + self.args.MT_adv_loss_w * \ 302 | loss_adv + self.args.MT_lpips_loss_w * \ 303 | loss_lpips + self.args.MT_1_l1_loss_w * loss_l1 304 | 305 | loss.backward() 306 | optim_ft.step() 307 | print( 308 | f"CLIP {step}-{it_out}: loss_adv: {loss_adv:.3f}, loss_dir: {loss_dir:.3f}, loss_dis: {loss_dis:.3f}") 309 | 310 | if self.args.save_train_image: 311 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 312 | f'train_{step}_2_clip_{self.ref_id.replace(" ", "_")}_{it_out}_ngen{self.args.n_train_step}.png')) 313 | if step == self.args.n_train_img - 1: 314 | break 315 | 316 | if isinstance(model, nn.DataParallel): 317 | torch.save(model.module.state_dict(), save_name) 318 | else: 319 | torch.save(model.state_dict(), save_name) 320 | print(f'Model {save_name} is saved.') 321 | scheduler_ft.step() 322 | 323 | # ----------- Eval -----------# 324 | if self.args.do_test: 325 | if not self.args.do_train: 326 | print(save_name) 327 | model.module.load_state_dict(torch.load(save_name)) 328 | 329 | model.eval() 330 | img_lat_pairs = img_lat_pairs_dic[mode] 331 | FAR01 = 0 332 | FAR001 = 0 333 | FAR0001 = 0 334 | total = 0 335 | for step, (x0, x_id, x_lat, _) in enumerate(img_lat_pairs): 336 | x0 = x0.to(self.device) 337 | x_id = x_id.to(self.device) 338 | x_lat = x_lat.to(self.device) 339 | with torch.no_grad(): 340 | x = x_lat 341 | with tqdm(total=len(seq_test), desc=f"Eval iteration") as progress_bar: 342 | for i, j in zip(reversed(seq_test), reversed(seq_test_next)): 343 | t = (torch.ones(n) * i).to(self.device) 344 | t_next = (torch.ones(n) * j).to(self.device) 345 | 346 | x = denoising_step(x, t=t, t_next=t_next, models=model, 347 | logvars=self.logvar, 348 | sampling_type=self.args.sample_type, 349 | b=self.betas, 350 | eta=self.args.eta, 351 | learn_sigma=learn_sigma) 352 | 353 | progress_bar.update(1) 354 | 355 | th_dict = {'ir152': (0.094632, 0.166788, 0.227922), 'irse50': (0.144840, 0.241045, 0.312703), 356 | 'facenet': (0.256587, 0.409131, 0.591191), 'mobile_face': (0.183635, 0.301611, 0.380878)} 357 | tvu.save_image( 358 | (x0+1)/2, './sample_real_test/sample_{}.png'.format(step)) 359 | tvu.save_image( 360 | (x+1)/2, './sample_fake_test/sample_{}.png'.format(step)) 361 | for test_model in list(self.model_list.keys())[-1:]: 362 | size = self.model_list[test_model][0] 363 | test_model_ = self.model_list[test_model][1] 364 | target_embbeding = test_model_( 365 | (F.interpolate(self.test_image, size=size, mode='bilinear'))) 366 | 367 | ae_embbeding = test_model_( 368 | (F.interpolate(x, size=size, mode='bilinear'))) 369 | cos_simi = torch.cosine_similarity( 370 | ae_embbeding, target_embbeding) 371 | 372 | if cos_simi.item() > th_dict[test_model][0]: 373 | FAR01 += 1 374 | if cos_simi.item() > th_dict[test_model][1]: 375 | FAR001 += 1 376 | if cos_simi.item() > th_dict[test_model][2]: 377 | FAR0001 += 1 378 | 379 | total += 1 380 | print(f"Eval {step}-{it_out}") 381 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 382 | f'{mode}_{step}_2_clip_{self.ref_id.replace(" ", "_")}_{it_out}_ngen{self.args.n_test_step}.png')) 383 | if step == self.args.n_test_img - 1: 384 | break 385 | print("ASR in FAR@0.1: {:.4f}, ASR in FAR@0.01: {:.4f}, ASR in FAR@0.001: {:.4f}". 386 | format(FAR01/total, FAR001/total, FAR0001/total)) 387 | 388 | def edit_one_image(self): 389 | # ----------- Data -----------# 390 | n = self.args.bs_test 391 | try: 392 | img = run_alignment(self.args.img_path, 393 | output_size=self.config.data.image_size) 394 | except: 395 | img = Image.open(self.args.img_path).convert("RGB") 396 | 397 | img = img.resize((self.config.data.image_size, 398 | self.config.data.image_size), Image.ANTIALIAS) 399 | img = np.array(img)/255 400 | img = torch.from_numpy(img).type(torch.FloatTensor).permute( 401 | 2, 0, 1).unsqueeze(dim=0).repeat(n, 1, 1, 1) 402 | img = img.to(self.config.device) 403 | tvu.save_image(img, os.path.join( 404 | self.args.image_folder, f'0_orig.png')) 405 | x0 = (img - 0.5) * 2. 406 | 407 | models = [] 408 | 409 | model_paths = [None, self.args.model_path] 410 | 411 | for model_path in model_paths: 412 | model_i = DDPM(self.config) 413 | if model_path: 414 | ckpt = torch.load(model_path) 415 | else: 416 | ckpt = torch.load('pretrained/celeba_hq.ckpt') 417 | learn_sigma = False 418 | model_i.load_state_dict(ckpt) 419 | model_i.to(self.device) 420 | model_i = torch.nn.DataParallel(model_i) 421 | model_i.eval() 422 | print(f"{model_path} is loaded.") 423 | models.append(model_i) 424 | 425 | with torch.no_grad(): 426 | # ---------------- Invert Image to Latent in case of Deterministic Inversion process -------------------# 427 | if self.args.deterministic_inv: 428 | x_lat_path = os.path.join( 429 | self.args.image_folder, f'x_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}.pth') 430 | if not os.path.exists(x_lat_path): 431 | seq_inv = np.linspace( 432 | 0, 1, self.args.n_inv_step) * self.args.t_0 433 | seq_inv = [int(s) for s in list(seq_inv)] 434 | seq_inv_next = [-1] + list(seq_inv[:-1]) 435 | 436 | x = x0.clone() 437 | with tqdm(total=len(seq_inv), desc=f"Inversion process ") as progress_bar: 438 | for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))): 439 | t = (torch.ones(n) * i).to(self.device) 440 | t_prev = (torch.ones(n) * j).to(self.device) 441 | 442 | x = denoising_step(x, t=t, t_next=t_prev, models=models[0], 443 | logvars=self.logvar, 444 | sampling_type='ddim', 445 | b=self.betas, 446 | eta=0, 447 | learn_sigma=learn_sigma, 448 | ratio=0, 449 | ) 450 | 451 | progress_bar.update(1) 452 | x_lat = x.clone() 453 | torch.save(x_lat, x_lat_path) 454 | else: 455 | print('Latent exists.') 456 | x_lat = torch.load(x_lat_path) 457 | 458 | # ----------- Generative Process -----------# 459 | print(f"Sampling type: {self.args.sample_type.upper()} with eta {self.args.eta}, " 460 | f" Steps: {self.args.n_test_step}/{self.args.t_0}") 461 | if self.args.n_test_step != 0: 462 | seq_test = np.linspace( 463 | 0, 1, self.args.n_test_step) * self.args.t_0 464 | seq_test = [int(s) for s in list(seq_test)] 465 | print('Uniform skip type') 466 | else: 467 | seq_test = list(range(self.args.t_0)) 468 | print('No skip') 469 | seq_test_next = [-1] + list(seq_test[:-1]) 470 | 471 | for it in range(self.args.n_iter): 472 | if self.args.deterministic_inv: 473 | x = x_lat.clone() 474 | else: 475 | e = torch.randn_like(x0) 476 | a = (1 - self.betas).cumprod(dim=0) 477 | x = x0 * a[self.args.t_0 - 1].sqrt() + e * \ 478 | (1.0 - a[self.args.t_0 - 1]).sqrt() 479 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 480 | f'1_lat_ninv{self.args.n_inv_step}.png')) 481 | 482 | with tqdm(total=len(seq_test), desc="Generative process {}".format(it)) as progress_bar: 483 | for i, j in zip(reversed(seq_test), reversed(seq_test_next)): 484 | t = (torch.ones(n) * i).to(self.device) 485 | t_next = (torch.ones(n) * j).to(self.device) 486 | 487 | x = denoising_step(x, t=t, t_next=t_next, models=models, 488 | logvars=self.logvar, 489 | sampling_type=self.args.sample_type, 490 | b=self.betas, 491 | eta=self.args.eta, 492 | learn_sigma=learn_sigma, 493 | ratio=self.args.model_ratio) 494 | 495 | # added intermediate step vis 496 | if (i - 99) % 100 == 0: 497 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 498 | f'2_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}_ngen{self.args.n_test_step}_{i}_it{it}.png')) 499 | progress_bar.update(1) 500 | 501 | x0 = x.clone() 502 | if self.args.model_path: 503 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 504 | f"3_gen_t{self.args.t_0}_it{it}_ninv{self.args.n_inv_step}_ngen{self.args.n_test_step}_mrat{self.args.model_ratio}_{self.args.model_path.split('/')[-1].replace('.pth','')}.png")) 505 | else: 506 | tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder, 507 | f'3_gen_t{self.args.t_0}_it{it}_ninv{self.args.n_inv_step}_ngen{self.args.n_test_step}_mrat{self.args.model_ratio}.png')) 508 | --------------------------------------------------------------------------------