├── pics ├── pipeline.png └── teaser.png ├── .gitignore ├── dataset └── __pycache__ │ ├── ABAWData.cpython-38.pyc │ ├── ABAWData.cpython-39.pyc │ ├── ABAWData_seq.cpython-38.pyc │ ├── ABAWData_seq.cpython-39.pyc │ ├── asymmDataset.cpython-38.pyc │ └── asymmDataset.cpython-39.pyc ├── exp_emb_code ├── utils │ ├── __pycache__ │ │ ├── misc.cpython-37.pyc │ │ └── metrics.cpython-37.pyc │ ├── metrics.py │ └── misc.py ├── configs │ └── mae_train_expemb.yaml ├── model │ ├── mae_pipeline.py │ └── models_vit.py ├── dataset.py └── train.py ├── requirements.txt ├── configs_rig2img.yaml ├── configs_emb2rig_multi.yaml ├── mae-main ├── util │ ├── lr_sched.py │ ├── crop.py │ ├── lars.py │ ├── datasets.py │ ├── lr_decay.py │ ├── pos_embed.py │ └── misc.py ├── CONTRIBUTING.md ├── PRETRAIN.md ├── models_vit.py ├── engine_pretrain.py ├── CODE_OF_CONDUCT.md ├── submitit_pretrain.py ├── submitit_finetune.py ├── submitit_linprobe.py ├── engine_finetune.py ├── README.md ├── FINETUNE.md ├── main_pretrain.py └── models_mae.py ├── models ├── load_emb_model.py ├── mae_pipeline.py ├── pipeline5.py ├── gan_loss.py ├── models_vit.py ├── CascadeNet.py ├── DCGAN.py ├── Emoca_ExprNet.py ├── facenet2.py └── discriminator.py ├── utils ├── tools.py └── common.py ├── choose_character.py ├── README.md └── train_rig2img.py /pics/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/pics/pipeline.png -------------------------------------------------------------------------------- /pics/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/pics/teaser.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | core 3 | checkpoints/not_used/ 4 | checkpoints 5 | *.onnx 6 | tmp 7 | data 8 | expr-transfer-tiangong 9 | -------------------------------------------------------------------------------- /dataset/__pycache__/ABAWData.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/dataset/__pycache__/ABAWData.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/ABAWData.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/dataset/__pycache__/ABAWData.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/ABAWData_seq.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/dataset/__pycache__/ABAWData_seq.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/ABAWData_seq.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/dataset/__pycache__/ABAWData_seq.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/asymmDataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/dataset/__pycache__/asymmDataset.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/asymmDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/dataset/__pycache__/asymmDataset.cpython-39.pyc -------------------------------------------------------------------------------- /exp_emb_code/utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/exp_emb_code/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /exp_emb_code/utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/free_avatar/HEAD/exp_emb_code/utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.28.1 2 | matplotlib==3.7.5 3 | numpy==1.23.5 4 | opencv_contrib_python==4.8.0.76 5 | opencv_python==4.9.0.80 6 | opencv_python_headless==4.8.1.78 7 | Pillow==10.4.0 8 | PyYAML==6.0.1 9 | setuptools==45.2.0 10 | skimage==0.0 11 | timm==0.9.2 12 | torch==2.1.0 13 | torchvision==0.16.0 14 | tqdm==4.65.0 -------------------------------------------------------------------------------- /configs_rig2img.yaml: -------------------------------------------------------------------------------- 1 | character: L36_230_61 2 | use_multichar: True 3 | save_root: /project/qiuf/expr-capture 4 | emb_backbone: mae_emb 5 | faceware_ratio: 0.0 6 | batch_size: 128 7 | lr: 0.00001 8 | lr_D: 0.000001 9 | patience: 30 10 | save_step: 10 11 | # loss weight 12 | weight_rig: 10 13 | weight_img: 1 14 | weight_mouth: 10 15 | weight_emb: 100 16 | weight_D: 0.01 17 | weight_symm: 1 18 | id_embedding_dim: 16 19 | pretrained: '' 20 | mode: 'train' 21 | train_step_per_epoch: 500 22 | eval_step_per_epoch: 100 23 | seed: 101010 -------------------------------------------------------------------------------- /exp_emb_code/configs/mae_train_expemb.yaml: -------------------------------------------------------------------------------- 1 | train_csv: /path/to/train.csv 2 | train_img_path: /path/to/train_image 3 | val_csv: /path/to/val.csv 4 | val_img_path: /path/to/val_image 5 | 6 | use_dp: False 7 | device: [0] 8 | 9 | log_dir: "./log/vit_base_16_exp_emb" 10 | checkpoint_dir: "./checkpoints/vit_base_16_exp_emb" 11 | 12 | print_freq: 20 13 | accum_iter: 100 14 | batch_size: 64 15 | num_workers: 8 16 | 17 | num_epochs: 100 18 | save_epoch: 1 19 | mae_pretrain_checkpoints: /path/to/pretrained_mae 20 | resume: null 21 | 22 | lr: 0.002 23 | momentum: 0.9 24 | weight_decay: 0.00005 25 | optim: "SGD" 26 | -------------------------------------------------------------------------------- /configs_emb2rig_multi.yaml: -------------------------------------------------------------------------------- 1 | character: L36_230_61, L36_233 #, L36_233 2 | use_multichar: True 3 | CHARACTER_NAME: ['L36_233', 'L36_230', 'L36_230_61'] 4 | save_root: /project/qiuf/expr-capture 5 | # mae_emb_cartoon2,mae_dissymm_v5,mae_affectnet, mae_emb, repvit, mobilevit 6 | emb_backbone: mobilevit_relu 7 | mode: 'train' 8 | # DATA 9 | faceware_ratio: 0.1 10 | only_render: False 11 | # training 12 | batch_size: 16 13 | lr: 0.00001 14 | lr_D: 0.000001 15 | patience: 30 16 | save_step: 10 17 | # loss weight 18 | weight_rig: 10 #10 19 | weight_img: 1 20 | weight_mouth: 0 21 | weight_emb: 0.2 #100 22 | weight_D: 0 #0.0001 23 | weight_symm: 0.2 24 | id_embedding_dim: 16 25 | pretrained: '' # 20240614-181949_210, 20240625-191424_90 26 | train_step_per_epoch: 500 27 | eval_step_per_epoch: 100 28 | seed: 101010 -------------------------------------------------------------------------------- /mae-main/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /exp_emb_code/model/mae_pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | import torch 7 | import model.models_vit as models_vit 8 | 9 | 10 | 11 | class Pipeline(nn.Module): 12 | def __init__(self,config): 13 | super(Pipeline,self).__init__() 14 | self.config = config 15 | model_name = 'vit_base_patch16' 16 | num_classes = 16 17 | ckpt_path = config["mae_pretrain_checkpoints"] 18 | model = getattr(models_vit, model_name)( 19 | global_pool=True, 20 | num_classes=num_classes, 21 | drop_path_rate=0.1, 22 | img_size=224, 23 | ) 24 | print(f"Load pre-trained checkpoint from: {ckpt_path}") 25 | checkpoint = torch.load(ckpt_path, map_location='cpu') 26 | checkpoint_model = checkpoint['model'] 27 | model.load_state_dict(checkpoint_model, strict=False) 28 | self.main = model 29 | 30 | def forward(self, x): 31 | x = self.main(x) 32 | x = F.normalize(x, dim=1) 33 | return x 34 | 35 | def forward_fea(self,x): 36 | x,fea = self.main(x,True) 37 | return fea 38 | -------------------------------------------------------------------------------- /models/load_emb_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.transforms as transforms 3 | from models.mae_pipeline import Pipeline_mae 4 | from models.Emoca_ExprNet import ExpressionLossNet as EmoNet 5 | 6 | def load_emb_model(backbone, opt=None): 7 | 8 | if backbone == 'emonet': 9 | model_emb = EmoNet().cuda() 10 | resize = transforms.Compose([transforms.Resize([224,224], antialias=True)]) 11 | n_rig = 2048 12 | if opt: 13 | opt.exp_dim = 2048 14 | 15 | elif backbone == 'mae_emb': 16 | n_rig = 768 17 | mean = [0.49895147219604985, 0.4104390648367995, 0.3656147590417074] 18 | std = [0.2970847084907291, 0.2699003075660314, 0.2652599579468044] 19 | resize = transforms.Compose([transforms.Resize([224,224], antialias=True), 20 | transforms.Normalize(mean, std)]) 21 | model_emb = Pipeline_mae() 22 | ckpt_mae = torch.load('/data/Workspace/Rig2Face/ckpt/epoch_90_acc_0.8736.pth') 23 | ckpt_mae = {key.replace('module.', ''):ckpt_mae[key] for key in ckpt_mae.keys()} 24 | model_emb.load_state_dict(ckpt_mae) 25 | if opt: 26 | opt.exp_dim = n_rig 27 | 28 | return model_emb, n_rig, resize -------------------------------------------------------------------------------- /mae-main/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to mae 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to mae, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /mae-main/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /exp_emb_code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import math 4 | 5 | 6 | 7 | def triplet_prediction_accuracy(distances1,distances2,distances3,types=None,mode="triplet"): 8 | # distances1: anc and pos 9 | # distances2: anc and neg 10 | # distances3: pos and neg 11 | N = len(distances1) 12 | distances1 = np.array(distances1) 13 | distances2 = np.array(distances2) 14 | distances3 = np.array(distances3) 15 | 16 | 17 | c1 = distances2-distances1 18 | c2 = distances3-distances1 19 | n = 0 20 | if types==None: 21 | for i in range(N): 22 | if c1[i] > 0 and c2[i] > 0: 23 | n+=1 24 | acc = n/N 25 | return acc 26 | else: 27 | s1,s2,s3,N1,N2,N3=0,0,0,0,0,0 28 | for i in range(len(c1)): 29 | if types[i] == "ONE_CLASS_TRIPLET": 30 | N1 += 1 31 | elif types[i] == "TWO_CLASS_TRIPLET": 32 | N2 += 1 33 | elif types[i] == "THREE_CLASS_TRIPLET": 34 | N3 += 1 35 | if c1[i] > 0 and c2[i] > 0: 36 | n+=1 37 | if types[i]=="ONE_CLASS_TRIPLET": 38 | s1+=1 39 | elif types[i]=="TWO_CLASS_TRIPLET": 40 | s2+=1 41 | elif types[i]=="THREE_CLASS_TRIPLET": 42 | s3+=1 43 | acc = n/N 44 | acc1 = s1/N1 45 | acc2 = s2/N2 46 | acc3 = s3/N3 47 | return acc,acc1,acc2,acc3 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /models/mae_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | import models.models_vit as models_vit 6 | 7 | 8 | class Emb_mae(nn.Module): 9 | def __init__(self): 10 | super(Emb_mae,self).__init__() 11 | 12 | state_dict = torch.load("/data/Workspace/Rig2Face/ckpt/emb_regress_loss_0.0088.pth") 13 | new_dict = {} 14 | for k,v in state_dict.items(): 15 | if "module." in k: 16 | new_k = k.split("module.")[1] 17 | new_dict[new_k] = v 18 | if len(new_dict)>0: 19 | state_dict = new_dict 20 | 21 | emb_net = Pipeline_mae().cuda() 22 | emb_net.load_state_dict(state_dict) 23 | emb_net.eval() 24 | for params in emb_net.parameters(): 25 | params.requires_grad = False 26 | self.emb_net = emb_net 27 | # self.transform = build_transform(False) 28 | 29 | def forward(self, img_tensor): 30 | # feature_mae = self.transform(img_tensor) 31 | emb = self.emb_net(img_tensor) 32 | return emb, emb 33 | 34 | class Pipeline_mae(nn.Module): 35 | def __init__(self): 36 | super(Pipeline_mae,self).__init__() 37 | model_name = 'vit_base_patch16' 38 | num_classes = 16 39 | model = getattr(models_vit, model_name)( 40 | global_pool=True, 41 | num_classes=num_classes, 42 | drop_path_rate=0.1, 43 | img_size=224, 44 | ) 45 | self.main = model 46 | 47 | def forward(self,x): 48 | x,fea = self.main(x,True) 49 | x = F.normalize(x, dim=1) 50 | return fea,x 51 | 52 | -------------------------------------------------------------------------------- /mae-main/PRETRAIN.md: -------------------------------------------------------------------------------- 1 | ## Pre-training MAE 2 | 3 | To pre-train ViT-Large (recommended default) with **multi-node distributed training**, run the following on 8 nodes with 8 GPUs each: 4 | ``` 5 | python submitit_pretrain.py \ 6 | --job_dir ${JOB_DIR} \ 7 | --nodes 8 \ 8 | --use_volta32 \ 9 | --batch_size 64 \ 10 | --model mae_vit_large_patch16 \ 11 | --norm_pix_loss \ 12 | --mask_ratio 0.75 \ 13 | --epochs 800 \ 14 | --warmup_epochs 40 \ 15 | --blr 1.5e-4 --weight_decay 0.05 \ 16 | --data_path ${IMAGENET_DIR} 17 | ``` 18 | - Here the effective batch size is 64 (`batch_size` per gpu) * 8 (`nodes`) * 8 (gpus per node) = 4096. If memory or # gpus is limited, use `--accum_iter` to maintain the effective batch size, which is `batch_size` (per gpu) * `nodes` * 8 (gpus per node) * `accum_iter`. 19 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 20 | - Here we use `--norm_pix_loss` as the target for better representation learning. To train a baseline model (e.g., for visualization), use pixel-based construction and turn off `--norm_pix_loss`. 21 | - The exact same hyper-parameters and configs (initialization, augmentation, etc.) are used as our TF/TPU implementation. In our sanity checks, this PT/GPU re-implementation can reproduce the TF/TPU results within reasonable random variation. We get 85.5% [fine-tuning](FINETUNE.md) accuracy by pre-training ViT-Large for 800 epochs (85.4% in paper Table 1d with TF/TPU). 22 | - Training time is ~42h in 64 V100 GPUs (800 epochs). 23 | 24 | To train ViT-Base or ViT-Huge, set `--model mae_vit_base_patch16` or `--model mae_vit_huge_patch14`. 25 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import glob 5 | from tqdm import tqdm 6 | 7 | 8 | def read_DJ01_ctrls(path): 9 | with open(path, 'r') as f: 10 | lines = f.readlines() 11 | lines = [list(map(float,l.strip(' \n').split(' '))) for l in lines] 12 | return np.array(lines).squeeze() 13 | 14 | def load_rigs_to_cache(root, n_rig, version_old=None): 15 | n_folders = len(os.listdir(root)) 16 | pkl_path = root + f'{n_folders}.pkl' 17 | print('loading rig pkl:', pkl_path) 18 | rigs = {} 19 | 20 | load_rig = read_DJ01_ctrls 21 | print(root) 22 | 23 | if version_old: 24 | pkl_path_old = root + f'{version_old}.pkl' 25 | with open(pkl_path_old, 'rb') as f: 26 | rigs = pickle.load(f) 27 | 28 | if os.path.exists(pkl_path): 29 | with open(pkl_path, 'rb') as f: 30 | rigs = pickle.load(f) 31 | else: 32 | print('=> loading rigs to cache..') 33 | # folders = ['linjie_expr_test3'] 34 | # files_name = [] 35 | # for fold in folders: 36 | # files_name += [y for x in os.walk(os.path.join(root, fold)) for y in glob.glob(os.path.join(x[0], '*.txt'))] 37 | 38 | files_name = [y for x in os.walk(root) for y in glob.glob(os.path.join(x[0], '*.txt'))] 39 | files_name = [fn.replace(root, '').strip('/') for fn in files_name] 40 | files_name = [fn for fn in files_name if fn not in rigs.keys()] 41 | # rigs = {} 42 | for fname in tqdm(files_name, total=len(files_name)): 43 | try: 44 | rig = load_rig(os.path.join(root, fname)) 45 | except: 46 | print(fname) 47 | exit() 48 | 49 | rigs[fname] = rig 50 | with open(pkl_path, 'wb' ) as f: 51 | pickle.dump(rigs, f) 52 | return rigs -------------------------------------------------------------------------------- /mae-main/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /mae-main/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /choose_character.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def character_choice(character): 4 | print(f'=> Choose character: {character}') 5 | 6 | if character.lower() == 'l36_233': 7 | img_postfix = '.jpg' 8 | n_rig=61 9 | data_path = '/project/qiuf/DJ01/L36/images' 10 | mouth_left, mouth_right, mouth_top, mouth_bottom = map(int, np.array([150, 350, 350, 450]) * 256 / 512.) 11 | eye_left, eye_right, eye_top, eye_bottom = map(int, np.array([106, 404, 161, 266]) * 256 / 512.) 12 | ckpt_img2rig = None 13 | ckpt_rig2img = '/project/qiuf/expr-capture/ckpt/rig2img_20240211-045412.pt' 14 | elif character.lower() == 'l36_234': 15 | img_postfix = '.jpg' 16 | n_rig = 61 17 | data_path = '/project/qiuf/DJ01/L36_234/images' 18 | mouth_left, mouth_right, mouth_top, mouth_bottom = map(int, np.array([178, 360, 363, 451]) * 256 / 512.) 19 | eye_left, eye_right, eye_top, eye_bottom = map(int, np.array([119, 415, 136, 250]) * 256 / 512.) 20 | ckpt_img2rig = None 21 | ckpt_rig2img = '/project/qiuf/expr-capture/ckpt/rig2img_20240324-184156.pt' 22 | elif character.lower() == 'l36_230_61': 23 | img_postfix = '.png' 24 | n_rig = 67 25 | data_path = '/project/qiuf/DJ01/L36_230_61/images' 26 | mouth_left, mouth_right, mouth_top, mouth_bottom = map(int, np.array([194, 295, 312, 367]) * 256 / 512.) 27 | eye_left, eye_right, eye_top, eye_bottom = map(int, np.array([151, 337, 185, 261]) * 256 / 512.) 28 | ckpt_img2rig = '/project/qiuf/expr-capture/ckpt/img2rig_20240522-153804.pt' 29 | ckpt_rig2img = '/project/qiuf/expr-capture/ckpt/rig2img_20240425-180631.pt' 30 | else: 31 | raise NotImplementedError 32 | 33 | mouth_crop = np.zeros((3,256,256)) 34 | mouth_crop[:, mouth_top:mouth_bottom, mouth_left:mouth_right] = 1 35 | mouth_crop[:, eye_top:eye_bottom, eye_left:eye_right] = 1 36 | return { 37 | 'data_path': data_path, 38 | 'mouth_crop':mouth_crop, 39 | 'n_rig':n_rig, 40 | 'ckpt_img2rig': ckpt_img2rig, 41 | 'ckpt_rig2img': ckpt_rig2img, 42 | 'img_postfix': img_postfix 43 | } 44 | -------------------------------------------------------------------------------- /models/pipeline5.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from models.facenet2 import InceptionResnetV1 4 | import torch 5 | 6 | class Pipeline(nn.Module): 7 | """ 8 | DLN model without high-order here. Pretrained weights can be found in checkpoints directory. 9 | """ 10 | def __init__(self, out_dim=512): 11 | super(Pipeline,self).__init__() 12 | self.faceNet = InceptionResnetV1(pretrained="vggface2").eval() 13 | self.R_net = InceptionResnetV1(pretrained="vggface2") 14 | self.BN1 = nn.BatchNorm1d(512) 15 | self.BN2 = nn.BatchNorm1d(512) 16 | self.dropout = nn.Dropout(0.5) 17 | self.linear2 = nn.Linear(512,16,bias=False) 18 | self.out_dim = out_dim 19 | 20 | if out_dim == 1: 21 | self.last_D = nn.Linear(512,1,bias=True) 22 | 23 | def forward(self, x , c=1): 24 | """ 25 | Calculate expression embeddings or logits given a batch of input image tensors. 26 | :param x: Batch of image tensors representing faces. 27 | :return: Batch of embedding vectors or multinomial logits. 28 | """ 29 | with torch.no_grad(): 30 | id_feature_ = self.faceNet(x) 31 | id_feature = torch.sigmoid(id_feature_) 32 | x = self.R_net(x) 33 | x = torch.sigmoid(x) 34 | x = x-id_feature 35 | emb_16 = self.linear2(x) 36 | emb_16 = F.normalize(emb_16, dim=1) 37 | 38 | if self.out_dim == 1: 39 | x = self.last_D(x) 40 | if self.out_dim == 512 + 512: 41 | x = torch.cat([id_feature, x], dim=1) 42 | return x, emb_16 43 | 44 | def forward2(self, x): 45 | with torch.no_grad(): 46 | id_feature = self.faceNet(x) 47 | id_feature = torch.sigmoid(id_feature) 48 | x = self.R_net(x) 49 | x = torch.sigmoid(x) 50 | x = x - id_feature 51 | global_feature = x 52 | x = self.linear2(x) 53 | x = F.normalize(x, dim=1) 54 | return global_feature, x 55 | 56 | 57 | if __name__ == '__main__': 58 | net = Pipeline().cuda() 59 | x = torch.rand([16,3,224,224]).cuda() 60 | res= net(x) 61 | print(res.shape) 62 | print(sum(param.numel() for param in net.parameters())) -------------------------------------------------------------------------------- /mae-main/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | import timm.models.vision_transformer 18 | 19 | 20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 21 | """ Vision Transformer with support for global average pooling 22 | """ 23 | def __init__(self, global_pool=False, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | 26 | self.global_pool = global_pool 27 | if self.global_pool: 28 | norm_layer = kwargs['norm_layer'] 29 | embed_dim = kwargs['embed_dim'] 30 | self.fc_norm = norm_layer(embed_dim) 31 | 32 | del self.norm # remove the original norm 33 | 34 | def forward_features(self, x): 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | if self.global_pool: 47 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 48 | outcome = self.fc_norm(x) 49 | else: 50 | x = self.norm(x) 51 | outcome = x[:, 0] 52 | 53 | return outcome 54 | 55 | 56 | def vit_base_patch16(**kwargs): 57 | model = VisionTransformer( 58 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 59 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 60 | return model 61 | 62 | 63 | def vit_large_patch16(**kwargs): 64 | model = VisionTransformer( 65 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 66 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 67 | return model 68 | 69 | 70 | def vit_huge_patch14(**kwargs): 71 | model = VisionTransformer( 72 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 74 | return model -------------------------------------------------------------------------------- /mae-main/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import yaml 3 | import torch 4 | import random 5 | from torch import nn 6 | import pickle 7 | import torch.nn.init as init 8 | import numpy as np 9 | from types import SimpleNamespace 10 | import os 11 | 12 | def load_pickle(path): 13 | with open(path, 'rb') as f: 14 | return pickle.load(f) 15 | 16 | def setup_seed(seed): 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | np.random.seed(seed) 20 | random.seed(seed) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | def init_weights(m): 24 | if isinstance(m, nn.Linear): 25 | init.xavier_uniform_(m.weight) 26 | if m.bias is not None: 27 | init.zeros_(m.bias) 28 | # 为其他类型的层添加初始化,例如卷积层 29 | elif isinstance(m, nn.Conv2d): 30 | init.xavier_uniform_(m.weight) 31 | if m.bias is not None: 32 | init.zeros_(m.bias) 33 | elif isinstance(m, nn.GRU): 34 | for param in m.parameters(): 35 | if len(param.shape) >= 2: # 对于权重矩阵 36 | init.xavier_uniform_(param.data) 37 | else: # 对于偏置项 38 | init.zeros_(param.data) 39 | 40 | def tensors_to_cuda(data): 41 | """ 42 | Recursively move all tensors in the input data to CUDA, if CUDA is available. 43 | 44 | :param data: A dictionary which may contain other dictionaries or torch.Tensors. 45 | :return: Same structure as input with all tensors moved to CUDA. 46 | """ 47 | if torch.cuda.is_available(): 48 | if isinstance(data, dict): 49 | # Recursively apply to dictionary elements 50 | return {key: tensors_to_cuda(value) for key, value in data.items()} 51 | elif isinstance(data, torch.Tensor): 52 | # Move tensor to CUDA 53 | return data.to('cuda') 54 | else: 55 | # If data is not a dictionary or tensor, return it as is 56 | return data 57 | else: 58 | # If CUDA is not available, return the data unchanged 59 | return data 60 | 61 | 62 | def parse_args_from_yaml(yaml_path): 63 | with open(yaml_path, 'r') as file: 64 | config = yaml.safe_load(file) 65 | config = SimpleNamespace(**config) 66 | 67 | return config 68 | 69 | def imgs2video(image_folder, video_name='', fps=30): 70 | if not video_name: 71 | video_name = image_folder+'.mp4' 72 | 73 | images = [img for img in os.listdir(image_folder) if img.split('.')[-1] in ['png', 'jpg', 'jpeg']] 74 | images.sort() 75 | 76 | frames = [] 77 | for image_file in images: 78 | frames.append(imageio.imread(os.path.join(image_folder, image_file))) 79 | imageio.mimsave(video_name, frames, 'FFMPEG', fps=fps) 80 | 81 | print('output video: ', video_name) 82 | return True 83 | 84 | def count_parameters(model, verbose=False): 85 | total_params = sum(p.numel() for p in model.parameters()) 86 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 87 | if verbose: 88 | for name, param in model.named_parameters(): 89 | print(f'Layer: {name} | Parameters: {param.numel()}') 90 | return {"Total": total_params, "Trainable": trainable_params, "Non-Trainable": total_params - trainable_params} -------------------------------------------------------------------------------- /mae-main/engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | import sys 13 | from typing import Iterable 14 | 15 | import torch 16 | 17 | import util.misc as misc 18 | import util.lr_sched as lr_sched 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, 24 | log_writer=None, 25 | args=None): 26 | model.train(True) 27 | metric_logger = misc.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 20 31 | 32 | accum_iter = args.accum_iter 33 | 34 | optimizer.zero_grad() 35 | 36 | if log_writer is not None: 37 | print('log_dir: {}'.format(log_writer.log_dir)) 38 | 39 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 40 | 41 | # we use a per iteration (instead of per epoch) lr scheduler 42 | if data_iter_step % accum_iter == 0: 43 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 44 | 45 | samples = samples.to(device, non_blocking=True) 46 | 47 | with torch.cuda.amp.autocast(): 48 | loss, _, _ = model(samples, mask_ratio=args.mask_ratio) 49 | 50 | loss_value = loss.item() 51 | 52 | if not math.isfinite(loss_value): 53 | print("Loss is {}, stopping training".format(loss_value)) 54 | sys.exit(1) 55 | 56 | loss /= accum_iter 57 | loss_scaler(loss, optimizer, parameters=model.parameters(), 58 | update_grad=(data_iter_step + 1) % accum_iter == 0) 59 | if (data_iter_step + 1) % accum_iter == 0: 60 | optimizer.zero_grad() 61 | 62 | torch.cuda.synchronize() 63 | 64 | metric_logger.update(loss=loss_value) 65 | 66 | lr = optimizer.param_groups[0]["lr"] 67 | metric_logger.update(lr=lr) 68 | 69 | loss_value_reduce = misc.all_reduce_mean(loss_value) 70 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 71 | """ We use epoch_1000x as the x-axis in tensorboard. 72 | This calibrates different curves when batch size changes. 73 | """ 74 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 75 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 76 | log_writer.add_scalar('lr', lr, epoch_1000x) 77 | 78 | 79 | # gather the stats from all processes 80 | metric_logger.synchronize_between_processes() 81 | print("Averaged stats:", metric_logger) 82 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /exp_emb_code/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import os 5 | import torch.utils.data as data 6 | from PIL import Image 7 | from PIL import ImageFile 8 | import torchvision.transforms.transforms as transforms 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | from tqdm import tqdm 11 | from timm.data import create_transform 12 | import PIL 13 | import pickle 14 | 15 | 16 | class FecData(data.dataset.Dataset): 17 | def __init__(self, csv_file, img_path, transform=None): 18 | self.transform = transform 19 | 20 | self.csv_file = csv_file 21 | self.img_path = img_path 22 | 23 | self.data_anc = [] 24 | self.data_pos = [] 25 | self.data_neg = [] 26 | self.type = [] 27 | 28 | self.pd_data = pd.read_csv(self.csv_file) 29 | self.data = self.pd_data.to_dict("list") 30 | anc, pos, neg, tys = self.data["anchor"],self.data["positive"],self.data["negative"], self.data["type"] 31 | self.data_anc = [os.path.join(self.img_path, k) for k in anc] 32 | self.data_pos = [os.path.join(self.img_path, k) for k in pos] 33 | self.data_neg = [os.path.join(self.img_path, k) for k in neg] 34 | self.type = tys 35 | 36 | 37 | def __len__(self): 38 | return 100 39 | #return len(self.data_anc) 40 | 41 | def __getitem__(self, index): 42 | type = self.type[index] 43 | anc_list = self.data_anc[index] 44 | pos_list = self.data_pos[index] 45 | neg_list = self.data_neg[index] 46 | 47 | anc_img = Image.open(anc_list).convert('RGB') 48 | pos_img = Image.open(pos_list).convert('RGB') 49 | neg_img = Image.open(neg_list).convert('RGB') 50 | 51 | if self.transform is not None: 52 | anc_img = self.transform(anc_img) 53 | pos_img = self.transform(pos_img) 54 | neg_img = self.transform(neg_img) 55 | 56 | dict = { 57 | "name" : anc_list, 58 | "anc":anc_img, 59 | "pos":pos_img, 60 | "neg":neg_img, 61 | "type":type 62 | } 63 | 64 | return dict 65 | 66 | 67 | def build_transform(is_train): 68 | mean = [0.49895147219604985,0.4104390648367995,0.3656147590417074] 69 | std = [0.2970847084907291,0.2699003075660314,0.2652599579468044] 70 | input_size = 224 71 | if is_train: 72 | transform = create_transform( 73 | input_size=224, 74 | is_training=True, 75 | scale=(0.08,1.0), 76 | ratio=(7/8,8/7), 77 | color_jitter=None, 78 | auto_augment='rand-m9-mstd0.5-inc1', 79 | interpolation='bicubic', 80 | re_prob=0.25, 81 | re_mode='pixel', 82 | re_count=1, 83 | mean=mean, 84 | std=std, 85 | ) 86 | return transform 87 | 88 | if input_size <= 224: 89 | crop_pct = 224 / 256 90 | else: 91 | crop_pct = 1.0 92 | size = int(input_size / crop_pct) 93 | t = [ 94 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), 95 | transforms.CenterCrop(224), 96 | transforms.ToTensor(), 97 | transforms.Normalize(mean, std), 98 | ] 99 | return transforms.Compose(t) 100 | 101 | def build_dataset(config,mode): 102 | train_transform = build_transform(True) 103 | val_transform = build_transform(False) 104 | 105 | dataset = None 106 | if mode == "train": 107 | dataset = FecData(config["train_csv"],config["train_img_path"],train_transform) 108 | elif mode == "val": 109 | dataset = FecData(config["val_csv"],config["val_img_path"],train_transform) 110 | return dataset -------------------------------------------------------------------------------- /mae-main/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /models/gan_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class GANLoss(nn.Module): 6 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 7 | tensor=torch.FloatTensor, opt=None): 8 | super(GANLoss, self).__init__() 9 | self.real_label = target_real_label 10 | self.fake_label = target_fake_label 11 | self.real_label_tensor = None 12 | self.fake_label_tensor = None 13 | self.zero_tensor = None 14 | self.Tensor = tensor 15 | self.gan_mode = gan_mode 16 | self.opt = opt 17 | if gan_mode == 'ls': 18 | pass 19 | elif gan_mode == 'original': 20 | pass 21 | elif gan_mode == 'w': 22 | pass 23 | elif gan_mode == 'hinge': 24 | pass 25 | else: 26 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 27 | 28 | def get_target_tensor(self, input, target_is_real): 29 | if target_is_real: 30 | if self.real_label_tensor is None: 31 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label).type_as(input) 32 | self.real_label_tensor.requires_grad_(False) 33 | return self.real_label_tensor.expand_as(input) 34 | else: 35 | if self.fake_label_tensor is None: 36 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label).type_as(input) 37 | self.fake_label_tensor.requires_grad_(False) 38 | return self.fake_label_tensor.expand_as(input) 39 | 40 | def get_zero_tensor(self, input): 41 | if self.zero_tensor is None: 42 | self.zero_tensor = self.Tensor(1).fill_(0).type_as(input) 43 | self.zero_tensor.requires_grad_(False) 44 | return self.zero_tensor.expand_as(input) 45 | 46 | def loss(self, input, target_is_real, for_discriminator=True): 47 | if self.gan_mode == 'original': # cross entropy loss 48 | target_tensor = self.get_target_tensor(input, target_is_real) 49 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 50 | return loss 51 | elif self.gan_mode == 'ls': 52 | target_tensor = self.get_target_tensor(input, target_is_real) 53 | return F.mse_loss(input, target_tensor) 54 | elif self.gan_mode == 'hinge': 55 | if for_discriminator: 56 | if target_is_real: 57 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 58 | loss = -torch.mean(minval) 59 | else: 60 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 61 | loss = -torch.mean(minval) 62 | else: 63 | assert target_is_real, "The generator's hinge loss must be aiming for real" 64 | loss = -torch.mean(input) 65 | return loss 66 | else: 67 | # wgan 68 | if target_is_real: 69 | return -input.mean() 70 | else: 71 | return input.mean() 72 | 73 | def __call__(self, input, target_is_real, for_discriminator=True): 74 | # computing loss is a bit complicated because |input| may not be 75 | # a tensor, but list of tensors in case of multiscale discriminator 76 | if isinstance(input, list): 77 | loss = 0 78 | for pred_i in input: 79 | if isinstance(pred_i, list): 80 | pred_i = pred_i[-1] 81 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 82 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 83 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 84 | loss += new_loss 85 | return loss / len(input) 86 | else: 87 | return self.loss(input, target_is_real, for_discriminator) 88 | 89 | def d_logistic_loss(real_pred, fake_pred): 90 | real_loss = F.softplus(-real_pred) 91 | fake_loss = F.softplus(fake_pred) 92 | 93 | return real_loss.mean(), fake_loss.mean() 94 | 95 | def g_nonsaturating_loss(fake_pred): 96 | loss = F.softplus(-fake_pred).mean() 97 | 98 | return loss -------------------------------------------------------------------------------- /mae-main/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FreeAvatar:Robust 3D Facial Animation Transfer by Learning an Expression Foundation Model 2 | 3 | ![示例图片](./pics/teaser.png) 4 | 5 | ## Abstract 6 | > Video-driven 3D facial animation transfer aims to drive avatars to reproduce the expressions of actors. Existing methods have achieved remarkable results by constraining both geometric and perceptual consistency. However, geometric constraints (like those designed on facial landmarks) are insufficient to capture subtle emotions, while expression features trained on classification tasks lack fine granularity for complex emotions. To address this, we propose FreeAvatar, a robust facial animation transfer method that relies solely on our learned expression representation. Specifically, FreeAvatar consists of two main components: the expression foundation model and the facial animation transfer model. In the first component, we initially construct a facial feature space through a face reconstruction task and then optimize the expression feature space by exploring the similarities among different expressions. Benefiting from training on the amounts of unlabeled facial images and re-collected expression comparison dataset, our model adapts freely and effectively to any in-the-wild input facial images. In the facial animation transfer component, we propose a novel Expression-driven Multi-avatar Animator, which first maps expressive semantics to the facial control parameters of 3D avatars and then imposes perceptual constraints between the input and output images to maintain expression consistency. To make the entire process differentiable, we employ a trained neural renderer to translate rig parameters into corresponding images. Furthermore, unlike previous methods that require separate decoders for each avatar, we propose a dynamic identity injection module that allows for the joint training of multiple avatars within a single network. The comparisons show that our method achieves prominent performance even without introducing any geometric constraints, highlighting the robustness of our FreeAvatar. 7 | 8 | 9 | ## Requirements 10 | - imageio==2.28.1 11 | - matplotlib==3.7.5 12 | - numpy==1.23.5 13 | - opencv_contrib_python==4.8.0.76 14 | - opencv_python==4.9.0.80 15 | - opencv_python_headless==4.8.1.78 16 | - Pillow==10.4.0 17 | - PyYAML==6.0.1 18 | - setuptools==45.2.0 19 | - skimage==0.0 20 | - timm==0.9.2 21 | - torch==2.1.0 22 | - torchvision==0.16.0 23 | - tqdm==4.65.0 24 | 25 | This implementation is only tested under the device: 26 | 27 | - System: Unbuntu 18.04 28 | - GPU: A30 29 | - Cuda Version: 12.0 30 | - Cuda Driver Version: 525.78.01 31 | 32 | We also used the MAE PyTorch/GPU implementation for pre-training our facial expression foundation model. For more dependencies, please refer to https://github.com/facebookresearch/mae. 33 | 34 | 35 | ## Method 36 | ![alt text](pics/pipeline.png) 37 | Our FreeAvatar includes three components: 38 | - Facial Feature Space Construction 39 | - Expression Feature Space Optimization 40 | - Expression-driven Multi-avatar Animator 41 | 42 | ### Facial Feature Space Construction 43 | 44 | We used the ViT-B model of MAE for pre-training the facial reconstruction task. We used the ViT-B model for pre-training the facial reconstruction task. For more training details, please refer to https://github.com/facebookresearch/mae. 45 | 46 | 47 | ### Expression Feature Space Optimization 48 | 49 | - Finetuing the ViT encoder on expression comparison triplets. 50 | ``` 51 | python train.py 52 | ``` 53 | 54 | 55 | ### Expression-driven Multi-avatar Animator 56 | 57 | - Training Neural Renderer. 58 | ``` 59 | python train_rig2img.py 60 | ``` 61 | 62 | 63 | - Training Rig Parameter Decoder. 64 | ``` 65 | python train_emb2rig_multichar.py 66 | ``` 67 | The data path and training parameters can be modified in the configuration file `configs_emb2rig_multi.yaml`. To perform testing, specify pretrained and set `mode='test'`. 68 | 69 | ## Citation 70 | If you use this project for your research, please consider citing: 71 | 72 | ## Contact 73 | If you have any questions, please contact 74 | - Feng Qiu (qiufeng@corp.netease.com) 75 | - Wei Zhang (zhangwei05@corp.netase.com) 76 | - Lincheng Li (lilincheng@corp.netease.com) 77 | 78 | ## Acknowledgement 79 | 80 | There are some functions or scripts in this implementation that are based on external sources. We thank the authors for their excellent works. Here are some great resources we benefit: 81 | 82 | - [MAE_pytorch](https://github.com/facebookresearch/mae) for the expression foundation training code. 83 | - [Nvdiffrast](https://github.com/NVlabs/nvdiffrast) for differentiable rendering. 84 | - [Fuxi Youling Platform](https://fuxi.163.com/) for data collection and annotation. -------------------------------------------------------------------------------- /exp_emb_code/model/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | 16 | import timm.models.vision_transformer 17 | #import model.vision_transformer 18 | 19 | 20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 21 | """ Vision Transformer with support for global average pooling 22 | """ 23 | def __init__(self, global_pool=False, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | self.pretrained=True 26 | self.global_pool = global_pool 27 | if self.global_pool: 28 | norm_layer = kwargs['norm_layer'] 29 | embed_dim = kwargs['embed_dim'] 30 | self.fc_norm = norm_layer(embed_dim) 31 | 32 | del self.norm # remove the original norm 33 | 34 | def forward_features(self, x): 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | if self.global_pool: 47 | x = x[:, 1:, :] # without cls token (N, L=14*14, D=768=16*16*3) 48 | x = x.mean(dim=1) # global average pooling (N, D=768) 49 | outcome = self.fc_norm(x) # Layer Normalization (N, D=768) 50 | else: 51 | x = self.norm(x) 52 | outcome = x[:, 0] 53 | 54 | return outcome 55 | 56 | # borrow from timm 57 | def forward(self, x, ret_feature=False): 58 | x = self.forward_features(x) 59 | feature = x 60 | if getattr(self, 'head_dist', None) is not None: 61 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 62 | if self.training and not torch.jit.is_scripting(): 63 | # during inference, return the average of both classifier predictions 64 | return x, x_dist 65 | else: 66 | return (x + x_dist) / 2 67 | else: 68 | x = self.head(x) 69 | # return 70 | if ret_feature: 71 | return x, feature 72 | else: 73 | return x 74 | 75 | 76 | # setup model archs 77 | VIT_KWARGS_BASE = dict(mlp_ratio=4, qkv_bias=True, 78 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6)) 79 | 80 | VIT_KWARGS_PRESETS = { 81 | 'micro': dict(patch_size=16, embed_dim=96, depth=12, num_heads=2), 82 | 'mini': dict(patch_size=16, embed_dim=128, depth=12, num_heads=2), 83 | 'tiny_d6': dict(patch_size=16, embed_dim=192, depth=6, num_heads=3), 84 | 'tiny': dict(patch_size=16, embed_dim=192, depth=12, num_heads=3), 85 | 'small': dict(patch_size=16, embed_dim=384, depth=12, num_heads=6), 86 | 'base': dict(patch_size=16, embed_dim=768, depth=12, num_heads=12), 87 | 'large': dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16), 88 | 'huge': dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16), 89 | 'giant': dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11), 90 | 'gigantic': dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=64/13), 91 | } 92 | 93 | def create_vit_model(preset=None, creator=None, **kwargs): 94 | preset = 'base' if preset is None else preset.lower() 95 | all_kwargs = dict() 96 | all_kwargs.update(VIT_KWARGS_BASE) 97 | all_kwargs.update(VIT_KWARGS_PRESETS[preset]) 98 | all_kwargs.update(kwargs) 99 | if creator is None: 100 | creator = VisionTransformer 101 | return creator(**all_kwargs) 102 | 103 | vit_micro_patch16 = partial(create_vit_model, preset='micro') 104 | vit_mini_patch16 = partial(create_vit_model, preset='mini') 105 | vit_tiny_d6_patch16 = partial(create_vit_model, preset='tiny_d6') 106 | vit_tiny_patch16 = partial(create_vit_model, preset='tiny') 107 | vit_small_patch16 = partial(create_vit_model, preset='small') 108 | vit_base_patch16 = partial(create_vit_model, preset='base') 109 | vit_large_patch16 = partial(create_vit_model, preset='large') 110 | vit_huge_patch14 = partial(create_vit_model, preset='huge') 111 | vit_giant_patch14 = partial(create_vit_model, preset='giant') 112 | vit_gigantic_patch14 = partial(create_vit_model, preset='gigantic') 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /mae-main/submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_pretrain as trainer 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | trainer_parser = trainer.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE pretrain", parents=[trainer_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_pretrain as trainer 57 | 58 | self._setup_gpu_args() 59 | trainer.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, # max is 60 * 72 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /mae-main/submitit_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_finetune as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_finetune as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /mae-main/submitit_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_linprobe as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_linprobe as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /mae-main/engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy 20 | 21 | import util.misc as misc 22 | import util.lr_sched as lr_sched 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 28 | mixup_fn: Optional[Mixup] = None, log_writer=None, 29 | args=None): 30 | model.train(True) 31 | metric_logger = misc.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 20 35 | 36 | accum_iter = args.accum_iter 37 | 38 | optimizer.zero_grad() 39 | 40 | if log_writer is not None: 41 | print('log_dir: {}'.format(log_writer.log_dir)) 42 | 43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | 45 | # we use a per iteration (instead of per epoch) lr scheduler 46 | if data_iter_step % accum_iter == 0: 47 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 48 | 49 | samples = samples.to(device, non_blocking=True) 50 | targets = targets.to(device, non_blocking=True) 51 | 52 | if mixup_fn is not None: 53 | samples, targets = mixup_fn(samples, targets) 54 | 55 | with torch.cuda.amp.autocast(): 56 | outputs = model(samples) 57 | loss = criterion(outputs, targets) 58 | 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | sys.exit(1) 64 | 65 | loss /= accum_iter 66 | loss_scaler(loss, optimizer, clip_grad=max_norm, 67 | parameters=model.parameters(), create_graph=False, 68 | update_grad=(data_iter_step + 1) % accum_iter == 0) 69 | if (data_iter_step + 1) % accum_iter == 0: 70 | optimizer.zero_grad() 71 | 72 | torch.cuda.synchronize() 73 | 74 | metric_logger.update(loss=loss_value) 75 | min_lr = 10. 76 | max_lr = 0. 77 | for group in optimizer.param_groups: 78 | min_lr = min(min_lr, group["lr"]) 79 | max_lr = max(max_lr, group["lr"]) 80 | 81 | metric_logger.update(lr=max_lr) 82 | 83 | loss_value_reduce = misc.all_reduce_mean(loss_value) 84 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 85 | """ We use epoch_1000x as the x-axis in tensorboard. 86 | This calibrates different curves when batch size changes. 87 | """ 88 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 89 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 90 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 91 | 92 | # gather the stats from all processes 93 | metric_logger.synchronize_between_processes() 94 | print("Averaged stats:", metric_logger) 95 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 96 | 97 | 98 | @torch.no_grad() 99 | def evaluate(data_loader, model, device): 100 | criterion = torch.nn.CrossEntropyLoss() 101 | 102 | metric_logger = misc.MetricLogger(delimiter=" ") 103 | header = 'Test:' 104 | 105 | # switch to evaluation mode 106 | model.eval() 107 | 108 | for batch in metric_logger.log_every(data_loader, 10, header): 109 | images = batch[0] 110 | target = batch[-1] 111 | images = images.to(device, non_blocking=True) 112 | target = target.to(device, non_blocking=True) 113 | 114 | # compute output 115 | with torch.cuda.amp.autocast(): 116 | output = model(images) 117 | loss = criterion(output, target) 118 | 119 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 120 | 121 | batch_size = images.shape[0] 122 | metric_logger.update(loss=loss.item()) 123 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 124 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 125 | # gather the stats from all processes 126 | metric_logger.synchronize_between_processes() 127 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 128 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 129 | 130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /mae-main/README.md: -------------------------------------------------------------------------------- 1 | ## Masked Autoencoders: A PyTorch Implementation 2 | 3 |

4 | 5 |

6 | 7 | 8 | This is a PyTorch/GPU re-implementation of the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377): 9 | ``` 10 | @Article{MaskedAutoencoders2021, 11 | author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Doll{\'a}r and Ross Girshick}, 12 | journal = {arXiv:2111.06377}, 13 | title = {Masked Autoencoders Are Scalable Vision Learners}, 14 | year = {2021}, 15 | } 16 | ``` 17 | 18 | * The original implementation was in TensorFlow+TPU. This re-implementation is in PyTorch+GPU. 19 | 20 | * This repo is a modification on the [DeiT repo](https://github.com/facebookresearch/deit). Installation and preparation follow that repo. 21 | 22 | * This repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. 23 | 24 | ### Catalog 25 | 26 | - [x] Visualization demo 27 | - [x] Pre-trained checkpoints + fine-tuning code 28 | - [x] Pre-training code 29 | 30 | ### Visualization demo 31 | 32 | Run our interactive visualization demo using [Colab notebook](https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb) (no GPU needed): 33 |

34 | 35 |

36 | 37 | ### Fine-tuning with pre-trained checkpoints 38 | 39 | The following table provides the pre-trained checkpoints used in the paper, converted from TF/TPU to PT/GPU: 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 |
ViT-BaseViT-LargeViT-Huge
pre-trained checkpointdownloaddownloaddownload
md58cad7cb8b06e9bdbb0
59 | 60 | The fine-tuning instruction is in [FINETUNE.md](FINETUNE.md). 61 | 62 | By fine-tuning these pre-trained models, we rank #1 in these classification tasks (detailed in the paper): 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 |
ViT-BViT-LViT-HViT-H448prev best
ImageNet-1K (no external data)83.685.986.987.887.1
following are evaluation of the same model weights (fine-tuned in original ImageNet-1K):
ImageNet-Corruption (error rate) 51.741.833.836.842.5
ImageNet-Adversarial35.957.168.276.735.8
ImageNet-Rendition48.359.964.466.548.7
ImageNet-Sketch34.545.349.650.936.0
following are transfer learning by fine-tuning the pre-trained MAE on the target dataset:
iNaturalists 201770.575.779.383.475.4
iNaturalists 201875.480.183.086.881.2
iNaturalists 201980.583.485.788.384.1
Places20563.965.865.966.866.0
Places36557.959.459.860.358.0
149 | 150 | ### Pre-training 151 | 152 | The pre-training instruction is in [PRETRAIN.md](PRETRAIN.md). 153 | 154 | ### License 155 | 156 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 157 | -------------------------------------------------------------------------------- /models/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | import torch 14 | import timm.models.vision_transformer 15 | import torchvision.transforms.transforms as transforms 16 | from timm.data import create_transform 17 | import PIL 18 | 19 | 20 | 21 | def build_transform(is_train): 22 | # mean = IMAGENET_DEFAULT_MEAN 23 | # std = IMAGENET_DEFAULT_STD 24 | mean = [0.49895147219604985,0.4104390648367995,0.3656147590417074] 25 | std = [0.2970847084907291,0.2699003075660314,0.2652599579468044] 26 | # train transform 27 | input_size = 224 28 | if is_train: 29 | # this should always dispatch to transforms_imagenet_train 30 | transform = create_transform( 31 | input_size=224, 32 | is_training=True, 33 | scale=(0.08,1.0), 34 | ratio=(7/8,8/7), 35 | color_jitter=None, 36 | auto_augment='rand-m9-mstd0.5-inc1', 37 | interpolation='bicubic', 38 | re_prob=0.25, 39 | re_mode='pixel', 40 | re_count=1, 41 | mean=mean, 42 | std=std, 43 | ) 44 | return transform 45 | 46 | # eval transform 47 | if input_size <= 224: 48 | crop_pct = 224 / 256 49 | else: 50 | crop_pct = 1.0 51 | size = int(input_size / crop_pct) 52 | t = [ 53 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), 54 | # transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), # to maintain same ratio w.r.t. 224 images 55 | transforms.CenterCrop(224), 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean, std), 58 | ] 59 | return transforms.Compose(t) 60 | 61 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 62 | """ Vision Transformer with support for global average pooling 63 | """ 64 | def __init__(self, global_pool=False, **kwargs): 65 | super(VisionTransformer, self).__init__(**kwargs) 66 | 67 | self.global_pool = global_pool 68 | if self.global_pool: 69 | norm_layer = kwargs['norm_layer'] 70 | embed_dim = kwargs['embed_dim'] 71 | self.fc_norm = norm_layer(embed_dim) 72 | 73 | del self.norm # remove the original norm 74 | 75 | def forward_features(self, x): 76 | B = x.shape[0] 77 | x = self.patch_embed(x) 78 | 79 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 80 | x = torch.cat((cls_tokens, x), dim=1) 81 | x = x + self.pos_embed 82 | x = self.pos_drop(x) 83 | 84 | for blk in self.blocks: 85 | x = blk(x) 86 | 87 | if self.global_pool: 88 | x = x[:, 1:, :] # without cls token (N, L=14*14, D=768=16*16*3) 89 | x = x.mean(dim=1) # global average pooling (N, D=768) 90 | outcome = self.fc_norm(x) # Layer Normalization (N, D=768) 91 | else: 92 | x = self.norm(x) 93 | outcome = x[:, 0] 94 | 95 | return outcome 96 | 97 | # borrow from timm 98 | def forward2(self, x, ret_feature=False): 99 | x = self.forward_features(x) 100 | feature = x 101 | if getattr(self, 'head_dist', None) is not None: 102 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 103 | if self.training and not torch.jit.is_scripting(): 104 | # during inference, return the average of both classifier predictions 105 | return x, x_dist 106 | else: 107 | return (x + x_dist) / 2 108 | else: 109 | x = self.head(x) 110 | # return 111 | if ret_feature: 112 | return x, feature 113 | else: 114 | return x 115 | 116 | def forward(self, x, ret_feature=False): 117 | B = x.shape[0] 118 | x = self.patch_embed(x) 119 | 120 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 121 | x = torch.cat((cls_tokens, x), dim=1) 122 | x = x + self.pos_embed 123 | x = self.pos_drop(x) 124 | 125 | for blk in self.blocks: 126 | x = blk(x) 127 | 128 | if self.global_pool: 129 | x = x[:, 1:, :] # without cls token (N, L=14*14, D=768=16*16*3) 130 | x = x.mean(dim=1) # global average pooling (N, D=768) 131 | outcome = x 132 | else: 133 | x = self.norm(x) 134 | outcome = x[:, 0] 135 | 136 | return outcome, outcome 137 | 138 | 139 | # setup model archs 140 | VIT_KWARGS_BASE = dict(mlp_ratio=4, qkv_bias=True, 141 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6)) 142 | 143 | VIT_KWARGS_PRESETS = { 144 | 'micro': dict(patch_size=16, embed_dim=96, depth=12, num_heads=2), 145 | 'mini': dict(patch_size=16, embed_dim=128, depth=12, num_heads=2), 146 | 'tiny_d6': dict(patch_size=16, embed_dim=192, depth=6, num_heads=3), 147 | 'tiny': dict(patch_size=16, embed_dim=192, depth=12, num_heads=3), 148 | 'small': dict(patch_size=16, embed_dim=384, depth=12, num_heads=6), 149 | 'base': dict(patch_size=16, embed_dim=768, depth=12, num_heads=12), 150 | 'large': dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16), 151 | 'huge': dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16), 152 | 'giant': dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11), 153 | 'gigantic': dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=64/13), 154 | } 155 | 156 | def create_vit_model(preset=None, creator=None, **kwargs): 157 | preset = 'base' if preset is None else preset.lower() 158 | all_kwargs = dict() 159 | all_kwargs.update(VIT_KWARGS_BASE) 160 | all_kwargs.update(VIT_KWARGS_PRESETS[preset]) 161 | all_kwargs.update(kwargs) 162 | if creator is None: 163 | creator = VisionTransformer 164 | return creator(**all_kwargs) 165 | 166 | vit_micro_patch16 = partial(create_vit_model, preset='micro') 167 | vit_mini_patch16 = partial(create_vit_model, preset='mini') 168 | vit_tiny_d6_patch16 = partial(create_vit_model, preset='tiny_d6') 169 | vit_tiny_patch16 = partial(create_vit_model, preset='tiny') 170 | vit_small_patch16 = partial(create_vit_model, preset='small') 171 | vit_base_patch16 = partial(create_vit_model, preset='base') 172 | vit_large_patch16 = partial(create_vit_model, preset='large') 173 | vit_huge_patch14 = partial(create_vit_model, preset='huge') 174 | vit_giant_patch14 = partial(create_vit_model, preset='giant') 175 | vit_gigantic_patch14 = partial(create_vit_model, preset='gigantic') 176 | -------------------------------------------------------------------------------- /exp_emb_code/train.py: -------------------------------------------------------------------------------- 1 | from dataset import build_dataset 2 | from model.mae_pipeline import Pipeline 3 | from torch import optim 4 | from utils.metrics import triplet_prediction_accuracy 5 | from torch.utils.tensorboard import SummaryWriter 6 | from torch.nn.modules.distance import PairwiseDistance 7 | from tqdm import tqdm 8 | import yaml 9 | import argparse 10 | import utils.misc as misc 11 | from torch.utils.data import DataLoader 12 | import torch 13 | import torch.nn as nn 14 | import math 15 | import os 16 | 17 | def read_yaml_to_dict(yaml_path): 18 | with open(yaml_path) as file: 19 | dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) 20 | return dict_value 21 | 22 | 23 | def compute_loss(anc_fea, pos_fea, neg_fea, type): 24 | criterion1 = nn.TripletMarginLoss(margin=0.1) 25 | criterion2 = nn.TripletMarginLoss(margin=0.2) 26 | criterion3 = nn.TripletMarginLoss(margin=0.1) 27 | l2_dist = PairwiseDistance(2) 28 | loss = 0 29 | for i in range(len(type)): 30 | anc_ = anc_fea[i].unsqueeze(0) 31 | pos_ = pos_fea[i].unsqueeze(0) 32 | neg_ = neg_fea[i].unsqueeze(0) 33 | 34 | if type[i] == "ONE_CLASS_TRIPLET": 35 | loss += criterion1(anc_, pos_, neg_) + criterion1(pos_, anc_, neg_) 36 | else: 37 | loss += criterion2(anc_, pos_, neg_) + criterion2(pos_, anc_, neg_) 38 | 39 | loss = loss/anc_fea.shape[0] 40 | dists1 = l2_dist.forward(anc_fea, pos_fea).data.cpu().numpy() 41 | dists2 = l2_dist.forward(anc_fea, neg_fea).data.cpu().numpy() 42 | dists3 = l2_dist.forward(pos_fea, neg_fea).data.cpu().numpy() 43 | return loss,dists1,dists2,dists3 44 | 45 | 46 | 47 | def main(config): 48 | trainset = build_dataset(config, "train") 49 | valset = build_dataset(config, "val") 50 | trainloader = DataLoader(trainset, batch_size=config["batch_size"],num_workers=config["num_workers"],shuffle=True) 51 | valloader = DataLoader(valset, batch_size=config["batch_size"],num_workers=config["num_workers"],shuffle=True) 52 | 53 | model = Pipeline(config).cuda() 54 | 55 | if config["resume"] != None: 56 | state_dict = torch.load(config["resume"]) 57 | model.load_state_dict(state_dict) 58 | 59 | if config["use_dp"] == True: 60 | gpus = config["device"] 61 | model = torch.nn.DataParallel(model, device_ids=gpus) 62 | 63 | num_epochs = config["num_epochs"] 64 | 65 | if config["optim"] == "SGD": 66 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config["lr"], momentum=config["momentum"], 67 | weight_decay=config["weight_decay"]) 68 | elif config["optim"] == "AdamW": 69 | optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=config["lr"], betas=(0.9,0.999), 70 | weight_decay=config["weight_decay"]) 71 | 72 | metric_logger = misc.MetricLogger(delimiter=" ") 73 | test_metric_logger = misc.MetricLogger(delimiter=" ") 74 | log_writer = SummaryWriter(log_dir=config["log_dir"]) 75 | 76 | os.makedirs(config["log_dir"],exist_ok=True) 77 | os.makedirs(config["checkpoint_dir"],exist_ok=True) 78 | 79 | 80 | for epoch in range(num_epochs): 81 | train_one_epoch(epoch, model, trainloader, metric_logger, log_writer, optimizer, config) 82 | evaluate(epoch, model, valloader, test_metric_logger, config) 83 | 84 | 85 | 86 | def train_one_epoch(epoch, model, data_loader, metric_logger, log_writer, optimizer, config): 87 | model.train(True) 88 | print_freq = config["print_freq"] 89 | accum_iter = config["accum_iter"] 90 | header = 'Training Epoch: [{}]'.format(epoch) 91 | t = enumerate(metric_logger.log_every(data_loader, print_freq, header)) 92 | for step, samples in t: 93 | anc_img, pos_img, neg_img, anc_list, type = samples["anc"],samples["pos"],samples["neg"],samples["name"],samples["type"] 94 | anc_img, pos_img, neg_img = anc_img.cuda(), pos_img.cuda(), neg_img.cuda() 95 | model.zero_grad() 96 | vec = torch.cat((anc_img, pos_img, neg_img), dim=0) 97 | emb = model.forward(vec) 98 | ll = int(emb.shape[0] / 3) 99 | anc_fea, pos_fea, neg_fea = torch.split(emb, ll, dim=0) 100 | 101 | loss,dists1,dists2,dists3 = compute_loss(anc_fea, pos_fea, neg_fea, type) 102 | 103 | loss_value = loss.item() 104 | if not math.isfinite(loss_value): 105 | print("Loss is {}, stopping training".format(loss_value)) 106 | sys.exit(1) 107 | 108 | loss.backward() 109 | optimizer.step() 110 | 111 | metric_logger.update(loss=loss_value) 112 | metric_logger.update(lr=optimizer.state_dict()['param_groups'][0]['lr']) 113 | 114 | loss_value_reduce = misc.all_reduce_mean(loss_value) 115 | 116 | 117 | if (step + 1) % accum_iter == 0: 118 | iter = epoch * len(data_loader) + step + 1 119 | log_writer.add_scalar("loss", loss_value, iter) 120 | 121 | print("Averaged stats:", metric_logger) 122 | 123 | 124 | def evaluate(epoch, model, data_loader, test_metric_logger, config): 125 | model.eval() 126 | print_freq = config["print_freq"] 127 | header = 'Validation Epoch: [{}]'.format(epoch) 128 | acc_logger = misc.Triplet_Logger(os.path.join(config["log_dir"], "test_log.json")) 129 | 130 | t = enumerate(test_metric_logger.log_every(data_loader, print_freq, header)) 131 | 132 | for step, samples in t: 133 | anc_img, pos_img, neg_img, anc_list, type = samples["anc"],samples["pos"],samples["neg"],samples["name"],samples["type"] 134 | anc_img, pos_img, neg_img = anc_img.cuda(), pos_img.cuda(), neg_img.cuda() 135 | 136 | vec = torch.cat((anc_img, pos_img, neg_img), dim=0) 137 | with torch.no_grad(): 138 | emb = model.forward(vec) 139 | ll = int(emb.shape[0] / 3) 140 | anc_fea, pos_fea, neg_fea = torch.split(emb, ll, dim=0) 141 | 142 | loss,dists1,dists2,dists3 = compute_loss(anc_fea, pos_fea, neg_fea, type) 143 | loss_value = loss.item() 144 | 145 | test_metric_logger.update(loss=loss_value) 146 | acc_logger.update(dists1,dists2,dists3,loss,type) 147 | 148 | avg_loss, res = acc_logger.summary() 149 | print(res) 150 | test_metric_logger.meters['loss_avg'].update(avg_loss, n=1) 151 | test_metric_logger.meters['overall_accuracy'].update(res[0], n=1) 152 | 153 | if len(res)>1: 154 | test_metric_logger.meters['CLASS1_1_accuracy'].update(res[1], n=1) 155 | test_metric_logger.meters['CLASS1_2_accuracy'].update(res[2], n=1) 156 | test_metric_logger.meters['CLASS1_3_accuracy'].update(res[3], n=1) 157 | 158 | print('* Overall Accuracy: {overall_accuracy.avg:.3f} loss {loss_avg.global_avg:.3f}' 159 | .format(overall_accuracy = test_metric_logger.overall_accuracy, loss_avg = test_metric_logger.meters["loss_avg"])) 160 | 161 | if epoch % config["save_epoch"] == 0: 162 | save_path = os.path.join(config["checkpoint_dir"], "epoch_" + str(epoch) + "_acc_" + str(res[0]) + ".pth") 163 | torch.save(model.state_dict(),save_path) 164 | 165 | return avg_loss, res 166 | 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument("--config", default="configs/mae_train_expemb.yaml") 172 | args = parser.parse_args() 173 | yml_path = args.config 174 | config = read_yaml_to_dict(yml_path) 175 | 176 | main(config) -------------------------------------------------------------------------------- /mae-main/FINETUNE.md: -------------------------------------------------------------------------------- 1 | ## Fine-tuning Pre-trained MAE for Classification 2 | 3 | ### Evaluation 4 | 5 | As a sanity check, run evaluation using our ImageNet **fine-tuned** models: 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 |
ViT-BaseViT-LargeViT-Huge
fine-tuned checkpointdownloaddownloaddownload
md51b25e951f5502541f2
reference ImageNet accuracy83.66485.95286.928
31 | 32 | Evaluate ViT-Base in a single GPU (`${IMAGENET_DIR}` is a directory containing `{train, val}` sets of ImageNet): 33 | ``` 34 | python main_finetune.py --eval --resume mae_finetuned_vit_base.pth --model vit_base_patch16 --batch_size 16 --data_path ${IMAGENET_DIR} 35 | ``` 36 | This should give: 37 | ``` 38 | * Acc@1 83.664 Acc@5 96.530 loss 0.731 39 | ``` 40 | 41 | Evaluate ViT-Large: 42 | ``` 43 | python main_finetune.py --eval --resume mae_finetuned_vit_large.pth --model vit_large_patch16 --batch_size 16 --data_path ${IMAGENET_DIR} 44 | ``` 45 | This should give: 46 | ``` 47 | * Acc@1 85.952 Acc@5 97.570 loss 0.646 48 | ``` 49 | 50 | Evaluate ViT-Huge: 51 | ``` 52 | python main_finetune.py --eval --resume mae_finetuned_vit_huge.pth --model vit_huge_patch14 --batch_size 16 --data_path ${IMAGENET_DIR} 53 | ``` 54 | This should give: 55 | ``` 56 | * Acc@1 86.928 Acc@5 98.088 loss 0.584 57 | ``` 58 | 59 | ### Fine-tuning 60 | 61 | Get our pre-trained checkpoints from [here](https://github.com/fairinternal/mae/#pre-trained-checkpoints). 62 | 63 | To fine-tune with **multi-node distributed training**, run the following on 4 nodes with 8 GPUs each: 64 | ``` 65 | python submitit_finetune.py \ 66 | --job_dir ${JOB_DIR} \ 67 | --nodes 4 \ 68 | --batch_size 32 \ 69 | --model vit_base_patch16 \ 70 | --finetune ${PRETRAIN_CHKPT} \ 71 | --epochs 100 \ 72 | --blr 5e-4 --layer_decay 0.65 \ 73 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 74 | --dist_eval --data_path ${IMAGENET_DIR} 75 | ``` 76 | - Install submitit (`pip install submitit`) first. 77 | - Here the effective batch size is 32 (`batch_size` per gpu) * 4 (`nodes`) * 8 (gpus per node) = 1024. 78 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 79 | - We have run 4 trials with different random seeds. The resutls are 83.63, 83.66, 83.52, 83.46 (mean 83.57 and std 0.08). 80 | - Training time is ~7h11m in 32 V100 GPUs. 81 | 82 | Script for ViT-Large: 83 | ``` 84 | python submitit_finetune.py \ 85 | --job_dir ${JOB_DIR} \ 86 | --nodes 4 --use_volta32 \ 87 | --batch_size 32 \ 88 | --model vit_large_patch16 \ 89 | --finetune ${PRETRAIN_CHKPT} \ 90 | --epochs 50 \ 91 | --blr 1e-3 --layer_decay 0.75 \ 92 | --weight_decay 0.05 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 93 | --dist_eval --data_path ${IMAGENET_DIR} 94 | ``` 95 | - We have run 4 trials with different random seeds. The resutls are 85.95, 85.87, 85.76, 85.88 (mean 85.87 and std 0.07). 96 | - Training time is ~8h52m in 32 V100 GPUs. 97 | 98 | Script for ViT-Huge: 99 | ``` 100 | python submitit_finetune.py \ 101 | --job_dir ${JOB_DIR} \ 102 | --nodes 8 --use_volta32 \ 103 | --batch_size 16 \ 104 | --model vit_huge_patch14 \ 105 | --finetune ${PRETRAIN_CHKPT} \ 106 | --epochs 50 \ 107 | --blr 1e-3 --layer_decay 0.75 \ 108 | --weight_decay 0.05 --drop_path 0.3 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 109 | --dist_eval --data_path ${IMAGENET_DIR} 110 | ``` 111 | - Training time is ~13h9m in 64 V100 GPUs. 112 | 113 | To fine-tune our pre-trained ViT-Base with **single-node training**, run the following on 1 node with 8 GPUs: 114 | ``` 115 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \ 116 | --accum_iter 4 \ 117 | --batch_size 32 \ 118 | --model vit_base_patch16 \ 119 | --finetune ${PRETRAIN_CHKPT} \ 120 | --epochs 100 \ 121 | --blr 5e-4 --layer_decay 0.65 \ 122 | --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \ 123 | --dist_eval --data_path ${IMAGENET_DIR} 124 | ``` 125 | - Here the effective batch size is 32 (`batch_size` per gpu) * 4 (`accum_iter`) * 8 (gpus) = 1024. `--accum_iter 4` simulates 4 nodes. 126 | 127 | #### Notes 128 | 129 | - The [pre-trained models we provide](https://github.com/fairinternal/mae/#pre-trained-checkpoints) are trained with *normalized* pixels `--norm_pix_loss` (1600 epochs, Table 3 in paper). The fine-tuning hyper-parameters are slightly different from the default baseline using *unnormalized* pixels. 130 | 131 | - The original MAE implementation was in TensorFlow+TPU with no explicit mixed precision. This re-implementation is in PyTorch+GPU with automatic mixed precision (`torch.cuda.amp`). We have observed different numerical behavior between the two platforms. In this repo, we use `--global_pool` for fine-tuning; using `--cls_token` performs similarly, but there is a chance of producing NaN when fine-tuning ViT-Huge in GPUs. We did not observe this issue in TPUs. Turning off amp could solve this issue, but is slower. 132 | 133 | - Here we use RandErase following DeiT: `--reprob 0.25`. Its effect is smaller than random variance. 134 | 135 | ### Linear Probing 136 | 137 | Run the following on 4 nodes with 8 GPUs each: 138 | ``` 139 | python submitit_linprobe.py \ 140 | --job_dir ${JOB_DIR} \ 141 | --nodes 4 \ 142 | --batch_size 512 \ 143 | --model vit_base_patch16 --cls_token \ 144 | --finetune ${PRETRAIN_CHKPT} \ 145 | --epochs 90 \ 146 | --blr 0.1 \ 147 | --weight_decay 0.0 \ 148 | --dist_eval --data_path ${IMAGENET_DIR} 149 | ``` 150 | - Here the effective batch size is 512 (`batch_size` per gpu) * 4 (`nodes`) * 8 (gpus per node) = 16384. 151 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 152 | - Training time is ~2h20m for 90 epochs in 32 V100 GPUs. 153 | - To run single-node training, follow the instruction in fine-tuning. 154 | 155 | To train ViT-Large or ViT-Huge, set `--model vit_large_patch16` or `--model vit_huge_patch14`. It is sufficient to train 50 epochs `--epochs 50`. 156 | 157 | This PT/GPU code produces *better* results for ViT-L/H (see the table below). This is likely caused by the system difference between TF and PT. 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 |
ViT-BaseViT-LargeViT-Huge
paper (TF/TPU)68.075.876.6
this repo (PT/GPU)67.876.077.2
178 | -------------------------------------------------------------------------------- /train_rig2img.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | def Train(epoch, loader, model): 4 | lr = optimizer.param_groups[0]['lr'] 5 | print(f"*** Epoch {epoch}, lr:{lr:.5f}, timestamp:{timestamp}") 6 | loss_sum = 0.0 7 | model.train() 8 | train_step = min(args.train_step_per_epoch, len(loader)) 9 | b = '{l_bar}{bar:20}{r_bar}{bar:10b}' 10 | pbar = tqdm(enumerate(loader), bar_format=b, total=train_step) 11 | logger = [] 12 | time0 = time.time() 13 | for i, data in pbar: 14 | if i>train_step: 15 | break 16 | optimizer.zero_grad() 17 | loss = dict() 18 | 19 | targets = data['img'].cuda().float() 20 | rigs = data['rigs'].cuda().float() 21 | assert (data['has_rig'] == 1).all() 22 | outputs = model(rigs.reshape(-1, configs_character['n_rig'], 1, 1)) 23 | loss['image'] = criterion_l1(outputs, targets) * args.weight_img 24 | loss['mouth'] = criterion_l1(outputs*mouth_crop, targets*mouth_crop) * args.weight_mouth 25 | 26 | loss_value = sum([v for k, v in loss.items()]) 27 | 28 | loss_sum += loss_value.item() 29 | loss_value.backward(retain_graph=True) 30 | optimizer.step() 31 | scheduler.step() 32 | 33 | 34 | writer.add_scalars(f'train/loss', loss, epoch * train_step + i) 35 | writer.add_scalar(f'train/loss_total', loss_value.item(), epoch * train_step + i) 36 | 37 | _loss_str = str({k: "{:.4f}".format(v/(i+1)) for k, v in loss.items()}) 38 | _log = f"Epoch {epoch}({timestamp}) (lr:{optimizer.param_groups[0]['lr']:05f}): [{i}/{len(train_dataloader)}] loss_G:{_loss_str}" 39 | logger.append(_log+'\n') 40 | pbar.set_description(_log) 41 | 42 | writer.add_images(f'train/img', torch.cat([outputs, targets], dim=-2)[::4], epoch * train_step + i) 43 | avg_loss = loss_sum / train_step 44 | _log = "==> [Train] Epoch {} ({}), training loss={}".format(epoch, timestamp, avg_loss) 45 | print(_log) 46 | with open(os.path.join(log_save_path, f'{task}_{timestamp}.log'), "a+") as log_file: 47 | log_file.writelines(logger) 48 | if epoch % args.save_step == 0: 49 | torch.save({'state_dict': model.state_dict()}, model_path.replace('.pt', f'_{epoch}.pt')) 50 | return avg_loss 51 | 52 | def Eval(epoch, loader, model, best_score): 53 | loss_sum = 0.0 54 | model.eval() 55 | eval_step = min(args.eval_step_per_epoch, len(loader)) 56 | b = '{l_bar}{bar:20}{r_bar}{bar:10b}' 57 | pbar = tqdm(enumerate(loader), bar_format=b, total=eval_step) 58 | logger = [] 59 | time0 = time.time() 60 | for i, data in pbar: 61 | if i>eval_step: 62 | break 63 | loss = dict() 64 | 65 | targets = data['img'].cuda().float() 66 | rigs = data['rigs'].cuda().float() 67 | assert (data['has_rig'] == 1).all() 68 | with torch.no_grad(): 69 | outputs = model(rigs.reshape(-1, configs_character['n_rig'], 1, 1)) 70 | loss['image'] = criterion_l1(outputs, targets) * args.weight_img 71 | loss['mouth'] = criterion_l1(outputs*mouth_crop, targets*mouth_crop) * args.weight_mouth 72 | 73 | loss_value = sum([v for k, v in loss.items()]) 74 | 75 | loss_sum += loss_value.item() 76 | 77 | writer.add_scalars(f'train/loss', loss, epoch * eval_step + i) 78 | writer.add_scalar(f'train/loss_total', loss_value.item(), epoch * eval_step + i) 79 | 80 | _loss_str = str({k: "{:.4f}".format(v/(i+1)) for k, v in loss.items()}) 81 | _log = f"Epoch {epoch}({timestamp}) (lr:{optimizer.param_groups[0]['lr']:05f}): [{i}/{len(train_dataloader)}] loss_G:{_loss_str}" 82 | logger.append(_log+'\n') 83 | pbar.set_description(_log) 84 | 85 | writer.add_images(f'train/img', torch.cat([outputs, targets], dim=-2)[::4], epoch * eval_step + i) 86 | avg_loss = loss_sum / eval_step 87 | _log = "==> [Eval] Epoch {} ({}), training loss={}".format(epoch, timestamp, avg_loss) 88 | 89 | if avg_loss < best_score: 90 | patience_cur = args.patience 91 | best_score = avg_loss 92 | torch.save({'state_dict': model.state_dict()}, model_path) 93 | _log += '\n Found new best model!\n' 94 | else: 95 | patience_cur -= 1 96 | 97 | print(_log) 98 | with open(os.path.join(log_save_path, f'{task}_{timestamp}.log'), "a+") as log_file: 99 | log_file.writelines(logger) 100 | return avg_loss 101 | 102 | if __name__ == '__main__': 103 | import time 104 | import os 105 | import torch 106 | from choose_character import character_choice 107 | from utils.common import parse_args_from_yaml, setup_seed, init_weights 108 | from models.DCGAN import Generator 109 | import torchvision.transforms as transforms 110 | import torch.nn as nn 111 | from dataset.ABAWData import ABAWDataset2 112 | from torch.utils.data import DataLoader 113 | from torch.optim import lr_scheduler 114 | from torch.utils.tensorboard import SummaryWriter 115 | task = 'rig2img' 116 | args = parse_args_from_yaml(f'configs_{task}.yaml') 117 | setup_seed(args.seed) 118 | timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) 119 | os.system("git add .") 120 | os.system("git commit -m" + timestamp) 121 | os.system("git push") 122 | 123 | configs_character = character_choice(args.character) 124 | mouth_crop = torch.tensor(configs_character['mouth_crop']).cuda().float() 125 | 126 | model_path = os.path.join(args.save_root,'ckpt', f"{task}_{timestamp}.pt") 127 | params = {'nz': configs_character['n_rig'], 'ngf': 64*2, 'nc': 3} 128 | model = Generator(params) 129 | model = model.cuda() 130 | 131 | if args.pretrained: 132 | ckpt_pretrained = os.path.join(args.save_root, 'ckpt', f"{task}_{args.pretrained}.pt") 133 | checkpoint = torch.load(ckpt_pretrained) 134 | model.load_state_dict(checkpoint['state_dict']) 135 | print("load pretrained model {}".format(ckpt_pretrained)) 136 | else: 137 | model.apply(init_weights) 138 | print("Model initialized") 139 | transform = transforms.Compose([ 140 | transforms.Resize([256, 256]), 141 | transforms.ToTensor()]) 142 | 143 | criterion_l1 = nn.L1Loss() 144 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.0, 0.99)) 145 | scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 500, 2, 1e-6) 146 | 147 | train_dataset = ABAWDataset2(root_path=configs_character['data_path'],character=args.character, only_render=True, 148 | data_split='train', transform=transform, return_rigs=True, n_rigs=configs_character['n_rig']) 149 | test_dataset = ABAWDataset2(root_path=configs_character['data_path'],character=args.character,only_render=True, 150 | data_split='test', transform=transform, return_rigs=True, n_rigs=configs_character['n_rig']) 151 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True,drop_last=True,num_workers=12) 152 | val_dataloader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False,drop_last=True, num_workers=12) 153 | 154 | ck_save_path = f'{args.save_root}/ckpt' 155 | pred_save_path = f'{args.save_root}/test' 156 | log_save_path = f'{args.save_root}/logs' 157 | tensorboard_path = f'{args.save_root}/tensorboard/{timestamp}' 158 | 159 | os.makedirs(ck_save_path,exist_ok=True) 160 | os.makedirs(pred_save_path, exist_ok=True) 161 | os.makedirs(tensorboard_path, exist_ok=True) 162 | os.makedirs(log_save_path, exist_ok=True) 163 | 164 | writer = SummaryWriter(log_dir=tensorboard_path) 165 | 166 | patience_cur = args.patience 167 | best_score = float('inf') 168 | 169 | 170 | for epoch in range(500000000): 171 | avg_loss = Train(epoch, train_dataloader, model) 172 | avg_loss_eval = Eval(epoch, val_dataloader, model, best_score) 173 | -------------------------------------------------------------------------------- /exp_emb_code/utils/misc.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import datetime 3 | import os 4 | import time 5 | from collections import defaultdict, deque 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch._six import inf 11 | from utils.metrics import triplet_prediction_accuracy 12 | import json 13 | import os 14 | import pickle 15 | 16 | class SmoothedValue(object): 17 | def __init__(self, window_size=20, fmt=None): 18 | if fmt is None: 19 | fmt = "{median:.6f} ({global_avg:.6f})" 20 | self.deque = deque(maxlen=window_size) 21 | self.total = 0.0 22 | self.count = 0 23 | self.fmt = fmt 24 | 25 | def update(self, value, n=1): 26 | self.deque.append(value) 27 | self.count += n 28 | self.total += value * n 29 | 30 | def synchronize_between_processes(self): 31 | if not is_dist_avail_and_initialized(): 32 | return 33 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 34 | dist.barrier() 35 | dist.all_reduce(t) 36 | t = t.tolist() 37 | self.count = int(t[0]) 38 | self.total = t[1] 39 | 40 | @property 41 | def median(self): 42 | d = torch.tensor(list(self.deque)) 43 | return d.median().item() 44 | 45 | @property 46 | def avg(self): 47 | d = torch.tensor(list(self.deque), dtype=torch.float32) 48 | return d.mean().item() 49 | 50 | @property 51 | def global_avg(self): 52 | return self.total / self.count 53 | 54 | @property 55 | def max(self): 56 | return max(self.deque) 57 | 58 | @property 59 | def value(self): 60 | return self.deque[-1] 61 | 62 | def __str__(self): 63 | return self.fmt.format( 64 | median=self.median, 65 | avg=self.avg, 66 | global_avg=self.global_avg, 67 | max=self.max, 68 | value=self.value) 69 | 70 | 71 | class MetricLogger(object): 72 | def __init__(self, delimiter="\t"): 73 | self.meters = defaultdict(SmoothedValue) 74 | self.delimiter = delimiter 75 | 76 | def update(self, **kwargs): 77 | for k, v in kwargs.items(): 78 | if v is None: 79 | continue 80 | if isinstance(v, torch.Tensor): 81 | v = v.item() 82 | assert isinstance(v, (float, int)) 83 | self.meters[k].update(v) 84 | 85 | def __getattr__(self, attr): 86 | if attr in self.meters: 87 | return self.meters[attr] 88 | if attr in self.__dict__: 89 | return self.__dict__[attr] 90 | raise AttributeError("'{}' object has no attribute '{}'".format( 91 | type(self).__name__, attr)) 92 | 93 | def __str__(self): 94 | loss_str = [] 95 | for name, meter in self.meters.items(): 96 | loss_str.append( 97 | "{}: {}".format(name, str(meter)) 98 | ) 99 | return self.delimiter.join(loss_str) 100 | 101 | def synchronize_between_processes(self): 102 | for meter in self.meters.values(): 103 | meter.synchronize_between_processes() 104 | 105 | def add_meter(self, name, meter): 106 | self.meters[name] = meter 107 | 108 | 109 | def log_every(self, iterable, print_freq, header=None): 110 | i = 0 111 | if not header: 112 | header = '' 113 | start_time = time.time() 114 | end = time.time() 115 | iter_time = SmoothedValue(fmt='{avg:.4f}') 116 | data_time = SmoothedValue(fmt='{avg:.4f}') 117 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 118 | log_msg = [ 119 | header, 120 | '[{0' + space_fmt + '}/{1}]', 121 | 'eta: {eta}', 122 | '{meters}', 123 | 'time: {time}', 124 | 'data: {data}' 125 | ] 126 | if torch.cuda.is_available(): 127 | log_msg.append('max mem: {memory:.0f}') 128 | log_msg = self.delimiter.join(log_msg) 129 | MB = 1024.0 * 1024.0 130 | for obj in iterable: 131 | data_time.update(time.time() - end) 132 | yield obj 133 | iter_time.update(time.time() - end) 134 | if i % print_freq == 0 or i == len(iterable) - 1: 135 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 136 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 137 | if torch.cuda.is_available(): 138 | print(log_msg.format( 139 | i, len(iterable), eta=eta_string, 140 | meters=str(self), 141 | time=str(iter_time), data=str(data_time), 142 | memory=torch.cuda.max_memory_allocated() / MB)) 143 | else: 144 | 145 | print(log_msg.format( 146 | i, len(iterable), eta=eta_string, 147 | meters=str(self), 148 | time=str(iter_time), data=str(data_time))) 149 | i += 1 150 | end = time.time() 151 | total_time = time.time() - start_time 152 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 153 | 154 | print('{} Total time: {} ({:.4f} s / it)'.format( 155 | header, total_time_str, total_time / len(iterable))) 156 | 157 | 158 | 159 | class Triplet_Logger: 160 | def __init__(self, save_log_file=None): 161 | self.distance1 = [] 162 | self.distance2 = [] 163 | self.distance3 = [] 164 | self.losses = [] 165 | self.types = [] 166 | self.save_log_file = save_log_file 167 | 168 | 169 | def update(self, dis1, dis2, dis3, loss, type): 170 | self.distance1.extend(dis1) 171 | self.distance2.extend(dis2) 172 | self.distance3.extend(dis3) 173 | self.losses.append(loss) 174 | self.types.extend(type) 175 | 176 | def dict_to_str(self,dict): 177 | logs = "" 178 | for k,v in dict.items(): 179 | logs += k + ": " +str(v) + " " 180 | return logs 181 | 182 | def write_log(self, avg_loss, res): 183 | if len(res)==1: 184 | write_json_dict = {'time':time.asctime(), 'Avg losses': avg_loss.item(), 'Overall prediction Accuracy': res[0]} 185 | else: 186 | write_json_dict = {'time':time.asctime(), 'Avg losses': avg_loss.item(), 'Overall prediction Accuracy': res[0], \ 187 | 'Class 1 Acc': res[1], 'Class 2 Acc': res[2], 'Class 3 Acc': res[3]} 188 | logs = self.dict_to_str(write_json_dict) 189 | with open(self.save_log_file, mode="a") as f: 190 | f.write(logs + "\n") 191 | 192 | 193 | def summary(self): 194 | res = [] 195 | avg_loss = sum(self.losses)/len(self.losses) 196 | 197 | acc, acc1, acc2, acc3 = triplet_prediction_accuracy(self.distance1,self.distance2,self.distance3,self.types,"triplet") 198 | res = [acc, acc1, acc2, acc3] 199 | 200 | if self.save_log_file!=None: 201 | self.write_log(avg_loss,res) 202 | return avg_loss, res 203 | 204 | 205 | 206 | 207 | class Emb_Logger: 208 | def __init__(self,save_path=None): 209 | self.emb_dict = {} 210 | self.save_path = save_path 211 | 212 | def update(self, names , embs): 213 | for i in range(len(names)): 214 | self.emb_dict[names[i]] = embs[i] 215 | 216 | def summary(self): 217 | 218 | with open(self.save_path, mode="wb") as f: 219 | pickle.dump(self.emb_dict,f) 220 | 221 | def is_dist_avail_and_initialized(): 222 | if not dist.is_available(): 223 | return False 224 | if not dist.is_initialized(): 225 | return False 226 | return True 227 | 228 | 229 | def get_world_size(): 230 | if not is_dist_avail_and_initialized(): 231 | return 1 232 | return dist.get_world_size() 233 | 234 | def all_reduce_mean(x): 235 | world_size = get_world_size() 236 | if world_size > 1: 237 | x_reduce = torch.tensor(x).cuda() 238 | dist.all_reduce(x_reduce) 239 | x_reduce /= world_size 240 | return x_reduce.item() 241 | else: 242 | return x -------------------------------------------------------------------------------- /models/CascadeNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fully-connected residual network as a single deep learner. 3 | Convert 2d to 3d pose. 4 | From: https://github.com/Nicholasli1995/EvoSkeleton/blob/master/libs/model/model.py 5 | """ 6 | 7 | import torch.nn as nn 8 | import torch 9 | 10 | 11 | class ResidualBlock(nn.Module): 12 | """ 13 | A residual block. 14 | """ 15 | 16 | def __init__(self, linear_size, p_dropout=0.5, kaiming=True, leaky=False, activation=True): 17 | super(ResidualBlock, self).__init__() 18 | self.l_size = linear_size 19 | self.activation = activation 20 | if leaky: 21 | self.relu = nn.LeakyReLU(inplace=True) 22 | else: 23 | self.relu = nn.ReLU(inplace=True) 24 | self.dropout = nn.Dropout(p_dropout) 25 | 26 | self.w1 = nn.Linear(self.l_size, self.l_size) 27 | self.batch_norm1 = nn.BatchNorm1d(self.l_size) 28 | 29 | self.w2 = nn.Linear(self.l_size, self.l_size) 30 | self.batch_norm2 = nn.BatchNorm1d(self.l_size) 31 | 32 | if kaiming: 33 | self.w1.weight.data = nn.init.kaiming_normal_(self.w1.weight.data) 34 | self.w2.weight.data = nn.init.kaiming_normal_(self.w2.weight.data) 35 | 36 | def forward(self, x): 37 | y = self.w1(x) 38 | y = self.batch_norm1(y) 39 | if self.activation: 40 | y = self.relu(y) 41 | y = self.dropout(y) 42 | 43 | y = self.w2(y) 44 | y = self.batch_norm2(y) 45 | if self.activation: 46 | y = self.relu(y) 47 | y = self.dropout(y) 48 | 49 | out = x + y 50 | 51 | return out 52 | 53 | 54 | class FCModel(nn.Module): 55 | def __init__(self, 56 | stage_id=1, 57 | linear_size=1024, 58 | num_blocks=2, 59 | p_dropout=0.5, 60 | norm_twoD=False, 61 | kaiming=True, 62 | refine_3d=False, 63 | leaky=False, 64 | dm=False, 65 | input_size=32, 66 | output_size=64, 67 | activation=True, 68 | use_multichar=False, 69 | id_embedding_dim=16): 70 | """ 71 | Fully-connected network. 72 | """ 73 | super(FCModel, self).__init__() 74 | if use_multichar: 75 | self.embedding_layer = nn.Embedding(10, embedding_dim=id_embedding_dim) 76 | input_size += id_embedding_dim 77 | self.activation = activation 78 | self.linear_size = linear_size 79 | self.p_dropout = p_dropout 80 | self.num_blocks = num_blocks 81 | self.stage_id = stage_id 82 | self.refine_3d = refine_3d 83 | self.leaky = leaky 84 | self.dm = dm 85 | self.input_size = input_size 86 | if self.stage_id > 1 and self.refine_3d: 87 | self.input_size += 16 * 3 88 | # 3d joints 89 | self.output_size = output_size 90 | 91 | # process input to linear size 92 | self.w1 = nn.Linear(self.input_size, self.linear_size) 93 | self.batch_norm1 = nn.BatchNorm1d(self.linear_size) 94 | 95 | self.res_blocks = [] 96 | for l in range(num_blocks): 97 | self.res_blocks.append(ResidualBlock(self.linear_size, 98 | self.p_dropout, 99 | leaky=self.leaky, 100 | activation=activation)) 101 | self.res_blocks = nn.ModuleList(self.res_blocks) 102 | 103 | # output 104 | 105 | 106 | self.w2 = nn.Linear(self.linear_size, self.output_size) 107 | if self.leaky: 108 | self.relu = nn.LeakyReLU(inplace=True) 109 | else: 110 | self.relu = nn.ReLU(inplace=True) 111 | self.dropout = nn.Dropout(self.p_dropout) 112 | 113 | if kaiming: 114 | self.w1.weight.data = nn.init.kaiming_normal_(self.w1.weight.data) 115 | self.w2.weight.data = nn.init.kaiming_normal_(self.w2.weight.data) 116 | self.out_activation = nn.Sigmoid() 117 | self.use_multichar=use_multichar 118 | 119 | 120 | def forward(self, x, id_index=0): 121 | if self.use_multichar: 122 | input_feature = self.embedding_layer(id_index.long()) 123 | x = torch.cat([input_feature, x], dim=1) 124 | y = self.get_representation(x) 125 | y = self.w2(y) 126 | y = self.out_activation(y) 127 | return y 128 | 129 | def get_representation(self, x): 130 | # get the latent representation of an input vector 131 | # first layer 132 | y = self.w1(x) 133 | y = self.batch_norm1(y) 134 | if self.activation: 135 | y = self.relu(y) 136 | y = self.dropout(y) 137 | 138 | # residual blocks 139 | for i in range(self.num_blocks): 140 | y = self.res_blocks[i](y) 141 | 142 | return y 143 | 144 | 145 | def get_model(stage_id, 146 | refine_3d=False, 147 | norm_twoD=False, 148 | num_blocks=2, 149 | input_size=32, 150 | output_size=64, 151 | linear_size=1024, 152 | dropout=0.5, 153 | leaky=False, 154 | activation=True, 155 | use_multichar=False, 156 | id_embedding_dim=16 157 | ): 158 | model = FCModel(stage_id=stage_id, 159 | refine_3d=refine_3d, 160 | norm_twoD=norm_twoD, 161 | num_blocks=num_blocks, 162 | input_size=input_size, 163 | output_size=output_size, 164 | linear_size=linear_size, 165 | p_dropout=dropout, 166 | leaky=leaky, 167 | activation=activation, 168 | 169 | use_multichar=use_multichar, 170 | id_embedding_dim=id_embedding_dim 171 | ) 172 | return model 173 | 174 | 175 | def prepare_optim(model, opt): 176 | """ 177 | Prepare optimizer. 178 | """ 179 | params = [p for p in model.parameters() if p.requires_grad] 180 | if opt.optim_type == 'adam': 181 | optimizer = torch.optim.Adam(params, 182 | lr=opt.lr, 183 | weight_decay=opt.weight_decay 184 | ) 185 | elif opt.optim_type == 'sgd': 186 | optimizer = torch.optim.SGD(params, 187 | lr=opt.lr, 188 | momentum=opt.momentum, 189 | weight_decay=opt.weight_decay 190 | ) 191 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 192 | milestones=opt.milestones, 193 | gamma=opt.gamma) 194 | return optimizer, scheduler 195 | 196 | 197 | def get_cascade(): 198 | """ 199 | Get an empty cascade. 200 | """ 201 | return nn.ModuleList([]) 202 | 203 | class FC(nn.Module): 204 | def __init__(self, input_size=32, output_size=64, kaiming=True): 205 | """ 206 | Fully-connected network. 207 | """ 208 | super(FC, self).__init__() 209 | 210 | self.w1 = nn.Linear(input_size, output_size) 211 | 212 | if kaiming: 213 | self.w1.weight.data = nn.init.kaiming_normal_(self.w1.weight.data) 214 | 215 | def forward(self, x): 216 | out = self.w1(x) 217 | return out 218 | 219 | if __name__ == '__main__': 220 | import os 221 | 222 | # cascade = get_cascade() 223 | # for stage_id in range(2): 224 | # cascade.append(get_model(stage_id + 1, refine_3d=False, 225 | # norm_twoD=False, 226 | # num_blocks=2, 227 | # input_size=16, 228 | # output_size=139, 229 | # linear_size=1024, 230 | # dropout=0.5, 231 | # leaky=False 232 | # )) 233 | # cascade.eval() 234 | n_rig = 61 235 | exp_dim = 16 236 | model_path_root = '/data/Workspace/Rig2Face/ckpt' 237 | model = get_model(1, refine_3d=False, 238 | norm_twoD=False, 239 | num_blocks=4, #2, 240 | input_size=n_rig, 241 | output_size=exp_dim, 242 | linear_size=1024, #1024, 243 | dropout=0.0, 244 | leaky=False 245 | ) 246 | checkpoint = torch.load(os.path.join(model_path_root, 'model_model_20221129-130512.pt')) 247 | model.load_state_dict(checkpoint['state_dict']) 248 | print("load model {}".format(os.path.join(model_path_root,f'model_model_20221129-130512.pt'))) 249 | 250 | rigs = torch.randn((8, n_rig)) 251 | rigs = torch.clip(rigs, min=0, max=1) 252 | outputs = model(rigs) 253 | print(outputs) 254 | 255 | -------------------------------------------------------------------------------- /mae-main/main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import numpy as np 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.tensorboard import SummaryWriter 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | 25 | import timm 26 | 27 | assert timm.__version__ == "0.3.2" # version check 28 | import timm.optim.optim_factory as optim_factory 29 | 30 | import util.misc as misc 31 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 32 | 33 | import models_mae 34 | 35 | from engine_pretrain import train_one_epoch 36 | 37 | 38 | def get_args_parser(): 39 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 40 | parser.add_argument('--batch_size', default=64, type=int, 41 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 42 | parser.add_argument('--epochs', default=400, type=int) 43 | parser.add_argument('--accum_iter', default=1, type=int, 44 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 45 | 46 | # Model parameters 47 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 48 | help='Name of model to train') 49 | 50 | parser.add_argument('--input_size', default=224, type=int, 51 | help='images input size') 52 | 53 | parser.add_argument('--mask_ratio', default=0.75, type=float, 54 | help='Masking ratio (percentage of removed patches).') 55 | 56 | parser.add_argument('--norm_pix_loss', action='store_true', 57 | help='Use (per-patch) normalized pixels as targets for computing loss') 58 | parser.set_defaults(norm_pix_loss=False) 59 | 60 | # Optimizer parameters 61 | parser.add_argument('--weight_decay', type=float, default=0.05, 62 | help='weight decay (default: 0.05)') 63 | 64 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 65 | help='learning rate (absolute lr)') 66 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 67 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 68 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 69 | help='lower lr bound for cyclic schedulers that hit 0') 70 | 71 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 72 | help='epochs to warmup LR') 73 | 74 | # Dataset parameters 75 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 76 | help='dataset path') 77 | 78 | parser.add_argument('--output_dir', default='./output_dir', 79 | help='path where to save, empty for no saving') 80 | parser.add_argument('--log_dir', default='./output_dir', 81 | help='path where to tensorboard log') 82 | parser.add_argument('--device', default='cuda', 83 | help='device to use for training / testing') 84 | parser.add_argument('--seed', default=0, type=int) 85 | parser.add_argument('--resume', default='', 86 | help='resume from checkpoint') 87 | 88 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 89 | help='start epoch') 90 | parser.add_argument('--num_workers', default=10, type=int) 91 | parser.add_argument('--pin_mem', action='store_true', 92 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 93 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 94 | parser.set_defaults(pin_mem=True) 95 | 96 | # distributed training parameters 97 | parser.add_argument('--world_size', default=1, type=int, 98 | help='number of distributed processes') 99 | parser.add_argument('--local_rank', default=-1, type=int) 100 | parser.add_argument('--dist_on_itp', action='store_true') 101 | parser.add_argument('--dist_url', default='env://', 102 | help='url used to set up distributed training') 103 | 104 | return parser 105 | 106 | 107 | def main(args): 108 | misc.init_distributed_mode(args) 109 | 110 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 111 | print("{}".format(args).replace(', ', ',\n')) 112 | 113 | device = torch.device(args.device) 114 | 115 | # fix the seed for reproducibility 116 | seed = args.seed + misc.get_rank() 117 | torch.manual_seed(seed) 118 | np.random.seed(seed) 119 | 120 | cudnn.benchmark = True 121 | 122 | # simple augmentation 123 | transform_train = transforms.Compose([ 124 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 125 | transforms.RandomHorizontalFlip(), 126 | transforms.ToTensor(), 127 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 128 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 129 | print(dataset_train) 130 | 131 | if True: # args.distributed: 132 | num_tasks = misc.get_world_size() 133 | global_rank = misc.get_rank() 134 | sampler_train = torch.utils.data.DistributedSampler( 135 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 136 | ) 137 | print("Sampler_train = %s" % str(sampler_train)) 138 | else: 139 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 140 | 141 | if global_rank == 0 and args.log_dir is not None: 142 | os.makedirs(args.log_dir, exist_ok=True) 143 | log_writer = SummaryWriter(log_dir=args.log_dir) 144 | else: 145 | log_writer = None 146 | 147 | data_loader_train = torch.utils.data.DataLoader( 148 | dataset_train, sampler=sampler_train, 149 | batch_size=args.batch_size, 150 | num_workers=args.num_workers, 151 | pin_memory=args.pin_mem, 152 | drop_last=True, 153 | ) 154 | 155 | # define the model 156 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) 157 | 158 | model.to(device) 159 | 160 | model_without_ddp = model 161 | print("Model = %s" % str(model_without_ddp)) 162 | 163 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 164 | 165 | if args.lr is None: # only base_lr is specified 166 | args.lr = args.blr * eff_batch_size / 256 167 | 168 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 169 | print("actual lr: %.2e" % args.lr) 170 | 171 | print("accumulate grad iterations: %d" % args.accum_iter) 172 | print("effective batch size: %d" % eff_batch_size) 173 | 174 | if args.distributed: 175 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 176 | model_without_ddp = model.module 177 | 178 | # following timm: set wd as 0 for bias and norm layers 179 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 180 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 181 | print(optimizer) 182 | loss_scaler = NativeScaler() 183 | 184 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 185 | 186 | print(f"Start training for {args.epochs} epochs") 187 | start_time = time.time() 188 | for epoch in range(args.start_epoch, args.epochs): 189 | if args.distributed: 190 | data_loader_train.sampler.set_epoch(epoch) 191 | train_stats = train_one_epoch( 192 | model, data_loader_train, 193 | optimizer, device, epoch, loss_scaler, 194 | log_writer=log_writer, 195 | args=args 196 | ) 197 | if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs): 198 | misc.save_model( 199 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 200 | loss_scaler=loss_scaler, epoch=epoch) 201 | 202 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 203 | 'epoch': epoch,} 204 | 205 | if args.output_dir and misc.is_main_process(): 206 | if log_writer is not None: 207 | log_writer.flush() 208 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 209 | f.write(json.dumps(log_stats) + "\n") 210 | 211 | total_time = time.time() - start_time 212 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 213 | print('Training time {}'.format(total_time_str)) 214 | 215 | 216 | if __name__ == '__main__': 217 | args = get_args_parser() 218 | args = args.parse_args() 219 | if args.output_dir: 220 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 221 | main(args) 222 | -------------------------------------------------------------------------------- /models/DCGAN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: aaronlai 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | try: 8 | from datasets.process_ctls_emb import load_ctr_from_txt 9 | from datasets.RigData import read_image,get_statis 10 | except: 11 | pass 12 | ''' 13 | DCGAN From: https://github.com/AaronYALai/Generative_Adversarial_Networks_PyTorch 14 | ''' 15 | 16 | class DCGAN_Discriminator(nn.Module): 17 | def __init__(self, featmap_dim=512, n_channel=1): 18 | super(DCGAN_Discriminator, self).__init__() 19 | self.featmap_dim = featmap_dim 20 | self.conv1 = nn.Conv2d(n_channel, int(featmap_dim / 4), 5, 21 | stride=4, padding=2) 22 | 23 | self.conv2 = nn.Conv2d(int(featmap_dim / 4), int(featmap_dim / 2), 5, 24 | stride=2, padding=2) 25 | self.BN2 = nn.BatchNorm2d(int(featmap_dim / 2)) 26 | 27 | self.conv3 = nn.Conv2d(int(featmap_dim / 2), featmap_dim, 5, 28 | stride=2, padding=2) 29 | self.BN3 = nn.BatchNorm2d(featmap_dim) 30 | 31 | self.conv4 = nn.Conv2d(featmap_dim, featmap_dim, 5, 32 | stride=2, padding=2) 33 | self.BN4 = nn.BatchNorm2d(featmap_dim) 34 | 35 | self.conv5 = nn.Conv2d(featmap_dim, featmap_dim, 5, 36 | stride=2, padding=2) 37 | self.BN5 = nn.BatchNorm2d(featmap_dim) 38 | 39 | self.fc = nn.Linear(featmap_dim * 4 * 4, 1) 40 | 41 | def forward(self, x): 42 | """ 43 | Strided convulation layers, 44 | Batch Normalization after convulation but not at input layer, 45 | LeakyReLU activation function with slope 0.2. 46 | """ 47 | x = F.leaky_relu(self.conv1(x), negative_slope=0.2) 48 | x = F.leaky_relu(self.BN2(self.conv2(x)), negative_slope=0.2) 49 | x = F.leaky_relu(self.BN3(self.conv3(x)), negative_slope=0.2) 50 | x = F.leaky_relu(self.BN4(self.conv4(x)), negative_slope=0.2) 51 | x = F.leaky_relu(self.BN5(self.conv5(x)), negative_slope=0.2) 52 | x = x.view(-1, self.featmap_dim * 4 * 4) 53 | x = F.sigmoid(self.fc(x)) 54 | return x 55 | 56 | 57 | class DCGAN_Generator(nn.Module): 58 | 59 | def __init__(self, featmap_dim=1024, n_channel=1, noise_dim=100): 60 | super(DCGAN_Generator, self).__init__() 61 | self.featmap_dim = featmap_dim 62 | self.fc1 = nn.Linear(noise_dim, 4 * 4 * featmap_dim) 63 | self.conv1 = nn.ConvTranspose2d(featmap_dim, (featmap_dim / 2), 5, 64 | stride=2, padding=2) 65 | 66 | self.BN1 = nn.BatchNorm2d(featmap_dim / 2) 67 | self.conv2 = nn.ConvTranspose2d(featmap_dim / 2, featmap_dim / 4, 6, 68 | stride=2, padding=2) 69 | 70 | self.BN2 = nn.BatchNorm2d(featmap_dim / 4) 71 | self.conv3 = nn.ConvTranspose2d(featmap_dim / 4, n_channel, 6, 72 | stride=2, padding=2) 73 | 74 | def forward(self, x): 75 | """ 76 | Project noise to featureMap * width * height, 77 | Batch Normalization after convulation but not at output layer, 78 | ReLU activation function. 79 | """ 80 | x = self.fc1(x) 81 | x = x.view(-1, self.featmap_dim, 4, 4) 82 | x = F.relu(self.BN1(self.conv1(x))) 83 | x = F.relu(self.BN2(self.conv2(x))) 84 | x = F.tanh(self.conv3(x)) 85 | 86 | return x 87 | 88 | # Define the Generator Network 89 | class Generator(nn.Module): 90 | ''' 91 | From https://github.com/Natsu6767/DCGAN-PyTorch/blob/master/dcgan.py 92 | ''' 93 | def __init__(self, params, activation='sigmoid', convert_norm=False): 94 | super().__init__() 95 | 96 | # Input is the latent vector Z. 97 | # self.fc1 = nn.Linear(139, 64*8) 98 | # self.fc2 = nn.Linear(64*8, 64*4) 99 | # TODO params['nz'] / 64*4 100 | self.params = params 101 | self.activation = activation 102 | self.tconv1 = nn.ConvTranspose2d(params['nz'], params['ngf']*8, 103 | kernel_size=4, stride=1, padding=0, bias=False) 104 | self.bn1 = nn.BatchNorm2d(params['ngf']*8) 105 | # Input Dimension: (ngf*8) x 4 x 4 106 | self.tconv2 = nn.ConvTranspose2d(params['ngf']*8, params['ngf']*4, 107 | 4, 2, 1, bias=False) 108 | self.bn2 = nn.BatchNorm2d(params['ngf']*4) 109 | 110 | # Input Dimension: (ngf*4) x 8 x 8 111 | self.tconv3 = nn.ConvTranspose2d(params['ngf']*4, params['ngf']*2, 112 | 4, 2, 1, bias=False) 113 | self.bn3 = nn.BatchNorm2d(params['ngf']*2) 114 | 115 | # Input Dimension: (ngf*2) x 16 x 16 116 | self.tconv4 = nn.ConvTranspose2d(params['ngf']*2, params['ngf'], 117 | 4, 2, 1, bias=False) 118 | self.bn4 = nn.BatchNorm2d(params['ngf']) 119 | 120 | # Input Dimension: (ngf) * 32 * 32 121 | self.tconv5 = nn.ConvTranspose2d(params['ngf'], params['ngf'], 122 | 4, 2, 1, bias=False) 123 | self.bn5 = nn.BatchNorm2d(params['ngf']) 124 | 125 | 126 | self.tconv6 = nn.ConvTranspose2d(params['ngf'], params['ngf'], 127 | 4, 2, 1, bias=False) 128 | self.bn6 = nn.BatchNorm2d(params['ngf']) 129 | 130 | # self.tconv7 = nn.ConvTranspose2d(params['ngf'], params['ngf'], 131 | # 4, 2, 1, bias=False) 132 | # self.bn7 = nn.BatchNorm2d(params['ngf']) 133 | 134 | 135 | self.tconv8 = nn.ConvTranspose2d(params['ngf'], params['nc'], 136 | 4, 2, 1, bias=False) 137 | self.upsample = nn.Upsample(scale_factor=2) 138 | #Output Dimension: (nc) x 64 x 64 139 | # self.x_neu, self.out_nue = self.get_neutral() 140 | # self.mean, self.std = np.load('./results/data_statis.npy') 141 | # self.mean = torch.tensor(self.mean).cuda().reshape(1,-1,1,1).float() 142 | # self.std = torch.tensor(self.std).cuda().reshape(1,-1,1,1).float() 143 | 144 | # self.convert_norm = convert_norm 145 | 146 | def get_neutral(self): 147 | img_path = '/project/ard/3DFacialExpression/Images/ZHEN_v3_all/Images/sample_3591_.0000.jpg' 148 | rig_path = '/project/ard/3DFacialExpression/Ctrls/ZHEN/sample_3591_CtrlRigs.txt' 149 | mean, std = get_statis() 150 | img = read_image(img_path, mode='rgb', size=256) 151 | rigs,_,_,_,_ = load_ctr_from_txt(rig_path, do_flip=False) 152 | img = np.array(img)/255. 153 | rigs = (rigs - mean) / std 154 | return torch.tensor(rigs).cuda().float().reshape(1,139,1,1), torch.tensor(img.transpose(2,0,1)).cuda().float().unsqueeze(0) 155 | 156 | def forward(self, x): 157 | # x = F.dropout(self.fc1(x.view(-1,139))) 158 | # x = F.dropout(self.fc2(x)).view(-1,64*4,1,1) 159 | # x = x - self.x_neu 160 | # 161 | # if self.convert_norm: 162 | # x = ((x * 2 - 1) - self.mean)/self.std 163 | 164 | x1 = F.relu(self.bn1(self.tconv1(x))) 165 | x2 = F.relu(self.bn2(self.tconv2(x1))) 166 | x3 = F.relu(self.bn3(self.tconv3(x2))) 167 | x4 = F.relu(self.bn4(self.tconv4(x3))) 168 | x5 = F.relu(self.bn5(self.tconv5(x4))) 169 | out = F.relu(self.bn6(self.tconv6(x5))) 170 | # x6 = self.upsample(x5) + x6 171 | # x = F.relu(self.bn7(self.tconv7(x))) 172 | if self.params['nc'] !=3: 173 | out = x4 174 | 175 | if self.activation == 'tanh': 176 | x7 = F.tanh(self.tconv8(out)) 177 | elif self.activation =='sigmoid': 178 | x7 = F.sigmoid(self.tconv8(out)) 179 | else: 180 | raise NotImplementedError 181 | # x7 = F.sigmoid(self.tconv8(x6) - self.out_nue) 182 | return x7 183 | 184 | 185 | class Generator_rectangle(nn.Module): 186 | ''' 187 | From https://github.com/Natsu6767/DCGAN-PyTorch/blob/master/dcgan.py 188 | ''' 189 | def __init__(self, params, activation='sigmoid', convert_norm=False): 190 | super().__init__() 191 | 192 | # Input is the latent vector Z. 193 | # self.fc1 = nn.Linear(139, 64*8) 194 | # self.fc2 = nn.Linear(64*8, 64*4) 195 | # TODO params['nz'] / 64*4 196 | self.params = params 197 | self.activation = activation 198 | self.tconv1 = nn.ConvTranspose2d(params['nz'], params['ngf']*8, 199 | kernel_size=4, stride=1, padding=0, bias=False) 200 | self.bn1 = nn.BatchNorm2d(params['ngf']*8) 201 | # Input Dimension: (ngf*8) x 4 x 4 202 | self.tconv2 = nn.ConvTranspose2d(params['ngf']*8, params['ngf']*4, 203 | 4, 2, 1, bias=False) 204 | self.bn2 = nn.BatchNorm2d(params['ngf']*4) 205 | 206 | # Input Dimension: (ngf*4) x 8 x 8 207 | self.tconv3 = nn.ConvTranspose2d(params['ngf']*4, params['ngf']*2, 208 | 4, 2, 1, bias=False) 209 | self.bn3 = nn.BatchNorm2d(params['ngf']*2) 210 | 211 | # Input Dimension: (ngf*2) x 16 x 16 212 | self.tconv4 = nn.ConvTranspose2d(params['ngf']*2, params['ngf'], 213 | 4, 2, 1, bias=False) 214 | self.bn4 = nn.BatchNorm2d(params['ngf']) 215 | 216 | # Input Dimension: (ngf) * 32 * 32 217 | self.tconv5 = nn.ConvTranspose2d(params['ngf'], params['ngf'], 218 | 4, 2, 1, bias=False) 219 | self.bn5 = nn.BatchNorm2d(params['ngf']) 220 | 221 | 222 | self.tconv8 = nn.ConvTranspose2d(params['ngf'], params['nc'], 223 | 4, 2, 1, bias=False) 224 | self.pooling = nn.MaxPool2d((4,1),(4,1)) 225 | 226 | def forward(self, x): 227 | x1 = F.relu(self.bn1(self.tconv1(x))) 228 | x2 = F.relu(self.bn2(self.tconv2(x1))) 229 | x3 = F.relu(self.bn3(self.tconv3(x2))) 230 | x4 = F.relu(self.bn4(self.tconv4(x3))) 231 | x5 = F.relu(self.bn5(self.tconv5(x4))) 232 | 233 | if self.params['nc'] !=3: 234 | out = x4 235 | 236 | out = self.pooling(x5) 237 | if self.activation == 'tanh': 238 | x7 = F.tanh(self.tconv8(out)) 239 | elif self.activation =='sigmoid': 240 | x7 = F.sigmoid(self.tconv8(out)) 241 | else: 242 | raise NotImplementedError 243 | return x7 244 | 245 | 246 | if __name__ == '__main__': 247 | params = {'nz':4, 'ngf':64, 'nc':3} 248 | model = Generator_rectangle(params) 249 | inputs = torch.randn((8, 4,1,1)) 250 | outputs = model(inputs) 251 | 252 | # inputs = torch.randn((8, 3,256,256)) 253 | # model = DCGAN_Discriminator(n_channel=3) 254 | # out = model(inputs) 255 | # print(1) -------------------------------------------------------------------------------- /mae-main/models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from timm.models.vision_transformer import PatchEmbed, Block 18 | 19 | from util.pos_embed import get_2d_sincos_pos_embed 20 | 21 | 22 | class MaskedAutoencoderViT(nn.Module): 23 | """ Masked Autoencoder with VisionTransformer backbone 24 | """ 25 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 26 | embed_dim=1024, depth=24, num_heads=16, 27 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 28 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 29 | super().__init__() 30 | 31 | # -------------------------------------------------------------------------- 32 | # MAE encoder specifics 33 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 34 | num_patches = self.patch_embed.num_patches 35 | 36 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 37 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 38 | 39 | self.blocks = nn.ModuleList([ 40 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 41 | for i in range(depth)]) 42 | self.norm = norm_layer(embed_dim) 43 | # -------------------------------------------------------------------------- 44 | 45 | # -------------------------------------------------------------------------- 46 | # MAE decoder specifics 47 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 48 | 49 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 50 | 51 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 52 | 53 | self.decoder_blocks = nn.ModuleList([ 54 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 55 | for i in range(decoder_depth)]) 56 | 57 | self.decoder_norm = norm_layer(decoder_embed_dim) 58 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 59 | # -------------------------------------------------------------------------- 60 | 61 | self.norm_pix_loss = norm_pix_loss 62 | 63 | self.initialize_weights() 64 | 65 | def initialize_weights(self): 66 | # initialization 67 | # initialize (and freeze) pos_embed by sin-cos embedding 68 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 69 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 70 | 71 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 72 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 73 | 74 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 75 | w = self.patch_embed.proj.weight.data 76 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 77 | 78 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 79 | torch.nn.init.normal_(self.cls_token, std=.02) 80 | torch.nn.init.normal_(self.mask_token, std=.02) 81 | 82 | # initialize nn.Linear and nn.LayerNorm 83 | self.apply(self._init_weights) 84 | 85 | def _init_weights(self, m): 86 | if isinstance(m, nn.Linear): 87 | # we use xavier_uniform following official JAX ViT: 88 | torch.nn.init.xavier_uniform_(m.weight) 89 | if isinstance(m, nn.Linear) and m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.LayerNorm): 92 | nn.init.constant_(m.bias, 0) 93 | nn.init.constant_(m.weight, 1.0) 94 | 95 | def patchify(self, imgs): 96 | """ 97 | imgs: (N, 3, H, W) 98 | x: (N, L, patch_size**2 *3) 99 | """ 100 | p = self.patch_embed.patch_size[0] 101 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 102 | 103 | h = w = imgs.shape[2] // p 104 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 105 | x = torch.einsum('nchpwq->nhwpqc', x) 106 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 107 | return x 108 | 109 | def unpatchify(self, x): 110 | """ 111 | x: (N, L, patch_size**2 *3) 112 | imgs: (N, 3, H, W) 113 | """ 114 | p = self.patch_embed.patch_size[0] 115 | h = w = int(x.shape[1]**.5) 116 | assert h * w == x.shape[1] 117 | 118 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 119 | x = torch.einsum('nhwpqc->nchpwq', x) 120 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 121 | return imgs 122 | 123 | def random_masking(self, x, mask_ratio): 124 | """ 125 | Perform per-sample random masking by per-sample shuffling. 126 | Per-sample shuffling is done by argsort random noise. 127 | x: [N, L, D], sequence 128 | """ 129 | N, L, D = x.shape # batch, length, dim 130 | len_keep = int(L * (1 - mask_ratio)) 131 | 132 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 133 | 134 | # sort noise for each sample 135 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 136 | ids_restore = torch.argsort(ids_shuffle, dim=1) 137 | 138 | # keep the first subset 139 | ids_keep = ids_shuffle[:, :len_keep] 140 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 141 | 142 | # generate the binary mask: 0 is keep, 1 is remove 143 | mask = torch.ones([N, L], device=x.device) 144 | mask[:, :len_keep] = 0 145 | # unshuffle to get the binary mask 146 | mask = torch.gather(mask, dim=1, index=ids_restore) 147 | 148 | return x_masked, mask, ids_restore 149 | 150 | def forward_encoder(self, x, mask_ratio): 151 | # embed patches 152 | x = self.patch_embed(x) 153 | 154 | # add pos embed w/o cls token 155 | x = x + self.pos_embed[:, 1:, :] 156 | 157 | # masking: length -> length * mask_ratio 158 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 159 | 160 | # append cls token 161 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 162 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 163 | x = torch.cat((cls_tokens, x), dim=1) 164 | 165 | # apply Transformer blocks 166 | for blk in self.blocks: 167 | x = blk(x) 168 | x = self.norm(x) 169 | 170 | return x, mask, ids_restore 171 | 172 | def forward_decoder(self, x, ids_restore): 173 | # embed tokens 174 | x = self.decoder_embed(x) 175 | 176 | # append mask tokens to sequence 177 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 178 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 179 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 180 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 181 | 182 | # add pos embed 183 | x = x + self.decoder_pos_embed 184 | 185 | # apply Transformer blocks 186 | for blk in self.decoder_blocks: 187 | x = blk(x) 188 | x = self.decoder_norm(x) 189 | 190 | # predictor projection 191 | x = self.decoder_pred(x) 192 | 193 | # remove cls token 194 | x = x[:, 1:, :] 195 | 196 | return x 197 | 198 | def forward_loss(self, imgs, pred, mask): 199 | """ 200 | imgs: [N, 3, H, W] 201 | pred: [N, L, p*p*3] 202 | mask: [N, L], 0 is keep, 1 is remove, 203 | """ 204 | target = self.patchify(imgs) 205 | if self.norm_pix_loss: 206 | mean = target.mean(dim=-1, keepdim=True) 207 | var = target.var(dim=-1, keepdim=True) 208 | target = (target - mean) / (var + 1.e-6)**.5 209 | 210 | loss = (pred - target) ** 2 211 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 212 | 213 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 214 | return loss 215 | 216 | def forward(self, imgs, mask_ratio=0.75): 217 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 218 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 219 | loss = self.forward_loss(imgs, pred, mask) 220 | return loss, pred, mask 221 | 222 | 223 | def mae_vit_base_patch16_dec512d8b(**kwargs): 224 | model = MaskedAutoencoderViT( 225 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 226 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 227 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 228 | return model 229 | 230 | 231 | def mae_vit_large_patch16_dec512d8b(**kwargs): 232 | model = MaskedAutoencoderViT( 233 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 234 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 235 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 236 | return model 237 | 238 | 239 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 240 | model = MaskedAutoencoderViT( 241 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 242 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 243 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 244 | return model 245 | 246 | 247 | # set recommended archs 248 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 249 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 250 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 251 | -------------------------------------------------------------------------------- /models/Emoca_ExprNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Soubhik Sanyal 3 | Copyright (c) 2019, Soubhik Sanyal 4 | All rights reserved. 5 | Loads different resnet models 6 | """ 7 | ''' 8 | file: Resnet.py 9 | date: 2018_05_02 10 | author: zhangxiong(1025679612@qq.com) 11 | mark: copied from pytorch source code 12 | ''' 13 | 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch 17 | from torch.nn.parameter import Parameter 18 | import torch.optim as optim 19 | import numpy as np 20 | import math 21 | import torchvision 22 | 23 | class ResNet(nn.Module): 24 | def __init__(self, block, layers, num_classes=1000): 25 | self.inplanes = 64 26 | super(ResNet, self).__init__() 27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 28 | bias=False) 29 | self.bn1 = nn.BatchNorm2d(64) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 32 | self.layer1 = self._make_layer(block, 64, layers[0]) 33 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 34 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 35 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 36 | self.avgpool = nn.AvgPool2d(7, stride=1) 37 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 38 | 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, math.sqrt(2. / n)) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.fill_(1) 45 | m.bias.data.zero_() 46 | 47 | def _make_layer(self, block, planes, blocks, stride=1): 48 | downsample = None 49 | if stride != 1 or self.inplanes != planes * block.expansion: 50 | downsample = nn.Sequential( 51 | nn.Conv2d(self.inplanes, planes * block.expansion, 52 | kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(planes * block.expansion), 54 | ) 55 | 56 | layers = [] 57 | layers.append(block(self.inplanes, planes, stride, downsample)) 58 | self.inplanes = planes * block.expansion 59 | for i in range(1, blocks): 60 | layers.append(block(self.inplanes, planes)) 61 | 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | x = self.conv1(x) 66 | x = self.bn1(x) 67 | x = self.relu(x) 68 | x = self.maxpool(x) 69 | 70 | x = self.layer1(x) 71 | x = self.layer2(x) 72 | x = self.layer3(x) 73 | x1 = self.layer4(x) 74 | 75 | x2 = self.avgpool(x1) 76 | x2 = x2.view(x2.size(0), -1) 77 | # x = self.fc(x) 78 | ## x2: [bz, 2048] for shape 79 | ## x1: [bz, 2048, 7, 7] for texture 80 | return x2 81 | 82 | class Bottleneck(nn.Module): 83 | expansion = 4 84 | 85 | def __init__(self, inplanes, planes, stride=1, downsample=None): 86 | super(Bottleneck, self).__init__() 87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(planes) 89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 90 | padding=1, bias=False) 91 | self.bn2 = nn.BatchNorm2d(planes) 92 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 93 | self.bn3 = nn.BatchNorm2d(planes * 4) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | out = self.bn3(out) 111 | 112 | if self.downsample is not None: 113 | residual = self.downsample(x) 114 | 115 | out += residual 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | def conv3x3(in_planes, out_planes, stride=1): 121 | """3x3 convolution with padding""" 122 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 123 | padding=1, bias=False) 124 | 125 | class BasicBlock(nn.Module): 126 | expansion = 1 127 | 128 | def __init__(self, inplanes, planes, stride=1, downsample=None): 129 | super(BasicBlock, self).__init__() 130 | self.conv1 = conv3x3(inplanes, planes, stride) 131 | self.bn1 = nn.BatchNorm2d(planes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.conv2 = conv3x3(planes, planes) 134 | self.bn2 = nn.BatchNorm2d(planes) 135 | self.downsample = downsample 136 | self.stride = stride 137 | 138 | def forward(self, x): 139 | residual = x 140 | 141 | out = self.conv1(x) 142 | out = self.bn1(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv2(out) 146 | out = self.bn2(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | def copy_parameter_from_resnet(model, resnet_dict): 157 | cur_state_dict = model.state_dict() 158 | # import ipdb; ipdb.set_trace() 159 | for name, param in list(resnet_dict.items())[0:None]: 160 | if name not in cur_state_dict: 161 | # print(name, ' not available in reconstructed resnet') 162 | continue 163 | if isinstance(param, Parameter): 164 | param = param.data 165 | try: 166 | cur_state_dict[name].copy_(param) 167 | except: 168 | # print(name, ' is inconsistent!') 169 | continue 170 | # print('copy resnet state dict finished!') 171 | # import ipdb; ipdb.set_trace() 172 | 173 | 174 | def load_ResNet50Model(): 175 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 176 | copy_parameter_from_resnet(model, torchvision.models.resnet50(pretrained = True).state_dict()) 177 | return model 178 | 179 | def load_ResNet101Model(): 180 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 181 | copy_parameter_from_resnet(model, torchvision.models.resnet101(pretrained = True).state_dict()) 182 | return model 183 | 184 | def load_ResNet152Model(): 185 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 186 | copy_parameter_from_resnet(model, torchvision.models.resnet152(pretrained = True).state_dict()) 187 | return model 188 | 189 | # model.load_state_dict(checkpoint['model_state_dict']) 190 | 191 | 192 | ######## Unet 193 | 194 | class DoubleConv(nn.Module): 195 | """(convolution => [BN] => ReLU) * 2""" 196 | 197 | def __init__(self, in_channels, out_channels): 198 | super().__init__() 199 | self.double_conv = nn.Sequential( 200 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 201 | nn.BatchNorm2d(out_channels), 202 | nn.ReLU(inplace=True), 203 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 204 | nn.BatchNorm2d(out_channels), 205 | nn.ReLU(inplace=True) 206 | ) 207 | 208 | def forward(self, x): 209 | return self.double_conv(x) 210 | 211 | 212 | class Down(nn.Module): 213 | """Downscaling with maxpool then double conv""" 214 | 215 | def __init__(self, in_channels, out_channels): 216 | super().__init__() 217 | self.maxpool_conv = nn.Sequential( 218 | nn.MaxPool2d(2), 219 | DoubleConv(in_channels, out_channels) 220 | ) 221 | 222 | def forward(self, x): 223 | return self.maxpool_conv(x) 224 | 225 | 226 | class Up(nn.Module): 227 | """Upscaling then double conv""" 228 | 229 | def __init__(self, in_channels, out_channels, bilinear=True): 230 | super().__init__() 231 | 232 | # if bilinear, use the normal convolutions to reduce the number of channels 233 | if bilinear: 234 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 235 | else: 236 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 237 | 238 | self.conv = DoubleConv(in_channels, out_channels) 239 | 240 | def forward(self, x1, x2): 241 | x1 = self.up(x1) 242 | # input is CHW 243 | diffY = x2.size()[2] - x1.size()[2] 244 | diffX = x2.size()[3] - x1.size()[3] 245 | 246 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 247 | diffY // 2, diffY - diffY // 2]) 248 | # if you have padding issues, see 249 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 250 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 251 | x = torch.cat([x2, x1], dim=1) 252 | return self.conv(x) 253 | 254 | 255 | class OutConv(nn.Module): 256 | def __init__(self, in_channels, out_channels): 257 | super(OutConv, self).__init__() 258 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 259 | 260 | def forward(self, x): 261 | return self.conv(x) 262 | 263 | class ExpressionLossNet(nn.Module): 264 | """ Code borrowed from EMOCA https://github.com/radekd91/emoca """ 265 | def __init__(self): 266 | super(ExpressionLossNet, self).__init__() 267 | path_ckpt = "/data/Workspace/emoca/gdl_apps/EmotionRecognition/checkpoints/ResNet50/checkpoints/deca-epoch=01-val_loss_total/dataloader_idx_0=1.27607644.ckpt" 268 | self.backbone = load_ResNet50Model().eval() #out: 2048 269 | ckpt = torch.load(path_ckpt)['state_dict'] 270 | ckpt ={key.replace('backbone.', ''):ckpt[key] for key in ckpt} 271 | 272 | self.backbone.load_state_dict(ckpt, strict=False) 273 | self.linear = nn.Sequential( 274 | nn.Linear(2048, 10)) 275 | 276 | def forward(self, inputs): 277 | with torch.no_grad(): 278 | features = self.backbone(inputs) 279 | out = self.linear(features) 280 | return features, out 281 | 282 | def forward2(self, inputs): 283 | with torch.no_grad(): 284 | features = self.backbone(inputs) 285 | return features 286 | 287 | if __name__ == '__main__': 288 | inputs_d = torch.zeros((2,3,224,224)) 289 | # backbone = load_ResNet50Model() 290 | backbone = ExpressionLossNet() 291 | torch.load("/data/Workspace/emoca/gdl_apps/EmotionRecognition/checkpoints/ResNet50/checkpoints/deca-epoch=01-val_loss_total/dataloader_idx_0=1.27607644.ckpt")['state_dict'] -------------------------------------------------------------------------------- /mae-main/util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 316 | if args.resume: 317 | if args.resume.startswith('https'): 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | args.resume, map_location='cpu', check_hash=True) 320 | else: 321 | checkpoint = torch.load(args.resume, map_location='cpu') 322 | model_without_ddp.load_state_dict(checkpoint['model']) 323 | print("Resume checkpoint %s" % args.resume) 324 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | args.start_epoch = checkpoint['epoch'] + 1 327 | if 'scaler' in checkpoint: 328 | loss_scaler.load_state_dict(checkpoint['scaler']) 329 | print("With optim & sched!") 330 | 331 | 332 | def all_reduce_mean(x): 333 | world_size = get_world_size() 334 | if world_size > 1: 335 | x_reduce = torch.tensor(x).cuda() 336 | dist.all_reduce(x_reduce) 337 | x_reduce /= world_size 338 | return x_reduce.item() 339 | else: 340 | return x -------------------------------------------------------------------------------- /models/facenet2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | 8 | class BasicConv2d(nn.Module): 9 | 10 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 11 | super().__init__() 12 | self.conv = nn.Conv2d( 13 | in_planes, out_planes, 14 | kernel_size=kernel_size, stride=stride, 15 | padding=padding, bias=False 16 | ) # verify bias false 17 | self.bn = nn.BatchNorm2d( 18 | out_planes, 19 | eps=0.001, # value found in tensorflow 20 | momentum=0.1, # default pytorch value 21 | affine=True 22 | ) 23 | self.relu = nn.ReLU(inplace=False) 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | x = self.bn(x) 28 | x = self.relu(x) 29 | return x 30 | 31 | 32 | class Block35(nn.Module): 33 | 34 | def __init__(self, scale=1.0): 35 | super().__init__() 36 | 37 | self.scale = scale 38 | 39 | self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1) 40 | 41 | self.branch1 = nn.Sequential( 42 | BasicConv2d(256, 32, kernel_size=1, stride=1), 43 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 44 | ) 45 | 46 | self.branch2 = nn.Sequential( 47 | BasicConv2d(256, 32, kernel_size=1, stride=1), 48 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1), 49 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 50 | ) 51 | 52 | self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1) 53 | self.relu = nn.ReLU(inplace=False) 54 | 55 | def forward(self, x): 56 | x0 = self.branch0(x) 57 | x1 = self.branch1(x) 58 | x2 = self.branch2(x) 59 | out = torch.cat((x0, x1, x2), 1) 60 | out = self.conv2d(out) 61 | out = out * self.scale + x 62 | out = self.relu(out) 63 | return out 64 | 65 | 66 | class Block17(nn.Module): 67 | 68 | def __init__(self, scale=1.0): 69 | super().__init__() 70 | 71 | self.scale = scale 72 | 73 | self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1) 74 | 75 | self.branch1 = nn.Sequential( 76 | BasicConv2d(896, 128, kernel_size=1, stride=1), 77 | BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)), 78 | BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0)) 79 | ) 80 | 81 | self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1) 82 | self.relu = nn.ReLU(inplace=False) 83 | 84 | def forward(self, x): 85 | x0 = self.branch0(x) 86 | x1 = self.branch1(x) 87 | out = torch.cat((x0, x1), 1) 88 | out = self.conv2d(out) 89 | out = out * self.scale + x 90 | out = self.relu(out) 91 | return out 92 | 93 | 94 | class Block8(nn.Module): 95 | 96 | def __init__(self, scale=1.0, noReLU=False): 97 | super().__init__() 98 | 99 | self.scale = scale 100 | self.noReLU = noReLU 101 | 102 | self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1) 103 | 104 | self.branch1 = nn.Sequential( 105 | BasicConv2d(1792, 192, kernel_size=1, stride=1), 106 | BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)), 107 | BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0)) 108 | ) 109 | 110 | self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1) 111 | if not self.noReLU: 112 | self.relu = nn.ReLU(inplace=False) 113 | 114 | def forward(self, x): 115 | x0 = self.branch0(x) 116 | x1 = self.branch1(x) 117 | out = torch.cat((x0, x1), 1) 118 | out = self.conv2d(out) 119 | out = out * self.scale + x 120 | if not self.noReLU: 121 | out = self.relu(out) 122 | return out 123 | 124 | 125 | class Mixed_6a(nn.Module): 126 | 127 | def __init__(self): 128 | super().__init__() 129 | 130 | self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2) 131 | 132 | self.branch1 = nn.Sequential( 133 | BasicConv2d(256, 192, kernel_size=1, stride=1), 134 | BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1), 135 | BasicConv2d(192, 256, kernel_size=3, stride=2) 136 | ) 137 | 138 | self.branch2 = nn.MaxPool2d(3, stride=2) 139 | 140 | def forward(self, x): 141 | x0 = self.branch0(x) 142 | x1 = self.branch1(x) 143 | x2 = self.branch2(x) 144 | out = torch.cat((x0, x1, x2), 1) 145 | return out 146 | 147 | 148 | class Mixed_7a(nn.Module): 149 | 150 | def __init__(self): 151 | super().__init__() 152 | 153 | self.branch0 = nn.Sequential( 154 | BasicConv2d(896, 256, kernel_size=1, stride=1), 155 | BasicConv2d(256, 384, kernel_size=3, stride=2) 156 | ) 157 | 158 | self.branch1 = nn.Sequential( 159 | BasicConv2d(896, 256, kernel_size=1, stride=1), 160 | BasicConv2d(256, 256, kernel_size=3, stride=2) 161 | ) 162 | 163 | self.branch2 = nn.Sequential( 164 | BasicConv2d(896, 256, kernel_size=1, stride=1), 165 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), 166 | BasicConv2d(256, 256, kernel_size=3, stride=2) 167 | ) 168 | 169 | self.branch3 = nn.MaxPool2d(3, stride=2) 170 | 171 | def forward(self, x): 172 | x0 = self.branch0(x) 173 | x1 = self.branch1(x) 174 | x2 = self.branch2(x) 175 | x3 = self.branch3(x) 176 | out = torch.cat((x0, x1, x2, x3), 1) 177 | return out 178 | 179 | 180 | class InceptionResnetV1(nn.Module): 181 | """Inception Resnet V1 model with optional loading of pretrained weights. 182 | Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface 183 | datasets. Pretrained state_dicts are automatically downloaded on model instantiation if 184 | requested and cached in the torch cache. Subsequent instantiations use the cache rather than 185 | redownloading. 186 | Keyword Arguments: 187 | pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'. 188 | (default: {None}) 189 | classify {bool} -- Whether the model should output classification probabilities or feature 190 | embeddings. (default: {False}) 191 | num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not 192 | equal to that used for the pretrained model, the final linear layer will be randomly 193 | initialized. (default: {None}) 194 | dropout_prob {float} -- Dropout probability. (default: {0.6}) 195 | """ 196 | def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None,is_R = False): 197 | super().__init__() 198 | 199 | # Set simple attributes 200 | self.pretrained = pretrained 201 | self.classify = classify 202 | self.num_classes = num_classes 203 | 204 | if pretrained == 'vggface2': 205 | tmp_classes = 8631 206 | elif pretrained == 'casia-webface': 207 | tmp_classes = 10575 208 | elif pretrained is None and self.classify and self.num_classes is None: 209 | raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified') 210 | 211 | 212 | # Define layers 213 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) 214 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) 215 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 216 | self.maxpool_3a = nn.MaxPool2d(3, stride=2) 217 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) 218 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) 219 | self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2) 220 | self.repeat_1 = nn.Sequential( 221 | Block35(scale=0.17), 222 | Block35(scale=0.17), 223 | Block35(scale=0.17), 224 | Block35(scale=0.17), 225 | Block35(scale=0.17), 226 | ) 227 | self.mixed_6a = Mixed_6a() 228 | self.repeat_2 = nn.Sequential( 229 | Block17(scale=0.10), 230 | Block17(scale=0.10), 231 | Block17(scale=0.10), 232 | Block17(scale=0.10), 233 | Block17(scale=0.10), 234 | Block17(scale=0.10), 235 | Block17(scale=0.10), 236 | Block17(scale=0.10), 237 | Block17(scale=0.10), 238 | Block17(scale=0.10), 239 | ) 240 | self.mixed_7a = Mixed_7a() 241 | self.repeat_3 = nn.Sequential( 242 | Block8(scale=0.20), 243 | Block8(scale=0.20), 244 | Block8(scale=0.20), 245 | Block8(scale=0.20), 246 | Block8(scale=0.20), 247 | ) 248 | self.block8 = Block8(noReLU=True) 249 | self.avgpool_1a = nn.AdaptiveAvgPool2d(1) 250 | self.dropout = nn.Dropout(dropout_prob) 251 | self.last_linear = nn.Linear(1792, 512, bias=False) 252 | self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True) 253 | self.grads = [] 254 | self.grads2 = [] 255 | self.is_R = is_R 256 | self.save_grad_flag = False 257 | if pretrained is not None: 258 | self.logits = nn.Linear(512, tmp_classes) 259 | load_weights(self, pretrained) 260 | 261 | if self.classify and self.num_classes is not None: 262 | self.logits = nn.Linear(512, self.num_classes) 263 | self.device = torch.device('cuda:0') 264 | if device is not None: 265 | self.device = device 266 | self.to(device) 267 | 268 | def forward(self, x, return_skip = False): 269 | """Calculate embeddings or logits given a batch of input image tensors. 270 | Arguments: 271 | x {torch.tensor} -- Batch of image tensors representing faces. 272 | Returns: 273 | torch.tensor -- Batch of embedding vectors or multinomial logits. 274 | """ 275 | x = self.conv2d_1a(x) 276 | x = self.conv2d_2a(x) 277 | # if self.save_grad_flag == True and len(self.grads)<300: 278 | # x.register_hook(lambda grad: self.grads.append(grad.view(grad.shape[0],grad.shape[1],-1).mean(-1).mean(0).detach().cpu().numpy())) 279 | # if self.save_grad_flag == True and len(self.grads)>0: 280 | # print(self.grads[0].shape) 281 | x = self.conv2d_2b(x) 282 | s1 = x 283 | x = self.maxpool_3a(x) 284 | x = self.conv2d_3b(x) 285 | s2 = x 286 | x = self.conv2d_4a(x) 287 | x = self.conv2d_4b(x) 288 | x = self.repeat_1(x) 289 | s3 = x 290 | 291 | x = self.mixed_6a(x) 292 | x = self.repeat_2(x) 293 | s4 = x 294 | x = self.mixed_7a(x) 295 | x = self.repeat_3(x) 296 | x = self.block8(x) 297 | if self.save_grad_flag == True and len(self.grads)<300: 298 | x.register_hook(lambda grad: self.grads2.append(grad.view(grad.shape[0],grad.shape[1],-1).mean(0).mean(-1).detach().cpu().numpy())) 299 | if self.save_grad_flag == True and len(self.grads)>0: 300 | print(self.grads2[0].shape) 301 | x = self.avgpool_1a(x) 302 | x = self.dropout(x) 303 | x = self.last_linear(x.view(x.shape[0], -1)) 304 | # if x.shape[0]>1: 305 | x = self.last_bn(x) 306 | if self.is_R == False: 307 | x= F.normalize(x, p=2, dim=1) 308 | return x 309 | 310 | 311 | def load_weights(mdl, name): 312 | """Download pretrained state_dict and load into model. 313 | Arguments: 314 | mdl {torch.nn.Module} -- Pytorch model. 315 | name {str} -- Name of dataset that was used to generate pretrained state_dict. 316 | Raises: 317 | ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'. 318 | """ 319 | if name == 'vggface2': 320 | path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt' 321 | elif name == 'casia-webface': 322 | path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt' 323 | else: 324 | raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"') 325 | 326 | model_dir = os.path.join(get_torch_home(), 'checkpoints1') 327 | os.makedirs(model_dir, exist_ok=True) 328 | 329 | cached_file = os.path.join(model_dir, os.path.basename(path)) 330 | print(path) 331 | print(cached_file) 332 | if not os.path.exists(cached_file): 333 | download_url_to_file(path, cached_file) 334 | 335 | state_dict = torch.load(cached_file) 336 | mdl.load_state_dict(state_dict) 337 | 338 | 339 | def get_torch_home(): 340 | torch_home = os.path.expanduser( 341 | os.getenv( 342 | 'TORCH_HOME', 343 | os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch') 344 | ) 345 | ) 346 | return torch_home 347 | 348 | # facenet = InceptionResnetV1(pretrained="vggface2").cuda() 349 | # a=facenet.parameters() 350 | # load_weights(facenet,'vggface2') 351 | # #x = torch.rand([16,3,224,224]).cuda() 352 | # import cv2 353 | # from PIL import Image 354 | # from torchvision import transforms 355 | # import numpy as np 356 | # f = open("../face.txt",'w') 357 | # im = cv2.imread("../15660_1.jpg") 358 | # im = Image.fromarray(im,mode='RGB') 359 | # x = transforms.ToTensor()(im).unsqueeze(0).cuda() 360 | # res = facenet(x).cpu().detach().numpy() 361 | # mean = np.mean(res) 362 | # print(mean) 363 | # res = res-mean 364 | # c = np.dot(res.T,res) 365 | # for i in range(512): 366 | # for j in range(512): 367 | # f.write(str(c[i][j])+",") 368 | # f.write("\n") 369 | # f.close() 370 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | # import utils.util as util 6 | import torch.nn.utils.spectral_norm as spectral_norm 7 | import argparse 8 | 9 | def get_norm_layer(norm_type='instance'): 10 | # helper function to get # output channels of the previous layer 11 | def get_out_channel(layer): 12 | if hasattr(layer, 'out_channels'): 13 | return getattr(layer, 'out_channels') 14 | return layer.weight.size(0) 15 | 16 | # this function will be returned 17 | def add_norm_layer(layer): 18 | nonlocal norm_type 19 | if norm_type.startswith('spectral'): 20 | layer = spectral_norm(layer) 21 | subnorm_type = norm_type[len('spectral'):] 22 | else: 23 | subnorm_type = norm_type 24 | 25 | if subnorm_type == 'none' or len(subnorm_type) == 0: 26 | return layer 27 | 28 | # remove bias in the previous layer, which is meaningless 29 | # since it has no effect after normalization 30 | if getattr(layer, 'bias', None) is not None: 31 | delattr(layer, 'bias') 32 | layer.register_parameter('bias', None) 33 | 34 | if subnorm_type == 'batch': 35 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 36 | # elif subnorm_type == 'sync_batch': 37 | # norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 38 | elif subnorm_type == 'instance': 39 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 40 | else: 41 | raise ValueError( 42 | 'normalization layer %s is not recognized' % subnorm_type) 43 | 44 | return nn.Sequential(layer, norm_layer) 45 | 46 | return add_norm_layer 47 | 48 | class MultiscaleDiscriminator(nn.Module): 49 | 50 | def __init__(self, opt): 51 | super().__init__() 52 | self.opt = opt 53 | 54 | for i in range(opt.num_D): 55 | subnetD = self.create_single_discriminator(opt) 56 | self.add_module('discriminator_%d' % i, subnetD) 57 | 58 | # self.init_weights(opt.init_type, opt.init_variance) 59 | 60 | def create_single_discriminator(self, opt): 61 | netD = NLayerDiscriminator(opt) 62 | return netD 63 | 64 | def downsample(self, input): 65 | return F.avg_pool2d(input, kernel_size=3, 66 | stride=2, padding=[1, 1], 67 | count_include_pad=False) 68 | 69 | # Returns list of lists of discriminator outputs. 70 | # The final result is of size opt.num_D x opt.n_layers_D 71 | def forward(self, input): 72 | result = [] 73 | get_intermediate_features = self.opt.fm_power > 0 74 | for name, D in self.named_children(): 75 | out = D(input) 76 | if not get_intermediate_features: 77 | out = [out] 78 | result.append(out) 79 | input = self.downsample(input) 80 | return result 81 | 82 | 83 | # Defines the PatchGAN discriminator with the specified arguments. 84 | class NLayerDiscriminator(nn.Module): 85 | 86 | def __init__(self, opt): 87 | super().__init__() 88 | self.opt = opt 89 | 90 | kw = 4 91 | padw = int(np.ceil((kw - 1.0) / 2)) 92 | nf = opt.ndf 93 | input_nc = opt.input_channel 94 | 95 | norm_layer = get_norm_layer(opt.norm_D) 96 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 97 | nn.LeakyReLU(0.2, False)]] 98 | 99 | for n in range(1, opt.n_layers_D): 100 | nf_prev = nf 101 | nf = min(nf * 2, opt.max_conv_dim) 102 | stride = 1 if n == opt.n_layers_D - 1 else 2 103 | sequence += [[ 104 | norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, 105 | stride=stride, padding=padw)), 106 | nn.LeakyReLU(0.2, False) 107 | ]] 108 | 109 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, 110 | stride=1, padding=padw)]] 111 | 112 | # We divide the layers into groups to extract intermediate layer outputs 113 | for n in range(len(sequence)): 114 | self.add_module('model' + str(n), nn.Sequential(*sequence[n])) 115 | 116 | def forward(self, input): 117 | results = [input] 118 | for submodel in self.children(): 119 | intermediate_output = submodel(results[-1]) 120 | results.append(intermediate_output) 121 | get_intermediate_features = self.opt.fm_power > 0 122 | if get_intermediate_features: 123 | return results[1:] 124 | else: 125 | return results[-1] 126 | 127 | 128 | 129 | 130 | def get_parser(): 131 | parser = argparse.ArgumentParser(description='Options for training SwapYou') 132 | 133 | ###1. dataset 134 | parser.add_argument('--dataroot', type=str, default='data/ZHEN_cleaned_rig_train.pkl') 135 | parser.add_argument('--val_dataroot', type=str, default='data/ZHEN_cleaned_rig_train.pkl') 136 | ###2. model arch 137 | parser.add_argument('--input_channel', type=int, default=3, help='# of input image channels') 138 | parser.add_argument('--input_size', type=int, default=256, help='# of input image w&h') 139 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 140 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 141 | parser.add_argument('--mask_df', type=int, default=16, help='# of discrim filters in first conv layer') 142 | 143 | parser.add_argument('--data_view', type=str, default=None) 144 | parser.add_argument('--n_layers_D', type=int, default=4, help='num of layers of D') 145 | parser.add_argument('--num_D', type=int, default=3, help='# of discrim') 146 | parser.add_argument('--type_D', type=str, default='patch', help='type of discrim') 147 | parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization') 148 | parser.add_argument('--sn_G', action='store_true', help='spectral norm for generator') 149 | parser.add_argument('--sn_D', action='store_true', help='spectral norm for discriminator') 150 | parser.add_argument('--epoch', type=int, default=20, help='# total epochs') 151 | parser.add_argument('--decay_epoch', type=int, default=10, help='#epoch to decay learning rate') 152 | parser.add_argument('--decay_factor', type=float, default=0.9, help='#decay_factor to decay learning rate') 153 | parser.add_argument('--id_th', type=float, default=0.0) 154 | parser.add_argument('--id_power', type=float, default=-1.0) 155 | parser.add_argument('--vgg_power', type=float, default=0.0) 156 | parser.add_argument('--ssim_power', type=float, default=0.0) 157 | parser.add_argument('--mouth_power', type=float, default=0.0) 158 | parser.add_argument('--eyes_power', type=float, default=0.0) 159 | parser.add_argument('--sobel_power', type=float, default=0.0) 160 | parser.add_argument('--face_power', type=float, default=0.0) 161 | 162 | parser.add_argument('--id_ch', type=int, default=1024, help="dimension of the id") 163 | parser.add_argument('--content_ch', type=int, default=8, help="dimension of the id") 164 | parser.add_argument('--background_ch', type=int, default=1, help="dimension of the id") 165 | 166 | parser.add_argument('--bottleneck_nums', type=int, default=2, help='bottleneck_nums size') 167 | parser.add_argument('--max_conv_dim', type=int, default=512, help='max_conv_dim') 168 | parser.add_argument('--bottle_dim', type=int, default=16, help='bottle_dim') 169 | parser.add_argument('--exp_dim', type=int, default=16, help='bottle_dim') 170 | 171 | parser.add_argument('--id_num', type=int, default=4, help='the number of id') 172 | 173 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 174 | parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution') 175 | 176 | parser.add_argument('--num_upsampling_layers', default=5,type=int, 177 | help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator") 178 | 179 | parser.add_argument('--skip_module', action='store_true', help='using skip_module') 180 | parser.add_argument('--skip_module_list', nargs='*', type=int, help="skip_module_list for details") 181 | parser.add_argument('--no_adain_list', nargs='*', type=int, help="skip_module_list for details") 182 | parser.add_argument('--attention_type', type=str, default='spatial', help='(spatial|channel|both)') 183 | 184 | parser.add_argument('--new_D', action='store_true', help='using multi domain discriminator') 185 | parser.add_argument('--G_v2', action='store_true', help='using multi domain discriminator') 186 | 187 | 188 | parser.add_argument('--exp_injection', action='store_true', help='inject expression embedding') 189 | parser.add_argument('--residual_learning', action='store_true', help='inject expression embedding') 190 | parser.add_argument('--exp_reduce', action='store_true', help='inject expression embedding') 191 | parser.add_argument('--merge_norm', type=str, default='batchnorm', help='norm type of merge module') 192 | 193 | parser.add_argument('--exp_injection_type', type=str, default='adain', help='inject expression embedding') 194 | 195 | 196 | parser.add_argument('--pretrained_model', type=str, default=None, help='inject expression embedding') 197 | 198 | parser.add_argument('--id_injection_models', type=str, default='facenet#arcfacev2') 199 | parser.add_argument('--id_constrain_models', type=str, default='facenet#5+arcfacev2#10') 200 | parser.add_argument('--id_fusion_method', type=str, default='concat', help='(concat|add)') 201 | ###3. log 202 | parser.add_argument('--name', type=str, default='SwapYou', help='name of the experiment.') 203 | parser.add_argument('--arch', type=str, default='dcgan', help='name of the experiment.') 204 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints_wild', help='models are saved here') 205 | parser.add_argument('--display_freq', type=int, default=50, help='frequency of showing training results on screen') 206 | parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console') 207 | parser.add_argument('--save_step_freq', type=int, default=20000, help='frequency of steps to save model') 208 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of showing training results on console') 209 | 210 | parser.add_argument('--resume_dir', type=str, default=None , help='models are saved here') 211 | 212 | ###4. pretrained checkpints 213 | parser.add_argument('--arcfacev2_model_path', type=str, default='/data/input_models/arcfacev2_resnet50_IR_checkpoint.pth') 214 | parser.add_argument('--facenet_model_path', type=str, default='/data/input_models/20180402-114759-vggface2.pt') 215 | parser.add_argument('--cosface_model_path', type=str, default='/data/input_models/cosface_ACC99.28.pth') 216 | parser.add_argument('--arcface_irse_model_path', type=str, default='/data/input_models/model_ir_se50.pth') 217 | parser.add_argument('--arcface_insight_model_path', type=str, default='/data/input_models/Glint360k_r100_backbone.pth') 218 | parser.add_argument('--exp_model_path', type=str, default='/data/input_models/minus_pipeline_DISFA_FEC_best_aug_857.pth') 219 | 220 | ###5. loss 221 | parser.add_argument('--cycle_power', type=float, default=0.0, help='weight for cycle loss') 222 | parser.add_argument('--cycle_start_epoch', type=int, default=1) 223 | parser.add_argument('--exp_power', type=float, default=10.0, help='weight for exp loss') 224 | parser.add_argument('--neg_exp_power', type=float, default=0.0, help='weight for neg_exp loss') 225 | parser.add_argument('--gan_power', type=float, default=1.0, help='weight for gan loss') 226 | parser.add_argument('--neg_id_power', type=float, default=0.0, help='weight for neg_id loss') 227 | parser.add_argument('--neg_id_exp_power', type=float, default=0.0, help='weight for neg_id loss') 228 | parser.add_argument('--rec_power', type=float, default=10.0, help='weight for rec loss') 229 | parser.add_argument('--fm_power', type=float, default=0.0, help='weight for fm loss') 230 | parser.add_argument('--tv_power', type=float, default=0.0, help='weight for total_variation_loss') 231 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') 232 | parser.add_argument('--mask_power', type=float, default=10.0, help='weight for mask loss') 233 | parser.add_argument('--rig_power', type=float, default=10.0, help='weight for mask loss') 234 | parser.add_argument('--symmetry_power', type=float, default=0.0, help='weight for mask loss') 235 | 236 | ###5. training parameters 237 | parser.add_argument('--seed', type=int, default=666666, help='seed') 238 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size') 239 | parser.add_argument('--num_workers', type=int, default=4, help='num_workers') 240 | parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam') 241 | parser.add_argument('--beta2', type=float, default=0.99, help='momentum term of adam') 242 | parser.add_argument('--lr_g', type=float, default=0.0001, help='initial learning rate for G') 243 | parser.add_argument('--lr_d', type=float, default=0.0001, help='initial learning rate for D') 244 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 245 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 246 | parser.add_argument('--gpus', type=str, default=None) 247 | parser.add_argument('--same_prob', type=float, default=0.0, help='same prob') 248 | 249 | args = parser.parse_args() 250 | 251 | return args 252 | --------------------------------------------------------------------------------