├── .gitignore
├── LICENSE
├── README.md
├── README_CN.md
├── app_hf_space.py
├── app_lam.py
├── assets
└── images
│ ├── logo.jpeg
│ └── teaser.jpg
├── configs
├── inference
│ └── lam-20k-8gpu.yaml
├── stylematte_config.json
└── vhap_tracking
│ └── base_tracking_config.yaml
├── external
├── human_matting
│ ├── __init__.py
│ ├── matting_engine.py
│ └── stylematte.py
├── landmark_detection
│ ├── FaceBoxesV2
│ │ ├── __init__.py
│ │ ├── detector.py
│ │ ├── faceboxes_detector.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── box_utils.py
│ │ │ ├── build.py
│ │ │ ├── build
│ │ │ └── temp.linux-x86_64-cpython-310
│ │ │ │ └── nms
│ │ │ │ └── cpu_nms.o
│ │ │ ├── config.py
│ │ │ ├── faceboxes.py
│ │ │ ├── make.sh
│ │ │ ├── nms
│ │ │ ├── __init__.py
│ │ │ ├── cpu_nms.c
│ │ │ ├── cpu_nms.py
│ │ │ ├── cpu_nms.pyx
│ │ │ ├── gpu_nms.hpp
│ │ │ ├── gpu_nms.pyx
│ │ │ ├── nms_kernel.cu
│ │ │ └── py_cpu_nms.py
│ │ │ ├── nms_wrapper.py
│ │ │ ├── prior_box.py
│ │ │ └── timer.py
│ ├── README.md
│ ├── conf
│ │ ├── __init__.py
│ │ ├── alignment.py
│ │ └── base.py
│ ├── config.json
│ ├── data_processor
│ │ ├── CheckFaceKeyPoint.py
│ │ ├── align.py
│ │ └── process_pcd.py
│ ├── evaluate.py
│ ├── infer_folder.py
│ ├── infer_image.py
│ ├── infer_video.py
│ ├── lib
│ │ ├── __init__.py
│ │ ├── backbone
│ │ │ ├── __init__.py
│ │ │ ├── core
│ │ │ │ └── coord_conv.py
│ │ │ └── stackedHGNetV1.py
│ │ ├── dataset
│ │ │ ├── __init__.py
│ │ │ ├── alignmentDataset.py
│ │ │ ├── augmentation.py
│ │ │ ├── decoder
│ │ │ │ ├── __init__.py
│ │ │ │ └── decoder_default.py
│ │ │ └── encoder
│ │ │ │ ├── __init__.py
│ │ │ │ └── encoder_default.py
│ │ ├── loss
│ │ │ ├── __init__.py
│ │ │ ├── awingLoss.py
│ │ │ ├── smoothL1Loss.py
│ │ │ ├── starLoss.py
│ │ │ ├── starLoss_v2.py
│ │ │ └── wingLoss.py
│ │ ├── metric
│ │ │ ├── __init__.py
│ │ │ ├── accuracy.py
│ │ │ ├── fr_and_auc.py
│ │ │ ├── nme.py
│ │ │ └── params.py
│ │ ├── utility.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── dist_utils.py
│ │ │ ├── meter.py
│ │ │ ├── time_utils.py
│ │ │ └── vis_utils.py
│ ├── requirements.txt
│ ├── tester.py
│ └── tools
│ │ ├── analysis_motivation.py
│ │ ├── infinite_loop.py
│ │ ├── infinite_loop_gpu.py
│ │ ├── split_wflw.py
│ │ └── testtime_pca.py
└── vgghead_detector
│ ├── VGGDetector.py
│ ├── __init__.py
│ ├── utils_lmks_detector.py
│ └── utils_vgghead.py
├── lam
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── base.py
│ ├── cam_utils.py
│ ├── mixer.py
│ └── video_head.py
├── launch.py
├── losses
│ ├── __init__.py
│ ├── perceptual.py
│ ├── pixelwise.py
│ └── tvloss.py
├── models
│ ├── __init__.py
│ ├── block.py
│ ├── discriminator.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── dinov2
│ │ │ ├── __init__.py
│ │ │ ├── hub
│ │ │ │ ├── __init__.py
│ │ │ │ ├── backbones.py
│ │ │ │ ├── classifiers.py
│ │ │ │ ├── depth
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── decode_heads.py
│ │ │ │ │ ├── encoder_decoder.py
│ │ │ │ │ └── ops.py
│ │ │ │ ├── depthers.py
│ │ │ │ └── utils.py
│ │ │ ├── layers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── block.py
│ │ │ │ ├── dino_head.py
│ │ │ │ ├── drop_path.py
│ │ │ │ ├── layer_scale.py
│ │ │ │ ├── mlp.py
│ │ │ │ ├── patch_embed.py
│ │ │ │ └── swiglu_ffn.py
│ │ │ └── models
│ │ │ │ ├── __init__.py
│ │ │ │ └── vision_transformer.py
│ │ ├── dinov2_fusion_wrapper.py
│ │ └── dpt_util
│ │ │ ├── __init__.py
│ │ │ ├── blocks.py
│ │ │ └── transform.py
│ ├── modeling_lam.py
│ ├── modulate.py
│ ├── rendering
│ │ ├── __init__.py
│ │ ├── flame_model
│ │ │ ├── flame.py
│ │ │ ├── flame_arkit.py
│ │ │ └── lbs.py
│ │ ├── gaussian_model.py
│ │ ├── gs_renderer.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── math_utils.py
│ │ │ ├── mesh_utils.py
│ │ │ ├── point_utils.py
│ │ │ ├── renderer.py
│ │ │ ├── sh_utils.py
│ │ │ ├── typing.py
│ │ │ ├── utils.py
│ │ │ ├── uv_utils.py
│ │ │ └── vis_utils.py
│ ├── transformer.py
│ └── transformer_dit.py
├── runners
│ ├── __init__.py
│ ├── abstract.py
│ └── infer
│ │ ├── __init__.py
│ │ ├── base_inferrer.py
│ │ ├── head_utils.py
│ │ ├── lam.py
│ │ └── utils.py
└── utils
│ ├── __init__.py
│ ├── compile.py
│ ├── ffmpeg_utils.py
│ ├── gen_id_json.py
│ ├── gen_json.py
│ ├── hf_hub.py
│ ├── logging.py
│ ├── preprocess.py
│ ├── profiler.py
│ ├── proxy.py
│ ├── registry.py
│ ├── scheduler.py
│ └── video.py
├── requirements.txt
├── scripts
├── convert_hf.py
├── exp
│ ├── run_4gpu.sh
│ ├── run_8gpu.sh
│ └── run_debug.sh
├── inference.sh
├── install
│ ├── WINDOWS_INSTALL.md
│ ├── install_cu118.sh
│ └── install_cu121.sh
└── upload_hub.py
├── tools
├── AVATAR_EXPORT_GUIDE.md
├── __init__.py
├── convertFBX2GLB.py
├── flame_tracking_single_image.py
├── generateARKITGLBWithBlender.py
├── generateGLBWithBlender_v2.py
├── generateVertexIndices.py
└── install_fbx_sdk.sh
└── vhap
├── combine_nerf_datasets.py
├── config
├── base.py
└── nersemble.py
├── data
├── image_folder_dataset.py
├── nerf_dataset.py
├── nersemble_dataset.py
└── video_dataset.py
├── export_as_nerf_dataset.py
├── flame_editor.py
├── flame_viewer.py
├── generate_flame_uvmask.py
├── model
├── flame.py
├── lbs.py
└── tracker.py
├── track.py
├── track_nersemble.py
└── util
├── camera.py
├── landmark_detector_fa.py
├── landmark_detector_star.py
├── log.py
├── mesh.py
├── render_nvdiffrast.py
├── render_uvmap.py
├── vector_ops.py
└── visualization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | wheels
2 | __pycache__/
3 | build/
4 | *.so
5 | assets/sample_input/
6 | assets/sample_motion/
7 | configs/vhap_tracking/
8 | exps/
9 | pretrain_model/
10 | pretrained_models/
11 | model_zoo/
12 | tracking_output/
--------------------------------------------------------------------------------
/assets/images/logo.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/assets/images/logo.jpeg
--------------------------------------------------------------------------------
/assets/images/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/assets/images/teaser.jpg
--------------------------------------------------------------------------------
/configs/inference/lam-20k-8gpu.yaml:
--------------------------------------------------------------------------------
1 |
2 | experiment:
3 | type: lam
4 | seed: 42
5 | parent: lam
6 | child: lam_20k
7 | model:
8 | # image encoder
9 | encoder_type: "dinov2_fusion"
10 | encoder_model_name: "dinov2_vitl14_reg"
11 | encoder_feat_dim: 1024
12 | encoder_freeze: false
13 |
14 | # points embeddings
15 | latent_query_points_type: "e2e_flame"
16 | pcl_dim: 1024
17 |
18 | # transformer
19 | transformer_type: "sd3_cond"
20 | transformer_heads: 16
21 | transformer_dim: 1024
22 | transformer_layers: 10
23 | tf_grad_ckpt: true
24 | encoder_grad_ckpt: true
25 |
26 | # for gs renderer
27 | human_model_path: "./model_zoo/human_parametric_models"
28 | flame_subdivide_num: 1
29 | flame_type: "flame"
30 | gs_query_dim: 1024
31 | gs_use_rgb: True
32 | gs_sh: 3
33 | gs_mlp_network_config:
34 | n_neurons: 512
35 | n_hidden_layers: 2
36 | activation: silu
37 | gs_xyz_offset_max_step: 0.2
38 | gs_clip_scaling: 0.01
39 | scale_sphere: false
40 |
41 | expr_param_dim: 10
42 | shape_param_dim: 10
43 | add_teeth: false
44 |
45 | fix_opacity: false
46 | fix_rotation: false
47 |
48 | has_disc: false
49 |
50 | teeth_bs_flag: false
51 | oral_mesh_flag: false
52 |
53 | dataset:
54 | subsets:
55 | - name: video_head
56 | root_dirs: "./train_data/vfhq_vhap_nooffset/export"
57 | meta_path:
58 | train: "./train_data/vfhq_vhap_nooffset/label/valid_id_train_list.json"
59 | val: "./train_data/vfhq_vhap_nooffset/label/valid_id_val_list.json"
60 | sample_rate: 1.0
61 | sample_side_views: 7
62 | sample_aug_views: 0
63 | source_image_res: 512
64 | render_image:
65 | low: 512
66 | high: 512
67 | region: null
68 | num_train_workers: 4
69 | num_val_workers: 2
70 | pin_mem: true
71 | repeat_num: 1
72 | gaga_track_type: "vfhq"
73 |
74 | train:
75 | mixed_precision: bf16 # REPLACE THIS BASED ON GPU TYPE
76 | find_unused_parameters: false
77 | loss:
78 | pixel_weight: 0.0
79 | pixel_loss_fn: "mse"
80 | crop_face_weight: 0.
81 | crop_mouth_weight: 0.
82 | crop_eye_weight: 0.
83 | masked_pixel_weight: 1.0
84 | perceptual_weight: 1.0
85 | tv_weight: -1
86 | mask_weight: 0:1.0:0.5:10000
87 | offset_reg_weight: 0.1
88 | optim:
89 | lr: 4e-4
90 | weight_decay: 0.05
91 | beta1: 0.9
92 | beta2: 0.95
93 | clip_grad_norm: 1.0
94 | scheduler:
95 | type: cosine
96 | warmup_real_iters: 3000
97 | batch_size: 4 # REPLACE THIS (PER GPU)
98 | accum_steps: 1 # REPLACE THIS
99 | epochs: 100 # REPLACE THIS
100 | debug_global_steps: null
101 | resume: ""
102 |
103 | val:
104 | batch_size: 2
105 | global_step_period: 500
106 | debug_batches: 10
107 |
108 | saver:
109 | auto_resume: true
110 | load_model: null
111 | checkpoint_root: ./exps/checkpoints
112 | checkpoint_global_steps: 500
113 | checkpoint_keep_level: 5
114 |
115 | logger:
116 | stream_level: WARNING
117 | log_level: INFO
118 | log_root: ./exps/logs
119 | tracker_root: ./exps/trackers
120 | enable_profiler: false
121 | trackers:
122 | - tensorboard
123 | image_monitor:
124 | train_global_steps: 500
125 | samples_per_log: 4
126 |
127 | compile:
128 | suppress_errors: true
129 | print_specializations: true
130 | disable: true
131 |
--------------------------------------------------------------------------------
/external/human_matting/__init__.py:
--------------------------------------------------------------------------------
1 | from .matting_engine import StyleMatteEngine
2 |
--------------------------------------------------------------------------------
/external/human_matting/matting_engine.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import inspect
4 | import warnings
5 | import torchvision
6 | from .stylematte import StyleMatte
7 |
8 | class StyleMatteEngine(torch.nn.Module):
9 | def __init__(self, device='cpu',human_matting_path='./model_zoo/flame_tracking_models/matting/stylematte_synth.pt'):
10 | super().__init__()
11 | self._device = device
12 | self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
13 | self._init_models(human_matting_path)
14 |
15 | def _init_models(self,_ckpt_path):
16 | # load dict
17 | state_dict = torch.load(_ckpt_path, map_location='cpu')
18 | # build model
19 | model = StyleMatte()
20 | model.load_state_dict(state_dict)
21 | self.model = model.to(self._device).eval()
22 |
23 | @torch.no_grad()
24 | def forward(self, input_image, return_type='matting', background_rgb=1.0):
25 | if not hasattr(self, 'model'):
26 | self._init_models()
27 | if input_image.max() > 2.0:
28 | warnings.warn('Image should be normalized to [0, 1].')
29 | _, ori_h, ori_w = input_image.shape
30 | input_image = input_image.to(self._device).float()
31 | image = input_image.clone()
32 | # resize
33 | if max(ori_h, ori_w) > 1024:
34 | scale = 1024.0 / max(ori_h, ori_w)
35 | resized_h, resized_w = int(ori_h * scale), int(ori_w * scale)
36 | image = torchvision.transforms.functional.resize(image, (resized_h, resized_w), antialias=True)
37 | else:
38 | resized_h, resized_w = ori_h, ori_w
39 | # padding
40 | if resized_h % 8 != 0 or resized_w % 8 != 0:
41 | image = torchvision.transforms.functional.pad(image, ((8-resized_w % 8)%8, (8-resized_h % 8)%8, 0, 0, ), padding_mode='reflect')
42 | # normalize and forwarding
43 | image = self.normalize(image)[None]
44 | predict = self.model(image)[0]
45 | # undo padding
46 | predict = predict[:, -resized_h:, -resized_w:]
47 | # undo resize
48 | if resized_h != ori_h or resized_w != ori_w:
49 | predict = torchvision.transforms.functional.resize(predict, (ori_h, ori_w), antialias=True)
50 |
51 | if return_type == 'alpha':
52 | return predict[0]
53 | elif return_type == 'matting':
54 | predict = predict.expand(3, -1, -1)
55 | matting_image = input_image.clone()
56 | background_rgb = matting_image.new_ones(matting_image.shape) * background_rgb
57 | matting_image = matting_image * predict + (1-predict) * background_rgb
58 | return matting_image, predict[0]
59 | elif return_type == 'all':
60 | predict = predict.expand(3, -1, -1)
61 | background_rgb = input_image.new_ones(input_image.shape) * background_rgb
62 | foreground_image = input_image * predict + (1-predict) * background_rgb
63 | background_image = input_image * (1-predict) + predict * background_rgb
64 | return foreground_image, background_image
65 | else:
66 | raise NotImplementedError
67 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/__init__.py:
--------------------------------------------------------------------------------
1 | from . import detector
2 | from . import faceboxes_detector
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/detector.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 | class Detector(object):
4 | def __init__(self, model_arch, model_weights):
5 | self.model_arch = model_arch
6 | self.model_weights = model_weights
7 |
8 | def detect(self, image, thresh):
9 | raise NotImplementedError
10 |
11 | def crop(self, image, detections):
12 | crops = []
13 | for det in detections:
14 | xmin = max(det[2], 0)
15 | ymin = max(det[3], 0)
16 | width = det[4]
17 | height = det[5]
18 | xmax = min(xmin+width, image.shape[1])
19 | ymax = min(ymin+height, image.shape[0])
20 | cut = image[ymin:ymax, xmin:xmax,:]
21 | crops.append(cut)
22 |
23 | return crops
24 |
25 | def draw(self, image, detections, im_scale=None):
26 | if im_scale is not None:
27 | image = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
28 | detections = [[det[0],det[1],int(det[2]*im_scale),int(det[3]*im_scale),int(det[4]*im_scale),int(det[5]*im_scale)] for det in detections]
29 |
30 | for det in detections:
31 | xmin = det[2]
32 | ymin = det[3]
33 | width = det[4]
34 | height = det[5]
35 | xmax = xmin + width
36 | ymax = ymin + height
37 | cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
38 |
39 | return image
40 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/faceboxes_detector.py:
--------------------------------------------------------------------------------
1 | from .detector import Detector
2 | import cv2, os
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from .utils.config import cfg
7 | from .utils.prior_box import PriorBox
8 | from .utils.nms_wrapper import nms
9 | from .utils.faceboxes import FaceBoxesV2
10 | from .utils.box_utils import decode
11 | import time
12 |
13 | class FaceBoxesDetector(Detector):
14 | def __init__(self, model_arch, model_weights, use_gpu, device):
15 | super().__init__(model_arch, model_weights)
16 | self.name = 'FaceBoxesDetector'
17 | self.net = FaceBoxesV2(phase='test', size=None, num_classes=2) # initialize detector
18 | self.use_gpu = use_gpu
19 | self.device = device
20 |
21 | state_dict = torch.load(self.model_weights, map_location=self.device)
22 | # create new OrderedDict that does not contain `module.`
23 | from collections import OrderedDict
24 | new_state_dict = OrderedDict()
25 | for k, v in state_dict.items():
26 | name = k[7:] # remove `module.`
27 | new_state_dict[name] = v
28 | # load params
29 | self.net.load_state_dict(new_state_dict)
30 | self.net = self.net.to(self.device)
31 | self.net.eval()
32 |
33 |
34 | def detect(self, image, thresh=0.6, im_scale=None):
35 | # auto resize for large images
36 | if im_scale is None:
37 | height, width, _ = image.shape
38 | if min(height, width) > 600:
39 | im_scale = 600. / min(height, width)
40 | else:
41 | im_scale = 1
42 | image_scale = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
43 |
44 | scale = torch.Tensor([image_scale.shape[1], image_scale.shape[0], image_scale.shape[1], image_scale.shape[0]])
45 | image_scale = torch.from_numpy(image_scale.transpose(2,0,1)).to(self.device).int()
46 | mean_tmp = torch.IntTensor([104, 117, 123]).to(self.device)
47 | mean_tmp = mean_tmp.unsqueeze(1).unsqueeze(2)
48 | image_scale -= mean_tmp
49 | image_scale = image_scale.float().unsqueeze(0)
50 | scale = scale.to(self.device)
51 |
52 | with torch.no_grad():
53 | out = self.net(image_scale)
54 | #priorbox = PriorBox(cfg, out[2], (image_scale.size()[2], image_scale.size()[3]), phase='test')
55 | priorbox = PriorBox(cfg, image_size=(image_scale.size()[2], image_scale.size()[3]))
56 | priors = priorbox.forward()
57 | priors = priors.to(self.device)
58 | loc, conf = out
59 | prior_data = priors.data
60 | boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
61 | boxes = boxes * scale
62 | boxes = boxes.cpu().numpy()
63 | scores = conf.data.cpu().numpy()[:, 1]
64 |
65 | # ignore low scores
66 | inds = np.where(scores > thresh)[0]
67 | boxes = boxes[inds]
68 | scores = scores[inds]
69 |
70 | # keep top-K before NMS
71 | order = scores.argsort()[::-1][:5000]
72 | boxes = boxes[order]
73 | scores = scores[order]
74 |
75 | # do NMS
76 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
77 | keep = nms(dets, 0.3)
78 | dets = dets[keep, :]
79 |
80 | dets = dets[:750, :]
81 | detections_scale = []
82 | for i in range(dets.shape[0]):
83 | xmin = int(dets[i][0])
84 | ymin = int(dets[i][1])
85 | xmax = int(dets[i][2])
86 | ymax = int(dets[i][3])
87 | score = dets[i][4]
88 | width = xmax - xmin
89 | height = ymax - ymin
90 | detections_scale.append(['face', score, xmin, ymin, width, height])
91 |
92 | # adapt bboxes to the original image size
93 | if len(detections_scale) > 0:
94 | detections_scale = [[det[0],det[1],int(det[2]/im_scale),int(det[3]/im_scale),int(det[4]/im_scale),int(det[5]/im_scale)] for det in detections_scale]
95 |
96 | return detections_scale, im_scale
97 |
98 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/external/landmark_detection/FaceBoxesV2/utils/__init__.py
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/build.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | # --------------------------------------------------------
4 | # Fast R-CNN
5 | # Copyright (c) 2015 Microsoft
6 | # Licensed under The MIT License [see LICENSE for details]
7 | # Written by Ross Girshick
8 | # --------------------------------------------------------
9 |
10 | import os
11 | from os.path import join as pjoin
12 | import numpy as np
13 | from distutils.core import setup
14 | from distutils.extension import Extension
15 | from Cython.Distutils import build_ext
16 |
17 |
18 | def find_in_path(name, path):
19 | "Find a file in a search path"
20 | # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/
21 | for dir in path.split(os.pathsep):
22 | binpath = pjoin(dir, name)
23 | if os.path.exists(binpath):
24 | return os.path.abspath(binpath)
25 | return None
26 |
27 |
28 | # Obtain the numpy include directory. This logic works across numpy versions.
29 | try:
30 | numpy_include = np.get_include()
31 | except AttributeError:
32 | numpy_include = np.get_numpy_include()
33 |
34 |
35 | # run the customize_compiler
36 | class custom_build_ext(build_ext):
37 | def build_extensions(self):
38 | # customize_compiler_for_nvcc(self.compiler)
39 | build_ext.build_extensions(self)
40 |
41 |
42 | ext_modules = [
43 | Extension(
44 | "nms.cpu_nms",
45 | ["nms/cpu_nms.pyx"],
46 | # extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]},
47 | extra_compile_args=["-Wno-cpp", "-Wno-unused-function"],
48 | include_dirs=[numpy_include]
49 | )
50 | ]
51 |
52 | setup(
53 | name='mot_utils',
54 | ext_modules=ext_modules,
55 | # inject our custom trigger
56 | cmdclass={'build_ext': custom_build_ext},
57 | )
58 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/build/temp.linux-x86_64-cpython-310/nms/cpu_nms.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/external/landmark_detection/FaceBoxesV2/utils/build/temp.linux-x86_64-cpython-310/nms/cpu_nms.o
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/config.py:
--------------------------------------------------------------------------------
1 | # config.py
2 |
3 | cfg = {
4 | 'name': 'FaceBoxes',
5 | #'min_dim': 1024,
6 | #'feature_maps': [[32, 32], [16, 16], [8, 8]],
7 | # 'aspect_ratios': [[1], [1], [1]],
8 | 'min_sizes': [[32, 64, 128], [256], [512]],
9 | 'steps': [32, 64, 128],
10 | 'variance': [0.1, 0.2],
11 | 'clip': False,
12 | 'loc_weight': 2.0,
13 | 'gpu_train': True
14 | }
15 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 build.py build_ext --inplace
3 |
4 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.hpp:
--------------------------------------------------------------------------------
1 | void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
2 | int boxes_dim, float nms_overlap_thresh, int device_id);
3 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.pyx:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Faster R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ross Girshick
6 | # --------------------------------------------------------
7 |
8 | import numpy as np
9 | cimport numpy as np
10 |
11 | assert sizeof(int) == sizeof(np.int32_t)
12 |
13 | cdef extern from "gpu_nms.hpp":
14 | void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int)
15 |
16 | def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh,
17 | np.int32_t device_id=0):
18 | cdef int boxes_num = dets.shape[0]
19 | cdef int boxes_dim = dets.shape[1]
20 | cdef int num_out
21 | cdef np.ndarray[np.int32_t, ndim=1] \
22 | keep = np.zeros(boxes_num, dtype=np.int32)
23 | cdef np.ndarray[np.float32_t, ndim=1] \
24 | scores = dets[:, 4]
25 | cdef np.ndarray[np.int_t, ndim=1] \
26 | order = scores.argsort()[::-1]
27 | cdef np.ndarray[np.float32_t, ndim=2] \
28 | sorted_dets = dets[order, :]
29 | _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id)
30 | keep = keep[:num_out]
31 | return list(order[keep])
32 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/nms/py_cpu_nms.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ross Girshick
6 | # --------------------------------------------------------
7 |
8 | import numpy as np
9 |
10 | def py_cpu_nms(dets, thresh):
11 | """Pure Python NMS baseline."""
12 | x1 = dets[:, 0]
13 | y1 = dets[:, 1]
14 | x2 = dets[:, 2]
15 | y2 = dets[:, 3]
16 | scores = dets[:, 4]
17 |
18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
19 | order = scores.argsort()[::-1]
20 |
21 | keep = []
22 | while order.size > 0:
23 | i = order[0]
24 | keep.append(i)
25 | xx1 = np.maximum(x1[i], x1[order[1:]])
26 | yy1 = np.maximum(y1[i], y1[order[1:]])
27 | xx2 = np.minimum(x2[i], x2[order[1:]])
28 | yy2 = np.minimum(y2[i], y2[order[1:]])
29 |
30 | w = np.maximum(0.0, xx2 - xx1 + 1)
31 | h = np.maximum(0.0, yy2 - yy1 + 1)
32 | inter = w * h
33 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
34 |
35 | inds = np.where(ovr <= thresh)[0]
36 | order = order[inds + 1]
37 |
38 | return keep
39 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/nms_wrapper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ross Girshick
6 | # --------------------------------------------------------
7 |
8 | from .nms.cpu_nms import cpu_nms, cpu_soft_nms
9 |
10 | def nms(dets, thresh):
11 | """Dispatch to either CPU or GPU NMS implementations."""
12 |
13 | if dets.shape[0] == 0:
14 | return []
15 | return cpu_nms(dets, thresh)
16 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/prior_box.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from itertools import product as product
3 | import numpy as np
4 | from math import ceil
5 |
6 |
7 | class PriorBox(object):
8 | def __init__(self, cfg, image_size=None, phase='train'):
9 | super(PriorBox, self).__init__()
10 | #self.aspect_ratios = cfg['aspect_ratios']
11 | self.min_sizes = cfg['min_sizes']
12 | self.steps = cfg['steps']
13 | self.clip = cfg['clip']
14 | self.image_size = image_size
15 | self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
16 |
17 | def forward(self):
18 | anchors = []
19 | for k, f in enumerate(self.feature_maps):
20 | min_sizes = self.min_sizes[k]
21 | for i, j in product(range(f[0]), range(f[1])):
22 | for min_size in min_sizes:
23 | s_kx = min_size / self.image_size[1]
24 | s_ky = min_size / self.image_size[0]
25 | if min_size == 32:
26 | dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.25, j+0.5, j+0.75]]
27 | dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.25, i+0.5, i+0.75]]
28 | for cy, cx in product(dense_cy, dense_cx):
29 | anchors += [cx, cy, s_kx, s_ky]
30 | elif min_size == 64:
31 | dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.5]]
32 | dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.5]]
33 | for cy, cx in product(dense_cy, dense_cx):
34 | anchors += [cx, cy, s_kx, s_ky]
35 | else:
36 | cx = (j + 0.5) * self.steps[k] / self.image_size[1]
37 | cy = (i + 0.5) * self.steps[k] / self.image_size[0]
38 | anchors += [cx, cy, s_kx, s_ky]
39 | # back to torch land
40 | output = torch.Tensor(anchors).view(-1, 4)
41 | if self.clip:
42 | output.clamp_(max=1, min=0)
43 | return output
44 |
--------------------------------------------------------------------------------
/external/landmark_detection/FaceBoxesV2/utils/timer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ross Girshick
6 | # --------------------------------------------------------
7 |
8 | import time
9 |
10 |
11 | class Timer(object):
12 | """A simple timer."""
13 | def __init__(self):
14 | self.total_time = 0.
15 | self.calls = 0
16 | self.start_time = 0.
17 | self.diff = 0.
18 | self.average_time = 0.
19 |
20 | def tic(self):
21 | # using time.time instead of time.clock because time time.clock
22 | # does not normalize for multithreading
23 | self.start_time = time.time()
24 |
25 | def toc(self, average=True):
26 | self.diff = time.time() - self.start_time
27 | self.total_time += self.diff
28 | self.calls += 1
29 | self.average_time = self.total_time / self.calls
30 | if average:
31 | return self.average_time
32 | else:
33 | return self.diff
34 |
35 | def clear(self):
36 | self.total_time = 0.
37 | self.calls = 0
38 | self.start_time = 0.
39 | self.diff = 0.
40 | self.average_time = 0.
41 |
--------------------------------------------------------------------------------
/external/landmark_detection/README.md:
--------------------------------------------------------------------------------
1 | # STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection.
2 |
3 | Paper Link: [arxiv](https://arxiv.org/abs/2306.02763) | [CVPR 2023](https://openaccess.thecvf.com/content/CVPR2023/papers/Zhou_STAR_Loss_Reducing_Semantic_Ambiguity_in_Facial_Landmark_Detection_CVPR_2023_paper.pdf)
4 |
5 |
6 | - Pytorch implementation of **S**elf-adap**T**ive **A**mbiguity **R**eduction (**STAR**) loss.
7 | - STAR loss is a self-adaptive anisotropic direction loss, which can be used in heatmap regression-based methods for facial landmark detection.
8 | - Specifically, we find that semantic ambiguity results in the anisotropic predicted distribution, which inspires us to use predicted distribution to represent semantic ambiguity. So, we use PCA to indicate the character of the predicted distribution and indirectly formulate the direction and intensity of semantic ambiguity. Based on this, STAR loss adaptively suppresses the prediction error in the ambiguity direction to mitigate the impact of ambiguity annotation in training. More details can be found in our paper.
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Dependencies
17 |
18 | * python==3.7.3
19 | * PyTorch=1.6.0
20 | * requirements.txt
21 |
22 | ## Dataset Preparation
23 |
24 | - Step1: Download the raw images from [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/#dataset), [300W](https://ibug.doc.ic.ac.uk/resources/300-W/), and [WFLW](https://wywu.github.io/projects/LAB/WFLW.html).
25 | - Step2: We follow the data preprocess in [ADNet](https://openaccess.thecvf.com/content/ICCV2021/papers/Huang_ADNet_Leveraging_Error-Bias_Towards_Normal_Direction_in_Face_Alignment_ICCV_2021_paper.pdf), and the metadata can be download from [the corresponding repository](https://github.com/huangyangyu/ADNet).
26 | - Step3: Make them look like this:
27 | ```script
28 | # the dataset directory:
29 | |-- ${image_dir}
30 | |-- WFLW
31 | | -- WFLW_images
32 | |-- 300W
33 | | -- afw
34 | | -- helen
35 | | -- ibug
36 | | -- lfpw
37 | |-- COFW
38 | | -- train
39 | | -- test
40 | |-- ${annot_dir}
41 | |-- WFLW
42 | |-- train.tsv, test.tsv
43 | |-- 300W
44 | |-- train.tsv, test.tsv
45 | |--COFW
46 | |-- train.tsv, test.tsv
47 | ```
48 |
49 | ## Usage
50 | * Work directory: set the ${ckpt_dir} in ./conf/alignment.py.
51 | * Pretrained model:
52 |
53 | | Dataset | Model |
54 | |:-----------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------|
55 | | WFLW | [google](https://drive.google.com/file/d/1aOx0wYEZUfBndYy_8IYszLPG_D2fhxrT/view?usp=sharing) / [baidu](https://pan.baidu.com/s/10vvI-ovs3x9NrdmpnXK6sg?pwd=u0yu) |
56 | | 300W | [google](https://drive.google.com/file/d/1Fiu3hjjkQRdKsWE9IgyNPdiJSz9_MzA5/view?usp=sharing) / [baidu](https://pan.baidu.com/s/1bjUhLq1zS1XSl1nX78fU7A?pwd=yb2s) |
57 | | COFW | [google](https://drive.google.com/file/d/1NFcZ9jzql_jnn3ulaSzUlyhS05HWB9n_/view?usp=drive_link) / [baidu](https://pan.baidu.com/s/1XO6hDZ8siJLTgFcpyu1Tzw?pwd=m57n) |
58 |
59 |
60 | ### Training
61 | ```shell
62 | python main.py --mode=train --device_ids=0,1,2,3 \
63 | --image_dir=${image_dir} --annot_dir=${annot_dir} \
64 | --data_definition={WFLW, 300W, COFW}
65 | ```
66 |
67 | ### Testing
68 | ```shell
69 | python main.py --mode=test --device_ids=0 \
70 | --image_dir=${image_dir} --annot_dir=${annot_dir} \
71 | --data_definition={WFLW, 300W, COFW} \
72 | --pretrained_weight=${model_path} \
73 | ```
74 |
75 | ### Evaluation
76 | ```shell
77 | python evaluate.py --device_ids=0 \
78 | --model_path=${model_path} --metadata_path=${metadata_path} \
79 | --image_dir=${image_dir} --data_definition={WFLW, 300W, COFW} \
80 | ```
81 |
82 | To test on your own image, the following code could be considered:
83 | ```shell
84 | python demo.py
85 | ```
86 |
87 |
88 | ## Results
89 | The models trained by STAR Loss achieved **SOTA** performance in all of COFW, 300W and WFLW datasets.
90 |
91 |
92 |
93 |
94 |
95 | ## BibTeX Citation
96 | Please consider citing our papers in your publications if the project helps your research. BibTeX reference is as follows.
97 | ```
98 | @inproceedings{Zhou_2023_CVPR,
99 | author = {Zhou, Zhenglin and Li, Huaxia and Liu, Hong and Wang, Nanyang and Yu, Gang and Ji, Rongrong},
100 | title = {STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection},
101 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
102 | month = {June},
103 | year = {2023},
104 | pages = {15475-15484}
105 | }
106 | ```
107 |
108 | ## Acknowledgments
109 | This repository is built on top of [ADNet](https://github.com/huangyangyu/ADNet).
110 | Thanks for this strong baseline.
111 |
--------------------------------------------------------------------------------
/external/landmark_detection/conf/__init__.py:
--------------------------------------------------------------------------------
1 | from .alignment import Alignment
--------------------------------------------------------------------------------
/external/landmark_detection/conf/base.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | import logging
3 | import os.path as osp
4 | from argparse import Namespace
5 | # from tensorboardX import SummaryWriter
6 |
7 | class Base:
8 | """
9 | Base configure file, which contains the basic training parameters and should be inherited by other attribute configure file.
10 | """
11 |
12 | def __init__(self, config_name, ckpt_dir='./', image_dir='./', annot_dir='./'):
13 | self.type = config_name
14 | self.id = str(uuid.uuid4())
15 | self.note = ""
16 |
17 | self.ckpt_dir = ckpt_dir
18 | self.image_dir = image_dir
19 | self.annot_dir = annot_dir
20 |
21 | self.loader_type = "alignment"
22 | self.loss_func = "STARLoss"
23 |
24 | # train
25 | self.batch_size = 128
26 | self.val_batch_size = 1
27 | self.test_batch_size = 32
28 | self.channels = 3
29 | self.width = 256
30 | self.height = 256
31 |
32 | # mean values in r, g, b channel.
33 | self.means = (127, 127, 127)
34 | self.scale = 0.0078125
35 |
36 | self.display_iteration = 100
37 | self.milestones = [50, 80]
38 | self.max_epoch = 100
39 |
40 | self.net = "stackedHGnet_v1"
41 | self.nstack = 4
42 |
43 | # ["adam", "sgd"]
44 | self.optimizer = "adam"
45 | self.learn_rate = 0.1
46 | self.momentum = 0.01 # caffe: 0.99
47 | self.weight_decay = 0.0
48 | self.nesterov = False
49 | self.scheduler = "MultiStepLR"
50 | self.gamma = 0.1
51 |
52 | self.loss_weights = [1.0]
53 | self.criterions = ["SoftmaxWithLoss"]
54 | self.metrics = ["Accuracy"]
55 | self.key_metric_index = 0
56 | self.classes_num = [1000]
57 | self.label_num = len(self.classes_num)
58 |
59 | # model
60 | self.ema = False
61 | self.use_AAM = True
62 |
63 | # visualization
64 | self.writer = None
65 |
66 | # log file
67 | self.logger = None
68 |
69 | def init_instance(self):
70 | # self.writer = SummaryWriter(logdir=self.log_dir, comment=self.type)
71 | log_formatter = logging.Formatter("%(asctime)s %(levelname)-8s: %(message)s")
72 | root_logger = logging.getLogger()
73 | file_handler = logging.FileHandler(osp.join(self.log_dir, "log.txt"))
74 | file_handler.setFormatter(log_formatter)
75 | file_handler.setLevel(logging.NOTSET)
76 | root_logger.addHandler(file_handler)
77 | console_handler = logging.StreamHandler()
78 | console_handler.setFormatter(log_formatter)
79 | console_handler.setLevel(logging.NOTSET)
80 | root_logger.addHandler(console_handler)
81 | root_logger.setLevel(logging.NOTSET)
82 | self.logger = root_logger
83 |
84 | def __del__(self):
85 | # tensorboard --logdir self.log_dir
86 | if self.writer is not None:
87 | # self.writer.export_scalars_to_json(self.log_dir + "visual.json")
88 | self.writer.close()
89 |
90 | def init_from_args(self, args: Namespace):
91 | args_vars = vars(args)
92 | for key, value in args_vars.items():
93 | if hasattr(self, key) and value is not None:
94 | setattr(self, key, value)
95 |
--------------------------------------------------------------------------------
/external/landmark_detection/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "Token":"bpt4JPotFA6bpdknR9ZDCw",
3 | "business_flag": "shadow_cv_face",
4 | "model_local_file_path": "/apdcephfs_cq3/share_1134483/charlinzhou/Documents/awesome-tools/jizhi/",
5 | "host_num": 1,
6 | "host_gpu_num": 1,
7 | "GPUName": "V100",
8 | "is_elasticity": true,
9 | "enable_evicted_pulled_up": true,
10 | "task_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
11 | "task_flag": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
12 | "model_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
13 | "image_full_name": "mirrors.tencent.com/haroldzcli/py36-pytorch1.7.1-torchvision0.8.2-cuda10.1-cudnn7.6",
14 | "start_cmd": "./start_slpt.sh /apdcephfs_cq3/share_1134483/charlinzhou/Documents/SLPT_Training train.py --loss_func=star --bb_init --eigen_box --dist_func=align_smoothl1"
15 | }
16 |
--------------------------------------------------------------------------------
/external/landmark_detection/data_processor/CheckFaceKeyPoint.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | from PIL import Image
6 |
7 | selected_indices_old = [
8 | 2311,
9 | 2416,
10 | 2437,
11 | 2460,
12 | 2495,
13 | 2518,
14 | 2520,
15 | 2627,
16 | 4285,
17 | 4315,
18 | 6223,
19 | 6457,
20 | 6597,
21 | 6642,
22 | 6974,
23 | 7054,
24 | 7064,
25 | 7182,
26 | 7303,
27 | 7334,
28 | 7351,
29 | 7368,
30 | 7374,
31 | 7493,
32 | 7503,
33 | 7626,
34 | 8443,
35 | 8562,
36 | 8597,
37 | 8701,
38 | 8817,
39 | 8953,
40 | 11213,
41 | 11261,
42 | 11317,
43 | 11384,
44 | 11600,
45 | 11755,
46 | 11852,
47 | 11891,
48 | 11945,
49 | 12010,
50 | 12354,
51 | 12534,
52 | 12736,
53 | 12880,
54 | 12892,
55 | 13004,
56 | 13323,
57 | 13371,
58 | 13534,
59 | 13575,
60 | 14874,
61 | 14949,
62 | 14977,
63 | 15052,
64 | 15076,
65 | 15291,
66 | 15620,
67 | 15758,
68 | 16309,
69 | 16325,
70 | 16348,
71 | 16390,
72 | 16489,
73 | 16665,
74 | 16891,
75 | 17147,
76 | 17183,
77 | 17488,
78 | 17549,
79 | 17657,
80 | 17932,
81 | 19661,
82 | 20162,
83 | 20200,
84 | 20238,
85 | 20286,
86 | 20432,
87 | 20834,
88 | 20954,
89 | 21015,
90 | 21036,
91 | 21117,
92 | 21299,
93 | 21611,
94 | 21632,
95 | 21649,
96 | 22722,
97 | 22759,
98 | 22873,
99 | 23028,
100 | 23033,
101 | 23082,
102 | 23187,
103 | 23232,
104 | 23302,
105 | 23413,
106 | 23430,
107 | 23446,
108 | 23457,
109 | 23548,
110 | 23636,
111 | 32060,
112 | 32245,
113 | ]
114 |
115 | selected_indices = list()
116 | with open('/home/gyalex/Desktop/face_anno.txt', 'r') as f:
117 | lines = f.readlines()
118 | for line in lines:
119 | hh = line.strip().split()
120 | if len(hh) > 0:
121 | pid = hh[0].find('.')
122 | if pid != -1:
123 | s = hh[0][pid+1:len(hh[0])]
124 | print(s)
125 | selected_indices.append(int(s))
126 |
127 | f.close()
128 |
129 | dir = '/media/gyalex/Data/face_ldk_dataset/MHC_LightingPreset_Portrait_RT_0_19/MHC_LightingPreset_Portrait_RT_seq_000015'
130 |
131 | for idx in range(500):
132 | img = os.path.join(dir, "view_1/MHC_LightingPreset_Portrait_RT_seq_000015_FinalImage_" + str(idx).zfill(4) + ".jpeg")
133 | lmd = os.path.join(dir, "mesh/mesh_screen" + str(idx+5).zfill(7) + ".npy")
134 |
135 | img = cv2.imread(img)
136 | # c = 511 / 2
137 | # lmd = np.load(lmd) * c + c
138 | # lmd[:, 1] = 511 - lmd[:, 1]
139 | lmd = np.load(lmd)[selected_indices]
140 | for i in range(lmd.shape[0]):
141 | p = lmd[i]
142 | x, y = round(float(p[0])), round(float(p[1]))
143 | print(p)
144 | cv2.circle(img, (x, y), 2, (0, 0, 255), -1)
145 |
146 | cv2.imshow('win', img)
147 | cv2.waitKey(0)
--------------------------------------------------------------------------------
/external/landmark_detection/lib/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import get_encoder, get_decoder
2 | from .dataset import AlignmentDataset, Augmentation
3 | from .backbone import StackedHGNetV1
4 | from .metric import NME, Accuracy
5 | from .utils import time_print, time_string, time_for_file, time_string_short
6 | from .utils import convert_secs2time, convert_size2str
7 |
8 | from .utility import get_dataloader, get_config, get_net, get_criterions
9 | from .utility import get_optimizer, get_scheduler
10 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .stackedHGNetV1 import StackedHGNetV1
2 |
3 | __all__ = [
4 | "StackedHGNetV1",
5 | ]
--------------------------------------------------------------------------------
/external/landmark_detection/lib/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoder import get_encoder
2 | from .decoder import get_decoder
3 | from .augmentation import Augmentation
4 | from .alignmentDataset import AlignmentDataset
5 |
6 | __all__ = [
7 | "Augmentation",
8 | "AlignmentDataset",
9 | "get_encoder",
10 | "get_decoder"
11 | ]
12 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/dataset/decoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .decoder_default import decoder_default
2 |
3 | def get_decoder(decoder_type='default'):
4 | if decoder_type == 'default':
5 | decoder = decoder_default()
6 | else:
7 | raise NotImplementedError
8 | return decoder
--------------------------------------------------------------------------------
/external/landmark_detection/lib/dataset/decoder/decoder_default.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class decoder_default:
5 | def __init__(self, weight=1, use_weight_map=False):
6 | self.weight = weight
7 | self.use_weight_map = use_weight_map
8 |
9 | def _make_grid(self, h, w):
10 | yy, xx = torch.meshgrid(
11 | torch.arange(h).float() / (h - 1) * 2 - 1,
12 | torch.arange(w).float() / (w - 1) * 2 - 1)
13 | return yy, xx
14 |
15 | def get_coords_from_heatmap(self, heatmap):
16 | """
17 | inputs:
18 | - heatmap: batch x npoints x h x w
19 |
20 | outputs:
21 | - coords: batch x npoints x 2 (x,y), [-1, +1]
22 | - radius_sq: batch x npoints
23 | """
24 | batch, npoints, h, w = heatmap.shape
25 | if self.use_weight_map:
26 | heatmap = heatmap * self.weight
27 |
28 | yy, xx = self._make_grid(h, w)
29 | yy = yy.view(1, 1, h, w).to(heatmap)
30 | xx = xx.view(1, 1, h, w).to(heatmap)
31 |
32 | heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
33 |
34 | yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints
35 | xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints
36 | coords = torch.stack([xx_coord, yy_coord], dim=-1)
37 |
38 | return coords
39 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/dataset/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoder_default import encoder_default
2 |
3 | def get_encoder(image_height, image_width, scale=0.25, sigma=1.5, encoder_type='default'):
4 | if encoder_type == 'default':
5 | encoder = encoder_default(image_height, image_width, scale, sigma)
6 | else:
7 | raise NotImplementedError
8 | return encoder
9 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/dataset/encoder/encoder_default.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | class encoder_default:
9 | def __init__(self, image_height, image_width, scale=0.25, sigma=1.5):
10 | self.image_height = image_height
11 | self.image_width = image_width
12 | self.scale = scale
13 | self.sigma = sigma
14 |
15 | def generate_heatmap(self, points):
16 | # points = (num_pts, 2)
17 | h, w = self.image_height, self.image_width
18 | pointmaps = []
19 | for i in range(len(points)):
20 | pointmap = np.zeros([h, w], dtype=np.float32)
21 | # align_corners: False.
22 | point = copy.deepcopy(points[i])
23 | point[0] = max(0, min(w - 1, point[0]))
24 | point[1] = max(0, min(h - 1, point[1]))
25 | pointmap = self._circle(pointmap, point, sigma=self.sigma)
26 |
27 | pointmaps.append(pointmap)
28 | pointmaps = np.stack(pointmaps, axis=0) / 255.0
29 | pointmaps = torch.from_numpy(pointmaps).float().unsqueeze(0)
30 | pointmaps = F.interpolate(pointmaps, size=(int(w * self.scale), int(h * self.scale)), mode='bilinear',
31 | align_corners=False).squeeze()
32 | return pointmaps
33 |
34 | def _circle(self, img, pt, sigma=1.0, label_type='Gaussian'):
35 | # Check that any part of the gaussian is in-bounds
36 | tmp_size = sigma * 3
37 | ul = [int(pt[0] - tmp_size), int(pt[1] - tmp_size)]
38 | br = [int(pt[0] + tmp_size + 1), int(pt[1] + tmp_size + 1)]
39 | if (ul[0] > img.shape[1] - 1 or ul[1] > img.shape[0] - 1 or
40 | br[0] - 1 < 0 or br[1] - 1 < 0):
41 | # If not, just return the image as is
42 | return img
43 |
44 | # Generate gaussian
45 | size = 2 * tmp_size + 1
46 | x = np.arange(0, size, 1, np.float32)
47 | y = x[:, np.newaxis]
48 | x0 = y0 = size // 2
49 | # The gaussian is not normalized, we want the center value to equal 1
50 | if label_type == 'Gaussian':
51 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
52 | else:
53 | g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
54 |
55 | # Usable gaussian range
56 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
57 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
58 | # Image range
59 | img_x = max(0, ul[0]), min(br[0], img.shape[1])
60 | img_y = max(0, ul[1]), min(br[1], img.shape[0])
61 |
62 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = 255 * g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
63 | return img
64 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .awingLoss import AWingLoss
2 | from .smoothL1Loss import SmoothL1Loss
3 | from .wingLoss import WingLoss
4 | from .starLoss import STARLoss
5 | from .starLoss_v2 import STARLoss_v2
6 |
7 | __all__ = [
8 | "AWingLoss",
9 | "SmoothL1Loss",
10 | "WingLoss",
11 | "STARLoss",
12 |
13 | "STARLoss_v2",
14 | ]
15 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/loss/awingLoss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class AWingLoss(nn.Module):
7 | def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1, use_weight_map=True):
8 | super(AWingLoss, self).__init__()
9 | self.omega = omega
10 | self.theta = theta
11 | self.epsilon = epsilon
12 | self.alpha = alpha
13 | self.use_weight_map = use_weight_map
14 |
15 | def __repr__(self):
16 | return "AWingLoss()"
17 |
18 | def generate_weight_map(self, heatmap, k_size=3, w=10):
19 | dilate = F.max_pool2d(heatmap, kernel_size=k_size, stride=1, padding=1)
20 | weight_map = torch.where(dilate < 0.2, torch.zeros_like(heatmap), torch.ones_like(heatmap))
21 | return w * weight_map + 1
22 |
23 | def forward(self, output, groundtruth):
24 | """
25 | input: b x n x h x w
26 | output: b x n x h x w => 1
27 | """
28 | delta = (output - groundtruth).abs()
29 | A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth))) * (self.alpha - groundtruth) * \
30 | (torch.pow(self.theta / self.epsilon, self.alpha - groundtruth - 1)) * (1 / self.epsilon)
31 | C = self.theta * A - self.omega * \
32 | torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth))
33 | loss = torch.where(delta < self.theta,
34 | self.omega * torch.log(1 + torch.pow(delta / self.epsilon, self.alpha - groundtruth)),
35 | (A * delta - C))
36 | if self.use_weight_map:
37 | weight = self.generate_weight_map(groundtruth)
38 | loss = loss * weight
39 | return loss.mean()
40 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/loss/smoothL1Loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SmoothL1Loss(nn.Module):
6 | def __init__(self, scale=0.01):
7 | super(SmoothL1Loss, self).__init__()
8 | self.scale = scale
9 | self.EPSILON = 1e-10
10 |
11 | def __repr__(self):
12 | return "SmoothL1Loss()"
13 |
14 | def forward(self, output: torch.Tensor, groundtruth: torch.Tensor, reduction='mean'):
15 | """
16 | input: b x n x 2
17 | output: b x n x 1 => 1
18 | """
19 | if output.dim() == 4:
20 | shape = output.shape
21 | groundtruth = groundtruth.reshape(shape[0], shape[1], 1, shape[3])
22 |
23 | delta_2 = (output - groundtruth).pow(2).sum(dim=-1, keepdim=False)
24 | delta = delta_2.clamp(min=1e-6).sqrt()
25 | # delta = torch.sqrt(delta_2 + self.EPSILON)
26 | loss = torch.where( \
27 | delta_2 < self.scale * self.scale, \
28 | 0.5 / self.scale * delta_2, \
29 | delta - 0.5 * self.scale)
30 |
31 | if reduction == 'mean':
32 | loss = loss.mean()
33 | elif reduction == 'sum':
34 | loss = loss.sum()
35 |
36 | return loss
37 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/loss/wingLoss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import torch
5 | from torch import nn
6 |
7 |
8 | # torch.log and math.log is e based
9 | class WingLoss(nn.Module):
10 | def __init__(self, omega=0.01, epsilon=2):
11 | super(WingLoss, self).__init__()
12 | self.omega = omega
13 | self.epsilon = epsilon
14 |
15 | def forward(self, pred, target):
16 | y = target
17 | y_hat = pred
18 | delta_2 = (y - y_hat).pow(2).sum(dim=-1, keepdim=False)
19 | # delta = delta_2.sqrt()
20 | delta = delta_2.clamp(min=1e-6).sqrt()
21 | C = self.omega - self.omega * math.log(1 + self.omega / self.epsilon)
22 | loss = torch.where(
23 | delta < self.omega,
24 | self.omega * torch.log(1 + delta / self.epsilon),
25 | delta - C
26 | )
27 | return loss.mean()
28 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/metric/__init__.py:
--------------------------------------------------------------------------------
1 | from .nme import NME
2 | from .accuracy import Accuracy
3 | from .fr_and_auc import FR_AUC
4 | from .params import count_parameters_in_MB
5 |
6 | __all__ = [
7 | "NME",
8 | "Accuracy",
9 | "FR_AUC",
10 | 'count_parameters_in_MB',
11 | ]
12 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/metric/accuracy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | class Accuracy:
5 | def __init__(self):
6 | pass
7 |
8 | def __repr__(self):
9 | return "Accuracy()"
10 |
11 | def test(self, label_pd, label_gt, ignore_label=-1):
12 | correct_cnt = 0
13 | total_cnt = 0
14 | with torch.no_grad():
15 | label_pd = F.softmax(label_pd, dim=1)
16 | label_pd = torch.max(label_pd, 1)[1]
17 | label_gt = label_gt.long()
18 | c = (label_pd == label_gt)
19 | correct_cnt = torch.sum(c).item()
20 | total_cnt = c.size(0) - torch.sum(label_gt==ignore_label).item()
21 | return correct_cnt, total_cnt
22 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/metric/fr_and_auc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.integrate import simps
3 |
4 |
5 | class FR_AUC:
6 | def __init__(self, data_definition):
7 | self.data_definition = data_definition
8 | if data_definition == '300W':
9 | self.thresh = 0.05
10 | else:
11 | self.thresh = 0.1
12 |
13 | def __repr__(self):
14 | return "FR_AUC()"
15 |
16 | def test(self, nmes, thres=None, step=0.0001):
17 | if thres is None:
18 | thres = self.thresh
19 |
20 | num_data = len(nmes)
21 | xs = np.arange(0, thres + step, step)
22 | ys = np.array([np.count_nonzero(nmes <= x) for x in xs]) / float(num_data)
23 | fr = 1.0 - ys[-1]
24 | auc = simps(ys, x=xs) / thres
25 | return [round(fr, 4), round(auc, 6)]
26 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/metric/nme.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | class NME:
5 | def __init__(self, nme_left_index, nme_right_index):
6 | self.nme_left_index = nme_left_index
7 | self.nme_right_index = nme_right_index
8 |
9 | def __repr__(self):
10 | return "NME()"
11 |
12 | def get_norm_distance(self, landmarks):
13 | assert isinstance(self.nme_right_index, list), 'the nme_right_index is not list.'
14 | assert isinstance(self.nme_left_index, list), 'the nme_left, index is not list.'
15 | right_pupil = landmarks[self.nme_right_index, :].mean(0)
16 | left_pupil = landmarks[self.nme_left_index, :].mean(0)
17 | norm_distance = np.linalg.norm(right_pupil - left_pupil)
18 | return norm_distance
19 |
20 | def test(self, label_pd, label_gt):
21 | nme_list = []
22 | label_pd = label_pd.data.cpu().numpy()
23 | label_gt = label_gt.data.cpu().numpy()
24 |
25 | for i in range(label_gt.shape[0]):
26 | landmarks_gt = label_gt[i]
27 | landmarks_pv = label_pd[i]
28 | if isinstance(self.nme_right_index, list):
29 | norm_distance = self.get_norm_distance(landmarks_gt)
30 | elif isinstance(self.nme_right_index, int):
31 | norm_distance = np.linalg.norm(landmarks_gt[self.nme_left_index] - landmarks_gt[self.nme_right_index])
32 | else:
33 | raise NotImplementedError
34 | landmarks_delta = landmarks_pv - landmarks_gt
35 | nme = (np.linalg.norm(landmarks_delta, axis=1) / norm_distance).mean()
36 | nme_list.append(nme)
37 | # sum_nme += nme
38 | # total_cnt += 1
39 | return nme_list
40 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/metric/params.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def count_parameters_in_MB(model):
4 | if isinstance(model, nn.Module):
5 | return sum(v.numel() for v in model.parameters()) / 1e6
6 | else:
7 | return sum(v.numel() for v in model) / 1e6
--------------------------------------------------------------------------------
/external/landmark_detection/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .meter import AverageMeter
2 | from .time_utils import time_print, time_string, time_string_short, time_for_file
3 | from .time_utils import convert_secs2time, convert_size2str
4 | from .vis_utils import plot_points
5 |
6 | __all__ = [
7 | "AverageMeter",
8 | "time_print",
9 | "time_string",
10 | "time_string_short",
11 | "time_for_file",
12 | "convert_size2str",
13 | "convert_secs2time",
14 |
15 | "plot_points",
16 | ]
17 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/utils/meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value"""
3 |
4 | def __init__(self):
5 | self.reset()
6 |
7 | def reset(self):
8 | self.val = 0.0
9 | self.avg = 0.0
10 | self.sum = 0.0
11 | self.count = 0.0
12 |
13 | def update(self, val, n=1):
14 | self.val = val
15 | self.sum += val
16 | self.count += n
17 | self.avg = self.sum / self.count
18 |
19 | def __repr__(self):
20 | return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
--------------------------------------------------------------------------------
/external/landmark_detection/lib/utils/time_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | import time, sys
8 | import numpy as np
9 |
10 |
11 | def time_for_file():
12 | ISOTIMEFORMAT = '%d-%h-at-%H-%M-%S'
13 | return '{}'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
14 |
15 |
16 | def time_string():
17 | ISOTIMEFORMAT = '%Y-%m-%d %X'
18 | string = '[{}]'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
19 | return string
20 |
21 |
22 | def time_string_short():
23 | ISOTIMEFORMAT = '%Y%m%d'
24 | string = '{}'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
25 | return string
26 |
27 |
28 | def time_print(string, is_print=True):
29 | if (is_print):
30 | print('{} : {}'.format(time_string(), string))
31 |
32 |
33 | def convert_size2str(torch_size):
34 | dims = len(torch_size)
35 | string = '['
36 | for idim in range(dims):
37 | string = string + ' {}'.format(torch_size[idim])
38 | return string + ']'
39 |
40 |
41 | def convert_secs2time(epoch_time, return_str=False):
42 | need_hour = int(epoch_time / 3600)
43 | need_mins = int((epoch_time - 3600 * need_hour) / 60)
44 | need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins)
45 | if return_str:
46 | str = '[Time Left: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
47 | return str
48 | else:
49 | return need_hour, need_mins, need_secs
50 |
--------------------------------------------------------------------------------
/external/landmark_detection/lib/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import numbers
4 |
5 |
6 | def plot_points(vis, points, radius=1, color=(255, 255, 0), shift=4, indexes=0, is_index=False):
7 | if isinstance(points, list):
8 | num_point = len(points)
9 | elif isinstance(points, np.numarray):
10 | num_point = points.shape[0]
11 | else:
12 | raise NotImplementedError
13 | if isinstance(radius, numbers.Number):
14 | radius = np.zeros((num_point)) + radius
15 |
16 | if isinstance(indexes, numbers.Number):
17 | indexes = [indexes + i for i in range(num_point)]
18 | elif isinstance(indexes, list):
19 | pass
20 | else:
21 | raise NotImplementedError
22 |
23 | factor = (1 << shift)
24 | for (index, p, s) in zip(indexes, points, radius):
25 | cv2.circle(vis, (int(p[0] * factor + 0.5), int(p[1] * factor + 0.5)),
26 | int(s * factor), color, 1, cv2.LINE_AA, shift=shift)
27 | if is_index:
28 | vis = cv2.putText(vis, str(index), (int(p[0]), int(p[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.2,
29 | (255, 255, 255), 1)
30 |
31 | return vis
32 |
--------------------------------------------------------------------------------
/external/landmark_detection/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | torch==1.6.0
3 | torchvision==0.7.0
4 | python-gflags==3.1.2
5 | pandas==0.24.2
6 | pillow==6.0.0
7 | numpy==1.16.4
8 | opencv-python==4.1.0.25
9 | imageio==2.5.0
10 | imgaug==0.2.9
11 | lmdb==0.98
12 | lxml==4.5.0
13 | tensorboard==2.4.1
14 | protobuf==3.20
15 | tensorboardX==1.8
16 | # pyarrow==0.17.1
17 | # wandb==0.10.25
18 | # https://pytorch.org/get-started/previous-versions/
19 | # pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
20 |
--------------------------------------------------------------------------------
/external/landmark_detection/tester.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from lib import utility
4 |
5 |
6 | def test(args):
7 | # conf
8 | config = utility.get_config(args)
9 | config.device_id = args.device_ids[0]
10 |
11 | # set environment
12 | utility.set_environment(config)
13 | config.init_instance()
14 | if config.logger is not None:
15 | config.logger.info("Loaded configure file %s: %s" % (args.config_name, config.id))
16 | config.logger.info("\n" + "\n".join(["%s: %s" % item for item in config.__dict__.items()]))
17 |
18 | # model
19 | net = utility.get_net(config)
20 | model_path = os.path.join(config.model_dir,
21 | "train.pkl") if args.pretrained_weight is None else args.pretrained_weight
22 | if args.device_ids == [-1]:
23 | checkpoint = torch.load(model_path, map_location="cpu")
24 | else:
25 | checkpoint = torch.load(model_path)
26 |
27 | net.load_state_dict(checkpoint["net"])
28 |
29 | if config.logger is not None:
30 | config.logger.info("Loaded network")
31 | # config.logger.info('Net flops: {} G, params: {} MB'.format(flops/1e9, params/1e6))
32 |
33 | # data - test
34 | test_loader = utility.get_dataloader(config, "test")
35 |
36 | if config.logger is not None:
37 | config.logger.info("Loaded data from {:}".format(config.test_tsv_file))
38 |
39 | # inference
40 | result, metrics = utility.forward(config, test_loader, net)
41 | if config.logger is not None:
42 | config.logger.info("Finished inference")
43 |
44 | # output
45 | for k, metric in enumerate(metrics):
46 | if config.logger is not None and len(metric) != 0:
47 | config.logger.info(
48 | "Tested {} dataset, the Size is {}, Metric: [NME {:.6f}, FR {:.6f}, AUC {:.6f}]".format(
49 | config.type, len(test_loader.dataset), metric[0], metric[1], metric[2]))
50 |
--------------------------------------------------------------------------------
/external/landmark_detection/tools/infinite_loop.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | while True:
4 | time.sleep(1)
5 |
--------------------------------------------------------------------------------
/external/landmark_detection/tools/infinite_loop_gpu.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import time
5 | import torch
6 | import argparse
7 |
8 | parser = argparse.ArgumentParser(description='inf')
9 | parser.add_argument('--gpu', default='1', type=str, help='index of gpu to use')
10 | args = parser.parse_args()
11 |
12 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
13 |
14 | n = 1000
15 |
16 | x = torch.zeros(4, n, n).cuda()
17 | rest_time = 0.0000000000001
18 | while True:
19 | y = x * x
20 | time.sleep(rest_time)
21 | y1 = x * x
22 |
--------------------------------------------------------------------------------
/external/landmark_detection/tools/split_wflw.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os.path as osp
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 |
7 | tsv_file = '/apdcephfs/share_1134483/charlinzhou/datas/ADNet/WFLW/test.tsv'
8 | save_folder = '/apdcephfs/share_1134483/charlinzhou/datas/ADNet/_WFLW/'
9 |
10 | save_tags = ['largepose', 'expression', 'illumination', 'makeup', 'occlusion', 'blur']
11 | save_tags = ['test_{}_metadata.tsv'.format(t) for t in save_tags]
12 | save_files = [osp.join(save_folder, t) for t in save_tags]
13 | save_files = [open(f, 'w', newline='') for f in save_files]
14 |
15 | landmark_num = 98
16 | items = pd.read_csv(tsv_file, sep="\t")
17 |
18 | items_num = len(items)
19 | for index in tqdm(range(items_num)):
20 | image_path = items.iloc[index, 0]
21 | landmarks_5pts = items.iloc[index, 1]
22 | # landmarks_5pts = np.array(list(map(float, landmarks_5pts.split(","))), dtype=np.float32).reshape(5, 2)
23 | landmarks_target = items.iloc[index, 2]
24 | # landmarks_target = np.array(list(map(float, landmarks_target.split(","))), dtype=np.float32).reshape(landmark_num, 2)
25 | scale = items.iloc[index, 3]
26 | center_w, center_h = items.iloc[index, 4], items.iloc[index, 5]
27 | if len(items.iloc[index]) > 6:
28 | tags = np.array(list(map(lambda x: int(float(x)), items.iloc[index, 6].split(","))))
29 | else:
30 | tags = np.array([])
31 | assert len(tags) == 6, '{} v.s. 6'.format(len(tags))
32 | for k, tag in enumerate(tags):
33 | if tag == 1:
34 | save_file = save_files[k]
35 | tsv_w = csv.writer(save_file, delimiter='\t')
36 | tsv_w.writerow([image_path, landmarks_5pts, landmarks_target, scale, center_w, center_h])
37 |
38 | print('Done!')
39 |
--------------------------------------------------------------------------------
/external/landmark_detection/tools/testtime_pca.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 |
6 | def get_channel_sum(input):
7 | temp = torch.sum(input, dim=3)
8 | output = torch.sum(temp, dim=2)
9 | return output
10 |
11 |
12 | def expand_two_dimensions_at_end(input, dim1, dim2):
13 | input = input.unsqueeze(-1).unsqueeze(-1)
14 | input = input.expand(-1, -1, dim1, dim2)
15 | return input
16 |
17 |
18 | class TestTimePCA(nn.Module):
19 | def __init__(self):
20 | super(TestTimePCA, self).__init__()
21 |
22 | def _make_grid(self, h, w):
23 | yy, xx = torch.meshgrid(
24 | torch.arange(h).float() / (h - 1) * 2 - 1,
25 | torch.arange(w).float() / (w - 1) * 2 - 1)
26 | return yy, xx
27 |
28 | def weighted_mean(self, heatmap):
29 | batch, npoints, h, w = heatmap.shape
30 |
31 | yy, xx = self._make_grid(h, w)
32 | yy = yy.view(1, 1, h, w).to(heatmap)
33 | xx = xx.view(1, 1, h, w).to(heatmap)
34 |
35 | yy_coord = (yy * heatmap).sum([2, 3]) # batch x npoints
36 | xx_coord = (xx * heatmap).sum([2, 3]) # batch x npoints
37 | coords = torch.stack([xx_coord, yy_coord], dim=-1)
38 | return coords
39 |
40 | def unbiased_weighted_covariance(self, htp, means, num_dim_image=2, EPSILON=1e-5):
41 | batch_size, num_points, height, width = htp.shape
42 |
43 | yv, xv = self._make_grid(height, width)
44 | xv = Variable(xv)
45 | yv = Variable(yv)
46 |
47 | if htp.is_cuda:
48 | xv = xv.cuda()
49 | yv = yv.cuda()
50 |
51 | xmean = means[:, :, 0]
52 | xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height,
53 | width) # [batch_size, 68, 64, 64]
54 | ymean = means[:, :, 1]
55 | yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height,
56 | width) # [batch_size, 68, 64, 64]
57 | wt_xv_minus_mean = xv_minus_mean
58 | wt_yv_minus_mean = yv_minus_mean
59 |
60 | wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
61 | wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
62 | wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
63 | wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
64 | vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1) # [batch_size*68, 2, 4096]
65 |
66 | htp_vec = htp.view(batch_size * num_points, 1, height * width)
67 | htp_vec = htp_vec.expand(-1, 2, -1)
68 |
69 | covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2)) # [batch_size*68, 2, 2]
70 | covariance = covariance.view(batch_size, num_points, num_dim_image, num_dim_image) # [batch_size, 68, 2, 2]
71 |
72 | V_1 = htp.sum([2, 3]) + EPSILON # [batch_size, 68]
73 | V_2 = torch.pow(htp, 2).sum([2, 3]) + EPSILON # [batch_size, 68]
74 |
75 | denominator = V_1 - (V_2 / V_1)
76 | covariance = covariance / expand_two_dimensions_at_end(denominator, num_dim_image, num_dim_image)
77 |
78 | return covariance
79 |
80 | def forward(self, heatmap, groudtruth):
81 |
82 | batch, npoints, h, w = heatmap.shape
83 |
84 | heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
85 | heatmap = heatmap / heatmap_sum.view(batch, npoints, 1, 1)
86 |
87 | # means [batch_size, 68, 2]
88 | means = self.weighted_mean(heatmap)
89 |
90 | # covars [batch_size, 68, 2, 2]
91 | covars = self.unbiased_weighted_covariance(heatmap, means)
92 |
93 | # eigenvalues [batch_size * 68, 2] , eigenvectors [batch_size * 68, 2, 2]
94 | covars = covars.view(batch * npoints, 2, 2).cpu()
95 | evalues, evectors = covars.symeig(eigenvectors=True)
96 | evalues = evalues.view(batch, npoints, 2)
97 | evectors = evectors.view(batch, npoints, 2, 2)
98 | means = means.cpu()
99 |
100 | results = [dict() for _ in range(batch)]
101 | for i in range(batch):
102 | results[i]['pred'] = means[i].numpy().tolist()
103 | results[i]['gt'] = groudtruth[i].cpu().numpy().tolist()
104 | results[i]['evalues'] = evalues[i].numpy().tolist()
105 | results[i]['evectors'] = evectors[i].numpy().tolist()
106 |
107 | return results
108 |
--------------------------------------------------------------------------------
/external/vgghead_detector/VGGDetector.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Xuangeng Chu (xg.chu@outlook.com)
3 | # Modified based on code from Orest Kupyn (University of Oxford).
4 |
5 | import os
6 | import torch
7 | import numpy as np
8 | import torchvision
9 |
10 | from .utils_vgghead import nms
11 | from .utils_lmks_detector import LmksDetector
12 |
13 | class VGGHeadDetector(torch.nn.Module):
14 | def __init__(self, device,
15 | vggheadmodel_path=None):
16 | super().__init__()
17 | self.image_size = 640
18 | self._device = device
19 | self.vggheadmodel_path = vggheadmodel_path
20 | self._init_models()
21 |
22 | def _init_models(self,):
23 | # vgg_heads_l
24 | self.model = torch.load(self.vggheadmodel_path, map_location='cpu')
25 | self.model.to(self._device).eval()
26 |
27 | @torch.no_grad()
28 | def forward(self, image_tensor, image_key, conf_threshold=0.5):
29 | if not hasattr(self, 'model'):
30 | self._init_models()
31 | image_tensor = image_tensor.to(self._device).float()
32 | image, padding, scale = self._preprocess(image_tensor)
33 | bbox, scores, flame_params = self.model(image)
34 | bbox, vgg_results = self._postprocess(bbox, scores, flame_params, conf_threshold)
35 |
36 | if bbox is None:
37 | print('VGGHeadDetector: No face detected: {}!'.format(image_key))
38 | return None, None, None
39 | vgg_results['normalize'] = {'padding': padding, 'scale': scale}
40 |
41 | # bbox
42 | bbox = bbox.clip(0, self.image_size)
43 | bbox[[0, 2]] -= padding[0]; bbox[[1, 3]] -= padding[1]; bbox /= scale
44 | bbox = bbox.clip(0, self.image_size / scale)
45 |
46 | return vgg_results, bbox, None
47 |
48 | def _preprocess(self, image):
49 | _, h, w = image.shape
50 | if h > w:
51 | new_h, new_w = self.image_size, int(w * self.image_size / h)
52 | else:
53 | new_h, new_w = int(h * self.image_size / w), self.image_size
54 | scale = self.image_size / max(h, w)
55 | image = torchvision.transforms.functional.resize(image, (new_h, new_w), antialias=True)
56 | pad_w = self.image_size - image.shape[2]
57 | pad_h = self.image_size - image.shape[1]
58 | image = torchvision.transforms.functional.pad(image, (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2), fill=127)
59 | image = image.unsqueeze(0).float() / 255.0
60 | return image, np.array([pad_w // 2, pad_h // 2]), scale
61 |
62 | def _postprocess(self, bbox, scores, flame_params, conf_threshold):
63 | # flame_params = {"shape": 300, "exp": 100, "rotation": 6, "jaw": 3, "translation": 3, "scale": 1}
64 | bbox, scores, flame_params = nms(bbox, scores, flame_params, confidence_threshold=conf_threshold)
65 | if bbox.shape[0] == 0:
66 | return None, None
67 | max_idx = ((bbox[:, 3] - bbox[:, 1]) * (bbox[:, 2] - bbox[:, 0])).argmax().long()
68 | bbox, flame_params = bbox[max_idx], flame_params[max_idx]
69 | if bbox[0] < 5 and bbox[1] < 5 and bbox[2] > 635 and bbox[3] > 635:
70 | return None, None
71 | # flame
72 | posecode = torch.cat([flame_params.new_zeros(3), flame_params[400:403]])
73 | vgg_results = {
74 | 'rotation_6d': flame_params[403:409], 'translation': flame_params[409:412], 'scale': flame_params[412:],
75 | 'shapecode': flame_params[:300], 'expcode': flame_params[300:400], 'posecode': posecode,
76 | }
77 | return bbox, vgg_results
78 |
--------------------------------------------------------------------------------
/external/vgghead_detector/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Xuangeng Chu (xg.chu@outlook.com)
3 |
4 | from .VGGDetector import VGGHeadDetector
5 | from .utils_vgghead import reproject_vertices
6 |
--------------------------------------------------------------------------------
/external/vgghead_detector/utils_vgghead.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Xuangeng Chu (xg.chu@outlook.com)
3 | # Modified based on code from Orest Kupyn (University of Oxford).
4 |
5 | import torch
6 | import torchvision
7 |
8 | def reproject_vertices(flame_model, vgg_results):
9 | # flame_model = FLAMEModel(n_shape=300, n_exp=100, scale=1.0)
10 | vertices, _ = flame_model(
11 | shape_params=vgg_results['shapecode'],
12 | expression_params=vgg_results['expcode'],
13 | pose_params=vgg_results['posecode'],
14 | verts_sclae=1.0
15 | )
16 | vertices[:, :, 2] += 0.05 # MESH_OFFSET_Z
17 | vgg_landmarks3d = flame_model._vertices2landmarks(vertices)
18 | vgg_transform_results = vgg_results['transform']
19 | rotation_mat = rot_mat_from_6dof(vgg_transform_results['rotation_6d']).type(vertices.dtype)
20 | translation = vgg_transform_results['translation'][:, None, :]
21 | scale = torch.clamp(vgg_transform_results['scale'][:, None], 1e-8)
22 | rot_vertices = vertices.clone()
23 | rot_vertices = torch.matmul(rotation_mat.unsqueeze(1), rot_vertices.unsqueeze(-1))[..., 0]
24 | vgg_landmarks3d = torch.matmul(rotation_mat.unsqueeze(1), vgg_landmarks3d.unsqueeze(-1))[..., 0]
25 | proj_vertices = (rot_vertices * scale) + translation
26 | vgg_landmarks3d = (vgg_landmarks3d * scale) + translation
27 |
28 | trans_padding, trans_scale = vgg_results['normalize']['padding'], vgg_results['normalize']['scale']
29 | proj_vertices[:, :, 0] -= trans_padding[:, 0, None]
30 | proj_vertices[:, :, 1] -= trans_padding[:, 1, None]
31 | proj_vertices = proj_vertices / trans_scale[:, None, None]
32 | vgg_landmarks3d[:, :, 0] -= trans_padding[:, 0, None]
33 | vgg_landmarks3d[:, :, 1] -= trans_padding[:, 1, None]
34 | vgg_landmarks3d = vgg_landmarks3d / trans_scale[:, None, None]
35 | return proj_vertices.float()[..., :2], vgg_landmarks3d.float()[..., :2]
36 |
37 |
38 | def rot_mat_from_6dof(v: torch.Tensor) -> torch.Tensor:
39 | assert v.shape[-1] == 6
40 | v = v.view(-1, 6)
41 | vx, vy = v[..., :3].clone(), v[..., 3:].clone()
42 |
43 | b1 = torch.nn.functional.normalize(vx, dim=-1)
44 | b3 = torch.nn.functional.normalize(torch.cross(b1, vy, dim=-1), dim=-1)
45 | b2 = -torch.cross(b1, b3, dim=1)
46 | return torch.stack((b1, b2, b3), dim=-1)
47 |
48 |
49 | def nms(boxes_xyxy, scores, flame_params,
50 | confidence_threshold: float = 0.5, iou_threshold: float = 0.5,
51 | top_k: int = 1000, keep_top_k: int = 100
52 | ):
53 | for pred_bboxes_xyxy, pred_bboxes_conf, pred_flame_params in zip(
54 | boxes_xyxy.detach().float(),
55 | scores.detach().float(),
56 | flame_params.detach().float(),
57 | ):
58 | pred_bboxes_conf = pred_bboxes_conf.squeeze(-1) # [Anchors]
59 | conf_mask = pred_bboxes_conf >= confidence_threshold
60 |
61 | pred_bboxes_conf = pred_bboxes_conf[conf_mask]
62 | pred_bboxes_xyxy = pred_bboxes_xyxy[conf_mask]
63 | pred_flame_params = pred_flame_params[conf_mask]
64 |
65 | # Filter all predictions by self.nms_top_k
66 | if pred_bboxes_conf.size(0) > top_k:
67 | topk_candidates = torch.topk(pred_bboxes_conf, k=top_k, largest=True, sorted=True)
68 | pred_bboxes_conf = pred_bboxes_conf[topk_candidates.indices]
69 | pred_bboxes_xyxy = pred_bboxes_xyxy[topk_candidates.indices]
70 | pred_flame_params = pred_flame_params[topk_candidates.indices]
71 |
72 | # NMS
73 | idx_to_keep = torchvision.ops.boxes.nms(boxes=pred_bboxes_xyxy, scores=pred_bboxes_conf, iou_threshold=iou_threshold)
74 |
75 | final_bboxes = pred_bboxes_xyxy[idx_to_keep][: keep_top_k] # [Instances, 4]
76 | final_scores = pred_bboxes_conf[idx_to_keep][: keep_top_k] # [Instances, 1]
77 | final_params = pred_flame_params[idx_to_keep][: keep_top_k] # [Instances, Flame Params]
78 | return final_bboxes, final_scores, final_params
79 |
--------------------------------------------------------------------------------
/lam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/lam/__init__.py
--------------------------------------------------------------------------------
/lam/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .mixer import MixerDataset
17 |
--------------------------------------------------------------------------------
/lam/datasets/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from abc import ABC, abstractmethod
17 | import traceback
18 | import json
19 | import numpy as np
20 | import torch
21 | from PIL import Image
22 | from typing import Optional, Union
23 | from megfile import smart_open, smart_path_join, smart_exists
24 |
25 |
26 | class BaseDataset(torch.utils.data.Dataset, ABC):
27 | def __init__(self, root_dirs: str, meta_path: Optional[Union[list, str]]):
28 | super().__init__()
29 | self.root_dirs = root_dirs
30 | self.uids = self._load_uids(meta_path)
31 |
32 | def __len__(self):
33 | return len(self.uids)
34 |
35 | @abstractmethod
36 | def inner_get_item(self, idx):
37 | pass
38 |
39 | def __getitem__(self, idx):
40 | try:
41 | return self.inner_get_item(idx)
42 | except Exception as e:
43 | traceback.print_exc()
44 | print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}")
45 | # raise e
46 | return self.__getitem__((idx + 1) % self.__len__())
47 |
48 | @staticmethod
49 | def _load_uids(meta_path: Optional[Union[list, str]]):
50 | # meta_path is a json file
51 | if isinstance(meta_path, str):
52 | with open(meta_path, 'r') as f:
53 | uids = json.load(f)
54 | else:
55 | uids_lst = []
56 | max_total = 0
57 | for pth, weight in meta_path:
58 | with open(pth, 'r') as f:
59 | uids = json.load(f)
60 | max_total = max(len(uids) / weight, max_total)
61 | uids_lst.append([uids, weight, pth])
62 | merged_uids = []
63 | for uids, weight, pth in uids_lst:
64 | repeat = 1
65 | if len(uids) < int(weight * max_total):
66 | repeat = int(weight * max_total) // len(uids)
67 | cur_uids = uids * repeat
68 | merged_uids += cur_uids
69 | print("Data Path:", pth, "Repeat:", repeat, "Final Length:", len(cur_uids))
70 | uids = merged_uids
71 | print("Total UIDs:", len(uids))
72 | return uids
73 |
74 | @staticmethod
75 | def _load_rgba_image(file_path, bg_color: float = 1.0):
76 | ''' Load and blend RGBA image to RGB with certain background, 0-1 scaled '''
77 | rgba = np.array(Image.open(smart_open(file_path, 'rb')))
78 | rgba = torch.from_numpy(rgba).float() / 255.0
79 | rgba = rgba.permute(2, 0, 1).unsqueeze(0)
80 | rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :])
81 | rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...])
82 | return rgb
83 |
84 | @staticmethod
85 | def _locate_datadir(root_dirs, uid, locator: str):
86 | for root_dir in root_dirs:
87 | datadir = smart_path_join(root_dir, uid, locator)
88 | if smart_exists(datadir):
89 | return root_dir
90 | raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}")
91 |
--------------------------------------------------------------------------------
/lam/datasets/mixer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import math
17 | from functools import partial
18 | import torch
19 |
20 | __all__ = ['MixerDataset']
21 |
22 |
23 | class MixerDataset(torch.utils.data.Dataset):
24 |
25 | def __init__(self,
26 | split: str,
27 | subsets: dict,
28 | **dataset_kwargs,
29 | ):
30 | subsets = [e for e in subsets if e["meta_path"][split] is not None]
31 | self.subsets = [
32 | self._dataset_fn(subset, split)(**dataset_kwargs)
33 | for subset in subsets
34 | ]
35 | self.virtual_lens = [
36 | math.ceil(subset_config['sample_rate'] * len(subset_obj))
37 | for subset_config, subset_obj in zip(subsets, self.subsets)
38 | ]
39 |
40 | @staticmethod
41 | def _dataset_fn(subset_config: dict, split: str):
42 | name = subset_config['name']
43 |
44 | dataset_cls = None
45 | if name == "exavatar":
46 | from .exavatar import ExAvatarDataset
47 | dataset_cls = ExAvatarDataset
48 | elif name == "humman":
49 | from .humman import HuMManDataset
50 | dataset_cls = HuMManDataset
51 | elif name == "humman_ori":
52 | from .humman_ori import HuMManOriDataset
53 | dataset_cls = HuMManOriDataset
54 | elif name == "static_human":
55 | from .static_human import StaticHumanDataset
56 | dataset_cls = StaticHumanDataset
57 | elif name == "singleview_human":
58 | from .singleview_human import SingleViewHumanDataset
59 | dataset_cls = SingleViewHumanDataset
60 | elif name == "singleview_square_human":
61 | from .singleview_square_human import SingleViewSquareHumanDataset
62 | dataset_cls = SingleViewSquareHumanDataset
63 | elif name == "bedlam":
64 | from .bedlam import BedlamDataset
65 | dataset_cls = BedlamDataset
66 | elif name == "dna_human":
67 | from .dna import DNAHumanDataset
68 | dataset_cls = DNAHumanDataset
69 | elif name == "video_human":
70 | from .video_human import VideoHumanDataset
71 | dataset_cls = VideoHumanDataset
72 | elif name == "video_head":
73 | from .video_head import VideoHeadDataset
74 | dataset_cls = VideoHeadDataset
75 | elif name == "video_head_gagtrack":
76 | from .video_head_gagtrack import VideoHeadGagDataset
77 | dataset_cls = VideoHeadGagDataset
78 | elif name == "objaverse":
79 | from .objaverse import ObjaverseDataset
80 | dataset_cls = ObjaverseDataset
81 | # elif name == 'mvimgnet':
82 | # from .mvimgnet import MVImgNetDataset
83 | # dataset_cls = MVImgNetDataset
84 | else:
85 | raise NotImplementedError(f"Dataset {name} not implemented")
86 | print("==="*16*3, "\nUse dataset loader:", name, "\n"+"==="*3*16)
87 |
88 | return partial(
89 | dataset_cls,
90 | root_dirs=subset_config['root_dirs'],
91 | meta_path=subset_config['meta_path'][split],
92 | )
93 |
94 | def __len__(self):
95 | return sum(self.virtual_lens)
96 |
97 | def __getitem__(self, idx):
98 | subset_idx = 0
99 | virtual_idx = idx
100 | while virtual_idx >= self.virtual_lens[subset_idx]:
101 | virtual_idx -= self.virtual_lens[subset_idx]
102 | subset_idx += 1
103 | real_idx = virtual_idx % len(self.subsets[subset_idx])
104 | return self.subsets[subset_idx][real_idx]
105 |
--------------------------------------------------------------------------------
/lam/launch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import argparse
17 |
18 | from lam.runners import REGISTRY_RUNNERS
19 |
20 |
21 | def main():
22 |
23 | parser = argparse.ArgumentParser(description='lam launcher')
24 | parser.add_argument('runner', type=str, help='Runner to launch')
25 | args, unknown = parser.parse_known_args()
26 |
27 | if args.runner not in REGISTRY_RUNNERS:
28 | raise ValueError('Runner {} not found'.format(args.runner))
29 |
30 | RunnerClass = REGISTRY_RUNNERS[args.runner]
31 | with RunnerClass() as runner:
32 | runner.run()
33 |
34 |
35 | if __name__ == '__main__':
36 | main()
37 |
--------------------------------------------------------------------------------
/lam/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .pixelwise import *
17 | from .perceptual import *
18 | from .tvloss import *
19 |
--------------------------------------------------------------------------------
/lam/losses/perceptual.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | import torch.nn as nn
18 | from einops import rearrange
19 |
20 | __all__ = ['LPIPSLoss']
21 |
22 |
23 | class LPIPSLoss(nn.Module):
24 | """
25 | Compute LPIPS loss between two images.
26 | """
27 |
28 | def __init__(self, device, prefech: bool = False):
29 | super().__init__()
30 | self.device = device
31 | self.cached_models = {}
32 | if prefech:
33 | self.prefetch_models()
34 |
35 | def _get_model(self, model_name: str):
36 | if model_name not in self.cached_models:
37 | import warnings
38 | with warnings.catch_warnings():
39 | warnings.filterwarnings('ignore', category=UserWarning)
40 | import lpips
41 | _model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device)
42 | _model = torch.compile(_model)
43 | self.cached_models[model_name] = _model
44 | return self.cached_models[model_name]
45 |
46 | def prefetch_models(self):
47 | _model_names = ['alex', 'vgg']
48 | for model_name in _model_names:
49 | self._get_model(model_name)
50 |
51 | def forward(self, x, y, is_training: bool = True, conf_sigma=None, only_sym_conf=False):
52 | """
53 | Assume images are 0-1 scaled and channel first.
54 |
55 | Args:
56 | x: [N, M, C, H, W]
57 | y: [N, M, C, H, W]
58 | is_training: whether to use VGG or AlexNet.
59 |
60 | Returns:
61 | Mean-reduced LPIPS loss across batch.
62 | """
63 | model_name = 'vgg' if is_training else 'alex'
64 | loss_fn = self._get_model(model_name)
65 | EPS = 1e-7
66 | if len(x.shape) == 5:
67 | N, M, C, H, W = x.shape
68 | x = x.reshape(N*M, C, H, W)
69 | y = y.reshape(N*M, C, H, W)
70 | image_loss = loss_fn(x, y, normalize=True)
71 | image_loss = image_loss.mean(dim=[1, 2, 3])
72 | batch_loss = image_loss.reshape(N, M).mean(dim=1)
73 | all_loss = batch_loss.mean()
74 | else:
75 | image_loss = loss_fn(x, y, normalize=True)
76 | if conf_sigma is not None:
77 | image_loss = image_loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log()
78 | image_loss = image_loss.mean(dim=[1, 2, 3])
79 | all_loss = image_loss.mean()
80 | return all_loss
81 |
--------------------------------------------------------------------------------
/lam/losses/pixelwise.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | import torch.nn as nn
18 | from einops import rearrange
19 |
20 | __all__ = ['PixelLoss']
21 |
22 |
23 | class PixelLoss(nn.Module):
24 | """
25 | Pixel-wise loss between two images.
26 | """
27 |
28 | def __init__(self, option: str = 'mse'):
29 | super().__init__()
30 | self.loss_fn = self._build_from_option(option)
31 |
32 | @staticmethod
33 | def _build_from_option(option: str, reduction: str = 'none'):
34 | if option == 'mse':
35 | return nn.MSELoss(reduction=reduction)
36 | elif option == 'l1':
37 | return nn.L1Loss(reduction=reduction)
38 | else:
39 | raise NotImplementedError(f'Unknown pixel loss option: {option}')
40 |
41 | @torch.compile
42 | def forward(self, x, y, conf_sigma=None, only_sym_conf=False):
43 | """
44 | Assume images are channel first.
45 |
46 | Args:
47 | x: [N, M, C, H, W]
48 | y: [N, M, C, H, W]
49 |
50 | Returns:
51 | Mean-reduced pixel loss across batch.
52 | """
53 | N, M, C, H, W = x.shape
54 | x = rearrange(x, "n m c h w -> (n m) c h w")
55 | y = rearrange(y, "n m c h w -> (n m) c h w")
56 | image_loss = self.loss_fn(x, y)
57 |
58 | image_loss = image_loss.mean(dim=[1, 2, 3])
59 | batch_loss = image_loss.reshape(N, M).mean(dim=1)
60 | all_loss = batch_loss.mean()
61 | return all_loss
62 |
--------------------------------------------------------------------------------
/lam/losses/tvloss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | import torch.nn as nn
18 |
19 | __all__ = ['TVLoss']
20 |
21 |
22 | class TVLoss(nn.Module):
23 | """
24 | Total variance loss.
25 | """
26 |
27 | def __init__(self):
28 | super().__init__()
29 |
30 | def numel_excluding_first_dim(self, x):
31 | return x.numel() // x.shape[0]
32 |
33 | @torch.compile
34 | def forward(self, x):
35 | """
36 | Assume batched and channel first with inner sizes.
37 |
38 | Args:
39 | x: [N, M, C, H, W]
40 |
41 | Returns:
42 | Mean-reduced TV loss with element-level scaling.
43 | """
44 | N, M, C, H, W = x.shape
45 | x = x.reshape(N*M, C, H, W)
46 | diff_i = x[..., 1:, :] - x[..., :-1, :]
47 | diff_j = x[..., :, 1:] - x[..., :, :-1]
48 | div_i = self.numel_excluding_first_dim(diff_i)
49 | div_j = self.numel_excluding_first_dim(diff_j)
50 | tv_i = diff_i.pow(2).sum(dim=[1,2,3]) / div_i
51 | tv_j = diff_j.pow(2).sum(dim=[1,2,3]) / div_j
52 | tv = tv_i + tv_j
53 | batch_tv = tv.reshape(N, M).mean(dim=1)
54 | all_tv = batch_tv.mean()
55 | return all_tv
56 |
--------------------------------------------------------------------------------
/lam/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .modeling_lam import ModelLAM
17 |
18 |
19 | model_dict = {
20 | 'lam': ModelLAM,
21 | }
22 |
--------------------------------------------------------------------------------
/lam/models/discriminator.py:
--------------------------------------------------------------------------------
1 | """
2 | Ported from Paella
3 | """
4 |
5 | import torch
6 | from torch import nn
7 |
8 | from diffusers.configuration_utils import ConfigMixin, register_to_config
9 | from diffusers.models.modeling_utils import ModelMixin
10 |
11 | import functools
12 | # import torch.nn as nn
13 | from taming.modules.util import ActNorm
14 |
15 |
16 | # Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
17 | class Discriminator(ModelMixin, ConfigMixin):
18 | @register_to_config
19 | def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
20 | super().__init__()
21 | d = max(depth - 3, 3)
22 | layers = [
23 | nn.utils.spectral_norm(
24 | nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
25 | ),
26 | nn.LeakyReLU(0.2),
27 | ]
28 | for i in range(depth - 1):
29 | c_in = hidden_channels // (2 ** max((d - i), 0))
30 | c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
31 | layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
32 | layers.append(nn.InstanceNorm2d(c_out))
33 | layers.append(nn.LeakyReLU(0.2))
34 | self.encoder = nn.Sequential(*layers)
35 | self.shuffle = nn.Conv2d(
36 | (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
37 | )
38 | # self.logits = nn.Sigmoid()
39 |
40 |
41 | def forward(self, x, cond=None):
42 | x = self.encoder(x)
43 | if cond is not None:
44 | cond = cond.view(
45 | cond.size(0),
46 | cond.size(1),
47 | 1,
48 | 1,
49 | ).expand(-1, -1, x.size(-2), x.size(-1))
50 | x = torch.cat([x, cond], dim=1)
51 | x = self.shuffle(x)
52 | # x = self.logits(x)
53 | return x
54 |
55 |
56 |
57 |
58 | def weights_init(m):
59 | classname = m.__class__.__name__
60 | if classname.find('Conv') != -1:
61 | nn.init.normal_(m.weight.data, 0.0, 0.02)
62 | elif classname.find('BatchNorm') != -1:
63 | nn.init.normal_(m.weight.data, 1.0, 0.02)
64 | nn.init.constant_(m.bias.data, 0)
65 |
66 |
67 | class NLayerDiscriminator(nn.Module):
68 | """Defines a PatchGAN discriminator as in Pix2Pix
69 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
70 | """
71 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
72 | """Construct a PatchGAN discriminator
73 | Parameters:
74 | input_nc (int) -- the number of channels in input images
75 | ndf (int) -- the number of filters in the last conv layer
76 | n_layers (int) -- the number of conv layers in the discriminator
77 | norm_layer -- normalization layer
78 | """
79 | super(NLayerDiscriminator, self).__init__()
80 | if not use_actnorm:
81 | # norm_layer = nn.BatchNorm2d
82 | norm_layer = nn.InstanceNorm2d
83 | else:
84 | norm_layer = ActNorm
85 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
86 | # use_bias = norm_layer.func != nn.BatchNorm2d
87 | use_bias = norm_layer.func != nn.InstanceNorm2d
88 | else:
89 | # use_bias = norm_layer != nn.BatchNorm2d
90 | use_bias = norm_layer != nn.InstanceNorm2d
91 |
92 | kw = 4
93 | padw = 1
94 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]
95 | nf_mult = 1
96 | nf_mult_prev = 1
97 | for n in range(1, n_layers): # gradually increase the number of filters
98 | nf_mult_prev = nf_mult
99 | nf_mult = min(2 ** n, 8)
100 | sequence += [
101 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
102 | norm_layer(ndf * nf_mult),
103 | nn.LeakyReLU(0.2, False)
104 | ]
105 |
106 | nf_mult_prev = nf_mult
107 | nf_mult = min(2 ** n_layers, 8)
108 | sequence += [
109 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
110 | norm_layer(ndf * nf_mult),
111 | nn.LeakyReLU(0.2, False)
112 | ]
113 |
114 | sequence += [
115 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
116 | self.main = nn.Sequential(*sequence)
117 |
118 | def forward(self, input):
119 | """Standard forward."""
120 | return self.main(input)
121 |
--------------------------------------------------------------------------------
/lam/models/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Empty
16 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Empty
16 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/hub/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/hub/depth/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | from .decode_heads import BNHead, DPTHead
7 | from .encoder_decoder import DepthEncoderDecoder
8 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/hub/depth/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import warnings
7 |
8 | import torch.nn.functional as F
9 |
10 |
11 | def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
12 | if warning:
13 | if size is not None and align_corners:
14 | input_h, input_w = tuple(int(x) for x in input.shape[2:])
15 | output_h, output_w = tuple(int(x) for x in size)
16 | if output_h > input_h or output_w > output_h:
17 | if (
18 | (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
19 | and (output_h - 1) % (input_h - 1)
20 | and (output_w - 1) % (input_w - 1)
21 | ):
22 | warnings.warn(
23 | f"When align_corners={align_corners}, "
24 | "the output would more aligned if "
25 | f"input size {(input_h, input_w)} is `x+1` and "
26 | f"out size {(output_h, output_w)} is `nx+1`"
27 | )
28 | return F.interpolate(input, size, scale_factor, mode, align_corners)
29 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/hub/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import itertools
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15 |
16 |
17 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18 | compact_arch_name = arch_name.replace("_", "")[:4]
19 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21 |
22 |
23 | class CenterPadding(nn.Module):
24 | def __init__(self, multiple):
25 | super().__init__()
26 | self.multiple = multiple
27 |
28 | def _get_pad(self, size):
29 | new_size = math.ceil(size / self.multiple) * self.multiple
30 | pad_size = new_size - size
31 | pad_size_left = pad_size // 2
32 | pad_size_right = pad_size - pad_size_left
33 | return pad_size_left, pad_size_right
34 |
35 | @torch.inference_mode()
36 | def forward(self, x):
37 | pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38 | output = F.pad(x, pads)
39 | return output
40 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # ******************************************************************************
7 | # Code modified by Zexin He in 2023-2024.
8 | # Modifications are marked with clearly visible comments
9 | # licensed under the Apache License, Version 2.0.
10 | # ******************************************************************************
11 |
12 | from .dino_head import DINOHead
13 | from .mlp import Mlp
14 | from .patch_embed import PatchEmbed
15 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
16 | # ********** Modified by Zexin He in 2023-2024 **********
17 | # Avoid using nested tensor for now, deprecating usage of NestedTensorBlock
18 | from .block import Block, BlockWithModulation
19 | # ********************************************************
20 | from .attention import MemEffAttention
21 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9 |
10 | import logging
11 | import os
12 | import warnings
13 |
14 | from torch import Tensor
15 | from torch import nn
16 |
17 |
18 | logger = logging.getLogger("dinov2")
19 |
20 |
21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22 | try:
23 | if XFORMERS_ENABLED:
24 | from xformers.ops import memory_efficient_attention, unbind
25 |
26 | XFORMERS_AVAILABLE = True
27 | warnings.warn("xFormers is available (Attention)")
28 | else:
29 | warnings.warn("xFormers is disabled (Attention)")
30 | raise ImportError
31 | except ImportError:
32 | XFORMERS_AVAILABLE = False
33 | warnings.warn("xFormers is not available (Attention)")
34 |
35 |
36 | class Attention(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int = 8,
41 | qkv_bias: bool = False,
42 | proj_bias: bool = True,
43 | attn_drop: float = 0.0,
44 | proj_drop: float = 0.0,
45 | ) -> None:
46 | super().__init__()
47 | self.num_heads = num_heads
48 | head_dim = dim // num_heads
49 | self.scale = head_dim**-0.5
50 |
51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52 | self.attn_drop = nn.Dropout(attn_drop)
53 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
54 | self.proj_drop = nn.Dropout(proj_drop)
55 |
56 | def forward(self, x: Tensor) -> Tensor:
57 | B, N, C = x.shape
58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59 |
60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61 | attn = q @ k.transpose(-2, -1)
62 |
63 | attn = attn.softmax(dim=-1)
64 | attn = self.attn_drop(attn)
65 |
66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67 | x = self.proj(x)
68 | x = self.proj_drop(x)
69 | return x
70 |
71 |
72 | class MemEffAttention(Attention):
73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74 | if not XFORMERS_AVAILABLE:
75 | if attn_bias is not None:
76 | raise AssertionError("xFormers is required for using nested tensors")
77 | return super().forward(x)
78 |
79 | B, N, C = x.shape
80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81 |
82 | q, k, v = unbind(qkv, 2)
83 |
84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85 | x = x.reshape([B, N, C])
86 |
87 | x = self.proj(x)
88 | x = self.proj_drop(x)
89 | return x
90 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.init import trunc_normal_
9 | from torch.nn.utils import weight_norm
10 |
11 |
12 | class DINOHead(nn.Module):
13 | def __init__(
14 | self,
15 | in_dim,
16 | out_dim,
17 | use_bn=False,
18 | nlayers=3,
19 | hidden_dim=2048,
20 | bottleneck_dim=256,
21 | mlp_bias=True,
22 | ):
23 | super().__init__()
24 | nlayers = max(nlayers, 1)
25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26 | self.apply(self._init_weights)
27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28 | self.last_layer.weight_g.data.fill_(1)
29 |
30 | def _init_weights(self, m):
31 | if isinstance(m, nn.Linear):
32 | trunc_normal_(m.weight, std=0.02)
33 | if isinstance(m, nn.Linear) and m.bias is not None:
34 | nn.init.constant_(m.bias, 0)
35 |
36 | def forward(self, x):
37 | x = self.mlp(x)
38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40 | x = self.last_layer(x)
41 | return x
42 |
43 |
44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45 | if nlayers == 1:
46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47 | else:
48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49 | if use_bn:
50 | layers.append(nn.BatchNorm1d(hidden_dim))
51 | layers.append(nn.GELU())
52 | for _ in range(nlayers - 2):
53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54 | if use_bn:
55 | layers.append(nn.BatchNorm1d(hidden_dim))
56 | layers.append(nn.GELU())
57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58 | return nn.Sequential(*layers)
59 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9 |
10 |
11 | from torch import nn
12 |
13 |
14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15 | if drop_prob == 0.0 or not training:
16 | return x
17 | keep_prob = 1 - drop_prob
18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20 | if keep_prob > 0.0:
21 | random_tensor.div_(keep_prob)
22 | output = x * random_tensor
23 | return output
24 |
25 |
26 | class DropPath(nn.Module):
27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28 |
29 | def __init__(self, drop_prob=None):
30 | super(DropPath, self).__init__()
31 | self.drop_prob = drop_prob
32 |
33 | def forward(self, x):
34 | return drop_path(x, self.drop_prob, self.training)
35 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7 |
8 | from typing import Union
9 |
10 | import torch
11 | from torch import Tensor
12 | from torch import nn
13 |
14 |
15 | class LayerScale(nn.Module):
16 | def __init__(
17 | self,
18 | dim: int,
19 | init_values: Union[float, Tensor] = 1e-5,
20 | inplace: bool = False,
21 | ) -> None:
22 | super().__init__()
23 | self.inplace = inplace
24 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
25 |
26 | def forward(self, x: Tensor) -> Tensor:
27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
28 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9 |
10 |
11 | from typing import Callable, Optional
12 |
13 | from torch import Tensor, nn
14 |
15 |
16 | class Mlp(nn.Module):
17 | def __init__(
18 | self,
19 | in_features: int,
20 | hidden_features: Optional[int] = None,
21 | out_features: Optional[int] = None,
22 | act_layer: Callable[..., nn.Module] = nn.GELU,
23 | drop: float = 0.0,
24 | bias: bool = True,
25 | ) -> None:
26 | super().__init__()
27 | out_features = out_features or in_features
28 | hidden_features = hidden_features or in_features
29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30 | self.act = act_layer()
31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32 | self.drop = nn.Dropout(drop)
33 |
34 | def forward(self, x: Tensor) -> Tensor:
35 | x = self.fc1(x)
36 | x = self.act(x)
37 | x = self.drop(x)
38 | x = self.fc2(x)
39 | x = self.drop(x)
40 | return x
41 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9 |
10 | from typing import Callable, Optional, Tuple, Union
11 |
12 | from torch import Tensor
13 | import torch.nn as nn
14 |
15 |
16 | def make_2tuple(x):
17 | if isinstance(x, tuple):
18 | assert len(x) == 2
19 | return x
20 |
21 | assert isinstance(x, int)
22 | return (x, x)
23 |
24 |
25 | class PatchEmbed(nn.Module):
26 | """
27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28 |
29 | Args:
30 | img_size: Image size.
31 | patch_size: Patch token size.
32 | in_chans: Number of input image channels.
33 | embed_dim: Number of linear projection output channels.
34 | norm_layer: Normalization layer.
35 | """
36 |
37 | def __init__(
38 | self,
39 | img_size: Union[int, Tuple[int, int]] = 224,
40 | patch_size: Union[int, Tuple[int, int]] = 16,
41 | in_chans: int = 3,
42 | embed_dim: int = 768,
43 | norm_layer: Optional[Callable] = None,
44 | flatten_embedding: bool = True,
45 | ) -> None:
46 | super().__init__()
47 |
48 | image_HW = make_2tuple(img_size)
49 | patch_HW = make_2tuple(patch_size)
50 | patch_grid_size = (
51 | image_HW[0] // patch_HW[0],
52 | image_HW[1] // patch_HW[1],
53 | )
54 |
55 | self.img_size = image_HW
56 | self.patch_size = patch_HW
57 | self.patches_resolution = patch_grid_size
58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59 |
60 | self.in_chans = in_chans
61 | self.embed_dim = embed_dim
62 |
63 | self.flatten_embedding = flatten_embedding
64 |
65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67 |
68 | def forward(self, x: Tensor) -> Tensor:
69 | _, _, H, W = x.shape
70 | patch_H, patch_W = self.patch_size
71 |
72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74 |
75 | x = self.proj(x) # B C H W
76 | H, W = x.size(2), x.size(3)
77 | x = x.flatten(2).transpose(1, 2) # B HW C
78 | x = self.norm(x)
79 | if not self.flatten_embedding:
80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81 | return x
82 |
83 | def flops(self) -> float:
84 | Ho, Wo = self.patches_resolution
85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86 | if self.norm is not None:
87 | flops += Ho * Wo * self.embed_dim
88 | return flops
89 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import os
7 | from typing import Callable, Optional
8 | import warnings
9 |
10 | from torch import Tensor, nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class SwiGLUFFN(nn.Module):
15 | def __init__(
16 | self,
17 | in_features: int,
18 | hidden_features: Optional[int] = None,
19 | out_features: Optional[int] = None,
20 | act_layer: Callable[..., nn.Module] = None,
21 | drop: float = 0.0,
22 | bias: bool = True,
23 | ) -> None:
24 | super().__init__()
25 | out_features = out_features or in_features
26 | hidden_features = hidden_features or in_features
27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29 |
30 | def forward(self, x: Tensor) -> Tensor:
31 | x12 = self.w12(x)
32 | x1, x2 = x12.chunk(2, dim=-1)
33 | hidden = F.silu(x1) * x2
34 | return self.w3(hidden)
35 |
36 |
37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38 | try:
39 | if XFORMERS_ENABLED:
40 | from xformers.ops import SwiGLU
41 |
42 | XFORMERS_AVAILABLE = True
43 | warnings.warn("xFormers is available (SwiGLU)")
44 | else:
45 | warnings.warn("xFormers is disabled (SwiGLU)")
46 | raise ImportError
47 | except ImportError:
48 | SwiGLU = SwiGLUFFN
49 | XFORMERS_AVAILABLE = False
50 |
51 | warnings.warn("xFormers is not available (SwiGLU)")
52 |
53 |
54 | class SwiGLUFFNFused(SwiGLU):
55 | def __init__(
56 | self,
57 | in_features: int,
58 | hidden_features: Optional[int] = None,
59 | out_features: Optional[int] = None,
60 | act_layer: Callable[..., nn.Module] = None,
61 | drop: float = 0.0,
62 | bias: bool = True,
63 | ) -> None:
64 | out_features = out_features or in_features
65 | hidden_features = hidden_features or in_features
66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67 | super().__init__(
68 | in_features=in_features,
69 | hidden_features=hidden_features,
70 | out_features=out_features,
71 | bias=bias,
72 | )
73 |
--------------------------------------------------------------------------------
/lam/models/encoders/dinov2/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import logging
7 |
8 | from . import vision_transformer as vits
9 |
10 |
11 | logger = logging.getLogger("dinov2")
12 |
13 |
14 | def build_model(args, only_teacher=False, img_size=224):
15 | args.arch = args.arch.removesuffix("_memeff")
16 | if "vit" in args.arch:
17 | vit_kwargs = dict(
18 | img_size=img_size,
19 | patch_size=args.patch_size,
20 | init_values=args.layerscale,
21 | ffn_layer=args.ffn_layer,
22 | block_chunks=args.block_chunks,
23 | qkv_bias=args.qkv_bias,
24 | proj_bias=args.proj_bias,
25 | ffn_bias=args.ffn_bias,
26 | num_register_tokens=args.num_register_tokens,
27 | interpolate_offset=args.interpolate_offset,
28 | interpolate_antialias=args.interpolate_antialias,
29 | )
30 | teacher = vits.__dict__[args.arch](**vit_kwargs)
31 | if only_teacher:
32 | return teacher, teacher.embed_dim
33 | student = vits.__dict__[args.arch](
34 | **vit_kwargs,
35 | drop_path_rate=args.drop_path_rate,
36 | drop_path_uniform=args.drop_path_uniform,
37 | )
38 | embed_dim = student.embed_dim
39 | return student, teacher, embed_dim
40 |
41 |
42 | def build_model_from_cfg(cfg, only_teacher=False):
43 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
44 |
--------------------------------------------------------------------------------
/lam/models/encoders/dpt_util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/lam/models/encoders/dpt_util/__init__.py
--------------------------------------------------------------------------------
/lam/models/encoders/dpt_util/blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5 | scratch = nn.Module()
6 |
7 | out_shape1 = out_shape
8 | out_shape2 = out_shape
9 | out_shape3 = out_shape
10 | if len(in_shape) >= 4:
11 | out_shape4 = out_shape
12 |
13 | if expand:
14 | out_shape1 = out_shape
15 | out_shape2 = out_shape * 2
16 | out_shape3 = out_shape * 4
17 | if len(in_shape) >= 4:
18 | out_shape4 = out_shape * 8
19 |
20 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
21 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
22 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
23 | if len(in_shape) >= 4:
24 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25 |
26 | return scratch
27 |
28 |
29 | class ResidualConvUnit(nn.Module):
30 | """Residual convolution module.
31 | """
32 |
33 | def __init__(self, features, activation, bn):
34 | """Init.
35 |
36 | Args:
37 | features (int): number of features
38 | """
39 | super().__init__()
40 |
41 | self.bn = bn
42 |
43 | self.groups=1
44 |
45 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
46 |
47 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
48 |
49 | if self.bn == True:
50 | self.bn1 = nn.BatchNorm2d(features)
51 | self.bn2 = nn.BatchNorm2d(features)
52 |
53 | self.activation = activation
54 |
55 | self.skip_add = nn.quantized.FloatFunctional()
56 |
57 | def forward(self, x):
58 | """Forward pass.
59 |
60 | Args:
61 | x (tensor): input
62 |
63 | Returns:
64 | tensor: output
65 | """
66 |
67 | out = self.activation(x)
68 | out = self.conv1(out)
69 | if self.bn == True:
70 | out = self.bn1(out)
71 |
72 | out = self.activation(out)
73 | out = self.conv2(out)
74 | if self.bn == True:
75 | out = self.bn2(out)
76 |
77 | if self.groups > 1:
78 | out = self.conv_merge(out)
79 |
80 | return self.skip_add.add(out, x)
81 |
82 |
83 | class FeatureFusionBlock(nn.Module):
84 | """Feature fusion block.
85 | """
86 |
87 | def __init__(
88 | self,
89 | features,
90 | activation,
91 | deconv=False,
92 | bn=False,
93 | expand=False,
94 | align_corners=True,
95 | size=None,
96 | use_conv1=True
97 | ):
98 | """Init.
99 |
100 | Args:
101 | features (int): number of features
102 | """
103 | super(FeatureFusionBlock, self).__init__()
104 |
105 | self.deconv = deconv
106 | self.align_corners = align_corners
107 |
108 | self.groups=1
109 |
110 | self.expand = expand
111 | out_features = features
112 | if self.expand == True:
113 | out_features = features // 2
114 |
115 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
116 |
117 | if use_conv1:
118 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
119 | self.skip_add = nn.quantized.FloatFunctional()
120 |
121 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
122 |
123 |
124 | self.size=size
125 |
126 | def forward(self, *xs, size=None, scale_factor=2):
127 | """Forward pass.
128 |
129 | Returns:
130 | tensor: output
131 | """
132 | output = xs[0]
133 |
134 | if len(xs) == 2:
135 | res = self.resConfUnit1(xs[1])
136 | output = self.skip_add.add(output, res)
137 |
138 | output = self.resConfUnit2(output)
139 |
140 | if (size is None) and (self.size is None):
141 | modifier = {"scale_factor": scale_factor}
142 | elif size is None:
143 | modifier = {"size": self.size}
144 | else:
145 | modifier = {"size": size}
146 |
147 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
148 |
149 | output = self.out_conv(output)
150 |
151 | return output
152 |
--------------------------------------------------------------------------------
/lam/models/modulate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | import torch.nn as nn
18 |
19 |
20 | class ModLN(nn.Module):
21 | """
22 | Modulation with adaLN.
23 |
24 | References:
25 | DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
26 | """
27 | def __init__(self, inner_dim: int, mod_dim: int, eps: float):
28 | super().__init__()
29 | self.norm = nn.LayerNorm(inner_dim, eps=eps)
30 | self.mlp = nn.Sequential(
31 | nn.SiLU(),
32 | nn.Linear(mod_dim, inner_dim * 2),
33 | )
34 |
35 | @staticmethod
36 | def modulate(x, shift, scale):
37 | # x: [N, L, D]
38 | # shift, scale: [N, D]
39 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
40 |
41 | def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
42 | shift, scale = self.mlp(mod).chunk(2, dim=-1) # [N, D]
43 | return self.modulate(self.norm(x), shift, scale) # [N, L, D]
44 |
--------------------------------------------------------------------------------
/lam/models/rendering/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Empty
16 |
--------------------------------------------------------------------------------
/lam/models/rendering/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
--------------------------------------------------------------------------------
/lam/models/rendering/utils/math_utils.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2022 Petr Kellnhofer
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 |
25 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
26 | """
27 | Left-multiplies MxM @ NxM. Returns NxM.
28 | """
29 | res = torch.matmul(vectors4, matrix.T)
30 | return res
31 |
32 |
33 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
34 | """
35 | Normalize vector lengths.
36 | """
37 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
38 |
39 | def torch_dot(x: torch.Tensor, y: torch.Tensor):
40 | """
41 | Dot product of two tensors.
42 | """
43 | return (x * y).sum(-1)
44 |
45 |
46 | def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
47 | """
48 | Author: Petr Kellnhofer
49 | Intersects rays with the [-1, 1] NDC volume.
50 | Returns min and max distance of entry.
51 | Returns -1 for no intersection.
52 | https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
53 | """
54 | o_shape = rays_o.shape
55 | rays_o = rays_o.detach().reshape(-1, 3)
56 | rays_d = rays_d.detach().reshape(-1, 3)
57 |
58 |
59 | bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
60 | bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
61 | bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
62 | is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
63 |
64 | # Precompute inverse for stability.
65 | invdir = 1 / rays_d
66 | sign = (invdir < 0).long()
67 |
68 | # Intersect with YZ plane.
69 | tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
70 | tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
71 |
72 | # Intersect with XZ plane.
73 | tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
74 | tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
75 |
76 | # Resolve parallel rays.
77 | is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
78 |
79 | # Use the shortest intersection.
80 | tmin = torch.max(tmin, tymin)
81 | tmax = torch.min(tmax, tymax)
82 |
83 | # Intersect with XY plane.
84 | tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
85 | tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
86 |
87 | # Resolve parallel rays.
88 | is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
89 |
90 | # Use the shortest intersection.
91 | tmin = torch.max(tmin, tzmin)
92 | tmax = torch.min(tmax, tzmax)
93 |
94 | # Mark invalid.
95 | tmin[torch.logical_not(is_valid)] = -1
96 | tmax[torch.logical_not(is_valid)] = -2
97 |
98 | return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
99 |
100 |
101 | def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
102 | """
103 | Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
104 | Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
105 | """
106 | # create a tensor of 'num' steps from 0 to 1
107 | steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
108 |
109 | # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
110 | # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
111 | # "cannot statically infer the expected size of a list in this contex", hence the code below
112 | for i in range(start.ndim):
113 | steps = steps.unsqueeze(-1)
114 |
115 | # the output starts at 'start' and increments until 'stop' in each dimension
116 | out = start[None] + steps * (stop - start)[None]
117 |
118 | return out
119 |
--------------------------------------------------------------------------------
/lam/models/rendering/utils/point_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import os, cv2
6 | import matplotlib.pyplot as plt
7 | import math
8 |
9 | def depths_to_points(view, depthmap):
10 | c2w = (view.world_view_transform.T).inverse()
11 | if hasattr(view, "image_width"):
12 | W, H = view.image_width, view.image_height
13 | else:
14 | W, H = view.width, view.height
15 | ndc2pix = torch.tensor([
16 | [W / 2, 0, 0, (W) / 2],
17 | [0, H / 2, 0, (H) / 2],
18 | [0, 0, 0, 1]]).float().cuda().T
19 | projection_matrix = c2w.T @ view.full_proj_transform
20 | intrins = (projection_matrix @ ndc2pix)[:3,:3].T
21 |
22 | grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy')
23 | points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3)
24 | rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T
25 | rays_o = c2w[:3,3]
26 | points = depthmap.reshape(-1, 1) * rays_d + rays_o
27 | return points
28 |
29 | def depth_to_normal(view, depth):
30 | """
31 | view: view camera
32 | depth: depthmap
33 | """
34 | points = depths_to_points(view, depth).reshape(*depth.shape[1:], 3)
35 | output = torch.zeros_like(points)
36 | dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0)
37 | dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1)
38 | normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
39 | output[1:-1, 1:-1, :] = normal_map
40 | return output
--------------------------------------------------------------------------------
/lam/models/rendering/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/lam/models/rendering/utils/typing.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains type annotations for the project, using
3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
5 |
6 | Two types of typing checking can be used:
7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
9 | """
10 |
11 | # Basic types
12 | from typing import (
13 | Any,
14 | Callable,
15 | Dict,
16 | Iterable,
17 | List,
18 | Literal,
19 | NamedTuple,
20 | NewType,
21 | Optional,
22 | Sized,
23 | Tuple,
24 | Type,
25 | TypeVar,
26 | Union,
27 | )
28 |
29 | # Tensor dtype
30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
32 |
33 | # Config type
34 | from omegaconf import DictConfig
35 |
36 | # PyTorch Tensor type
37 | from torch import Tensor
38 |
39 | # Runtime type checking decorator
40 | from typeguard import typechecked as typechecker
41 |
--------------------------------------------------------------------------------
/lam/models/rendering/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Function
4 | from torch.cuda.amp import custom_bwd, custom_fwd
5 |
6 | from lam.models.rendering.utils.typing import *
7 |
8 | def get_activation(name):
9 | if name is None:
10 | return lambda x: x
11 | name = name.lower()
12 | if name == "none":
13 | return lambda x: x
14 | elif name == "lin2srgb":
15 | return lambda x: torch.where(
16 | x > 0.0031308,
17 | torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
18 | 12.92 * x,
19 | ).clamp(0.0, 1.0)
20 | elif name == "exp":
21 | return lambda x: torch.exp(x)
22 | elif name == "shifted_exp":
23 | return lambda x: torch.exp(x - 1.0)
24 | elif name == "trunc_exp":
25 | return trunc_exp
26 | elif name == "shifted_trunc_exp":
27 | return lambda x: trunc_exp(x - 1.0)
28 | elif name == "sigmoid":
29 | return lambda x: torch.sigmoid(x)
30 | elif name == "tanh":
31 | return lambda x: torch.tanh(x)
32 | elif name == "shifted_softplus":
33 | return lambda x: F.softplus(x - 1.0)
34 | elif name == "scale_-11_01":
35 | return lambda x: x * 0.5 + 0.5
36 | else:
37 | try:
38 | return getattr(F, name)
39 | except AttributeError:
40 | raise ValueError(f"Unknown activation function: {name}")
41 |
42 | class MLP(nn.Module):
43 | def __init__(
44 | self,
45 | dim_in: int,
46 | dim_out: int,
47 | n_neurons: int,
48 | n_hidden_layers: int,
49 | activation: str = "relu",
50 | output_activation: Optional[str] = None,
51 | bias: bool = True,
52 | ):
53 | super().__init__()
54 | layers = [
55 | self.make_linear(
56 | dim_in, n_neurons, is_first=True, is_last=False, bias=bias
57 | ),
58 | self.make_activation(activation),
59 | ]
60 | for i in range(n_hidden_layers - 1):
61 | layers += [
62 | self.make_linear(
63 | n_neurons, n_neurons, is_first=False, is_last=False, bias=bias
64 | ),
65 | self.make_activation(activation),
66 | ]
67 | layers += [
68 | self.make_linear(
69 | n_neurons, dim_out, is_first=False, is_last=True, bias=bias
70 | )
71 | ]
72 | self.layers = nn.Sequential(*layers)
73 | self.output_activation = get_activation(output_activation)
74 |
75 | def forward(self, x):
76 | x = self.layers(x)
77 | x = self.output_activation(x)
78 | return x
79 |
80 | def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True):
81 | layer = nn.Linear(dim_in, dim_out, bias=bias)
82 | return layer
83 |
84 | def make_activation(self, activation):
85 | if activation == "relu":
86 | return nn.ReLU(inplace=True)
87 | elif activation == "silu":
88 | return nn.SiLU(inplace=True)
89 | else:
90 | raise NotImplementedError
91 |
92 |
93 | class _TruncExp(Function): # pylint: disable=abstract-method
94 | # Implementation from torch-ngp:
95 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
96 | @staticmethod
97 | @custom_fwd(cast_inputs=torch.float32)
98 | def forward(ctx, x): # pylint: disable=arguments-differ
99 | ctx.save_for_backward(x)
100 | return torch.exp(x)
101 |
102 | @staticmethod
103 | @custom_bwd
104 | def backward(ctx, g): # pylint: disable=arguments-differ
105 | x = ctx.saved_tensors[0]
106 | return g * torch.exp(torch.clamp(x, max=15))
107 |
108 |
109 | trunc_exp = _TruncExp.apply
--------------------------------------------------------------------------------
/lam/runners/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from lam.utils.registry import Registry
17 |
18 | REGISTRY_RUNNERS = Registry()
19 |
20 | from .infer import *
21 |
--------------------------------------------------------------------------------
/lam/runners/abstract.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from abc import ABC, abstractmethod
17 |
18 |
19 | class Runner(ABC):
20 | """Abstract runner class"""
21 |
22 | def __init__(self):
23 | pass
24 |
25 | @abstractmethod
26 | def run(self):
27 | pass
28 |
--------------------------------------------------------------------------------
/lam/runners/infer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .lam import LAMInferrer
16 |
--------------------------------------------------------------------------------
/lam/runners/infer/base_inferrer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | from abc import abstractmethod
18 | from accelerate import Accelerator
19 | from accelerate.logging import get_logger
20 |
21 | from lam.runners.abstract import Runner
22 |
23 |
24 | logger = get_logger(__name__)
25 |
26 |
27 | class Inferrer(Runner):
28 |
29 | EXP_TYPE: str = None
30 |
31 | def __init__(self):
32 | super().__init__()
33 |
34 | torch._dynamo.config.disable = True
35 | self.accelerator = Accelerator()
36 |
37 | self.model : torch.nn.Module = None
38 |
39 | def __enter__(self):
40 | return self
41 |
42 | def __exit__(self, exc_type, exc_val, exc_tb):
43 | pass
44 |
45 | @property
46 | def device(self):
47 | return self.accelerator.device
48 |
49 | @abstractmethod
50 | def _build_model(self, cfg):
51 | pass
52 |
53 | @abstractmethod
54 | def infer_single(self, *args, **kwargs):
55 | pass
56 |
57 | @abstractmethod
58 | def infer(self):
59 | pass
60 |
61 | def run(self):
62 | self.infer()
63 |
--------------------------------------------------------------------------------
/lam/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Empty
16 |
--------------------------------------------------------------------------------
/lam/utils/compile.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from accelerate.logging import get_logger
17 |
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | def configure_dynamo(config: dict):
23 | try:
24 | import torch._dynamo
25 | logger.debug(f'Configuring torch._dynamo.config with {config}')
26 | for k, v in config.items():
27 | if v is None:
28 | logger.debug(f'Skipping torch._dynamo.config.{k} with None')
29 | continue
30 | if hasattr(torch._dynamo.config, k):
31 | logger.warning(f'Overriding torch._dynamo.config.{k} from {getattr(torch._dynamo.config, k)} to {v}')
32 | setattr(torch._dynamo.config, k, v)
33 | except ImportError:
34 | logger.debug('torch._dynamo not found, skipping')
35 | pass
36 |
--------------------------------------------------------------------------------
/lam/utils/ffmpeg_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import torch
4 | import numpy as np
5 | import imageio
6 | import cv2
7 | import imageio.v3 as iio
8 |
9 | VIDEO_TYPE_LIST = {'.avi','.mp4','.gif','.AVI','.MP4','.GIF'}
10 |
11 | def encodeffmpeg(inputs, frame_rate, output, format="png"):
12 | """output: need video_name"""
13 | assert (
14 | os.path.splitext(output)[-1] in VIDEO_TYPE_LIST
15 | ), "output is the format of video, e.g., mp4"
16 | assert os.path.isdir(inputs), "input dir is NOT file format"
17 |
18 | inputs = inputs[:-1] if inputs[-1] == "/" else inputs
19 |
20 | output = os.path.abspath(output)
21 |
22 | cmd = (
23 | f"ffmpeg -r {frame_rate} -pattern_type glob -i '{inputs}/*.{format}' "
24 | + f'-vcodec libx264 -crf 10 -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" '
25 | + f"-pix_fmt yuv420p {output} > /dev/null 2>&1"
26 | )
27 |
28 | print(cmd)
29 |
30 | output_dir = os.path.dirname(output)
31 | if os.path.exists(output):
32 | os.remove(output)
33 | os.makedirs(output_dir, exist_ok=True)
34 |
35 | print("encoding imgs to video.....")
36 | os.system(cmd)
37 | print("video done!")
38 |
39 | def images_to_video(images, output_path, fps, gradio_codec: bool, verbose=False, bitrate="2M"):
40 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
41 | frames = []
42 | for i in range(images.shape[0]):
43 | if isinstance(images, torch.Tensor):
44 | frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
45 | assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
46 | f"Frame shape mismatch: {frame.shape} vs {images.shape}"
47 | assert frame.min() >= 0 and frame.max() <= 255, \
48 | f"Frame value out of range: {frame.min()} ~ {frame.max()}"
49 | else:
50 | frame = images[i]
51 | width, height = frame.shape[1], frame.shape[0]
52 | # reshape to limit the export time
53 | # if width > 1200 or height > 1200 or images.shape[0] > 200:
54 | # frames.append(cv2.resize(frame, (width // 2, height // 2)))
55 | # else:
56 | frames.append(frame)
57 | # limit the frames directly @NOTE huggingface only!
58 | frames = frames[:200]
59 |
60 | frames = np.stack(frames)
61 |
62 | print("start saving {} using imageio.v3 .".format(output_path))
63 | iio.imwrite(output_path,frames,fps=fps,codec="libx264",pixelformat="yuv420p",bitrate=bitrate, macro_block_size=32)
64 | print("saved {} using imageio.v3 .".format(output_path))
--------------------------------------------------------------------------------
/lam/utils/gen_id_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import glob
3 | import sys
4 | import os
5 |
6 | data_root = sys.argv[1]
7 | save_path = sys.argv[2]
8 |
9 | all_hid_list = []
10 | for hid in os.listdir(data_root):
11 | if hid.startswith("p"):
12 | hid = os.path.join(data_root, hid)
13 | all_hid_list.append(hid.replace(data_root + "/", ""))
14 |
15 | print(f"len:{len(all_hid_list)}")
16 | print(all_hid_list[:3])
17 | with open(save_path, 'w') as fp:
18 | json.dump(all_hid_list, fp, indent=4)
--------------------------------------------------------------------------------
/lam/utils/gen_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import glob
3 | import sys
4 | import os
5 |
6 | data_root = sys.argv[1]
7 | save_path = sys.argv[2]
8 |
9 | all_img_list = []
10 | for hid in os.listdir(data_root):
11 | all_view_imgs_dir = os.path.join(data_root, hid, "kinect_color")
12 | if not os.path.exists(all_view_imgs_dir):
13 | continue
14 |
15 | for view_id in os.listdir(all_view_imgs_dir):
16 | imgs_dir = os.path.join(all_view_imgs_dir, view_id)
17 | for img_path in glob.glob(os.path.join(imgs_dir, "*.png")):
18 | all_img_list.append(img_path.replace(data_root + "/", ""))
19 |
20 | print(f"len:{len(all_img_list)}")
21 | print(all_img_list[:3])
22 | with open(save_path, 'w') as fp:
23 | json.dump(all_img_list, fp, indent=4)
--------------------------------------------------------------------------------
/lam/utils/hf_hub.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch.nn as nn
17 | from huggingface_hub import PyTorchModelHubMixin
18 |
19 |
20 | def wrap_model_hub(model_cls: nn.Module):
21 | class HfModel(model_cls, PyTorchModelHubMixin):
22 | def __init__(self, config: dict):
23 | super().__init__(**config)
24 | self.config = config
25 | return HfModel
26 |
--------------------------------------------------------------------------------
/lam/utils/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import logging
18 | from tqdm.auto import tqdm
19 |
20 |
21 | class TqdmStreamHandler(logging.StreamHandler):
22 | def emit(self, record):
23 | tqdm.write(self.format(record))
24 |
25 |
26 | def configure_logger(stream_level, log_level, file_path = None):
27 | _stream_level = stream_level.upper()
28 | _log_level = log_level.upper()
29 | _project_level = _log_level
30 |
31 | _formatter = logging.Formatter("[%(asctime)s] %(name)s: [%(levelname)s] %(message)s")
32 |
33 | _stream_handler = TqdmStreamHandler()
34 | _stream_handler.setLevel(_stream_level)
35 | _stream_handler.setFormatter(_formatter)
36 |
37 | if file_path is not None:
38 | os.makedirs(os.path.dirname(file_path), exist_ok=True)
39 | _file_handler = logging.FileHandler(file_path)
40 | _file_handler.setLevel(_log_level)
41 | _file_handler.setFormatter(_formatter)
42 |
43 | _project_logger = logging.getLogger(__name__.split('.')[0])
44 | _project_logger.setLevel(_project_level)
45 | _project_logger.addHandler(_stream_handler)
46 | if file_path is not None:
47 | _project_logger.addHandler(_file_handler)
48 |
--------------------------------------------------------------------------------
/lam/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import numpy as np
17 | import rembg
18 | import cv2
19 |
20 |
21 | class Preprocessor:
22 |
23 | """
24 | Preprocessing under cv2 conventions.
25 | """
26 |
27 | def __init__(self):
28 | self.rembg_session = rembg.new_session(
29 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
30 | )
31 |
32 | def preprocess(self, image_path: str, save_path: str, rmbg: bool = True, recenter: bool = True, size: int = 512, border_ratio: float = 0.2):
33 | image = self.step_load_to_size(image_path=image_path, size=size*2)
34 | if rmbg:
35 | image = self.step_rembg(image_in=image)
36 | else:
37 | image = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)
38 | if recenter:
39 | image = self.step_recenter(image_in=image, border_ratio=border_ratio, square_size=size)
40 | else:
41 | image = cv2.resize(
42 | src=image,
43 | dsize=(size, size),
44 | interpolation=cv2.INTER_AREA,
45 | )
46 | return cv2.imwrite(save_path, image)
47 |
48 | def step_rembg(self, image_in: np.ndarray) -> np.ndarray:
49 | image_out = rembg.remove(
50 | data=image_in,
51 | session=self.rembg_session,
52 | )
53 | return image_out
54 |
55 | def step_recenter(self, image_in: np.ndarray, border_ratio: float, square_size: int) -> np.ndarray:
56 | assert image_in.shape[-1] == 4, "Image to recenter must be RGBA"
57 | mask = image_in[..., -1] > 0
58 | ijs = np.nonzero(mask)
59 | # find bbox
60 | i_min, i_max = ijs[0].min(), ijs[0].max()
61 | j_min, j_max = ijs[1].min(), ijs[1].max()
62 | bbox_height, bbox_width = i_max - i_min, j_max - j_min
63 | # recenter and resize
64 | desired_size = int(square_size * (1 - border_ratio))
65 | scale = desired_size / max(bbox_height, bbox_width)
66 | desired_height, desired_width = int(bbox_height * scale), int(bbox_width * scale)
67 | desired_i_min, desired_j_min = (square_size - desired_height) // 2, (square_size - desired_width) // 2
68 | desired_i_max, desired_j_max = desired_i_min + desired_height, desired_j_min + desired_width
69 | # create new image
70 | image_out = np.zeros((square_size, square_size, 4), dtype=np.uint8)
71 | image_out[desired_i_min:desired_i_max, desired_j_min:desired_j_max] = cv2.resize(
72 | src=image_in[i_min:i_max, j_min:j_max],
73 | dsize=(desired_width, desired_height),
74 | interpolation=cv2.INTER_AREA,
75 | )
76 | return image_out
77 |
78 | def step_load_to_size(self, image_path: str, size: int) -> np.ndarray:
79 | image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
80 | height, width = image.shape[:2]
81 | scale = size / max(height, width)
82 | height, width = int(height * scale), int(width * scale)
83 | image_out = cv2.resize(
84 | src=image,
85 | dsize=(width, height),
86 | interpolation=cv2.INTER_AREA,
87 | )
88 | return image_out
89 |
--------------------------------------------------------------------------------
/lam/utils/profiler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from torch.profiler import profile
17 |
18 |
19 | class DummyProfiler(profile):
20 | def __init__(self):
21 | pass
22 |
23 | def __enter__(self):
24 | return self
25 |
26 | def __exit__(self, *args):
27 | pass
28 |
29 | def step(self):
30 | pass
31 |
--------------------------------------------------------------------------------
/lam/utils/proxy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 |
18 | NO_PROXY = "lam_NO_DATA_PROXY" in os.environ
19 |
20 | def no_proxy(func):
21 | """Decorator to disable proxy but then restore after the function call."""
22 | def wrapper(*args, **kwargs):
23 | # http_proxy, https_proxy, HTTP_PROXY, HTTPS_PROXY, all_proxy
24 | http_proxy = os.environ.get('http_proxy')
25 | https_proxy = os.environ.get('https_proxy')
26 | HTTP_PROXY = os.environ.get('HTTP_PROXY')
27 | HTTPS_PROXY = os.environ.get('HTTPS_PROXY')
28 | all_proxy = os.environ.get('all_proxy')
29 | os.environ['http_proxy'] = ''
30 | os.environ['https_proxy'] = ''
31 | os.environ['HTTP_PROXY'] = ''
32 | os.environ['HTTPS_PROXY'] = ''
33 | os.environ['all_proxy'] = ''
34 | try:
35 | return func(*args, **kwargs)
36 | finally:
37 | os.environ['http_proxy'] = http_proxy
38 | os.environ['https_proxy'] = https_proxy
39 | os.environ['HTTP_PROXY'] = HTTP_PROXY
40 | os.environ['HTTPS_PROXY'] = HTTPS_PROXY
41 | os.environ['all_proxy'] = all_proxy
42 | if NO_PROXY:
43 | return wrapper
44 | else:
45 | return func
46 |
--------------------------------------------------------------------------------
/lam/utils/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | class Registry:
17 | """Registry class"""
18 |
19 | def __init__(self):
20 | self._registry = {}
21 |
22 | def register(self, name):
23 | """Register a module"""
24 | def decorator(cls):
25 | assert name not in self._registry, 'Module {} already registered'.format(name)
26 | self._registry[name] = cls
27 | return cls
28 | return decorator
29 |
30 | def __getitem__(self, name):
31 | """Get a module"""
32 | return self._registry[name]
33 |
34 | def __contains__(self, name):
35 | return name in self._registry
36 |
--------------------------------------------------------------------------------
/lam/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import math
17 | from torch.optim.lr_scheduler import LRScheduler
18 | from accelerate.logging import get_logger
19 |
20 |
21 | logger = get_logger(__name__)
22 |
23 |
24 | class CosineWarmupScheduler(LRScheduler):
25 | def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1):
26 | self.warmup_iters = warmup_iters
27 | self.max_iters = max_iters
28 | self.initial_lr = initial_lr
29 | super().__init__(optimizer, last_iter)
30 |
31 | def get_lr(self):
32 | logger.debug(f"step count: {self._step_count} | warmup iters: {self.warmup_iters} | max iters: {self.max_iters}")
33 | if self._step_count <= self.warmup_iters:
34 | return [
35 | self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters
36 | for base_lr in self.base_lrs]
37 | else:
38 | cos_iter = self._step_count - self.warmup_iters
39 | cos_max_iter = self.max_iters - self.warmup_iters
40 | cos_theta = cos_iter / cos_max_iter * math.pi
41 | cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs]
42 | return cos_lr
43 |
--------------------------------------------------------------------------------
/lam/utils/video.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import numpy as np
18 | import torch
19 |
20 | def images_to_video(images, output_path, fps, gradio_codec: bool, verbose=False):
21 | import imageio
22 | # images: torch.tensor (T, C, H, W), 0-1 or numpy: (T, H, W, 3) 0-255
23 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
24 | frames = []
25 | for i in range(images.shape[0]):
26 | if isinstance(images, torch.Tensor):
27 | frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
28 | assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
29 | f"Frame shape mismatch: {frame.shape} vs {images.shape}"
30 | assert frame.min() >= 0 and frame.max() <= 255, \
31 | f"Frame value out of range: {frame.min()} ~ {frame.max()}"
32 | else:
33 | frame = images[i]
34 | frames.append(frame)
35 | frames = np.stack(frames)
36 | if gradio_codec:
37 | imageio.mimwrite(output_path, frames, fps=fps, quality=10)
38 | else:
39 | # imageio.mimwrite(output_path, frames, fps=fps, codec='mpeg4', quality=10)
40 | imageio.mimwrite(output_path, frames, fps=fps, quality=10)
41 |
42 | if verbose:
43 | print(f"Using gradio codec option {gradio_codec}")
44 | print(f"Saved video to {output_path}")
45 |
46 |
47 | def save_images2video(img_lst, v_pth, fps):
48 | import moviepy.editor as mpy
49 | # Convert the list of NumPy arrays to a list of ImageClip objects
50 | clips = [mpy.ImageClip(img).set_duration(0.1) for img in img_lst] # 0.1 seconds per frame
51 |
52 | # Concatenate the ImageClips into a single VideoClip
53 | video = mpy.concatenate_videoclips(clips, method="compose")
54 |
55 | # Write the VideoClip to a file
56 | video.write_videofile(v_pth, fps=fps) # setting fps to 10 as example
57 | print("save video to:", v_pth)
58 |
59 |
60 | if __name__ == "__main__":
61 | from glob import glob
62 | clip_name = "clip1"
63 | ptn = f"./assets/sample_motion/export/{clip_name}/images/*.png"
64 | images_pths = glob(ptn)
65 | import cv2
66 | import numpy as np
67 | images = [cv2.imread(pth) for pth in images_pths]
68 | save_images2video(images, "./assets/sample_mption/export/{clip_name}/video.mp4", 25, True)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pymcubes==0.1.6
2 | omegaconf==2.3.0
3 | moviepy==1.0.3
4 | pandas==2.2.3
5 | transformers==4.41.2
6 | numpy==1.23.0
7 | trimesh==4.4.9
8 | opencv_python_headless==4.11.0.86
9 | tensorflow==2.12.0
10 | face-detection-tflite==0.6.0
11 | scikit-image==0.20.0
12 | jaxlib==0.4.30
13 | gradio==3.44.3
14 | huggingface_hub==0.23.2
15 | Cython
16 | accelerate
17 | tyro
18 | einops
19 | diffusers
20 | plyfile
21 | jaxtyping
22 | typeguard
23 | chumpy
24 | loguru
25 | ninja
26 | git+https://github.com/facebookresearch/pytorch3d.git
27 | git+https://github.com/ashawkey/diff-gaussian-rasterization/
28 | nvdiffrast@git+https://github.com/ShenhanQian/nvdiffrast@backface-culling
29 | git+https://github.com/camenduru/simple-knn/
--------------------------------------------------------------------------------
/scripts/convert_hf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, Zexin He
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import pdb
17 | import sys
18 | import traceback
19 | from tempfile import TemporaryDirectory
20 |
21 | import safetensors
22 | import torch.nn as nn
23 | from accelerate import Accelerator
24 | from megfile import (
25 | smart_copy,
26 | smart_exists,
27 | smart_listdir,
28 | smart_makedirs,
29 | smart_path_join,
30 | )
31 | from omegaconf import OmegaConf
32 |
33 | sys.path.append(".")
34 |
35 | from lam.models import model_dict
36 | from lam.utils.hf_hub import wrap_model_hub
37 | from lam.utils.proxy import no_proxy
38 |
39 |
40 | @no_proxy
41 | def auto_load_model(cfg, model: nn.Module) -> int:
42 |
43 | ckpt_root = smart_path_join(
44 | cfg.saver.checkpoint_root,
45 | cfg.experiment.parent,
46 | cfg.experiment.child,
47 | )
48 | if not smart_exists(ckpt_root):
49 | raise FileNotFoundError(f"Checkpoint root not found: {ckpt_root}")
50 | ckpt_dirs = smart_listdir(ckpt_root)
51 | if len(ckpt_dirs) == 0:
52 | raise FileNotFoundError(f"No checkpoint found in {ckpt_root}")
53 | ckpt_dirs.sort()
54 |
55 | load_step = (
56 | f"{cfg.convert.global_step}"
57 | if cfg.convert.global_step is not None
58 | else ckpt_dirs[-1]
59 | )
60 | load_model_path = smart_path_join(ckpt_root, load_step, "model.safetensors")
61 |
62 | if load_model_path.startswith("s3"):
63 | tmpdir = TemporaryDirectory()
64 | tmp_model_path = smart_path_join(tmpdir.name, f"tmp.safetensors")
65 | smart_copy(load_model_path, tmp_model_path)
66 | load_model_path = tmp_model_path
67 |
68 | print(f"Loading from {load_model_path}")
69 | try:
70 | safetensors.torch.load_model(model, load_model_path, strict=True)
71 | except:
72 | traceback.print_exc()
73 | safetensors.torch.load_model(model, load_model_path, strict=False)
74 |
75 | return int(load_step)
76 |
77 |
78 | if __name__ == "__main__":
79 |
80 | parser = argparse.ArgumentParser()
81 | parser.add_argument("--config", type=str, default="./assets/config.yaml")
82 | args, unknown = parser.parse_known_args()
83 | cfg = OmegaConf.load(args.config)
84 | cli_cfg = OmegaConf.from_cli(unknown)
85 | cfg = OmegaConf.merge(cfg, cli_cfg)
86 |
87 | """
88 | [cfg.convert]
89 | global_step: int
90 | save_dir: str
91 | """
92 |
93 | accelerator = Accelerator()
94 |
95 | # hf_model_cls = wrap_model_hub(model_dict[cfg.experiment.type])
96 | hf_model_cls = wrap_model_hub(model_dict["human_lrm_sapdino_bh_sd3_5"])
97 |
98 | hf_model = hf_model_cls(OmegaConf.to_container(cfg.model))
99 | loaded_step = auto_load_model(cfg, hf_model)
100 | dump_path = smart_path_join(
101 | f"./exps/releases",
102 | cfg.experiment.parent,
103 | cfg.experiment.child,
104 | f"step_{loaded_step:06d}",
105 | )
106 | print(f"Saving locally to {dump_path}")
107 | smart_makedirs(dump_path, exist_ok=True)
108 | hf_model.save_pretrained(
109 | save_directory=dump_path,
110 | config=hf_model.config,
111 | )
112 |
--------------------------------------------------------------------------------
/scripts/exp/run_4gpu.sh:
--------------------------------------------------------------------------------
1 | ACC_CONFIG="./configs/accelerate-train-4gpu.yaml"
2 | TRAIN_CONFIG="./configs/train-sample-human.yaml"
3 |
4 | if [ -n "$1" ]; then
5 | TRAIN_CONFIG=$1
6 | else
7 | TRAIN_CONFIG="./configs/train-sample-human.yaml"
8 | fi
9 |
10 | if [ -n "$2" ]; then
11 | MAIN_PORT=$2
12 | else
13 | MAIN_PORT=12345
14 | fi
15 |
16 | accelerate launch --config_file $ACC_CONFIG --main_process_port=$MAIN_PORT -m openlrm.launch train.human_lrm --config $TRAIN_CONFIG
--------------------------------------------------------------------------------
/scripts/exp/run_8gpu.sh:
--------------------------------------------------------------------------------
1 | ACC_CONFIG="./configs/accelerate-train.yaml"
2 | TRAIN_CONFIG="./configs/train-sample-human.yaml"
3 |
4 | if [ -n "$1" ]; then
5 | TRAIN_CONFIG=$1
6 | else
7 | TRAIN_CONFIG="./configs/train-sample-human.yaml"
8 | fi
9 |
10 | if [ -n "$2" ]; then
11 | MAIN_PORT=$2
12 | else
13 | MAIN_PORT=12345
14 | fi
15 |
16 | accelerate launch --config_file $ACC_CONFIG --main_process_port=$MAIN_PORT -m openlrm.launch train.human_lrm --config $TRAIN_CONFIG
--------------------------------------------------------------------------------
/scripts/exp/run_debug.sh:
--------------------------------------------------------------------------------
1 | ACC_CONFIG="./configs/accelerate-train-1gpu.yaml"
2 |
3 | if [ -n "$1" ]; then
4 | TRAIN_CONFIG=$1
5 | else
6 | TRAIN_CONFIG="./configs/train-sample-human.yaml"
7 | fi
8 |
9 | if [ -n "$2" ]; then
10 | MAIN_PORT=$2
11 | else
12 | MAIN_PORT=12345
13 | fi
14 |
15 | accelerate launch --config_file $ACC_CONFIG --main_process_port=$MAIN_PORT -m openlrm.launch train.human_lrm --config $TRAIN_CONFIG
--------------------------------------------------------------------------------
/scripts/inference.sh:
--------------------------------------------------------------------------------
1 | # step1. set TRAIN_CONFIG path to config file
2 |
3 | TRAIN_CONFIG="configs/inference/lam-20k-8gpu.yaml"
4 | MODEL_NAME="model_zoo/lam_models/releases/lam/lam-20k/step_045500/"
5 | IMAGE_INPUT="assets/sample_input/status.png"
6 | MOTION_SEQS_DIR="assets/sample_motion/export/Look_In_My_Eyes/"
7 |
8 |
9 | TRAIN_CONFIG=${1:-$TRAIN_CONFIG}
10 | MODEL_NAME=${2:-$MODEL_NAME}
11 | IMAGE_INPUT=${3:-$IMAGE_INPUT}
12 | MOTION_SEQS_DIR=${4:-$MOTION_SEQS_DIR}
13 |
14 | echo "TRAIN_CONFIG: $TRAIN_CONFIG"
15 | echo "IMAGE_INPUT: $IMAGE_INPUT"
16 | echo "MODEL_NAME: $MODEL_NAME"
17 | echo "MOTION_SEQS_DIR: $MOTION_SEQS_DIR"
18 |
19 |
20 | MOTION_IMG_DIR=null
21 | SAVE_PLY=false
22 | SAVE_IMG=false
23 | VIS_MOTION=false
24 | MOTION_IMG_NEED_MASK=true
25 | RENDER_FPS=30
26 | MOTION_VIDEO_READ_FPS=30
27 | EXPORT_VIDEO=true
28 | CROSS_ID=false
29 | TEST_SAMPLE=false
30 | GAGA_TRACK_TYPE=""
31 |
32 | device=0
33 | nodes=0
34 |
35 | export PYTHONPATH=$PYTHONPATH:$pwd
36 |
37 |
38 | CUDA_VISIBLE_DEVICES=$device python -m lam.launch infer.lam --config $TRAIN_CONFIG \
39 | model_name=$MODEL_NAME image_input=$IMAGE_INPUT \
40 | export_video=$EXPORT_VIDEO export_mesh=$EXPORT_MESH \
41 | motion_seqs_dir=$MOTION_SEQS_DIR motion_img_dir=$MOTION_IMG_DIR \
42 | vis_motion=$VIS_MOTION motion_img_need_mask=$MOTION_IMG_NEED_MASK \
43 | render_fps=$RENDER_FPS motion_video_read_fps=$MOTION_VIDEO_READ_FPS \
44 | save_ply=$SAVE_PLY save_img=$SAVE_IMG \
45 | gaga_track_type=$GAGA_TRACK_TYPE cross_id=$CROSS_ID \
46 | test_sample=$TEST_SAMPLE rank=$device nodes=$nodes
47 |
--------------------------------------------------------------------------------
/scripts/install/WINDOWS_INSTALL.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ## Windows Installation Guide
5 |
6 | ### Base software
7 |
8 | - Python 3.10
9 | - Nvidia Cuda Toolkit 11.8 (You can also change to others)
10 | - Visual Studio 2019: 2022 will cause some compilation error on cuda operators. Download it from [techspot](https://www.techspot.com/downloads/7241-visual-studio-2019.html)
11 |
12 |
13 |
14 | ### Install Dependencies
15 |
16 | Note we use "x64 Native Tools" from Visual Studio as the compilation and do not use powershell or cmd. It offer MSVC environment for python package compilation.
17 |
18 | Open "x64 Native Tools" terminal and install dependencies:
19 |
20 | - Prepare environment:
21 | We recommend to use venv (or conda) to create a python environment:
22 | ```bash
23 | python -m venv lam_env
24 | lam_env\Scripts\activate
25 | git clone https://github.com/aigc3d/LAM.git
26 | git checkout feat/windows
27 | ```
28 |
29 | - Install torch 2.3.0 and xformers
30 | ```bash
31 | pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118
32 | pip install -U xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu118
33 | ```
34 |
35 | - Install python packages which do not need compilation:
36 | ```bash
37 | # pip install -r requirements.txt without the last 4 lines:
38 | head -n $((total_lines - 4)) requirements.txt | pip install -r /dev/stdin
39 | ```
40 |
41 | - Install packages which need compilation
42 | ```bash
43 | # Install Pytorch3d, which follows:
44 | # https://blog.csdn.net/m0_70229101/article/details/127196699
45 | # https://blog.csdn.net/qq_61247019/article/details/139927752
46 | set DISTUTILS_USE_SDK=1
47 | set PYTORCH3D_NO_NINJA=1
48 | git clone https://github.com/facebookresearch/pytorch3d.git
49 | cd pytorch3d
50 | # modify setup.py
51 | # add "-DWIN32_LEAN_AND_MEAN" in nvcc_args
52 | python setup.py install
53 |
54 | # Install other packages
55 | pip install git+https://github.com/ashawkey/diff-gaussian-rasterization/
56 | pip install nvdiffrast@git+https://github.com/ShenhanQian/nvdiffrast@backface-culling
57 | pip install git+https://github.com/camenduru/simple-knn/
58 |
59 | cd external/landmark_detection/FaceBoxesV2/utils/
60 | python3 build.py build_ext --inplace
61 | cd ../../../../
62 | ```
63 |
64 |
65 | ### Run
66 |
67 | ```bash
68 | python app_lam.py
69 | ```
--------------------------------------------------------------------------------
/scripts/install/install_cu118.sh:
--------------------------------------------------------------------------------
1 | # install torch 2.3.0
2 | pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118
3 | pip install -U xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu118
4 |
5 | # install dependencies
6 | pip install -r requirements.txt
7 |
8 | # === If you fail to install some modules due to network connection, you can also try the following: ===
9 | # git clone https://github.com/facebookresearch/pytorch3d.git
10 | # pip install ./pytorch3d
11 | # git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization
12 | # pip install ./diff-gaussian-rasterization
13 | # git clone https://github.com/camenduru/simple-knn.git
14 | # pip install ./simple-knn
15 |
16 | cd external/landmark_detection/FaceBoxesV2/utils/
17 | sh make.sh
18 | cd ../../../../
19 |
--------------------------------------------------------------------------------
/scripts/install/install_cu121.sh:
--------------------------------------------------------------------------------
1 | # install torch 2.3.0
2 | pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
3 | pip install -U xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu121
4 |
5 | # install dependencies
6 | pip install -r requirements.txt
7 |
8 | # === If you fail to install some modules due to network connection, you can also try the following: ===
9 | # git clone https://github.com/facebookresearch/pytorch3d.git
10 | # pip install ./pytorch3d
11 | # git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization
12 | # pip install ./diff-gaussian-rasterization
13 | # git clone https://github.com/camenduru/simple-knn.git
14 | # pip install ./simple-knn
15 |
16 | cd external/landmark_detection/FaceBoxesV2/utils/
17 | sh make.sh
18 | cd ../../../../
19 |
--------------------------------------------------------------------------------
/scripts/upload_hub.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import sys
17 |
18 | sys.path.append(".")
19 |
20 | import argparse
21 |
22 | from accelerate import Accelerator
23 |
24 | from lam.models import model_dict
25 | from lam.utils.hf_hub import wrap_model_hub
26 |
27 | if __name__ == "__main__":
28 |
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument("--model_type", type=str, required=True)
31 | parser.add_argument("--local_ckpt", type=str, required=True)
32 | parser.add_argument("--repo_id", type=str, required=True)
33 | args, unknown = parser.parse_known_args()
34 |
35 | accelerator = Accelerator()
36 |
37 | hf_model_cls = wrap_model_hub(model_dict[args.model_type])
38 | hf_model = hf_model_cls.from_pretrained(args.local_ckpt)
39 | hf_model.push_to_hub(
40 | repo_id=args.repo_id,
41 | config=hf_model.config,
42 | private=True,
43 | )
44 |
--------------------------------------------------------------------------------
/tools/AVATAR_EXPORT_GUIDE.md:
--------------------------------------------------------------------------------
1 | ## Export Chatting Avatar Guide
2 | ### 🛠️ Environment Setup
3 | #### Prerequisites
4 | ```
5 | Python FBX SDK 2020.2+
6 | Blender (version > 4.0.0)
7 | ```
8 | #### Step1: download and install python fbx-sdk and other requirements
9 | ```bash
10 | # FBX SDK: https://www.autodesk.com/developer-network/platform-technologies/fbx-sdk-2020-2
11 | # Download FBX SDK installation package, example for Linux
12 | wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/fbxsdk_linux.tar
13 | tar -xf fbxsdk_linux.tar
14 | sh tools/install_fbx_sdk.sh
15 |
16 | # Or download and install the FBX-SDK Wheel built by us
17 | wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/fbx-2020.3.4-cp310-cp310-manylinux1_x86_64.whl
18 | pip install fbx-2020.3.4-cp310-cp310-manylinux1_x86_64.whl
19 |
20 | # Install other requirements
21 | pip install pathlib
22 | pip install patool
23 | ```
24 | #### Step2: download blender
25 | ```bash
26 | # Download latest Blender (>=4.0.0)
27 | # Choose appropriate version from: https://www.blender.org/download/
28 | # Example for Linux
29 | wget https://download.blender.org/release/Blender4.0/blender-4.0.2-linux-x64.tar.xz
30 | tar -xvf blender-4.0.2-linux-x64.tar.xz -C ~/software/
31 | ```
32 | #### Step3: download chatting avatar template file
33 | ```bash
34 | # Download and extract sample files
35 | wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/sample_oac.tar
36 | tar -xf sample_oac.tar -C assets/
37 | ```
38 |
39 | ### Gradio Run
40 | ```bash
41 | # Example path for Blender executable
42 | python app_lam.py --blender_path ~/software/blender-4.0.2-linux-x64/blender
43 | ```
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc3d/LAM/39bf05821d2bb821db1f9c41acf995c960a4a1e2/tools/__init__.py
--------------------------------------------------------------------------------
/tools/convertFBX2GLB.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
3 |
4 | Blender FBX to GLB Converter
5 | Converts 3D models from FBX to glTF Binary (GLB) format with optimized settings.
6 | Requires Blender to run in background mode.
7 | """
8 |
9 | import bpy
10 | import sys
11 | from pathlib import Path
12 |
13 | def clean_scene():
14 | """Clear all objects and data from the current Blender scene"""
15 | bpy.ops.object.select_all(action='SELECT')
16 | bpy.ops.object.delete()
17 | for collection in [bpy.data.meshes, bpy.data.materials, bpy.data.textures]:
18 | for item in collection:
19 | collection.remove(item)
20 |
21 |
22 | def main():
23 | try:
24 | # Parse command line arguments after "--"
25 | argv = sys.argv[sys.argv.index("--") + 1:]
26 | input_fbx = Path(argv[0])
27 | output_glb = Path(argv[1])
28 |
29 | # Validate input file
30 | if not input_fbx.exists():
31 | raise FileNotFoundError(f"Input FBX file not found: {input_fbx}")
32 |
33 | # Prepare scene
34 | clean_scene()
35 |
36 | # Import FBX with default settings
37 | print(f"Importing {input_fbx}...")
38 | bpy.ops.import_scene.fbx(filepath=str(input_fbx))
39 |
40 | # Export optimized GLB
41 | print(f"Exporting to {output_glb}...")
42 | bpy.ops.export_scene.gltf(
43 | filepath=str(output_glb),
44 | export_format='GLB', # Binary format
45 | export_skins=True, # Keep skinning data
46 | export_texcoords=False, # Reduce file size
47 | export_normals=False, # Reduce file size
48 | export_colors=False, # Reduce file size
49 | )
50 |
51 | print("Conversion completed successfully")
52 |
53 | except Exception as e:
54 | print(f"Error: {str(e)}")
55 | sys.exit(1)
56 |
57 |
58 | if __name__ == "__main__":
59 | main()
60 |
--------------------------------------------------------------------------------
/tools/generateVertexIndices.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
3 |
4 | Blender FBX to GLB Converter
5 | Converts 3D models from FBX to glTF Binary (GLB) format with optimized settings.
6 | Requires Blender to run in background mode.
7 | """
8 |
9 | import bpy
10 | import sys
11 | import os
12 | import json
13 | from pathlib import Path
14 |
15 | def import_obj(filepath):
16 | if not os.path.exists(filepath):
17 | raise FileNotFoundError(f"文件不存在:{filepath}")
18 | bpy.ops.wm.obj_import(filepath=filepath)
19 | print(f"成功导入:{filepath}")
20 |
21 |
22 | def clean_scene():
23 | """Clear all objects and data from the current Blender scene"""
24 | bpy.ops.object.select_all(action='SELECT')
25 | bpy.ops.object.delete()
26 | for collection in [bpy.data.meshes, bpy.data.materials, bpy.data.textures]:
27 | for item in collection:
28 | collection.remove(item)
29 |
30 | def apply_rotation(obj):
31 | obj.rotation_euler = (1.5708, 0, 0)
32 | bpy.context.view_layer.update()
33 | obj.select_set(True)
34 | bpy.context.view_layer.objects.active = obj
35 | bpy.ops.object.transform_apply(location=False, rotation=True, scale=False) # 应用旋转
36 | print(f"Applied 90-degree rotation to object: {obj.name}")
37 |
38 | def main():
39 | try:
40 | # Parse command line arguments after "--"
41 | argv = sys.argv[sys.argv.index("--") + 1:]
42 | input_mesh = Path(argv[0])
43 | output_vertex_order_file = argv[1]
44 |
45 | # Validate input file
46 | if not input_mesh.exists():
47 | raise FileNotFoundError(f"Input FBX file not found: {input_mesh}")
48 |
49 | # Prepare scene
50 | clean_scene()
51 |
52 | # Import FBX with default settings
53 | print(f"Importing {input_mesh}...")
54 | import_obj(str(input_mesh))
55 | base_obj = bpy.context.view_layer.objects.active
56 |
57 | apply_rotation(base_obj)
58 |
59 | bpy.context.view_layer.objects.active = base_obj
60 | base_obj.select_set(True)
61 | bpy.ops.object.mode_set(mode='OBJECT')
62 |
63 | base_objects = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH']
64 | if len(base_objects) != 1:
65 | raise ValueError("Scene should contain exactly one base mesh object.")
66 | base_obj = base_objects[0]
67 |
68 | vertices = [(i, v.co.z) for i, v in enumerate(base_obj.data.vertices)]
69 |
70 | sorted_vertices = sorted(vertices, key=lambda x: x[1]) # 按 Z 坐标从小到大排序
71 | sorted_vertex_indices = [idx for idx, z in sorted_vertices]
72 |
73 | with open(str(output_vertex_order_file), "w") as f:
74 | json.dump(sorted_vertex_indices, f, indent=4) # 保存为 JSON 数组
75 | print(f"Exported vertex order to: {str(output_vertex_order_file)}")
76 |
77 |
78 | except Exception as e:
79 | print(f"Error: {str(e)}")
80 | sys.exit(1)
81 |
82 |
83 | if __name__ == "__main__":
84 | main()
85 |
--------------------------------------------------------------------------------
/tools/install_fbx_sdk.sh:
--------------------------------------------------------------------------------
1 | cd fbxsdk_linux
2 | chmod +x fbx202034_fbxsdk_linux fbx202034_fbxpythonbindings_linux
3 | mkdir -p ./python_binding ./python_binding/fbx_sdk
4 | yes yes | ./fbx202034_fbxpythonbindings_linux ./python_binding
5 | yes yes | ./fbx202034_fbxsdk_linux ./python_binding/fbx_sdk
6 | cd ./python_binding
7 | export FBXSDK_ROOT=./fbx_sdk
8 | pip install .
9 | cd -
--------------------------------------------------------------------------------
/vhap/config/nersemble.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from typing import Optional, Literal
11 | from dataclasses import dataclass
12 | import tyro
13 |
14 | from vhap.config.base import (
15 | StageRgbSequentialTrackingConfig, StageRgbGlobalTrackingConfig, PipelineConfig,
16 | DataConfig, LossWeightConfig, BaseTrackingConfig,
17 | )
18 | from vhap.util.log import get_logger
19 | logger = get_logger(__name__)
20 |
21 |
22 | @dataclass()
23 | class NersembleDataConfig(DataConfig):
24 | _target: str = "vhap.data.nersemble_dataset.NeRSembleDataset"
25 | calibrated: bool = True
26 | image_size_during_calibration: Optional[tuple[int, int]] = (3208, 2200)
27 | """(height, width). Will be use to convert principle points when the image size is not included in the camera parameters."""
28 | background_color: Optional[Literal['white', 'black']] = None
29 | landmark_source: Optional[Literal["face-alignment", 'star']] = "star"
30 |
31 | subject: str = ""
32 | """Subject ID. Such as 018, 218, 251, 253"""
33 | use_color_correction: bool = True
34 | """Whether to use color correction to harmonize the color of the input images."""
35 |
36 | @dataclass()
37 | class NersembleLossWeightConfig(LossWeightConfig):
38 | landmark: Optional[float] = 3. # should not be lower to avoid collapse
39 | always_enable_jawline_landmarks: bool = False # allow disable_jawline_landmarks in StageConfig to work
40 | reg_expr: float = 1e-2 # for best expressivness
41 | reg_tex_tv: Optional[float] = 1e5 # 10x of the base value
42 |
43 | @dataclass()
44 | class NersembleStageRgbSequentialTrackingConfig(StageRgbSequentialTrackingConfig):
45 | optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "dynamic_offset")
46 |
47 | align_texture_except: tuple[str, ...] = ("boundary",)
48 | align_boundary_except: tuple[str, ...] = ("boundary",)
49 | """Due to the limited flexibility in the lower neck region of FLAME, we relax the
50 | alignment constraints for better alignment in the face region.
51 | """
52 |
53 | @dataclass()
54 | class NersembleStageRgbGlobalTrackingConfig(StageRgbGlobalTrackingConfig):
55 | align_texture_except: tuple[str, ...] = ("boundary",)
56 | align_boundary_except: tuple[str, ...] = ("boundary",)
57 | """Due to the limited flexibility in the lower neck region of FLAME, we relax the
58 | alignment constraints for better alignment in the face region.
59 | """
60 |
61 | @dataclass()
62 | class NersemblePipelineConfig(PipelineConfig):
63 | rgb_sequential_tracking: NersembleStageRgbSequentialTrackingConfig
64 | rgb_global_tracking: NersembleStageRgbGlobalTrackingConfig
65 |
66 | @dataclass()
67 | class NersembleTrackingConfig(BaseTrackingConfig):
68 | data: NersembleDataConfig
69 | w: NersembleLossWeightConfig
70 | pipeline: NersemblePipelineConfig
71 |
72 | def get_occluded(self):
73 | occluded_table = {
74 | '018': ('neck_lower',),
75 | '218': ('neck_lower',),
76 | '251': ('neck_lower', 'boundary'),
77 | '253': ('neck_lower',),
78 | }
79 | if self.data.subject in occluded_table:
80 | logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.subject]}")
81 | self.model.occluded = occluded_table[self.data.subject]
82 |
83 |
84 | if __name__ == "__main__":
85 | config = tyro.cli(NersembleTrackingConfig)
86 | print(tyro.to_yaml(config))
--------------------------------------------------------------------------------
/vhap/data/image_folder_dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional
3 | import numpy as np
4 | import PIL.Image as Image
5 | from torch.utils.data import Dataset
6 | from vhap.util.log import get_logger
7 |
8 |
9 | logger = get_logger(__name__)
10 |
11 |
12 | class ImageFolderDataset(Dataset):
13 | def __init__(
14 | self,
15 | image_folder: Path,
16 | background_folder: Optional[Path]=None,
17 | background_fname2camId=lambda x: x,
18 | image_fname2camId=lambda x: x,
19 | ):
20 | """
21 | Args:
22 | root_folder: Path to dataset with the following directory layout
23 | /
24 | |---xx.jpg
25 | |---...
26 | """
27 | super().__init__()
28 | self.image_fname2camId = image_fname2camId
29 | self.background_foler = background_folder
30 |
31 | logger.info(f"Initializing dataset from folder {image_folder}")
32 |
33 | self.image_paths = sorted(list(image_folder.glob('*.jpg')))
34 |
35 | if background_folder is not None:
36 | self.backgrounds = {}
37 | background_paths = sorted(list((image_folder / background_folder).glob('*.jpg')))
38 |
39 | for background_path in background_paths:
40 | bg = np.array(Image.open(background_path))
41 | cam_id = background_fname2camId(background_path.name)
42 | self.backgrounds[cam_id] = bg
43 |
44 | def __len__(self):
45 | return len(self.image_paths)
46 |
47 | def __getitem__(self, i):
48 | image_path = self.image_paths[i]
49 | cam_id = self.image_fname2camId(image_path.name)
50 | rgb = np.array(Image.open(image_path))
51 | item = {
52 | "rgb": rgb,
53 | 'image_path': str(image_path),
54 | }
55 |
56 | if self.background_foler is not None:
57 | item['background'] = self.backgrounds[cam_id]
58 |
59 | return item
60 |
61 |
62 | if __name__ == "__main__":
63 | from tqdm import tqdm
64 | from torch.utils.data import DataLoader
65 |
66 | dataset = ImageFolderDataset(
67 | image_folder='./xx',
68 | img_to_tensor=True,
69 | )
70 |
71 | print(len(dataset))
72 |
73 | sample = dataset[0]
74 | print(sample.keys())
75 | print(sample["rgb"].shape)
76 |
77 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
78 | for item in tqdm(dataloader):
79 | pass
80 |
--------------------------------------------------------------------------------
/vhap/generate_flame_uvmask.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from typing import Literal
11 | import tyro
12 | import numpy as np
13 | from PIL import Image
14 | from pathlib import Path
15 | import torch
16 | import nvdiffrast.torch as dr
17 | from vhap.util.render_uvmap import render_uvmap_vtex
18 | from vhap.model.flame import FlameHead
19 |
20 |
21 | FLAME_UV_MASK_FOLDER = "asset/flame/uv_masks"
22 | FLAME_UV_MASK_NPZ = "asset/flame/uv_masks.npz"
23 |
24 |
25 | def main(
26 | use_opengl: bool = False,
27 | device: Literal['cuda', 'cpu'] = 'cuda',
28 | ):
29 | n_shape = 300
30 | n_expr = 100
31 | print("Initializing FLAME model")
32 | flame_model = FlameHead(n_shape, n_expr, add_teeth=True)
33 |
34 | flame_model = FlameHead(
35 | n_shape,
36 | n_expr,
37 | add_teeth=True,
38 | ).cuda()
39 |
40 | faces = flame_model.faces.int().cuda()
41 | verts_uv = flame_model.verts_uvs.cuda()
42 | # verts_uv[:, 1] = 1 - verts_uv[:, 1]
43 | faces_uv = flame_model.textures_idx.int().cuda()
44 | col_idx = faces_uv
45 |
46 | # Rasterizer context
47 | glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()
48 |
49 | h, w = 2048, 2048
50 | resolution = (h, w)
51 |
52 | if not Path(FLAME_UV_MASK_FOLDER).exists():
53 | Path(FLAME_UV_MASK_FOLDER).mkdir(parents=True)
54 |
55 | # alpha_maps = {}
56 | masks = {}
57 | for region, vt_mask in flame_model.mask.vt:
58 | v_color = torch.zeros(verts_uv.shape[0], 1).to(device) # alpha channel
59 | v_color[vt_mask] = 1
60 |
61 | alpha = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)[0]
62 | alpha = alpha.flip(0)
63 | # alpha_maps[region] = alpha.cpu().numpy()
64 | mask = (alpha > 0.5) # to avoid overlap between hair and face
65 | mask = mask.squeeze(-1).cpu().numpy()
66 | masks[region] = mask # (h, w)
67 |
68 | print(f"Saving uv mask for {region}...")
69 | # rgba = mask.expand(-1, -1, 4) # (h, w, 4)
70 | # rgb = torch.ones_like(mask).expand(-1, -1, 3) # (h, w, 3)
71 | # rgba = torch.cat([rgb, mask], dim=-1).cpu().numpy() # (h, w, 4)
72 | img = mask
73 | img = Image.fromarray((img * 255).astype(np.uint8))
74 | img.save(Path(FLAME_UV_MASK_FOLDER) / f"{region}.png")
75 |
76 | print(f"Saving uv mask into: {FLAME_UV_MASK_NPZ}")
77 | np.savez_compressed(FLAME_UV_MASK_NPZ, **masks)
78 |
79 |
80 | if __name__ == "__main__":
81 | tyro.cli(main)
--------------------------------------------------------------------------------
/vhap/track.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import tyro
11 |
12 | from vhap.config.base import BaseTrackingConfig
13 | from vhap.model.tracker import GlobalTracker
14 |
15 |
16 | if __name__ == "__main__":
17 | tyro.extras.set_accent_color("bright_yellow")
18 | cfg = tyro.cli(BaseTrackingConfig)
19 |
20 | tracker = GlobalTracker(cfg)
21 | tracker.optimize()
22 |
--------------------------------------------------------------------------------
/vhap/track_nersemble.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import tyro
11 |
12 | from vhap.config.nersemble import NersembleTrackingConfig
13 | from vhap.model.tracker import GlobalTracker
14 |
15 |
16 | if __name__ == "__main__":
17 | tyro.extras.set_accent_color("bright_yellow")
18 | cfg = tyro.cli(NersembleTrackingConfig)
19 |
20 | tracker = GlobalTracker(cfg)
21 | tracker.optimize()
22 |
--------------------------------------------------------------------------------
/vhap/util/log.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import logging
11 | import sys
12 | from datetime import datetime
13 | import atexit
14 | from pathlib import Path
15 |
16 |
17 | def _colored(msg, color):
18 | colors = {'red': '\033[91m', 'green': '\033[92m', 'yellow': '\033[93m', 'normal': '\033[0m'}
19 | return colors[color] + msg + colors["normal"]
20 |
21 |
22 | class ColorFormatter(logging.Formatter):
23 | """
24 | Class to make command line log entries more appealing
25 | Inspired by https://github.com/facebookresearch/detectron2
26 | """
27 |
28 | def formatMessage(self, record):
29 | """
30 | Print warnings yellow and errors red
31 | :param record:
32 | :return:
33 | """
34 | log = super().formatMessage(record)
35 | if record.levelno == logging.WARNING:
36 | prefix = _colored("WARNING", "yellow")
37 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
38 | prefix = _colored("ERROR", "red")
39 | else:
40 | return log
41 | return prefix + " " + log
42 |
43 |
44 | def get_logger(name, level=logging.DEBUG, root=False, log_dir=None):
45 | """
46 | Replaces the standard library logging.getLogger call in order to make some configuration
47 | for all loggers.
48 | :param name: pass the __name__ variable
49 | :param level: the desired log level
50 | :param root: call only once in the program
51 | :param log_dir: if root is set to True, this defines the directory where a log file is going
52 | to be created that contains all logging output
53 | :return: the logger object
54 | """
55 | logger = logging.getLogger(name)
56 | logger.setLevel(level)
57 |
58 | if root:
59 | # create handler for console
60 | console_handler = logging.StreamHandler(sys.stdout)
61 | console_handler.setLevel(level)
62 | formatter = ColorFormatter(_colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
63 | datefmt="%m/%d %H:%M:%S")
64 | console_handler.setFormatter(formatter)
65 | logger.addHandler(console_handler)
66 | logger.propagate = False # otherwise root logger prints things again
67 |
68 | if log_dir is not None:
69 | # add handler to log to a file
70 | log_dir = Path(log_dir)
71 | if not log_dir.exists():
72 | logger.info(f"Logging directory {log_dir} does not exist and will be created")
73 | log_dir.mkdir(parents=True)
74 | timestamp = datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
75 | log_file = log_dir / f"{timestamp}.log"
76 |
77 | # open stream and make sure it will be closed
78 | stream = log_file.open(mode="w")
79 | atexit.register(stream.close)
80 |
81 | formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
82 | datefmt="%m/%d %H:%M:%S")
83 | file_handler = logging.StreamHandler(stream)
84 | file_handler.setLevel(level)
85 | file_handler.setFormatter(formatter)
86 | logger.addHandler(file_handler)
87 |
88 | return logger
89 |
--------------------------------------------------------------------------------
/vhap/util/mesh.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import torch
11 |
12 |
13 | def get_mtl_content(tex_fname):
14 | return f'newmtl Material\nmap_Kd {tex_fname}\n'
15 |
16 | def get_obj_content(vertices, faces, uv_coordinates=None, uv_indices=None, mtl_fname=None):
17 | obj = ('# Generated with multi-view-head-tracker\n')
18 |
19 | if mtl_fname is not None:
20 | obj += f'mtllib {mtl_fname}\n'
21 | obj += 'usemtl Material\n'
22 |
23 | # Write the vertices
24 | for vertex in vertices:
25 | obj += f"v {vertex[0]} {vertex[1]} {vertex[2]}\n"
26 |
27 | # Write the UV coordinates
28 | if uv_coordinates is not None:
29 | for uv in uv_coordinates:
30 | obj += f"vt {uv[0]} {uv[1]}\n"
31 |
32 | # Write the faces with UV indices
33 | if uv_indices is not None:
34 | for face, uv_indices in zip(faces, uv_indices):
35 | obj += f"f {face[0]+1}/{uv_indices[0]+1} {face[1]+1}/{uv_indices[1]+1} {face[2]+1}/{uv_indices[2]+1}\n"
36 | else:
37 | for face in faces:
38 | obj += f"f {face[0]+1} {face[1]+1} {face[2]+1}\n"
39 | return obj
40 |
41 | def normalize_image_points(u, v, resolution):
42 | """
43 | normalizes u, v coordinates from [0 ,image_size] to [-1, 1]
44 | :param u:
45 | :param v:
46 | :param resolution:
47 | :return:
48 | """
49 | u = 2 * (u - resolution[1] / 2.0) / resolution[1]
50 | v = 2 * (v - resolution[0] / 2.0) / resolution[0]
51 | return u, v
52 |
53 |
54 | def face_vertices(vertices, faces):
55 | """
56 | :param vertices: [batch size, number of vertices, 3]
57 | :param faces: [batch size, number of faces, 3]
58 | :return: [batch size, number of faces, 3, 3]
59 | """
60 | assert vertices.ndimension() == 3
61 | assert faces.ndimension() == 3
62 | assert vertices.shape[0] == faces.shape[0]
63 | assert vertices.shape[2] == 3
64 | assert faces.shape[2] == 3
65 |
66 | bs, nv = vertices.shape[:2]
67 | bs, nf = faces.shape[:2]
68 | device = vertices.device
69 | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
70 | vertices = vertices.reshape((bs * nv, 3))
71 | # pytorch only supports long and byte tensors for indexing
72 | return vertices[faces.long()]
73 |
74 |
--------------------------------------------------------------------------------
/vhap/util/render_uvmap.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import tyro
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import torch
14 | import nvdiffrast.torch as dr
15 |
16 | from vhap.model.flame import FlameHead
17 |
18 |
19 | FLAME_TEX_PATH = "asset/flame/FLAME_texture.npz"
20 |
21 |
22 | def transform_vt(vt):
23 | """Transform uv vertices to clip space"""
24 | xy = vt * 2 - 1
25 | w = torch.ones([1, vt.shape[-2], 1]).to(vt)
26 | z = -w # In the clip spcae of OpenGL, the camera looks at -z
27 | xyzw = torch.cat([xy[None, :, :], z, w], axis=-1)
28 | return xyzw
29 |
30 | def render_uvmap_vtex(glctx, pos, pos_idx, v_color, col_idx, resolution):
31 | """Render uv map with vertex color"""
32 | pos_clip = transform_vt(pos)
33 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution)
34 |
35 | color, _ = dr.interpolate(v_color, rast_out, col_idx)
36 | color = dr.antialias(color, rast_out, pos_clip, pos_idx)
37 | return color
38 |
39 | def render_uvmap_texmap(glctx, pos, pos_idx, verts_uv, faces_uv, tex, resolution, enable_mip=True, max_mip_level=None):
40 | """Render uv map with texture map"""
41 | pos_clip = transform_vt(pos)
42 | rast_out, rast_out_db = dr.rasterize(glctx, pos_clip, pos_idx, resolution)
43 |
44 | if enable_mip:
45 | texc, texd = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv, rast_db=rast_out_db, diff_attrs='all')
46 | color = dr.texture(tex[None, ...], texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=max_mip_level)
47 | else:
48 | texc, _ = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv)
49 | color = dr.texture(tex[None, ...], texc, filter_mode='linear')
50 | color = dr.antialias(color, rast_out, pos_clip, pos_idx)
51 | return color
52 |
53 |
54 | def main(
55 | use_texmap: bool = False,
56 | use_opengl: bool = False,
57 | ):
58 | n_shape = 300
59 | n_expr = 100
60 | print("Initialization FLAME model")
61 | flame_model = FlameHead(n_shape, n_expr)
62 |
63 | verts_uv = flame_model.verts_uvs.cuda()
64 | verts_uv[:, 1] = 1 - verts_uv[:, 1]
65 | faces_uv = flame_model.textures_idx.int().cuda()
66 |
67 | # Rasterizer context
68 | glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()
69 |
70 | h, w = 512, 512
71 | resolution = (h, w)
72 |
73 | if use_texmap:
74 | tex = torch.from_numpy(np.load(FLAME_TEX_PATH)['mean']).cuda().float().flip(dims=[-1]) / 255
75 | rgb = render_uvmap_texmap(glctx, verts_uv, faces_uv, verts_uv, faces_uv, tex, resolution, enable_mip=True)
76 | else:
77 | v_color = torch.ones(verts_uv.shape[0], 3).to(verts_uv)
78 | col_idx = faces_uv
79 | rgb = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)
80 |
81 | plt.imshow(rgb[0, :, :].cpu())
82 | plt.show()
83 |
84 |
85 | if __name__ == "__main__":
86 | tyro.cli(main)
87 |
--------------------------------------------------------------------------------
/vhap/util/vector_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5 | return torch.sum(x*y, -1, keepdim=True)
6 |
7 | def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
8 | return 2*dot(x, n)*n - x
9 |
10 | def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
11 | return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
12 |
13 | def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
14 | return x / length(x, eps)
15 |
16 | def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
17 | return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w)
18 |
--------------------------------------------------------------------------------
/vhap/util/visualization.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import matplotlib.pyplot as plt
11 | import torch
12 | from torchvision.utils import draw_bounding_boxes, draw_keypoints
13 |
14 |
15 | connectivity_face = (
16 | [(i, i + 1) for i in list(range(0, 16))]
17 | + [(i, i + 1) for i in list(range(17, 21))]
18 | + [(i, i + 1) for i in list(range(22, 26))]
19 | + [(i, i + 1) for i in list(range(27, 30))]
20 | + [(i, i + 1) for i in list(range(31, 35))]
21 | + [(i, i + 1) for i in list(range(36, 41))]
22 | + [(36, 41)]
23 | + [(i, i + 1) for i in list(range(42, 47))]
24 | + [(42, 47)]
25 | + [(i, i + 1) for i in list(range(48, 59))]
26 | + [(48, 59)]
27 | + [(i, i + 1) for i in list(range(60, 67))]
28 | + [(60, 67)]
29 | )
30 |
31 |
32 | def plot_landmarks_2d(
33 | img: torch.tensor,
34 | lmks: torch.tensor,
35 | connectivity=None,
36 | colors="white",
37 | unit=1,
38 | input_float=False,
39 | ):
40 | if input_float:
41 | img = (img * 255).byte()
42 |
43 | img = draw_keypoints(
44 | img,
45 | lmks,
46 | connectivity=connectivity,
47 | colors=colors,
48 | radius=2 * unit,
49 | width=2 * unit,
50 | )
51 |
52 | if input_float:
53 | img = img.float() / 255
54 | return img
55 |
56 |
57 | def blend(a, b, w):
58 | return (a * w + b * (1 - w)).byte()
59 |
60 |
61 | if __name__ == "__main__":
62 | from argparse import ArgumentParser
63 | from torch.utils.data import DataLoader
64 | from matplotlib import pyplot as plt
65 |
66 | from vhap.data.nersemble_dataset import NeRSembleDataset
67 |
68 | parser = ArgumentParser()
69 | parser.add_argument("--root_folder", type=str, required=True)
70 | parser.add_argument("--subject", type=str, required=True)
71 | parser.add_argument("--sequence", type=str, required=True)
72 | parser.add_argument("--division", default=None)
73 | parser.add_argument("--subset", default=None)
74 | parser.add_argument("--scale_factor", type=float, default=1.0)
75 | parser.add_argument("--blend_weight", type=float, default=0.6)
76 | args = parser.parse_args()
77 |
78 | dataset = NeRSembleDataset(
79 | root_folder=args.root_folder,
80 | subject=args.subject,
81 | sequence=args.sequence,
82 | division=args.division,
83 | subset=args.subset,
84 | n_downsample_rgb=2,
85 | scale_factor=args.scale_factor,
86 | use_landmark=True,
87 | )
88 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
89 |
90 | for item in dataloader:
91 | unit = int(item["scale_factor"][0] * 3) + 1
92 |
93 | rgb = item["rgb"][0].permute(2, 0, 1)
94 | vis = rgb
95 |
96 | if "bbox_2d" in item:
97 | bbox = item["bbox_2d"][0][:4]
98 | tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit)
99 | vis = blend(tmp, vis, args.blend_weight)
100 |
101 | if "lmk2d" in item:
102 | face_landmark = item["lmk2d"][0][:, :2]
103 | tmp = plot_landmarks_2d(
104 | vis,
105 | face_landmark[None, ...],
106 | connectivity=connectivity_face,
107 | colors="white",
108 | unit=unit,
109 | )
110 | vis = blend(tmp, vis, args.blend_weight)
111 |
112 | if "lmk2d_iris" in item:
113 | iris_landmark = item["lmk2d_iris"][0][:, :2]
114 | tmp = plot_landmarks_2d(
115 | vis,
116 | iris_landmark[None, ...],
117 | colors="blue",
118 | unit=unit,
119 | )
120 | vis = blend(tmp, vis, args.blend_weight)
121 |
122 | vis = vis.permute(1, 2, 0).numpy()
123 | plt.imshow(vis)
124 | plt.draw()
125 | while not plt.waitforbuttonpress(timeout=-1):
126 | pass
127 |
--------------------------------------------------------------------------------