├── test ├── __init__.py ├── test_camera_distribution.py ├── test_patching.py ├── test_semantic_pooling.py ├── test_specular.py └── test_raycast_rgbd.py ├── util ├── __init__.py ├── timer.py ├── filesystem_logger.py ├── df_metrics.py ├── misc.py └── camera.py ├── evaluation └── __init__.py ├── model ├── eg3d │ └── __init__.py ├── pigan │ ├── __init__.py │ ├── fid_evaluation.py │ └── discriminators.py ├── raycast_rgbd │ ├── __init__.py │ ├── setup.py │ ├── cudaUtil.h │ └── raycast_rgbd_cuda.cpp ├── loss.py ├── uv │ └── __init__.py ├── stylegan2 │ └── __init__.py ├── __init__.py ├── differentiable_renderer.py ├── differentiable_renderer_light.py ├── discriminator.py └── styleganvox │ ├── __init__.py │ └── generator.py ├── data_processing ├── __init__.py ├── create_uv_charts.py └── create_uv_charts_car.py ├── .gitmodules ├── requirements.txt ├── .gitignore ├── config ├── stylegan2.yaml └── stylegan2_car.yaml ├── trainer ├── __init__.py └── train_autoencoder.py ├── dataset ├── distance_field.py ├── mesh_uniform.py ├── mesh_real_sdfgrid.py ├── mesh_real_volume.py ├── mesh_real.py ├── mesh_cube.py ├── mesh_real_pigan.py └── meshcar_real_sdfgrid.py └── README.md /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/eg3d/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/pigan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_processing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/raycast_rgbd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/ChamferDistancePytorch"] 2 | path = external/ChamferDistancePytorch 3 | url = git@github.com:nihalsid/ChamferDistancePytorch.git 4 | -------------------------------------------------------------------------------- /util/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Timer(object): 5 | def __init__(self, name): 6 | self.name = name 7 | 8 | def __enter__(self): 9 | self.tstart = time.time() 10 | 11 | def __exit__(self, _, value, traceback): 12 | print('[%s] Elapsed: %s' % (self.name, time.time() - self.tstart)) 13 | -------------------------------------------------------------------------------- /model/raycast_rgbd/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='raycast_rgbd_cuda', 6 | ext_modules=[ 7 | CUDAExtension('raycast_rgbd_cuda', [ 8 | 'raycast_rgbd_cuda.cpp', 9 | 'raycast_rgbd_cuda_kernel.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ballpark~=1.4.0 2 | argparse~=1.4.0 3 | pyyaml 4 | typing~=3.6.4 5 | omegaconf~=2.0.6 6 | numpy~=1.19.4 7 | torchvision~=0.11.1 8 | pillow~=8.4.0 9 | wandb~=0.12.2 10 | hydra~=2.5 11 | pytorch-lightning 12 | tqdm~=4.48.2 13 | torchmetrics~=0.6.0 14 | pyrender 15 | scipy~=1.7.1 16 | setuptools~=49.6.0 17 | matplotlib 18 | opencv-python~=3.4.9.33 19 | scikit-image~=0.19.1 20 | ninja 21 | imageio 22 | imageio-ffmpeg 23 | hydra-core 24 | torch-ema==0.2 25 | clean-fid -------------------------------------------------------------------------------- /model/raycast_rgbd/cudaUtil.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef _CUDA_UTIL_ 4 | #define _CUDA_UTIL_ 5 | 6 | #undef max 7 | #undef min 8 | 9 | #include "cutil_inline_runtime.h" 10 | 11 | // Enable run time assertion checking in kernel code 12 | #define cudaAssert(condition) if (!(condition)) { printf("ASSERT: %s %s\n", #condition, __FILE__); } 13 | //#define cudaAssert(condition) 14 | 15 | #if defined(__CUDA_ARCH__) 16 | #define __CONDITIONAL_UNROLL__ #pragma unroll 17 | #else 18 | #define __CONDITIONAL_UNROLL__ 19 | #endif 20 | 21 | // math helpers 22 | #include "cutil_math.h" 23 | 24 | #ifndef sint 25 | typedef signed int sint; 26 | #endif 27 | 28 | #ifndef uint 29 | typedef unsigned int uint; 30 | #endif 31 | 32 | #ifndef slong 33 | typedef signed long slong; 34 | #endif 35 | 36 | #ifndef ulong 37 | typedef unsigned long ulong; 38 | #endif 39 | 40 | #ifndef uchar 41 | typedef unsigned char uchar; 42 | #endif 43 | 44 | #ifndef schar 45 | typedef signed char schar; 46 | #endif 47 | 48 | #ifndef MINF 49 | #define MINF __int_as_float(0xff800000) 50 | #endif 51 | 52 | #ifndef PINF 53 | #define PINF __int_as_float(0x7f800000) 54 | #endif 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | 5 | 6 | def compute_gradient_penalty(x, d): 7 | gradients = torch.autograd.grad(outputs=[d.sum()], inputs=[x], create_graph=True, only_inputs=True)[0] 8 | r1_penalty = gradients.square().sum([1, 2, 3]).mean() 9 | return r1_penalty / 2 10 | 11 | 12 | class PathLengthPenalty(nn.Module): 13 | 14 | def __init__(self, pl_decay, pl_batch_shrink): 15 | super().__init__() 16 | self.pl_batch_shrink = pl_batch_shrink 17 | self.pl_decay = pl_decay 18 | self.pl_mean = nn.Parameter(torch.zeros([1]), requires_grad=False) 19 | 20 | def forward(self, fake, w): 21 | pl_noise = torch.randn_like(fake) / np.sqrt(fake.shape[2] * fake.shape[3]) 22 | pl_grads = torch.autograd.grad(outputs=[(fake * pl_noise).sum()], inputs=[w], create_graph=True, only_inputs=True)[0] 23 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() 24 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) 25 | self.pl_mean.copy_(pl_mean.detach()) 26 | pl_penalty = (pl_lengths - pl_mean).square() 27 | return pl_penalty.mean() 28 | -------------------------------------------------------------------------------- /util/filesystem_logger.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from pathlib import Path 4 | from typing import Dict, Optional, Union 5 | 6 | from omegaconf import OmegaConf 7 | from pytorch_lightning.loggers import LightningLoggerBase 8 | from pytorch_lightning.loggers.base import rank_zero_experiment, DummyExperiment 9 | 10 | 11 | class FilesystemLogger(LightningLoggerBase): 12 | 13 | @property 14 | def version(self) -> Union[int, str]: 15 | return 0 16 | 17 | @property 18 | def name(self) -> str: 19 | return "fslogger" 20 | 21 | # noinspection PyMethodOverriding 22 | def log_hyperparams(self, params: argparse.Namespace): 23 | pass 24 | 25 | def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): 26 | pass 27 | 28 | def __init__(self, experiment_config, **_kwargs): 29 | super().__init__() 30 | self.experiment_config = experiment_config 31 | self._experiment = None 32 | # noinspection PyStatementEffect 33 | self.experiment 34 | 35 | @property 36 | @rank_zero_experiment 37 | def experiment(self): 38 | if self._experiment is None: 39 | self._experiment = DummyExperiment() 40 | experiment_dir = Path("runs", self.experiment_config["experiment"]) 41 | experiment_dir.mkdir(exist_ok=True, parents=True) 42 | 43 | src_folders = ['config', 'data/splits', 'model', 'tests', 'trainer', 'util', 'data_processing', 'dataset'] 44 | sources = [] 45 | for src in src_folders: 46 | sources.extend(list(Path(".").glob(f'{src}/**/*'))) 47 | 48 | files_to_copy = [x for x in sources if x.suffix in [".py", ".pyx", ".txt", ".so", ".pyd", ".h", ".cu", ".c", '.cpp', ".html"] and x.parts[0] != "runs" and x.parts[0] != "wandb"] 49 | 50 | for f in files_to_copy: 51 | Path(experiment_dir, "code", f).parents[0].mkdir(parents=True, exist_ok=True) 52 | shutil.copyfile(f, Path(experiment_dir, "code", f)) 53 | 54 | Path(experiment_dir, "config.yaml").write_text(OmegaConf.to_yaml(self.experiment_config)) 55 | 56 | return self._experiment 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | slurm/ 3 | !data/splits 4 | wandb/ 5 | lightning_logs/ 6 | runs/ 7 | runs 8 | slurm/inference/ 9 | slurm/groups/ 10 | *.pyd 11 | data_processing/sdf_gen/bin 12 | data_processing/sdf_gen/build 13 | 14 | # Editors 15 | .vscode/ 16 | .idea/ 17 | 18 | # Vagrant 19 | .vagrant/ 20 | 21 | # Mac/OSX 22 | .DS_Store 23 | 24 | # Windows 25 | Thumbs.db 26 | 27 | # Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore 28 | # Byte-compiled / optimized / DLL files 29 | __pycache__/ 30 | *.py[cod] 31 | *$py.class 32 | 33 | *.out 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | MANIFEST 55 | 56 | # PyInstaller 57 | # Usually these files are written by a python script from a template 58 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .nox/ 70 | .coverage 71 | .coverage.* 72 | .cache 73 | nosetests.xml 74 | coverage.xml 75 | *.cover 76 | .hypothesis/ 77 | .pytest_cache/ 78 | 79 | # Translations 80 | *.mo 81 | *.pot 82 | 83 | # Django stuff: 84 | *.log 85 | local_settings.py 86 | db.sqlite3 87 | 88 | # Flask stuff: 89 | instance/ 90 | .webassets-cache 91 | 92 | # Scrapy stuff: 93 | .scrapy 94 | 95 | # Sphinx documentation 96 | docs/_build/ 97 | 98 | # PyBuilder 99 | target/ 100 | 101 | # Jupyter Notebook 102 | .ipynb_checkpoints 103 | 104 | # IPython 105 | profile_default/ 106 | ipython_config.py 107 | 108 | # pyenv 109 | .python-version 110 | 111 | # celery beat schedule file 112 | celerybeat-schedule 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | -------------------------------------------------------------------------------- /test/test_camera_distribution.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import numpy as np 4 | import random 5 | from util.camera import spherical_coord_to_cam 6 | import math 7 | from scipy.spatial.transform import Rotation 8 | 9 | pairmeta_path = Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape-model/metadata/pairs.json") 10 | image_path = Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape/exemplars") 11 | 12 | 13 | def meta_to_pair(c): 14 | return f'shape{c["shape_id"]:05d}_rank{(c["rank"] - 1):02d}_pair{c["id"]}' 15 | 16 | 17 | def load_pair_meta_views(image_path, pairmeta_path): 18 | dataset_images = [x.stem for x in image_path.iterdir()] 19 | loaded_json = json.loads(Path(pairmeta_path).read_text()) 20 | ret_dict = {} 21 | for k in loaded_json.keys(): 22 | if meta_to_pair(loaded_json[k]) in dataset_images: 23 | ret_dict[meta_to_pair(loaded_json[k])] = loaded_json[k] 24 | return ret_dict 25 | 26 | 27 | views_photoshape = load_pair_meta_views(image_path, pairmeta_path) 28 | view_keys = sorted(list(views_photoshape.keys())) 29 | 30 | 31 | def test_camera_distribution(): 32 | positions = [] 33 | print(len(view_keys)) 34 | for vk in view_keys: 35 | c_v = views_photoshape[vk] 36 | noise_azimuth = (random.random() - 0.5) * 0.075 37 | noise_elevation = (random.random() - 0.5) * 0.075 38 | perspective_cam = spherical_coord_to_cam(c_v['fov'], c_v['azimuth'] + noise_azimuth, c_v['elevation'] + noise_elevation) 39 | positions.append(perspective_cam.position) 40 | Path("camera_distribution.obj").write_text("\n".join([f"v {p[0]} {p[1]} {p[2]}" for p in positions])) 41 | 42 | 43 | def test_camera_distribution_sdf(): 44 | positions = [] 45 | print(len(view_keys)) 46 | for vk in view_keys: 47 | c_v = views_photoshape[vk] 48 | y_angle = c_v['azimuth'] * 180 / math.pi 49 | x_angle = 90 - c_v['elevation'] * 180 / math.pi 50 | z_angle = 180 # random.random() * 90 51 | camera_rot = np.eye(4) 52 | camera_rot[:3, :3] = Rotation.from_euler('x', x_angle, degrees=True).as_matrix() @ Rotation.from_euler('z', z_angle, degrees=True).as_matrix() @ Rotation.from_euler('y', y_angle, degrees=True).as_matrix() 53 | camera_translation = np.eye(4) 54 | camera_translation[:3, 3] = np.array([0, 0, 1.75]) 55 | camera_pose = np.linalg.inv(camera_translation @ camera_rot) 56 | noise = np.array([random.random(), random.random(), random.random()]) * 0.12 * 0 57 | positions.append(camera_pose[:3, 3] + noise) 58 | Path("camera_distribution_sdf.obj").write_text("\n".join([f"v {p[0]} {p[1]} {p[2]}" for p in positions])) 59 | 60 | 61 | if __name__ == '__main__': 62 | test_camera_distribution() 63 | -------------------------------------------------------------------------------- /test/test_patching.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hydra 4 | from torchvision.utils import save_image 5 | from tqdm import tqdm 6 | import torch 7 | 8 | from dataset.mesh_real_features_patch import FaceGraphMeshDataset 9 | from dataset import GraphDataLoader, to_device, to_vertex_colors_scatter 10 | from model.differentiable_renderer import DifferentiableRenderer 11 | from trainer.train_stylegan_real_feature_patch import StyleGAN2Trainer 12 | 13 | 14 | @hydra.main(config_path='../config', config_name='stylegan2') 15 | def test_patches(config): 16 | dataset = FaceGraphMeshDataset(config) 17 | dataloader = GraphDataLoader(dataset, batch_size=config.batch_size, num_workers=0) 18 | render_helper = DifferentiableRenderer(config.image_size, 'bounds', config.colorspace, num_channels=4).cuda() 19 | Path("runs/patches").mkdir(exist_ok=True) 20 | for batch_idx, batch in enumerate(tqdm(dataloader)): 21 | batch = to_device(batch, torch.device("cuda:0")) 22 | batch['real_hres'] = dataset.get_color_bg_real_hres(batch) 23 | first_views = list(range(0, batch['real_hres'].shape[0], config.views_per_sample)) 24 | patches = StyleGAN2Trainer.extract_patches_from_tensor(batch['real_hres'][first_views], batch['mask_hres'][first_views, 0, :, :], config.num_patch_per_view * config.views_per_sample, config.patch_size) 25 | patches = patches.reshape((config.batch_size * config.views_per_sample * config.num_patch_per_view, 3, config.patch_size, config.patch_size)) 26 | rendered_color_gt = render_helper.render(batch['vertices'], batch['indices'], to_vertex_colors_scatter(batch["y"], batch), batch["ranges"].cpu(), batch['bg'], resolution=config.image_size_hres).permute((0, 3, 1, 2)) 27 | rend_patches = StyleGAN2Trainer.extract_patches_from_tensor(rendered_color_gt[:, :3, :, :], 1 - rendered_color_gt[:, 3, :, :], config.num_patch_per_view, config.patch_size) 28 | rend_patches = rend_patches.reshape((config.batch_size * config.views_per_sample * config.num_patch_per_view, 3, config.patch_size, config.patch_size)) 29 | patches = dataset.cspace_convert_back(patches) 30 | batch['real_hres'] = dataset.cspace_convert_back(batch['real_hres']) 31 | rendered_color_gt = dataset.cspace_convert_back(rendered_color_gt) 32 | rend_patches = dataset.cspace_convert_back(rend_patches) 33 | save_image(torch.cat([batch['real_hres'], rendered_color_gt[:, :3, :, :]], dim=0), f"runs/patches/{batch_idx:04d}_view.png", nrow=config.batch_size, value_range=(-1, 1), normalize=True) 34 | save_image(rendered_color_gt[:, 3, :, :].unsqueeze(1), f"runs/patches/{batch_idx:04d}_mask.png", nrow=config.batch_size, value_range=(0, 1), normalize=True) 35 | save_image(torch.cat([patches, rend_patches], dim=0), f"runs/patches/{batch_idx:04d}_patch.png", nrow=config.batch_size, value_range=(-1, 1), normalize=True) 36 | 37 | 38 | if __name__ == '__main__': 39 | test_patches() 40 | -------------------------------------------------------------------------------- /test/test_semantic_pooling.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hydra 4 | import torch_scatter 5 | import trimesh 6 | from tqdm import tqdm 7 | import torch 8 | 9 | from dataset import GraphDataLoader, to_device 10 | import numpy as np 11 | import json 12 | from model.graph import pool 13 | 14 | mesh_directory = Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape-model/shapenet-chairs-manifold-highres/") 15 | 16 | 17 | def export_semantics_at_level(semantics, batch, num_faces): 18 | 19 | for bid in range(len(batch['name'])): 20 | selections = json.loads((mesh_directory / batch['name'][bid] / "selection.json").read_text()) 21 | mesh_file = f"quad_{num_faces:05d}_{selections[str(num_faces)]:03d}.obj" 22 | mesh = trimesh.load(mesh_directory / batch['name'][bid] / mesh_file, process=False) 23 | vertex_colors = torch.zeros(mesh.vertices.shape).to(semantics.device) 24 | torch_scatter.scatter_mean(semantics[num_faces * bid: num_faces * (bid + 1), :].unsqueeze(1).expand(-1, 4, -1).reshape(-1, 3), 25 | torch.from_numpy(mesh.faces).to(semantics.device).reshape(-1).long(), dim=0, out=vertex_colors) 26 | out_mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, vertex_colors=vertex_colors.cpu().numpy(), process=False) 27 | out_mesh.export(f"runs/semantics/{batch['name'][bid]}_{mesh_file}") 28 | 29 | 30 | @hydra.main(config_path='../config', config_name='stylegan2') 31 | def test_dataloader(config): 32 | from dataset.mesh_real_2features import FaceGraphMeshDataset 33 | hex_to_rgb = lambda x: [int(x[i:i + 2], 16) for i in (1, 3, 5)] 34 | distinct_colors = ['#ff0000', '#ffff00', '#c71585', '#00fa9a', '#0000ff', '#1e90ff', '#ffdab9'] 35 | dataset = FaceGraphMeshDataset(config) 36 | dataloader = GraphDataLoader(dataset, batch_size=4, num_workers=0) 37 | Path("runs/semantics").mkdir(exist_ok=True) 38 | for batch_idx, batch in enumerate(tqdm(dataloader)): 39 | batch = to_device(batch, torch.device("cuda:0")) 40 | x = batch['x_1'] 41 | for level in range(5): 42 | colored_semantics = torch.from_numpy(np.array([hex_to_rgb(distinct_colors[label]) for label in torch.argmax(x, dim=1).cpu().numpy().tolist()])).cuda().float() 43 | export_semantics_at_level(colored_semantics, batch, config.num_faces[level]) 44 | x = pool(x, batch['graph_data']['node_counts'][level], batch['graph_data']['pool_maps'][level], batch['graph_data']['lateral_maps'][level], pool_op='mean') 45 | break 46 | 47 | 48 | @hydra.main(config_path='../config', config_name='stylegan2') 49 | def test_semantics_from_data(config): 50 | from dataset.mesh_real_features_patch_spool import FaceGraphMeshDataset 51 | hex_to_rgb = lambda x: [int(x[i:i + 2], 16) for i in (1, 3, 5)] 52 | distinct_colors = ['#ff0000', '#ffff00', '#c71585', '#00fa9a', '#0000ff', '#1e90ff', '#ffdab9'] 53 | dataset = FaceGraphMeshDataset(config) 54 | dataloader = GraphDataLoader(dataset, batch_size=4, num_workers=0) 55 | Path("runs/semantics").mkdir(exist_ok=True) 56 | for batch_idx, batch in enumerate(tqdm(dataloader)): 57 | batch = to_device(batch, torch.device("cuda:0")) 58 | for level in range(5): 59 | x = batch['graph_data']['semantic_maps'][level] 60 | colored_semantics = torch.from_numpy(np.array([hex_to_rgb(distinct_colors[label]) for label in x.squeeze().cpu().numpy().tolist()])).cuda().float() 61 | export_semantics_at_level(colored_semantics, batch, config.num_faces[level]) 62 | break 63 | 64 | 65 | if __name__ == '__main__': 66 | # test_dataloader() 67 | test_semantics_from_data() 68 | -------------------------------------------------------------------------------- /model/pigan/fid_evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains code for logging approximate FID scores during training. 3 | If you want to output ground-truth images from the training dataset, you can 4 | run this file as a script. 5 | """ 6 | 7 | import os 8 | import shutil 9 | import torch 10 | import copy 11 | import gc 12 | import argparse 13 | 14 | from torchvision.utils import save_image 15 | from pytorch_fid import fid_score 16 | from tqdm import tqdm 17 | 18 | from dataset import GraphDataLoader, to_device 19 | from dataset.mesh_real_pigan import SDFGridDataset 20 | from util.misc import EasyDict 21 | 22 | 23 | def output_real_images(dataloader, num_imgs, real_dir): 24 | img_counter = 0 25 | batch_size = dataloader.batch_size 26 | dataloader = iter(dataloader) 27 | for i in range(num_imgs//batch_size): 28 | real_imgs = next(dataloader)['real'] 29 | for img in real_imgs: 30 | save_image(img, os.path.join(real_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1)) 31 | img_counter += 1 32 | 33 | 34 | def setup_evaluation(config, generated_dir, target_size=128, num_imgs=5000): 35 | # Only make real images if they haven't been made yet 36 | real_dir = os.path.join('runs/EvalImages', '_real_images_' + str(target_size)) 37 | if not os.path.exists(real_dir): 38 | os.makedirs(real_dir) 39 | dataloader = GraphDataLoader(SDFGridDataset(EasyDict(config))) 40 | print('outputting real images...') 41 | output_real_images(dataloader, num_imgs, real_dir) 42 | print('...done') 43 | 44 | if generated_dir is not None: 45 | os.makedirs(generated_dir, exist_ok=True) 46 | return real_dir 47 | 48 | 49 | def output_images(render, generator, encoder, dataloader, rank, output_dir, device, num_imgs=2048): 50 | generator.eval() 51 | img_counter = rank 52 | 53 | if rank == 0: pbar = tqdm("generating images", total = num_imgs) 54 | with torch.no_grad(): 55 | while img_counter < num_imgs: 56 | torch.cuda.empty_cache() 57 | gc.collect() 58 | print('=====') 59 | for obj in gc.get_objects(): 60 | try: 61 | if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): 62 | print(type(obj), obj.size()) 63 | except: 64 | pass 65 | batch = next(dataloader) 66 | batch = to_device(batch, device) 67 | z = torch.randn((batch['real'].shape[0], generator.z_dim), device=device) 68 | faces = batch['faces'] 69 | shape = encoder(batch['sdf_x'])[4].mean((2, 3, 4)) 70 | generated_imgs = render(generator(faces, z, shape), batch, 128).cpu() 71 | 72 | for img in generated_imgs: 73 | save_image(img, os.path.join(output_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1)) 74 | if rank == 0: pbar.update(1) 75 | if rank == 0: pbar.close() 76 | 77 | 78 | def calculate_fid(dataset_name, generated_dir, target_size=256): 79 | real_dir = os.path.join('runs/EvalImages', '_real_images_' + str(target_size)) 80 | fid = fid_score.calculate_fid_given_paths([real_dir, generated_dir], 128, 'cuda', 2048) 81 | torch.cuda.empty_cache() 82 | 83 | return fid 84 | 85 | 86 | if __name__ == '__main__': 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--dataset', type=str, default='CelebA') 89 | parser.add_argument('--img_size', type=int, default=128) 90 | parser.add_argument('--num_imgs', type=int, default=8000) 91 | 92 | opt = parser.parse_args() 93 | 94 | real_images_dir = setup_evaluation(opt.dataset, None, target_size=opt.img_size, num_imgs=opt.num_imgs) -------------------------------------------------------------------------------- /config/stylegan2.yaml: -------------------------------------------------------------------------------- 1 | #dataset_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape/CubeTexturesForGraphQuad_gan_FC_processed 2 | #mesh_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape-model/CubeTexturesForGraphQuad 3 | #mesh_resolution: 128 4 | #dataset_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape/CubeTexturePlaneQuad32_FC_processed 5 | #mesh_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape-model/CubeTexturePlaneQuad32 6 | #mesh_resolution: 32 7 | #dataset_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape/CubeTexturesForGraphQuad32_FC_processed 8 | #mesh_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape-model/CubeTexturesForGraphQuad32 9 | #mesh_resolution: 64 10 | 11 | dataset_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape/shapenet-chairs-manifold-highres-part_processed_color 12 | mesh_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape/shapenet-chairs-manifold-highres 13 | pairmeta_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape-model/metadata/pairs.json 14 | df_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape-model/shapenet-chairs-manifold 15 | image_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape/exemplars 16 | mask_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape/exemplars_mask 17 | condition_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape/shapenet-chairs-manifold_autoencoder 18 | stat_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape/shapenet-chairs-manifold-highres_stat.pt 19 | uv_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape/uv_map 20 | silhoutte_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape/uv_mask 21 | normals_path: /cluster/gimli/ysiddiqui/CADTextures/Photoshape/uv_normals 22 | mesh_resolution: 64 23 | 24 | experiment: fast_dev 25 | seed: null 26 | save_epoch: 1 27 | sanity_steps: 1 28 | max_epoch: 2000 29 | scheduler: null 30 | val_check_percent: 1.0 31 | val_check_interval: 1 32 | resume: null 33 | 34 | num_mapping_layers: 5 35 | lr_g: 0.002 36 | lr_d: 0.00235 37 | lr_e: 0.0001 38 | lazy_gradient_penalty_interval: 16 39 | lazy_path_penalty_after: 0 40 | lazy_path_penalty_interval: 4 41 | latent_dim: 512 42 | condition_dim: 512 43 | lambda_gp: 1 44 | lambda_plp: 2 45 | lambda_patch: 1 46 | ada_start_p: 0. #to disable set to -1 47 | ada_target: 0.6 48 | ada_interval: 4 49 | ada_fixed: False 50 | g_channel_base: 16384 51 | d_channel_base: 16384 52 | d_channel_max: 512 53 | features: normal 54 | p_synthetic: 0.0 55 | #normal 56 | #position 57 | #position+normal 58 | #normal+laplacian 59 | #normal+ff1+ff2 60 | #normal+curvature 61 | #normal+laplacian+ff1+ff2+curvature 62 | 63 | random_bg: white 64 | colorspace: rgb 65 | 66 | image_size: 128 67 | render_size: 68 | num_patch_per_view: 4 69 | patch_size: 64 70 | image_size_hres: 512 71 | erode: True 72 | camera_noise: 0.0 73 | resume_ema: null 74 | 75 | num_faces: [24576, 6144, 1536, 384, 96, 24] 76 | #image_size: 32 77 | #num_faces: [1024, 256, 64, 16] 78 | #image_size: 64 79 | #num_faces: [1536, 384, 96, 24] 80 | num_eval_images: 2048 81 | num_vis_images: 1024 82 | num_vis_meshes: 64 83 | batch_size: 16 84 | views_per_sample: 2 85 | random_views: False 86 | num_workers: 8 87 | optimize_lights: False 88 | optimize_shininess: False 89 | 90 | conv_aggregation: max 91 | g_channel_max: 512 92 | enc_conv: face 93 | 94 | df_trunc: 0.3125 95 | df_size: 0.0625 96 | df_mean: 0.2839 97 | df_std: 0.0686 98 | 99 | shape_id: shape02344_rank02_pair183269 100 | epoch_steps: 320 101 | 102 | mbstd_on: 0 103 | 104 | create_new_resume: False 105 | 106 | wandb_main: False 107 | suffix: '' 108 | 109 | preload: False 110 | 111 | hydra: 112 | output_subdir: null # Disable saving of config files. We'll do that ourselves. 113 | run: 114 | dir: . 115 | -------------------------------------------------------------------------------- /config/stylegan2_car.yaml: -------------------------------------------------------------------------------- 1 | #dataset_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape/CubeTexturesForGraphQuad_gan_FC_processed 2 | #mesh_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape-model/CubeTexturesForGraphQuad 3 | #mesh_resolution: 128 4 | #dataset_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape/CubeTexturePlaneQuad32_FC_processed 5 | #mesh_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape-model/CubeTexturePlaneQuad32 6 | #mesh_resolution: 32 7 | #dataset_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape/CubeTexturesForGraphQuad32_FC_processed 8 | #mesh_path: /cluster/gimli/ysiddiqui/CADTextures/SingleShape-model/CubeTexturesForGraphQuad32 9 | #mesh_resolution: 64 10 | 11 | dataset_path: /cluster/gimli/ysiddiqui/CADTextures/CompCars/manifold_combined_processed 12 | mesh_path: /cluster/gimli/ysiddiqui/CADTextures/CompCars/manifold_combined 13 | # dataset_path: /cluster/gimli/ysiddiqui/CADTextures/CompCars/manifold_combined_processed 14 | # mesh_path: /cluster/gimli/ysiddiqui/CADTextures/CompCars/manifold_combined 15 | image_path: /cluster/gimli/ysiddiqui/CADTextures/CompCars/exemplars_highres 16 | mask_path: /cluster/gimli/ysiddiqui/CADTextures/CompCars/exemplars_highres_mask 17 | uv_path: /cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_map_first 18 | silhoutte_path: /cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_mask 19 | normals_path: /cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_normals 20 | 21 | experiment: fast_dev 22 | seed: null 23 | save_epoch: 1 24 | sanity_steps: 1 25 | max_epoch: 5000 26 | scheduler: null 27 | val_check_percent: 1.0 28 | val_check_interval: 1 29 | resume: null 30 | 31 | num_mapping_layers: 5 32 | lr_g: 0.002 33 | lr_d: 0.00235 34 | lr_e: 0.0001 35 | lazy_gradient_penalty_interval: 16 36 | lazy_path_penalty_after: 0 37 | lazy_path_penalty_interval: 4 38 | latent_dim: 512 39 | condition_dim: 512 40 | lambda_gp: 1 41 | lambda_plp: 2 42 | lambda_patch: 1 43 | ada_start_p: 0. #to disable set to -1 44 | ada_target: 0.6 45 | ada_interval: 4 46 | ada_fixed: False 47 | g_channel_base: 16384 48 | d_channel_base: 16384 49 | features: normal 50 | p_synthetic: 0.0 51 | 52 | random_bg: grayscale 53 | colorspace: rgb 54 | 55 | image_size: 128 56 | render_size: null 57 | 58 | #progressive_switch: [10e3, 70e3, 180e3] 59 | #alpha_switch: [5e3, 10e3, 10e3] 60 | 61 | progressive_switch: [10e3, 40e3, 80e3, 180e3] 62 | alpha_switch: [5e3, 10e3, 10e3, 10e3] 63 | progressive_start_res: 6 64 | 65 | num_patch_per_view: 4 66 | patch_size: 64 67 | image_size_hres: 512 68 | erode: True 69 | camera_noise: 0.0 70 | 71 | num_faces: [24576, 6144, 1536, 384, 96, 24] 72 | #image_size: 32 73 | #num_faces: [1024, 256, 64, 16] 74 | #image_size: 64 75 | #num_faces: [1536, 384, 96, 24] 76 | num_eval_images: 2048 77 | num_vis_images: 1024 78 | num_vis_meshes: 64 79 | batch_size: 16 80 | views_per_sample: 2 81 | random_views: False 82 | num_workers: 8 83 | optimize_lights: False 84 | optimize_shininess: False 85 | 86 | conv_aggregation: max 87 | g_channel_max: 512 88 | enc_conv: face 89 | 90 | df_trunc: 0.3125 91 | df_size: 0.0625 92 | df_mean: 0.2839 93 | df_std: 0.0686 94 | 95 | shape_id: shape02344_rank02_pair183269 96 | epoch_steps: 320 97 | 98 | mbstd_on: 0 99 | 100 | create_new_resume: False 101 | 102 | wandb_main: False 103 | suffix: '' 104 | resume_ema: null 105 | 106 | preload: False 107 | 108 | prog_resume_ema: "/cluster_HDD/gondor/ysiddiqui/stylegan2-ada-3d-texture/runs/23020923_StyleGAN23D-CompCars_bigdtwin-clip_fg3bgg-lrd1g14-v8m8-1K_128/checkpoints/ema_000027474.pth" 109 | prog_resume: "/cluster_HDD/gondor/ysiddiqui/stylegan2-ada-3d-texture/runs/23020923_StyleGAN23D-CompCars_bigdtwin-clip_fg3bgg-lrd1g14-v8m8-1K_128/checkpoints/_epoch=174.ckpt" 110 | 111 | hydra: 112 | output_subdir: null # Disable saving of config files. We'll do that ourselves. 113 | run: 114 | dir: . 115 | -------------------------------------------------------------------------------- /util/df_metrics.py: -------------------------------------------------------------------------------- 1 | from torchmetrics.metric import Metric 2 | import torch 3 | from external.ChamferDistancePytorch.chamfer3D import dist_chamfer_3D 4 | 5 | 6 | class IoU(Metric): 7 | 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.add_state("iou_sum", default=torch.tensor(0).float(), dist_reduce_fx="sum") 11 | self.add_state("total", default=torch.tensor(0).float(), dist_reduce_fx="sum") 12 | 13 | # noinspection PyMethodOverriding 14 | def update(self, preds: torch.Tensor, target: torch.Tensor): 15 | intersection = (preds & target).sum(-1).sum(-1).sum(-1).squeeze(1) 16 | union = (preds | target).sum(-1).sum(-1).sum(-1).squeeze(1) 17 | valid_mask = union > 0 18 | intersection = intersection[valid_mask] 19 | union = union[valid_mask] 20 | if union.sum() > 0: 21 | self.iou_sum += (intersection / (union + 1e-5)).sum() 22 | self.total += intersection.shape[0] 23 | 24 | def compute(self): 25 | return self.iou_sum.float() / self.total 26 | 27 | 28 | class Chamfer3D(Metric): 29 | 30 | def __init__(self, *args, **kwargs): 31 | super().__init__(*args, **kwargs) 32 | self.cham_loss = dist_chamfer_3D.chamfer_3DDist() 33 | self.add_state("cd_sum", default=torch.tensor(0).float(), dist_reduce_fx="sum") 34 | self.add_state("total", default=torch.tensor(0).float(), dist_reduce_fx="sum") 35 | 36 | # noinspection PyMethodOverriding 37 | def update(self, preds: torch.Tensor, target: torch.Tensor): 38 | cd = torch.tensor(0).float().to(device=preds.device) 39 | preds = preds.squeeze(1) 40 | target = target.squeeze(1) 41 | valid_chamf = 0 42 | for ip in range(preds.shape[0]): 43 | points_pred = torch.nonzero(preds[ip], as_tuple=False).unsqueeze(0).float() 44 | points_target = torch.nonzero(target[ip], as_tuple=False).unsqueeze(0).float() 45 | if points_pred.shape[0] > 0 and points_target.shape[0] > 0: 46 | dist1, dist2, _, _ = self.cham_loss(points_target, points_pred) 47 | cd_ = (torch.mean(dist1)) + (torch.mean(dist2)) 48 | if not torch.isnan(cd_): 49 | cd += cd_ 50 | valid_chamf += 1 51 | self.cd_sum += cd 52 | self.total += valid_chamf 53 | 54 | def compute(self): 55 | return self.cd_sum.float() / self.total 56 | 57 | 58 | class Precision(Metric): 59 | 60 | def __init__(self, *args, **kwargs): 61 | super().__init__(*args, **kwargs) 62 | self.add_state("precision_sum", default=torch.tensor(0).float(), dist_reduce_fx="sum") 63 | self.add_state("total", default=torch.tensor(0).float(), dist_reduce_fx="sum") 64 | 65 | # noinspection PyMethodOverriding 66 | def update(self, preds: torch.Tensor, target: torch.Tensor): 67 | intersection = (preds & target).sum(-1).sum(-1).sum(-1).squeeze(1) 68 | self.precision_sum += (intersection / (preds.sum(-1).sum(-1).sum(-1).squeeze(1) + 1e-5)).sum() 69 | self.total += intersection.shape[0] 70 | 71 | def compute(self): 72 | return self.precision_sum.float() / self.total 73 | 74 | 75 | class Recall(Metric): 76 | 77 | def __init__(self, *args, **kwargs): 78 | super().__init__(*args, **kwargs) 79 | self.add_state("recall_sum", default=torch.tensor(0).float(), dist_reduce_fx="sum") 80 | self.add_state("total", default=torch.tensor(0).float(), dist_reduce_fx="sum") 81 | 82 | # noinspection PyMethodOverriding 83 | def update(self, preds: torch.Tensor, target: torch.Tensor): 84 | intersection = (preds & target).sum(-1).sum(-1).sum(-1).squeeze(1) 85 | self.recall_sum += (intersection / (target.sum(-1).sum(-1).sum(-1).squeeze(1) + 1e-5)).sum() 86 | self.total += intersection.shape[0] 87 | 88 | def compute(self): 89 | return self.recall_sum.float() / self.total 90 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | import sys 4 | import traceback 5 | from pathlib import Path 6 | from random import randint 7 | import datetime 8 | 9 | import torch 10 | import wandb 11 | from pytorch_lightning import seed_everything, Trainer 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | from pytorch_lightning.loggers import WandbLogger 14 | from pytorch_lightning.plugins import DDPPlugin 15 | 16 | from util.filesystem_logger import FilesystemLogger 17 | 18 | 19 | def print_traceback_handler(sig, frame): 20 | print(f'Received signal {sig}') 21 | bt = ''.join(traceback.format_stack()) 22 | print(f'Requested stack trace:\n{bt}') 23 | 24 | 25 | def quit_handler(sig, frame): 26 | print(f'Received signal {sig}, quitting.') 27 | sys.exit(1) 28 | 29 | 30 | def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler): 31 | print(f'Setting signal {sig} handler {handler}') 32 | signal.signal(sig, handler) 33 | 34 | 35 | def register_quit_signal_handlers(sig=signal.SIGUSR2, handler=quit_handler): 36 | print(f'Setting signal {sig} handler {handler}') 37 | signal.signal(sig, handler) 38 | 39 | 40 | def generate_experiment_name(name, config): 41 | if config.resume is not None and not config.create_new_resume: 42 | experiment = Path(config.resume).parents[1].name 43 | os.environ['experiment'] = experiment 44 | elif not os.environ.get('experiment'): 45 | experiment = f"{datetime.datetime.now().strftime('%d%m%H%M')}_{name}_{config.experiment}" 46 | os.environ['experiment'] = experiment 47 | else: 48 | experiment = os.environ['experiment'] 49 | return experiment 50 | 51 | 52 | def create_trainer(name, config): 53 | if not config.wandb_main and config.suffix == '': 54 | config.suffix = '-dev' 55 | config.experiment = generate_experiment_name(name, config) 56 | if config.val_check_interval > 1: 57 | config.val_check_interval = int(config.val_check_interval) 58 | if config.seed is None: 59 | config.seed = randint(0, 999) 60 | 61 | seed_everything(config.seed) 62 | 63 | register_debug_signal_handlers() 64 | register_quit_signal_handlers() 65 | 66 | # noinspection PyUnusedLocal 67 | filesystem_logger = FilesystemLogger(config) 68 | logger = WandbLogger(project=f'{name}{config.suffix}', 69 | name=config.experiment, 70 | id=config.experiment, 71 | settings=wandb.Settings(start_method='thread')) 72 | checkpoint_callback = ModelCheckpoint(dirpath=(Path("runs") / config.experiment / "checkpoints"), 73 | filename='_{epoch}', 74 | save_top_k=-1, 75 | verbose=False, 76 | every_n_epochs=config.save_epoch) 77 | 78 | gpu_count = torch.cuda.device_count() 79 | 80 | if gpu_count > 1: 81 | 82 | # config.val_check_interval *= gpu_count 83 | trainer = Trainer(gpus=-1, 84 | accelerator='ddp', 85 | plugins=DDPPlugin(find_unused_parameters=True), 86 | num_sanity_val_steps=config.sanity_steps, 87 | max_epochs=config.max_epoch, 88 | limit_val_batches=config.val_check_percent, 89 | callbacks=[checkpoint_callback], 90 | val_check_interval=float(min(config.val_check_interval, 1)), 91 | check_val_every_n_epoch=max(1, config.val_check_interval), 92 | resume_from_checkpoint=config.resume, 93 | logger=logger, 94 | benchmark=True) 95 | else: 96 | trainer = Trainer(gpus=[0], 97 | num_sanity_val_steps=config.sanity_steps, 98 | max_epochs=config.max_epoch, 99 | limit_val_batches=config.val_check_percent, 100 | callbacks=[checkpoint_callback], 101 | val_check_interval=float(min(config.val_check_interval, 1)), 102 | check_val_every_n_epoch=max(1, config.val_check_interval), 103 | resume_from_checkpoint=config.resume, 104 | logger=logger, 105 | benchmark=True) 106 | return trainer 107 | -------------------------------------------------------------------------------- /test/test_specular.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hydra 4 | from torchvision.utils import save_image 5 | from tqdm import tqdm 6 | import torch 7 | import numpy as np 8 | 9 | from dataset.mesh_real_features import FaceGraphMeshDataset 10 | from dataset import GraphDataLoader, to_device, to_vertex_colors_scatter, to_vertex_shininess_scatter 11 | from model.differentiable_renderer_light import DifferentiableRenderer 12 | from trainer.train_stylegan_real_feature_light import get_light_directions, sample_light_directions 13 | 14 | 15 | @hydra.main(config_path='../config', config_name='stylegan2') 16 | def test_specular(config): 17 | dataset = FaceGraphMeshDataset(config) 18 | dataloader = GraphDataLoader(dataset, batch_size=4, num_workers=0) 19 | render_helper = DifferentiableRenderer(config.image_size, 'standard', config.colorspace).cuda() 20 | Path("runs/images_light").mkdir(exist_ok=True) 21 | ctr = 0 22 | for batch_idx, batch in enumerate(tqdm(dataloader)): 23 | batch = to_device(batch, torch.device("cuda:0")) 24 | lightdir = torch.randn(size=[3], device=batch['vertices'].device) 25 | lightdir /= lightdir.norm() + 1e-8 26 | rendered_color_gt = render_helper.render(batch['vertices'], batch['indices'], to_vertex_colors_scatter(batch["y"][:, :3], batch), batch["ranges"].cpu(), batch['bg']).permute((0, 3, 1, 2)) 27 | rendered_specular = render_helper.render_light(lightdir, batch['view_vector'], batch['vertices'], batch['indices'], to_vertex_colors_scatter(batch["x"], batch)[:, :3], batch["ranges"].cpu()).permute((0, 3, 1, 2)).expand(-1, 3, -1, 28 | -1) 29 | rendered_specular = torch.max(torch.zeros_like(rendered_specular), rendered_specular) ** 30 30 | batch['real'] = dataset.cspace_convert_back(batch['real']) 31 | rendered_color_gt = dataset.cspace_convert_back(rendered_color_gt) 32 | save_image(torch.cat([rendered_color_gt, rendered_specular]), f"runs/images_light/test_view_{batch_idx:04d}.png", nrow=4, value_range=(-1, 1), normalize=True) 33 | ctr += 1 34 | if ctr == 5: 35 | break 36 | 37 | 38 | @hydra.main(config_path='../config', config_name='stylegan2') 39 | def test_specular_bounds(config): 40 | dataset = FaceGraphMeshDataset(config) 41 | dataloader = GraphDataLoader(dataset, batch_size=4, num_workers=0) 42 | render_helper = DifferentiableRenderer(config.image_size, 'bounds', config.colorspace).cuda() 43 | Path("runs/images_light").mkdir(exist_ok=True) 44 | ctr = 0 45 | shininess = 28 46 | for batch_idx, batch in enumerate(tqdm(dataloader)): 47 | batch = to_device(batch, torch.device("cuda:0")) 48 | lightdirs = get_light_directions(3, torch.device("cuda:0")) 49 | # lightdirs = sample_light_directions(lightdirs) 50 | rendered_color_diffuse, rendered_color_spec = render_helper.render(batch['vertices'], batch['indices'], 51 | to_vertex_colors_scatter(batch["y"][:, :3], batch), 52 | batch['normals'], 53 | to_vertex_shininess_scatter(batch["y"][:, 3:4], batch), 54 | lightdirs, batch['view_vector'], shininess, 55 | batch["ranges"].cpu(), batch['bg']) 56 | rendered_color_diffuse = rendered_color_diffuse.permute((0, 3, 1, 2)) 57 | rendered_color_spec = rendered_color_spec.permute((0, 3, 1, 2)).expand(-1, 3, -1, -1) 58 | rendered_color = rendered_color_diffuse + rendered_color_spec 59 | batch['real'] = dataset.cspace_convert_back(batch['real']) 60 | rendered_color = dataset.cspace_convert_back(rendered_color) 61 | save_image(torch.cat([rendered_color, rendered_color_diffuse, rendered_color_spec]), f"runs/images_light/test_view_{batch_idx:04d}.png", nrow=4, value_range=(-1, 1), normalize=True) 62 | ctr += 1 63 | if ctr == 10: 64 | break 65 | 66 | 67 | if __name__ == '__main__': 68 | # ld = get_light_directions(3, torch.device("cpu")) 69 | # all_lights = [] 70 | # for _sid in range(640): 71 | # ld = sample_light_directions(ld) 72 | # all_lights.append(ld) 73 | # ld = torch.cat(all_lights, dim=0) 74 | # Path(f"lights.obj").write_text("\n".join([f"v {ld[i][0]} {ld[i][1]} {ld[i][2]}\n" for i in range(len(ld))])) 75 | test_specular_bounds() 76 | -------------------------------------------------------------------------------- /model/raycast_rgbd/raycast_rgbd_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | void raycast_rgbd_cuda_forward( 8 | torch::Tensor sparse_mapping, 9 | torch::Tensor locs, 10 | torch::Tensor vals_sdf, 11 | torch::Tensor vals_color, 12 | torch::Tensor vals_normals, 13 | torch::Tensor viewMatrixInv, 14 | torch::Tensor imageColor, 15 | torch::Tensor imageDepth, 16 | torch::Tensor imageNormal, 17 | torch::Tensor mapping2dto3d, 18 | torch::Tensor mapping3dto2d_num, 19 | torch::Tensor intrinsicParams, 20 | torch::Tensor opts); 21 | 22 | void raycast_rgbd_cuda_backward( 23 | torch::Tensor grad_color, 24 | torch::Tensor grad_depth, 25 | torch::Tensor grad_normal, 26 | torch::Tensor sparse_mapping, 27 | torch::Tensor mapping3dto2d, 28 | torch::Tensor mapping3dto2d_num, 29 | torch::Tensor dims, 30 | torch::Tensor d_color, 31 | torch::Tensor d_depth, 32 | torch::Tensor d_normals); 33 | 34 | void raycast_occ_cuda_forward( 35 | at::Tensor occ3d, 36 | at::Tensor occ2d, 37 | at::Tensor viewMatrixInv, 38 | at::Tensor intrinsicParams, 39 | at::Tensor opts); 40 | 41 | void construct_dense_sparse_mapping_cuda( 42 | torch::Tensor locs, 43 | torch::Tensor sparse_mapping); 44 | 45 | 46 | // C++ interface 47 | 48 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 49 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 50 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 51 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 52 | 53 | void raycast_color_forward( 54 | torch::Tensor sparse_mapping, 55 | torch::Tensor locs, 56 | torch::Tensor vals_sdf, 57 | torch::Tensor vals_color, 58 | torch::Tensor vals_normals, 59 | torch::Tensor viewMatrixInv, 60 | torch::Tensor imageColor, 61 | torch::Tensor imageDepth, 62 | torch::Tensor imageNormal, 63 | torch::Tensor mapping3dto2d, 64 | torch::Tensor mapping3dto2d_num, 65 | torch::Tensor intrinsicParams, 66 | torch::Tensor opts) { 67 | CHECK_INPUT(sparse_mapping); 68 | CHECK_INPUT(locs); 69 | CHECK_INPUT(vals_sdf); 70 | CHECK_INPUT(vals_color); 71 | CHECK_INPUT(vals_normals); 72 | CHECK_INPUT(viewMatrixInv); 73 | CHECK_INPUT(imageColor); 74 | CHECK_INPUT(imageDepth); 75 | CHECK_INPUT(imageNormal); 76 | CHECK_INPUT(mapping3dto2d); 77 | CHECK_INPUT(mapping3dto2d_num); 78 | CHECK_INPUT(intrinsicParams); 79 | 80 | return raycast_rgbd_cuda_forward(sparse_mapping, locs, vals_sdf, vals_color, vals_normals, viewMatrixInv, imageColor, imageDepth, imageNormal, mapping3dto2d, mapping3dto2d_num, intrinsicParams, opts); 81 | } 82 | 83 | void construct_dense_sparse_mapping( 84 | torch::Tensor locs, 85 | torch::Tensor sparse_mapping) { 86 | CHECK_INPUT(locs); 87 | CHECK_INPUT(sparse_mapping); 88 | 89 | return construct_dense_sparse_mapping_cuda(locs, sparse_mapping); 90 | } 91 | 92 | void raycast_color_backward( 93 | torch::Tensor grad_color, 94 | torch::Tensor grad_depth, 95 | torch::Tensor grad_normal, 96 | torch::Tensor sparse_mapping, 97 | torch::Tensor mapping3dto2d, 98 | torch::Tensor mapping3dto2d_num, 99 | torch::Tensor dims, 100 | torch::Tensor d_color, 101 | torch::Tensor d_depth, 102 | torch::Tensor d_normals) { 103 | CHECK_INPUT(grad_color); 104 | CHECK_INPUT(grad_depth); 105 | CHECK_INPUT(grad_normal); 106 | CHECK_INPUT(sparse_mapping); 107 | CHECK_INPUT(mapping3dto2d); 108 | CHECK_INPUT(mapping3dto2d_num); 109 | CHECK_INPUT(d_color); 110 | CHECK_INPUT(d_depth); 111 | CHECK_INPUT(d_normals); 112 | 113 | return raycast_rgbd_cuda_backward( 114 | grad_color, 115 | grad_depth, 116 | grad_normal, 117 | sparse_mapping, 118 | mapping3dto2d, 119 | mapping3dto2d_num, 120 | dims, 121 | d_color, 122 | d_depth, 123 | d_normals); 124 | } 125 | 126 | void raycast_occ_forward( 127 | at::Tensor occ3d, 128 | at::Tensor occ2d, 129 | at::Tensor viewMatrixInv, 130 | at::Tensor intrinsicParams, 131 | at::Tensor opts) { 132 | CHECK_INPUT(occ3d); 133 | CHECK_INPUT(occ2d); 134 | CHECK_INPUT(viewMatrixInv); 135 | CHECK_INPUT(intrinsicParams); 136 | return raycast_occ_cuda_forward(occ3d, occ2d, viewMatrixInv, intrinsicParams, opts); 137 | } 138 | 139 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 140 | m.def("forward", &raycast_color_forward, "raycast_color forward (CUDA)"); 141 | m.def("backward", &raycast_color_backward, "raycast_color backward (CUDA)"); 142 | m.def("construct_dense_sparse_mapping", &construct_dense_sparse_mapping, "construct mapping from dense to sparse (CUDA)"); 143 | m.def("raycast_occ", &raycast_occ_forward, "raycast_color 3d occupancy grid (CUDA)"); 144 | } 145 | -------------------------------------------------------------------------------- /dataset/distance_field.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from pathlib import Path 4 | import marching_cubes as mc 5 | import trimesh 6 | import hydra 7 | 8 | 9 | class DistanceFieldDataset(torch.utils.data.Dataset): 10 | 11 | def __init__(self, config, interval, overfit=False): 12 | self.items = [x / "032.npy" for x in sorted(list(Path(config.df_path).iterdir()))][interval[0]: interval[1]] 13 | if overfit: 14 | self.items = self.items * 160 15 | self.trunc = config.df_trunc 16 | self.vox_size = config.df_size 17 | self.mean = config.df_mean 18 | self.std = config.df_std 19 | self.weight_occupied = 8 20 | 21 | def __len__(self): 22 | return len(self.items) 23 | 24 | def __getitem__(self, idx): 25 | selected_item = self.items[idx] 26 | return { 27 | 'df': torch.from_numpy(np.load(str(selected_item))).unsqueeze(0), 28 | 'name': selected_item.parent.name.split('.')[0] 29 | } 30 | 31 | def augment_batch_data(self, batch): 32 | weights = torch.ones_like(batch['df']) * (1 + (batch['df'] < self.trunc).float() * (self.weight_occupied - 1)).float() 33 | empty = (batch['df'] >= self.trunc) 34 | batch['weights'] = weights 35 | batch['empty'] = empty 36 | 37 | def normalize(self, df): 38 | return (df - self.mean) / self.std 39 | 40 | def visualize_as_mesh(self, grid, output_path): 41 | vertices, triangles = mc.marching_cubes(grid, self.vox_size) 42 | mc.export_obj(vertices, triangles, output_path) 43 | 44 | def visualize_sdf_as_voxels(self, sdf, output_path): 45 | from util.misc import to_point_list 46 | point_list = to_point_list(sdf <= self.vox_size) 47 | if point_list.shape[0] > 0: 48 | base_mesh = trimesh.voxel.ops.multibox(centers=point_list, pitch=1) 49 | base_mesh.export(output_path) 50 | 51 | @staticmethod 52 | def visualize_occupancy_as_voxels(occupancy, output_path): 53 | from util.misc import to_point_list 54 | point_list = to_point_list(occupancy) 55 | if point_list.shape[0] > 0: 56 | base_mesh = trimesh.voxel.ops.multibox(centers=point_list, pitch=1) 57 | base_mesh.export(output_path) 58 | 59 | @staticmethod 60 | def visualize_float_grid(grid, ignore_val, minval, maxval, output_path): 61 | from matplotlib import cm 62 | jetmap = cm.get_cmap('jet') 63 | norm_grid = (grid - minval) / (maxval - minval) 64 | f = open(output_path, "w") 65 | for x in range(grid.shape[0]): 66 | for y in range(grid.shape[1]): 67 | for z in range(grid.shape[2]): 68 | if grid[x, y, z] > ignore_val: 69 | c = (np.array(jetmap(norm_grid[x, y, z])) * 255).astype(np.uint8) 70 | f.write('v %f %f %f %f %f %f\n' % (x + 0.5, y + 0.5, z + 0.5, c[0], c[1], c[2])) 71 | f.close() 72 | 73 | 74 | @hydra.main(config_path='../config', config_name='stylegan2') 75 | def test_distancefield_dataset(config): 76 | from tqdm import tqdm 77 | from dataset import to_device 78 | dataset = DistanceFieldDataset(config=config, limit=4) 79 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=False, pin_memory=True, num_workers=0) 80 | for batch_idx, batch in enumerate(tqdm(dataloader)): 81 | batch = to_device(batch, torch.device("cuda:0")) 82 | dataset.augment_batch_data(batch) 83 | print(batch['name'], batch['df'].shape, batch['df'].max(), batch['df'].min()) 84 | for idx in range(batch['df'].shape[0]): 85 | dataset.visualize_as_mesh(batch['df'][idx].squeeze(0).cpu().numpy(), f"{batch['name'][idx]}.obj") 86 | dataset.visualize_occupancy_as_voxels(1 - batch["empty"][idx].squeeze(0).cpu().numpy(), f"{batch['name'][idx]}_nempty.obj") 87 | dataset.visualize_float_grid(batch['weights'][idx].squeeze(0).cpu().numpy(), 0.0, 0.0, 2.0, f"{batch['name'][idx]}_weight.obj") 88 | 89 | 90 | @hydra.main(config_path='../config', config_name='stylegan2') 91 | def get_dataset_mean_std(config): 92 | from tqdm import tqdm 93 | import math 94 | dataset = DistanceFieldDataset(config=config) 95 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=False, pin_memory=True, num_workers=0) 96 | ctr, mean, var = 0., 0., 0. 97 | for batch_idx, batch in enumerate(tqdm(dataloader)): 98 | mean += batch['df'].mean().item() 99 | var += batch['df'].std().item() ** 2 100 | ctr += 1 101 | print(mean / ctr, math.sqrt((var / ctr))) 102 | 103 | 104 | if __name__ == "__main__": 105 | get_dataset_mean_std() 106 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from ballpark import business 5 | import numpy as np 6 | from cleanfid import fid 7 | from pathlib import Path 8 | 9 | 10 | def print_model_parameter_count(model): 11 | count = sum(p.numel() for p in model.parameters() if p.requires_grad) 12 | print(f"Number of parameters in {type(model).__name__}: {business(count, precision=3, prefix=True)}") 13 | 14 | 15 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 16 | assert isinstance(module, torch.nn.Module) 17 | assert not isinstance(module, torch.jit.ScriptModule) 18 | assert isinstance(inputs, (tuple, list)) 19 | 20 | # Register hooks. 21 | entries = [] 22 | nesting = [0] 23 | 24 | def pre_hook(_mod, _inputs): 25 | nesting[0] += 1 26 | 27 | def post_hook(mod, _inputs, outputs): 28 | nesting[0] -= 1 29 | if nesting[0] <= max_nesting: 30 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 31 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 32 | entries.append(EasyDict(mod=mod, outputs=outputs)) 33 | 34 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 35 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 36 | 37 | # Run module. 38 | outputs = module(*inputs) 39 | for hook in hooks: 40 | hook.remove() 41 | 42 | # Identify unique outputs, parameters, and buffers. 43 | tensors_seen = set() 44 | for e in entries: 45 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen if t.requires_grad] 46 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 47 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 48 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 49 | 50 | # Filter out redundant entries. 51 | if skip_redundant: 52 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 53 | 54 | # Construct table. 55 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 56 | rows += [['---'] * len(rows[0])] 57 | param_total = 0 58 | buffer_total = 0 59 | submodule_names = {mod: name for name, mod in module.named_modules()} 60 | for e in entries: 61 | name = '' if e.mod is module else submodule_names[e.mod] 62 | param_size = sum(t.numel() for t in e.unique_params) 63 | buffer_size = sum(t.numel() for t in e.unique_buffers) 64 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 65 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 66 | rows += [[ 67 | name + (':0' if len(e.outputs) >= 2 else ''), 68 | str(param_size) if param_size else '-', 69 | str(buffer_size) if buffer_size else '-', 70 | (output_shapes + ['-'])[0], 71 | (output_dtypes + ['-'])[0], 72 | ]] 73 | for idx in range(1, len(e.outputs)): 74 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 75 | param_total += param_size 76 | buffer_total += buffer_size 77 | rows += [['---'] * len(rows[0])] 78 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 79 | 80 | # Print table. 81 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 82 | print() 83 | for row in rows: 84 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 85 | print() 86 | return outputs 87 | 88 | 89 | class EasyDict(dict): 90 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 91 | 92 | def __getattr__(self, name): 93 | try: 94 | return self[name] 95 | except KeyError: 96 | raise AttributeError(name) 97 | 98 | def __setattr__(self, name, value): 99 | self[name] = value 100 | 101 | def __delattr__(self, name): 102 | del self[name] 103 | 104 | 105 | def to_point_list(s): 106 | return np.concatenate([c[:, np.newaxis] for c in np.where(s)], axis=1) 107 | 108 | 109 | def get_parameters_from_state_dict(state_dict, filter_key): 110 | new_state_dict = OrderedDict() 111 | for k in state_dict: 112 | if k.startswith(filter_key): 113 | new_state_dict[k.replace(filter_key + '.', '')] = state_dict[k] 114 | return new_state_dict 115 | 116 | 117 | def compute_fid(output_dir_fake, output_dir_real, filepath, device): 118 | fid_score = fid.compute_fid(str(output_dir_real), str(output_dir_fake), device=device, dataset_res=256, num_workers=0) 119 | print(f'FID: {fid_score:.3f}') 120 | kid_score = fid.compute_kid(str(output_dir_real), str(output_dir_fake), device=device, dataset_res=256, num_workers=0) 121 | print(f'KID: {kid_score:.3f}') 122 | Path(filepath).write_text(f"fid = {fid_score}\nkid = {kid_score}") 123 | 124 | -------------------------------------------------------------------------------- /dataset/mesh_uniform.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import trimesh 9 | from torch.utils.data.dataloader import default_collate 10 | from scipy.spatial.transform import Rotation 11 | 12 | from dataset import get_default_perspective_cam 13 | from model.differentiable_renderer import intrinsic_to_projection 14 | from util.misc import EasyDict 15 | 16 | 17 | def generate_random_camera(loc): 18 | y_angles = list(range(90, 270, 5)) 19 | weight_fn = lambda x: 0.5 + 0.125 * math.cos(2 * (math.pi / 45) * x) 20 | weights_y = [weight_fn(y) for y in y_angles] 21 | y_angle = random.choices(y_angles, weights_y, k=1)[0] 22 | camera_pose = np.eye(4) 23 | camera_pose[:3, :3] = Rotation.from_euler('y', y_angle, degrees=True).as_matrix() @ Rotation.from_euler('z', 180, degrees=True).as_matrix() @ Rotation.from_euler('x', 180, degrees=True).as_matrix() 24 | camera_translation = camera_pose[:3, :3] @ np.array([0, 0, 1.250]) + loc 25 | camera_pose[:3, 3] = camera_translation 26 | return camera_pose 27 | 28 | 29 | class FaceGraphMeshDataset(torch.utils.data.Dataset): 30 | 31 | def __init__(self, config, limit_dataset_size=None): 32 | self.dataset_directory = Path(config.dataset_path) 33 | self.mesh_directory = Path(config.mesh_path) 34 | self.items = list(x.stem for x in Path(config.dataset_path).iterdir())[:limit_dataset_size] 35 | self.target_name = "model_normalized.obj" 36 | self.projection_matrix = intrinsic_to_projection(get_default_perspective_cam()).float() 37 | self.generate_camera = generate_random_camera 38 | self.views_per_sample = config.views_per_sample 39 | 40 | def __len__(self): 41 | return len(self.items) 42 | 43 | def __getitem__(self, idx): 44 | selected_item = self.items[idx] 45 | pt_arxiv = torch.load(os.path.join(self.dataset_directory, f'{selected_item}.pt')) 46 | edge_index = pt_arxiv['conv_data'][0][0].long() 47 | num_sub_vertices = [pt_arxiv['conv_data'][i][0].shape[0] for i in range(1, len(pt_arxiv['conv_data']))] 48 | pad_sizes = [pt_arxiv['conv_data'][i][2].shape[0] for i in range(len(pt_arxiv['conv_data']))] 49 | sub_edges = [pt_arxiv['conv_data'][i][0].long() for i in range(1, len(pt_arxiv['conv_data']))] 50 | pool_maps = pt_arxiv['pool_locations'] 51 | is_pad = [pt_arxiv['conv_data'][i][4].bool() for i in range(len(pt_arxiv['conv_data']))] 52 | level_masks = [torch.zeros(pt_arxiv['conv_data'][i][0].shape[0]).long() for i in range(len(pt_arxiv['conv_data']))] 53 | 54 | # noinspection PyTypeChecker 55 | mesh = trimesh.load(self.mesh_directory / selected_item / self.target_name, process=False) 56 | mvp = torch.stack([torch.matmul(self.projection_matrix, torch.from_numpy(np.linalg.inv(self.generate_camera((mesh.bounds[0] + mesh.bounds[1]) / 2))).float()) 57 | for _ in range(self.views_per_sample)], dim=0) 58 | vertices = torch.from_numpy(mesh.vertices).float() 59 | indices = torch.from_numpy(mesh.faces).int() 60 | tri_indices = torch.cat([indices[:, [0, 1, 2]], indices[:, [0, 2, 3]]], 0) 61 | vctr = torch.tensor(list(range(vertices.shape[0]))).long() 62 | r, g, b = random.randint(0, 255) / 255. - 0.5, random.randint(0, 255) / 255. - 0.5, random.randint(0, 255) / 255. - 0.5 63 | pt_arxiv['target_colors'][:, 0] = r 64 | pt_arxiv['target_colors'][:, 1] = g 65 | pt_arxiv['target_colors'][:, 2] = b 66 | return { 67 | "name": selected_item, 68 | "y": pt_arxiv['target_colors'].float() * 2, 69 | "vertex_ctr": vctr, 70 | "vertices": vertices, 71 | "indices_quad": indices, 72 | "mvp": mvp, 73 | "indices": tri_indices, 74 | "ranges": torch.tensor([0, tri_indices.shape[0]]).int(), 75 | "graph_data": self.get_item_as_graphdata(edge_index, sub_edges, pad_sizes, num_sub_vertices, pool_maps, is_pad, level_masks) 76 | } 77 | 78 | @staticmethod 79 | def get_item_as_graphdata(edge_index, sub_edges, pad_sizes, num_sub_vertices, pool_maps, is_pad, level_masks): 80 | return EasyDict({ 81 | 'face_neighborhood': edge_index, 82 | 'sub_neighborhoods': sub_edges, 83 | 'pads': pad_sizes, 84 | 'node_counts': num_sub_vertices, 85 | 'pool_maps': pool_maps, 86 | 'is_pad': is_pad, 87 | 'level_masks': level_masks 88 | }) 89 | 90 | def visualize_graph_with_predictions(self, name, prediction, output_dir, output_suffix): 91 | output_dir = Path(output_dir) 92 | output_dir.mkdir(exist_ok=True, parents=True) 93 | # noinspection PyTypeChecker 94 | mesh = trimesh.load(Path(self.raw_dir, name) / self.target_name, force='mesh', process=False) 95 | mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, face_colors=prediction + 0.5, process=False) 96 | mesh.export(output_dir / f"{name}_{output_suffix}.obj") 97 | 98 | @staticmethod 99 | def batch_mask(t, graph_data, idx, level=0): 100 | return t[graph_data['level_masks'][level] == idx] 101 | -------------------------------------------------------------------------------- /model/uv/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | @torch.jit.script 6 | def clamp_gain(x: torch.Tensor, g: float, c: float): 7 | return torch.clamp(x * g, -c, c) 8 | 9 | 10 | def normalize_2nd_moment(x, dim=1, eps=1e-8): 11 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() 12 | 13 | 14 | def identity(x): 15 | return x 16 | 17 | 18 | def leaky_relu_0_2(x): 19 | return torch.nn.functional.leaky_relu(x, 0.2) 20 | 21 | 22 | activation_funcs = { 23 | "linear": { 24 | "fn": identity, 25 | "def_gain": 1 26 | }, 27 | "lrelu": { 28 | "fn": leaky_relu_0_2, 29 | "def_gain": np.sqrt(2) 30 | } 31 | } 32 | 33 | 34 | class FullyConnectedLayer(torch.nn.Module): 35 | 36 | def __init__(self, in_features, out_features, bias=True, activation='linear', lr_multiplier=1, bias_init=0): 37 | super().__init__() 38 | self.activation = activation_funcs[activation]['fn'] 39 | self.activation_gain = activation_funcs[activation]['def_gain'] 40 | self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) 41 | self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None 42 | self.weight_gain = lr_multiplier / np.sqrt(in_features) 43 | self.bias_gain = lr_multiplier 44 | 45 | def forward(self, x): 46 | w = self.weight * self.weight_gain 47 | b = self.bias 48 | if b is not None and self.bias_gain != 1: 49 | b = b * self.bias_gain 50 | x = self.activation(torch.addmm(b.unsqueeze(0), x, w.t())) * self.activation_gain 51 | return x 52 | 53 | 54 | class SmoothDownsample(torch.nn.Module): 55 | 56 | def __init__(self): 57 | super().__init__() 58 | kernel = [[1, 3, 3, 1], 59 | [3, 9, 9, 3], 60 | [3, 9, 9, 3], 61 | [1, 3, 3, 1]] 62 | kernel = torch.tensor([[kernel]], dtype=torch.float) 63 | kernel /= kernel.sum() 64 | self.kernel = torch.nn.Parameter(kernel, requires_grad=False) 65 | self.pad = torch.nn.ReplicationPad2d((2, 1, 2, 1)) 66 | 67 | def forward(self, x: torch.Tensor): 68 | b, c, h, w = x.shape 69 | x = x.view(-1, 1, h, w) 70 | x = self.pad(x) 71 | x = torch.nn.functional.conv2d(x, self.kernel).view(b, c, h, w) 72 | x = torch.nn.functional.interpolate(x, scale_factor=0.5, mode='nearest', recompute_scale_factor=False) 73 | return x 74 | 75 | 76 | class SmoothUpsample(torch.nn.Module): 77 | 78 | def __init__(self): 79 | super().__init__() 80 | kernel = [[1, 3, 3, 1], 81 | [3, 9, 9, 3], 82 | [3, 9, 9, 3], 83 | [1, 3, 3, 1]] 84 | kernel = torch.tensor([[kernel]], dtype=torch.float) 85 | kernel /= kernel.sum() 86 | self.kernel = torch.nn.Parameter(kernel, requires_grad=False) 87 | self.pad = torch.nn.ReplicationPad2d((2, 1, 2, 1)) 88 | 89 | def forward(self, x: torch.Tensor): 90 | b, c, h, w = x.shape 91 | x = x.view(-1, 1, h, w) 92 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest') 93 | x = self.pad(x) 94 | x = torch.nn.functional.conv2d(x, self.kernel).view(b, c, h * 2, w * 2) 95 | return x 96 | 97 | 98 | class EqualizedConv2d(torch.nn.Module): 99 | 100 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, activation='linear', resample=identity): 101 | super().__init__() 102 | self.resample = resample 103 | self.padding = kernel_size // 2 104 | self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) 105 | self.activation = activation_funcs[activation]['fn'] 106 | self.activation_gain = activation_funcs[activation]['def_gain'] 107 | weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]) 108 | bias = torch.zeros([out_channels]) if bias else None 109 | self.weight = torch.nn.Parameter(weight) 110 | self.bias = torch.nn.Parameter(bias) if bias is not None else None 111 | 112 | def forward(self, x, gain=1): 113 | w = self.weight * self.weight_gain 114 | b = self.bias[None, :, None, None] if self.bias is not None else 0 115 | x = self.resample(x) 116 | x = torch.nn.functional.conv2d(x, w, padding=self.padding) 117 | return clamp_gain(self.activation(x + b), self.activation_gain * gain, 256 * gain) 118 | 119 | 120 | def modulated_conv2d(x, weight, styles, padding=0, demodulate=True): 121 | batch_size = x.shape[0] 122 | out_channels, in_channels, kh, kw = weight.shape 123 | 124 | # Calculate per-sample weights and demodulation coefficients. 125 | w = weight.unsqueeze(0) # [NOIkk] 126 | w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] 127 | if demodulate: 128 | dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] 129 | w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] 130 | 131 | # Execute as one fused op using grouped convolution. 132 | batch_size = int(batch_size) 133 | x = x.reshape(1, -1, *x.shape[2:]) 134 | w = w.reshape(-1, in_channels, kh, kw) 135 | x = torch.nn.functional.conv2d(x, w, padding=padding, groups=batch_size) 136 | x = x.reshape(batch_size, -1, *x.shape[2:]) 137 | return x 138 | -------------------------------------------------------------------------------- /model/stylegan2/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | @torch.jit.script 6 | def clamp_gain(x: torch.Tensor, g: float, c: float): 7 | return torch.clamp(x * g, -c, c) 8 | 9 | 10 | def normalize_2nd_moment(x, dim=1, eps=1e-8): 11 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() 12 | 13 | 14 | def identity(x): 15 | return x 16 | 17 | 18 | def leaky_relu_0_2(x): 19 | return torch.nn.functional.leaky_relu(x, 0.2) 20 | 21 | 22 | activation_funcs = { 23 | "linear": { 24 | "fn": identity, 25 | "def_gain": 1 26 | }, 27 | "lrelu": { 28 | "fn": leaky_relu_0_2, 29 | "def_gain": np.sqrt(2) 30 | } 31 | } 32 | 33 | 34 | class FullyConnectedLayer(torch.nn.Module): 35 | 36 | def __init__(self, in_features, out_features, bias=True, activation='linear', lr_multiplier=1, bias_init=0): 37 | super().__init__() 38 | self.activation = activation_funcs[activation]['fn'] 39 | self.activation_gain = activation_funcs[activation]['def_gain'] 40 | self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) 41 | self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None 42 | self.weight_gain = lr_multiplier / np.sqrt(in_features) 43 | self.bias_gain = lr_multiplier 44 | 45 | def forward(self, x): 46 | w = self.weight * self.weight_gain 47 | b = self.bias 48 | if b is not None and self.bias_gain != 1: 49 | b = b * self.bias_gain 50 | x = self.activation(torch.addmm(b.unsqueeze(0), x, w.t())) * self.activation_gain 51 | return x 52 | 53 | 54 | class SmoothDownsample(torch.nn.Module): 55 | 56 | def __init__(self): 57 | super().__init__() 58 | kernel = [[1, 3, 3, 1], 59 | [3, 9, 9, 3], 60 | [3, 9, 9, 3], 61 | [1, 3, 3, 1]] 62 | kernel = torch.tensor([[kernel]], dtype=torch.float) 63 | kernel /= kernel.sum() 64 | self.kernel = torch.nn.Parameter(kernel, requires_grad=False) 65 | self.pad = torch.nn.ReplicationPad2d((2, 1, 2, 1)) 66 | 67 | def forward(self, x: torch.Tensor): 68 | b, c, h, w = x.shape 69 | x = x.view(-1, 1, h, w) 70 | x = self.pad(x) 71 | x = torch.nn.functional.conv2d(x, self.kernel).view(b, c, h, w) 72 | x = torch.nn.functional.interpolate(x, scale_factor=0.5, mode='nearest', recompute_scale_factor=False) 73 | return x 74 | 75 | 76 | class SmoothUpsample(torch.nn.Module): 77 | 78 | def __init__(self): 79 | super().__init__() 80 | kernel = [[1, 3, 3, 1], 81 | [3, 9, 9, 3], 82 | [3, 9, 9, 3], 83 | [1, 3, 3, 1]] 84 | kernel = torch.tensor([[kernel]], dtype=torch.float) 85 | kernel /= kernel.sum() 86 | self.kernel = torch.nn.Parameter(kernel, requires_grad=False) 87 | self.pad = torch.nn.ReplicationPad2d((2, 1, 2, 1)) 88 | 89 | def forward(self, x: torch.Tensor): 90 | b, c, h, w = x.shape 91 | x = x.view(-1, 1, h, w) 92 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest') 93 | x = self.pad(x) 94 | x = torch.nn.functional.conv2d(x, self.kernel).view(b, c, h * 2, w * 2) 95 | return x 96 | 97 | 98 | class EqualizedConv2d(torch.nn.Module): 99 | 100 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, activation='linear', resample=identity): 101 | super().__init__() 102 | self.resample = resample 103 | self.padding = kernel_size // 2 104 | self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) 105 | self.activation = activation_funcs[activation]['fn'] 106 | self.activation_gain = activation_funcs[activation]['def_gain'] 107 | weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]) 108 | bias = torch.zeros([out_channels]) if bias else None 109 | self.weight = torch.nn.Parameter(weight) 110 | self.bias = torch.nn.Parameter(bias) if bias is not None else None 111 | 112 | def forward(self, x, gain=1): 113 | w = self.weight * self.weight_gain 114 | b = self.bias[None, :, None, None] if self.bias is not None else 0 115 | x = self.resample(x) 116 | x = torch.nn.functional.conv2d(x, w, padding=self.padding) 117 | return clamp_gain(self.activation(x + b), self.activation_gain * gain, 256 * gain) 118 | 119 | 120 | def modulated_conv2d(x, weight, styles, padding=0, demodulate=True): 121 | batch_size = x.shape[0] 122 | out_channels, in_channels, kh, kw = weight.shape 123 | 124 | # Calculate per-sample weights and demodulation coefficients. 125 | w = weight.unsqueeze(0) # [NOIkk] 126 | w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] 127 | if demodulate: 128 | dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] 129 | w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] 130 | 131 | # Execute as one fused op using grouped convolution. 132 | batch_size = int(batch_size) 133 | x = x.reshape(1, -1, *x.shape[2:]) 134 | w = w.reshape(-1, in_channels, kh, kw) 135 | x = torch.nn.functional.conv2d(x, w, padding=padding, groups=batch_size) 136 | x = x.reshape(batch_size, -1, *x.shape[2:]) 137 | return x 138 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | @torch.jit.script 6 | def clamp_gain(x: torch.Tensor, g: float, c: float): 7 | return torch.clamp(x * g, -c, c) 8 | 9 | 10 | def normalize_2nd_moment(x, dim=1, eps=1e-8): 11 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() 12 | 13 | 14 | def identity(*x): 15 | return x[0] 16 | 17 | 18 | def leaky_relu_0_2(x): 19 | return torch.nn.functional.leaky_relu(x, 0.2) 20 | 21 | 22 | activation_funcs = { 23 | "linear": { 24 | "fn": identity, 25 | "def_gain": 1 26 | }, 27 | "lrelu": { 28 | "fn": leaky_relu_0_2, 29 | "def_gain": np.sqrt(2) 30 | } 31 | } 32 | 33 | 34 | class FullyConnectedLayer(torch.nn.Module): 35 | 36 | def __init__(self, in_features, out_features, bias=True, activation='linear', lr_multiplier=1, bias_init=0): 37 | super().__init__() 38 | self.activation = activation_funcs[activation]['fn'] 39 | self.activation_gain = activation_funcs[activation]['def_gain'] 40 | self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) 41 | self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None 42 | self.weight_gain = lr_multiplier / np.sqrt(in_features) 43 | self.bias_gain = lr_multiplier 44 | 45 | def forward(self, x): 46 | w = self.weight * self.weight_gain 47 | b = self.bias 48 | if b is not None and self.bias_gain != 1: 49 | b = b * self.bias_gain 50 | x = self.activation(torch.addmm(b.unsqueeze(0), x, w.t())) * self.activation_gain 51 | return x 52 | 53 | 54 | class SmoothDownsample(torch.nn.Module): 55 | 56 | def __init__(self): 57 | super().__init__() 58 | kernel = [[1, 3, 3, 1], 59 | [3, 9, 9, 3], 60 | [3, 9, 9, 3], 61 | [1, 3, 3, 1]] 62 | kernel = torch.tensor([[kernel]], dtype=torch.float) 63 | kernel /= kernel.sum() 64 | self.kernel = torch.nn.Parameter(kernel, requires_grad=False) 65 | self.pad = torch.nn.ReplicationPad2d((2, 1, 2, 1)) 66 | 67 | def forward(self, x: torch.Tensor): 68 | b, c, h, w = x.shape 69 | x = x.reshape(-1, 1, h, w) 70 | x = self.pad(x) 71 | x = torch.nn.functional.conv2d(x, self.kernel).view(b, c, h, w) 72 | x = torch.nn.functional.interpolate(x, scale_factor=0.5, mode='nearest', recompute_scale_factor=False) 73 | return x 74 | 75 | 76 | class SmoothUpsample(torch.nn.Module): 77 | 78 | def __init__(self): 79 | super().__init__() 80 | kernel = [[1, 3, 3, 1], 81 | [3, 9, 9, 3], 82 | [3, 9, 9, 3], 83 | [1, 3, 3, 1]] 84 | kernel = torch.tensor([[kernel]], dtype=torch.float) 85 | kernel /= kernel.sum() 86 | self.kernel = torch.nn.Parameter(kernel, requires_grad=False) 87 | self.pad = torch.nn.ReplicationPad2d((2, 1, 2, 1)) 88 | 89 | def forward(self, x: torch.Tensor): 90 | b, c, h, w = x.shape 91 | x = x.reshape(-1, 1, h, w) 92 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest') 93 | x = self.pad(x) 94 | x = torch.nn.functional.conv2d(x, self.kernel).view(b, c, h * 2, w * 2) 95 | return x 96 | 97 | 98 | class EqualizedConv2d(torch.nn.Module): 99 | 100 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, activation='linear', resample=identity): 101 | super().__init__() 102 | self.resample = resample 103 | self.padding = kernel_size // 2 104 | self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) 105 | self.activation = activation_funcs[activation]['fn'] 106 | self.activation_gain = activation_funcs[activation]['def_gain'] 107 | weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]) 108 | bias = torch.zeros([out_channels]) if bias else None 109 | self.weight = torch.nn.Parameter(weight) 110 | self.bias = torch.nn.Parameter(bias) if bias is not None else None 111 | 112 | def forward(self, x, gain=1): 113 | w = self.weight * self.weight_gain 114 | b = self.bias[None, :, None, None] if self.bias is not None else 0 115 | x = self.resample(x) 116 | x = torch.nn.functional.conv2d(x, w, padding=self.padding) 117 | return clamp_gain(self.activation(x + b), self.activation_gain * gain, 256 * gain) 118 | 119 | 120 | def modulated_conv2d(x, weight, styles, padding=0, demodulate=True): 121 | batch_size = x.shape[0] 122 | out_channels, in_channels, kh, kw = weight.shape 123 | 124 | # Calculate per-sample weights and demodulation coefficients. 125 | w = weight.unsqueeze(0) # [NOIkk] 126 | w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] 127 | if demodulate: 128 | dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] 129 | w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] 130 | 131 | # Execute as one fused op using grouped convolution. 132 | batch_size = int(batch_size) 133 | x = x.reshape(1, -1, *x.shape[2:]) 134 | w = w.reshape(-1, in_channels, kh, kw) 135 | x = torch.nn.functional.conv2d(x, w, padding=padding, groups=batch_size) 136 | x = x.reshape(batch_size, -1, *x.shape[2:]) 137 | return x 138 | 139 | -------------------------------------------------------------------------------- /model/differentiable_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import nvdiffrast.torch as dr 4 | from torchvision.ops import masks_to_boxes 5 | 6 | 7 | def transform_pos(pos, projection_matrix, world_to_cam_matrix): 8 | # (x,y,z) -> (x,y,z,1) 9 | t_mtx = torch.matmul(projection_matrix, world_to_cam_matrix) 10 | # noinspection PyArgumentList 11 | posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1) 12 | return torch.matmul(posw, t_mtx.t()) 13 | 14 | 15 | def transform_pos_mvp(pos, mvp): 16 | # noinspection PyArgumentList 17 | posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1) 18 | return torch.bmm(posw.unsqueeze(0).expand(mvp.shape[0], -1, -1), mvp.permute((0, 2, 1))).reshape((-1, 4)) 19 | 20 | 21 | def render(glctx, pos_clip, pos_idx, vtx_col, col_idx, resolution, ranges, _colorspace, background=None): 22 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution=[resolution, resolution], ranges=ranges) 23 | color, _ = dr.interpolate(vtx_col[None, ...], rast_out, col_idx) 24 | color = dr.antialias(color, rast_out, pos_clip, pos_idx) 25 | mask = color[..., -1:] == 0 26 | if background is None: 27 | one_tensor = torch.ones((color.shape[0], color.shape[3], 1, 1), device=color.device) 28 | else: 29 | one_tensor = background 30 | one_tensor_permuted = one_tensor.permute((0, 2, 3, 1)).contiguous() 31 | color = torch.where(mask, one_tensor_permuted, color) 32 | return color[:, :, :, :-1] 33 | 34 | 35 | def render_in_bounds(glctx, pos_clip, pos_idx, vtx_col, col_idx, resolution, ranges, color_space, background=None): 36 | render_resolution = int(resolution * 1.2) 37 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution=[render_resolution, render_resolution], ranges=ranges) 38 | color, _ = dr.interpolate(vtx_col[None, ...], rast_out, col_idx) 39 | color = dr.antialias(color, rast_out, pos_clip, pos_idx) 40 | mask = color[..., -1:] == 0 41 | if background is None: 42 | if color_space == 'rgb': 43 | one_tensor = torch.ones((color.shape[0], color.shape[3], 1, 1), device=color.device) 44 | else: 45 | one_tensor = torch.zeros((color.shape[0], color.shape[3], 1, 1), device=color.device) 46 | one_tensor[:, 0, :, :] = 1 47 | else: 48 | one_tensor = background 49 | one_tensor_permuted = one_tensor.permute((0, 2, 3, 1)).contiguous() 50 | color = torch.where(mask, one_tensor_permuted, color) # [:, :, :, :-1] 51 | color[..., -1:] = mask.float() 52 | color_crops = [] 53 | boxes = masks_to_boxes(torch.logical_not(mask.squeeze(-1))) 54 | for img_idx in range(color.shape[0]): 55 | x1, y1, x2, y2 = [int(val) for val in boxes[img_idx, :].tolist()] 56 | color_crop = color[img_idx, y1: y2, x1: x2, :].permute((2, 0, 1)) 57 | pad = [[0, 0], [0, 0]] 58 | if y2 - y1 > x2 - x1: 59 | total_pad = (y2 - y1) - (x2 - x1) 60 | pad[0][0] = total_pad // 2 61 | pad[0][1] = total_pad - pad[0][0] 62 | pad[1][0], pad[1][1] = 0, 0 63 | additional_pad = int((y2 - y1) * 0.1) 64 | else: 65 | total_pad = (x2 - x1) - (y2 - y1) 66 | pad[0][0], pad[0][1] = 0, 0 67 | pad[1][0] = total_pad // 2 68 | pad[1][1] = total_pad - pad[1][0] 69 | additional_pad = int((x2 - x1) * 0.1) 70 | for i in range(4): 71 | pad[i // 2][i % 2] += additional_pad 72 | 73 | padded = torch.ones((color_crop.shape[0], color_crop.shape[1] + pad[1][0] + pad[1][1], color_crop.shape[2] + pad[0][0] + pad[0][1]), device=color_crop.device) 74 | padded[:3, :, :] = padded[:3, :, :] * one_tensor[img_idx, :3, :, :] 75 | padded[:, pad[1][0]: padded.shape[1] - pad[1][1], pad[0][0]: padded.shape[2] - pad[0][1]] = color_crop 76 | # color_crop = T.Pad((pad[0][0], pad[1][0], pad[0][1], pad[1][1]), 1)(color_crop) 77 | color_crop = torch.nn.functional.interpolate(padded.unsqueeze(0), size=(resolution, resolution), mode='bilinear', align_corners=False).permute((0, 2, 3, 1)) 78 | color_crops.append(color_crop) 79 | return torch.cat(color_crops, dim=0) 80 | 81 | 82 | def intrinsic_to_projection(intrinsic_matrix): 83 | near, far = 0.1, 50. 84 | a, b = -(far + near) / (far - near), -2 * far * near / (far - near) 85 | projection_matrix = torch.tensor([ 86 | intrinsic_matrix[0][0] / intrinsic_matrix[0][2], 0, 0, 0, 87 | 0, -intrinsic_matrix[1][1] / intrinsic_matrix[1][2], 0, 0, 88 | 0, 0, a, b, 89 | 0, 0, -1, 0 90 | ]).float().reshape((4, 4)) 91 | return projection_matrix 92 | 93 | 94 | class DifferentiableRenderer(nn.Module): 95 | 96 | def __init__(self, resolution, mode='standard', color_space='rgb', num_channels=3): 97 | super().__init__() 98 | self.glctx = dr.RasterizeGLContext() 99 | self.resolution = resolution 100 | self.render_func = render 101 | self.color_space = color_space 102 | self.num_channels = num_channels 103 | if mode == 'bounds': 104 | self.render_func = render_in_bounds 105 | 106 | def render(self, vertex_positions, triface_indices, vertex_colors, ranges=None, background=None, resolution=None): 107 | if ranges is None: 108 | ranges = torch.tensor([[0, triface_indices.shape[0]]]).int() 109 | if resolution is None: 110 | resolution = self.resolution 111 | color = self.render_func(self.glctx, vertex_positions, triface_indices, vertex_colors, triface_indices, resolution, ranges, self.color_space, background) 112 | return color[:, :, :, :self.num_channels] 113 | 114 | def get_visible_triangles(self, vertex_positions, triface_indices, ranges=None, resolution=None): 115 | if ranges is None: 116 | ranges = torch.tensor([[0, triface_indices.shape[0]]]).int() 117 | if resolution is None: 118 | resolution = self.resolution 119 | rast, _ = dr.rasterize(self.glctx, vertex_positions, triface_indices, resolution=[resolution, resolution], ranges=ranges) 120 | return rast[..., 3] 121 | -------------------------------------------------------------------------------- /trainer/train_autoencoder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import numpy as np 5 | import hydra 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.utilities import rank_zero_only 8 | from dataset.distance_field import DistanceFieldDataset 9 | from model.autoencoder import AutoEncoder32 10 | from trainer import create_trainer 11 | from util.df_metrics import IoU, Chamfer3D, Precision, Recall 12 | 13 | 14 | class AutoencoderTrainer(pl.LightningModule): 15 | 16 | def __init__(self, config): 17 | super().__init__() 18 | self.save_hyperparameters(config) 19 | self.config = config 20 | self.train_data = DistanceFieldDataset(config, interval=[0, -1024]) 21 | self.val_data = DistanceFieldDataset(config, interval=[-1024, None]) 22 | self.num_vis_samples = 48 23 | self.model = AutoEncoder32() 24 | self.metrics = torch.nn.ModuleList([IoU(compute_on_step=False), Chamfer3D(compute_on_step=False), 25 | Precision(compute_on_step=False), Recall(compute_on_step=False)]) 26 | 27 | def configure_optimizers(self): 28 | opt = torch.optim.Adam(list(self.model.parameters()), lr=self.config.lr_e, eps=1e-8, weight_decay=1e-4) 29 | return opt 30 | 31 | def forward(self, batch): 32 | return self.model(self.normalize_df(batch['df'])) 33 | 34 | def training_step(self, batch, batch_idx): 35 | self.train_data.augment_batch_data(batch) 36 | predicted = self.network_pred_to_df(self.forward(batch)) 37 | weights = self.adjust_weights(predicted >= self.train_data.trunc, batch) 38 | loss_l1 = (torch.abs(predicted - batch['df']) * weights).mean() 39 | self.log("train/l1", loss_l1, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True) 40 | return loss_l1 41 | 42 | def validation_step(self, batch, batch_idx): 43 | predicted = self.network_pred_to_df(self.forward(batch)) 44 | loss_l1 = self.record_evaluation_for_batch(predicted, batch) 45 | self.log("val/l1", loss_l1, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 46 | if batch_idx < self.num_vis_samples: 47 | output_dir = Path(f'runs/{self.config.experiment}/visualization/{self.current_epoch:04d}') 48 | output_dir.mkdir(exist_ok=True, parents=True) 49 | for b in range(predicted.shape[0]): 50 | self.val_data.visualize_as_mesh(predicted[b][0].cpu().numpy(), output_dir / f"{batch['name'][b]}_pred.obj") 51 | self.val_data.visualize_as_mesh(batch['df'][b][0].cpu().numpy(), output_dir / f"{batch['name'][b]}_gt.obj") 52 | 53 | @rank_zero_only 54 | def validation_epoch_end(self, _outputs): 55 | iou, cd, precision, recall = self.metrics[0].compute(), self.metrics[1].compute(), self.metrics[2].compute(), self.metrics[3].compute() 56 | self.log("val/iou", iou, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 57 | self.log("val/cd", cd, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 58 | self.log("val/precision", precision, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 59 | self.log("val/recall", recall, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 60 | print(f"\nIoU = {iou:.3f} | CD = {cd:.3f} | P = {precision:.3f} | R = {recall:.3f}") 61 | 62 | def train_dataloader(self): 63 | return torch.utils.data.DataLoader(self.train_data, batch_size=self.config.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=self.config.num_workers) 64 | 65 | def val_dataloader(self): 66 | return torch.utils.data.DataLoader(self.val_data, batch_size=self.config.batch_size, shuffle=False, pin_memory=True, drop_last=False, num_workers=self.config.num_workers) 67 | 68 | def test_dataloader(self): 69 | data = DistanceFieldDataset(self.config, interval=[0, None]) 70 | return torch.utils.data.DataLoader(data, batch_size=self.config.batch_size, shuffle=False, pin_memory=False, drop_last=False, num_workers=self.config.num_workers) 71 | 72 | @staticmethod 73 | def adjust_weights(pred_empty, batch): 74 | weights = batch['weights'].clone().detach() 75 | weights[batch['empty'] & pred_empty] = 0 76 | return weights 77 | 78 | def network_pred_to_df(self, clamped_out): 79 | return (clamped_out + 1) * self.train_data.trunc / 2 80 | 81 | def normalize_df(self, df): 82 | return self.train_data.normalize(df) 83 | 84 | def record_evaluation_for_batch(self, pred_shape_df, batch): 85 | target_shape = batch['df'] <= self.train_data.vox_size 86 | predicted_shape = pred_shape_df <= self.train_data.vox_size 87 | for metric in self.metrics: 88 | metric(predicted_shape, target_shape) 89 | loss_l1 = torch.abs(pred_shape_df - batch['df']).mean() 90 | return loss_l1 91 | 92 | def test_step(self, batch, batch_idx): 93 | predicted = self.network_pred_to_df(self.forward(batch)) 94 | code = self.model.encoder(self.normalize_df(batch['df'])) 95 | b, c, d, h, w = code.shape 96 | code = code.reshape(b, c * d * h * w).cpu().numpy() 97 | self.record_evaluation_for_batch(predicted, batch) 98 | output_dir = Path(f'runs/{self.config.experiment}/latent/{self.current_epoch:04d}') 99 | output_dir.mkdir(exist_ok=True, parents=True) 100 | for b in range(code.shape[0]): 101 | np.save(output_dir / f'{batch["name"][b]}.npy', code[b]) 102 | 103 | def test_epoch_end(self, outputs): 104 | iou, cd, precision, recall = self.metrics[0].compute(), self.metrics[1].compute(), self.metrics[2].compute(), self.metrics[3].compute() 105 | print(f"\nIoU = {iou:.3f} | CD = {cd:.3f} | P = {precision:.3f} | R = {recall:.3f}") 106 | 107 | 108 | @hydra.main(config_path='../config', config_name='stylegan2') 109 | def main(config): 110 | trainer = create_trainer("Autoencoder32", config) 111 | model = AutoencoderTrainer(config) 112 | trainer.fit(model) 113 | 114 | 115 | @hydra.main(config_path='../config', config_name='stylegan2') 116 | def infer(config): 117 | trainer = create_trainer("Autoencoder32", config) 118 | model = AutoencoderTrainer(config) 119 | model.load_state_dict(torch.load(config.resume)['state_dict']) 120 | trainer.test(model) 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Texturify Generating Textures on 3D Shape Surfaces 2 | 3 | ## Dependencies 4 | 5 | Install python requirements: 6 | 7 | ```commandline 8 | pip install -r requirements.txt 9 | pip install torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 10 | ``` 11 | 12 | Install trimesh from our fork: 13 | ```bash 14 | cd ~ 15 | git clone git@github.com:nihalsid/trimesh.git 16 | cd trimesh 17 | python setup.py install 18 | ``` 19 | 20 | Also, for differentiable rendering we use `nvdiffrast`. You'll need to install its dependencies: 21 | 22 | ```bash 23 | sudo apt-get update && sudo apt-get install -y --no-install-recommends \ 24 | pkg-config \ 25 | libglvnd0 \ 26 | libgl1 \ 27 | libglx0 \ 28 | libegl1 \ 29 | libgles2 \ 30 | libglvnd-dev \ 31 | libgl1-mesa-dev \ 32 | libegl1-mesa-dev \ 33 | libgles2-mesa-dev \ 34 | cmake \ 35 | curl 36 | ``` 37 | 38 | Install `nvdiffrast` from official source: 39 | 40 | ```bash 41 | cd ~ 42 | git clone git@github.com:NVlabs/nvdiffrast.git 43 | cd nvdiffrast 44 | pip install . 45 | ``` 46 | 47 | Apart from this, you will need approporiate versions of torch-scatter, torch-sparse, torch-spline-conv, torch-geometric, depending on your torch+cuda combination. E.g. for torch-1.10 + cuda11.3 you'd need: 48 | 49 | ```commandline 50 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cu113.html 51 | ``` 52 | ## Dataset 53 | 54 | From project root execute: 55 | ```bash 56 | mkdir data 57 | cd data 58 | wget https://www.dropbox.com/s/or9tfmunvndibv0/data.zip 59 | unzip data.zip 60 | ``` 61 | 62 | For custom data processing check out https://github.com/nihalsid/CADTextures 63 | 64 | ## Output Directories 65 | 66 | Create a symlink `runs` in project root from a directory `OUTPUTDIR` where outputs would be stored 67 | ```bash 68 | ln -s OUTPUTDIR runs 69 | ``` 70 | 71 | ## Running Experiments 72 | 73 | Configuration provided with hydra config file `config/stylegan2.yaml`. Example training: 74 | 75 | ```bash 76 | python trainer/train_stylegan_real_feature.py wandb_main=False val_check_interval=5 experiment=test_run lr_d=0.001 sanity_steps=1 lambda_gp=14 image_size=512 batch_size=4 num_mapping_layers=5 views_per_sample=2 g_channel_base=32768 random_bg=grayscale num_vis_images=256 preload=False dataset_path=data/Photoshape/shapenet-chairs-manifold-highres-part_processed_color mesh_path=data/Photoshape/shapenet-chairs-manifold-highres pairmeta_path=data/Photoshape-model/metadata/pairs.json image_path=data/Photoshape/exemplars mask_path=data/Photoshape/exemplars_mask 77 | ``` 78 | 79 | ## Checkpoints 80 | 81 | Available [here](https://www.dropbox.com/scl/fi/cz9arygdbz05gucapldd1/texturify_checkpoints.zip?rlkey=n19t8x0zq13i7hmodfnjkgrst&dl=0). 82 | 83 | ## Configuration 84 | 85 | Configuration can be overriden with command line flags. 86 | 87 | | Key | Description | Default | 88 | | ----|-------------|---------| 89 | |`dataset_path`| Directory with processed data|| 90 | |`mesh_path`| Directory with processed mesh (highest res)|| 91 | |`pairmeta_path`| Directory with metadata for image-shape pairs (photoshape specific)|| 92 | |`df_path`| not used anymore || 93 | |`image_path`| real images || 94 | |`mask_path`| real image segmentation masks || 95 | |`condition_path`| not used anymore || 96 | |`stat_path`| not used anymore || 97 | |`uv_path`| processed uv data (for uv baseline) || 98 | |`silhoutte_path`| texture atlas silhoutte data (for uv baseline) || 99 | |`mesh_resolution`| not used anymore|| 100 | |`experiment`| Experiment name used for logs |`fast_dev`| 101 | |`wandb_main`| If false, results logged to "-dev" wandb project (for dev logs)|`False`| 102 | |`num_mapping_layers`| Number of layers in the mapping network |2| 103 | |`lr_g`| Generator learning rate | 0.002| 104 | |`lr_d`| Discriminator learning rate |0.00235| 105 | |`lr_e`| Encoder learning rate |0.0001| 106 | |`lambda_gp`| Gradient penalty weight | 0.0256 | 107 | |`lambda_plp`| Path length penalty weight |2| 108 | |`lazy_gradient_penalty_interval`| Gradient penalty regularizer interval |16| 109 | |`lazy_path_penalty_after`| Iteration after which path lenght penalty is active |0| 110 | |`lazy_path_penalty_interval`| Path length penalty regularizer interval |4| 111 | |`latent_dim`| Latent dim of starting noise and mapping network output |512| 112 | |`image_size`| Size of generated images |64| 113 | |`num_eval_images`| Number of images on which FID is computed |8096| 114 | |`num_vis_images`| Number of image visualized |1024| 115 | |`batch_size`| Mini batch size |16| 116 | |`num_workers`| Number of dataloader workers|8| 117 | |`seed`| RNG seed |null| 118 | |`save_epoch`| Epoch interval for checkpoint saves |1| 119 | |`sanity_steps`| Validation sanity runs before training start |1| 120 | |`max_epoch`| Maximum training epochs |250| 121 | |`val_check_interval`| Epoch interval for evaluating metrics and saving generated samples |1| 122 | |`resume`| Resume checkpoint |`null`| 123 | 124 | 125 | References 126 | ========== 127 | Official stylegan2-ada code and paper. 128 | 129 | ``` 130 | @article{Karras2019stylegan2, 131 | title = {Analyzing and Improving the Image Quality of {StyleGAN}}, 132 | author = {Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila}, 133 | journal = {CoRR}, 134 | volume = {abs/1912.04958}, 135 | year = {2019}, 136 | } 137 | ``` 138 | 139 | 140 | License 141 | ===================== 142 | 143 | Copyright © 2021 nihalsid 144 | 145 | Permission is hereby granted, free of charge, to any person 146 | obtaining a copy of this software and associated documentation 147 | files (the “Software”), to deal in the Software without 148 | restriction, including without limitation the rights to use, 149 | copy, modify, merge, publish, distribute, sublicense, and/or sell 150 | copies of the Software, and to permit persons to whom the 151 | Software is furnished to do so, subject to the following 152 | conditions: 153 | 154 | The above copyright notice and this permission notice shall be 155 | included in all copies or substantial portions of the Software. 156 | 157 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, 158 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 159 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 160 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 161 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 162 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 163 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 164 | OTHER DEALINGS IN THE SOFTWARE. 165 | 166 | -------------------------------------------------------------------------------- /util/camera.py: -------------------------------------------------------------------------------- 1 | # borrowed form keunhong/toolbox 2 | import numpy as np 3 | import math 4 | from numpy import linalg 5 | 6 | 7 | def normalized(vec): 8 | return vec / linalg.norm(vec) 9 | 10 | 11 | def normalize_to_range(array, lo, hi): 12 | if hi <= lo: 13 | raise ValueError('Range must be increasing but {} >= {}.'.format( 14 | lo, hi)) 15 | min_val = array.min() 16 | max_val = array.max() 17 | scale = max_val - min_val if (min_val < max_val) else 1 18 | return (array - min_val) / scale * (hi - lo) + lo 19 | 20 | 21 | class BaseCamera: 22 | def __init__(self, size, near, far, clear_color=(1.0, 1.0, 1.0, 1.0)): 23 | self.size = size 24 | self.near = near 25 | self.far = far 26 | self.clear_color = clear_color 27 | if len(self.clear_color) == 3: 28 | self.clear_color = (*self.clear_color, 1.0) 29 | self.position = None 30 | self.up = None 31 | self.lookat = None 32 | 33 | @property 34 | def left(self): 35 | return -self.size[0] / 2 36 | 37 | @property 38 | def right(self): 39 | return self.size[0] / 2 40 | 41 | @property 42 | def top(self): 43 | return self.size[1] / 2 44 | 45 | @property 46 | def bottom(self): 47 | return -self.size[1] / 2 48 | 49 | @property 50 | def forward(self): 51 | return normalized(np.subtract(self.lookat, self.position)) 52 | 53 | def projection_mat(self): 54 | raise NotImplementedError 55 | 56 | def rotation_mat(self): 57 | rotation_mat = np.eye(3) 58 | rotation_mat[0, :] = normalized(np.cross(self.forward, self.up)) 59 | rotation_mat[2, :] = -self.forward 60 | # We recompute the 'up' vector portion of the matrix as the cross 61 | # product of the forward and sideways vector so that we have an ortho- 62 | # normal basis. 63 | rotation_mat[1, :] = np.cross(rotation_mat[2, :], rotation_mat[0, :]) 64 | return rotation_mat 65 | 66 | def translation_vec(self): 67 | rotation_mat = self.rotation_mat() 68 | return -rotation_mat.T @ self.position 69 | 70 | def view_mat(self): 71 | rotation_mat = self.rotation_mat() 72 | position = rotation_mat.dot(self.position) 73 | 74 | view_mat = np.eye(4) 75 | view_mat[:3, :3] = rotation_mat 76 | view_mat[:3, 3] = -position 77 | 78 | return view_mat 79 | 80 | def cam_to_world(self): 81 | cam_to_world = np.eye(4) 82 | cam_to_world[:3, :3] = self.rotation_mat().T 83 | cam_to_world[:3, 3] = self.position 84 | return cam_to_world 85 | 86 | def handle_mouse(self, last_pos, cur_pos): 87 | pass 88 | 89 | def apply_projection(self, points): 90 | homo = euclidean_to_homogeneous(points) 91 | proj = self.projection_mat().dot(self.view_mat().dot(homo.T)).T 92 | proj = homogeneous_to_euclidean(proj)[:, :2] 93 | proj = (proj + 1) / 2 94 | proj[:, 0] = (proj[:, 0] * self.size[0]) 95 | proj[:, 1] = self.size[1] - (proj[:, 1] * self.size[1]) 96 | return np.fliplr(proj) 97 | 98 | def get_position(self): 99 | return linalg.inv(self.view_mat())[:3, 3] 100 | 101 | def serialize(self): 102 | raise NotImplementedError() 103 | 104 | 105 | class PerspectiveCamera(BaseCamera): 106 | 107 | def __init__(self, size, near, far, fov, position, lookat, up, 108 | *args, **kwargs): 109 | super().__init__(size, near, far, *args, **kwargs) 110 | 111 | self.fov = fov 112 | self._position = np.array(position, dtype=np.float32) 113 | self.lookat = np.array(lookat, dtype=np.float32) 114 | self.up = normalized(np.array(up)) 115 | 116 | @property 117 | def position(self): 118 | return self._position 119 | 120 | @position.setter 121 | def position(self, position): 122 | self._position = np.array(position) 123 | 124 | def projection_mat(self): 125 | mat = perspective(self.fov, self.size[0] / self.size[1], self.near, self.far).T 126 | return mat 127 | 128 | def view_mat(self): 129 | rotation_mat = np.eye(3) 130 | rotation_mat[0, :] = normalized(np.cross(self.forward, self.up)) 131 | rotation_mat[2, :] = -self.forward 132 | # We recompute the 'up' vector portion of the matrix as the cross 133 | # product of the forward and sideways vector so that we have an ortho- 134 | # normal basis. 135 | rotation_mat[1, :] = np.cross(rotation_mat[2, :], rotation_mat[0, :]) 136 | 137 | position = rotation_mat.dot(self.position) 138 | 139 | view_mat = np.eye(4) 140 | view_mat[:3, :3] = rotation_mat 141 | view_mat[:3, 3] = -position 142 | return view_mat 143 | 144 | def serialize(self): 145 | return { 146 | 'type': 'perspective', 147 | 'size': self.size, 148 | 'near': float(self.near), 149 | 'far': float(self.far), 150 | 'fov': self.fov, 151 | 'position': self.position.tolist(), 152 | 'lookat': self.lookat.tolist(), 153 | 'up': self.up.tolist(), 154 | 'clear_color': self.clear_color, 155 | } 156 | 157 | 158 | def spherical_to_cartesian(radius, azimuth, elevation): 159 | x = radius * math.cos(azimuth + 3 * math.pi / 2) * math.sin(elevation) 160 | y = radius * math.cos(elevation) 161 | z = radius * math.sin(azimuth + 3 * math.pi / 2) * math.sin(elevation) 162 | return x, y, z 163 | 164 | 165 | def spherical_coord_to_cam(fov, azimuth, elevation, max_len=500, cam_dist=1.75): 166 | shape = (max_len * 2, max_len * 2) 167 | camera = PerspectiveCamera( 168 | size=shape, fov=fov, near=0.1, far=5000.0, 169 | position=(0, 0, -cam_dist), clear_color=(1, 1, 1, 1), 170 | lookat=(0, 0, 0), up=(0, 1, 0)) 171 | camera.position = spherical_to_cartesian(cam_dist, azimuth, elevation) 172 | return camera 173 | 174 | 175 | def euclidean_to_homogeneous(points): 176 | ones = np.ones((points.shape[0], 1)) 177 | return np.concatenate((points, ones), 1) 178 | 179 | 180 | def homogeneous_to_euclidean(points): 181 | ndims = points.shape[1] 182 | euclidean_points = np.array(points[:, 0:ndims - 1]) / points[:, -1, None] 183 | return euclidean_points 184 | 185 | 186 | def perspective(fovy, aspect, znear, zfar): 187 | assert(znear != zfar) 188 | h = math.tan(fovy / 360.0 * math.pi) * znear 189 | w = h * aspect 190 | return frustum(-w, w, h, -h, znear, zfar) 191 | 192 | 193 | def frustum(left, right, bottom, top, znear, zfar): 194 | M = np.zeros((4, 4), dtype=np.float32) 195 | M[0, 0] = +2.0 * znear / float(right - left) 196 | M[2, 0] = (right + left) / float(right - left) 197 | M[1, 1] = +2.0 * znear / float(top - bottom) 198 | M[2, 1] = (top + bottom) / float(top - bottom) 199 | M[2, 2] = -(zfar + znear) / float(zfar - znear) 200 | M[3, 2] = -2.0 * znear * zfar / float(zfar - znear) 201 | M[2, 3] = -1.0 202 | return M -------------------------------------------------------------------------------- /test/test_raycast_rgbd.py: -------------------------------------------------------------------------------- 1 | from model.styleganvox.raycast_rgbd.raycast_rgbd import Raycast2DHandler 2 | import numpy as np 3 | import torch 4 | import random 5 | import math 6 | from PIL import Image 7 | from scipy.spatial.transform import Rotation 8 | from pathlib import Path 9 | import json 10 | import time 11 | from util.camera import spherical_coord_to_cam 12 | 13 | 14 | def get_random_views(num_views): 15 | elevation_params = [1.407, 0.207, 0.785, 1.767] 16 | azimuth = random.sample(np.arange(0, 2 * math.pi).tolist(), num_views) 17 | elevation = np.random.uniform(low=elevation_params[2], high=elevation_params[3], size=num_views).tolist() 18 | return [{'fov': 50, 'azimuth': a, 'elevation': e} for a, e in zip(azimuth, elevation)] 19 | 20 | 21 | def meta_to_pair(c): 22 | return f'shape{c["shape_id"]:05d}_rank{(c["rank"] - 1):02d}_pair{c["id"]}' 23 | 24 | 25 | def load_pair_meta_views(image_path, pairmeta_path): 26 | dataset_images = [x.stem for x in image_path.iterdir()] 27 | loaded_json = json.loads(Path(pairmeta_path).read_text()) 28 | ret_dict = {} 29 | for k in loaded_json.keys(): 30 | if meta_to_pair(loaded_json[k]) in dataset_images: 31 | ret_dict[meta_to_pair(loaded_json[k])] = loaded_json[k] 32 | return ret_dict 33 | 34 | 35 | def test_raycast(): 36 | import os 37 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 38 | voxel_size = 0.020834 39 | dims = (96, 96, 96) 40 | batch_size, render_shape = 4, (256, 256) 41 | trunc = 5 * voxel_size 42 | pairmeta_path = Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape-model/metadata/pairs.json") 43 | image_path = Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape/exemplars") 44 | 45 | sdf = torch.from_numpy(np.load("/cluster/gimli/ysiddiqui/CADTextures/Photoshape-model/shapenet-chairs/shape02320_rank00_pair83185/096.npy")).cuda() 46 | color = torch.from_numpy(np.load("/cluster/gimli/ysiddiqui/CADTextures/Photoshape-model/shapenet-chairs/shape02320_rank00_pair83185/096_if.npy")).cuda() 47 | 48 | sdf = sdf.unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1, -1, -1).float() - 0.01 49 | color = color.permute((3, 0, 1, 2)).unsqueeze(0).expand(batch_size, -1, -1, -1, -1).float() 50 | 51 | raycast_handler = Raycast2DHandler(batch_size, dims, render_shape, voxel_size, trunc) 52 | 53 | views_photoshape = load_pair_meta_views(image_path, pairmeta_path) 54 | view_keys = sorted(list(views_photoshape.keys())) 55 | 56 | repeats = 10 57 | 58 | for _ in range(repeats): 59 | projections, views = [], [] 60 | 61 | for b_idx in range(batch_size): 62 | c_v = views_photoshape[random.choice(view_keys)] 63 | # c_v = views_photoshape[view_keys[0]] 64 | perspective_cam = spherical_coord_to_cam(c_v['fov'], c_v['azimuth'], c_v['elevation']) 65 | 66 | y_angle = c_v['azimuth'] * 180 / math.pi 67 | x_angle = 90 - c_v['elevation'] * 180 / math.pi 68 | z_angle = 180 #random.random() * 90 69 | camera_rot = np.eye(4) 70 | camera_rot[:3, :3] = Rotation.from_euler('x', x_angle, degrees=True).as_matrix() @ Rotation.from_euler('z', z_angle, degrees=True).as_matrix() @ Rotation.from_euler('y', y_angle, degrees=True).as_matrix() 71 | camera_translation = np.eye(4) 72 | camera_translation[:3, 3] = np.array([0, 0, 1.75]) 73 | camera_pose = camera_translation @ camera_rot 74 | 75 | translate = torch.tensor([[1.0, 0, 0, 48], [0, 1.0, 0, 48], [0, 0, 1.0, 48], [0, 0, 0, 1.0]]).to(sdf.device).float() 76 | scale = torch.tensor([[48.0, 0, 0, 0], [0, 48.0, 0, 0], [0, 0, 48.0, 0], [0, 0, 0, 1.0]]).to(sdf.device).float() 77 | world2grid = translate @ scale 78 | translate_0 = torch.tensor([[1.0, 0, 0, 0], [0, 1.0, 0, 0.0], [0, 0, 1.0, -0.5], [0, 0, 0, 1.0]]).to(sdf.device).float() 79 | # view_matrix = torch.from_numpy(perspective_cam.view_mat()).to(sdf.device).float() @ scale @ translate 80 | # view_matrix = translate @ scale @ translate_0 81 | # view_matrix = world2grid @ translate_0 82 | # view_matrix = world2grid @ torch.from_numpy(perspective_cam.view_mat()).to(sdf.device).float() 83 | view_matrix = world2grid @ torch.linalg.inv(torch.from_numpy(camera_pose).to(sdf.device).float()) 84 | # views.append(torch.linalg.inv(view_matrix)) 85 | views.append(view_matrix) 86 | camera_intrinsics = torch.zeros((4,)).to(sdf.device) 87 | f = render_shape[1] / (2 * np.tan(c_v['fov'] * np.pi / 180 / 2.0)) 88 | camera_intrinsics[0] = f 89 | camera_intrinsics[1] = f 90 | camera_intrinsics[2] = render_shape[0] / 2 91 | camera_intrinsics[3] = render_shape[0] / 2 92 | # camera_intrinsics = torch.ones((4,)).to(sdf.device) 93 | projections.append(camera_intrinsics) 94 | 95 | r_color, r_depth, r_normals = raycast_handler.raycast_sdf(sdf, color, torch.stack(views, dim=0), torch.stack(projections, dim=0)) 96 | print(r_color.max(), r_color.min(), r_color.shape) 97 | for idx in range(r_color.shape[0]): 98 | color_i = r_color[idx].cpu().numpy() 99 | color_i[color_i == -float('inf')] = 0 100 | Image.fromarray(color_i.astype(np.uint8)).save(f"render_{idx}.jpg") 101 | 102 | time.sleep(2) 103 | 104 | 105 | def test_views(): 106 | import trimesh 107 | import pyrender 108 | from pyrender import RenderFlags 109 | mesh = trimesh.load("/cluster/gimli/ysiddiqui/CADTextures/Photoshape-model/shapenet-chairs/shape02320_rank00_pair83185/096.obj", process=True) 110 | for idx, c_v in enumerate(get_random_views(4)): 111 | spherical_camera = spherical_coord_to_cam(c_v['fov'], c_v['azimuth'], c_v['elevation']) 112 | translate_0 = np.array([[1.0, 0, 0, 0], [0, 1.0, 0, 0.00], [0, 0, 1.0, -1.75], [0, 0, 0, 1.0]]) 113 | translate = np.array([[1.0, 0, 0, -48], [0, 1.0, 0, -48], [0, 0, 1.0, -48], [0, 0, 0, 1.0]]) 114 | scale = np.array([[1.0 / 48, 0, 0, 0], [0, 1.0 / 48, 0, 0], [0, 0, 1.0 / 48, 0], [0, 0, 0, 1.0]]) 115 | # view = spherical_camera.view_mat() @ scale @ translate 116 | view = translate_0 @ scale @ translate 117 | # view = spherical_camera.view_mat() 118 | camera_pose = np.linalg.inv(view) 119 | r = pyrender.OffscreenRenderer(256, 256) 120 | camera = pyrender.PerspectiveCamera(yfov=np.pi * c_v['fov'] / 180, aspectRatio=1.0, znear=0.001) 121 | camera_intrinsics = np.eye(4, dtype=np.float32) 122 | camera_intrinsics[0, 0] = camera_intrinsics[1, 1] = 256 / (2 * np.tan(camera.yfov / 2.0)) 123 | camera_intrinsics[0, 2] = camera_intrinsics[1, 2] = camera_intrinsics[2, 2] = -256 / 2 124 | scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 0.0], ambient_light=[0.5, 0.5, 0.5]) 125 | geo = pyrender.Mesh.from_trimesh(mesh) 126 | scene.add(geo) 127 | scene.add(camera, pose=camera_pose) 128 | color_flat, depth = r.render(scene, flags=RenderFlags.FLAT | RenderFlags.SKIP_CULL_FACES) 129 | Image.fromarray(color_flat).save(f"render_{idx}.jpg") 130 | 131 | 132 | if __name__ == "__main__": 133 | # test_views() 134 | test_raycast() 135 | -------------------------------------------------------------------------------- /dataset/mesh_real_sdfgrid.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from collections import defaultdict 4 | from pathlib import Path 5 | import numpy as np 6 | import math 7 | 8 | import torch 9 | import torchvision.transforms as T 10 | from torchvision.io import read_image 11 | from tqdm import tqdm 12 | 13 | from scipy.spatial.transform import Rotation 14 | 15 | 16 | class SDFGridDataset(torch.utils.data.Dataset): 17 | 18 | def __init__(self, config, limit_dataset_size=None): 19 | self.dataset_directory = Path(config.dataset_path) 20 | self.mesh_directory = Path(config.mesh_path) 21 | self.image_size = config.image_size 22 | self.bg_color = config.random_bg 23 | self.real_images = {x.name.split('.')[0]: x for x in Path(config.image_path).iterdir() if x.name.endswith('.jpg') or x.name.endswith('.png')} 24 | self.masks = {x: Path(config.mask_path) / self.real_images[x].name for x in self.real_images} 25 | self.items = sorted(list(x.stem for x in Path(config.dataset_path).iterdir())[:limit_dataset_size]) 26 | self.views_per_sample = 1 27 | self.erode = config.erode 28 | self.real_pad = 150 29 | self.pair_meta, self.all_views = self.load_pair_meta(config.pairmeta_path) 30 | self.real_images_preloaded, self.masks_preloaded = {}, {} 31 | if config.preload: 32 | self.preload_real_images() 33 | 34 | def __len__(self): 35 | return len(self.items) 36 | 37 | def __getitem__(self, idx): 38 | selected_item = self.items[idx] 39 | sdf_grid = torch.from_numpy(np.load(self.dataset_directory / selected_item / "064.npy")).unsqueeze(0) - 0.0075 40 | color_grid = 2 * torch.from_numpy(np.load(self.dataset_directory / selected_item / "064_if.npy")).permute((3, 0, 1, 2)) / 255.0 - 1 41 | real_sample, masks, view_matrices, projection_matrices = self.get_image_and_view(selected_item) 42 | if self.bg_color == 'white': 43 | bg = torch.tensor(1.).float() 44 | else: 45 | bg = torch.tensor(random.random() * 2 - 1).float() 46 | return { 47 | "name": selected_item, 48 | "x": sdf_grid.float(), 49 | "y": color_grid.float(), 50 | "view": view_matrices[0], 51 | "intrinsic": projection_matrices[0], 52 | "real": real_sample[0], 53 | "mask": masks[0], 54 | "bg": bg 55 | } 56 | 57 | def get_image_and_view(self, shape): 58 | shape_id = int(shape.split('_')[0].split('shape')[1]) 59 | image_selections = self.get_image_selections(shape_id) 60 | view_selections = random.sample(self.all_views, self.views_per_sample) 61 | images, masks, view_matrices, projection_matrices = [], [], [], [] 62 | for c_i, c_v in zip(image_selections, view_selections): 63 | images.append(self.get_real_image(self.meta_to_pair(c_i))) 64 | masks.append(self.get_real_mask(self.meta_to_pair(c_i))) 65 | view_matrix, projection_matrix = self.get_camera(c_v['fov'], c_v['azimuth'], c_v['elevation']) 66 | view_matrices.append(view_matrix) 67 | projection_matrices.append(projection_matrix) 68 | image = torch.cat(images, dim=0) 69 | masks = torch.cat(masks, dim=0) 70 | return image, masks, torch.stack(view_matrices, dim=0), torch.stack(projection_matrices, dim=0) 71 | 72 | def get_real_image(self, name): 73 | if name not in self.real_images_preloaded.keys(): 74 | return self.process_real_image(self.real_images[name]) 75 | else: 76 | return self.real_images_preloaded[name] 77 | 78 | def get_real_mask(self, name): 79 | if name not in self.masks_preloaded.keys(): 80 | return self.process_real_mask(self.masks[name]) 81 | else: 82 | return self.masks_preloaded[name] 83 | 84 | @staticmethod 85 | def erode_mask(mask): 86 | import cv2 as cv 87 | mask = mask.squeeze(0).numpy().astype(np.uint8) 88 | kernel_size = 3 89 | element = cv.getStructuringElement(cv.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1), (kernel_size, kernel_size)) 90 | mask = cv.erode(mask, element) 91 | return torch.from_numpy(mask).unsqueeze(0) 92 | 93 | def process_real_mask(self, path): 94 | resize = T.Resize(size=(self.image_size, self.image_size)) 95 | pad = T.Pad(padding=(self.real_pad, self.real_pad), fill=0) 96 | if self.erode: 97 | eroded_mask = self.erode_mask(read_image(str(path))) 98 | else: 99 | eroded_mask = read_image(str(path)) 100 | t_mask = resize(pad((eroded_mask > 0).float())) 101 | return t_mask.unsqueeze(0) 102 | 103 | def get_image_selections(self, shape_id): 104 | candidates = self.pair_meta[shape_id] 105 | if len(candidates) < self.views_per_sample: 106 | while len(candidates) < self.views_per_sample: 107 | meta = self.pair_meta[random.choice(list(self.pair_meta.keys()))] 108 | candidates.extend(meta[:self.views_per_sample - len(candidates)]) 109 | else: 110 | candidates = random.sample(candidates, self.views_per_sample) 111 | return candidates 112 | 113 | def process_real_image(self, path): 114 | resize = T.Resize(size=(self.image_size, self.image_size)) 115 | pad = T.Pad(padding=(self.real_pad, self.real_pad), fill=1) 116 | t_image = resize(pad(read_image(str(path)).float() / 127.5 - 1)) 117 | return t_image.unsqueeze(0) 118 | 119 | def load_pair_meta(self, pairmeta_path): 120 | loaded_json = json.loads(Path(pairmeta_path).read_text()) 121 | ret_dict = defaultdict(list) 122 | ret_views = [] 123 | for k in loaded_json.keys(): 124 | if self.meta_to_pair(loaded_json[k]) in self.real_images.keys(): 125 | ret_dict[loaded_json[k]['shape_id']].append(loaded_json[k]) 126 | ret_views.append(loaded_json[k]) 127 | return ret_dict, ret_views 128 | 129 | def preload_real_images(self): 130 | for ri in tqdm(self.real_images.keys(), desc='preload'): 131 | self.real_images_preloaded[ri] = self.process_real_image(self.real_images[ri]) 132 | self.masks_preloaded[ri] = self.process_real_mask(self.masks[ri]) 133 | 134 | @staticmethod 135 | def meta_to_pair(c): 136 | return f'shape{c["shape_id"]:05d}_rank{(c["rank"] - 1):02d}_pair{c["id"]}' 137 | 138 | def get_camera(self, fov, azimuth, elevation): 139 | y_angle = azimuth * 180 / math.pi 140 | x_angle = 90 - elevation * 180 / math.pi 141 | z_angle = 180 142 | camera_rot = np.eye(4) 143 | camera_rot[:3, :3] = Rotation.from_euler('x', x_angle, degrees=True).as_matrix() @ Rotation.from_euler('z', z_angle, degrees=True).as_matrix() @ Rotation.from_euler('y', y_angle, degrees=True).as_matrix() 144 | camera_translation = np.eye(4) 145 | camera_translation[:3, 3] = np.array([0, 0, 1.75]) 146 | camera_pose = camera_translation @ camera_rot 147 | translate = torch.tensor([[1.0, 0, 0, 32], [0, 1.0, 0, 32], [0, 0, 1.0, 32], [0, 0, 0, 1.0]]).float() 148 | scale = torch.tensor([[32.0, 0, 0, 0], [0, 32.0, 0, 0], [0, 0, 32.0, 0], [0, 0, 0, 1.0]]).float() 149 | world2grid = translate @ scale 150 | view_matrix = world2grid @ torch.linalg.inv(torch.from_numpy(camera_pose).float()) 151 | camera_intrinsics = torch.zeros((4,)) 152 | f = self.image_size / (2 * np.tan(fov * np.pi / 180 / 2.0)) 153 | camera_intrinsics[0] = f 154 | camera_intrinsics[1] = f 155 | camera_intrinsics[2] = self.image_size / 2 156 | camera_intrinsics[3] = self.image_size / 2 157 | return view_matrix, camera_intrinsics 158 | -------------------------------------------------------------------------------- /dataset/mesh_real_volume.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import torch 7 | import trimesh 8 | from torchvision.io import read_image 9 | from tqdm import tqdm 10 | import json 11 | import numpy as np 12 | 13 | from util.camera import spherical_coord_to_cam 14 | from util.misc import EasyDict 15 | import torchvision.transforms as T 16 | 17 | 18 | class FaceGraphMeshDataset(torch.utils.data.Dataset): 19 | 20 | def __init__(self, config, limit_dataset_size=None): 21 | self.dataset_directory = Path(config.dataset_path) 22 | self.mesh_directory = Path(config.mesh_path) 23 | self.df_directory = Path(config.df_path) 24 | self.image_size = config.image_size 25 | self.real_images = {x.name.split('.')[0]: x for x in Path(config.image_path).iterdir() if x.name.endswith('.jpg') or x.name.endswith('.png')} 26 | self.items = list(x.stem for x in Path(config.dataset_path).iterdir())[:limit_dataset_size] 27 | self.target_name = "model_normalized.obj" 28 | self.views_per_sample = config.views_per_sample 29 | self.mean = config.df_mean 30 | self.std = config.df_std 31 | self.pair_meta, self.all_views = self.load_pair_meta(config.pairmeta_path) 32 | self.real_images_preloaded = {} 33 | if config.preload: 34 | self.preload_real_images() 35 | 36 | def __len__(self): 37 | return len(self.items) 38 | 39 | def __getitem__(self, idx): 40 | selected_item = self.items[idx] 41 | pt_arxiv = torch.load(os.path.join(self.dataset_directory, f'{selected_item}.pt')) 42 | edge_index = pt_arxiv['conv_data'][0][0].long() 43 | num_sub_vertices = [pt_arxiv['conv_data'][i][0].shape[0] for i in range(1, len(pt_arxiv['conv_data']))] 44 | pad_sizes = [pt_arxiv['conv_data'][i][2].shape[0] for i in range(len(pt_arxiv['conv_data']))] 45 | sub_edges = [pt_arxiv['conv_data'][i][0].long() for i in range(1, len(pt_arxiv['conv_data']))] 46 | pool_maps = pt_arxiv['pool_locations'] 47 | is_pad = [pt_arxiv['conv_data'][i][4].bool() for i in range(len(pt_arxiv['conv_data']))] 48 | level_masks = [torch.zeros(pt_arxiv['conv_data'][i][0].shape[0]).long() for i in range(len(pt_arxiv['conv_data']))] 49 | 50 | # noinspection PyTypeChecker 51 | mesh = trimesh.load(self.mesh_directory / selected_item / self.target_name, process=False) 52 | vertices = torch.from_numpy(mesh.vertices).float() 53 | indices = torch.from_numpy(mesh.faces).int() 54 | tri_indices = torch.cat([indices[:, [0, 1, 2]], indices[:, [0, 2, 3]]], 0) 55 | vctr = torch.tensor(list(range(vertices.shape[0]))).long() 56 | df = torch.from_numpy(np.load(str(self.df_directory / selected_item / "032.npy"))).float().unsqueeze(0) 57 | df = (df - self.mean) / self.std 58 | real_sample, mvp = self.get_image_and_view(selected_item) 59 | 60 | return { 61 | "name": selected_item, 62 | "y": pt_arxiv['target_colors'].float() * 2, 63 | "df": df, 64 | "vertex_ctr": vctr, 65 | "vertices": vertices, 66 | "indices_quad": indices, 67 | "mvp": mvp, 68 | "real": real_sample, 69 | "indices": tri_indices, 70 | "ranges": torch.tensor([0, tri_indices.shape[0]]).int(), 71 | "graph_data": self.get_item_as_graphdata(edge_index, sub_edges, pad_sizes, num_sub_vertices, pool_maps, is_pad, level_masks) 72 | } 73 | 74 | @staticmethod 75 | def get_item_as_graphdata(edge_index, sub_edges, pad_sizes, num_sub_vertices, pool_maps, is_pad, level_masks): 76 | return EasyDict({ 77 | 'face_neighborhood': edge_index, 78 | 'sub_neighborhoods': sub_edges, 79 | 'pads': pad_sizes, 80 | 'node_counts': num_sub_vertices, 81 | 'pool_maps': pool_maps, 82 | 'is_pad': is_pad, 83 | 'level_masks': level_masks 84 | }) 85 | 86 | def visualize_graph_with_predictions(self, name, prediction, output_dir, output_suffix): 87 | output_dir = Path(output_dir) 88 | output_dir.mkdir(exist_ok=True, parents=True) 89 | # noinspection PyTypeChecker 90 | mesh = trimesh.load(Path(self.raw_dir, name) / self.target_name, force='mesh', process=False) 91 | mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, face_colors=prediction + 0.5, process=False) 92 | mesh.export(output_dir / f"{name}_{output_suffix}.obj") 93 | 94 | @staticmethod 95 | def batch_mask(t, graph_data, idx, level=0): 96 | return t[graph_data['level_masks'][level] == idx] 97 | 98 | def get_image_and_view(self, shape): 99 | shape_id = int(shape.split('_')[0].split('shape')[1]) 100 | image_selections = self.get_image_selections(shape_id) 101 | images, cameras = [], [] 102 | for c_i in image_selections: 103 | images.append(self.get_real_image(self.meta_to_pair(c_i))) 104 | perspective_cam = spherical_coord_to_cam(c_i['fov'], c_i['azimuth'], c_i['elevation']) 105 | # projection_matrix = intrinsic_to_projection(get_default_perspective_cam()).float() 106 | projection_matrix = torch.from_numpy(perspective_cam.projection_mat()).float() 107 | # view_matrix = torch.from_numpy(np.linalg.inv(generate_camera(np.zeros(3), c['azimuth'], c['elevation']))).float() 108 | view_matrix = torch.from_numpy(perspective_cam.view_mat()).float() 109 | cameras.append(torch.matmul(projection_matrix, view_matrix)) 110 | image = torch.cat(images, dim=0) 111 | mvp = torch.stack(cameras, dim=0) 112 | return image, mvp 113 | 114 | def get_real_image(self, name): 115 | if name not in self.real_images_preloaded.keys(): 116 | resize = T.Resize(size=(self.image_size, self.image_size)) 117 | pad = T.Pad(padding=(100, 100), fill=1) 118 | image = read_image(str(self.real_images[name])).float() / 127.5 - 1 119 | return resize(pad(image)).unsqueeze(0) 120 | else: 121 | return self.real_images_preloaded[name] 122 | 123 | def get_image_selections(self, shape_id): 124 | candidates = self.pair_meta[shape_id] 125 | if len(candidates) < self.views_per_sample: 126 | while len(candidates) < self.views_per_sample: 127 | meta = self.pair_meta[random.choice(list(self.pair_meta.keys()))] 128 | candidates.extend(meta[:self.views_per_sample - len(candidates)]) 129 | else: 130 | candidates = random.sample(candidates, self.views_per_sample) 131 | return candidates 132 | 133 | def process_real_image(self, path): 134 | resize = T.Resize(size=(self.image_size, self.image_size)) 135 | pad = T.Pad(padding=(100, 100), fill=1) 136 | return resize(pad(read_image(str(path)).float() / 127.5 - 1)).unsqueeze(0) 137 | 138 | def load_pair_meta(self, pairmeta_path): 139 | loaded_json = json.loads(Path(pairmeta_path).read_text()) 140 | ret_dict = defaultdict(list) 141 | ret_views = [] 142 | for k in loaded_json.keys(): 143 | if self.meta_to_pair(loaded_json[k]) in self.real_images.keys(): 144 | ret_dict[loaded_json[k]['shape_id']].append(loaded_json[k]) 145 | ret_views.append(loaded_json[k]) 146 | return ret_dict, ret_views 147 | 148 | def preload_real_images(self): 149 | for ri in tqdm(self.real_images.keys(), desc='preload'): 150 | self.real_images_preloaded[ri] = self.process_real_image(self.real_images[ri]) 151 | 152 | @staticmethod 153 | def meta_to_pair(c): 154 | return f'shape{c["shape_id"]:05d}_rank{(c["rank"] - 1):02d}_pair{c["id"]}' 155 | -------------------------------------------------------------------------------- /dataset/mesh_real.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import cv2 5 | from collections import defaultdict 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | import torchvision.transforms as T 11 | import trimesh 12 | from torchvision.io import read_image 13 | from tqdm import tqdm 14 | 15 | from util.camera import spherical_coord_to_cam 16 | from util.misc import EasyDict 17 | 18 | 19 | class FaceGraphMeshDataset(torch.utils.data.Dataset): 20 | 21 | def __init__(self, config, limit_dataset_size=None): 22 | self.dataset_directory = Path(config.dataset_path) 23 | self.mesh_directory = Path(config.mesh_path) 24 | self.image_size = config.image_size 25 | self.real_images = {x.name.split('.')[0]: x for x in Path(config.image_path).iterdir() if x.name.endswith('.jpg') or x.name.endswith('.png')} 26 | self.items = list(x.stem for x in Path(config.dataset_path).iterdir())[:limit_dataset_size] 27 | self.condition_directory = None 28 | if config.condition_path is not None: 29 | self.condition_directory = Path(config.condition_path) 30 | self.target_name = "model_normalized.obj" 31 | self.views_per_sample = config.views_per_sample 32 | self.pair_meta, self.all_views = self.load_pair_meta(config.pairmeta_path) 33 | self.real_images_preloaded = {} 34 | if config.preload: 35 | self.preload_real_images() 36 | 37 | def __len__(self): 38 | return len(self.items) 39 | 40 | def __getitem__(self, idx): 41 | selected_item = self.items[idx] 42 | pt_arxiv = torch.load(os.path.join(self.dataset_directory, f'{selected_item}.pt')) 43 | edge_index = pt_arxiv['conv_data'][0][0].long() 44 | num_sub_vertices = [pt_arxiv['conv_data'][i][0].shape[0] for i in range(1, len(pt_arxiv['conv_data']))] 45 | pad_sizes = [pt_arxiv['conv_data'][i][2].shape[0] for i in range(len(pt_arxiv['conv_data']))] 46 | sub_edges = [pt_arxiv['conv_data'][i][0].long() for i in range(1, len(pt_arxiv['conv_data']))] 47 | pool_maps = pt_arxiv['pool_locations'] 48 | is_pad = [pt_arxiv['conv_data'][i][4].bool() for i in range(len(pt_arxiv['conv_data']))] 49 | level_masks = [torch.zeros(pt_arxiv['conv_data'][i][0].shape[0]).long() for i in range(len(pt_arxiv['conv_data']))] 50 | 51 | # noinspection PyTypeChecker 52 | mesh = trimesh.load(self.mesh_directory / selected_item / self.target_name, process=False) 53 | vertices = torch.from_numpy(mesh.vertices).float() 54 | indices = torch.from_numpy(mesh.faces).int() 55 | tri_indices = torch.cat([indices[:, [0, 1, 2]], indices[:, [0, 2, 3]]], 0) 56 | vctr = torch.tensor(list(range(vertices.shape[0]))).long() 57 | 58 | real_sample, mvp = self.get_image_and_view(selected_item) 59 | condition = None 60 | if self.condition_directory is not None: 61 | condition = np.load(str(self.condition_directory / f"{selected_item}.npy")) 62 | condition = torch.from_numpy(condition).float() 63 | 64 | return { 65 | "name": selected_item, 66 | "y": pt_arxiv['target_colors'].float() * 2, 67 | "condition": condition, 68 | "vertex_ctr": vctr, 69 | "vertices": vertices, 70 | "indices_quad": indices, 71 | "mvp": mvp, 72 | "real": real_sample, 73 | "indices": tri_indices, 74 | "ranges": torch.tensor([0, tri_indices.shape[0]]).int(), 75 | "graph_data": self.get_item_as_graphdata(edge_index, sub_edges, pad_sizes, num_sub_vertices, pool_maps, is_pad, level_masks) 76 | } 77 | 78 | @staticmethod 79 | def get_item_as_graphdata(edge_index, sub_edges, pad_sizes, num_sub_vertices, pool_maps, is_pad, level_masks): 80 | return EasyDict({ 81 | 'face_neighborhood': edge_index, 82 | 'sub_neighborhoods': sub_edges, 83 | 'pads': pad_sizes, 84 | 'node_counts': num_sub_vertices, 85 | 'pool_maps': pool_maps, 86 | 'is_pad': is_pad, 87 | 'level_masks': level_masks 88 | }) 89 | 90 | def visualize_graph_with_predictions(self, name, prediction, output_dir, output_suffix): 91 | output_dir = Path(output_dir) 92 | output_dir.mkdir(exist_ok=True, parents=True) 93 | # noinspection PyTypeChecker 94 | mesh = trimesh.load(Path(self.raw_dir, name) / self.target_name, force='mesh', process=False) 95 | mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, face_colors=prediction + 0.5, process=False) 96 | mesh.export(output_dir / f"{name}_{output_suffix}.obj") 97 | 98 | @staticmethod 99 | def batch_mask(t, graph_data, idx, level=0): 100 | return t[graph_data['level_masks'][level] == idx] 101 | 102 | def get_image_and_view(self, shape): 103 | shape_id = int(shape.split('_')[0].split('shape')[1]) 104 | image_selections = self.get_image_selections(shape_id) 105 | view_selections = random.sample(self.all_views, self.views_per_sample) 106 | images, cameras = [], [] 107 | for c_i, c_v in zip(image_selections, view_selections): 108 | images.append(self.get_real_image(self.meta_to_pair(c_i))) 109 | perspective_cam = spherical_coord_to_cam(c_v['fov'], c_v['azimuth'], c_v['elevation']) 110 | # projection_matrix = intrinsic_to_projection(get_default_perspective_cam()).float() 111 | projection_matrix = torch.from_numpy(perspective_cam.projection_mat()).float() 112 | # view_matrix = torch.from_numpy(np.linalg.inv(generate_camera(np.zeros(3), c['azimuth'], c['elevation']))).float() 113 | view_matrix = torch.from_numpy(perspective_cam.view_mat()).float() 114 | cameras.append(torch.matmul(projection_matrix, view_matrix)) 115 | image = torch.cat(images, dim=0) 116 | mvp = torch.stack(cameras, dim=0) 117 | return image, mvp 118 | 119 | def get_real_image(self, name): 120 | if name not in self.real_images_preloaded.keys(): 121 | return self.process_real_image(self.real_images[name]) 122 | else: 123 | return self.real_images_preloaded[name] 124 | 125 | def get_image_selections(self, shape_id): 126 | candidates = self.pair_meta[shape_id] 127 | if len(candidates) < self.views_per_sample: 128 | while len(candidates) < self.views_per_sample: 129 | meta = self.pair_meta[random.choice(list(self.pair_meta.keys()))] 130 | candidates.extend(meta[:self.views_per_sample - len(candidates)]) 131 | else: 132 | candidates = random.sample(candidates, self.views_per_sample) 133 | return candidates 134 | 135 | def process_real_image(self, path): 136 | resize = T.Resize(size=(self.image_size, self.image_size)) 137 | pad = T.Pad(padding=(100, 100), fill=1) 138 | t_image = resize(pad(read_image(str(path)).float() / 127.5 - 1)) 139 | return t_image.unsqueeze(0) 140 | 141 | def load_pair_meta(self, pairmeta_path): 142 | loaded_json = json.loads(Path(pairmeta_path).read_text()) 143 | ret_dict = defaultdict(list) 144 | ret_views = [] 145 | for k in loaded_json.keys(): 146 | if self.meta_to_pair(loaded_json[k]) in self.real_images.keys(): 147 | ret_dict[loaded_json[k]['shape_id']].append(loaded_json[k]) 148 | ret_views.append(loaded_json[k]) 149 | return ret_dict, ret_views 150 | 151 | def preload_real_images(self): 152 | for ri in tqdm(self.real_images.keys(), desc='preload'): 153 | self.real_images_preloaded[ri] = self.process_real_image(self.real_images[ri]) 154 | 155 | @staticmethod 156 | def meta_to_pair(c): 157 | return f'shape{c["shape_id"]:05d}_rank{(c["rank"] - 1):02d}_pair{c["id"]}' 158 | -------------------------------------------------------------------------------- /dataset/mesh_cube.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import trimesh 9 | from torch.utils.data.dataloader import default_collate 10 | from scipy.spatial.transform import Rotation 11 | 12 | from dataset.mesh_uniform import get_default_perspective_cam 13 | from model.differentiable_renderer import intrinsic_to_projection, transform_pos_mvp 14 | from util.misc import EasyDict 15 | 16 | 17 | def generate_random_camera(loc): 18 | x_angles, y_angles = list(range(90, 270, 5)), list(range(0, 360, 5)) 19 | weight_fn = lambda x: 0.5 + 0.125 * math.cos(2 * (math.pi / 45) * x) 20 | weights_x, weights_y = [weight_fn(x) for x in x_angles], [weight_fn(y) for y in y_angles] 21 | x_angle, y_angle = random.choices(x_angles, weights_x, k=1)[0], random.choices(y_angles, weights_y, k=1)[0] 22 | camera_pose = np.eye(4) 23 | camera_pose[:3, :3] = Rotation.from_euler('y', y_angle, degrees=True).as_matrix() @ Rotation.from_euler('z', 180, degrees=True).as_matrix() @ Rotation.from_euler('x', x_angle, degrees=True).as_matrix() 24 | # camera_translation = camera_pose[:3, :3] @ np.array([0, 0, 1.025]) + loc 25 | camera_translation = camera_pose[:3, :3] @ np.array([0, 0, 1.925]) + loc 26 | camera_pose[:3, 3] = camera_translation 27 | return camera_pose 28 | 29 | 30 | def generate_fixed_camera(loc): 31 | camera_pose = np.eye(4) 32 | camera_pose[:3, :3] = Rotation.from_euler('y', 0, degrees=True).as_matrix() @ Rotation.from_euler('z', 180, degrees=True).as_matrix() @ Rotation.from_euler('x', 0, degrees=True).as_matrix() 33 | # camera_translation = camera_pose[:3, :3] @ np.array([0, 0, 1.025]) + loc 34 | camera_translation = camera_pose[:3, :3] @ np.array([0, 0, 1.925]) + loc 35 | camera_pose[:3, 3] = camera_translation 36 | return camera_pose 37 | 38 | 39 | class FaceGraphMeshDataset(torch.utils.data.Dataset): 40 | 41 | def __init__(self, config, limit_dataset_size=None): 42 | self.dataset_directory = Path(config.dataset_path) 43 | self.mesh_directory = Path(config.mesh_path) 44 | self.items = list(x.stem for x in Path(config.dataset_path).iterdir())[:limit_dataset_size] 45 | self.target_name = "model_normalized.obj" 46 | self.mask = lambda x, bs: torch.ones((x.shape[0],)).float().to(x.device) 47 | self.indices_src, self.indices_dest_i, self.indices_dest_j, self.faces_to_uv = [], [], [], None 48 | self.mesh_resolution = config.mesh_resolution 49 | self.setup_cube_texture_fast_visualization_buffers() 50 | self.projection_matrix = intrinsic_to_projection(get_default_perspective_cam()).float() 51 | self.plane = "Plane" in self.dataset_directory.name 52 | self.generate_camera = generate_fixed_camera if self.plane else generate_random_camera 53 | self.views_per_sample = 1 if self.plane else config.views_per_sample 54 | print("Plane Rendering: ", self.plane) 55 | 56 | def __len__(self): 57 | return len(self.items) 58 | 59 | def __getitem__(self, idx): 60 | selected_item = self.items[idx] 61 | pt_arxiv = torch.load(os.path.join(self.dataset_directory, f'{selected_item}.pt')) 62 | edge_index = pt_arxiv['conv_data'][0][0].long() 63 | num_sub_vertices = [pt_arxiv['conv_data'][i][0].shape[0] for i in range(1, len(pt_arxiv['conv_data']))] 64 | pad_sizes = [pt_arxiv['conv_data'][i][2].shape[0] for i in range(len(pt_arxiv['conv_data']))] 65 | sub_edges = [pt_arxiv['conv_data'][i][0].long() for i in range(1, len(pt_arxiv['conv_data']))] 66 | pool_maps = pt_arxiv['pool_locations'] 67 | is_pad = [pt_arxiv['conv_data'][i][4].bool() for i in range(len(pt_arxiv['conv_data']))] 68 | level_masks = [torch.zeros(pt_arxiv['conv_data'][i][0].shape[0]).long() for i in range(len(pt_arxiv['conv_data']))] 69 | 70 | # noinspection PyTypeChecker 71 | mesh = trimesh.load(self.mesh_directory / '_'.join(selected_item.split('_')[:-2]) / self.target_name, process=False) 72 | mvp = torch.stack([torch.matmul(self.projection_matrix, torch.from_numpy(np.linalg.inv(self.generate_camera((mesh.bounds[0] + mesh.bounds[1]) / 2))).float()) 73 | for _ in range(self.views_per_sample)], dim=0) 74 | vertices = torch.from_numpy(mesh.vertices).float() 75 | indices = torch.from_numpy(mesh.faces).int() 76 | tri_indices = torch.cat([indices[:, [0, 1, 2]], indices[:, [0, 2, 3]]], 0) 77 | vctr = torch.tensor(list(range(vertices.shape[0]))).long() 78 | return { 79 | "name": selected_item, 80 | "y": pt_arxiv['target_colors'].float() * 2, 81 | "vertex_ctr": vctr, 82 | "vertices": vertices, 83 | "indices_quad": indices, 84 | "mvp": mvp, 85 | "indices": tri_indices, 86 | "ranges": torch.tensor([0, tri_indices.shape[0]]).int(), 87 | "graph_data": self.get_item_as_graphdata(edge_index, sub_edges, pad_sizes, num_sub_vertices, pool_maps, is_pad, level_masks) 88 | } 89 | 90 | @staticmethod 91 | def get_item_as_graphdata(edge_index, sub_edges, pad_sizes, num_sub_vertices, pool_maps, is_pad, level_masks): 92 | return EasyDict({ 93 | 'face_neighborhood': edge_index, 94 | 'sub_neighborhoods': sub_edges, 95 | 'pads': pad_sizes, 96 | 'node_counts': num_sub_vertices, 97 | 'pool_maps': pool_maps, 98 | 'is_pad': is_pad, 99 | 'level_masks': level_masks 100 | }) 101 | 102 | def visualize_graph_with_predictions(self, name, prediction, output_dir, output_suffix): 103 | output_dir = Path(output_dir) 104 | output_dir.mkdir(exist_ok=True, parents=True) 105 | # noinspection PyTypeChecker 106 | mesh = trimesh.load(Path(self.raw_dir, "_".join(name.split('_')[:-2])) / self.target_name, force='mesh', process=False) 107 | mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, face_colors=prediction + 0.5, process=False) 108 | mesh.export(output_dir / f"{name}_{output_suffix}.obj") 109 | 110 | def to_image(self, face_colors, level_mask): 111 | batch_size = level_mask.max() + 1 112 | image = torch.zeros((batch_size, 3, self.mesh_resolution, self.mesh_resolution), device=face_colors.device) 113 | indices_dest_i = torch.tensor(self.indices_dest_i * batch_size, device=face_colors.device).long() 114 | indices_dest_j = torch.tensor(self.indices_dest_j * batch_size, device=face_colors.device).long() 115 | indices_src = torch.tensor(self.indices_src * batch_size, device=face_colors.device).long() 116 | image[level_mask, :, indices_dest_i, indices_dest_j] = face_colors[indices_src + level_mask * len(self.indices_src), :] 117 | return image 118 | 119 | @staticmethod 120 | def batch_mask(t, graph_data, idx, level=0): 121 | return t[graph_data['level_masks'][level] == idx] 122 | 123 | def setup_cube_texture_fast_visualization_buffers(self): 124 | # noinspection PyTypeChecker 125 | mesh = trimesh.load(self.mesh_directory / "coloredbrodatz_D48_COLORED" / self.target_name, process=False) 126 | vertex_to_uv = np.array(mesh.visual.uv) 127 | faces_to_vertices = np.array(mesh.faces) 128 | a = vertex_to_uv[faces_to_vertices[:, 0], :] 129 | b = vertex_to_uv[faces_to_vertices[:, 1], :] 130 | c = vertex_to_uv[faces_to_vertices[:, 2], :] 131 | d = vertex_to_uv[faces_to_vertices[:, 3], :] 132 | self.faces_to_uv = (a + b + c + d) / 4 133 | for v_idx in range(self.faces_to_uv.shape[0]): 134 | j = int(round(self.faces_to_uv[v_idx][0] * (self.mesh_resolution - 1))) 135 | i = (self.mesh_resolution - 1) - int(round(self.faces_to_uv[v_idx][1] * (self.mesh_resolution - 1))) 136 | self.indices_dest_i.append(i) 137 | self.indices_dest_j.append(j) 138 | self.indices_src.append(v_idx) 139 | -------------------------------------------------------------------------------- /data_processing/create_uv_charts.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from pathlib import Path 3 | import scipy 4 | import numpy as np 5 | from torchvision.utils import save_image 6 | from tqdm import tqdm 7 | import torch 8 | import trimesh 9 | from dataset import GraphDataLoader, to_device, to_vertex_colors_scatter 10 | from dataset.mesh_real_features_uv import split_into_six 11 | from model.differentiable_renderer import DifferentiableRenderer 12 | from PIL import Image 13 | 14 | 15 | @hydra.main(config_path='../config', config_name='stylegan2') 16 | def create_silhouttes(config): 17 | from dataset.mesh_real_features_atlas import FaceGraphMeshDataset 18 | config.image_size = 512 19 | dataset = FaceGraphMeshDataset(config) 20 | dataloader = GraphDataLoader(dataset, batch_size=1, num_workers=0) 21 | render_helper = DifferentiableRenderer(config.image_size, 'bounds', config.colorspace, num_channels=4).cuda() 22 | Path("runs/uv_mask").mkdir(exist_ok=True) 23 | Path("runs/uv_positions").mkdir(exist_ok=True) 24 | Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape/uv_normals").mkdir(exist_ok=True) 25 | for batch_idx, batch in enumerate(tqdm(dataloader)): 26 | batch = to_device(batch, torch.device("cuda:0")) 27 | rendered_color_gt = render_helper.render(batch['vertices'], batch['indices'], 28 | to_vertex_colors_scatter(batch["x"][:, :3], batch), 29 | batch["ranges"].cpu(), batch['bg']).permute((0, 3, 1, 2)) 30 | rendered_color_gt_texture = render_helper.render(batch['vertices'], batch['indices'], 31 | to_vertex_colors_scatter(batch["y"][:, :3], batch), 32 | batch["ranges"].cpu(), batch['bg']).permute((0, 3, 1, 2))[:, :3, :, :] 33 | rendered_color_gt_texture = ((rendered_color_gt_texture * 0.5 + 0.5) * 255).int() 34 | rendered_color_gt_pos = rendered_color_gt[:, :3, :, :] 35 | rendered_color_gt_mask = ((1 - rendered_color_gt[:, 3, :, :]) * 255).int() 36 | 37 | row_0 = torch.cat([rendered_color_gt_mask[0, :, :], rendered_color_gt_mask[1, :, :], rendered_color_gt_mask[2, :, :]], dim=-1) 38 | row_1 = torch.cat([rendered_color_gt_mask[3, :, :], rendered_color_gt_mask[4, :, :], rendered_color_gt_mask[5, :, :]], dim=-1) 39 | mask = Image.fromarray(torch.cat([row_0, row_1], dim=-2).cpu().numpy().astype(np.uint8)) 40 | 41 | row_0 = torch.cat([rendered_color_gt_pos[0, :, :, :], rendered_color_gt_pos[1, :, :, :], rendered_color_gt_pos[2, :, :, :]], dim=-1) 42 | row_1 = torch.cat([rendered_color_gt_pos[3, :, :, :], rendered_color_gt_pos[4, :, :, :], rendered_color_gt_pos[5, :, :, :]], dim=-1) 43 | positions = torch.cat([row_0, row_1], dim=-2).permute((1, 2, 0)).cpu().numpy() 44 | 45 | row_0 = torch.cat([rendered_color_gt_texture[0, :, :, :], rendered_color_gt_texture[1, :, :, :], rendered_color_gt_texture[2, :, :, :]], dim=-1) 46 | row_1 = torch.cat([rendered_color_gt_texture[3, :, :, :], rendered_color_gt_texture[4, :, :, :], rendered_color_gt_texture[5, :, :, :]], dim=-1) 47 | colors = Image.fromarray(torch.cat([row_0, row_1], dim=-2).permute((1, 2, 0)).cpu().numpy().astype(np.uint8)) 48 | 49 | mask.save(str(Path("runs/uv_mask") / f"{batch['name'][0]}.jpg")) 50 | colors.save(str(Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape/uv_normals") / f"{batch['name'][0]}.jpg")) 51 | np.savez_compressed(str(Path("runs/uv_positions") / f"{batch['name'][0]}.npz"), positions) 52 | 53 | 54 | def create_uv_mapping(proc, num_proc): 55 | mesh_path = Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape/shapenet-chairs-manifold-highres") 56 | map_path = Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape/uv_positions") 57 | mask_path = Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape/uv_mask") 58 | output_path = Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape/uv_map") 59 | output_path.mkdir(exist_ok=True) 60 | meshes = list(mesh_path.iterdir()) 61 | meshes = [x for i, x in enumerate(meshes) if i % num_proc == proc] 62 | for mesh in tqdm(meshes): 63 | if not (map_path / f'{mesh.stem}.npz').exists(): 64 | continue 65 | tmesh = trimesh.load(mesh / "model_normalized.obj", process=False) 66 | uv_map = split_into_six(np.load(str(map_path / f'{mesh.stem}.npz'))['arr']) 67 | uv_mask = split_into_six(np.array(Image.open(mask_path / f'{mesh.stem}.jpg'))[:, :, np.newaxis]) 68 | 69 | selected_mapping = np.zeros([tmesh.vertices.shape[0], 4]) 70 | selected_mapping[:, 3] = float('inf') 71 | 72 | all_normals = np.array([[0., 1., 0.], [0., -1., 0.], [0., 0., -1.], [1., 0., 0.], [0., 0., 1.], [-1., 0., 0.]], dtype=np.float32) 73 | 74 | for mp_idx in range(uv_map.shape[0]): 75 | all_pos = uv_map[mp_idx].reshape((-1, 3)) 76 | max_i = uv_map[mp_idx].shape[0] 77 | max_j = uv_map[mp_idx].shape[1] 78 | pixel_coordinates_i, pixel_coordinates_j = np.meshgrid(list(range(max_i)), list(range(max_j)), indexing='ij') 79 | pixel_coordinates_i = pixel_coordinates_i / max_i 80 | pixel_coordinates_j = pixel_coordinates_j / max_j 81 | pixel_coordinates = np.stack([pixel_coordinates_i, pixel_coordinates_j], axis=-1).reshape((-1, 2)) 82 | valid_pos = uv_mask[mp_idx].flatten() == 255 83 | pixel_coordinates = pixel_coordinates[valid_pos] 84 | kdtree = scipy.spatial.cKDTree(all_pos[valid_pos]) 85 | dist, indices = kdtree.query(np.array(tmesh.vertices), k=1) 86 | kdtree_normals = scipy.spatial.cKDTree(all_normals) 87 | dist_norm, indices_norm = kdtree_normals.query(np.array(tmesh.vertex_normals), k=1) 88 | dmask = indices_norm == mp_idx 89 | selected_mapping[dmask, 1:3] = pixel_coordinates[indices, :][dmask, :] 90 | selected_mapping[dmask, 0] = mp_idx 91 | selected_mapping[dmask, 3] = dist[dmask] 92 | 93 | np.save(str(output_path / f'{mesh.stem}.npy'), selected_mapping[:, 0:3]) 94 | print(mesh) 95 | 96 | 97 | @hydra.main(config_path='../config', config_name='stylegan2') 98 | def render_with_uv(config): 99 | from dataset.mesh_real_features_uv import FaceGraphMeshDataset 100 | config.image_size = 128 101 | config.batch_size = 1 102 | dataset = FaceGraphMeshDataset(config) 103 | dataloader = GraphDataLoader(dataset, batch_size=config.batch_size, num_workers=0) 104 | render_helper = DifferentiableRenderer(config.image_size, 'bounds', config.colorspace).cuda() 105 | for batch_idx, batch in enumerate(tqdm(dataloader)): 106 | batch = to_device(batch, torch.device("cuda:0")) 107 | texture_map = [] 108 | for name in batch['name']: 109 | texture_map.append(torch.from_numpy(split_into_six(np.array(Image.open(Path("/cluster/gimli/ysiddiqui/CADTextures/Photoshape/uv_normals") / f'{name}.jpg')))).permute(0, 3, 1, 2)) 110 | texture_map = torch.nn.functional.interpolate(torch.cat(texture_map, dim=0).to(batch['vertices'].device).float() / 255 * 2 - 1, size=(config.image_size, config.image_size), mode='bilinear', align_corners=True) 111 | vertices_mapped = texture_map[batch["uv"][:, 0].long(), :, (batch["uv"][:, 1] * config.image_size).long(), (batch["uv"][:, 2] * config.image_size).long()] 112 | rendered_texture = render_helper.render(batch['vertices'], batch['indices'], vertices_mapped, batch["ranges"].cpu(), None).permute((0, 3, 1, 2)) 113 | save_image(rendered_texture, f"runs/images/{batch_idx:04d}.png", nrow=6, value_range=(-1, 1), normalize=True) 114 | 115 | 116 | if __name__ == '__main__': 117 | import argparse 118 | 119 | # parser = argparse.ArgumentParser() 120 | # parser.add_argument('-n', '--num_proc', default=1, type=int) 121 | # parser.add_argument('-p', '--proc', default=0, type=int) 122 | # args = parser.parse_args() 123 | 124 | # create_silhouttes() 125 | # create_uv_mapping(args.proc, args.num_proc) 126 | render_with_uv() 127 | -------------------------------------------------------------------------------- /data_processing/create_uv_charts_car.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from pathlib import Path 3 | import scipy 4 | import numpy as np 5 | from torchvision.utils import save_image 6 | from tqdm import tqdm 7 | import torch 8 | import trimesh 9 | from dataset import GraphDataLoader, to_device, to_vertex_colors_scatter 10 | from dataset.mesh_real_features_uv import split_into_six 11 | from model.differentiable_renderer import DifferentiableRenderer 12 | from PIL import Image 13 | 14 | 15 | @hydra.main(config_path='../config', config_name='stylegan2_car') 16 | def create_silhouttes(config): 17 | from dataset.meshcar_real_features_atlas import FaceGraphMeshDataset 18 | config.image_size = 512 19 | dataset = FaceGraphMeshDataset(config) 20 | dataloader = GraphDataLoader(dataset, batch_size=1, num_workers=0) 21 | render_helper = DifferentiableRenderer(config.image_size, 'bounds', config.colorspace, num_channels=4).cuda() 22 | Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_mask").mkdir(exist_ok=True) 23 | Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_positions").mkdir(exist_ok=True) 24 | Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_normals").mkdir(exist_ok=True) 25 | for batch_idx, batch in enumerate(tqdm(dataloader)): 26 | batch = to_device(batch, torch.device("cuda:0")) 27 | rendered_color_gt = render_helper.render(batch['vertices'], batch['indices'], 28 | to_vertex_colors_scatter(batch["x"][:, :3], batch), 29 | batch["ranges"].cpu(), batch['bg']).permute((0, 3, 1, 2)) 30 | rendered_color_gt_texture = render_helper.render(batch['vertices'], batch['indices'], 31 | to_vertex_colors_scatter(batch["y"][:, :3], batch), 32 | batch["ranges"].cpu(), batch['bg']).permute((0, 3, 1, 2))[:, :3, :, :] 33 | 34 | rendered_color_gt_texture = ((rendered_color_gt_texture * 0.5 + 0.5) * 255).int() 35 | rendered_color_gt_pos = rendered_color_gt[:, :3, :, :] 36 | rendered_color_gt_mask = ((1 - rendered_color_gt[:, 3, :, :]) * 255).int() 37 | 38 | row_0 = torch.cat([rendered_color_gt_mask[0, :, :], rendered_color_gt_mask[1, :, :], rendered_color_gt_mask[2, :, :]], dim=-1) 39 | row_1 = torch.cat([rendered_color_gt_mask[3, :, :], rendered_color_gt_mask[4, :, :], rendered_color_gt_mask[5, :, :]], dim=-1) 40 | mask = Image.fromarray(torch.cat([row_0, row_1], dim=-2).cpu().numpy().astype(np.uint8)) 41 | 42 | row_0 = torch.cat([rendered_color_gt_pos[0, :, :, :], rendered_color_gt_pos[1, :, :, :], rendered_color_gt_pos[2, :, :, :]], dim=-1) 43 | row_1 = torch.cat([rendered_color_gt_pos[3, :, :, :], rendered_color_gt_pos[4, :, :, :], rendered_color_gt_pos[5, :, :, :]], dim=-1) 44 | positions = torch.cat([row_0, row_1], dim=-2).permute((1, 2, 0)).cpu().numpy() 45 | 46 | row_0 = torch.cat([rendered_color_gt_texture[0, :, :, :], rendered_color_gt_texture[1, :, :, :], rendered_color_gt_texture[2, :, :, :]], dim=-1) 47 | row_1 = torch.cat([rendered_color_gt_texture[3, :, :, :], rendered_color_gt_texture[4, :, :, :], rendered_color_gt_texture[5, :, :, :]], dim=-1) 48 | colors = Image.fromarray(torch.cat([row_0, row_1], dim=-2).permute((1, 2, 0)).cpu().numpy().astype(np.uint8)) 49 | 50 | mask.save(str(Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_mask") / f"{batch['name'][0]}.jpg")) 51 | colors.save(str(Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_normals") / f"{batch['name'][0]}.jpg")) 52 | np.savez_compressed(str(Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_positions") / f"{batch['name'][0]}.npz"), positions) 53 | 54 | 55 | def create_uv_mapping(proc, num_proc): 56 | mesh_path = Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/manifold_combined") 57 | map_path = Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_positions") 58 | mask_path = Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_mask") 59 | output_path = Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_map_first") 60 | output_path.mkdir(exist_ok=True) 61 | meshes = list(mesh_path.iterdir()) 62 | meshes = [x for i, x in enumerate(meshes) if i % num_proc == proc] 63 | for mesh in tqdm(meshes): 64 | if not (map_path / f'{mesh.stem}.npz').exists(): 65 | continue 66 | tmesh = trimesh.load(mesh / "model_normalized.obj", process=False) 67 | uv_map = split_into_six(np.load(str(map_path / f'{mesh.stem}.npz'))['arr_0']) 68 | uv_mask = split_into_six(np.array(Image.open(mask_path / f'{mesh.stem}.jpg'))[:, :, np.newaxis]) 69 | 70 | selected_mapping = np.zeros([tmesh.vertices.shape[0], 4]) 71 | selected_mapping[:, 3] = float('inf') 72 | 73 | all_normals = np.array([[0., 1., 0.], [0., -1., 0.], [0., 0., -1.], [1., 0., 0.], [0., 0., 1.], [-1., 0., 0.]], dtype=np.float32) 74 | 75 | map_indices = [5, 3, 2, 4, 0, 1] 76 | for mp_idx in map_indices: 77 | all_pos = uv_map[mp_idx].reshape((-1, 3)) 78 | max_i = uv_map[mp_idx].shape[0] 79 | max_j = uv_map[mp_idx].shape[1] 80 | pixel_coordinates_i, pixel_coordinates_j = np.meshgrid(list(range(max_i)), list(range(max_j)), indexing='ij') 81 | pixel_coordinates_i = pixel_coordinates_i / max_i 82 | pixel_coordinates_j = pixel_coordinates_j / max_j 83 | pixel_coordinates = np.stack([pixel_coordinates_i, pixel_coordinates_j], axis=-1).reshape((-1, 2)) 84 | valid_pos = uv_mask[mp_idx].flatten() == 255 85 | pixel_coordinates = pixel_coordinates[valid_pos] 86 | 87 | kdtree = scipy.spatial.cKDTree(all_pos[valid_pos]) 88 | dist, indices = kdtree.query(np.array(tmesh.vertices), k=1) 89 | dmask_0 = dist < 8e-3 90 | dmask_1 = selected_mapping[:, 3] == float('inf') 91 | dmask = np.logical_and(dmask_0, dmask_1) 92 | selected_mapping[dmask, 1:3] = pixel_coordinates[indices, :][dmask, :] 93 | selected_mapping[dmask, 0] = mp_idx 94 | selected_mapping[dmask, 3] = dist[dmask] 95 | 96 | np.save(str(output_path / f'{mesh.stem}.npy'), selected_mapping[:, 0:3]) 97 | print(mesh) 98 | 99 | 100 | @hydra.main(config_path='../config', config_name='stylegan2_car') 101 | def render_with_uv(config): 102 | from dataset.meshcar_real_features_uv import FaceGraphMeshDataset 103 | config.image_size = 128 104 | config.batch_size = 1 105 | dataset = FaceGraphMeshDataset(config) 106 | dataloader = GraphDataLoader(dataset, batch_size=config.batch_size, num_workers=0) 107 | render_helper = DifferentiableRenderer(config.image_size, 'bounds', config.colorspace).cuda() 108 | for batch_idx, batch in enumerate(tqdm(dataloader)): 109 | batch = to_device(batch, torch.device("cuda:0")) 110 | texture_map = [] 111 | for name in batch['name']: 112 | texture_map.append(torch.from_numpy(split_into_six(np.array(Image.open(Path("/cluster/gimli/ysiddiqui/CADTextures/CompCars/uv_normals") / f'{name}.jpg')))).permute(0, 3, 1, 2)) 113 | texture_map = torch.nn.functional.interpolate(torch.cat(texture_map, dim=0).to(batch['vertices'].device).float() / 255 * 2 - 1, size=(config.image_size, config.image_size), mode='bilinear', align_corners=True) 114 | vertices_mapped = texture_map[batch["uv"][:, 0].long(), :, (batch["uv"][:, 1] * config.image_size).long(), (batch["uv"][:, 2] * config.image_size).long()] 115 | rendered_texture = render_helper.render(batch['vertices'], batch['indices'], vertices_mapped, batch["ranges"].cpu(), None).permute((0, 3, 1, 2)) 116 | save_image(rendered_texture, f"runs/images/{batch_idx:04d}.png", nrow=6, value_range=(-1, 1), normalize=True) 117 | 118 | 119 | if __name__ == '__main__': 120 | import argparse 121 | 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-n', '--num_proc', default=1, type=int) 124 | parser.add_argument('-p', '--proc', default=0, type=int) 125 | args = parser.parse_args() 126 | 127 | # create_silhouttes() 128 | create_uv_mapping(args.proc, args.num_proc) 129 | # render_with_uv() 130 | -------------------------------------------------------------------------------- /model/differentiable_renderer_light.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import nvdiffrast.torch as dr 4 | from torchvision.ops import masks_to_boxes 5 | 6 | 7 | def transform_pos(pos, projection_matrix, world_to_cam_matrix): 8 | # (x,y,z) -> (x,y,z,1) 9 | t_mtx = torch.matmul(projection_matrix, world_to_cam_matrix) 10 | # noinspection PyArgumentList 11 | posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1) 12 | return torch.matmul(posw, t_mtx.t()) 13 | 14 | 15 | def transform_pos_mvp(pos, mvp): 16 | # noinspection PyArgumentList 17 | posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1) 18 | return torch.bmm(posw.unsqueeze(0).expand(mvp.shape[0], -1, -1), mvp.permute((0, 2, 1))).reshape((-1, 4)) 19 | 20 | 21 | def render(glctx, pos_clip, pos_idx, vtx_col, col_idx, resolution, ranges, _colorspace, background=None): 22 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution=[resolution, resolution], ranges=ranges) 23 | color, _ = dr.interpolate(vtx_col[None, ...], rast_out, col_idx) 24 | color = dr.antialias(color, rast_out, pos_clip, pos_idx) 25 | mask = color[..., -1:] == 0 26 | if background is None: 27 | one_tensor = torch.ones((color.shape[0], color.shape[3], 1, 1), device=color.device) 28 | else: 29 | one_tensor = background 30 | one_tensor_permuted = one_tensor.permute((0, 2, 3, 1)).contiguous() 31 | color = torch.where(mask, one_tensor_permuted, color) 32 | return color[:, :, :, :-1] 33 | 34 | 35 | def render_specular_effects(glctx, lightdir, view_vec, pos_clip, pos_idx, normals, resolution, ranges): 36 | reflvec = view_vec - 2.0 * normals * torch.sum(normals * view_vec, -1, keepdim=True) # Reflection vectors at vertices. 37 | reflvec = reflvec / torch.sum(reflvec ** 2, -1, keepdim=True) ** 0.5 38 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, [resolution, resolution], ranges=ranges) 39 | refl, _ = dr.interpolate(reflvec, rast_out, pos_idx) 40 | refl = refl / (torch.sum(refl ** 2, -1, keepdim=True) + 1e-8) ** 0.5 # Normalize. 41 | ldotr = torch.sum(-lightdir * refl, -1, keepdim=True) # L dot R. 42 | return ldotr 43 | 44 | 45 | def render_in_bounds(glctx, pos_clip, pos_idx, vtx_col, col_idx, normals, vtx_shininess, lightdirs, view_vec, shininess, resolution, ranges, color_space, background=None): 46 | # color 47 | render_resolution = int(resolution * 1.2) 48 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution=[render_resolution, render_resolution], ranges=ranges) 49 | color, _ = dr.interpolate(vtx_col[None, ...], rast_out, col_idx) 50 | color = dr.antialias(color, rast_out, pos_clip, pos_idx) 51 | mask = color[..., -1:] == 0 52 | # light 53 | reflvec = vtx_shininess * (view_vec - 2.0 * normals * torch.sum(normals * view_vec, -1, keepdim=True)) # Reflection vectors at vertices. 54 | reflvec = reflvec / (torch.sum(reflvec ** 2, -1, keepdim=True) ** 0.5 + 1e-8) 55 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, [render_resolution, render_resolution], ranges=ranges) 56 | refl, _ = dr.interpolate(reflvec, rast_out, pos_idx) 57 | refl = refl / (torch.sum(refl ** 2, -1, keepdim=True) + 1e-8) ** 0.5 # Normalize. 58 | 59 | total_specular = torch.zeros(list(color.shape[:3]) + [1], device=color.device) 60 | for lightdir in lightdirs: 61 | ldotr = torch.sum(-lightdir * refl, -1, keepdim=True) # L dot R. 62 | total_specular += torch.max(torch.zeros_like(ldotr), ldotr) ** shininess 63 | total_specular = torch.clamp(total_specular, 0, 1) 64 | 65 | if background is None: 66 | if color_space == 'rgb': 67 | one_tensor = torch.ones((color.shape[0], color.shape[3], 1, 1), device=color.device) 68 | else: 69 | one_tensor = torch.zeros((color.shape[0], color.shape[3], 1, 1), device=color.device) 70 | one_tensor[:, 0, :, :] = 1 71 | else: 72 | one_tensor = background 73 | one_tensor_permuted = one_tensor.permute((0, 2, 3, 1)).contiguous() 74 | color = torch.where(mask, one_tensor_permuted, color) # [:, :, :, :-1] 75 | color[..., -1:] = mask.float() 76 | color_crops, specular_crops = [], [] 77 | boxes = masks_to_boxes(torch.logical_not(mask.squeeze(-1))) 78 | for img_idx in range(color.shape[0]): 79 | x1, y1, x2, y2 = [int(val) for val in boxes[img_idx, :].tolist()] 80 | color_crop = color[img_idx, y1: y2, x1: x2, :].permute((2, 0, 1)) 81 | specular_crop = total_specular[img_idx, y1: y2, x1: x2, :].permute((2, 0, 1)) 82 | pad = [[0, 0], [0, 0]] 83 | if y2 - y1 > x2 - x1: 84 | total_pad = (y2 - y1) - (x2 - x1) 85 | pad[0][0] = total_pad // 2 86 | pad[0][1] = total_pad - pad[0][0] 87 | pad[1][0], pad[1][1] = 0, 0 88 | additional_pad = int((y2 - y1) * 0.1) 89 | else: 90 | total_pad = (x2 - x1) - (y2 - y1) 91 | pad[0][0], pad[0][1] = 0, 0 92 | pad[1][0] = total_pad // 2 93 | pad[1][1] = total_pad - pad[1][0] 94 | additional_pad = int((x2 - x1) * 0.1) 95 | for i in range(4): 96 | pad[i // 2][i % 2] += additional_pad 97 | 98 | padded = torch.ones((color_crop.shape[0], color_crop.shape[1] + pad[1][0] + pad[1][1], color_crop.shape[2] + pad[0][0] + pad[0][1]), device=color_crop.device) 99 | padded[:3, :, :] = padded[:3, :, :] * one_tensor[img_idx, :3, :, :] 100 | padded[:, pad[1][0]: padded.shape[1] - pad[1][1], pad[0][0]: padded.shape[2] - pad[0][1]] = color_crop 101 | specular_padded = torch.zeros((specular_crop.shape[0], specular_crop.shape[1] + pad[1][0] + pad[1][1], specular_crop.shape[2] + pad[0][0] + pad[0][1]), device=specular_crop.device) 102 | specular_padded[:, pad[1][0]: padded.shape[1] - pad[1][1], pad[0][0]: padded.shape[2] - pad[0][1]] = specular_crop 103 | # color_crop = T.Pad((pad[0][0], pad[1][0], pad[0][1], pad[1][1]), 1)(color_crop) 104 | color_crop = torch.nn.functional.interpolate(padded.unsqueeze(0), size=(resolution, resolution), mode='bilinear', align_corners=False).permute((0, 2, 3, 1)) 105 | specular_crop = torch.nn.functional.interpolate(specular_padded.unsqueeze(0), size=(resolution, resolution), mode='bilinear', align_corners=False).permute((0, 2, 3, 1)) 106 | color_crops.append(color_crop) 107 | specular_crops.append(specular_crop) 108 | return torch.cat(color_crops, dim=0), torch.cat(specular_crops, dim=0) 109 | 110 | 111 | def intrinsic_to_projection(intrinsic_matrix): 112 | near, far = 0.1, 50. 113 | a, b = -(far + near) / (far - near), -2 * far * near / (far - near) 114 | projection_matrix = torch.tensor([ 115 | intrinsic_matrix[0][0] / intrinsic_matrix[0][2], 0, 0, 0, 116 | 0, -intrinsic_matrix[1][1] / intrinsic_matrix[1][2], 0, 0, 117 | 0, 0, a, b, 118 | 0, 0, -1, 0 119 | ]).float().reshape((4, 4)) 120 | return projection_matrix 121 | 122 | 123 | class DifferentiableRenderer(nn.Module): 124 | 125 | def __init__(self, resolution, mode='standard', color_space='rgb', num_channels=3): 126 | super().__init__() 127 | self.glctx = dr.RasterizeGLContext() 128 | self.resolution = resolution 129 | self.render_func = render 130 | self.color_space = color_space 131 | self.num_channels = num_channels 132 | if mode == 'bounds': 133 | self.render_func = render_in_bounds 134 | 135 | def render(self, vertex_positions, triface_indices, vertex_colors, normals, vertex_shininess, light_directions, view_direction, shininess, ranges=None, background=None, resolution=None): 136 | if ranges is None: 137 | ranges = torch.tensor([[0, triface_indices.shape[0]]]).int() 138 | if resolution is None: 139 | resolution = self.resolution 140 | color, specular = self.render_func(self.glctx, vertex_positions, triface_indices, vertex_colors, triface_indices, normals, vertex_shininess, light_directions, view_direction, shininess, resolution, ranges, self.color_space, background) 141 | return color[:, :, :, :self.num_channels], specular 142 | -------------------------------------------------------------------------------- /dataset/mesh_real_pigan.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import trimesh 4 | from collections import defaultdict 5 | from pathlib import Path 6 | import numpy as np 7 | 8 | import torch 9 | import torchvision.transforms as T 10 | from torchvision.io import read_image 11 | from tqdm import tqdm 12 | 13 | from util.camera import spherical_coord_to_cam 14 | 15 | 16 | class SDFGridDataset(torch.utils.data.Dataset): 17 | 18 | def __init__(self, config, limit_dataset_size=None): 19 | self.dataset_directory = Path(config.dataset_path) 20 | self.mesh_directory = Path(config.dataset_path) 21 | try: 22 | self.image_size = config.img_size 23 | except Exception as err: 24 | self.image_size = config.image_size 25 | self.bg_color = config.random_bg 26 | self.real_images = {x.name.split('.')[0]: x for x in Path(config.image_path).iterdir() if x.name.endswith('.jpg') or x.name.endswith('.png')} 27 | self.masks = {x: Path(config.mask_path) / self.real_images[x].name for x in self.real_images} 28 | self.items = sorted(list(x.stem for x in Path(config.dataset_path).iterdir())[:limit_dataset_size]) 29 | problem_files = {"shape03283_rank00_pair26092", "shape04324_rank00_pair46014", "shape04142_rank00_pair284841", "shape03765_rank01_pair164915", "shape03704_rank02_pair249529", "shape06471_rank00_pair221396"} 30 | self.items = [x for x in self.items if x not in problem_files] 31 | self.views_per_sample = 1 32 | self.erode = config.erode 33 | self.real_pad = 100 34 | self.color_generator = random_color if config.random_bg == 'color' else (random_grayscale if config.random_bg == 'grayscale' else white) 35 | self.pair_meta, self.all_views = self.load_pair_meta(config.pairmeta_path) 36 | self.real_images_preloaded, self.masks_preloaded = {}, {} 37 | if config.preload: 38 | self.preload_real_images() 39 | 40 | def __len__(self): 41 | return len(self.items) 42 | 43 | def __getitem__(self, idx): 44 | selected_item = self.items[idx] 45 | mesh = trimesh.load(self.dataset_directory / selected_item / "model_normalized.obj", process=False) 46 | faces = torch.from_numpy(mesh.triangles.mean(1)).float() 47 | vertices = torch.from_numpy(mesh.vertices).float() 48 | vctr = torch.tensor(list(range(vertices.shape[0]))).long() 49 | indices = torch.from_numpy(mesh.faces).int() 50 | sdf_grid = torch.from_numpy(np.load(self.dataset_directory / selected_item / "064.npy")).unsqueeze(0) - 0.0075 51 | color_grid = torch.from_numpy(np.load(self.dataset_directory / selected_item / "064_if.npy")).permute((3, 0, 1, 2)) / 127.5 - 1 52 | tri_indices = torch.cat([indices[:, [0, 1, 2]], indices[:, [0, 2, 3]]], 0) 53 | real_sample, masks, mvp = self.get_image_and_view(selected_item) 54 | background = self.color_generator(self.views_per_sample) 55 | return { 56 | "name": selected_item, 57 | "sdf_x": sdf_grid.float(), 58 | "csdf_y": color_grid.float(), 59 | "faces": faces, 60 | "vertex_ctr": vctr, 61 | "vertices": vertices, 62 | "indices_quad": indices, 63 | "indices": tri_indices, 64 | "ranges": torch.tensor([0, tri_indices.shape[0]]).int(), 65 | "mvp": mvp, 66 | "real": real_sample[0].unsqueeze(0), 67 | "mask": masks[0].unsqueeze(0), 68 | "bg": torch.cat([background, torch.ones([background.shape[0], 1, 1, 1])], dim=1) 69 | } 70 | 71 | def get_image_and_view(self, shape): 72 | shape_id = int(shape.split('_')[0].split('shape')[1]) 73 | image_selections = self.get_image_selections(shape_id) 74 | view_selections = random.sample(self.all_views, self.views_per_sample) 75 | images, masks, cameras = [], [], [] 76 | for c_i, c_v in zip(image_selections, view_selections): 77 | images.append(self.get_real_image(self.meta_to_pair(c_i))) 78 | masks.append(self.get_real_mask(self.meta_to_pair(c_i))) 79 | perspective_cam = spherical_coord_to_cam(c_v['fov'], c_v['azimuth'], c_v['elevation']) 80 | projection_matrix = torch.from_numpy(perspective_cam.projection_mat()).float() 81 | view_matrix = torch.from_numpy(perspective_cam.view_mat()).float() 82 | cameras.append(torch.matmul(projection_matrix, view_matrix)) 83 | image = torch.cat(images, dim=0) 84 | masks = torch.cat(masks, dim=0) 85 | mvp = torch.stack(cameras, dim=0) 86 | return image, masks, mvp 87 | 88 | def get_real_image(self, name): 89 | if name not in self.real_images_preloaded.keys(): 90 | return self.process_real_image(self.real_images[name]) 91 | else: 92 | return self.real_images_preloaded[name] 93 | 94 | def get_real_mask(self, name): 95 | if name not in self.masks_preloaded.keys(): 96 | return self.process_real_mask(self.masks[name]) 97 | else: 98 | return self.masks_preloaded[name] 99 | 100 | @staticmethod 101 | def erode_mask(mask): 102 | import cv2 as cv 103 | mask = mask.squeeze(0).numpy().astype(np.uint8) 104 | kernel_size = 3 105 | element = cv.getStructuringElement(cv.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1), (kernel_size, kernel_size)) 106 | mask = cv.erode(mask, element) 107 | return torch.from_numpy(mask).unsqueeze(0) 108 | 109 | def process_real_mask(self, path): 110 | resize = T.Resize(size=(self.image_size, self.image_size)) 111 | pad = T.Pad(padding=(self.real_pad, self.real_pad), fill=0) 112 | if self.erode: 113 | eroded_mask = self.erode_mask(read_image(str(path))) 114 | else: 115 | eroded_mask = read_image(str(path)) 116 | t_mask = resize(pad((eroded_mask > 0).float())) 117 | return t_mask.unsqueeze(0) 118 | 119 | def get_image_selections(self, shape_id): 120 | candidates = self.pair_meta[shape_id] 121 | if len(candidates) < self.views_per_sample: 122 | while len(candidates) < self.views_per_sample: 123 | meta = self.pair_meta[random.choice(list(self.pair_meta.keys()))] 124 | candidates.extend(meta[:self.views_per_sample - len(candidates)]) 125 | else: 126 | candidates = random.sample(candidates, self.views_per_sample) 127 | return candidates 128 | 129 | def process_real_image(self, path): 130 | resize = T.Resize(size=(self.image_size, self.image_size)) 131 | pad = T.Pad(padding=(self.real_pad, self.real_pad), fill=1) 132 | t_image = resize(pad(read_image(str(path)).float() / 127.5 - 1)) 133 | return t_image.unsqueeze(0) 134 | 135 | def load_pair_meta(self, pairmeta_path): 136 | loaded_json = json.loads(Path(pairmeta_path).read_text()) 137 | ret_dict = defaultdict(list) 138 | ret_views = [] 139 | for k in loaded_json.keys(): 140 | if self.meta_to_pair(loaded_json[k]) in self.real_images.keys(): 141 | ret_dict[loaded_json[k]['shape_id']].append(loaded_json[k]) 142 | ret_views.append(loaded_json[k]) 143 | return ret_dict, ret_views 144 | 145 | def preload_real_images(self): 146 | for ri in tqdm(self.real_images.keys(), desc='preload'): 147 | self.real_images_preloaded[ri] = self.process_real_image(self.real_images[ri]) 148 | self.masks_preloaded[ri] = self.process_real_mask(self.masks[ri]) 149 | 150 | @staticmethod 151 | def meta_to_pair(c): 152 | return f'shape{c["shape_id"]:05d}_rank{(c["rank"] - 1):02d}_pair{c["id"]}' 153 | 154 | 155 | def random_color(num_views): 156 | randoms = [] 157 | for i in range(num_views): 158 | r, g, b = random.randint(0, 255) / 127.5 - 1, random.randint(0, 255) / 127.5 - 1, random.randint(0, 255) / 127.5 - 1 159 | randoms.append(torch.from_numpy(np.array([r, g, b]).reshape((1, 3, 1, 1))).float()) 160 | return torch.cat(randoms, dim=0) 161 | 162 | 163 | def random_grayscale(num_views): 164 | randoms = [] 165 | for i in range(num_views): 166 | c = random.randint(0, 255) / 127.5 - 1 167 | randoms.append(torch.from_numpy(np.array([c, c, c]).reshape((1, 3, 1, 1))).float()) 168 | return torch.cat(randoms, dim=0) 169 | 170 | 171 | def white(num_views): 172 | return torch.from_numpy(np.array([1, 1, 1]).reshape((1, 3, 1, 1))).expand(num_views, -1, -1, -1).float() 173 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from model import SmoothDownsample, EqualizedConv2d, FullyConnectedLayer, normalize_2nd_moment 5 | 6 | 7 | class Discriminator(torch.nn.Module): 8 | 9 | def __init__(self, img_resolution, img_channels, w_num_layers=0, c_dim=0, mbstd_on=1, channel_base=16384, channel_max=512): 10 | super().__init__() 11 | self.img_resolution = img_resolution 12 | self.img_resolution_log2 = int(np.log2(img_resolution)) 13 | self.img_channels = img_channels 14 | self.c_dim = c_dim 15 | self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] 16 | channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} 17 | self.module_list = [EqualizedConv2d(img_channels, channels_dict[img_resolution], kernel_size=1, activation='lrelu')] 18 | for res in self.block_resolutions: 19 | in_channels = channels_dict[res] 20 | out_channels = channels_dict[res // 2] 21 | self.module_list.append(DiscriminatorBlock(in_channels, out_channels)) 22 | self.module_list.append(DiscriminatorEpilogue(channels_dict[4], resolution=4, cmap_dim=(0 if c_dim == 0 else channels_dict[4]), mbstd_num_channels=mbstd_on)) 23 | self.module_list = torch.nn.ModuleList(self.module_list) 24 | if c_dim > 0: 25 | self.mapping = DiscriminatorMappingNetwork(c_dim=c_dim, cmap_dim=channels_dict[4], num_layers=w_num_layers) 26 | 27 | def forward(self, x, c=None): 28 | if self.c_dim > 0: 29 | c = self.mapping(c) 30 | for net in self.module_list[:-1]: 31 | x = net(x) 32 | x = self.module_list[-1](x, c) 33 | return x 34 | 35 | 36 | class DiscriminatorProgressive(torch.nn.Module): 37 | 38 | def __init__(self, img_resolution, img_channels, w_num_layers=0, c_dim=0, mbstd_on=1, channel_base=16384, channel_max=512): 39 | super().__init__() 40 | self.img_resolution = img_resolution 41 | self.img_resolution_log2 = int(np.log2(img_resolution)) 42 | self.img_channels = img_channels 43 | self.c_dim = c_dim 44 | self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] 45 | channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} 46 | self.module_list = [] 47 | self.shape_to_str = lambda x: f'{x:03d}' 48 | self.fromRGB = torch.nn.ModuleDict({ 49 | self.shape_to_str(res): EqualizedConv2d(img_channels, channels_dict[res], kernel_size=1, activation='lrelu') 50 | for res in self.block_resolutions[:-2]}) 51 | for res in self.block_resolutions: 52 | in_channels = channels_dict[res] 53 | out_channels = channels_dict[res // 2] 54 | self.module_list.append(DiscriminatorBlock(in_channels, out_channels)) 55 | self.module_list.append(DiscriminatorEpilogue(channels_dict[4], resolution=4, cmap_dim=(0 if c_dim == 0 else channels_dict[4]), mbstd_num_channels=mbstd_on)) 56 | self.module_list = torch.nn.ModuleList(self.module_list) 57 | if c_dim > 0: 58 | self.mapping = DiscriminatorMappingNetwork(c_dim=c_dim, cmap_dim=channels_dict[4], num_layers=w_num_layers) 59 | 60 | def forward(self, x_in, alpha, c=None): 61 | if self.c_dim > 0: 62 | c = self.mapping(c) 63 | module_start = self.block_resolutions.index(x_in.shape[-1]) 64 | x = self.fromRGB[self.shape_to_str(x_in.shape[-1])](x_in) 65 | for i, net in enumerate(self.module_list[module_start:-1]): 66 | if i == 1: 67 | x = x * alpha + self.fromRGB[self.shape_to_str(x_in.shape[-1] // 2)](torch.nn.functional.interpolate(x_in, scale_factor=0.5, mode='nearest', recompute_scale_factor=False)) 68 | x = net(x) 69 | x = self.module_list[-1](x, c) 70 | return x 71 | 72 | 73 | class DiscriminatorBlock(torch.nn.Module): 74 | 75 | def __init__(self, in_channels, out_channels, activation='lrelu'): 76 | super().__init__() 77 | self.in_channels = in_channels 78 | self.num_layers = 0 79 | downsampler = SmoothDownsample() 80 | self.conv0 = EqualizedConv2d(in_channels, in_channels, kernel_size=3, activation=activation) 81 | self.conv1 = EqualizedConv2d(in_channels, out_channels, kernel_size=3, activation=activation, resample=downsampler) 82 | self.skip = EqualizedConv2d(in_channels, out_channels, kernel_size=1, bias=False, resample=downsampler) 83 | 84 | def forward(self, x): 85 | y = self.skip(x, gain=np.sqrt(0.5)) 86 | x = self.conv0(x) 87 | x = self.conv1(x, gain=np.sqrt(0.5)) 88 | x = y.add_(x) 89 | return x 90 | 91 | 92 | class DiscriminatorEpilogue(torch.nn.Module): 93 | 94 | def __init__(self, in_channels, resolution, cmap_dim=0, mbstd_group_size=4, mbstd_num_channels=1, activation='lrelu'): 95 | super().__init__() 96 | self.in_channels = in_channels 97 | self.resolution = resolution 98 | self.cmap_dim = cmap_dim 99 | 100 | self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None 101 | self.conv = EqualizedConv2d(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation) 102 | self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) 103 | self.out = FullyConnectedLayer(in_channels, cmap_dim if cmap_dim > 0 else 1) 104 | 105 | def forward(self, x, cmap=None): 106 | if self.mbstd is not None: 107 | x = self.mbstd(x) 108 | x = self.conv(x) 109 | x = self.fc(x.flatten(1)) 110 | x = self.out(x) 111 | if self.cmap_dim > 0: 112 | x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) 113 | return x 114 | 115 | 116 | class MinibatchStdLayer(torch.nn.Module): 117 | 118 | def __init__(self, group_size, num_channels=1): 119 | super().__init__() 120 | self.group_size = group_size 121 | self.num_channels = num_channels 122 | 123 | def forward(self, x): 124 | N, C, H, W = x.shape 125 | G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N 126 | F = self.num_channels 127 | c = C // F 128 | 129 | y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. 130 | y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. 131 | y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. 132 | y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. 133 | y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. 134 | y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. 135 | y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. 136 | x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. 137 | return x 138 | 139 | 140 | class DiscriminatorMappingNetwork(torch.nn.Module): 141 | 142 | def __init__(self, cmap_dim, c_dim, num_layers=8, activation='lrelu', lr_multiplier=0.01): 143 | super().__init__() 144 | self.w_dim = cmap_dim 145 | self.c_dim = c_dim 146 | self.num_layers = num_layers 147 | 148 | features_list = [c_dim] + [self.w_dim] * num_layers 149 | 150 | self.embed = FullyConnectedLayer(c_dim, self.w_dim) 151 | 152 | self.layers = torch.nn.ModuleList() 153 | for idx in range(num_layers): 154 | in_features = features_list[idx] 155 | out_features = features_list[idx + 1] 156 | self.layers.append(FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)) 157 | 158 | def forward(self, c): 159 | x = normalize_2nd_moment(self.embed(c)) 160 | 161 | # Main layers. 162 | for idx in range(self.num_layers): 163 | x = self.layers[idx](x) 164 | 165 | return x 166 | 167 | 168 | if __name__ == '__main__': 169 | from util.misc import print_model_parameter_count, print_module_summary 170 | img_res = 512 171 | model = DiscriminatorProgressive(img_resolution=img_res, img_channels=3) 172 | print_module_summary(model, (torch.randn((16, 3, img_res, img_res)), 1)) 173 | print_model_parameter_count(model) 174 | -------------------------------------------------------------------------------- /model/pigan/discriminators.py: -------------------------------------------------------------------------------- 1 | """Discrimators used in pi-GAN""" 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from model.pigan.sgdiscriminators import * 8 | 9 | 10 | class GlobalAveragePooling(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | def forward(self, x): 14 | return x.mean([2, 3]) 15 | 16 | class AdapterBlock(nn.Module): 17 | def __init__(self, output_channels): 18 | super().__init__() 19 | self.model = nn.Sequential( 20 | nn.Conv2d(3, output_channels, 1, padding=0), 21 | nn.LeakyReLU(0.2) 22 | ) 23 | def forward(self, input): 24 | return self.model(input) 25 | 26 | 27 | def kaiming_leaky_init(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Linear') != -1: 30 | torch.nn.init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu') 31 | 32 | 33 | class AddCoords(nn.Module): 34 | """ 35 | Source: https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py 36 | """ 37 | 38 | def __init__(self, with_r=False): 39 | super().__init__() 40 | self.with_r = with_r 41 | 42 | def forward(self, input_tensor): 43 | """ 44 | Args: 45 | input_tensor: shape(batch, channel, x_dim, y_dim) 46 | """ 47 | batch_size, _, x_dim, y_dim = input_tensor.size() 48 | 49 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 50 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 51 | 52 | xx_channel = xx_channel.float() / (x_dim - 1) 53 | yy_channel = yy_channel.float() / (y_dim - 1) 54 | 55 | xx_channel = xx_channel * 2 - 1 56 | yy_channel = yy_channel * 2 - 1 57 | 58 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 59 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 60 | 61 | ret = torch.cat([ 62 | input_tensor, 63 | xx_channel.type_as(input_tensor), 64 | yy_channel.type_as(input_tensor)], dim=1) 65 | 66 | if self.with_r: 67 | rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)) 68 | ret = torch.cat([ret, rr], dim=1) 69 | 70 | return ret 71 | 72 | class CoordConv(nn.Module): 73 | """ 74 | Source: https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py 75 | """ 76 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 77 | super().__init__() 78 | self.addcoords = AddCoords(with_r=with_r) 79 | in_size = in_channels+2 80 | if with_r: 81 | in_size += 1 82 | self.conv = nn.Conv2d(in_size, out_channels, **kwargs) 83 | 84 | def forward(self, x): 85 | ret = self.addcoords(x) 86 | ret = self.conv(ret) 87 | return ret 88 | 89 | class ResidualCoordConvBlock(nn.Module): 90 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, downsample=False, groups=1): 91 | super().__init__() 92 | p = kernel_size//2 93 | self.network = nn.Sequential( 94 | CoordConv(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=p), 95 | nn.LeakyReLU(0.2, inplace=True), 96 | CoordConv(planes, planes, kernel_size=kernel_size, padding=p), 97 | nn.LeakyReLU(0.2, inplace=True) 98 | ) 99 | self.network.apply(kaiming_leaky_init) 100 | 101 | self.proj = nn.Conv2d(inplanes, planes, 1) if inplanes != planes else None 102 | self.downsample = downsample 103 | 104 | def forward(self, identity): 105 | y = self.network(identity) 106 | 107 | if self.downsample: y = nn.functional.avg_pool2d(y, 2) 108 | if self.downsample: identity = nn.functional.avg_pool2d(identity, 2) 109 | identity = identity if self.proj is None else self.proj(identity) 110 | 111 | y = (y + identity)/math.sqrt(2) 112 | return y 113 | 114 | 115 | class ProgressiveDiscriminator(nn.Module): 116 | """Implement of a progressive growing discriminator with ResidualCoordConv Blocks""" 117 | 118 | def __init__(self): 119 | super().__init__() 120 | self.epoch = 0 121 | self.step = 0 122 | self.layers = nn.ModuleList( 123 | [ 124 | ResidualCoordConvBlock(16, 32, downsample=True), # 512x512 -> 256x256 125 | ResidualCoordConvBlock(32, 64, downsample=True), # 256x256 -> 128x128 126 | ResidualCoordConvBlock(64, 128, downsample=True), # 128x128 -> 64x64 127 | ResidualCoordConvBlock(128, 256, downsample=True), # 64x64 -> 32x32 128 | ResidualCoordConvBlock(256, 400, downsample=True), # 32x32 -> 16x16 129 | ResidualCoordConvBlock(400, 400, downsample=True), # 16x16 -> 8x8 130 | ResidualCoordConvBlock(400, 400, downsample=True), # 8x8 -> 4x4 131 | ResidualCoordConvBlock(400, 400, downsample=True), # 4x4 -> 2x2 132 | ]) 133 | 134 | self.fromRGB = nn.ModuleList( 135 | [ 136 | AdapterBlock(16), 137 | AdapterBlock(32), 138 | AdapterBlock(64), 139 | AdapterBlock(128), 140 | AdapterBlock(256), 141 | AdapterBlock(400), 142 | AdapterBlock(400), 143 | AdapterBlock(400), 144 | AdapterBlock(400) 145 | ]) 146 | self.final_layer = nn.Conv2d(400, 1, 2) 147 | self.img_size_to_layer = {2:8, 4:7, 8:6, 16:5, 32:4, 64:3, 128:2, 256:1, 512:0} 148 | 149 | 150 | def forward(self, input, alpha): 151 | start = self.img_size_to_layer[input.shape[-1]] 152 | 153 | x = self.fromRGB[start](input) 154 | for i, layer in enumerate(self.layers[start:]): 155 | if i == 1: 156 | x = alpha * x + (1 - alpha) * self.fromRGB[start+1](F.interpolate(input, scale_factor=0.5, mode='nearest')) 157 | x = layer(x) 158 | 159 | x = self.final_layer(x).reshape(x.shape[0], 1) 160 | 161 | return x 162 | 163 | 164 | class ProgressiveEncoderDiscriminator(nn.Module): 165 | """ 166 | Implement of a progressive growing discriminator with ResidualCoordConv Blocks. 167 | Identical to ProgressiveDiscriminator except it also predicts camera angles and latent codes. 168 | """ 169 | 170 | def __init__(self, **kwargs): 171 | super().__init__() 172 | self.epoch = 0 173 | self.step = 0 174 | self.layers = nn.ModuleList( 175 | [ 176 | ResidualCoordConvBlock(16, 32, downsample=True), # 512x512 -> 256x256 177 | ResidualCoordConvBlock(32, 64, downsample=True), # 256x256 -> 128x128 178 | ResidualCoordConvBlock(64, 128, downsample=True), # 128x128 -> 64x64 179 | ResidualCoordConvBlock(128, 256, downsample=True), # 64x64 -> 32x32 180 | ResidualCoordConvBlock(256, 400, downsample=True), # 32x32 -> 16x16 181 | ResidualCoordConvBlock(400, 400, downsample=True), # 16x16 -> 8x8 182 | ResidualCoordConvBlock(400, 400, downsample=True), # 8x8 -> 4x4 183 | ResidualCoordConvBlock(400, 400, downsample=True), # 4x4 -> 2x2 184 | ]) 185 | 186 | self.fromRGB = nn.ModuleList( 187 | [ 188 | AdapterBlock(16), 189 | AdapterBlock(32), 190 | AdapterBlock(64), 191 | AdapterBlock(128), 192 | AdapterBlock(256), 193 | AdapterBlock(400), 194 | AdapterBlock(400), 195 | AdapterBlock(400), 196 | AdapterBlock(400) 197 | ]) 198 | self.final_layer = nn.Conv2d(400, 1 + 256 + 2, 2) 199 | self.img_size_to_layer = {2:8, 4:7, 8:6, 16:5, 32:4, 64:3, 128:2, 256:1, 512:0} 200 | 201 | 202 | def forward(self, input, alpha, instance_noise=0, **kwargs): 203 | if instance_noise > 0: 204 | input = input + torch.randn_like(input) * instance_noise 205 | 206 | start = self.img_size_to_layer[input.shape[-1]] 207 | x = self.fromRGB[start](input) 208 | for i, layer in enumerate(self.layers[start:]): 209 | if i == 1: 210 | x = alpha * x + (1 - alpha) * self.fromRGB[start+1](F.interpolate(input, scale_factor=0.5, mode='nearest')) 211 | x = layer(x) 212 | 213 | x = self.final_layer(x).reshape(x.shape[0], -1) 214 | 215 | prediction = x[..., 0:1] 216 | latent = x[..., 1:257] 217 | position = x[..., 257:259] 218 | 219 | return prediction, latent, position -------------------------------------------------------------------------------- /dataset/meshcar_real_sdfgrid.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | import numpy as np 4 | import math 5 | 6 | import torch 7 | import torchvision.transforms as T 8 | from PIL import Image 9 | from torchvision.io import read_image 10 | from torchvision.transforms import InterpolationMode 11 | from tqdm import tqdm 12 | 13 | from scipy.spatial.transform import Rotation 14 | 15 | 16 | class SDFGridDataset(torch.utils.data.Dataset): 17 | 18 | def __init__(self, config, limit_dataset_size=None): 19 | self.dataset_directory = Path(config.dataset_path) 20 | self.mesh_directory = Path(config.mesh_path) 21 | self.image_size = config.image_size 22 | self.bg_color = config.random_bg 23 | self.real_images = {x.name.split('.')[0]: x for x in Path(config.image_path).iterdir() if x.name.endswith('.jpg') or x.name.endswith('.png')} 24 | self.masks = {x: Path(config.mask_path) / self.real_images[x].name for x in self.real_images} 25 | self.items = sorted(list(x.stem for x in Path(config.dataset_path).iterdir())[:limit_dataset_size]) 26 | self.views_per_sample = 1 27 | self.erode = config.erode 28 | self.real_pad_factor = 0.1 29 | self.real_images_preloaded, self.masks_preloaded = {}, {} 30 | if config.preload: 31 | self.preload_real_images() 32 | 33 | def __len__(self): 34 | return len(self.items) 35 | 36 | def __getitem__(self, idx): 37 | selected_item = self.items[idx] 38 | sdf_grid = torch.from_numpy(np.load(self.mesh_directory / selected_item / "064.npy")).unsqueeze(0) - (0.03125 / 2) 39 | normal_grid = compute_normals_from_sdf_dense(sdf_grid) 40 | real_sample, masks, view_matrices, projection_matrices = self.get_image_and_view() 41 | if self.bg_color == 'white': 42 | bg = torch.tensor(1.).float() 43 | else: 44 | bg = torch.tensor(random.random() * 2 - 1).float() 45 | return { 46 | "name": selected_item, 47 | "x": sdf_grid.float(), 48 | "y": normal_grid.float(), 49 | "view": view_matrices[0], 50 | "intrinsic": projection_matrices[0], 51 | "real": real_sample[0], 52 | "mask": masks[0], 53 | "bg": bg 54 | } 55 | 56 | def get_image_and_view(self): 57 | total_selections = len(self.real_images.keys()) // 8 58 | available_views = get_car_views() 59 | view_indices = random.sample(list(range(8)), self.views_per_sample) 60 | sampled_view = [available_views[vidx] for vidx in view_indices] 61 | image_indices = random.sample(list(range(total_selections)), self.views_per_sample) 62 | image_selections = [f'{(iidx * 8 + vidx):05d}' for (iidx, vidx) in zip(image_indices, view_indices)] 63 | images, masks, view_matrices, projection_matrices = [], [], [], [] 64 | for c_i, c_v in zip(image_selections, sampled_view): 65 | images.append(self.get_real_image(c_i)) 66 | masks.append(self.get_real_mask(c_i)) 67 | view_matrix, projection_matrix = self.get_camera(c_v['fov'], c_v['azimuth'], c_v['elevation']) 68 | view_matrices.append(view_matrix) 69 | projection_matrices.append(projection_matrix) 70 | image = torch.cat(images, dim=0) 71 | masks = torch.cat(masks, dim=0) 72 | image = image * masks.expand(-1, 3, -1, -1) + (1 - masks).expand(-1, 3, -1, -1) * torch.ones_like(image) 73 | return image, masks, torch.stack(view_matrices, dim=0), torch.stack(projection_matrices, dim=0) 74 | 75 | def get_real_image(self, name): 76 | if name not in self.real_images_preloaded.keys(): 77 | return self.process_real_image(self.real_images[name]) 78 | else: 79 | return self.real_images_preloaded[name] 80 | 81 | def get_real_mask(self, name): 82 | if name not in self.masks_preloaded.keys(): 83 | return self.process_real_mask(self.masks[name]) 84 | else: 85 | return self.masks_preloaded[name] 86 | 87 | @staticmethod 88 | def erode_mask(mask): 89 | import cv2 as cv 90 | mask = mask.squeeze(0).numpy().astype(np.uint8) 91 | kernel_size = 1 92 | element = cv.getStructuringElement(cv.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1), (kernel_size, kernel_size)) 93 | mask = cv.erode(mask, element) 94 | return torch.from_numpy(mask).unsqueeze(0) 95 | 96 | def process_real_mask(self, path): 97 | pad_size = int(self.image_size * self.real_pad_factor) 98 | resize = T.Resize(size=(self.image_size - 2 * pad_size, self.image_size - 2 * pad_size), interpolation=InterpolationMode.NEAREST) 99 | pad = T.Pad(padding=(pad_size, pad_size), fill=0) 100 | mask_im = read_image(str(path))[:1, :, :] 101 | if self.erode: 102 | eroded_mask = self.erode_mask(mask_im) 103 | else: 104 | eroded_mask = mask_im 105 | t_mask = pad(resize((eroded_mask > 128).float())) 106 | return t_mask.unsqueeze(0) 107 | 108 | def process_real_image(self, path): 109 | pad_size = int(self.image_size * self.real_pad_factor) 110 | resize = T.Resize(size=(self.image_size - 2 * pad_size, self.image_size - 2 * pad_size), interpolation=InterpolationMode.BICUBIC) 111 | pad = T.Pad(padding=(pad_size, pad_size), fill=1) 112 | t_image = pad(torch.from_numpy(np.array(resize(Image.open(str(path)))).transpose((2, 0, 1))).float() / 127.5 - 1) 113 | return t_image.unsqueeze(0) 114 | 115 | def preload_real_images(self): 116 | for ri in tqdm(self.real_images.keys(), desc='preload'): 117 | self.real_images_preloaded[ri] = self.process_real_image(self.real_images[ri]) 118 | self.masks_preloaded[ri] = self.process_real_mask(self.masks[ri]) 119 | 120 | def get_camera(self, fov, azimuth, elevation): 121 | y_angle = azimuth * 180 / math.pi 122 | x_angle = 90 - elevation * 180 / math.pi 123 | z_angle = 180 124 | camera_rot = np.eye(4) 125 | camera_rot[:3, :3] = Rotation.from_euler('x', x_angle, degrees=True).as_matrix() @ Rotation.from_euler('z', z_angle, degrees=True).as_matrix() @ Rotation.from_euler('y', y_angle, degrees=True).as_matrix() 126 | camera_translation = np.eye(4) 127 | camera_translation[:3, 3] = np.array([0, 0, 1.75]) 128 | camera_pose = camera_translation @ camera_rot 129 | translate = torch.tensor([[1.0, 0, 0, 32], [0, 1.0, 0, 32], [0, 0, 1.0, 32], [0, 0, 0, 1.0]]).float() 130 | scale = torch.tensor([[32.0, 0, 0, 0], [0, 32.0, 0, 0], [0, 0, 32.0, 0], [0, 0, 0, 1.0]]).float() 131 | world2grid = translate @ scale 132 | view_matrix = world2grid @ torch.linalg.inv(torch.from_numpy(camera_pose).float()) 133 | camera_intrinsics = torch.zeros((4,)) 134 | f = self.image_size / (2 * np.tan(fov * np.pi / 180 / 2.0)) 135 | camera_intrinsics[0] = f 136 | camera_intrinsics[1] = f 137 | camera_intrinsics[2] = self.image_size / 2 138 | camera_intrinsics[3] = self.image_size / 2 139 | return view_matrix, camera_intrinsics 140 | 141 | 142 | def get_car_views(): 143 | # front, back, right, left, front_right, front_left, back_right, back_left 144 | azimuth = [3 * math.pi / 2, math.pi / 2, 145 | 0, math.pi, 146 | math.pi + math.pi / 3, 0 - math.pi / 3, 147 | math.pi / 2 + math.pi / 6, math.pi / 2 - math.pi / 6] 148 | azimuth_noise = [0, 0, 149 | 0, 0, 150 | (random.random() - 0.5) * math.pi / 8, (random.random() - 0.5) * math.pi / 8, 151 | (random.random() - 0.5) * math.pi / 8, (random.random() - 0.5) * math.pi / 8, ] 152 | elevation = [math.pi / 2, math.pi / 2, 153 | math.pi / 2, math.pi / 2, 154 | math.pi / 2 - math.pi / 48, math.pi / 2 - math.pi / 48, 155 | math.pi / 2 - math.pi / 48, math.pi / 2 - math.pi / 48] 156 | elevation_noise = [-random.random() * math.pi / 70, -random.random() * math.pi / 70, 157 | 0, 0, 158 | -random.random() * math.pi / 32, -random.random() * math.pi / 32, 159 | 0, 0] 160 | return [{'azimuth': a + an, 'elevation': e + en, 'fov': 50} for a, an, e, en in zip(azimuth, azimuth_noise, elevation, elevation_noise)] 161 | 162 | 163 | def compute_normals_from_sdf_dense(sdf): 164 | sdf = sdf.unsqueeze(0) 165 | dims = sdf.shape[2:] 166 | sdfx = sdf[:, :, 1:dims[0] - 1, 1:dims[1] - 1, 2:dims[2]] - sdf[:, :, 1:dims[0] - 1, 1:dims[1] - 1, 0:dims[2] - 2] 167 | sdfy = sdf[:, :, 1:dims[0] - 1, 2:dims[1], 1:dims[2] - 1] - sdf[:, :, 1:dims[0] - 1, 0:dims[1] - 2, 1:dims[2] - 1] 168 | sdfz = sdf[:, :, 2:dims[0], 1:dims[1] - 1, 1:dims[2] - 1] - sdf[:, :, 0:dims[0] - 2, 1:dims[1] - 1, 1:dims[2] - 1] 169 | normals = torch.cat([sdfx, sdfy, sdfz], 1) 170 | normals = torch.nn.functional.pad(normals, [1, 1, 1, 1, 1, 1], value=0) 171 | normals = -torch.nn.functional.normalize(normals, p=2, dim=1, eps=1e-5, out=None) 172 | return normals.squeeze(0) 173 | -------------------------------------------------------------------------------- /model/styleganvox/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | @torch.jit.script 6 | def clamp_gain(x: torch.Tensor, g: float, c: float): 7 | return torch.clamp(x * g, -c, c) 8 | 9 | 10 | def normalize_2nd_moment(x, dim=1, eps=1e-8): 11 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() 12 | 13 | 14 | def identity(x): 15 | return x 16 | 17 | 18 | def leaky_relu_0_2(x): 19 | return torch.nn.functional.leaky_relu(x, 0.2) 20 | 21 | 22 | activation_funcs = { 23 | "linear": { 24 | "fn": identity, 25 | "def_gain": 1 26 | }, 27 | "lrelu": { 28 | "fn": leaky_relu_0_2, 29 | "def_gain": np.sqrt(2) 30 | } 31 | } 32 | 33 | 34 | class FullyConnectedLayer(torch.nn.Module): 35 | 36 | def __init__(self, in_features, out_features, bias=True, activation='linear', lr_multiplier=1, bias_init=0): 37 | super().__init__() 38 | self.activation = activation_funcs[activation]['fn'] 39 | self.activation_gain = activation_funcs[activation]['def_gain'] 40 | self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) 41 | self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None 42 | self.weight_gain = lr_multiplier / np.sqrt(in_features) 43 | self.bias_gain = lr_multiplier 44 | 45 | def forward(self, x): 46 | w = self.weight * self.weight_gain 47 | b = self.bias 48 | if b is not None and self.bias_gain != 1: 49 | b = b * self.bias_gain 50 | x = self.activation(torch.addmm(b.unsqueeze(0), x, w.t())) * self.activation_gain 51 | return x 52 | 53 | 54 | class SmoothDownsample(torch.nn.Module): 55 | 56 | def __init__(self): 57 | super().__init__() 58 | kernel = [[1, 3, 3, 1], 59 | [3, 9, 9, 3], 60 | [3, 9, 9, 3], 61 | [1, 3, 3, 1]] 62 | kernel = torch.tensor([[kernel]], dtype=torch.float) 63 | kernel /= kernel.sum() 64 | self.kernel = torch.nn.Parameter(kernel, requires_grad=False) 65 | self.pad = torch.nn.ReplicationPad2d((2, 1, 2, 1)) 66 | 67 | def forward(self, x: torch.Tensor): 68 | b, c, h, w = x.shape 69 | x = x.view(-1, 1, h, w) 70 | x = self.pad(x) 71 | x = torch.nn.functional.conv2d(x, self.kernel).view(b, c, h, w) 72 | x = torch.nn.functional.interpolate(x, scale_factor=0.5, mode='nearest', recompute_scale_factor=False) 73 | return x 74 | 75 | 76 | class SmoothUpsample(torch.nn.Module): 77 | 78 | def __init__(self): 79 | super().__init__() 80 | 81 | def forward(self, x: torch.Tensor): 82 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True) 83 | return x 84 | 85 | 86 | class EqualizedConv2d(torch.nn.Module): 87 | 88 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, activation='linear', resample=identity): 89 | super().__init__() 90 | self.resample = resample 91 | self.padding = kernel_size // 2 92 | self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) 93 | self.activation = activation_funcs[activation]['fn'] 94 | self.activation_gain = activation_funcs[activation]['def_gain'] 95 | weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]) 96 | bias = torch.zeros([out_channels]) if bias else None 97 | self.weight = torch.nn.Parameter(weight) 98 | self.bias = torch.nn.Parameter(bias) if bias is not None else None 99 | 100 | def forward(self, x, gain=1): 101 | w = self.weight * self.weight_gain 102 | b = self.bias[None, :, None, None] if self.bias is not None else 0 103 | x = self.resample(x) 104 | x = torch.nn.functional.conv2d(x, w, padding=self.padding) 105 | return clamp_gain(self.activation(x + b), self.activation_gain * gain, 256 * gain) 106 | 107 | 108 | def modulated_conv2d(x, weight, styles, padding=0, demodulate=True): 109 | batch_size = x.shape[0] 110 | out_channels, in_channels, kh, kw = weight.shape 111 | 112 | # Calculate per-sample weights and demodulation coefficients. 113 | w = weight.unsqueeze(0) # [NOIkk] 114 | w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] 115 | if demodulate: 116 | dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] 117 | w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] 118 | 119 | # Execute as one fused op using grouped convolution. 120 | batch_size = int(batch_size) 121 | x = x.reshape(1, -1, *x.shape[2:]) 122 | w = w.reshape(-1, in_channels, kh, kw) 123 | x = torch.nn.functional.conv2d(x, w, padding=padding, groups=batch_size) 124 | x = x.reshape(batch_size, -1, *x.shape[2:]) 125 | return x 126 | 127 | 128 | def modulated_conv3d(x, weight, styles, padding=0, demodulate=True): 129 | batch_size = x.shape[0] 130 | out_channels, in_channels, kd, kh, kw = weight.shape 131 | 132 | # Calculate per-sample weights and demodulation coefficients. 133 | w = weight.unsqueeze(0) # [NOIkkk] 134 | w = w * styles.reshape(batch_size, 1, -1, 1, 1, 1) # [NOIkkk] 135 | 136 | if demodulate: 137 | dcoefs = (w.square().sum(dim=[2, 3, 4, 5]) + 1e-8).rsqrt() # [NO] 138 | w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1, 1) # [NOIkk] 139 | 140 | # Execute as one fused op using grouped convolution. 141 | batch_size = int(batch_size) 142 | x = x.reshape(1, -1, *x.shape[2:]) 143 | w = w.reshape(-1, in_channels, kd, kh, kw) 144 | x = torch.nn.functional.conv3d(x, w, padding=padding, groups=batch_size) 145 | x = x.reshape(batch_size, -1, *x.shape[2:]) 146 | return x 147 | 148 | 149 | class Conv3DResBlock(torch.nn.Module): 150 | 151 | def __init__(self, nf_in, nf_out, activation): 152 | super().__init__() 153 | self.nf_in = nf_in 154 | self.nf_out = nf_out 155 | self.norm_0 = torch.nn.BatchNorm3d(nf_in) 156 | self.conv_0 = torch.nn.Conv3d(nf_in, nf_out, kernel_size=3, padding=1) 157 | self.norm_1 = torch.nn.BatchNorm3d(nf_out) 158 | self.conv_1 = torch.nn.Conv3d(nf_out, nf_out, kernel_size=3, padding=1) 159 | self.activation = activation 160 | if nf_in != nf_out: 161 | self.nin_shortcut = torch.nn.Conv3d(nf_out, nf_out, kernel_size=1, padding=0) 162 | 163 | def forward(self, x): 164 | h = x 165 | h = self.norm_0(h) 166 | h = self.activation(h) 167 | h = self.conv_0(h) 168 | 169 | h = self.norm_1(h) 170 | h = self.activation(h) 171 | h = self.conv_1(h) 172 | 173 | if self.nf_in != self.nf_out: 174 | x = self.nin_shortcut(x) 175 | 176 | return x + h 177 | 178 | 179 | class Conv3DBlock(torch.nn.Module): 180 | 181 | def __init__(self, nf_in, nf_out, activation): 182 | super().__init__() 183 | self.nf_in = nf_in 184 | self.nf_out = nf_out 185 | self.norm_0 = torch.nn.BatchNorm3d(nf_in) 186 | self.conv_0 = torch.nn.Conv3d(nf_in, nf_out, kernel_size=3, padding=1) 187 | self.activation = activation 188 | 189 | def forward(self, x): 190 | h = x 191 | h = self.norm_0(h) 192 | h = self.activation(h) 193 | h = self.conv_0(h) 194 | return h 195 | 196 | 197 | class SDFEncoder(torch.nn.Module): 198 | def __init__(self, in_channels, layer_dims=(32, 64, 64, 128, 128, 256, 256, 256)): 199 | super().__init__() 200 | self.activation = torch.nn.LeakyReLU() 201 | self.enc_conv_in = torch.nn.Conv3d(in_channels, layer_dims[0], 1) 202 | self.down_0_block_0 = Conv3DBlock(layer_dims[0], layer_dims[1], self.activation) 203 | self.down_0_block_1 = Conv3DBlock(layer_dims[1], layer_dims[2], self.activation) 204 | self.down_1_block_0 = Conv3DBlock(layer_dims[2], layer_dims[3], self.activation) 205 | self.down_2_block_0 = Conv3DBlock(layer_dims[3], layer_dims[4], self.activation) 206 | self.down_3_block_0 = Conv3DBlock(layer_dims[4], layer_dims[5], self.activation) 207 | self.down_4_block_0 = Conv3DBlock(layer_dims[5], layer_dims[6], self.activation) 208 | self.enc_mid_block_0 = Conv3DBlock(layer_dims[6], layer_dims[7], self.activation) 209 | self.enc_out_conv = torch.nn.Conv3d(layer_dims[7], layer_dims[7], kernel_size=3, padding=1) 210 | self.enc_out_norm = torch.nn.BatchNorm3d(layer_dims[7]) 211 | 212 | def forward(self, x): 213 | level_feats = [] 214 | pool_ctr = 0 215 | x = self.enc_conv_in(x) 216 | x = self.down_0_block_0(x) 217 | x = self.down_0_block_1(x) 218 | level_feats.append(x) 219 | x = torch.nn.functional.interpolate(x, scale_factor=0.5, mode='trilinear', recompute_scale_factor=False, align_corners=True) 220 | pool_ctr += 1 221 | 222 | x = self.down_1_block_0(x) 223 | level_feats.append(x) 224 | x = torch.nn.functional.interpolate(x, scale_factor=0.5, mode='trilinear', recompute_scale_factor=False, align_corners=True) 225 | pool_ctr += 1 226 | 227 | x = self.down_2_block_0(x) 228 | level_feats.append(x) 229 | x = torch.nn.functional.interpolate(x, scale_factor=0.5, mode='trilinear', recompute_scale_factor=False, align_corners=True) 230 | pool_ctr += 1 231 | 232 | x = self.down_3_block_0(x) 233 | level_feats.append(x) 234 | x = torch.nn.functional.interpolate(x, scale_factor=0.5, mode='trilinear', recompute_scale_factor=False, align_corners=True) 235 | pool_ctr += 1 236 | 237 | x = self.down_4_block_0(x) 238 | # level_feats.append(x) 239 | # x = torch.nn.functional.interpolate(x, scale_factor=0.5, mode='trilinear', recompute_scale_factor=False, align_corners=True) 240 | # pool_ctr += 1 241 | 242 | x = self.enc_mid_block_0(x) 243 | x = self.enc_out_norm(x) 244 | x = self.activation(x) 245 | x = self.enc_out_conv(x) 246 | 247 | level_feats.append(x) 248 | 249 | return level_feats 250 | 251 | 252 | if __name__ == '__main__': 253 | model = SDFEncoder(1).cuda() 254 | x = torch.randn(8, 1, 64, 64, 64).cuda() 255 | ll = model(x) 256 | for y in ll: 257 | print(y.shape) 258 | 259 | -------------------------------------------------------------------------------- /model/styleganvox/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from model.styleganvox import activation_funcs, FullyConnectedLayer, clamp_gain, modulated_conv3d, SmoothUpsample, normalize_2nd_moment, identity, SDFEncoder 4 | 5 | 6 | class Generator(torch.nn.Module): 7 | 8 | def __init__(self, z_dim, w_dim, w_num_layers, img_resolution, img_channels, synthesis_layer='stylegan2'): 9 | super().__init__() 10 | self.z_dim = z_dim 11 | self.w_dim = w_dim 12 | self.img_resolution = img_resolution 13 | self.img_channels = img_channels 14 | self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, synthesis_layer=synthesis_layer) 15 | self.num_ws = self.synthesis.num_ws 16 | self.mapping = MappingNetwork(z_dim=z_dim, w_dim=w_dim, num_ws=self.num_ws, num_layers=w_num_layers) 17 | 18 | def forward(self, z, shape, truncation_psi=1, truncation_cutoff=None, noise_mode='random'): 19 | ws = self.mapping(z, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) 20 | img = self.synthesis(ws, shape, noise_mode) 21 | return img 22 | 23 | 24 | class SynthesisNetwork(torch.nn.Module): 25 | 26 | def __init__(self, w_dim, img_resolution, img_channels, synthesis_layer='stylegan2'): 27 | super().__init__() 28 | 29 | self.w_dim = w_dim 30 | self.img_resolution = img_resolution 31 | self.img_resolution_log2 = int(np.log2(img_resolution)) 32 | self.img_channels = img_channels 33 | self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] 34 | self.num_ws = 2 * (len(self.block_resolutions) + 1) 35 | channels_dict = {4: 512, 8: 512, 16: 512, 32: 256, 64: 128} 36 | channels_dict_geo = {4: 256, 8: 256, 16: 128, 32: 128, 64: 64, 128: 0} 37 | self.blocks = torch.nn.ModuleList() 38 | self.first_block = SynthesisPrologue(channels_dict[self.block_resolutions[0]], w_dim=w_dim, geo_channels=channels_dict_geo[4], 39 | resolution=self.block_resolutions[0], img_channels=img_channels, 40 | synthesis_layer=synthesis_layer) 41 | for res in self.block_resolutions[1:]: 42 | in_channels = channels_dict[res // 2] 43 | geo_channels = channels_dict_geo[res] 44 | out_channels = channels_dict[res] 45 | block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, geo_channels=geo_channels, resolution=res, img_channels=img_channels, synthesis_layer=synthesis_layer) 46 | self.blocks.append(block) 47 | 48 | def forward(self, ws, shape, noise_mode='random'): 49 | split_ws = [ws[:, 0:2, :]] + [ws[:, 2 * n + 1: 2 * n + 4, :] for n in range(len(self.block_resolutions))] 50 | x, img = self.first_block(split_ws[0], noise_mode) 51 | for i in range(len(self.block_resolutions) - 1): 52 | x, img = self.blocks[i](x, img, split_ws[i + 1], shape[len(shape) - 1 - i], noise_mode) 53 | return img 54 | 55 | 56 | class SynthesisPrologue(torch.nn.Module): 57 | 58 | def __init__(self, out_channels, w_dim, geo_channels, resolution, img_channels, synthesis_layer): 59 | super().__init__() 60 | self.w_dim = w_dim 61 | self.resolution = resolution 62 | self.img_channels = img_channels 63 | self.img_channels = img_channels 64 | self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution, resolution])) 65 | self.conv1 = SynthesisLayer(out_channels, out_channels - geo_channels, w_dim=w_dim, resolution=resolution) 66 | self.torgb = ToRGBLayer(out_channels - geo_channels, img_channels, w_dim=w_dim) 67 | 68 | def forward(self, ws, noise_mode): 69 | w_iter = iter(ws.unbind(dim=1)) 70 | x = self.const.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1, 1]) 71 | x = self.conv1(x, next(w_iter), noise_mode=noise_mode) 72 | img = self.torgb(x, next(w_iter)) 73 | return x, img 74 | 75 | 76 | class SynthesisBlock(torch.nn.Module): 77 | 78 | def __init__(self, in_channels, out_channels, w_dim, geo_channels, resolution, img_channels, synthesis_layer): 79 | super().__init__() 80 | self.in_channels = in_channels 81 | self.w_dim = w_dim 82 | self.resolution = resolution 83 | self.img_channels = img_channels 84 | self.num_conv = 0 85 | self.num_torgb = 0 86 | self.resampler = SmoothUpsample() 87 | self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, resampler=self.resampler) 88 | self.conv1 = SynthesisLayer(out_channels, out_channels - geo_channels, w_dim=w_dim, resolution=resolution) 89 | self.torgb = ToRGBLayer(out_channels - geo_channels, img_channels, w_dim=w_dim) 90 | 91 | def forward(self, x, img, ws, shape, noise_mode): 92 | w_iter = iter(ws.unbind(dim=1)) 93 | 94 | x = torch.cat([x, shape], dim=1) 95 | x = self.conv0(x, next(w_iter), noise_mode=noise_mode) 96 | x = self.conv1(x, next(w_iter), noise_mode=noise_mode) 97 | 98 | y = self.torgb(x, next(w_iter)) 99 | img = self.resampler(img) 100 | img = img.add_(y) 101 | 102 | return x, img 103 | 104 | 105 | class ToRGBLayer(torch.nn.Module): 106 | 107 | def __init__(self, in_channels, out_channels, w_dim, kernel_size=1): 108 | super().__init__() 109 | self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) 110 | self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size, kernel_size])) 111 | self.bias = torch.nn.Parameter(torch.zeros([out_channels])) 112 | self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 3)) 113 | 114 | def forward(self, x, w): 115 | styles = self.affine(w) * self.weight_gain 116 | x = modulated_conv3d(x=x, weight=self.weight, styles=styles, demodulate=False) 117 | return torch.clamp(x + self.bias[None, :, None, None, None], -256, 256) 118 | 119 | 120 | class SynthesisLayer(torch.nn.Module): 121 | 122 | def __init__(self, in_channels, out_channels, w_dim, resolution, kernel_size=3, resampler=identity, activation='lrelu'): 123 | super().__init__() 124 | self.resolution = resolution 125 | self.resampler = resampler 126 | self.activation = activation_funcs[activation]['fn'] 127 | self.activation_gain = activation_funcs[activation]['def_gain'] 128 | self.padding = kernel_size // 2 129 | self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) 130 | self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size, kernel_size])) 131 | 132 | self.register_buffer('noise_const', torch.randn([resolution, resolution, resolution])) 133 | self.noise_strength = torch.nn.Parameter(torch.zeros([1])) 134 | 135 | self.bias = torch.nn.Parameter(torch.zeros([out_channels])) 136 | 137 | def forward(self, x, w, noise_mode, gain=1): 138 | styles = self.affine(w) 139 | 140 | noise = None 141 | if noise_mode == 'random': 142 | noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution, self.resolution], device=x.device) * self.noise_strength 143 | if noise_mode == 'const': 144 | noise = self.noise_const * self.noise_strength 145 | 146 | x = modulated_conv3d(x=x, weight=self.weight, styles=styles, padding=self.padding) 147 | x = self.resampler(x) 148 | x = x + noise 149 | 150 | return clamp_gain(self.activation(x + self.bias[None, :, None, None, None]), self.activation_gain * gain, 256 * gain) 151 | 152 | 153 | class MappingNetwork(torch.nn.Module): 154 | 155 | def __init__(self, z_dim, w_dim, num_ws, num_layers=8, activation='lrelu', lr_multiplier=0.01, w_avg_beta=0.995): 156 | super().__init__() 157 | self.z_dim = z_dim 158 | self.w_dim = w_dim 159 | self.num_ws = num_ws 160 | self.num_layers = num_layers 161 | self.w_avg_beta = w_avg_beta 162 | 163 | features_list = [z_dim] + [w_dim] * num_layers 164 | 165 | self.layers = torch.nn.ModuleList() 166 | for idx in range(num_layers): 167 | in_features = features_list[idx] 168 | out_features = features_list[idx + 1] 169 | self.layers.append(FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)) 170 | 171 | if num_ws is not None and w_avg_beta is not None: 172 | self.register_buffer('w_avg', torch.zeros([w_dim])) 173 | 174 | def forward(self, z, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): 175 | # Embed, normalize, and concat inputs. 176 | x = normalize_2nd_moment(z) 177 | 178 | # Main layers. 179 | for idx in range(self.num_layers): 180 | x = self.layers[idx](x) 181 | 182 | # Update moving average of W. 183 | if self.w_avg_beta is not None and self.training and not skip_w_avg_update: 184 | self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) 185 | 186 | # Broadcast. 187 | if self.num_ws is not None: 188 | x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) 189 | 190 | # Apply truncation. 191 | if truncation_psi != 1: 192 | if self.num_ws is None or truncation_cutoff is None: 193 | x = self.w_avg.lerp(x, truncation_psi) 194 | else: 195 | x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) 196 | 197 | return x 198 | 199 | 200 | def test_generator(): 201 | import time 202 | batch_size = 2 203 | G = Generator(512, 512, 2, 64, 3).cuda() 204 | E = SDFEncoder(1).cuda() 205 | print_model_parameter_count(G) 206 | print_model_parameter_count(E) 207 | for batch_idx in range(16): 208 | # sanity test forward pass 209 | z = torch.randn(batch_size, 512).to(torch.device("cuda:0")) 210 | x = torch.randn(batch_size, 1, 64, 64, 64).to(torch.device("cuda:0")) 211 | shape = E(x) 212 | w = G.mapping(z) 213 | t0 = time.time() 214 | fake = G.synthesis(w, shape) 215 | print('Time for fake:', time.time() - t0, ', shape:', fake.shape) 216 | # sanity test backwards 217 | loss = torch.abs(fake - torch.rand_like(fake)).mean() 218 | t0 = time.time() 219 | loss.backward() 220 | print('Time for backwards:', time.time() - t0) 221 | print('backwards done') 222 | break 223 | 224 | 225 | if __name__ == '__main__': 226 | from util.misc import print_model_parameter_count, print_module_summary 227 | 228 | # model = Generator(512, 512, 2, 16, 3).cuda() 229 | # print_module_summary(model, (torch.randn((16, 512)).cuda(), )) 230 | # print_model_parameter_count(model) 231 | 232 | test_generator() 233 | --------------------------------------------------------------------------------