├── .gitignore ├── LICENSE ├── README.md ├── README_CN.md ├── app_hf_space.py ├── app_lam.py ├── assets └── images │ ├── logo.jpeg │ └── teaser.jpg ├── configs ├── inference │ └── lam-20k-8gpu.yaml ├── stylematte_config.json └── vhap_tracking │ └── base_tracking_config.yaml ├── external ├── human_matting │ ├── __init__.py │ ├── matting_engine.py │ └── stylematte.py ├── landmark_detection │ ├── FaceBoxesV2 │ │ ├── __init__.py │ │ ├── detector.py │ │ ├── faceboxes_detector.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── box_utils.py │ │ │ ├── build.py │ │ │ ├── build │ │ │ └── temp.linux-x86_64-cpython-310 │ │ │ │ └── nms │ │ │ │ └── cpu_nms.o │ │ │ ├── config.py │ │ │ ├── faceboxes.py │ │ │ ├── make.sh │ │ │ ├── nms │ │ │ ├── __init__.py │ │ │ ├── cpu_nms.c │ │ │ ├── cpu_nms.py │ │ │ ├── cpu_nms.pyx │ │ │ ├── gpu_nms.hpp │ │ │ ├── gpu_nms.pyx │ │ │ ├── nms_kernel.cu │ │ │ └── py_cpu_nms.py │ │ │ ├── nms_wrapper.py │ │ │ ├── prior_box.py │ │ │ └── timer.py │ ├── README.md │ ├── conf │ │ ├── __init__.py │ │ ├── alignment.py │ │ └── base.py │ ├── config.json │ ├── data_processor │ │ ├── CheckFaceKeyPoint.py │ │ ├── align.py │ │ └── process_pcd.py │ ├── evaluate.py │ ├── infer_folder.py │ ├── infer_image.py │ ├── infer_video.py │ ├── lib │ │ ├── __init__.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── core │ │ │ │ └── coord_conv.py │ │ │ └── stackedHGNetV1.py │ │ ├── dataset │ │ │ ├── __init__.py │ │ │ ├── alignmentDataset.py │ │ │ ├── augmentation.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── decoder_default.py │ │ │ └── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── encoder_default.py │ │ ├── loss │ │ │ ├── __init__.py │ │ │ ├── awingLoss.py │ │ │ ├── smoothL1Loss.py │ │ │ ├── starLoss.py │ │ │ ├── starLoss_v2.py │ │ │ └── wingLoss.py │ │ ├── metric │ │ │ ├── __init__.py │ │ │ ├── accuracy.py │ │ │ ├── fr_and_auc.py │ │ │ ├── nme.py │ │ │ └── params.py │ │ ├── utility.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── dist_utils.py │ │ │ ├── meter.py │ │ │ ├── time_utils.py │ │ │ └── vis_utils.py │ ├── requirements.txt │ ├── tester.py │ └── tools │ │ ├── analysis_motivation.py │ │ ├── infinite_loop.py │ │ ├── infinite_loop_gpu.py │ │ ├── split_wflw.py │ │ └── testtime_pca.py └── vgghead_detector │ ├── VGGDetector.py │ ├── __init__.py │ ├── utils_lmks_detector.py │ └── utils_vgghead.py ├── lam ├── __init__.py ├── datasets │ ├── __init__.py │ ├── base.py │ ├── cam_utils.py │ ├── mixer.py │ └── video_head.py ├── launch.py ├── losses │ ├── __init__.py │ ├── perceptual.py │ ├── pixelwise.py │ └── tvloss.py ├── models │ ├── __init__.py │ ├── block.py │ ├── discriminator.py │ ├── encoders │ │ ├── __init__.py │ │ ├── dinov2 │ │ │ ├── __init__.py │ │ │ ├── hub │ │ │ │ ├── __init__.py │ │ │ │ ├── backbones.py │ │ │ │ ├── classifiers.py │ │ │ │ ├── depth │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── decode_heads.py │ │ │ │ │ ├── encoder_decoder.py │ │ │ │ │ └── ops.py │ │ │ │ ├── depthers.py │ │ │ │ └── utils.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── block.py │ │ │ │ ├── dino_head.py │ │ │ │ ├── drop_path.py │ │ │ │ ├── layer_scale.py │ │ │ │ ├── mlp.py │ │ │ │ ├── patch_embed.py │ │ │ │ └── swiglu_ffn.py │ │ │ └── models │ │ │ │ ├── __init__.py │ │ │ │ └── vision_transformer.py │ │ ├── dinov2_fusion_wrapper.py │ │ └── dpt_util │ │ │ ├── __init__.py │ │ │ ├── blocks.py │ │ │ └── transform.py │ ├── modeling_lam.py │ ├── modulate.py │ ├── rendering │ │ ├── __init__.py │ │ ├── flame_model │ │ │ ├── flame.py │ │ │ ├── flame_arkit.py │ │ │ └── lbs.py │ │ ├── gaussian_model.py │ │ ├── gs_renderer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── math_utils.py │ │ │ ├── mesh_utils.py │ │ │ ├── point_utils.py │ │ │ ├── renderer.py │ │ │ ├── sh_utils.py │ │ │ ├── typing.py │ │ │ ├── utils.py │ │ │ ├── uv_utils.py │ │ │ └── vis_utils.py │ ├── transformer.py │ └── transformer_dit.py ├── runners │ ├── __init__.py │ ├── abstract.py │ └── infer │ │ ├── __init__.py │ │ ├── base_inferrer.py │ │ ├── head_utils.py │ │ ├── lam.py │ │ └── utils.py └── utils │ ├── __init__.py │ ├── compile.py │ ├── ffmpeg_utils.py │ ├── gen_id_json.py │ ├── gen_json.py │ ├── hf_hub.py │ ├── logging.py │ ├── preprocess.py │ ├── profiler.py │ ├── proxy.py │ ├── registry.py │ ├── scheduler.py │ └── video.py ├── requirements.txt ├── scripts ├── convert_hf.py ├── exp │ ├── run_4gpu.sh │ ├── run_8gpu.sh │ └── run_debug.sh ├── inference.sh ├── install │ ├── WINDOWS_INSTALL.md │ ├── install_cu118.sh │ └── install_cu121.sh └── upload_hub.py ├── tools ├── AVATAR_EXPORT_GUIDE.md ├── __init__.py ├── convertFBX2GLB.py ├── flame_tracking_single_image.py ├── generateARKITGLBWithBlender.py ├── generateGLBWithBlender_v2.py ├── generateVertexIndices.py └── install_fbx_sdk.sh └── vhap ├── combine_nerf_datasets.py ├── config ├── base.py └── nersemble.py ├── data ├── image_folder_dataset.py ├── nerf_dataset.py ├── nersemble_dataset.py └── video_dataset.py ├── export_as_nerf_dataset.py ├── flame_editor.py ├── flame_viewer.py ├── generate_flame_uvmask.py ├── model ├── flame.py ├── lbs.py └── tracker.py ├── track.py ├── track_nersemble.py └── util ├── camera.py ├── landmark_detector_fa.py ├── landmark_detector_star.py ├── log.py ├── mesh.py ├── render_nvdiffrast.py ├── render_uvmap.py ├── vector_ops.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | wheels 2 | __pycache__/ 3 | build/ 4 | *.so 5 | assets/sample_input/ 6 | assets/sample_motion/ 7 | configs/vhap_tracking/ 8 | exps/ 9 | pretrain_model/ 10 | pretrained_models/ 11 | model_zoo/ 12 | tracking_output/ -------------------------------------------------------------------------------- /assets/images/logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/assets/images/logo.jpeg -------------------------------------------------------------------------------- /assets/images/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/assets/images/teaser.jpg -------------------------------------------------------------------------------- /configs/inference/lam-20k-8gpu.yaml: -------------------------------------------------------------------------------- 1 | 2 | experiment: 3 | type: lam 4 | seed: 42 5 | parent: lam 6 | child: lam_20k 7 | model: 8 | # image encoder 9 | encoder_type: "dinov2_fusion" 10 | encoder_model_name: "dinov2_vitl14_reg" 11 | encoder_feat_dim: 1024 12 | encoder_freeze: false 13 | 14 | # points embeddings 15 | latent_query_points_type: "e2e_flame" 16 | pcl_dim: 1024 17 | 18 | # transformer 19 | transformer_type: "sd3_cond" 20 | transformer_heads: 16 21 | transformer_dim: 1024 22 | transformer_layers: 10 23 | tf_grad_ckpt: true 24 | encoder_grad_ckpt: true 25 | 26 | # for gs renderer 27 | human_model_path: "./model_zoo/human_parametric_models" 28 | flame_subdivide_num: 1 29 | flame_type: "flame" 30 | gs_query_dim: 1024 31 | gs_use_rgb: True 32 | gs_sh: 3 33 | gs_mlp_network_config: 34 | n_neurons: 512 35 | n_hidden_layers: 2 36 | activation: silu 37 | gs_xyz_offset_max_step: 0.2 38 | gs_clip_scaling: 0.01 39 | scale_sphere: false 40 | 41 | expr_param_dim: 10 42 | shape_param_dim: 10 43 | add_teeth: false 44 | 45 | fix_opacity: false 46 | fix_rotation: false 47 | 48 | has_disc: false 49 | 50 | teeth_bs_flag: false 51 | oral_mesh_flag: false 52 | 53 | dataset: 54 | subsets: 55 | - name: video_head 56 | root_dirs: "./train_data/vfhq_vhap_nooffset/export" 57 | meta_path: 58 | train: "./train_data/vfhq_vhap_nooffset/label/valid_id_train_list.json" 59 | val: "./train_data/vfhq_vhap_nooffset/label/valid_id_val_list.json" 60 | sample_rate: 1.0 61 | sample_side_views: 7 62 | sample_aug_views: 0 63 | source_image_res: 512 64 | render_image: 65 | low: 512 66 | high: 512 67 | region: null 68 | num_train_workers: 4 69 | num_val_workers: 2 70 | pin_mem: true 71 | repeat_num: 1 72 | gaga_track_type: "vfhq" 73 | 74 | train: 75 | mixed_precision: bf16 # REPLACE THIS BASED ON GPU TYPE 76 | find_unused_parameters: false 77 | loss: 78 | pixel_weight: 0.0 79 | pixel_loss_fn: "mse" 80 | crop_face_weight: 0. 81 | crop_mouth_weight: 0. 82 | crop_eye_weight: 0. 83 | masked_pixel_weight: 1.0 84 | perceptual_weight: 1.0 85 | tv_weight: -1 86 | mask_weight: 0:1.0:0.5:10000 87 | offset_reg_weight: 0.1 88 | optim: 89 | lr: 4e-4 90 | weight_decay: 0.05 91 | beta1: 0.9 92 | beta2: 0.95 93 | clip_grad_norm: 1.0 94 | scheduler: 95 | type: cosine 96 | warmup_real_iters: 3000 97 | batch_size: 4 # REPLACE THIS (PER GPU) 98 | accum_steps: 1 # REPLACE THIS 99 | epochs: 100 # REPLACE THIS 100 | debug_global_steps: null 101 | resume: "" 102 | 103 | val: 104 | batch_size: 2 105 | global_step_period: 500 106 | debug_batches: 10 107 | 108 | saver: 109 | auto_resume: true 110 | load_model: null 111 | checkpoint_root: ./exps/checkpoints 112 | checkpoint_global_steps: 500 113 | checkpoint_keep_level: 5 114 | 115 | logger: 116 | stream_level: WARNING 117 | log_level: INFO 118 | log_root: ./exps/logs 119 | tracker_root: ./exps/trackers 120 | enable_profiler: false 121 | trackers: 122 | - tensorboard 123 | image_monitor: 124 | train_global_steps: 500 125 | samples_per_log: 4 126 | 127 | compile: 128 | suppress_errors: true 129 | print_specializations: true 130 | disable: true 131 | -------------------------------------------------------------------------------- /external/human_matting/__init__.py: -------------------------------------------------------------------------------- 1 | from .matting_engine import StyleMatteEngine 2 | -------------------------------------------------------------------------------- /external/human_matting/matting_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import inspect 4 | import warnings 5 | import torchvision 6 | from .stylematte import StyleMatte 7 | 8 | class StyleMatteEngine(torch.nn.Module): 9 | def __init__(self, device='cpu',human_matting_path='./model_zoo/flame_tracking_models/matting/stylematte_synth.pt'): 10 | super().__init__() 11 | self._device = device 12 | self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 13 | self._init_models(human_matting_path) 14 | 15 | def _init_models(self,_ckpt_path): 16 | # load dict 17 | state_dict = torch.load(_ckpt_path, map_location='cpu') 18 | # build model 19 | model = StyleMatte() 20 | model.load_state_dict(state_dict) 21 | self.model = model.to(self._device).eval() 22 | 23 | @torch.no_grad() 24 | def forward(self, input_image, return_type='matting', background_rgb=1.0): 25 | if not hasattr(self, 'model'): 26 | self._init_models() 27 | if input_image.max() > 2.0: 28 | warnings.warn('Image should be normalized to [0, 1].') 29 | _, ori_h, ori_w = input_image.shape 30 | input_image = input_image.to(self._device).float() 31 | image = input_image.clone() 32 | # resize 33 | if max(ori_h, ori_w) > 1024: 34 | scale = 1024.0 / max(ori_h, ori_w) 35 | resized_h, resized_w = int(ori_h * scale), int(ori_w * scale) 36 | image = torchvision.transforms.functional.resize(image, (resized_h, resized_w), antialias=True) 37 | else: 38 | resized_h, resized_w = ori_h, ori_w 39 | # padding 40 | if resized_h % 8 != 0 or resized_w % 8 != 0: 41 | image = torchvision.transforms.functional.pad(image, ((8-resized_w % 8)%8, (8-resized_h % 8)%8, 0, 0, ), padding_mode='reflect') 42 | # normalize and forwarding 43 | image = self.normalize(image)[None] 44 | predict = self.model(image)[0] 45 | # undo padding 46 | predict = predict[:, -resized_h:, -resized_w:] 47 | # undo resize 48 | if resized_h != ori_h or resized_w != ori_w: 49 | predict = torchvision.transforms.functional.resize(predict, (ori_h, ori_w), antialias=True) 50 | 51 | if return_type == 'alpha': 52 | return predict[0] 53 | elif return_type == 'matting': 54 | predict = predict.expand(3, -1, -1) 55 | matting_image = input_image.clone() 56 | background_rgb = matting_image.new_ones(matting_image.shape) * background_rgb 57 | matting_image = matting_image * predict + (1-predict) * background_rgb 58 | return matting_image, predict[0] 59 | elif return_type == 'all': 60 | predict = predict.expand(3, -1, -1) 61 | background_rgb = input_image.new_ones(input_image.shape) * background_rgb 62 | foreground_image = input_image * predict + (1-predict) * background_rgb 63 | background_image = input_image * (1-predict) + predict * background_rgb 64 | return foreground_image, background_image 65 | else: 66 | raise NotImplementedError 67 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/__init__.py: -------------------------------------------------------------------------------- 1 | from . import detector 2 | from . import faceboxes_detector -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/detector.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | class Detector(object): 4 | def __init__(self, model_arch, model_weights): 5 | self.model_arch = model_arch 6 | self.model_weights = model_weights 7 | 8 | def detect(self, image, thresh): 9 | raise NotImplementedError 10 | 11 | def crop(self, image, detections): 12 | crops = [] 13 | for det in detections: 14 | xmin = max(det[2], 0) 15 | ymin = max(det[3], 0) 16 | width = det[4] 17 | height = det[5] 18 | xmax = min(xmin+width, image.shape[1]) 19 | ymax = min(ymin+height, image.shape[0]) 20 | cut = image[ymin:ymax, xmin:xmax,:] 21 | crops.append(cut) 22 | 23 | return crops 24 | 25 | def draw(self, image, detections, im_scale=None): 26 | if im_scale is not None: 27 | image = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR) 28 | detections = [[det[0],det[1],int(det[2]*im_scale),int(det[3]*im_scale),int(det[4]*im_scale),int(det[5]*im_scale)] for det in detections] 29 | 30 | for det in detections: 31 | xmin = det[2] 32 | ymin = det[3] 33 | width = det[4] 34 | height = det[5] 35 | xmax = xmin + width 36 | ymax = ymin + height 37 | cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2) 38 | 39 | return image 40 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/faceboxes_detector.py: -------------------------------------------------------------------------------- 1 | from .detector import Detector 2 | import cv2, os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from .utils.config import cfg 7 | from .utils.prior_box import PriorBox 8 | from .utils.nms_wrapper import nms 9 | from .utils.faceboxes import FaceBoxesV2 10 | from .utils.box_utils import decode 11 | import time 12 | 13 | class FaceBoxesDetector(Detector): 14 | def __init__(self, model_arch, model_weights, use_gpu, device): 15 | super().__init__(model_arch, model_weights) 16 | self.name = 'FaceBoxesDetector' 17 | self.net = FaceBoxesV2(phase='test', size=None, num_classes=2) # initialize detector 18 | self.use_gpu = use_gpu 19 | self.device = device 20 | 21 | state_dict = torch.load(self.model_weights, map_location=self.device) 22 | # create new OrderedDict that does not contain `module.` 23 | from collections import OrderedDict 24 | new_state_dict = OrderedDict() 25 | for k, v in state_dict.items(): 26 | name = k[7:] # remove `module.` 27 | new_state_dict[name] = v 28 | # load params 29 | self.net.load_state_dict(new_state_dict) 30 | self.net = self.net.to(self.device) 31 | self.net.eval() 32 | 33 | 34 | def detect(self, image, thresh=0.6, im_scale=None): 35 | # auto resize for large images 36 | if im_scale is None: 37 | height, width, _ = image.shape 38 | if min(height, width) > 600: 39 | im_scale = 600. / min(height, width) 40 | else: 41 | im_scale = 1 42 | image_scale = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR) 43 | 44 | scale = torch.Tensor([image_scale.shape[1], image_scale.shape[0], image_scale.shape[1], image_scale.shape[0]]) 45 | image_scale = torch.from_numpy(image_scale.transpose(2,0,1)).to(self.device).int() 46 | mean_tmp = torch.IntTensor([104, 117, 123]).to(self.device) 47 | mean_tmp = mean_tmp.unsqueeze(1).unsqueeze(2) 48 | image_scale -= mean_tmp 49 | image_scale = image_scale.float().unsqueeze(0) 50 | scale = scale.to(self.device) 51 | 52 | with torch.no_grad(): 53 | out = self.net(image_scale) 54 | #priorbox = PriorBox(cfg, out[2], (image_scale.size()[2], image_scale.size()[3]), phase='test') 55 | priorbox = PriorBox(cfg, image_size=(image_scale.size()[2], image_scale.size()[3])) 56 | priors = priorbox.forward() 57 | priors = priors.to(self.device) 58 | loc, conf = out 59 | prior_data = priors.data 60 | boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) 61 | boxes = boxes * scale 62 | boxes = boxes.cpu().numpy() 63 | scores = conf.data.cpu().numpy()[:, 1] 64 | 65 | # ignore low scores 66 | inds = np.where(scores > thresh)[0] 67 | boxes = boxes[inds] 68 | scores = scores[inds] 69 | 70 | # keep top-K before NMS 71 | order = scores.argsort()[::-1][:5000] 72 | boxes = boxes[order] 73 | scores = scores[order] 74 | 75 | # do NMS 76 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) 77 | keep = nms(dets, 0.3) 78 | dets = dets[keep, :] 79 | 80 | dets = dets[:750, :] 81 | detections_scale = [] 82 | for i in range(dets.shape[0]): 83 | xmin = int(dets[i][0]) 84 | ymin = int(dets[i][1]) 85 | xmax = int(dets[i][2]) 86 | ymax = int(dets[i][3]) 87 | score = dets[i][4] 88 | width = xmax - xmin 89 | height = ymax - ymin 90 | detections_scale.append(['face', score, xmin, ymin, width, height]) 91 | 92 | # adapt bboxes to the original image size 93 | if len(detections_scale) > 0: 94 | detections_scale = [[det[0],det[1],int(det[2]/im_scale),int(det[3]/im_scale),int(det[4]/im_scale),int(det[5]/im_scale)] for det in detections_scale] 95 | 96 | return detections_scale, im_scale 97 | 98 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/external/landmark_detection/FaceBoxesV2/utils/__init__.py -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/build.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # -------------------------------------------------------- 4 | # Fast R-CNN 5 | # Copyright (c) 2015 Microsoft 6 | # Licensed under The MIT License [see LICENSE for details] 7 | # Written by Ross Girshick 8 | # -------------------------------------------------------- 9 | 10 | import os 11 | from os.path import join as pjoin 12 | import numpy as np 13 | from distutils.core import setup 14 | from distutils.extension import Extension 15 | from Cython.Distutils import build_ext 16 | 17 | 18 | def find_in_path(name, path): 19 | "Find a file in a search path" 20 | # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/ 21 | for dir in path.split(os.pathsep): 22 | binpath = pjoin(dir, name) 23 | if os.path.exists(binpath): 24 | return os.path.abspath(binpath) 25 | return None 26 | 27 | 28 | # Obtain the numpy include directory. This logic works across numpy versions. 29 | try: 30 | numpy_include = np.get_include() 31 | except AttributeError: 32 | numpy_include = np.get_numpy_include() 33 | 34 | 35 | # run the customize_compiler 36 | class custom_build_ext(build_ext): 37 | def build_extensions(self): 38 | # customize_compiler_for_nvcc(self.compiler) 39 | build_ext.build_extensions(self) 40 | 41 | 42 | ext_modules = [ 43 | Extension( 44 | "nms.cpu_nms", 45 | ["nms/cpu_nms.pyx"], 46 | # extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]}, 47 | extra_compile_args=["-Wno-cpp", "-Wno-unused-function"], 48 | include_dirs=[numpy_include] 49 | ) 50 | ] 51 | 52 | setup( 53 | name='mot_utils', 54 | ext_modules=ext_modules, 55 | # inject our custom trigger 56 | cmdclass={'build_ext': custom_build_ext}, 57 | ) 58 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/build/temp.linux-x86_64-cpython-310/nms/cpu_nms.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/external/landmark_detection/FaceBoxesV2/utils/build/temp.linux-x86_64-cpython-310/nms/cpu_nms.o -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | 3 | cfg = { 4 | 'name': 'FaceBoxes', 5 | #'min_dim': 1024, 6 | #'feature_maps': [[32, 32], [16, 16], [8, 8]], 7 | # 'aspect_ratios': [[1], [1], [1]], 8 | 'min_sizes': [[32, 64, 128], [256], [512]], 9 | 'steps': [32, 64, 128], 10 | 'variance': [0.1, 0.2], 11 | 'clip': False, 12 | 'loc_weight': 2.0, 13 | 'gpu_train': True 14 | } 15 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 build.py build_ext --inplace 3 | 4 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.hpp: -------------------------------------------------------------------------------- 1 | void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num, 2 | int boxes_dim, float nms_overlap_thresh, int device_id); 3 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Faster R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | cimport numpy as np 10 | 11 | assert sizeof(int) == sizeof(np.int32_t) 12 | 13 | cdef extern from "gpu_nms.hpp": 14 | void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int) 15 | 16 | def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh, 17 | np.int32_t device_id=0): 18 | cdef int boxes_num = dets.shape[0] 19 | cdef int boxes_dim = dets.shape[1] 20 | cdef int num_out 21 | cdef np.ndarray[np.int32_t, ndim=1] \ 22 | keep = np.zeros(boxes_num, dtype=np.int32) 23 | cdef np.ndarray[np.float32_t, ndim=1] \ 24 | scores = dets[:, 4] 25 | cdef np.ndarray[np.int_t, ndim=1] \ 26 | order = scores.argsort()[::-1] 27 | cdef np.ndarray[np.float32_t, ndim=2] \ 28 | sorted_dets = dets[order, :] 29 | _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id) 30 | keep = keep[:num_out] 31 | return list(order[keep]) 32 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/nms/py_cpu_nms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | 10 | def py_cpu_nms(dets, thresh): 11 | """Pure Python NMS baseline.""" 12 | x1 = dets[:, 0] 13 | y1 = dets[:, 1] 14 | x2 = dets[:, 2] 15 | y2 = dets[:, 3] 16 | scores = dets[:, 4] 17 | 18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 19 | order = scores.argsort()[::-1] 20 | 21 | keep = [] 22 | while order.size > 0: 23 | i = order[0] 24 | keep.append(i) 25 | xx1 = np.maximum(x1[i], x1[order[1:]]) 26 | yy1 = np.maximum(y1[i], y1[order[1:]]) 27 | xx2 = np.minimum(x2[i], x2[order[1:]]) 28 | yy2 = np.minimum(y2[i], y2[order[1:]]) 29 | 30 | w = np.maximum(0.0, xx2 - xx1 + 1) 31 | h = np.maximum(0.0, yy2 - yy1 + 1) 32 | inter = w * h 33 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 34 | 35 | inds = np.where(ovr <= thresh)[0] 36 | order = order[inds + 1] 37 | 38 | return keep 39 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/nms_wrapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | from .nms.cpu_nms import cpu_nms, cpu_soft_nms 9 | 10 | def nms(dets, thresh): 11 | """Dispatch to either CPU or GPU NMS implementations.""" 12 | 13 | if dets.shape[0] == 0: 14 | return [] 15 | return cpu_nms(dets, thresh) 16 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/prior_box.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from itertools import product as product 3 | import numpy as np 4 | from math import ceil 5 | 6 | 7 | class PriorBox(object): 8 | def __init__(self, cfg, image_size=None, phase='train'): 9 | super(PriorBox, self).__init__() 10 | #self.aspect_ratios = cfg['aspect_ratios'] 11 | self.min_sizes = cfg['min_sizes'] 12 | self.steps = cfg['steps'] 13 | self.clip = cfg['clip'] 14 | self.image_size = image_size 15 | self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] 16 | 17 | def forward(self): 18 | anchors = [] 19 | for k, f in enumerate(self.feature_maps): 20 | min_sizes = self.min_sizes[k] 21 | for i, j in product(range(f[0]), range(f[1])): 22 | for min_size in min_sizes: 23 | s_kx = min_size / self.image_size[1] 24 | s_ky = min_size / self.image_size[0] 25 | if min_size == 32: 26 | dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.25, j+0.5, j+0.75]] 27 | dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.25, i+0.5, i+0.75]] 28 | for cy, cx in product(dense_cy, dense_cx): 29 | anchors += [cx, cy, s_kx, s_ky] 30 | elif min_size == 64: 31 | dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.5]] 32 | dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.5]] 33 | for cy, cx in product(dense_cy, dense_cx): 34 | anchors += [cx, cy, s_kx, s_ky] 35 | else: 36 | cx = (j + 0.5) * self.steps[k] / self.image_size[1] 37 | cy = (i + 0.5) * self.steps[k] / self.image_size[0] 38 | anchors += [cx, cy, s_kx, s_ky] 39 | # back to torch land 40 | output = torch.Tensor(anchors).view(-1, 4) 41 | if self.clip: 42 | output.clamp_(max=1, min=0) 43 | return output 44 | -------------------------------------------------------------------------------- /external/landmark_detection/FaceBoxesV2/utils/timer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import time 9 | 10 | 11 | class Timer(object): 12 | """A simple timer.""" 13 | def __init__(self): 14 | self.total_time = 0. 15 | self.calls = 0 16 | self.start_time = 0. 17 | self.diff = 0. 18 | self.average_time = 0. 19 | 20 | def tic(self): 21 | # using time.time instead of time.clock because time time.clock 22 | # does not normalize for multithreading 23 | self.start_time = time.time() 24 | 25 | def toc(self, average=True): 26 | self.diff = time.time() - self.start_time 27 | self.total_time += self.diff 28 | self.calls += 1 29 | self.average_time = self.total_time / self.calls 30 | if average: 31 | return self.average_time 32 | else: 33 | return self.diff 34 | 35 | def clear(self): 36 | self.total_time = 0. 37 | self.calls = 0 38 | self.start_time = 0. 39 | self.diff = 0. 40 | self.average_time = 0. 41 | -------------------------------------------------------------------------------- /external/landmark_detection/README.md: -------------------------------------------------------------------------------- 1 | # STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection. 2 | 3 | Paper Link: [arxiv](https://arxiv.org/abs/2306.02763) | [CVPR 2023](https://openaccess.thecvf.com/content/CVPR2023/papers/Zhou_STAR_Loss_Reducing_Semantic_Ambiguity_in_Facial_Landmark_Detection_CVPR_2023_paper.pdf) 4 | 5 | 6 | - Pytorch implementation of **S**elf-adap**T**ive **A**mbiguity **R**eduction (**STAR**) loss. 7 | - STAR loss is a self-adaptive anisotropic direction loss, which can be used in heatmap regression-based methods for facial landmark detection. 8 | - Specifically, we find that semantic ambiguity results in the anisotropic predicted distribution, which inspires us to use predicted distribution to represent semantic ambiguity. So, we use PCA to indicate the character of the predicted distribution and indirectly formulate the direction and intensity of semantic ambiguity. Based on this, STAR loss adaptively suppresses the prediction error in the ambiguity direction to mitigate the impact of ambiguity annotation in training. More details can be found in our paper. 9 |

10 | 11 |

12 | 13 | 14 | 15 | 16 | ## Dependencies 17 | 18 | * python==3.7.3 19 | * PyTorch=1.6.0 20 | * requirements.txt 21 | 22 | ## Dataset Preparation 23 | 24 | - Step1: Download the raw images from [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/#dataset), [300W](https://ibug.doc.ic.ac.uk/resources/300-W/), and [WFLW](https://wywu.github.io/projects/LAB/WFLW.html). 25 | - Step2: We follow the data preprocess in [ADNet](https://openaccess.thecvf.com/content/ICCV2021/papers/Huang_ADNet_Leveraging_Error-Bias_Towards_Normal_Direction_in_Face_Alignment_ICCV_2021_paper.pdf), and the metadata can be download from [the corresponding repository](https://github.com/huangyangyu/ADNet). 26 | - Step3: Make them look like this: 27 | ```script 28 | # the dataset directory: 29 | |-- ${image_dir} 30 | |-- WFLW 31 | | -- WFLW_images 32 | |-- 300W 33 | | -- afw 34 | | -- helen 35 | | -- ibug 36 | | -- lfpw 37 | |-- COFW 38 | | -- train 39 | | -- test 40 | |-- ${annot_dir} 41 | |-- WFLW 42 | |-- train.tsv, test.tsv 43 | |-- 300W 44 | |-- train.tsv, test.tsv 45 | |--COFW 46 | |-- train.tsv, test.tsv 47 | ``` 48 | 49 | ## Usage 50 | * Work directory: set the ${ckpt_dir} in ./conf/alignment.py. 51 | * Pretrained model: 52 | 53 | | Dataset | Model | 54 | |:-----------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------| 55 | | WFLW | [google](https://drive.google.com/file/d/1aOx0wYEZUfBndYy_8IYszLPG_D2fhxrT/view?usp=sharing) / [baidu](https://pan.baidu.com/s/10vvI-ovs3x9NrdmpnXK6sg?pwd=u0yu) | 56 | | 300W | [google](https://drive.google.com/file/d/1Fiu3hjjkQRdKsWE9IgyNPdiJSz9_MzA5/view?usp=sharing) / [baidu](https://pan.baidu.com/s/1bjUhLq1zS1XSl1nX78fU7A?pwd=yb2s) | 57 | | COFW | [google](https://drive.google.com/file/d/1NFcZ9jzql_jnn3ulaSzUlyhS05HWB9n_/view?usp=drive_link) / [baidu](https://pan.baidu.com/s/1XO6hDZ8siJLTgFcpyu1Tzw?pwd=m57n) | 58 | 59 | 60 | ### Training 61 | ```shell 62 | python main.py --mode=train --device_ids=0,1,2,3 \ 63 | --image_dir=${image_dir} --annot_dir=${annot_dir} \ 64 | --data_definition={WFLW, 300W, COFW} 65 | ``` 66 | 67 | ### Testing 68 | ```shell 69 | python main.py --mode=test --device_ids=0 \ 70 | --image_dir=${image_dir} --annot_dir=${annot_dir} \ 71 | --data_definition={WFLW, 300W, COFW} \ 72 | --pretrained_weight=${model_path} \ 73 | ``` 74 | 75 | ### Evaluation 76 | ```shell 77 | python evaluate.py --device_ids=0 \ 78 | --model_path=${model_path} --metadata_path=${metadata_path} \ 79 | --image_dir=${image_dir} --data_definition={WFLW, 300W, COFW} \ 80 | ``` 81 | 82 | To test on your own image, the following code could be considered: 83 | ```shell 84 | python demo.py 85 | ``` 86 | 87 | 88 | ## Results 89 | The models trained by STAR Loss achieved **SOTA** performance in all of COFW, 300W and WFLW datasets. 90 | 91 |

92 | 93 |

94 | 95 | ## BibTeX Citation 96 | Please consider citing our papers in your publications if the project helps your research. BibTeX reference is as follows. 97 | ``` 98 | @inproceedings{Zhou_2023_CVPR, 99 | author = {Zhou, Zhenglin and Li, Huaxia and Liu, Hong and Wang, Nanyang and Yu, Gang and Ji, Rongrong}, 100 | title = {STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection}, 101 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 102 | month = {June}, 103 | year = {2023}, 104 | pages = {15475-15484} 105 | } 106 | ``` 107 | 108 | ## Acknowledgments 109 | This repository is built on top of [ADNet](https://github.com/huangyangyu/ADNet). 110 | Thanks for this strong baseline. 111 | -------------------------------------------------------------------------------- /external/landmark_detection/conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .alignment import Alignment -------------------------------------------------------------------------------- /external/landmark_detection/conf/base.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import logging 3 | import os.path as osp 4 | from argparse import Namespace 5 | # from tensorboardX import SummaryWriter 6 | 7 | class Base: 8 | """ 9 | Base configure file, which contains the basic training parameters and should be inherited by other attribute configure file. 10 | """ 11 | 12 | def __init__(self, config_name, ckpt_dir='./', image_dir='./', annot_dir='./'): 13 | self.type = config_name 14 | self.id = str(uuid.uuid4()) 15 | self.note = "" 16 | 17 | self.ckpt_dir = ckpt_dir 18 | self.image_dir = image_dir 19 | self.annot_dir = annot_dir 20 | 21 | self.loader_type = "alignment" 22 | self.loss_func = "STARLoss" 23 | 24 | # train 25 | self.batch_size = 128 26 | self.val_batch_size = 1 27 | self.test_batch_size = 32 28 | self.channels = 3 29 | self.width = 256 30 | self.height = 256 31 | 32 | # mean values in r, g, b channel. 33 | self.means = (127, 127, 127) 34 | self.scale = 0.0078125 35 | 36 | self.display_iteration = 100 37 | self.milestones = [50, 80] 38 | self.max_epoch = 100 39 | 40 | self.net = "stackedHGnet_v1" 41 | self.nstack = 4 42 | 43 | # ["adam", "sgd"] 44 | self.optimizer = "adam" 45 | self.learn_rate = 0.1 46 | self.momentum = 0.01 # caffe: 0.99 47 | self.weight_decay = 0.0 48 | self.nesterov = False 49 | self.scheduler = "MultiStepLR" 50 | self.gamma = 0.1 51 | 52 | self.loss_weights = [1.0] 53 | self.criterions = ["SoftmaxWithLoss"] 54 | self.metrics = ["Accuracy"] 55 | self.key_metric_index = 0 56 | self.classes_num = [1000] 57 | self.label_num = len(self.classes_num) 58 | 59 | # model 60 | self.ema = False 61 | self.use_AAM = True 62 | 63 | # visualization 64 | self.writer = None 65 | 66 | # log file 67 | self.logger = None 68 | 69 | def init_instance(self): 70 | # self.writer = SummaryWriter(logdir=self.log_dir, comment=self.type) 71 | log_formatter = logging.Formatter("%(asctime)s %(levelname)-8s: %(message)s") 72 | root_logger = logging.getLogger() 73 | file_handler = logging.FileHandler(osp.join(self.log_dir, "log.txt")) 74 | file_handler.setFormatter(log_formatter) 75 | file_handler.setLevel(logging.NOTSET) 76 | root_logger.addHandler(file_handler) 77 | console_handler = logging.StreamHandler() 78 | console_handler.setFormatter(log_formatter) 79 | console_handler.setLevel(logging.NOTSET) 80 | root_logger.addHandler(console_handler) 81 | root_logger.setLevel(logging.NOTSET) 82 | self.logger = root_logger 83 | 84 | def __del__(self): 85 | # tensorboard --logdir self.log_dir 86 | if self.writer is not None: 87 | # self.writer.export_scalars_to_json(self.log_dir + "visual.json") 88 | self.writer.close() 89 | 90 | def init_from_args(self, args: Namespace): 91 | args_vars = vars(args) 92 | for key, value in args_vars.items(): 93 | if hasattr(self, key) and value is not None: 94 | setattr(self, key, value) 95 | -------------------------------------------------------------------------------- /external/landmark_detection/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "Token":"bpt4JPotFA6bpdknR9ZDCw", 3 | "business_flag": "shadow_cv_face", 4 | "model_local_file_path": "/apdcephfs_cq3/share_1134483/charlinzhou/Documents/awesome-tools/jizhi/", 5 | "host_num": 1, 6 | "host_gpu_num": 1, 7 | "GPUName": "V100", 8 | "is_elasticity": true, 9 | "enable_evicted_pulled_up": true, 10 | "task_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1", 11 | "task_flag": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1", 12 | "model_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1", 13 | "image_full_name": "mirrors.tencent.com/haroldzcli/py36-pytorch1.7.1-torchvision0.8.2-cuda10.1-cudnn7.6", 14 | "start_cmd": "./start_slpt.sh /apdcephfs_cq3/share_1134483/charlinzhou/Documents/SLPT_Training train.py --loss_func=star --bb_init --eigen_box --dist_func=align_smoothl1" 15 | } 16 | -------------------------------------------------------------------------------- /external/landmark_detection/data_processor/CheckFaceKeyPoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | selected_indices_old = [ 8 | 2311, 9 | 2416, 10 | 2437, 11 | 2460, 12 | 2495, 13 | 2518, 14 | 2520, 15 | 2627, 16 | 4285, 17 | 4315, 18 | 6223, 19 | 6457, 20 | 6597, 21 | 6642, 22 | 6974, 23 | 7054, 24 | 7064, 25 | 7182, 26 | 7303, 27 | 7334, 28 | 7351, 29 | 7368, 30 | 7374, 31 | 7493, 32 | 7503, 33 | 7626, 34 | 8443, 35 | 8562, 36 | 8597, 37 | 8701, 38 | 8817, 39 | 8953, 40 | 11213, 41 | 11261, 42 | 11317, 43 | 11384, 44 | 11600, 45 | 11755, 46 | 11852, 47 | 11891, 48 | 11945, 49 | 12010, 50 | 12354, 51 | 12534, 52 | 12736, 53 | 12880, 54 | 12892, 55 | 13004, 56 | 13323, 57 | 13371, 58 | 13534, 59 | 13575, 60 | 14874, 61 | 14949, 62 | 14977, 63 | 15052, 64 | 15076, 65 | 15291, 66 | 15620, 67 | 15758, 68 | 16309, 69 | 16325, 70 | 16348, 71 | 16390, 72 | 16489, 73 | 16665, 74 | 16891, 75 | 17147, 76 | 17183, 77 | 17488, 78 | 17549, 79 | 17657, 80 | 17932, 81 | 19661, 82 | 20162, 83 | 20200, 84 | 20238, 85 | 20286, 86 | 20432, 87 | 20834, 88 | 20954, 89 | 21015, 90 | 21036, 91 | 21117, 92 | 21299, 93 | 21611, 94 | 21632, 95 | 21649, 96 | 22722, 97 | 22759, 98 | 22873, 99 | 23028, 100 | 23033, 101 | 23082, 102 | 23187, 103 | 23232, 104 | 23302, 105 | 23413, 106 | 23430, 107 | 23446, 108 | 23457, 109 | 23548, 110 | 23636, 111 | 32060, 112 | 32245, 113 | ] 114 | 115 | selected_indices = list() 116 | with open('/home/gyalex/Desktop/face_anno.txt', 'r') as f: 117 | lines = f.readlines() 118 | for line in lines: 119 | hh = line.strip().split() 120 | if len(hh) > 0: 121 | pid = hh[0].find('.') 122 | if pid != -1: 123 | s = hh[0][pid+1:len(hh[0])] 124 | print(s) 125 | selected_indices.append(int(s)) 126 | 127 | f.close() 128 | 129 | dir = '/media/gyalex/Data/face_ldk_dataset/MHC_LightingPreset_Portrait_RT_0_19/MHC_LightingPreset_Portrait_RT_seq_000015' 130 | 131 | for idx in range(500): 132 | img = os.path.join(dir, "view_1/MHC_LightingPreset_Portrait_RT_seq_000015_FinalImage_" + str(idx).zfill(4) + ".jpeg") 133 | lmd = os.path.join(dir, "mesh/mesh_screen" + str(idx+5).zfill(7) + ".npy") 134 | 135 | img = cv2.imread(img) 136 | # c = 511 / 2 137 | # lmd = np.load(lmd) * c + c 138 | # lmd[:, 1] = 511 - lmd[:, 1] 139 | lmd = np.load(lmd)[selected_indices] 140 | for i in range(lmd.shape[0]): 141 | p = lmd[i] 142 | x, y = round(float(p[0])), round(float(p[1])) 143 | print(p) 144 | cv2.circle(img, (x, y), 2, (0, 0, 255), -1) 145 | 146 | cv2.imshow('win', img) 147 | cv2.waitKey(0) -------------------------------------------------------------------------------- /external/landmark_detection/lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import get_encoder, get_decoder 2 | from .dataset import AlignmentDataset, Augmentation 3 | from .backbone import StackedHGNetV1 4 | from .metric import NME, Accuracy 5 | from .utils import time_print, time_string, time_for_file, time_string_short 6 | from .utils import convert_secs2time, convert_size2str 7 | 8 | from .utility import get_dataloader, get_config, get_net, get_criterions 9 | from .utility import get_optimizer, get_scheduler 10 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .stackedHGNetV1 import StackedHGNetV1 2 | 3 | __all__ = [ 4 | "StackedHGNetV1", 5 | ] -------------------------------------------------------------------------------- /external/landmark_detection/lib/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import get_encoder 2 | from .decoder import get_decoder 3 | from .augmentation import Augmentation 4 | from .alignmentDataset import AlignmentDataset 5 | 6 | __all__ = [ 7 | "Augmentation", 8 | "AlignmentDataset", 9 | "get_encoder", 10 | "get_decoder" 11 | ] 12 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/dataset/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder_default import decoder_default 2 | 3 | def get_decoder(decoder_type='default'): 4 | if decoder_type == 'default': 5 | decoder = decoder_default() 6 | else: 7 | raise NotImplementedError 8 | return decoder -------------------------------------------------------------------------------- /external/landmark_detection/lib/dataset/decoder/decoder_default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class decoder_default: 5 | def __init__(self, weight=1, use_weight_map=False): 6 | self.weight = weight 7 | self.use_weight_map = use_weight_map 8 | 9 | def _make_grid(self, h, w): 10 | yy, xx = torch.meshgrid( 11 | torch.arange(h).float() / (h - 1) * 2 - 1, 12 | torch.arange(w).float() / (w - 1) * 2 - 1) 13 | return yy, xx 14 | 15 | def get_coords_from_heatmap(self, heatmap): 16 | """ 17 | inputs: 18 | - heatmap: batch x npoints x h x w 19 | 20 | outputs: 21 | - coords: batch x npoints x 2 (x,y), [-1, +1] 22 | - radius_sq: batch x npoints 23 | """ 24 | batch, npoints, h, w = heatmap.shape 25 | if self.use_weight_map: 26 | heatmap = heatmap * self.weight 27 | 28 | yy, xx = self._make_grid(h, w) 29 | yy = yy.view(1, 1, h, w).to(heatmap) 30 | xx = xx.view(1, 1, h, w).to(heatmap) 31 | 32 | heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6) 33 | 34 | yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints 35 | xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints 36 | coords = torch.stack([xx_coord, yy_coord], dim=-1) 37 | 38 | return coords 39 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/dataset/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder_default import encoder_default 2 | 3 | def get_encoder(image_height, image_width, scale=0.25, sigma=1.5, encoder_type='default'): 4 | if encoder_type == 'default': 5 | encoder = encoder_default(image_height, image_width, scale, sigma) 6 | else: 7 | raise NotImplementedError 8 | return encoder 9 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/dataset/encoder/encoder_default.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | class encoder_default: 9 | def __init__(self, image_height, image_width, scale=0.25, sigma=1.5): 10 | self.image_height = image_height 11 | self.image_width = image_width 12 | self.scale = scale 13 | self.sigma = sigma 14 | 15 | def generate_heatmap(self, points): 16 | # points = (num_pts, 2) 17 | h, w = self.image_height, self.image_width 18 | pointmaps = [] 19 | for i in range(len(points)): 20 | pointmap = np.zeros([h, w], dtype=np.float32) 21 | # align_corners: False. 22 | point = copy.deepcopy(points[i]) 23 | point[0] = max(0, min(w - 1, point[0])) 24 | point[1] = max(0, min(h - 1, point[1])) 25 | pointmap = self._circle(pointmap, point, sigma=self.sigma) 26 | 27 | pointmaps.append(pointmap) 28 | pointmaps = np.stack(pointmaps, axis=0) / 255.0 29 | pointmaps = torch.from_numpy(pointmaps).float().unsqueeze(0) 30 | pointmaps = F.interpolate(pointmaps, size=(int(w * self.scale), int(h * self.scale)), mode='bilinear', 31 | align_corners=False).squeeze() 32 | return pointmaps 33 | 34 | def _circle(self, img, pt, sigma=1.0, label_type='Gaussian'): 35 | # Check that any part of the gaussian is in-bounds 36 | tmp_size = sigma * 3 37 | ul = [int(pt[0] - tmp_size), int(pt[1] - tmp_size)] 38 | br = [int(pt[0] + tmp_size + 1), int(pt[1] + tmp_size + 1)] 39 | if (ul[0] > img.shape[1] - 1 or ul[1] > img.shape[0] - 1 or 40 | br[0] - 1 < 0 or br[1] - 1 < 0): 41 | # If not, just return the image as is 42 | return img 43 | 44 | # Generate gaussian 45 | size = 2 * tmp_size + 1 46 | x = np.arange(0, size, 1, np.float32) 47 | y = x[:, np.newaxis] 48 | x0 = y0 = size // 2 49 | # The gaussian is not normalized, we want the center value to equal 1 50 | if label_type == 'Gaussian': 51 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 52 | else: 53 | g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) 54 | 55 | # Usable gaussian range 56 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] 57 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] 58 | # Image range 59 | img_x = max(0, ul[0]), min(br[0], img.shape[1]) 60 | img_y = max(0, ul[1]), min(br[1], img.shape[0]) 61 | 62 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = 255 * g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 63 | return img 64 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .awingLoss import AWingLoss 2 | from .smoothL1Loss import SmoothL1Loss 3 | from .wingLoss import WingLoss 4 | from .starLoss import STARLoss 5 | from .starLoss_v2 import STARLoss_v2 6 | 7 | __all__ = [ 8 | "AWingLoss", 9 | "SmoothL1Loss", 10 | "WingLoss", 11 | "STARLoss", 12 | 13 | "STARLoss_v2", 14 | ] 15 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/loss/awingLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AWingLoss(nn.Module): 7 | def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1, use_weight_map=True): 8 | super(AWingLoss, self).__init__() 9 | self.omega = omega 10 | self.theta = theta 11 | self.epsilon = epsilon 12 | self.alpha = alpha 13 | self.use_weight_map = use_weight_map 14 | 15 | def __repr__(self): 16 | return "AWingLoss()" 17 | 18 | def generate_weight_map(self, heatmap, k_size=3, w=10): 19 | dilate = F.max_pool2d(heatmap, kernel_size=k_size, stride=1, padding=1) 20 | weight_map = torch.where(dilate < 0.2, torch.zeros_like(heatmap), torch.ones_like(heatmap)) 21 | return w * weight_map + 1 22 | 23 | def forward(self, output, groundtruth): 24 | """ 25 | input: b x n x h x w 26 | output: b x n x h x w => 1 27 | """ 28 | delta = (output - groundtruth).abs() 29 | A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth))) * (self.alpha - groundtruth) * \ 30 | (torch.pow(self.theta / self.epsilon, self.alpha - groundtruth - 1)) * (1 / self.epsilon) 31 | C = self.theta * A - self.omega * \ 32 | torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth)) 33 | loss = torch.where(delta < self.theta, 34 | self.omega * torch.log(1 + torch.pow(delta / self.epsilon, self.alpha - groundtruth)), 35 | (A * delta - C)) 36 | if self.use_weight_map: 37 | weight = self.generate_weight_map(groundtruth) 38 | loss = loss * weight 39 | return loss.mean() 40 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/loss/smoothL1Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SmoothL1Loss(nn.Module): 6 | def __init__(self, scale=0.01): 7 | super(SmoothL1Loss, self).__init__() 8 | self.scale = scale 9 | self.EPSILON = 1e-10 10 | 11 | def __repr__(self): 12 | return "SmoothL1Loss()" 13 | 14 | def forward(self, output: torch.Tensor, groundtruth: torch.Tensor, reduction='mean'): 15 | """ 16 | input: b x n x 2 17 | output: b x n x 1 => 1 18 | """ 19 | if output.dim() == 4: 20 | shape = output.shape 21 | groundtruth = groundtruth.reshape(shape[0], shape[1], 1, shape[3]) 22 | 23 | delta_2 = (output - groundtruth).pow(2).sum(dim=-1, keepdim=False) 24 | delta = delta_2.clamp(min=1e-6).sqrt() 25 | # delta = torch.sqrt(delta_2 + self.EPSILON) 26 | loss = torch.where( \ 27 | delta_2 < self.scale * self.scale, \ 28 | 0.5 / self.scale * delta_2, \ 29 | delta - 0.5 * self.scale) 30 | 31 | if reduction == 'mean': 32 | loss = loss.mean() 33 | elif reduction == 'sum': 34 | loss = loss.sum() 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/loss/wingLoss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | from torch import nn 6 | 7 | 8 | # torch.log and math.log is e based 9 | class WingLoss(nn.Module): 10 | def __init__(self, omega=0.01, epsilon=2): 11 | super(WingLoss, self).__init__() 12 | self.omega = omega 13 | self.epsilon = epsilon 14 | 15 | def forward(self, pred, target): 16 | y = target 17 | y_hat = pred 18 | delta_2 = (y - y_hat).pow(2).sum(dim=-1, keepdim=False) 19 | # delta = delta_2.sqrt() 20 | delta = delta_2.clamp(min=1e-6).sqrt() 21 | C = self.omega - self.omega * math.log(1 + self.omega / self.epsilon) 22 | loss = torch.where( 23 | delta < self.omega, 24 | self.omega * torch.log(1 + delta / self.epsilon), 25 | delta - C 26 | ) 27 | return loss.mean() 28 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/metric/__init__.py: -------------------------------------------------------------------------------- 1 | from .nme import NME 2 | from .accuracy import Accuracy 3 | from .fr_and_auc import FR_AUC 4 | from .params import count_parameters_in_MB 5 | 6 | __all__ = [ 7 | "NME", 8 | "Accuracy", 9 | "FR_AUC", 10 | 'count_parameters_in_MB', 11 | ] 12 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/metric/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class Accuracy: 5 | def __init__(self): 6 | pass 7 | 8 | def __repr__(self): 9 | return "Accuracy()" 10 | 11 | def test(self, label_pd, label_gt, ignore_label=-1): 12 | correct_cnt = 0 13 | total_cnt = 0 14 | with torch.no_grad(): 15 | label_pd = F.softmax(label_pd, dim=1) 16 | label_pd = torch.max(label_pd, 1)[1] 17 | label_gt = label_gt.long() 18 | c = (label_pd == label_gt) 19 | correct_cnt = torch.sum(c).item() 20 | total_cnt = c.size(0) - torch.sum(label_gt==ignore_label).item() 21 | return correct_cnt, total_cnt 22 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/metric/fr_and_auc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.integrate import simps 3 | 4 | 5 | class FR_AUC: 6 | def __init__(self, data_definition): 7 | self.data_definition = data_definition 8 | if data_definition == '300W': 9 | self.thresh = 0.05 10 | else: 11 | self.thresh = 0.1 12 | 13 | def __repr__(self): 14 | return "FR_AUC()" 15 | 16 | def test(self, nmes, thres=None, step=0.0001): 17 | if thres is None: 18 | thres = self.thresh 19 | 20 | num_data = len(nmes) 21 | xs = np.arange(0, thres + step, step) 22 | ys = np.array([np.count_nonzero(nmes <= x) for x in xs]) / float(num_data) 23 | fr = 1.0 - ys[-1] 24 | auc = simps(ys, x=xs) / thres 25 | return [round(fr, 4), round(auc, 6)] 26 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/metric/nme.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class NME: 5 | def __init__(self, nme_left_index, nme_right_index): 6 | self.nme_left_index = nme_left_index 7 | self.nme_right_index = nme_right_index 8 | 9 | def __repr__(self): 10 | return "NME()" 11 | 12 | def get_norm_distance(self, landmarks): 13 | assert isinstance(self.nme_right_index, list), 'the nme_right_index is not list.' 14 | assert isinstance(self.nme_left_index, list), 'the nme_left, index is not list.' 15 | right_pupil = landmarks[self.nme_right_index, :].mean(0) 16 | left_pupil = landmarks[self.nme_left_index, :].mean(0) 17 | norm_distance = np.linalg.norm(right_pupil - left_pupil) 18 | return norm_distance 19 | 20 | def test(self, label_pd, label_gt): 21 | nme_list = [] 22 | label_pd = label_pd.data.cpu().numpy() 23 | label_gt = label_gt.data.cpu().numpy() 24 | 25 | for i in range(label_gt.shape[0]): 26 | landmarks_gt = label_gt[i] 27 | landmarks_pv = label_pd[i] 28 | if isinstance(self.nme_right_index, list): 29 | norm_distance = self.get_norm_distance(landmarks_gt) 30 | elif isinstance(self.nme_right_index, int): 31 | norm_distance = np.linalg.norm(landmarks_gt[self.nme_left_index] - landmarks_gt[self.nme_right_index]) 32 | else: 33 | raise NotImplementedError 34 | landmarks_delta = landmarks_pv - landmarks_gt 35 | nme = (np.linalg.norm(landmarks_delta, axis=1) / norm_distance).mean() 36 | nme_list.append(nme) 37 | # sum_nme += nme 38 | # total_cnt += 1 39 | return nme_list 40 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/metric/params.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def count_parameters_in_MB(model): 4 | if isinstance(model, nn.Module): 5 | return sum(v.numel() for v in model.parameters()) / 1e6 6 | else: 7 | return sum(v.numel() for v in model) / 1e6 -------------------------------------------------------------------------------- /external/landmark_detection/lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .meter import AverageMeter 2 | from .time_utils import time_print, time_string, time_string_short, time_for_file 3 | from .time_utils import convert_secs2time, convert_size2str 4 | from .vis_utils import plot_points 5 | 6 | __all__ = [ 7 | "AverageMeter", 8 | "time_print", 9 | "time_string", 10 | "time_string_short", 11 | "time_for_file", 12 | "convert_size2str", 13 | "convert_secs2time", 14 | 15 | "plot_points", 16 | ] 17 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0.0 9 | self.avg = 0.0 10 | self.sum = 0.0 11 | self.count = 0.0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | 19 | def __repr__(self): 20 | return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__)) -------------------------------------------------------------------------------- /external/landmark_detection/lib/utils/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import time, sys 8 | import numpy as np 9 | 10 | 11 | def time_for_file(): 12 | ISOTIMEFORMAT = '%d-%h-at-%H-%M-%S' 13 | return '{}'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) 14 | 15 | 16 | def time_string(): 17 | ISOTIMEFORMAT = '%Y-%m-%d %X' 18 | string = '[{}]'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) 19 | return string 20 | 21 | 22 | def time_string_short(): 23 | ISOTIMEFORMAT = '%Y%m%d' 24 | string = '{}'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) 25 | return string 26 | 27 | 28 | def time_print(string, is_print=True): 29 | if (is_print): 30 | print('{} : {}'.format(time_string(), string)) 31 | 32 | 33 | def convert_size2str(torch_size): 34 | dims = len(torch_size) 35 | string = '[' 36 | for idim in range(dims): 37 | string = string + ' {}'.format(torch_size[idim]) 38 | return string + ']' 39 | 40 | 41 | def convert_secs2time(epoch_time, return_str=False): 42 | need_hour = int(epoch_time / 3600) 43 | need_mins = int((epoch_time - 3600 * need_hour) / 60) 44 | need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins) 45 | if return_str: 46 | str = '[Time Left: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) 47 | return str 48 | else: 49 | return need_hour, need_mins, need_secs 50 | -------------------------------------------------------------------------------- /external/landmark_detection/lib/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import numbers 4 | 5 | 6 | def plot_points(vis, points, radius=1, color=(255, 255, 0), shift=4, indexes=0, is_index=False): 7 | if isinstance(points, list): 8 | num_point = len(points) 9 | elif isinstance(points, np.numarray): 10 | num_point = points.shape[0] 11 | else: 12 | raise NotImplementedError 13 | if isinstance(radius, numbers.Number): 14 | radius = np.zeros((num_point)) + radius 15 | 16 | if isinstance(indexes, numbers.Number): 17 | indexes = [indexes + i for i in range(num_point)] 18 | elif isinstance(indexes, list): 19 | pass 20 | else: 21 | raise NotImplementedError 22 | 23 | factor = (1 << shift) 24 | for (index, p, s) in zip(indexes, points, radius): 25 | cv2.circle(vis, (int(p[0] * factor + 0.5), int(p[1] * factor + 0.5)), 26 | int(s * factor), color, 1, cv2.LINE_AA, shift=shift) 27 | if is_index: 28 | vis = cv2.putText(vis, str(index), (int(p[0]), int(p[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.2, 29 | (255, 255, 255), 1) 30 | 31 | return vis 32 | -------------------------------------------------------------------------------- /external/landmark_detection/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | torch==1.6.0 3 | torchvision==0.7.0 4 | python-gflags==3.1.2 5 | pandas==0.24.2 6 | pillow==6.0.0 7 | numpy==1.16.4 8 | opencv-python==4.1.0.25 9 | imageio==2.5.0 10 | imgaug==0.2.9 11 | lmdb==0.98 12 | lxml==4.5.0 13 | tensorboard==2.4.1 14 | protobuf==3.20 15 | tensorboardX==1.8 16 | # pyarrow==0.17.1 17 | # wandb==0.10.25 18 | # https://pytorch.org/get-started/previous-versions/ 19 | # pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 20 | -------------------------------------------------------------------------------- /external/landmark_detection/tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from lib import utility 4 | 5 | 6 | def test(args): 7 | # conf 8 | config = utility.get_config(args) 9 | config.device_id = args.device_ids[0] 10 | 11 | # set environment 12 | utility.set_environment(config) 13 | config.init_instance() 14 | if config.logger is not None: 15 | config.logger.info("Loaded configure file %s: %s" % (args.config_name, config.id)) 16 | config.logger.info("\n" + "\n".join(["%s: %s" % item for item in config.__dict__.items()])) 17 | 18 | # model 19 | net = utility.get_net(config) 20 | model_path = os.path.join(config.model_dir, 21 | "train.pkl") if args.pretrained_weight is None else args.pretrained_weight 22 | if args.device_ids == [-1]: 23 | checkpoint = torch.load(model_path, map_location="cpu") 24 | else: 25 | checkpoint = torch.load(model_path) 26 | 27 | net.load_state_dict(checkpoint["net"]) 28 | 29 | if config.logger is not None: 30 | config.logger.info("Loaded network") 31 | # config.logger.info('Net flops: {} G, params: {} MB'.format(flops/1e9, params/1e6)) 32 | 33 | # data - test 34 | test_loader = utility.get_dataloader(config, "test") 35 | 36 | if config.logger is not None: 37 | config.logger.info("Loaded data from {:}".format(config.test_tsv_file)) 38 | 39 | # inference 40 | result, metrics = utility.forward(config, test_loader, net) 41 | if config.logger is not None: 42 | config.logger.info("Finished inference") 43 | 44 | # output 45 | for k, metric in enumerate(metrics): 46 | if config.logger is not None and len(metric) != 0: 47 | config.logger.info( 48 | "Tested {} dataset, the Size is {}, Metric: [NME {:.6f}, FR {:.6f}, AUC {:.6f}]".format( 49 | config.type, len(test_loader.dataset), metric[0], metric[1], metric[2])) 50 | -------------------------------------------------------------------------------- /external/landmark_detection/tools/infinite_loop.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | while True: 4 | time.sleep(1) 5 | -------------------------------------------------------------------------------- /external/landmark_detection/tools/infinite_loop_gpu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import time 5 | import torch 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser(description='inf') 9 | parser.add_argument('--gpu', default='1', type=str, help='index of gpu to use') 10 | args = parser.parse_args() 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 13 | 14 | n = 1000 15 | 16 | x = torch.zeros(4, n, n).cuda() 17 | rest_time = 0.0000000000001 18 | while True: 19 | y = x * x 20 | time.sleep(rest_time) 21 | y1 = x * x 22 | -------------------------------------------------------------------------------- /external/landmark_detection/tools/split_wflw.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os.path as osp 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | tsv_file = '/apdcephfs/share_1134483/charlinzhou/datas/ADNet/WFLW/test.tsv' 8 | save_folder = '/apdcephfs/share_1134483/charlinzhou/datas/ADNet/_WFLW/' 9 | 10 | save_tags = ['largepose', 'expression', 'illumination', 'makeup', 'occlusion', 'blur'] 11 | save_tags = ['test_{}_metadata.tsv'.format(t) for t in save_tags] 12 | save_files = [osp.join(save_folder, t) for t in save_tags] 13 | save_files = [open(f, 'w', newline='') for f in save_files] 14 | 15 | landmark_num = 98 16 | items = pd.read_csv(tsv_file, sep="\t") 17 | 18 | items_num = len(items) 19 | for index in tqdm(range(items_num)): 20 | image_path = items.iloc[index, 0] 21 | landmarks_5pts = items.iloc[index, 1] 22 | # landmarks_5pts = np.array(list(map(float, landmarks_5pts.split(","))), dtype=np.float32).reshape(5, 2) 23 | landmarks_target = items.iloc[index, 2] 24 | # landmarks_target = np.array(list(map(float, landmarks_target.split(","))), dtype=np.float32).reshape(landmark_num, 2) 25 | scale = items.iloc[index, 3] 26 | center_w, center_h = items.iloc[index, 4], items.iloc[index, 5] 27 | if len(items.iloc[index]) > 6: 28 | tags = np.array(list(map(lambda x: int(float(x)), items.iloc[index, 6].split(",")))) 29 | else: 30 | tags = np.array([]) 31 | assert len(tags) == 6, '{} v.s. 6'.format(len(tags)) 32 | for k, tag in enumerate(tags): 33 | if tag == 1: 34 | save_file = save_files[k] 35 | tsv_w = csv.writer(save_file, delimiter='\t') 36 | tsv_w.writerow([image_path, landmarks_5pts, landmarks_target, scale, center_w, center_h]) 37 | 38 | print('Done!') 39 | -------------------------------------------------------------------------------- /external/landmark_detection/tools/testtime_pca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | def get_channel_sum(input): 7 | temp = torch.sum(input, dim=3) 8 | output = torch.sum(temp, dim=2) 9 | return output 10 | 11 | 12 | def expand_two_dimensions_at_end(input, dim1, dim2): 13 | input = input.unsqueeze(-1).unsqueeze(-1) 14 | input = input.expand(-1, -1, dim1, dim2) 15 | return input 16 | 17 | 18 | class TestTimePCA(nn.Module): 19 | def __init__(self): 20 | super(TestTimePCA, self).__init__() 21 | 22 | def _make_grid(self, h, w): 23 | yy, xx = torch.meshgrid( 24 | torch.arange(h).float() / (h - 1) * 2 - 1, 25 | torch.arange(w).float() / (w - 1) * 2 - 1) 26 | return yy, xx 27 | 28 | def weighted_mean(self, heatmap): 29 | batch, npoints, h, w = heatmap.shape 30 | 31 | yy, xx = self._make_grid(h, w) 32 | yy = yy.view(1, 1, h, w).to(heatmap) 33 | xx = xx.view(1, 1, h, w).to(heatmap) 34 | 35 | yy_coord = (yy * heatmap).sum([2, 3]) # batch x npoints 36 | xx_coord = (xx * heatmap).sum([2, 3]) # batch x npoints 37 | coords = torch.stack([xx_coord, yy_coord], dim=-1) 38 | return coords 39 | 40 | def unbiased_weighted_covariance(self, htp, means, num_dim_image=2, EPSILON=1e-5): 41 | batch_size, num_points, height, width = htp.shape 42 | 43 | yv, xv = self._make_grid(height, width) 44 | xv = Variable(xv) 45 | yv = Variable(yv) 46 | 47 | if htp.is_cuda: 48 | xv = xv.cuda() 49 | yv = yv.cuda() 50 | 51 | xmean = means[:, :, 0] 52 | xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height, 53 | width) # [batch_size, 68, 64, 64] 54 | ymean = means[:, :, 1] 55 | yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height, 56 | width) # [batch_size, 68, 64, 64] 57 | wt_xv_minus_mean = xv_minus_mean 58 | wt_yv_minus_mean = yv_minus_mean 59 | 60 | wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096] 61 | wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096] 62 | wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096] 63 | wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096] 64 | vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1) # [batch_size*68, 2, 4096] 65 | 66 | htp_vec = htp.view(batch_size * num_points, 1, height * width) 67 | htp_vec = htp_vec.expand(-1, 2, -1) 68 | 69 | covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2)) # [batch_size*68, 2, 2] 70 | covariance = covariance.view(batch_size, num_points, num_dim_image, num_dim_image) # [batch_size, 68, 2, 2] 71 | 72 | V_1 = htp.sum([2, 3]) + EPSILON # [batch_size, 68] 73 | V_2 = torch.pow(htp, 2).sum([2, 3]) + EPSILON # [batch_size, 68] 74 | 75 | denominator = V_1 - (V_2 / V_1) 76 | covariance = covariance / expand_two_dimensions_at_end(denominator, num_dim_image, num_dim_image) 77 | 78 | return covariance 79 | 80 | def forward(self, heatmap, groudtruth): 81 | 82 | batch, npoints, h, w = heatmap.shape 83 | 84 | heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6) 85 | heatmap = heatmap / heatmap_sum.view(batch, npoints, 1, 1) 86 | 87 | # means [batch_size, 68, 2] 88 | means = self.weighted_mean(heatmap) 89 | 90 | # covars [batch_size, 68, 2, 2] 91 | covars = self.unbiased_weighted_covariance(heatmap, means) 92 | 93 | # eigenvalues [batch_size * 68, 2] , eigenvectors [batch_size * 68, 2, 2] 94 | covars = covars.view(batch * npoints, 2, 2).cpu() 95 | evalues, evectors = covars.symeig(eigenvectors=True) 96 | evalues = evalues.view(batch, npoints, 2) 97 | evectors = evectors.view(batch, npoints, 2, 2) 98 | means = means.cpu() 99 | 100 | results = [dict() for _ in range(batch)] 101 | for i in range(batch): 102 | results[i]['pred'] = means[i].numpy().tolist() 103 | results[i]['gt'] = groudtruth[i].cpu().numpy().tolist() 104 | results[i]['evalues'] = evalues[i].numpy().tolist() 105 | results[i]['evectors'] = evectors[i].numpy().tolist() 106 | 107 | return results 108 | -------------------------------------------------------------------------------- /external/vgghead_detector/VGGDetector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Xuangeng Chu (xg.chu@outlook.com) 3 | # Modified based on code from Orest Kupyn (University of Oxford). 4 | 5 | import os 6 | import torch 7 | import numpy as np 8 | import torchvision 9 | 10 | from .utils_vgghead import nms 11 | from .utils_lmks_detector import LmksDetector 12 | 13 | class VGGHeadDetector(torch.nn.Module): 14 | def __init__(self, device, 15 | vggheadmodel_path=None): 16 | super().__init__() 17 | self.image_size = 640 18 | self._device = device 19 | self.vggheadmodel_path = vggheadmodel_path 20 | self._init_models() 21 | 22 | def _init_models(self,): 23 | # vgg_heads_l 24 | self.model = torch.load(self.vggheadmodel_path, map_location='cpu') 25 | self.model.to(self._device).eval() 26 | 27 | @torch.no_grad() 28 | def forward(self, image_tensor, image_key, conf_threshold=0.5): 29 | if not hasattr(self, 'model'): 30 | self._init_models() 31 | image_tensor = image_tensor.to(self._device).float() 32 | image, padding, scale = self._preprocess(image_tensor) 33 | bbox, scores, flame_params = self.model(image) 34 | bbox, vgg_results = self._postprocess(bbox, scores, flame_params, conf_threshold) 35 | 36 | if bbox is None: 37 | print('VGGHeadDetector: No face detected: {}!'.format(image_key)) 38 | return None, None, None 39 | vgg_results['normalize'] = {'padding': padding, 'scale': scale} 40 | 41 | # bbox 42 | bbox = bbox.clip(0, self.image_size) 43 | bbox[[0, 2]] -= padding[0]; bbox[[1, 3]] -= padding[1]; bbox /= scale 44 | bbox = bbox.clip(0, self.image_size / scale) 45 | 46 | return vgg_results, bbox, None 47 | 48 | def _preprocess(self, image): 49 | _, h, w = image.shape 50 | if h > w: 51 | new_h, new_w = self.image_size, int(w * self.image_size / h) 52 | else: 53 | new_h, new_w = int(h * self.image_size / w), self.image_size 54 | scale = self.image_size / max(h, w) 55 | image = torchvision.transforms.functional.resize(image, (new_h, new_w), antialias=True) 56 | pad_w = self.image_size - image.shape[2] 57 | pad_h = self.image_size - image.shape[1] 58 | image = torchvision.transforms.functional.pad(image, (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2), fill=127) 59 | image = image.unsqueeze(0).float() / 255.0 60 | return image, np.array([pad_w // 2, pad_h // 2]), scale 61 | 62 | def _postprocess(self, bbox, scores, flame_params, conf_threshold): 63 | # flame_params = {"shape": 300, "exp": 100, "rotation": 6, "jaw": 3, "translation": 3, "scale": 1} 64 | bbox, scores, flame_params = nms(bbox, scores, flame_params, confidence_threshold=conf_threshold) 65 | if bbox.shape[0] == 0: 66 | return None, None 67 | max_idx = ((bbox[:, 3] - bbox[:, 1]) * (bbox[:, 2] - bbox[:, 0])).argmax().long() 68 | bbox, flame_params = bbox[max_idx], flame_params[max_idx] 69 | if bbox[0] < 5 and bbox[1] < 5 and bbox[2] > 635 and bbox[3] > 635: 70 | return None, None 71 | # flame 72 | posecode = torch.cat([flame_params.new_zeros(3), flame_params[400:403]]) 73 | vgg_results = { 74 | 'rotation_6d': flame_params[403:409], 'translation': flame_params[409:412], 'scale': flame_params[412:], 75 | 'shapecode': flame_params[:300], 'expcode': flame_params[300:400], 'posecode': posecode, 76 | } 77 | return bbox, vgg_results 78 | -------------------------------------------------------------------------------- /external/vgghead_detector/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Xuangeng Chu (xg.chu@outlook.com) 3 | 4 | from .VGGDetector import VGGHeadDetector 5 | from .utils_vgghead import reproject_vertices 6 | -------------------------------------------------------------------------------- /external/vgghead_detector/utils_vgghead.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Xuangeng Chu (xg.chu@outlook.com) 3 | # Modified based on code from Orest Kupyn (University of Oxford). 4 | 5 | import torch 6 | import torchvision 7 | 8 | def reproject_vertices(flame_model, vgg_results): 9 | # flame_model = FLAMEModel(n_shape=300, n_exp=100, scale=1.0) 10 | vertices, _ = flame_model( 11 | shape_params=vgg_results['shapecode'], 12 | expression_params=vgg_results['expcode'], 13 | pose_params=vgg_results['posecode'], 14 | verts_sclae=1.0 15 | ) 16 | vertices[:, :, 2] += 0.05 # MESH_OFFSET_Z 17 | vgg_landmarks3d = flame_model._vertices2landmarks(vertices) 18 | vgg_transform_results = vgg_results['transform'] 19 | rotation_mat = rot_mat_from_6dof(vgg_transform_results['rotation_6d']).type(vertices.dtype) 20 | translation = vgg_transform_results['translation'][:, None, :] 21 | scale = torch.clamp(vgg_transform_results['scale'][:, None], 1e-8) 22 | rot_vertices = vertices.clone() 23 | rot_vertices = torch.matmul(rotation_mat.unsqueeze(1), rot_vertices.unsqueeze(-1))[..., 0] 24 | vgg_landmarks3d = torch.matmul(rotation_mat.unsqueeze(1), vgg_landmarks3d.unsqueeze(-1))[..., 0] 25 | proj_vertices = (rot_vertices * scale) + translation 26 | vgg_landmarks3d = (vgg_landmarks3d * scale) + translation 27 | 28 | trans_padding, trans_scale = vgg_results['normalize']['padding'], vgg_results['normalize']['scale'] 29 | proj_vertices[:, :, 0] -= trans_padding[:, 0, None] 30 | proj_vertices[:, :, 1] -= trans_padding[:, 1, None] 31 | proj_vertices = proj_vertices / trans_scale[:, None, None] 32 | vgg_landmarks3d[:, :, 0] -= trans_padding[:, 0, None] 33 | vgg_landmarks3d[:, :, 1] -= trans_padding[:, 1, None] 34 | vgg_landmarks3d = vgg_landmarks3d / trans_scale[:, None, None] 35 | return proj_vertices.float()[..., :2], vgg_landmarks3d.float()[..., :2] 36 | 37 | 38 | def rot_mat_from_6dof(v: torch.Tensor) -> torch.Tensor: 39 | assert v.shape[-1] == 6 40 | v = v.view(-1, 6) 41 | vx, vy = v[..., :3].clone(), v[..., 3:].clone() 42 | 43 | b1 = torch.nn.functional.normalize(vx, dim=-1) 44 | b3 = torch.nn.functional.normalize(torch.cross(b1, vy, dim=-1), dim=-1) 45 | b2 = -torch.cross(b1, b3, dim=1) 46 | return torch.stack((b1, b2, b3), dim=-1) 47 | 48 | 49 | def nms(boxes_xyxy, scores, flame_params, 50 | confidence_threshold: float = 0.5, iou_threshold: float = 0.5, 51 | top_k: int = 1000, keep_top_k: int = 100 52 | ): 53 | for pred_bboxes_xyxy, pred_bboxes_conf, pred_flame_params in zip( 54 | boxes_xyxy.detach().float(), 55 | scores.detach().float(), 56 | flame_params.detach().float(), 57 | ): 58 | pred_bboxes_conf = pred_bboxes_conf.squeeze(-1) # [Anchors] 59 | conf_mask = pred_bboxes_conf >= confidence_threshold 60 | 61 | pred_bboxes_conf = pred_bboxes_conf[conf_mask] 62 | pred_bboxes_xyxy = pred_bboxes_xyxy[conf_mask] 63 | pred_flame_params = pred_flame_params[conf_mask] 64 | 65 | # Filter all predictions by self.nms_top_k 66 | if pred_bboxes_conf.size(0) > top_k: 67 | topk_candidates = torch.topk(pred_bboxes_conf, k=top_k, largest=True, sorted=True) 68 | pred_bboxes_conf = pred_bboxes_conf[topk_candidates.indices] 69 | pred_bboxes_xyxy = pred_bboxes_xyxy[topk_candidates.indices] 70 | pred_flame_params = pred_flame_params[topk_candidates.indices] 71 | 72 | # NMS 73 | idx_to_keep = torchvision.ops.boxes.nms(boxes=pred_bboxes_xyxy, scores=pred_bboxes_conf, iou_threshold=iou_threshold) 74 | 75 | final_bboxes = pred_bboxes_xyxy[idx_to_keep][: keep_top_k] # [Instances, 4] 76 | final_scores = pred_bboxes_conf[idx_to_keep][: keep_top_k] # [Instances, 1] 77 | final_params = pred_flame_params[idx_to_keep][: keep_top_k] # [Instances, Flame Params] 78 | return final_bboxes, final_scores, final_params 79 | -------------------------------------------------------------------------------- /lam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/lam/__init__.py -------------------------------------------------------------------------------- /lam/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .mixer import MixerDataset 17 | -------------------------------------------------------------------------------- /lam/datasets/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | import traceback 18 | import json 19 | import numpy as np 20 | import torch 21 | from PIL import Image 22 | from typing import Optional, Union 23 | from megfile import smart_open, smart_path_join, smart_exists 24 | 25 | 26 | class BaseDataset(torch.utils.data.Dataset, ABC): 27 | def __init__(self, root_dirs: str, meta_path: Optional[Union[list, str]]): 28 | super().__init__() 29 | self.root_dirs = root_dirs 30 | self.uids = self._load_uids(meta_path) 31 | 32 | def __len__(self): 33 | return len(self.uids) 34 | 35 | @abstractmethod 36 | def inner_get_item(self, idx): 37 | pass 38 | 39 | def __getitem__(self, idx): 40 | try: 41 | return self.inner_get_item(idx) 42 | except Exception as e: 43 | traceback.print_exc() 44 | print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}") 45 | # raise e 46 | return self.__getitem__((idx + 1) % self.__len__()) 47 | 48 | @staticmethod 49 | def _load_uids(meta_path: Optional[Union[list, str]]): 50 | # meta_path is a json file 51 | if isinstance(meta_path, str): 52 | with open(meta_path, 'r') as f: 53 | uids = json.load(f) 54 | else: 55 | uids_lst = [] 56 | max_total = 0 57 | for pth, weight in meta_path: 58 | with open(pth, 'r') as f: 59 | uids = json.load(f) 60 | max_total = max(len(uids) / weight, max_total) 61 | uids_lst.append([uids, weight, pth]) 62 | merged_uids = [] 63 | for uids, weight, pth in uids_lst: 64 | repeat = 1 65 | if len(uids) < int(weight * max_total): 66 | repeat = int(weight * max_total) // len(uids) 67 | cur_uids = uids * repeat 68 | merged_uids += cur_uids 69 | print("Data Path:", pth, "Repeat:", repeat, "Final Length:", len(cur_uids)) 70 | uids = merged_uids 71 | print("Total UIDs:", len(uids)) 72 | return uids 73 | 74 | @staticmethod 75 | def _load_rgba_image(file_path, bg_color: float = 1.0): 76 | ''' Load and blend RGBA image to RGB with certain background, 0-1 scaled ''' 77 | rgba = np.array(Image.open(smart_open(file_path, 'rb'))) 78 | rgba = torch.from_numpy(rgba).float() / 255.0 79 | rgba = rgba.permute(2, 0, 1).unsqueeze(0) 80 | rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :]) 81 | rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...]) 82 | return rgb 83 | 84 | @staticmethod 85 | def _locate_datadir(root_dirs, uid, locator: str): 86 | for root_dir in root_dirs: 87 | datadir = smart_path_join(root_dir, uid, locator) 88 | if smart_exists(datadir): 89 | return root_dir 90 | raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}") 91 | -------------------------------------------------------------------------------- /lam/datasets/mixer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | from functools import partial 18 | import torch 19 | 20 | __all__ = ['MixerDataset'] 21 | 22 | 23 | class MixerDataset(torch.utils.data.Dataset): 24 | 25 | def __init__(self, 26 | split: str, 27 | subsets: dict, 28 | **dataset_kwargs, 29 | ): 30 | subsets = [e for e in subsets if e["meta_path"][split] is not None] 31 | self.subsets = [ 32 | self._dataset_fn(subset, split)(**dataset_kwargs) 33 | for subset in subsets 34 | ] 35 | self.virtual_lens = [ 36 | math.ceil(subset_config['sample_rate'] * len(subset_obj)) 37 | for subset_config, subset_obj in zip(subsets, self.subsets) 38 | ] 39 | 40 | @staticmethod 41 | def _dataset_fn(subset_config: dict, split: str): 42 | name = subset_config['name'] 43 | 44 | dataset_cls = None 45 | if name == "exavatar": 46 | from .exavatar import ExAvatarDataset 47 | dataset_cls = ExAvatarDataset 48 | elif name == "humman": 49 | from .humman import HuMManDataset 50 | dataset_cls = HuMManDataset 51 | elif name == "humman_ori": 52 | from .humman_ori import HuMManOriDataset 53 | dataset_cls = HuMManOriDataset 54 | elif name == "static_human": 55 | from .static_human import StaticHumanDataset 56 | dataset_cls = StaticHumanDataset 57 | elif name == "singleview_human": 58 | from .singleview_human import SingleViewHumanDataset 59 | dataset_cls = SingleViewHumanDataset 60 | elif name == "singleview_square_human": 61 | from .singleview_square_human import SingleViewSquareHumanDataset 62 | dataset_cls = SingleViewSquareHumanDataset 63 | elif name == "bedlam": 64 | from .bedlam import BedlamDataset 65 | dataset_cls = BedlamDataset 66 | elif name == "dna_human": 67 | from .dna import DNAHumanDataset 68 | dataset_cls = DNAHumanDataset 69 | elif name == "video_human": 70 | from .video_human import VideoHumanDataset 71 | dataset_cls = VideoHumanDataset 72 | elif name == "video_head": 73 | from .video_head import VideoHeadDataset 74 | dataset_cls = VideoHeadDataset 75 | elif name == "video_head_gagtrack": 76 | from .video_head_gagtrack import VideoHeadGagDataset 77 | dataset_cls = VideoHeadGagDataset 78 | elif name == "objaverse": 79 | from .objaverse import ObjaverseDataset 80 | dataset_cls = ObjaverseDataset 81 | # elif name == 'mvimgnet': 82 | # from .mvimgnet import MVImgNetDataset 83 | # dataset_cls = MVImgNetDataset 84 | else: 85 | raise NotImplementedError(f"Dataset {name} not implemented") 86 | print("==="*16*3, "\nUse dataset loader:", name, "\n"+"==="*3*16) 87 | 88 | return partial( 89 | dataset_cls, 90 | root_dirs=subset_config['root_dirs'], 91 | meta_path=subset_config['meta_path'][split], 92 | ) 93 | 94 | def __len__(self): 95 | return sum(self.virtual_lens) 96 | 97 | def __getitem__(self, idx): 98 | subset_idx = 0 99 | virtual_idx = idx 100 | while virtual_idx >= self.virtual_lens[subset_idx]: 101 | virtual_idx -= self.virtual_lens[subset_idx] 102 | subset_idx += 1 103 | real_idx = virtual_idx % len(self.subsets[subset_idx]) 104 | return self.subsets[subset_idx][real_idx] 105 | -------------------------------------------------------------------------------- /lam/launch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import argparse 17 | 18 | from lam.runners import REGISTRY_RUNNERS 19 | 20 | 21 | def main(): 22 | 23 | parser = argparse.ArgumentParser(description='lam launcher') 24 | parser.add_argument('runner', type=str, help='Runner to launch') 25 | args, unknown = parser.parse_known_args() 26 | 27 | if args.runner not in REGISTRY_RUNNERS: 28 | raise ValueError('Runner {} not found'.format(args.runner)) 29 | 30 | RunnerClass = REGISTRY_RUNNERS[args.runner] 31 | with RunnerClass() as runner: 32 | runner.run() 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /lam/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .pixelwise import * 17 | from .perceptual import * 18 | from .tvloss import * 19 | -------------------------------------------------------------------------------- /lam/losses/perceptual.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | from einops import rearrange 19 | 20 | __all__ = ['LPIPSLoss'] 21 | 22 | 23 | class LPIPSLoss(nn.Module): 24 | """ 25 | Compute LPIPS loss between two images. 26 | """ 27 | 28 | def __init__(self, device, prefech: bool = False): 29 | super().__init__() 30 | self.device = device 31 | self.cached_models = {} 32 | if prefech: 33 | self.prefetch_models() 34 | 35 | def _get_model(self, model_name: str): 36 | if model_name not in self.cached_models: 37 | import warnings 38 | with warnings.catch_warnings(): 39 | warnings.filterwarnings('ignore', category=UserWarning) 40 | import lpips 41 | _model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device) 42 | _model = torch.compile(_model) 43 | self.cached_models[model_name] = _model 44 | return self.cached_models[model_name] 45 | 46 | def prefetch_models(self): 47 | _model_names = ['alex', 'vgg'] 48 | for model_name in _model_names: 49 | self._get_model(model_name) 50 | 51 | def forward(self, x, y, is_training: bool = True, conf_sigma=None, only_sym_conf=False): 52 | """ 53 | Assume images are 0-1 scaled and channel first. 54 | 55 | Args: 56 | x: [N, M, C, H, W] 57 | y: [N, M, C, H, W] 58 | is_training: whether to use VGG or AlexNet. 59 | 60 | Returns: 61 | Mean-reduced LPIPS loss across batch. 62 | """ 63 | model_name = 'vgg' if is_training else 'alex' 64 | loss_fn = self._get_model(model_name) 65 | EPS = 1e-7 66 | if len(x.shape) == 5: 67 | N, M, C, H, W = x.shape 68 | x = x.reshape(N*M, C, H, W) 69 | y = y.reshape(N*M, C, H, W) 70 | image_loss = loss_fn(x, y, normalize=True) 71 | image_loss = image_loss.mean(dim=[1, 2, 3]) 72 | batch_loss = image_loss.reshape(N, M).mean(dim=1) 73 | all_loss = batch_loss.mean() 74 | else: 75 | image_loss = loss_fn(x, y, normalize=True) 76 | if conf_sigma is not None: 77 | image_loss = image_loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log() 78 | image_loss = image_loss.mean(dim=[1, 2, 3]) 79 | all_loss = image_loss.mean() 80 | return all_loss 81 | -------------------------------------------------------------------------------- /lam/losses/pixelwise.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | from einops import rearrange 19 | 20 | __all__ = ['PixelLoss'] 21 | 22 | 23 | class PixelLoss(nn.Module): 24 | """ 25 | Pixel-wise loss between two images. 26 | """ 27 | 28 | def __init__(self, option: str = 'mse'): 29 | super().__init__() 30 | self.loss_fn = self._build_from_option(option) 31 | 32 | @staticmethod 33 | def _build_from_option(option: str, reduction: str = 'none'): 34 | if option == 'mse': 35 | return nn.MSELoss(reduction=reduction) 36 | elif option == 'l1': 37 | return nn.L1Loss(reduction=reduction) 38 | else: 39 | raise NotImplementedError(f'Unknown pixel loss option: {option}') 40 | 41 | @torch.compile 42 | def forward(self, x, y, conf_sigma=None, only_sym_conf=False): 43 | """ 44 | Assume images are channel first. 45 | 46 | Args: 47 | x: [N, M, C, H, W] 48 | y: [N, M, C, H, W] 49 | 50 | Returns: 51 | Mean-reduced pixel loss across batch. 52 | """ 53 | N, M, C, H, W = x.shape 54 | x = rearrange(x, "n m c h w -> (n m) c h w") 55 | y = rearrange(y, "n m c h w -> (n m) c h w") 56 | image_loss = self.loss_fn(x, y) 57 | 58 | image_loss = image_loss.mean(dim=[1, 2, 3]) 59 | batch_loss = image_loss.reshape(N, M).mean(dim=1) 60 | all_loss = batch_loss.mean() 61 | return all_loss 62 | -------------------------------------------------------------------------------- /lam/losses/tvloss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | __all__ = ['TVLoss'] 20 | 21 | 22 | class TVLoss(nn.Module): 23 | """ 24 | Total variance loss. 25 | """ 26 | 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def numel_excluding_first_dim(self, x): 31 | return x.numel() // x.shape[0] 32 | 33 | @torch.compile 34 | def forward(self, x): 35 | """ 36 | Assume batched and channel first with inner sizes. 37 | 38 | Args: 39 | x: [N, M, C, H, W] 40 | 41 | Returns: 42 | Mean-reduced TV loss with element-level scaling. 43 | """ 44 | N, M, C, H, W = x.shape 45 | x = x.reshape(N*M, C, H, W) 46 | diff_i = x[..., 1:, :] - x[..., :-1, :] 47 | diff_j = x[..., :, 1:] - x[..., :, :-1] 48 | div_i = self.numel_excluding_first_dim(diff_i) 49 | div_j = self.numel_excluding_first_dim(diff_j) 50 | tv_i = diff_i.pow(2).sum(dim=[1,2,3]) / div_i 51 | tv_j = diff_j.pow(2).sum(dim=[1,2,3]) / div_j 52 | tv = tv_i + tv_j 53 | batch_tv = tv.reshape(N, M).mean(dim=1) 54 | all_tv = batch_tv.mean() 55 | return all_tv 56 | -------------------------------------------------------------------------------- /lam/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .modeling_lam import ModelLAM 17 | 18 | 19 | model_dict = { 20 | 'lam': ModelLAM, 21 | } 22 | -------------------------------------------------------------------------------- /lam/models/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported from Paella 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.models.modeling_utils import ModelMixin 10 | 11 | import functools 12 | # import torch.nn as nn 13 | from taming.modules.util import ActNorm 14 | 15 | 16 | # Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py 17 | class Discriminator(ModelMixin, ConfigMixin): 18 | @register_to_config 19 | def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6): 20 | super().__init__() 21 | d = max(depth - 3, 3) 22 | layers = [ 23 | nn.utils.spectral_norm( 24 | nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1) 25 | ), 26 | nn.LeakyReLU(0.2), 27 | ] 28 | for i in range(depth - 1): 29 | c_in = hidden_channels // (2 ** max((d - i), 0)) 30 | c_out = hidden_channels // (2 ** max((d - 1 - i), 0)) 31 | layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) 32 | layers.append(nn.InstanceNorm2d(c_out)) 33 | layers.append(nn.LeakyReLU(0.2)) 34 | self.encoder = nn.Sequential(*layers) 35 | self.shuffle = nn.Conv2d( 36 | (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1 37 | ) 38 | # self.logits = nn.Sigmoid() 39 | 40 | 41 | def forward(self, x, cond=None): 42 | x = self.encoder(x) 43 | if cond is not None: 44 | cond = cond.view( 45 | cond.size(0), 46 | cond.size(1), 47 | 1, 48 | 1, 49 | ).expand(-1, -1, x.size(-2), x.size(-1)) 50 | x = torch.cat([x, cond], dim=1) 51 | x = self.shuffle(x) 52 | # x = self.logits(x) 53 | return x 54 | 55 | 56 | 57 | 58 | def weights_init(m): 59 | classname = m.__class__.__name__ 60 | if classname.find('Conv') != -1: 61 | nn.init.normal_(m.weight.data, 0.0, 0.02) 62 | elif classname.find('BatchNorm') != -1: 63 | nn.init.normal_(m.weight.data, 1.0, 0.02) 64 | nn.init.constant_(m.bias.data, 0) 65 | 66 | 67 | class NLayerDiscriminator(nn.Module): 68 | """Defines a PatchGAN discriminator as in Pix2Pix 69 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 70 | """ 71 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 72 | """Construct a PatchGAN discriminator 73 | Parameters: 74 | input_nc (int) -- the number of channels in input images 75 | ndf (int) -- the number of filters in the last conv layer 76 | n_layers (int) -- the number of conv layers in the discriminator 77 | norm_layer -- normalization layer 78 | """ 79 | super(NLayerDiscriminator, self).__init__() 80 | if not use_actnorm: 81 | # norm_layer = nn.BatchNorm2d 82 | norm_layer = nn.InstanceNorm2d 83 | else: 84 | norm_layer = ActNorm 85 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 86 | # use_bias = norm_layer.func != nn.BatchNorm2d 87 | use_bias = norm_layer.func != nn.InstanceNorm2d 88 | else: 89 | # use_bias = norm_layer != nn.BatchNorm2d 90 | use_bias = norm_layer != nn.InstanceNorm2d 91 | 92 | kw = 4 93 | padw = 1 94 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)] 95 | nf_mult = 1 96 | nf_mult_prev = 1 97 | for n in range(1, n_layers): # gradually increase the number of filters 98 | nf_mult_prev = nf_mult 99 | nf_mult = min(2 ** n, 8) 100 | sequence += [ 101 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 102 | norm_layer(ndf * nf_mult), 103 | nn.LeakyReLU(0.2, False) 104 | ] 105 | 106 | nf_mult_prev = nf_mult 107 | nf_mult = min(2 ** n_layers, 8) 108 | sequence += [ 109 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 110 | norm_layer(ndf * nf_mult), 111 | nn.LeakyReLU(0.2, False) 112 | ] 113 | 114 | sequence += [ 115 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 116 | self.main = nn.Sequential(*sequence) 117 | 118 | def forward(self, input): 119 | """Standard forward.""" 120 | return self.main(input) 121 | -------------------------------------------------------------------------------- /lam/models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Empty 16 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Empty 16 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/hub/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/hub/depth/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .decode_heads import BNHead, DPTHead 7 | from .encoder_decoder import DepthEncoderDecoder 8 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/hub/depth/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import warnings 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): 12 | if warning: 13 | if size is not None and align_corners: 14 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 15 | output_h, output_w = tuple(int(x) for x in size) 16 | if output_h > input_h or output_w > output_h: 17 | if ( 18 | (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) 19 | and (output_h - 1) % (input_h - 1) 20 | and (output_w - 1) % (input_w - 1) 21 | ): 22 | warnings.warn( 23 | f"When align_corners={align_corners}, " 24 | "the output would more aligned if " 25 | f"input size {(input_h, input_w)} is `x+1` and " 26 | f"out size {(output_h, output_w)} is `nx+1`" 27 | ) 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/hub/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 15 | 16 | 17 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: 18 | compact_arch_name = arch_name.replace("_", "")[:4] 19 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" 20 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" 21 | 22 | 23 | class CenterPadding(nn.Module): 24 | def __init__(self, multiple): 25 | super().__init__() 26 | self.multiple = multiple 27 | 28 | def _get_pad(self, size): 29 | new_size = math.ceil(size / self.multiple) * self.multiple 30 | pad_size = new_size - size 31 | pad_size_left = pad_size // 2 32 | pad_size_right = pad_size - pad_size_left 33 | return pad_size_left, pad_size_right 34 | 35 | @torch.inference_mode() 36 | def forward(self, x): 37 | pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) 38 | output = F.pad(x, pads) 39 | return output 40 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # ****************************************************************************** 7 | # Code modified by Zexin He in 2023-2024. 8 | # Modifications are marked with clearly visible comments 9 | # licensed under the Apache License, Version 2.0. 10 | # ****************************************************************************** 11 | 12 | from .dino_head import DINOHead 13 | from .mlp import Mlp 14 | from .patch_embed import PatchEmbed 15 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 16 | # ********** Modified by Zexin He in 2023-2024 ********** 17 | # Avoid using nested tensor for now, deprecating usage of NestedTensorBlock 18 | from .block import Block, BlockWithModulation 19 | # ******************************************************** 20 | from .attention import MemEffAttention 21 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 22 | try: 23 | if XFORMERS_ENABLED: 24 | from xformers.ops import memory_efficient_attention, unbind 25 | 26 | XFORMERS_AVAILABLE = True 27 | warnings.warn("xFormers is available (Attention)") 28 | else: 29 | warnings.warn("xFormers is disabled (Attention)") 30 | raise ImportError 31 | except ImportError: 32 | XFORMERS_AVAILABLE = False 33 | warnings.warn("xFormers is not available (Attention)") 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int = 8, 41 | qkv_bias: bool = False, 42 | proj_bias: bool = True, 43 | attn_drop: float = 0.0, 44 | proj_drop: float = 0.0, 45 | ) -> None: 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | self.scale = head_dim**-0.5 50 | 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | def forward(self, x: Tensor) -> Tensor: 57 | B, N, C = x.shape 58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 59 | 60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 61 | attn = q @ k.transpose(-2, -1) 62 | 63 | attn = attn.softmax(dim=-1) 64 | attn = self.attn_drop(attn) 65 | 66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | return x 70 | 71 | 72 | class MemEffAttention(Attention): 73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 74 | if not XFORMERS_AVAILABLE: 75 | if attn_bias is not None: 76 | raise AssertionError("xFormers is required for using nested tensors") 77 | return super().forward(x) 78 | 79 | B, N, C = x.shape 80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 81 | 82 | q, k, v = unbind(qkv, 2) 83 | 84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 85 | x = x.reshape([B, N, C]) 86 | 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.init import trunc_normal_ 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | class DINOHead(nn.Module): 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim, 17 | use_bn=False, 18 | nlayers=3, 19 | hidden_dim=2048, 20 | bottleneck_dim=256, 21 | mlp_bias=True, 22 | ): 23 | super().__init__() 24 | nlayers = max(nlayers, 1) 25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 26 | self.apply(self._init_weights) 27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 28 | self.last_layer.weight_g.data.fill_(1) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=0.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | x = self.mlp(x) 38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 40 | x = self.last_layer(x) 41 | return x 42 | 43 | 44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 45 | if nlayers == 1: 46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 47 | else: 48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 49 | if use_bn: 50 | layers.append(nn.BatchNorm1d(hidden_dim)) 51 | layers.append(nn.GELU()) 52 | for _ in range(nlayers - 2): 53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 54 | if use_bn: 55 | layers.append(nn.BatchNorm1d(hidden_dim)) 56 | layers.append(nn.GELU()) 57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 58 | return nn.Sequential(*layers) 59 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | _, _, H, W = x.shape 70 | patch_H, patch_W = self.patch_size 71 | 72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 74 | 75 | x = self.proj(x) # B C H W 76 | H, W = x.size(2), x.size(3) 77 | x = x.flatten(2).transpose(1, 2) # B HW C 78 | x = self.norm(x) 79 | if not self.flatten_embedding: 80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 81 | return x 82 | 83 | def flops(self) -> float: 84 | Ho, Wo = self.patches_resolution 85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 86 | if self.norm is not None: 87 | flops += Ho * Wo * self.embed_dim 88 | return flops 89 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | try: 39 | if XFORMERS_ENABLED: 40 | from xformers.ops import SwiGLU 41 | 42 | XFORMERS_AVAILABLE = True 43 | warnings.warn("xFormers is available (SwiGLU)") 44 | else: 45 | warnings.warn("xFormers is disabled (SwiGLU)") 46 | raise ImportError 47 | except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /lam/models/encoders/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | from . import vision_transformer as vits 9 | 10 | 11 | logger = logging.getLogger("dinov2") 12 | 13 | 14 | def build_model(args, only_teacher=False, img_size=224): 15 | args.arch = args.arch.removesuffix("_memeff") 16 | if "vit" in args.arch: 17 | vit_kwargs = dict( 18 | img_size=img_size, 19 | patch_size=args.patch_size, 20 | init_values=args.layerscale, 21 | ffn_layer=args.ffn_layer, 22 | block_chunks=args.block_chunks, 23 | qkv_bias=args.qkv_bias, 24 | proj_bias=args.proj_bias, 25 | ffn_bias=args.ffn_bias, 26 | num_register_tokens=args.num_register_tokens, 27 | interpolate_offset=args.interpolate_offset, 28 | interpolate_antialias=args.interpolate_antialias, 29 | ) 30 | teacher = vits.__dict__[args.arch](**vit_kwargs) 31 | if only_teacher: 32 | return teacher, teacher.embed_dim 33 | student = vits.__dict__[args.arch]( 34 | **vit_kwargs, 35 | drop_path_rate=args.drop_path_rate, 36 | drop_path_uniform=args.drop_path_uniform, 37 | ) 38 | embed_dim = student.embed_dim 39 | return student, teacher, embed_dim 40 | 41 | 42 | def build_model_from_cfg(cfg, only_teacher=False): 43 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) 44 | -------------------------------------------------------------------------------- /lam/models/encoders/dpt_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/lam/models/encoders/dpt_util/__init__.py -------------------------------------------------------------------------------- /lam/models/encoders/dpt_util/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 5 | scratch = nn.Module() 6 | 7 | out_shape1 = out_shape 8 | out_shape2 = out_shape 9 | out_shape3 = out_shape 10 | if len(in_shape) >= 4: 11 | out_shape4 = out_shape 12 | 13 | if expand: 14 | out_shape1 = out_shape 15 | out_shape2 = out_shape * 2 16 | out_shape3 = out_shape * 4 17 | if len(in_shape) >= 4: 18 | out_shape4 = out_shape * 8 19 | 20 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 21 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 22 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 23 | if len(in_shape) >= 4: 24 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 25 | 26 | return scratch 27 | 28 | 29 | class ResidualConvUnit(nn.Module): 30 | """Residual convolution module. 31 | """ 32 | 33 | def __init__(self, features, activation, bn): 34 | """Init. 35 | 36 | Args: 37 | features (int): number of features 38 | """ 39 | super().__init__() 40 | 41 | self.bn = bn 42 | 43 | self.groups=1 44 | 45 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 46 | 47 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 48 | 49 | if self.bn == True: 50 | self.bn1 = nn.BatchNorm2d(features) 51 | self.bn2 = nn.BatchNorm2d(features) 52 | 53 | self.activation = activation 54 | 55 | self.skip_add = nn.quantized.FloatFunctional() 56 | 57 | def forward(self, x): 58 | """Forward pass. 59 | 60 | Args: 61 | x (tensor): input 62 | 63 | Returns: 64 | tensor: output 65 | """ 66 | 67 | out = self.activation(x) 68 | out = self.conv1(out) 69 | if self.bn == True: 70 | out = self.bn1(out) 71 | 72 | out = self.activation(out) 73 | out = self.conv2(out) 74 | if self.bn == True: 75 | out = self.bn2(out) 76 | 77 | if self.groups > 1: 78 | out = self.conv_merge(out) 79 | 80 | return self.skip_add.add(out, x) 81 | 82 | 83 | class FeatureFusionBlock(nn.Module): 84 | """Feature fusion block. 85 | """ 86 | 87 | def __init__( 88 | self, 89 | features, 90 | activation, 91 | deconv=False, 92 | bn=False, 93 | expand=False, 94 | align_corners=True, 95 | size=None, 96 | use_conv1=True 97 | ): 98 | """Init. 99 | 100 | Args: 101 | features (int): number of features 102 | """ 103 | super(FeatureFusionBlock, self).__init__() 104 | 105 | self.deconv = deconv 106 | self.align_corners = align_corners 107 | 108 | self.groups=1 109 | 110 | self.expand = expand 111 | out_features = features 112 | if self.expand == True: 113 | out_features = features // 2 114 | 115 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 116 | 117 | if use_conv1: 118 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 119 | self.skip_add = nn.quantized.FloatFunctional() 120 | 121 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 122 | 123 | 124 | self.size=size 125 | 126 | def forward(self, *xs, size=None, scale_factor=2): 127 | """Forward pass. 128 | 129 | Returns: 130 | tensor: output 131 | """ 132 | output = xs[0] 133 | 134 | if len(xs) == 2: 135 | res = self.resConfUnit1(xs[1]) 136 | output = self.skip_add.add(output, res) 137 | 138 | output = self.resConfUnit2(output) 139 | 140 | if (size is None) and (self.size is None): 141 | modifier = {"scale_factor": scale_factor} 142 | elif size is None: 143 | modifier = {"size": self.size} 144 | else: 145 | modifier = {"size": size} 146 | 147 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) 148 | 149 | output = self.out_conv(output) 150 | 151 | return output 152 | -------------------------------------------------------------------------------- /lam/models/modulate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | class ModLN(nn.Module): 21 | """ 22 | Modulation with adaLN. 23 | 24 | References: 25 | DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101 26 | """ 27 | def __init__(self, inner_dim: int, mod_dim: int, eps: float): 28 | super().__init__() 29 | self.norm = nn.LayerNorm(inner_dim, eps=eps) 30 | self.mlp = nn.Sequential( 31 | nn.SiLU(), 32 | nn.Linear(mod_dim, inner_dim * 2), 33 | ) 34 | 35 | @staticmethod 36 | def modulate(x, shift, scale): 37 | # x: [N, L, D] 38 | # shift, scale: [N, D] 39 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 40 | 41 | def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: 42 | shift, scale = self.mlp(mod).chunk(2, dim=-1) # [N, D] 43 | return self.modulate(self.norm(x), shift, scale) # [N, L, D] 44 | -------------------------------------------------------------------------------- /lam/models/rendering/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Empty 16 | -------------------------------------------------------------------------------- /lam/models/rendering/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | -------------------------------------------------------------------------------- /lam/models/rendering/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Petr Kellnhofer 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | 25 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: 26 | """ 27 | Left-multiplies MxM @ NxM. Returns NxM. 28 | """ 29 | res = torch.matmul(vectors4, matrix.T) 30 | return res 31 | 32 | 33 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Normalize vector lengths. 36 | """ 37 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) 38 | 39 | def torch_dot(x: torch.Tensor, y: torch.Tensor): 40 | """ 41 | Dot product of two tensors. 42 | """ 43 | return (x * y).sum(-1) 44 | 45 | 46 | def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): 47 | """ 48 | Author: Petr Kellnhofer 49 | Intersects rays with the [-1, 1] NDC volume. 50 | Returns min and max distance of entry. 51 | Returns -1 for no intersection. 52 | https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection 53 | """ 54 | o_shape = rays_o.shape 55 | rays_o = rays_o.detach().reshape(-1, 3) 56 | rays_d = rays_d.detach().reshape(-1, 3) 57 | 58 | 59 | bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] 60 | bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] 61 | bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) 62 | is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) 63 | 64 | # Precompute inverse for stability. 65 | invdir = 1 / rays_d 66 | sign = (invdir < 0).long() 67 | 68 | # Intersect with YZ plane. 69 | tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] 70 | tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] 71 | 72 | # Intersect with XZ plane. 73 | tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] 74 | tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] 75 | 76 | # Resolve parallel rays. 77 | is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False 78 | 79 | # Use the shortest intersection. 80 | tmin = torch.max(tmin, tymin) 81 | tmax = torch.min(tmax, tymax) 82 | 83 | # Intersect with XY plane. 84 | tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] 85 | tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] 86 | 87 | # Resolve parallel rays. 88 | is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False 89 | 90 | # Use the shortest intersection. 91 | tmin = torch.max(tmin, tzmin) 92 | tmax = torch.min(tmax, tzmax) 93 | 94 | # Mark invalid. 95 | tmin[torch.logical_not(is_valid)] = -1 96 | tmax[torch.logical_not(is_valid)] = -2 97 | 98 | return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) 99 | 100 | 101 | def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): 102 | """ 103 | Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. 104 | Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. 105 | """ 106 | # create a tensor of 'num' steps from 0 to 1 107 | steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) 108 | 109 | # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings 110 | # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript 111 | # "cannot statically infer the expected size of a list in this contex", hence the code below 112 | for i in range(start.ndim): 113 | steps = steps.unsqueeze(-1) 114 | 115 | # the output starts at 'start' and increments until 'stop' in each dimension 116 | out = start[None] + steps * (stop - start)[None] 117 | 118 | return out 119 | -------------------------------------------------------------------------------- /lam/models/rendering/utils/point_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os, cv2 6 | import matplotlib.pyplot as plt 7 | import math 8 | 9 | def depths_to_points(view, depthmap): 10 | c2w = (view.world_view_transform.T).inverse() 11 | if hasattr(view, "image_width"): 12 | W, H = view.image_width, view.image_height 13 | else: 14 | W, H = view.width, view.height 15 | ndc2pix = torch.tensor([ 16 | [W / 2, 0, 0, (W) / 2], 17 | [0, H / 2, 0, (H) / 2], 18 | [0, 0, 0, 1]]).float().cuda().T 19 | projection_matrix = c2w.T @ view.full_proj_transform 20 | intrins = (projection_matrix @ ndc2pix)[:3,:3].T 21 | 22 | grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy') 23 | points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3) 24 | rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T 25 | rays_o = c2w[:3,3] 26 | points = depthmap.reshape(-1, 1) * rays_d + rays_o 27 | return points 28 | 29 | def depth_to_normal(view, depth): 30 | """ 31 | view: view camera 32 | depth: depthmap 33 | """ 34 | points = depths_to_points(view, depth).reshape(*depth.shape[1:], 3) 35 | output = torch.zeros_like(points) 36 | dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) 37 | dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) 38 | normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) 39 | output[1:-1, 1:-1, :] = normal_map 40 | return output -------------------------------------------------------------------------------- /lam/models/rendering/utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /lam/models/rendering/utils/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | ) 28 | 29 | # Tensor dtype 30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 32 | 33 | # Config type 34 | from omegaconf import DictConfig 35 | 36 | # PyTorch Tensor type 37 | from torch import Tensor 38 | 39 | # Runtime type checking decorator 40 | from typeguard import typechecked as typechecker 41 | -------------------------------------------------------------------------------- /lam/models/rendering/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from torch.cuda.amp import custom_bwd, custom_fwd 5 | 6 | from lam.models.rendering.utils.typing import * 7 | 8 | def get_activation(name): 9 | if name is None: 10 | return lambda x: x 11 | name = name.lower() 12 | if name == "none": 13 | return lambda x: x 14 | elif name == "lin2srgb": 15 | return lambda x: torch.where( 16 | x > 0.0031308, 17 | torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, 18 | 12.92 * x, 19 | ).clamp(0.0, 1.0) 20 | elif name == "exp": 21 | return lambda x: torch.exp(x) 22 | elif name == "shifted_exp": 23 | return lambda x: torch.exp(x - 1.0) 24 | elif name == "trunc_exp": 25 | return trunc_exp 26 | elif name == "shifted_trunc_exp": 27 | return lambda x: trunc_exp(x - 1.0) 28 | elif name == "sigmoid": 29 | return lambda x: torch.sigmoid(x) 30 | elif name == "tanh": 31 | return lambda x: torch.tanh(x) 32 | elif name == "shifted_softplus": 33 | return lambda x: F.softplus(x - 1.0) 34 | elif name == "scale_-11_01": 35 | return lambda x: x * 0.5 + 0.5 36 | else: 37 | try: 38 | return getattr(F, name) 39 | except AttributeError: 40 | raise ValueError(f"Unknown activation function: {name}") 41 | 42 | class MLP(nn.Module): 43 | def __init__( 44 | self, 45 | dim_in: int, 46 | dim_out: int, 47 | n_neurons: int, 48 | n_hidden_layers: int, 49 | activation: str = "relu", 50 | output_activation: Optional[str] = None, 51 | bias: bool = True, 52 | ): 53 | super().__init__() 54 | layers = [ 55 | self.make_linear( 56 | dim_in, n_neurons, is_first=True, is_last=False, bias=bias 57 | ), 58 | self.make_activation(activation), 59 | ] 60 | for i in range(n_hidden_layers - 1): 61 | layers += [ 62 | self.make_linear( 63 | n_neurons, n_neurons, is_first=False, is_last=False, bias=bias 64 | ), 65 | self.make_activation(activation), 66 | ] 67 | layers += [ 68 | self.make_linear( 69 | n_neurons, dim_out, is_first=False, is_last=True, bias=bias 70 | ) 71 | ] 72 | self.layers = nn.Sequential(*layers) 73 | self.output_activation = get_activation(output_activation) 74 | 75 | def forward(self, x): 76 | x = self.layers(x) 77 | x = self.output_activation(x) 78 | return x 79 | 80 | def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True): 81 | layer = nn.Linear(dim_in, dim_out, bias=bias) 82 | return layer 83 | 84 | def make_activation(self, activation): 85 | if activation == "relu": 86 | return nn.ReLU(inplace=True) 87 | elif activation == "silu": 88 | return nn.SiLU(inplace=True) 89 | else: 90 | raise NotImplementedError 91 | 92 | 93 | class _TruncExp(Function): # pylint: disable=abstract-method 94 | # Implementation from torch-ngp: 95 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py 96 | @staticmethod 97 | @custom_fwd(cast_inputs=torch.float32) 98 | def forward(ctx, x): # pylint: disable=arguments-differ 99 | ctx.save_for_backward(x) 100 | return torch.exp(x) 101 | 102 | @staticmethod 103 | @custom_bwd 104 | def backward(ctx, g): # pylint: disable=arguments-differ 105 | x = ctx.saved_tensors[0] 106 | return g * torch.exp(torch.clamp(x, max=15)) 107 | 108 | 109 | trunc_exp = _TruncExp.apply -------------------------------------------------------------------------------- /lam/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from lam.utils.registry import Registry 17 | 18 | REGISTRY_RUNNERS = Registry() 19 | 20 | from .infer import * 21 | -------------------------------------------------------------------------------- /lam/runners/abstract.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | 18 | 19 | class Runner(ABC): 20 | """Abstract runner class""" 21 | 22 | def __init__(self): 23 | pass 24 | 25 | @abstractmethod 26 | def run(self): 27 | pass 28 | -------------------------------------------------------------------------------- /lam/runners/infer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .lam import LAMInferrer 16 | -------------------------------------------------------------------------------- /lam/runners/infer/base_inferrer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from abc import abstractmethod 18 | from accelerate import Accelerator 19 | from accelerate.logging import get_logger 20 | 21 | from lam.runners.abstract import Runner 22 | 23 | 24 | logger = get_logger(__name__) 25 | 26 | 27 | class Inferrer(Runner): 28 | 29 | EXP_TYPE: str = None 30 | 31 | def __init__(self): 32 | super().__init__() 33 | 34 | torch._dynamo.config.disable = True 35 | self.accelerator = Accelerator() 36 | 37 | self.model : torch.nn.Module = None 38 | 39 | def __enter__(self): 40 | return self 41 | 42 | def __exit__(self, exc_type, exc_val, exc_tb): 43 | pass 44 | 45 | @property 46 | def device(self): 47 | return self.accelerator.device 48 | 49 | @abstractmethod 50 | def _build_model(self, cfg): 51 | pass 52 | 53 | @abstractmethod 54 | def infer_single(self, *args, **kwargs): 55 | pass 56 | 57 | @abstractmethod 58 | def infer(self): 59 | pass 60 | 61 | def run(self): 62 | self.infer() 63 | -------------------------------------------------------------------------------- /lam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Empty 16 | -------------------------------------------------------------------------------- /lam/utils/compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from accelerate.logging import get_logger 17 | 18 | 19 | logger = get_logger(__name__) 20 | 21 | 22 | def configure_dynamo(config: dict): 23 | try: 24 | import torch._dynamo 25 | logger.debug(f'Configuring torch._dynamo.config with {config}') 26 | for k, v in config.items(): 27 | if v is None: 28 | logger.debug(f'Skipping torch._dynamo.config.{k} with None') 29 | continue 30 | if hasattr(torch._dynamo.config, k): 31 | logger.warning(f'Overriding torch._dynamo.config.{k} from {getattr(torch._dynamo.config, k)} to {v}') 32 | setattr(torch._dynamo.config, k, v) 33 | except ImportError: 34 | logger.debug('torch._dynamo not found, skipping') 35 | pass 36 | -------------------------------------------------------------------------------- /lam/utils/ffmpeg_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import torch 4 | import numpy as np 5 | import imageio 6 | import cv2 7 | import imageio.v3 as iio 8 | 9 | VIDEO_TYPE_LIST = {'.avi','.mp4','.gif','.AVI','.MP4','.GIF'} 10 | 11 | def encodeffmpeg(inputs, frame_rate, output, format="png"): 12 | """output: need video_name""" 13 | assert ( 14 | os.path.splitext(output)[-1] in VIDEO_TYPE_LIST 15 | ), "output is the format of video, e.g., mp4" 16 | assert os.path.isdir(inputs), "input dir is NOT file format" 17 | 18 | inputs = inputs[:-1] if inputs[-1] == "/" else inputs 19 | 20 | output = os.path.abspath(output) 21 | 22 | cmd = ( 23 | f"ffmpeg -r {frame_rate} -pattern_type glob -i '{inputs}/*.{format}' " 24 | + f'-vcodec libx264 -crf 10 -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" ' 25 | + f"-pix_fmt yuv420p {output} > /dev/null 2>&1" 26 | ) 27 | 28 | print(cmd) 29 | 30 | output_dir = os.path.dirname(output) 31 | if os.path.exists(output): 32 | os.remove(output) 33 | os.makedirs(output_dir, exist_ok=True) 34 | 35 | print("encoding imgs to video.....") 36 | os.system(cmd) 37 | print("video done!") 38 | 39 | def images_to_video(images, output_path, fps, gradio_codec: bool, verbose=False, bitrate="2M"): 40 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 41 | frames = [] 42 | for i in range(images.shape[0]): 43 | if isinstance(images, torch.Tensor): 44 | frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) 45 | assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ 46 | f"Frame shape mismatch: {frame.shape} vs {images.shape}" 47 | assert frame.min() >= 0 and frame.max() <= 255, \ 48 | f"Frame value out of range: {frame.min()} ~ {frame.max()}" 49 | else: 50 | frame = images[i] 51 | width, height = frame.shape[1], frame.shape[0] 52 | # reshape to limit the export time 53 | # if width > 1200 or height > 1200 or images.shape[0] > 200: 54 | # frames.append(cv2.resize(frame, (width // 2, height // 2))) 55 | # else: 56 | frames.append(frame) 57 | # limit the frames directly @NOTE huggingface only! 58 | frames = frames[:200] 59 | 60 | frames = np.stack(frames) 61 | 62 | print("start saving {} using imageio.v3 .".format(output_path)) 63 | iio.imwrite(output_path,frames,fps=fps,codec="libx264",pixelformat="yuv420p",bitrate=bitrate, macro_block_size=32) 64 | print("saved {} using imageio.v3 .".format(output_path)) -------------------------------------------------------------------------------- /lam/utils/gen_id_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import sys 4 | import os 5 | 6 | data_root = sys.argv[1] 7 | save_path = sys.argv[2] 8 | 9 | all_hid_list = [] 10 | for hid in os.listdir(data_root): 11 | if hid.startswith("p"): 12 | hid = os.path.join(data_root, hid) 13 | all_hid_list.append(hid.replace(data_root + "/", "")) 14 | 15 | print(f"len:{len(all_hid_list)}") 16 | print(all_hid_list[:3]) 17 | with open(save_path, 'w') as fp: 18 | json.dump(all_hid_list, fp, indent=4) -------------------------------------------------------------------------------- /lam/utils/gen_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import sys 4 | import os 5 | 6 | data_root = sys.argv[1] 7 | save_path = sys.argv[2] 8 | 9 | all_img_list = [] 10 | for hid in os.listdir(data_root): 11 | all_view_imgs_dir = os.path.join(data_root, hid, "kinect_color") 12 | if not os.path.exists(all_view_imgs_dir): 13 | continue 14 | 15 | for view_id in os.listdir(all_view_imgs_dir): 16 | imgs_dir = os.path.join(all_view_imgs_dir, view_id) 17 | for img_path in glob.glob(os.path.join(imgs_dir, "*.png")): 18 | all_img_list.append(img_path.replace(data_root + "/", "")) 19 | 20 | print(f"len:{len(all_img_list)}") 21 | print(all_img_list[:3]) 22 | with open(save_path, 'w') as fp: 23 | json.dump(all_img_list, fp, indent=4) -------------------------------------------------------------------------------- /lam/utils/hf_hub.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch.nn as nn 17 | from huggingface_hub import PyTorchModelHubMixin 18 | 19 | 20 | def wrap_model_hub(model_cls: nn.Module): 21 | class HfModel(model_cls, PyTorchModelHubMixin): 22 | def __init__(self, config: dict): 23 | super().__init__(**config) 24 | self.config = config 25 | return HfModel 26 | -------------------------------------------------------------------------------- /lam/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import logging 18 | from tqdm.auto import tqdm 19 | 20 | 21 | class TqdmStreamHandler(logging.StreamHandler): 22 | def emit(self, record): 23 | tqdm.write(self.format(record)) 24 | 25 | 26 | def configure_logger(stream_level, log_level, file_path = None): 27 | _stream_level = stream_level.upper() 28 | _log_level = log_level.upper() 29 | _project_level = _log_level 30 | 31 | _formatter = logging.Formatter("[%(asctime)s] %(name)s: [%(levelname)s] %(message)s") 32 | 33 | _stream_handler = TqdmStreamHandler() 34 | _stream_handler.setLevel(_stream_level) 35 | _stream_handler.setFormatter(_formatter) 36 | 37 | if file_path is not None: 38 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 39 | _file_handler = logging.FileHandler(file_path) 40 | _file_handler.setLevel(_log_level) 41 | _file_handler.setFormatter(_formatter) 42 | 43 | _project_logger = logging.getLogger(__name__.split('.')[0]) 44 | _project_logger.setLevel(_project_level) 45 | _project_logger.addHandler(_stream_handler) 46 | if file_path is not None: 47 | _project_logger.addHandler(_file_handler) 48 | -------------------------------------------------------------------------------- /lam/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import rembg 18 | import cv2 19 | 20 | 21 | class Preprocessor: 22 | 23 | """ 24 | Preprocessing under cv2 conventions. 25 | """ 26 | 27 | def __init__(self): 28 | self.rembg_session = rembg.new_session( 29 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"], 30 | ) 31 | 32 | def preprocess(self, image_path: str, save_path: str, rmbg: bool = True, recenter: bool = True, size: int = 512, border_ratio: float = 0.2): 33 | image = self.step_load_to_size(image_path=image_path, size=size*2) 34 | if rmbg: 35 | image = self.step_rembg(image_in=image) 36 | else: 37 | image = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA) 38 | if recenter: 39 | image = self.step_recenter(image_in=image, border_ratio=border_ratio, square_size=size) 40 | else: 41 | image = cv2.resize( 42 | src=image, 43 | dsize=(size, size), 44 | interpolation=cv2.INTER_AREA, 45 | ) 46 | return cv2.imwrite(save_path, image) 47 | 48 | def step_rembg(self, image_in: np.ndarray) -> np.ndarray: 49 | image_out = rembg.remove( 50 | data=image_in, 51 | session=self.rembg_session, 52 | ) 53 | return image_out 54 | 55 | def step_recenter(self, image_in: np.ndarray, border_ratio: float, square_size: int) -> np.ndarray: 56 | assert image_in.shape[-1] == 4, "Image to recenter must be RGBA" 57 | mask = image_in[..., -1] > 0 58 | ijs = np.nonzero(mask) 59 | # find bbox 60 | i_min, i_max = ijs[0].min(), ijs[0].max() 61 | j_min, j_max = ijs[1].min(), ijs[1].max() 62 | bbox_height, bbox_width = i_max - i_min, j_max - j_min 63 | # recenter and resize 64 | desired_size = int(square_size * (1 - border_ratio)) 65 | scale = desired_size / max(bbox_height, bbox_width) 66 | desired_height, desired_width = int(bbox_height * scale), int(bbox_width * scale) 67 | desired_i_min, desired_j_min = (square_size - desired_height) // 2, (square_size - desired_width) // 2 68 | desired_i_max, desired_j_max = desired_i_min + desired_height, desired_j_min + desired_width 69 | # create new image 70 | image_out = np.zeros((square_size, square_size, 4), dtype=np.uint8) 71 | image_out[desired_i_min:desired_i_max, desired_j_min:desired_j_max] = cv2.resize( 72 | src=image_in[i_min:i_max, j_min:j_max], 73 | dsize=(desired_width, desired_height), 74 | interpolation=cv2.INTER_AREA, 75 | ) 76 | return image_out 77 | 78 | def step_load_to_size(self, image_path: str, size: int) -> np.ndarray: 79 | image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) 80 | height, width = image.shape[:2] 81 | scale = size / max(height, width) 82 | height, width = int(height * scale), int(width * scale) 83 | image_out = cv2.resize( 84 | src=image, 85 | dsize=(width, height), 86 | interpolation=cv2.INTER_AREA, 87 | ) 88 | return image_out 89 | -------------------------------------------------------------------------------- /lam/utils/profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from torch.profiler import profile 17 | 18 | 19 | class DummyProfiler(profile): 20 | def __init__(self): 21 | pass 22 | 23 | def __enter__(self): 24 | return self 25 | 26 | def __exit__(self, *args): 27 | pass 28 | 29 | def step(self): 30 | pass 31 | -------------------------------------------------------------------------------- /lam/utils/proxy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | NO_PROXY = "lam_NO_DATA_PROXY" in os.environ 19 | 20 | def no_proxy(func): 21 | """Decorator to disable proxy but then restore after the function call.""" 22 | def wrapper(*args, **kwargs): 23 | # http_proxy, https_proxy, HTTP_PROXY, HTTPS_PROXY, all_proxy 24 | http_proxy = os.environ.get('http_proxy') 25 | https_proxy = os.environ.get('https_proxy') 26 | HTTP_PROXY = os.environ.get('HTTP_PROXY') 27 | HTTPS_PROXY = os.environ.get('HTTPS_PROXY') 28 | all_proxy = os.environ.get('all_proxy') 29 | os.environ['http_proxy'] = '' 30 | os.environ['https_proxy'] = '' 31 | os.environ['HTTP_PROXY'] = '' 32 | os.environ['HTTPS_PROXY'] = '' 33 | os.environ['all_proxy'] = '' 34 | try: 35 | return func(*args, **kwargs) 36 | finally: 37 | os.environ['http_proxy'] = http_proxy 38 | os.environ['https_proxy'] = https_proxy 39 | os.environ['HTTP_PROXY'] = HTTP_PROXY 40 | os.environ['HTTPS_PROXY'] = HTTPS_PROXY 41 | os.environ['all_proxy'] = all_proxy 42 | if NO_PROXY: 43 | return wrapper 44 | else: 45 | return func 46 | -------------------------------------------------------------------------------- /lam/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class Registry: 17 | """Registry class""" 18 | 19 | def __init__(self): 20 | self._registry = {} 21 | 22 | def register(self, name): 23 | """Register a module""" 24 | def decorator(cls): 25 | assert name not in self._registry, 'Module {} already registered'.format(name) 26 | self._registry[name] = cls 27 | return cls 28 | return decorator 29 | 30 | def __getitem__(self, name): 31 | """Get a module""" 32 | return self._registry[name] 33 | 34 | def __contains__(self, name): 35 | return name in self._registry 36 | -------------------------------------------------------------------------------- /lam/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | from torch.optim.lr_scheduler import LRScheduler 18 | from accelerate.logging import get_logger 19 | 20 | 21 | logger = get_logger(__name__) 22 | 23 | 24 | class CosineWarmupScheduler(LRScheduler): 25 | def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1): 26 | self.warmup_iters = warmup_iters 27 | self.max_iters = max_iters 28 | self.initial_lr = initial_lr 29 | super().__init__(optimizer, last_iter) 30 | 31 | def get_lr(self): 32 | logger.debug(f"step count: {self._step_count} | warmup iters: {self.warmup_iters} | max iters: {self.max_iters}") 33 | if self._step_count <= self.warmup_iters: 34 | return [ 35 | self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters 36 | for base_lr in self.base_lrs] 37 | else: 38 | cos_iter = self._step_count - self.warmup_iters 39 | cos_max_iter = self.max_iters - self.warmup_iters 40 | cos_theta = cos_iter / cos_max_iter * math.pi 41 | cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs] 42 | return cos_lr 43 | -------------------------------------------------------------------------------- /lam/utils/video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import numpy as np 18 | import torch 19 | 20 | def images_to_video(images, output_path, fps, gradio_codec: bool, verbose=False): 21 | import imageio 22 | # images: torch.tensor (T, C, H, W), 0-1 or numpy: (T, H, W, 3) 0-255 23 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 24 | frames = [] 25 | for i in range(images.shape[0]): 26 | if isinstance(images, torch.Tensor): 27 | frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) 28 | assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ 29 | f"Frame shape mismatch: {frame.shape} vs {images.shape}" 30 | assert frame.min() >= 0 and frame.max() <= 255, \ 31 | f"Frame value out of range: {frame.min()} ~ {frame.max()}" 32 | else: 33 | frame = images[i] 34 | frames.append(frame) 35 | frames = np.stack(frames) 36 | if gradio_codec: 37 | imageio.mimwrite(output_path, frames, fps=fps, quality=10) 38 | else: 39 | # imageio.mimwrite(output_path, frames, fps=fps, codec='mpeg4', quality=10) 40 | imageio.mimwrite(output_path, frames, fps=fps, quality=10) 41 | 42 | if verbose: 43 | print(f"Using gradio codec option {gradio_codec}") 44 | print(f"Saved video to {output_path}") 45 | 46 | 47 | def save_images2video(img_lst, v_pth, fps): 48 | import moviepy.editor as mpy 49 | # Convert the list of NumPy arrays to a list of ImageClip objects 50 | clips = [mpy.ImageClip(img).set_duration(0.1) for img in img_lst] # 0.1 seconds per frame 51 | 52 | # Concatenate the ImageClips into a single VideoClip 53 | video = mpy.concatenate_videoclips(clips, method="compose") 54 | 55 | # Write the VideoClip to a file 56 | video.write_videofile(v_pth, fps=fps) # setting fps to 10 as example 57 | print("save video to:", v_pth) 58 | 59 | 60 | if __name__ == "__main__": 61 | from glob import glob 62 | clip_name = "clip1" 63 | ptn = f"./assets/sample_motion/export/{clip_name}/images/*.png" 64 | images_pths = glob(ptn) 65 | import cv2 66 | import numpy as np 67 | images = [cv2.imread(pth) for pth in images_pths] 68 | save_images2video(images, "./assets/sample_mption/export/{clip_name}/video.mp4", 25, True) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pymcubes==0.1.6 2 | omegaconf==2.3.0 3 | moviepy==1.0.3 4 | pandas==2.2.3 5 | transformers==4.41.2 6 | numpy==1.23.0 7 | trimesh==4.4.9 8 | opencv_python_headless==4.11.0.86 9 | tensorflow==2.12.0 10 | face-detection-tflite==0.6.0 11 | scikit-image==0.20.0 12 | jaxlib==0.4.30 13 | gradio==3.44.3 14 | huggingface_hub==0.23.2 15 | Cython 16 | accelerate 17 | tyro 18 | einops 19 | diffusers 20 | plyfile 21 | jaxtyping 22 | typeguard 23 | chumpy 24 | loguru 25 | ninja 26 | git+https://github.com/facebookresearch/pytorch3d.git 27 | git+https://github.com/ashawkey/diff-gaussian-rasterization/ 28 | nvdiffrast@git+https://github.com/ShenhanQian/nvdiffrast@backface-culling 29 | git+https://github.com/camenduru/simple-knn/ -------------------------------------------------------------------------------- /scripts/convert_hf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import pdb 17 | import sys 18 | import traceback 19 | from tempfile import TemporaryDirectory 20 | 21 | import safetensors 22 | import torch.nn as nn 23 | from accelerate import Accelerator 24 | from megfile import ( 25 | smart_copy, 26 | smart_exists, 27 | smart_listdir, 28 | smart_makedirs, 29 | smart_path_join, 30 | ) 31 | from omegaconf import OmegaConf 32 | 33 | sys.path.append(".") 34 | 35 | from lam.models import model_dict 36 | from lam.utils.hf_hub import wrap_model_hub 37 | from lam.utils.proxy import no_proxy 38 | 39 | 40 | @no_proxy 41 | def auto_load_model(cfg, model: nn.Module) -> int: 42 | 43 | ckpt_root = smart_path_join( 44 | cfg.saver.checkpoint_root, 45 | cfg.experiment.parent, 46 | cfg.experiment.child, 47 | ) 48 | if not smart_exists(ckpt_root): 49 | raise FileNotFoundError(f"Checkpoint root not found: {ckpt_root}") 50 | ckpt_dirs = smart_listdir(ckpt_root) 51 | if len(ckpt_dirs) == 0: 52 | raise FileNotFoundError(f"No checkpoint found in {ckpt_root}") 53 | ckpt_dirs.sort() 54 | 55 | load_step = ( 56 | f"{cfg.convert.global_step}" 57 | if cfg.convert.global_step is not None 58 | else ckpt_dirs[-1] 59 | ) 60 | load_model_path = smart_path_join(ckpt_root, load_step, "model.safetensors") 61 | 62 | if load_model_path.startswith("s3"): 63 | tmpdir = TemporaryDirectory() 64 | tmp_model_path = smart_path_join(tmpdir.name, f"tmp.safetensors") 65 | smart_copy(load_model_path, tmp_model_path) 66 | load_model_path = tmp_model_path 67 | 68 | print(f"Loading from {load_model_path}") 69 | try: 70 | safetensors.torch.load_model(model, load_model_path, strict=True) 71 | except: 72 | traceback.print_exc() 73 | safetensors.torch.load_model(model, load_model_path, strict=False) 74 | 75 | return int(load_step) 76 | 77 | 78 | if __name__ == "__main__": 79 | 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--config", type=str, default="./assets/config.yaml") 82 | args, unknown = parser.parse_known_args() 83 | cfg = OmegaConf.load(args.config) 84 | cli_cfg = OmegaConf.from_cli(unknown) 85 | cfg = OmegaConf.merge(cfg, cli_cfg) 86 | 87 | """ 88 | [cfg.convert] 89 | global_step: int 90 | save_dir: str 91 | """ 92 | 93 | accelerator = Accelerator() 94 | 95 | # hf_model_cls = wrap_model_hub(model_dict[cfg.experiment.type]) 96 | hf_model_cls = wrap_model_hub(model_dict["human_lrm_sapdino_bh_sd3_5"]) 97 | 98 | hf_model = hf_model_cls(OmegaConf.to_container(cfg.model)) 99 | loaded_step = auto_load_model(cfg, hf_model) 100 | dump_path = smart_path_join( 101 | f"./exps/releases", 102 | cfg.experiment.parent, 103 | cfg.experiment.child, 104 | f"step_{loaded_step:06d}", 105 | ) 106 | print(f"Saving locally to {dump_path}") 107 | smart_makedirs(dump_path, exist_ok=True) 108 | hf_model.save_pretrained( 109 | save_directory=dump_path, 110 | config=hf_model.config, 111 | ) 112 | -------------------------------------------------------------------------------- /scripts/exp/run_4gpu.sh: -------------------------------------------------------------------------------- 1 | ACC_CONFIG="./configs/accelerate-train-4gpu.yaml" 2 | TRAIN_CONFIG="./configs/train-sample-human.yaml" 3 | 4 | if [ -n "$1" ]; then 5 | TRAIN_CONFIG=$1 6 | else 7 | TRAIN_CONFIG="./configs/train-sample-human.yaml" 8 | fi 9 | 10 | if [ -n "$2" ]; then 11 | MAIN_PORT=$2 12 | else 13 | MAIN_PORT=12345 14 | fi 15 | 16 | accelerate launch --config_file $ACC_CONFIG --main_process_port=$MAIN_PORT -m openlrm.launch train.human_lrm --config $TRAIN_CONFIG -------------------------------------------------------------------------------- /scripts/exp/run_8gpu.sh: -------------------------------------------------------------------------------- 1 | ACC_CONFIG="./configs/accelerate-train.yaml" 2 | TRAIN_CONFIG="./configs/train-sample-human.yaml" 3 | 4 | if [ -n "$1" ]; then 5 | TRAIN_CONFIG=$1 6 | else 7 | TRAIN_CONFIG="./configs/train-sample-human.yaml" 8 | fi 9 | 10 | if [ -n "$2" ]; then 11 | MAIN_PORT=$2 12 | else 13 | MAIN_PORT=12345 14 | fi 15 | 16 | accelerate launch --config_file $ACC_CONFIG --main_process_port=$MAIN_PORT -m openlrm.launch train.human_lrm --config $TRAIN_CONFIG -------------------------------------------------------------------------------- /scripts/exp/run_debug.sh: -------------------------------------------------------------------------------- 1 | ACC_CONFIG="./configs/accelerate-train-1gpu.yaml" 2 | 3 | if [ -n "$1" ]; then 4 | TRAIN_CONFIG=$1 5 | else 6 | TRAIN_CONFIG="./configs/train-sample-human.yaml" 7 | fi 8 | 9 | if [ -n "$2" ]; then 10 | MAIN_PORT=$2 11 | else 12 | MAIN_PORT=12345 13 | fi 14 | 15 | accelerate launch --config_file $ACC_CONFIG --main_process_port=$MAIN_PORT -m openlrm.launch train.human_lrm --config $TRAIN_CONFIG -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | # step1. set TRAIN_CONFIG path to config file 2 | 3 | TRAIN_CONFIG="configs/inference/lam-20k-8gpu.yaml" 4 | MODEL_NAME="model_zoo/lam_models/releases/lam/lam-20k/step_045500/" 5 | IMAGE_INPUT="assets/sample_input/status.png" 6 | MOTION_SEQS_DIR="assets/sample_motion/export/Look_In_My_Eyes/" 7 | 8 | 9 | TRAIN_CONFIG=${1:-$TRAIN_CONFIG} 10 | MODEL_NAME=${2:-$MODEL_NAME} 11 | IMAGE_INPUT=${3:-$IMAGE_INPUT} 12 | MOTION_SEQS_DIR=${4:-$MOTION_SEQS_DIR} 13 | 14 | echo "TRAIN_CONFIG: $TRAIN_CONFIG" 15 | echo "IMAGE_INPUT: $IMAGE_INPUT" 16 | echo "MODEL_NAME: $MODEL_NAME" 17 | echo "MOTION_SEQS_DIR: $MOTION_SEQS_DIR" 18 | 19 | 20 | MOTION_IMG_DIR=null 21 | SAVE_PLY=false 22 | SAVE_IMG=false 23 | VIS_MOTION=false 24 | MOTION_IMG_NEED_MASK=true 25 | RENDER_FPS=30 26 | MOTION_VIDEO_READ_FPS=30 27 | EXPORT_VIDEO=true 28 | CROSS_ID=false 29 | TEST_SAMPLE=false 30 | GAGA_TRACK_TYPE="" 31 | 32 | device=0 33 | nodes=0 34 | 35 | export PYTHONPATH=$PYTHONPATH:$pwd 36 | 37 | 38 | CUDA_VISIBLE_DEVICES=$device python -m lam.launch infer.lam --config $TRAIN_CONFIG \ 39 | model_name=$MODEL_NAME image_input=$IMAGE_INPUT \ 40 | export_video=$EXPORT_VIDEO export_mesh=$EXPORT_MESH \ 41 | motion_seqs_dir=$MOTION_SEQS_DIR motion_img_dir=$MOTION_IMG_DIR \ 42 | vis_motion=$VIS_MOTION motion_img_need_mask=$MOTION_IMG_NEED_MASK \ 43 | render_fps=$RENDER_FPS motion_video_read_fps=$MOTION_VIDEO_READ_FPS \ 44 | save_ply=$SAVE_PLY save_img=$SAVE_IMG \ 45 | gaga_track_type=$GAGA_TRACK_TYPE cross_id=$CROSS_ID \ 46 | test_sample=$TEST_SAMPLE rank=$device nodes=$nodes 47 | -------------------------------------------------------------------------------- /scripts/install/WINDOWS_INSTALL.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## Windows Installation Guide 5 | 6 | ### Base software 7 | 8 | - Python 3.10 9 | - Nvidia Cuda Toolkit 11.8 (You can also change to others) 10 | - Visual Studio 2019: 2022 will cause some compilation error on cuda operators. Download it from [techspot](https://www.techspot.com/downloads/7241-visual-studio-2019.html) 11 | 12 | 13 | 14 | ### Install Dependencies 15 | 16 | Note we use "x64 Native Tools" from Visual Studio as the compilation and do not use powershell or cmd. It offer MSVC environment for python package compilation. 17 | 18 | Open "x64 Native Tools" terminal and install dependencies: 19 | 20 | - Prepare environment: 21 | We recommend to use venv (or conda) to create a python environment: 22 | ```bash 23 | python -m venv lam_env 24 | lam_env\Scripts\activate 25 | git clone https://github.com/aigc3d/LAM.git 26 | git checkout feat/windows 27 | ``` 28 | 29 | - Install torch 2.3.0 and xformers 30 | ```bash 31 | pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118 32 | pip install -U xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu118 33 | ``` 34 | 35 | - Install python packages which do not need compilation: 36 | ```bash 37 | # pip install -r requirements.txt without the last 4 lines: 38 | head -n $((total_lines - 4)) requirements.txt | pip install -r /dev/stdin 39 | ``` 40 | 41 | - Install packages which need compilation 42 | ```bash 43 | # Install Pytorch3d, which follows: 44 | # https://blog.csdn.net/m0_70229101/article/details/127196699 45 | # https://blog.csdn.net/qq_61247019/article/details/139927752 46 | set DISTUTILS_USE_SDK=1 47 | set PYTORCH3D_NO_NINJA=1 48 | git clone https://github.com/facebookresearch/pytorch3d.git 49 | cd pytorch3d 50 | # modify setup.py 51 | # add "-DWIN32_LEAN_AND_MEAN" in nvcc_args 52 | python setup.py install 53 | 54 | # Install other packages 55 | pip install git+https://github.com/ashawkey/diff-gaussian-rasterization/ 56 | pip install nvdiffrast@git+https://github.com/ShenhanQian/nvdiffrast@backface-culling 57 | pip install git+https://github.com/camenduru/simple-knn/ 58 | 59 | cd external/landmark_detection/FaceBoxesV2/utils/ 60 | python3 build.py build_ext --inplace 61 | cd ../../../../ 62 | ``` 63 | 64 | 65 | ### Run 66 | 67 | ```bash 68 | python app_lam.py 69 | ``` -------------------------------------------------------------------------------- /scripts/install/install_cu118.sh: -------------------------------------------------------------------------------- 1 | # install torch 2.3.0 2 | pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118 3 | pip install -U xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu118 4 | 5 | # install dependencies 6 | pip install -r requirements.txt 7 | 8 | # === If you fail to install some modules due to network connection, you can also try the following: === 9 | # git clone https://github.com/facebookresearch/pytorch3d.git 10 | # pip install ./pytorch3d 11 | # git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization 12 | # pip install ./diff-gaussian-rasterization 13 | # git clone https://github.com/camenduru/simple-knn.git 14 | # pip install ./simple-knn 15 | 16 | cd external/landmark_detection/FaceBoxesV2/utils/ 17 | sh make.sh 18 | cd ../../../../ 19 | -------------------------------------------------------------------------------- /scripts/install/install_cu121.sh: -------------------------------------------------------------------------------- 1 | # install torch 2.3.0 2 | pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121 3 | pip install -U xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu121 4 | 5 | # install dependencies 6 | pip install -r requirements.txt 7 | 8 | # === If you fail to install some modules due to network connection, you can also try the following: === 9 | # git clone https://github.com/facebookresearch/pytorch3d.git 10 | # pip install ./pytorch3d 11 | # git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization 12 | # pip install ./diff-gaussian-rasterization 13 | # git clone https://github.com/camenduru/simple-knn.git 14 | # pip install ./simple-knn 15 | 16 | cd external/landmark_detection/FaceBoxesV2/utils/ 17 | sh make.sh 18 | cd ../../../../ 19 | -------------------------------------------------------------------------------- /scripts/upload_hub.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import sys 17 | 18 | sys.path.append(".") 19 | 20 | import argparse 21 | 22 | from accelerate import Accelerator 23 | 24 | from lam.models import model_dict 25 | from lam.utils.hf_hub import wrap_model_hub 26 | 27 | if __name__ == "__main__": 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--model_type", type=str, required=True) 31 | parser.add_argument("--local_ckpt", type=str, required=True) 32 | parser.add_argument("--repo_id", type=str, required=True) 33 | args, unknown = parser.parse_known_args() 34 | 35 | accelerator = Accelerator() 36 | 37 | hf_model_cls = wrap_model_hub(model_dict[args.model_type]) 38 | hf_model = hf_model_cls.from_pretrained(args.local_ckpt) 39 | hf_model.push_to_hub( 40 | repo_id=args.repo_id, 41 | config=hf_model.config, 42 | private=True, 43 | ) 44 | -------------------------------------------------------------------------------- /tools/AVATAR_EXPORT_GUIDE.md: -------------------------------------------------------------------------------- 1 | ## Export Chatting Avatar Guide 2 | ### 🛠️ Environment Setup 3 | #### Prerequisites 4 | ``` 5 | Python FBX SDK 2020.2+ 6 | Blender (version > 4.0.0) 7 | ``` 8 | #### Step1: download and install python fbx-sdk and other requirements 9 | ```bash 10 | # FBX SDK: https://www.autodesk.com/developer-network/platform-technologies/fbx-sdk-2020-2 11 | # Download FBX SDK installation package, example for Linux 12 | wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/fbxsdk_linux.tar 13 | tar -xf fbxsdk_linux.tar 14 | sh tools/install_fbx_sdk.sh 15 | 16 | # Or download and install the FBX-SDK Wheel built by us 17 | wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/fbx-2020.3.4-cp310-cp310-manylinux1_x86_64.whl 18 | pip install fbx-2020.3.4-cp310-cp310-manylinux1_x86_64.whl 19 | 20 | # Install other requirements 21 | pip install pathlib 22 | pip install patool 23 | ``` 24 | #### Step2: download blender 25 | ```bash 26 | # Download latest Blender (>=4.0.0) 27 | # Choose appropriate version from: https://www.blender.org/download/ 28 | # Example for Linux 29 | wget https://download.blender.org/release/Blender4.0/blender-4.0.2-linux-x64.tar.xz 30 | tar -xvf blender-4.0.2-linux-x64.tar.xz -C ~/software/ 31 | ``` 32 | #### Step3: download chatting avatar template file 33 | ```bash 34 | # Download and extract sample files 35 | wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/sample_oac.tar 36 | tar -xf sample_oac.tar -C assets/ 37 | ``` 38 | 39 | ### Gradio Run 40 | ```bash 41 | # Example path for Blender executable 42 | python app_lam.py --blender_path ~/software/blender-4.0.2-linux-x64/blender 43 | ``` -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/tools/__init__.py -------------------------------------------------------------------------------- /tools/convertFBX2GLB.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors. 3 | 4 | Blender FBX to GLB Converter 5 | Converts 3D models from FBX to glTF Binary (GLB) format with optimized settings. 6 | Requires Blender to run in background mode. 7 | """ 8 | 9 | import bpy 10 | import sys 11 | from pathlib import Path 12 | 13 | def clean_scene(): 14 | """Clear all objects and data from the current Blender scene""" 15 | bpy.ops.object.select_all(action='SELECT') 16 | bpy.ops.object.delete() 17 | for collection in [bpy.data.meshes, bpy.data.materials, bpy.data.textures]: 18 | for item in collection: 19 | collection.remove(item) 20 | 21 | 22 | def main(): 23 | try: 24 | # Parse command line arguments after "--" 25 | argv = sys.argv[sys.argv.index("--") + 1:] 26 | input_fbx = Path(argv[0]) 27 | output_glb = Path(argv[1]) 28 | 29 | # Validate input file 30 | if not input_fbx.exists(): 31 | raise FileNotFoundError(f"Input FBX file not found: {input_fbx}") 32 | 33 | # Prepare scene 34 | clean_scene() 35 | 36 | # Import FBX with default settings 37 | print(f"Importing {input_fbx}...") 38 | bpy.ops.import_scene.fbx(filepath=str(input_fbx)) 39 | 40 | # Export optimized GLB 41 | print(f"Exporting to {output_glb}...") 42 | bpy.ops.export_scene.gltf( 43 | filepath=str(output_glb), 44 | export_format='GLB', # Binary format 45 | export_skins=True, # Keep skinning data 46 | export_texcoords=False, # Reduce file size 47 | export_normals=False, # Reduce file size 48 | export_colors=False, # Reduce file size 49 | ) 50 | 51 | print("Conversion completed successfully") 52 | 53 | except Exception as e: 54 | print(f"Error: {str(e)}") 55 | sys.exit(1) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /tools/generateVertexIndices.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors. 3 | 4 | Blender FBX to GLB Converter 5 | Converts 3D models from FBX to glTF Binary (GLB) format with optimized settings. 6 | Requires Blender to run in background mode. 7 | """ 8 | 9 | import bpy 10 | import sys 11 | import os 12 | import json 13 | from pathlib import Path 14 | 15 | def import_obj(filepath): 16 | if not os.path.exists(filepath): 17 | raise FileNotFoundError(f"文件不存在:{filepath}") 18 | bpy.ops.wm.obj_import(filepath=filepath) 19 | print(f"成功导入:{filepath}") 20 | 21 | 22 | def clean_scene(): 23 | """Clear all objects and data from the current Blender scene""" 24 | bpy.ops.object.select_all(action='SELECT') 25 | bpy.ops.object.delete() 26 | for collection in [bpy.data.meshes, bpy.data.materials, bpy.data.textures]: 27 | for item in collection: 28 | collection.remove(item) 29 | 30 | def apply_rotation(obj): 31 | obj.rotation_euler = (1.5708, 0, 0) 32 | bpy.context.view_layer.update() 33 | obj.select_set(True) 34 | bpy.context.view_layer.objects.active = obj 35 | bpy.ops.object.transform_apply(location=False, rotation=True, scale=False) # 应用旋转 36 | print(f"Applied 90-degree rotation to object: {obj.name}") 37 | 38 | def main(): 39 | try: 40 | # Parse command line arguments after "--" 41 | argv = sys.argv[sys.argv.index("--") + 1:] 42 | input_mesh = Path(argv[0]) 43 | output_vertex_order_file = argv[1] 44 | 45 | # Validate input file 46 | if not input_mesh.exists(): 47 | raise FileNotFoundError(f"Input FBX file not found: {input_mesh}") 48 | 49 | # Prepare scene 50 | clean_scene() 51 | 52 | # Import FBX with default settings 53 | print(f"Importing {input_mesh}...") 54 | import_obj(str(input_mesh)) 55 | base_obj = bpy.context.view_layer.objects.active 56 | 57 | apply_rotation(base_obj) 58 | 59 | bpy.context.view_layer.objects.active = base_obj 60 | base_obj.select_set(True) 61 | bpy.ops.object.mode_set(mode='OBJECT') 62 | 63 | base_objects = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH'] 64 | if len(base_objects) != 1: 65 | raise ValueError("Scene should contain exactly one base mesh object.") 66 | base_obj = base_objects[0] 67 | 68 | vertices = [(i, v.co.z) for i, v in enumerate(base_obj.data.vertices)] 69 | 70 | sorted_vertices = sorted(vertices, key=lambda x: x[1]) # 按 Z 坐标从小到大排序 71 | sorted_vertex_indices = [idx for idx, z in sorted_vertices] 72 | 73 | with open(str(output_vertex_order_file), "w") as f: 74 | json.dump(sorted_vertex_indices, f, indent=4) # 保存为 JSON 数组 75 | print(f"Exported vertex order to: {str(output_vertex_order_file)}") 76 | 77 | 78 | except Exception as e: 79 | print(f"Error: {str(e)}") 80 | sys.exit(1) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /tools/install_fbx_sdk.sh: -------------------------------------------------------------------------------- 1 | cd fbxsdk_linux 2 | chmod +x fbx202034_fbxsdk_linux fbx202034_fbxpythonbindings_linux 3 | mkdir -p ./python_binding ./python_binding/fbx_sdk 4 | yes yes | ./fbx202034_fbxpythonbindings_linux ./python_binding 5 | yes yes | ./fbx202034_fbxsdk_linux ./python_binding/fbx_sdk 6 | cd ./python_binding 7 | export FBXSDK_ROOT=./fbx_sdk 8 | pip install . 9 | cd - -------------------------------------------------------------------------------- /vhap/config/nersemble.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from typing import Optional, Literal 11 | from dataclasses import dataclass 12 | import tyro 13 | 14 | from vhap.config.base import ( 15 | StageRgbSequentialTrackingConfig, StageRgbGlobalTrackingConfig, PipelineConfig, 16 | DataConfig, LossWeightConfig, BaseTrackingConfig, 17 | ) 18 | from vhap.util.log import get_logger 19 | logger = get_logger(__name__) 20 | 21 | 22 | @dataclass() 23 | class NersembleDataConfig(DataConfig): 24 | _target: str = "vhap.data.nersemble_dataset.NeRSembleDataset" 25 | calibrated: bool = True 26 | image_size_during_calibration: Optional[tuple[int, int]] = (3208, 2200) 27 | """(height, width). Will be use to convert principle points when the image size is not included in the camera parameters.""" 28 | background_color: Optional[Literal['white', 'black']] = None 29 | landmark_source: Optional[Literal["face-alignment", 'star']] = "star" 30 | 31 | subject: str = "" 32 | """Subject ID. Such as 018, 218, 251, 253""" 33 | use_color_correction: bool = True 34 | """Whether to use color correction to harmonize the color of the input images.""" 35 | 36 | @dataclass() 37 | class NersembleLossWeightConfig(LossWeightConfig): 38 | landmark: Optional[float] = 3. # should not be lower to avoid collapse 39 | always_enable_jawline_landmarks: bool = False # allow disable_jawline_landmarks in StageConfig to work 40 | reg_expr: float = 1e-2 # for best expressivness 41 | reg_tex_tv: Optional[float] = 1e5 # 10x of the base value 42 | 43 | @dataclass() 44 | class NersembleStageRgbSequentialTrackingConfig(StageRgbSequentialTrackingConfig): 45 | optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "dynamic_offset") 46 | 47 | align_texture_except: tuple[str, ...] = ("boundary",) 48 | align_boundary_except: tuple[str, ...] = ("boundary",) 49 | """Due to the limited flexibility in the lower neck region of FLAME, we relax the 50 | alignment constraints for better alignment in the face region. 51 | """ 52 | 53 | @dataclass() 54 | class NersembleStageRgbGlobalTrackingConfig(StageRgbGlobalTrackingConfig): 55 | align_texture_except: tuple[str, ...] = ("boundary",) 56 | align_boundary_except: tuple[str, ...] = ("boundary",) 57 | """Due to the limited flexibility in the lower neck region of FLAME, we relax the 58 | alignment constraints for better alignment in the face region. 59 | """ 60 | 61 | @dataclass() 62 | class NersemblePipelineConfig(PipelineConfig): 63 | rgb_sequential_tracking: NersembleStageRgbSequentialTrackingConfig 64 | rgb_global_tracking: NersembleStageRgbGlobalTrackingConfig 65 | 66 | @dataclass() 67 | class NersembleTrackingConfig(BaseTrackingConfig): 68 | data: NersembleDataConfig 69 | w: NersembleLossWeightConfig 70 | pipeline: NersemblePipelineConfig 71 | 72 | def get_occluded(self): 73 | occluded_table = { 74 | '018': ('neck_lower',), 75 | '218': ('neck_lower',), 76 | '251': ('neck_lower', 'boundary'), 77 | '253': ('neck_lower',), 78 | } 79 | if self.data.subject in occluded_table: 80 | logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.subject]}") 81 | self.model.occluded = occluded_table[self.data.subject] 82 | 83 | 84 | if __name__ == "__main__": 85 | config = tyro.cli(NersembleTrackingConfig) 86 | print(tyro.to_yaml(config)) -------------------------------------------------------------------------------- /vhap/data/image_folder_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | import numpy as np 4 | import PIL.Image as Image 5 | from torch.utils.data import Dataset 6 | from vhap.util.log import get_logger 7 | 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | class ImageFolderDataset(Dataset): 13 | def __init__( 14 | self, 15 | image_folder: Path, 16 | background_folder: Optional[Path]=None, 17 | background_fname2camId=lambda x: x, 18 | image_fname2camId=lambda x: x, 19 | ): 20 | """ 21 | Args: 22 | root_folder: Path to dataset with the following directory layout 23 | / 24 | |---xx.jpg 25 | |---... 26 | """ 27 | super().__init__() 28 | self.image_fname2camId = image_fname2camId 29 | self.background_foler = background_folder 30 | 31 | logger.info(f"Initializing dataset from folder {image_folder}") 32 | 33 | self.image_paths = sorted(list(image_folder.glob('*.jpg'))) 34 | 35 | if background_folder is not None: 36 | self.backgrounds = {} 37 | background_paths = sorted(list((image_folder / background_folder).glob('*.jpg'))) 38 | 39 | for background_path in background_paths: 40 | bg = np.array(Image.open(background_path)) 41 | cam_id = background_fname2camId(background_path.name) 42 | self.backgrounds[cam_id] = bg 43 | 44 | def __len__(self): 45 | return len(self.image_paths) 46 | 47 | def __getitem__(self, i): 48 | image_path = self.image_paths[i] 49 | cam_id = self.image_fname2camId(image_path.name) 50 | rgb = np.array(Image.open(image_path)) 51 | item = { 52 | "rgb": rgb, 53 | 'image_path': str(image_path), 54 | } 55 | 56 | if self.background_foler is not None: 57 | item['background'] = self.backgrounds[cam_id] 58 | 59 | return item 60 | 61 | 62 | if __name__ == "__main__": 63 | from tqdm import tqdm 64 | from torch.utils.data import DataLoader 65 | 66 | dataset = ImageFolderDataset( 67 | image_folder='./xx', 68 | img_to_tensor=True, 69 | ) 70 | 71 | print(len(dataset)) 72 | 73 | sample = dataset[0] 74 | print(sample.keys()) 75 | print(sample["rgb"].shape) 76 | 77 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) 78 | for item in tqdm(dataloader): 79 | pass 80 | -------------------------------------------------------------------------------- /vhap/generate_flame_uvmask.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from typing import Literal 11 | import tyro 12 | import numpy as np 13 | from PIL import Image 14 | from pathlib import Path 15 | import torch 16 | import nvdiffrast.torch as dr 17 | from vhap.util.render_uvmap import render_uvmap_vtex 18 | from vhap.model.flame import FlameHead 19 | 20 | 21 | FLAME_UV_MASK_FOLDER = "asset/flame/uv_masks" 22 | FLAME_UV_MASK_NPZ = "asset/flame/uv_masks.npz" 23 | 24 | 25 | def main( 26 | use_opengl: bool = False, 27 | device: Literal['cuda', 'cpu'] = 'cuda', 28 | ): 29 | n_shape = 300 30 | n_expr = 100 31 | print("Initializing FLAME model") 32 | flame_model = FlameHead(n_shape, n_expr, add_teeth=True) 33 | 34 | flame_model = FlameHead( 35 | n_shape, 36 | n_expr, 37 | add_teeth=True, 38 | ).cuda() 39 | 40 | faces = flame_model.faces.int().cuda() 41 | verts_uv = flame_model.verts_uvs.cuda() 42 | # verts_uv[:, 1] = 1 - verts_uv[:, 1] 43 | faces_uv = flame_model.textures_idx.int().cuda() 44 | col_idx = faces_uv 45 | 46 | # Rasterizer context 47 | glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext() 48 | 49 | h, w = 2048, 2048 50 | resolution = (h, w) 51 | 52 | if not Path(FLAME_UV_MASK_FOLDER).exists(): 53 | Path(FLAME_UV_MASK_FOLDER).mkdir(parents=True) 54 | 55 | # alpha_maps = {} 56 | masks = {} 57 | for region, vt_mask in flame_model.mask.vt: 58 | v_color = torch.zeros(verts_uv.shape[0], 1).to(device) # alpha channel 59 | v_color[vt_mask] = 1 60 | 61 | alpha = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)[0] 62 | alpha = alpha.flip(0) 63 | # alpha_maps[region] = alpha.cpu().numpy() 64 | mask = (alpha > 0.5) # to avoid overlap between hair and face 65 | mask = mask.squeeze(-1).cpu().numpy() 66 | masks[region] = mask # (h, w) 67 | 68 | print(f"Saving uv mask for {region}...") 69 | # rgba = mask.expand(-1, -1, 4) # (h, w, 4) 70 | # rgb = torch.ones_like(mask).expand(-1, -1, 3) # (h, w, 3) 71 | # rgba = torch.cat([rgb, mask], dim=-1).cpu().numpy() # (h, w, 4) 72 | img = mask 73 | img = Image.fromarray((img * 255).astype(np.uint8)) 74 | img.save(Path(FLAME_UV_MASK_FOLDER) / f"{region}.png") 75 | 76 | print(f"Saving uv mask into: {FLAME_UV_MASK_NPZ}") 77 | np.savez_compressed(FLAME_UV_MASK_NPZ, **masks) 78 | 79 | 80 | if __name__ == "__main__": 81 | tyro.cli(main) -------------------------------------------------------------------------------- /vhap/track.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import tyro 11 | 12 | from vhap.config.base import BaseTrackingConfig 13 | from vhap.model.tracker import GlobalTracker 14 | 15 | 16 | if __name__ == "__main__": 17 | tyro.extras.set_accent_color("bright_yellow") 18 | cfg = tyro.cli(BaseTrackingConfig) 19 | 20 | tracker = GlobalTracker(cfg) 21 | tracker.optimize() 22 | -------------------------------------------------------------------------------- /vhap/track_nersemble.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import tyro 11 | 12 | from vhap.config.nersemble import NersembleTrackingConfig 13 | from vhap.model.tracker import GlobalTracker 14 | 15 | 16 | if __name__ == "__main__": 17 | tyro.extras.set_accent_color("bright_yellow") 18 | cfg = tyro.cli(NersembleTrackingConfig) 19 | 20 | tracker = GlobalTracker(cfg) 21 | tracker.optimize() 22 | -------------------------------------------------------------------------------- /vhap/util/log.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import logging 11 | import sys 12 | from datetime import datetime 13 | import atexit 14 | from pathlib import Path 15 | 16 | 17 | def _colored(msg, color): 18 | colors = {'red': '\033[91m', 'green': '\033[92m', 'yellow': '\033[93m', 'normal': '\033[0m'} 19 | return colors[color] + msg + colors["normal"] 20 | 21 | 22 | class ColorFormatter(logging.Formatter): 23 | """ 24 | Class to make command line log entries more appealing 25 | Inspired by https://github.com/facebookresearch/detectron2 26 | """ 27 | 28 | def formatMessage(self, record): 29 | """ 30 | Print warnings yellow and errors red 31 | :param record: 32 | :return: 33 | """ 34 | log = super().formatMessage(record) 35 | if record.levelno == logging.WARNING: 36 | prefix = _colored("WARNING", "yellow") 37 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 38 | prefix = _colored("ERROR", "red") 39 | else: 40 | return log 41 | return prefix + " " + log 42 | 43 | 44 | def get_logger(name, level=logging.DEBUG, root=False, log_dir=None): 45 | """ 46 | Replaces the standard library logging.getLogger call in order to make some configuration 47 | for all loggers. 48 | :param name: pass the __name__ variable 49 | :param level: the desired log level 50 | :param root: call only once in the program 51 | :param log_dir: if root is set to True, this defines the directory where a log file is going 52 | to be created that contains all logging output 53 | :return: the logger object 54 | """ 55 | logger = logging.getLogger(name) 56 | logger.setLevel(level) 57 | 58 | if root: 59 | # create handler for console 60 | console_handler = logging.StreamHandler(sys.stdout) 61 | console_handler.setLevel(level) 62 | formatter = ColorFormatter(_colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 63 | datefmt="%m/%d %H:%M:%S") 64 | console_handler.setFormatter(formatter) 65 | logger.addHandler(console_handler) 66 | logger.propagate = False # otherwise root logger prints things again 67 | 68 | if log_dir is not None: 69 | # add handler to log to a file 70 | log_dir = Path(log_dir) 71 | if not log_dir.exists(): 72 | logger.info(f"Logging directory {log_dir} does not exist and will be created") 73 | log_dir.mkdir(parents=True) 74 | timestamp = datetime.now().strftime("%d-%m-%Y_%H-%M-%S") 75 | log_file = log_dir / f"{timestamp}.log" 76 | 77 | # open stream and make sure it will be closed 78 | stream = log_file.open(mode="w") 79 | atexit.register(stream.close) 80 | 81 | formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s", 82 | datefmt="%m/%d %H:%M:%S") 83 | file_handler = logging.StreamHandler(stream) 84 | file_handler.setLevel(level) 85 | file_handler.setFormatter(formatter) 86 | logger.addHandler(file_handler) 87 | 88 | return logger 89 | -------------------------------------------------------------------------------- /vhap/util/mesh.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import torch 11 | 12 | 13 | def get_mtl_content(tex_fname): 14 | return f'newmtl Material\nmap_Kd {tex_fname}\n' 15 | 16 | def get_obj_content(vertices, faces, uv_coordinates=None, uv_indices=None, mtl_fname=None): 17 | obj = ('# Generated with multi-view-head-tracker\n') 18 | 19 | if mtl_fname is not None: 20 | obj += f'mtllib {mtl_fname}\n' 21 | obj += 'usemtl Material\n' 22 | 23 | # Write the vertices 24 | for vertex in vertices: 25 | obj += f"v {vertex[0]} {vertex[1]} {vertex[2]}\n" 26 | 27 | # Write the UV coordinates 28 | if uv_coordinates is not None: 29 | for uv in uv_coordinates: 30 | obj += f"vt {uv[0]} {uv[1]}\n" 31 | 32 | # Write the faces with UV indices 33 | if uv_indices is not None: 34 | for face, uv_indices in zip(faces, uv_indices): 35 | obj += f"f {face[0]+1}/{uv_indices[0]+1} {face[1]+1}/{uv_indices[1]+1} {face[2]+1}/{uv_indices[2]+1}\n" 36 | else: 37 | for face in faces: 38 | obj += f"f {face[0]+1} {face[1]+1} {face[2]+1}\n" 39 | return obj 40 | 41 | def normalize_image_points(u, v, resolution): 42 | """ 43 | normalizes u, v coordinates from [0 ,image_size] to [-1, 1] 44 | :param u: 45 | :param v: 46 | :param resolution: 47 | :return: 48 | """ 49 | u = 2 * (u - resolution[1] / 2.0) / resolution[1] 50 | v = 2 * (v - resolution[0] / 2.0) / resolution[0] 51 | return u, v 52 | 53 | 54 | def face_vertices(vertices, faces): 55 | """ 56 | :param vertices: [batch size, number of vertices, 3] 57 | :param faces: [batch size, number of faces, 3] 58 | :return: [batch size, number of faces, 3, 3] 59 | """ 60 | assert vertices.ndimension() == 3 61 | assert faces.ndimension() == 3 62 | assert vertices.shape[0] == faces.shape[0] 63 | assert vertices.shape[2] == 3 64 | assert faces.shape[2] == 3 65 | 66 | bs, nv = vertices.shape[:2] 67 | bs, nf = faces.shape[:2] 68 | device = vertices.device 69 | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] 70 | vertices = vertices.reshape((bs * nv, 3)) 71 | # pytorch only supports long and byte tensors for indexing 72 | return vertices[faces.long()] 73 | 74 | -------------------------------------------------------------------------------- /vhap/util/render_uvmap.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import tyro 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch 14 | import nvdiffrast.torch as dr 15 | 16 | from vhap.model.flame import FlameHead 17 | 18 | 19 | FLAME_TEX_PATH = "asset/flame/FLAME_texture.npz" 20 | 21 | 22 | def transform_vt(vt): 23 | """Transform uv vertices to clip space""" 24 | xy = vt * 2 - 1 25 | w = torch.ones([1, vt.shape[-2], 1]).to(vt) 26 | z = -w # In the clip spcae of OpenGL, the camera looks at -z 27 | xyzw = torch.cat([xy[None, :, :], z, w], axis=-1) 28 | return xyzw 29 | 30 | def render_uvmap_vtex(glctx, pos, pos_idx, v_color, col_idx, resolution): 31 | """Render uv map with vertex color""" 32 | pos_clip = transform_vt(pos) 33 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution) 34 | 35 | color, _ = dr.interpolate(v_color, rast_out, col_idx) 36 | color = dr.antialias(color, rast_out, pos_clip, pos_idx) 37 | return color 38 | 39 | def render_uvmap_texmap(glctx, pos, pos_idx, verts_uv, faces_uv, tex, resolution, enable_mip=True, max_mip_level=None): 40 | """Render uv map with texture map""" 41 | pos_clip = transform_vt(pos) 42 | rast_out, rast_out_db = dr.rasterize(glctx, pos_clip, pos_idx, resolution) 43 | 44 | if enable_mip: 45 | texc, texd = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv, rast_db=rast_out_db, diff_attrs='all') 46 | color = dr.texture(tex[None, ...], texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=max_mip_level) 47 | else: 48 | texc, _ = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv) 49 | color = dr.texture(tex[None, ...], texc, filter_mode='linear') 50 | color = dr.antialias(color, rast_out, pos_clip, pos_idx) 51 | return color 52 | 53 | 54 | def main( 55 | use_texmap: bool = False, 56 | use_opengl: bool = False, 57 | ): 58 | n_shape = 300 59 | n_expr = 100 60 | print("Initialization FLAME model") 61 | flame_model = FlameHead(n_shape, n_expr) 62 | 63 | verts_uv = flame_model.verts_uvs.cuda() 64 | verts_uv[:, 1] = 1 - verts_uv[:, 1] 65 | faces_uv = flame_model.textures_idx.int().cuda() 66 | 67 | # Rasterizer context 68 | glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext() 69 | 70 | h, w = 512, 512 71 | resolution = (h, w) 72 | 73 | if use_texmap: 74 | tex = torch.from_numpy(np.load(FLAME_TEX_PATH)['mean']).cuda().float().flip(dims=[-1]) / 255 75 | rgb = render_uvmap_texmap(glctx, verts_uv, faces_uv, verts_uv, faces_uv, tex, resolution, enable_mip=True) 76 | else: 77 | v_color = torch.ones(verts_uv.shape[0], 3).to(verts_uv) 78 | col_idx = faces_uv 79 | rgb = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution) 80 | 81 | plt.imshow(rgb[0, :, :].cpu()) 82 | plt.show() 83 | 84 | 85 | if __name__ == "__main__": 86 | tyro.cli(main) 87 | -------------------------------------------------------------------------------- /vhap/util/vector_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 5 | return torch.sum(x*y, -1, keepdim=True) 6 | 7 | def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: 8 | return 2*dot(x, n)*n - x 9 | 10 | def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: 11 | return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN 12 | 13 | def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: 14 | return x / length(x, eps) 15 | 16 | def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: 17 | return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) 18 | -------------------------------------------------------------------------------- /vhap/util/visualization.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import matplotlib.pyplot as plt 11 | import torch 12 | from torchvision.utils import draw_bounding_boxes, draw_keypoints 13 | 14 | 15 | connectivity_face = ( 16 | [(i, i + 1) for i in list(range(0, 16))] 17 | + [(i, i + 1) for i in list(range(17, 21))] 18 | + [(i, i + 1) for i in list(range(22, 26))] 19 | + [(i, i + 1) for i in list(range(27, 30))] 20 | + [(i, i + 1) for i in list(range(31, 35))] 21 | + [(i, i + 1) for i in list(range(36, 41))] 22 | + [(36, 41)] 23 | + [(i, i + 1) for i in list(range(42, 47))] 24 | + [(42, 47)] 25 | + [(i, i + 1) for i in list(range(48, 59))] 26 | + [(48, 59)] 27 | + [(i, i + 1) for i in list(range(60, 67))] 28 | + [(60, 67)] 29 | ) 30 | 31 | 32 | def plot_landmarks_2d( 33 | img: torch.tensor, 34 | lmks: torch.tensor, 35 | connectivity=None, 36 | colors="white", 37 | unit=1, 38 | input_float=False, 39 | ): 40 | if input_float: 41 | img = (img * 255).byte() 42 | 43 | img = draw_keypoints( 44 | img, 45 | lmks, 46 | connectivity=connectivity, 47 | colors=colors, 48 | radius=2 * unit, 49 | width=2 * unit, 50 | ) 51 | 52 | if input_float: 53 | img = img.float() / 255 54 | return img 55 | 56 | 57 | def blend(a, b, w): 58 | return (a * w + b * (1 - w)).byte() 59 | 60 | 61 | if __name__ == "__main__": 62 | from argparse import ArgumentParser 63 | from torch.utils.data import DataLoader 64 | from matplotlib import pyplot as plt 65 | 66 | from vhap.data.nersemble_dataset import NeRSembleDataset 67 | 68 | parser = ArgumentParser() 69 | parser.add_argument("--root_folder", type=str, required=True) 70 | parser.add_argument("--subject", type=str, required=True) 71 | parser.add_argument("--sequence", type=str, required=True) 72 | parser.add_argument("--division", default=None) 73 | parser.add_argument("--subset", default=None) 74 | parser.add_argument("--scale_factor", type=float, default=1.0) 75 | parser.add_argument("--blend_weight", type=float, default=0.6) 76 | args = parser.parse_args() 77 | 78 | dataset = NeRSembleDataset( 79 | root_folder=args.root_folder, 80 | subject=args.subject, 81 | sequence=args.sequence, 82 | division=args.division, 83 | subset=args.subset, 84 | n_downsample_rgb=2, 85 | scale_factor=args.scale_factor, 86 | use_landmark=True, 87 | ) 88 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) 89 | 90 | for item in dataloader: 91 | unit = int(item["scale_factor"][0] * 3) + 1 92 | 93 | rgb = item["rgb"][0].permute(2, 0, 1) 94 | vis = rgb 95 | 96 | if "bbox_2d" in item: 97 | bbox = item["bbox_2d"][0][:4] 98 | tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit) 99 | vis = blend(tmp, vis, args.blend_weight) 100 | 101 | if "lmk2d" in item: 102 | face_landmark = item["lmk2d"][0][:, :2] 103 | tmp = plot_landmarks_2d( 104 | vis, 105 | face_landmark[None, ...], 106 | connectivity=connectivity_face, 107 | colors="white", 108 | unit=unit, 109 | ) 110 | vis = blend(tmp, vis, args.blend_weight) 111 | 112 | if "lmk2d_iris" in item: 113 | iris_landmark = item["lmk2d_iris"][0][:, :2] 114 | tmp = plot_landmarks_2d( 115 | vis, 116 | iris_landmark[None, ...], 117 | colors="blue", 118 | unit=unit, 119 | ) 120 | vis = blend(tmp, vis, args.blend_weight) 121 | 122 | vis = vis.permute(1, 2, 0).numpy() 123 | plt.imshow(vis) 124 | plt.draw() 125 | while not plt.waitforbuttonpress(timeout=-1): 126 | pass 127 | --------------------------------------------------------------------------------