├── .gitignore ├── assets └── teaser.jpg ├── examples ├── generated │ ├── fabric.jpg │ ├── pasta.jpg │ ├── stone.jpg │ ├── wood.jpg │ └── ironwall.jpg ├── specular │ ├── ceiling.jpg │ ├── chain.jpg │ ├── metal.jpg │ └── titanium.jpg └── in_the_wild │ ├── wild_1.jpg │ ├── wild_2.jpg │ ├── wild_3.jpg │ ├── wild_4.jpg │ └── wild_5.jpg ├── requirements.txt ├── chord ├── module │ ├── __init__.py │ ├── base.py │ ├── light.py │ ├── stable_diffusion.py │ └── chord.py ├── __init__.py ├── util.py └── io.py ├── config └── chord.yaml ├── test.py ├── README.md ├── LICENSE └── demo_gradio.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .venv 3 | *.ckpt 4 | output -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/assets/teaser.jpg -------------------------------------------------------------------------------- /examples/generated/fabric.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/generated/fabric.jpg -------------------------------------------------------------------------------- /examples/generated/pasta.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/generated/pasta.jpg -------------------------------------------------------------------------------- /examples/generated/stone.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/generated/stone.jpg -------------------------------------------------------------------------------- /examples/generated/wood.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/generated/wood.jpg -------------------------------------------------------------------------------- /examples/specular/ceiling.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/specular/ceiling.jpg -------------------------------------------------------------------------------- /examples/specular/chain.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/specular/chain.jpg -------------------------------------------------------------------------------- /examples/specular/metal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/specular/metal.jpg -------------------------------------------------------------------------------- /examples/generated/ironwall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/generated/ironwall.jpg -------------------------------------------------------------------------------- /examples/in_the_wild/wild_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/in_the_wild/wild_1.jpg -------------------------------------------------------------------------------- /examples/in_the_wild/wild_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/in_the_wild/wild_2.jpg -------------------------------------------------------------------------------- /examples/in_the_wild/wild_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/in_the_wild/wild_3.jpg -------------------------------------------------------------------------------- /examples/in_the_wild/wild_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/in_the_wild/wild_4.jpg -------------------------------------------------------------------------------- /examples/in_the_wild/wild_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/in_the_wild/wild_5.jpg -------------------------------------------------------------------------------- /examples/specular/titanium.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubisoft/ubisoft-laforge-chord/HEAD/examples/specular/titanium.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface_hub[hf_xet] 2 | diffusers 3 | transformers==4.57.1 4 | tokenizers==0.22.1 5 | typer 6 | omegaconf 7 | imageio 8 | tqdm -------------------------------------------------------------------------------- /chord/module/__init__.py: -------------------------------------------------------------------------------- 1 | modules = {} 2 | 3 | def register(name): 4 | def decorator(cls): 5 | modules[name] = cls 6 | return cls 7 | return decorator 8 | 9 | 10 | def make(name, config): 11 | model = modules[name](config) 12 | return model 13 | 14 | 15 | from . import ( 16 | light, 17 | stable_diffusion, 18 | chord, 19 | ) -------------------------------------------------------------------------------- /chord/module/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Base(nn.Module): 5 | def __init__(self, config): 6 | super().__init__() 7 | self.config = config 8 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 9 | self.setup() 10 | 11 | def setup(self): 12 | raise NotImplementedError 13 | 14 | -------------------------------------------------------------------------------- /chord/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from chord.module import make 4 | from chord.module.chord import post_decoder 5 | 6 | class ChordModel(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.model = make(config.model.name, config.model) 10 | 11 | def forward(self, x: torch.Tensor): 12 | x = {"render": x} 13 | pred = self.model(x) 14 | return post_decoder(pred) -------------------------------------------------------------------------------- /config/chord.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: chord 3 | roughness_step: 5. 4 | metallic_step: 1. 5 | # format: "OutputMapName": ConvInInput1_ConvInInput2_{0/1} 6 | # 0/1 stands for using gt/pred image; 7 | chain_type: chord 8 | chain_library: 9 | chord: 10 | basecolor: render_0 11 | normal: render_approxIrr_01 12 | rou_met: render_approxRM_01 13 | rgbx_prompts: 14 | basecolor: Basecolor 15 | normal: Normal 16 | roughness: Roughness 17 | metallic: Metallic 18 | irradiance: Irradiance 19 | rou_met: Roughness and Metallic 20 | prior_light: 21 | name: distant-light 22 | direction: [-1.0, -1.0, 1.0] # Top-left corner towards bottom right 23 | color: [23.47, 21.31, 20.79] 24 | power: 0.1 25 | stable_diffusion: 26 | name: stable_diffusion 27 | fp16: true 28 | vae_padding: circular 29 | version: 2.1 30 | -------------------------------------------------------------------------------- /chord/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def vector_dot(A: torch.Tensor, B: torch.Tensor, min=0.0) -> torch.Tensor: 4 | return torch.clamp((A * B).sum(1, keepdim=True), min=min, max=1.0) 5 | 6 | def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: 7 | return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)).to(f.dtype) 8 | 9 | def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: 10 | return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055).to(f.dtype) 11 | 12 | def tone_gamma(x: torch.Tensor) -> torch.Tensor: 13 | x = 1 - torch.exp(-x) 14 | return torch.pow(x, 1.0/2.2) 15 | 16 | # safe division for value range 0-1 17 | class safe_01_div(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, a, b): 20 | ctx.save_for_backward(a, b) 21 | return torch.div(a, torch.clamp(b, min=1e-4, max=1.0)) 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | a, b = ctx.saved_tensors 26 | grad_input = grad_output.clone() 27 | 28 | return torch.div(1, torch.clamp(b, min=1e-4, max=1.0)) * grad_input, -1 * torch.div(a, torch.clamp(b, min=1e-2, max=1.0)**2) * grad_input 29 | 30 | 31 | def get_positions(h, w, real_size, use_pixel_centers=True) -> torch.Tensor: 32 | pixel_center = 0.5 if use_pixel_centers else 0 33 | i, j = torch.meshgrid( 34 | torch.arange(h) + pixel_center, 35 | torch.arange(w) + pixel_center, 36 | indexing='ij' 37 | ) 38 | if not isinstance(real_size, list): 39 | real_size = [real_size] * 2 40 | pos = torch.stack([(i / h - 0.5) * real_size[0], (j / w - 0.5) * real_size[1], torch.zeros_like(i)], dim=-1) 41 | return pos 42 | 43 | # N, H: (Bx3xHxW), roughness: (Bx1xHxW) 44 | # The "D", facet distribution function in Cook-Torrence model 45 | def DistributionGGX(cosNH, roughness): 46 | a = roughness * roughness 47 | a2 = a * a 48 | cosNH2 = cosNH * cosNH 49 | num = a2 50 | denom = cosNH2 * (a2 - 1.0) + 1.0 51 | denom = torch.pi * denom * denom 52 | return num / denom 53 | 54 | # NdotV, roughness: (Bx1xHxW) 55 | def GeometrySchlickGGX(NdotV: torch.Tensor, roughness: torch.Tensor) -> torch.Tensor: 56 | r = (roughness + 1.0) 57 | k = (r*r) / 8.0 58 | 59 | num = NdotV 60 | denom = NdotV * (1.0 - k) + k 61 | 62 | return num / denom 63 | 64 | # cosTheta, F0 (Bx1xHxW) 65 | # The "F" 66 | def fresnelSchlick(cosTheta: torch.Tensor, F0: torch.Tensor) -> torch.Tensor: 67 | return F0 + (1.0 - F0) * torch.pow(1.0 - cosTheta, 5.0) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | from torchvision.transforms import v2 5 | import typer 6 | from pathlib import Path 7 | from typing_extensions import Annotated 8 | from omegaconf import OmegaConf 9 | import tqdm 10 | from huggingface_hub import hf_hub_download 11 | 12 | from chord import ChordModel 13 | from chord.io import read_image, save_maps, load_torch_file 14 | 15 | app = typer.Typer(pretty_exceptions_show_locals=False) 16 | 17 | def setup_python_path(): 18 | current_dir = os.path.dirname(os.path.abspath(__file__)) 19 | if current_dir not in sys.path: 20 | sys.path.insert(0, current_dir) 21 | 22 | def get_image_files(indir): 23 | files = [] 24 | for ext in ("*.png", "*.jpg", "*.jpeg"): 25 | for path in Path(indir).rglob(ext): 26 | files.append(str(path)) 27 | return files 28 | 29 | @app.command() 30 | def inference( 31 | input_dir: Annotated[str, typer.Option(default=..., help="Paths to the input image directory.")], 32 | output_dir: Annotated[str, typer.Option(help="Directory to save output maps.")] = "output", 33 | config_path: Annotated[str, typer.Option(help="Path to the config file.")] = "config/chord.yaml", 34 | ): 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | config = OmegaConf.load(config_path) 37 | model = ChordModel(config) 38 | ckpt_path = hf_hub_download(repo_id="Ubisoft/ubisoft-laforge-chord", filename="chord_v1.safetensors") 39 | print(f"[INFO] Loading model from: {ckpt_path}") 40 | state_dict = load_torch_file(ckpt_path) 41 | model.load_state_dict(state_dict) 42 | model.eval() 43 | model.to(device) 44 | 45 | os.makedirs(output_dir, exist_ok=True) 46 | image_files = get_image_files(input_dir) 47 | print(f"[INFO] found {len(image_files)} images in {input_dir}") 48 | print(f"[INFO] saving results to {output_dir}") 49 | 50 | for image_file in tqdm.tqdm(image_files, desc="[INFO] processing images"): 51 | image = read_image(image_file).to(device) 52 | ori_h, ori_w = image.shape[-2:] 53 | x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0) 54 | image_name = Path(image_file).stem 55 | with torch.no_grad() as no_grad, torch.autocast(device_type="cuda") as amp: 56 | output = model(x) 57 | for key in output.keys(): 58 | output[key] = v2.Resize(size=(ori_h, ori_w), antialias=True)(output[key]) 59 | output.update({"input": image}) 60 | save_maps(os.path.join(output_dir, image_name), output) 61 | 62 | if __name__ == "__main__": 63 | setup_python_path() 64 | app() -------------------------------------------------------------------------------- /chord/module/light.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | import torch.nn.functional as Fn 4 | import math 5 | import copy 6 | 7 | from . import register 8 | from .base import Base 9 | 10 | class BaseLight(Base): 11 | """ 12 | Base class for light models. 13 | """ 14 | 15 | def setup(self): 16 | pass 17 | 18 | def forward(self, x: Optional[torch.Tensor] = None): 19 | """ 20 | Get the light intensity. 21 | 22 | Args: 23 | x: positions of shape (..., 3). 24 | 25 | Returns: 26 | color: radiance intensity of shape (..., 3) 27 | d: directions of shape (..., 3). 28 | """ 29 | raise NotImplementedError 30 | 31 | 32 | @register("point-light") 33 | class PointLight(BaseLight): 34 | """Point light definitions 35 | """ 36 | def setup(self): 37 | """Initialize point light. 38 | 39 | Args: 40 | position (float, float, float): World coordinate of the light. 41 | color (float, float, float): Light color in (R, G, B). 42 | power (float): Light power, it will be directly multiplied to each color channel. 43 | """ 44 | position = self.config.get("position", [0., 0., 10.]) 45 | color = self.config.get("color", [23.47, 21.31, 20.79]) 46 | power = self.config.get("power", 10.) 47 | 48 | self.register_buffer("position", torch.tensor(position)) 49 | self.register_buffer("color", torch.tensor(color) * power) 50 | 51 | def forward(self, x: Optional[torch.Tensor] = None): 52 | """Compute light radiance and direction. 53 | 54 | Args: 55 | x : World coordinate of the interacting surface. [B, H, W, 3] 56 | Returns: 57 | color: radiance intensity of shape [B, H, W, 3] 58 | d: directions of shape [B, H, W, 3], V = (light_pos - world_pos) 59 | """ 60 | distance = torch.norm(self.position - x, dim=-1, keepdim=True) 61 | attenuation = 1.0 / (distance ** 2) 62 | radiance = self.color * attenuation 63 | direction = Fn.normalize(self.position - x, dim=-1) 64 | return radiance, direction 65 | 66 | @register("distant-light") 67 | class DistantLight(BaseLight): 68 | """Distant light definitions 69 | """ 70 | def setup(self): 71 | """Initialize distant light. 72 | 73 | Args: 74 | direction (float, float, float):The direction of light vector. 75 | color (float, float, float): Light color in (R, G, B). 76 | power (float): Light power, it will be directly multiplied to each color channel. 77 | """ 78 | direction = self.config.get("direction", [0., 0., 1.]) 79 | color = self.config.get("color", [23.47, 21.31, 20.79]) 80 | power = self.config.get("power", 0.1) 81 | 82 | self.register_buffer("color", torch.tensor(color) * power) 83 | self.register_buffer("direction", Fn.normalize(torch.tensor(direction), dim=0)) 84 | 85 | def forward(self, x: Optional[torch.Tensor] = None): 86 | """Compute light radiance and direction. 87 | 88 | Args: 89 | x : World coordinate of the interacting surface. [B, H, W, 3] 90 | Returns: 91 | color: radiance intensity of shape [B, H, W, 3] 92 | d: directions of shape [B, H, W, 3] 93 | """ 94 | radiance = self.color.repeat(*x.shape[:-1], 1) 95 | direction = self.direction.repeat(*x.shape[:-1], 1) 96 | return radiance, direction -------------------------------------------------------------------------------- /chord/io.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import imageio.v3 as imageio 3 | import numpy as np 4 | import warnings 5 | import os 6 | import safetensors 7 | 8 | import torchvision.transforms.functional as F 9 | 10 | def read_image(filename: str, out: torch.Tensor=None) -> torch.Tensor: 11 | ''' 12 | Read a local image file into a float tensor (pixel values are normalized to [0, 1], CxHxW) 13 | 14 | Args: 15 | filename: Image file path. 16 | out: Fill in this tensor rather than return a new tensor if provided. 17 | 18 | Returns: 19 | Loaded image tensor. 20 | ''' 21 | with warnings.catch_warnings(): 22 | warnings.simplefilter("ignore") # ignore PIL's user warning that reads fp16 img as fp32 23 | img: np.ndarray = imageio.imread(filename) 24 | 25 | # Convert the image array to float tensor according to its data type 26 | res = None 27 | if img.dtype == np.uint8: 28 | img = img.astype(np.float32) / 255.0 29 | elif img.dtype == np.uint16 or img.dtype == np.int32: 30 | img = img.astype(np.float32) / 65535.0 31 | else: 32 | raise ValueError(f'Unrecognized image pixel value type: {img.dtype}') 33 | if img.ndim == 2: 34 | res = torch.from_numpy(img).unsqueeze(0) # 1xHxW for grayscale images 35 | elif img.ndim == 3: 36 | res = torch.from_numpy(img).movedim(2, 0)[:3] # HxWxC to CxHxW 37 | else: 38 | raise ValueError(f'Unrecognized image dimension: {img.shape}') 39 | 40 | if out is None: 41 | return res 42 | out.copy_(res) 43 | 44 | def create_img(img: torch.Tensor): 45 | ''' 46 | Convert tensor to PIL image 47 | 48 | Args: 49 | path: Image tensor CxHxW. Squeeze if BxCxHxW and B==1 50 | 51 | Returns: 52 | PIL image 53 | ''' 54 | if img.dim() == 4: 55 | assert img.shape[0] == 1 56 | img = img.squeeze(0) 57 | 58 | if img.shape[0] == 4: 59 | out_img = F.to_pil_image(img, mode="CMYK") 60 | out_img = out_img.convert('RGB') 61 | elif img.shape[0] == 3: 62 | out_img = F.to_pil_image(img, mode="RGB") 63 | elif img.shape[0] == 1: 64 | out_img = F.to_pil_image(img, mode="L") 65 | else: 66 | raise ValueError("Unsupported image dimension.") 67 | return out_img 68 | 69 | def save_maps(path: str, maps: dict): 70 | ''' 71 | Save SVBRDF maps to a given path. 72 | 73 | Args: 74 | path: Output path. 75 | maps: Named maps of tensor images. 76 | ''' 77 | if not os.path.exists(path): 78 | os.makedirs(path) 79 | for name, image in maps.items(): 80 | out_img = create_img(image) 81 | out_img.save(os.path.join(path, name+".png")) 82 | 83 | def load_torch_file(ckpt, device=None): 84 | if device is None: 85 | device = torch.device("cpu") 86 | if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): 87 | with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: 88 | state_dict = {} 89 | for k in f.keys(): 90 | tensor = f.get_tensor(k) 91 | state_dict[k] = tensor 92 | else: 93 | torch_args = {} 94 | ckpt = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) 95 | 96 | if "state_dict" in ckpt: 97 | state_dict = ckpt["state_dict"] 98 | else: 99 | state_dict = ckpt 100 | 101 | return state_dict -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture Images

3 | 4 | arXiv 5 | Project Page 6 | Demo 7 | Custom Node 8 | 9 | [Zhi Ying](https://orcid.org/0009-0008-8390-3366)\*, [Boxiang Rong](https://ribosome-rbx.github.io/)\*, [Jingyu Wang](https://ccetaw.github.io/), [Maoyuan Xu](https://ultraman-blazar.github.io/) 10 | 11 | teaser 12 |
13 | 14 | Official implementation of the paper "**Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture Images**". 15 | 16 | ## Setup environment 17 | 18 | 1. Clone github repo: 19 | 20 | ```shell 21 | git clone https://github.com/ubisoft/ubisoft-laforge-chord 22 | cd ubisoft-laforge-chord 23 | ``` 24 | 25 | 2. Install dependencies. The example below uses [uv](https://docs.astral.sh/uv/getting-started/) to manage the virtual environment: 26 | 27 | ```shell 28 | # Get Python environment 29 | uv venv --python 3.12 30 | 31 | # On Linux/WSL 32 | source .venv/bin/activate 33 | 34 | # Or on Windows 35 | .venv\Scripts\activate 36 | 37 | # If you encounter the following error on Windows: 38 | # File .venv\Scripts\activate.ps1 cannot be loaded because running scripts is disabled on this system 39 | # Run the command: Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass 40 | 41 | # Install dependencies 42 | uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128 43 | uv pip install -r requirements.txt 44 | ``` 45 | 46 | 3. Agree to the model's term from [here](https://huggingface.co/Ubisoft/ubisoft-laforge-chord), then log in: 47 | 48 | ``` 49 | huggingface-cli login 50 | ``` 51 | 52 | 4. (Optional) Install Gradio for running demo locally: 53 | 54 | ``` 55 | uv pip install gradio 56 | ``` 57 | 58 | ## Usage example 59 | 60 | Run test: 61 | 62 | ```shell 63 | python test.py --input-dir examples 64 | ``` 65 | 66 | Run the Gradio demo locally: 67 | 68 | ```shell 69 | python demo_gradio.py 70 | ``` 71 | 72 | ## License 73 | 74 | This project is released under the **Ubisoft Machine Learning License (Research-Only - Copyleft)**. See the full terms in the [LICENSE](LICENSE) file. 75 | 76 | ## Citation 77 | 78 | If you find our work useful, please consider citing: 79 | 80 | ``` 81 | @inproceedings{ying2025chord, 82 | author = {Ying, Zhi and Rong, Boxiang and Wang, Jingyu and Xu, Maoyuan}, 83 | title = {Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture Images}, 84 | year = {2025}, 85 | isbn = {9798400721373}, 86 | publisher = {Association for Computing Machinery}, 87 | address = {New York, NY, USA}, 88 | url = {https://doi.org/10.1145/3757377.3763848}, 89 | doi = {10.1145/3757377.3763848}, 90 | booktitle = {Proceedings of the SIGGRAPH Asia 2025 Conference Papers}, 91 | articleno = {164}, 92 | numpages = {11}, 93 | keywords = {Appearance Modeling, Material Generation, Texture Synthesis, SVBRDF, Image-conditional Diffusion Models}, 94 | series = {SA Conference Papers '25} 95 | } 96 | ``` 97 | 98 | © [2025] Ubisoft Entertainment. All Rights Reserved. 99 | -------------------------------------------------------------------------------- /chord/module/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import v2 3 | 4 | from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler 5 | from transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer 6 | 7 | from . import register 8 | from .base import Base 9 | 10 | 11 | def apply_padding(model, mode): 12 | for layer in [layer for _, layer in model.named_modules() if isinstance(layer, torch.nn.Conv2d)]: 13 | if mode == 'circular': 14 | layer.padding_mode = 'circular' 15 | else: 16 | layer.padding_mode = 'zeros' 17 | return model 18 | 19 | def freeze(model): 20 | model = model.eval() 21 | for param in model.parameters(): 22 | param.requires_grad = False 23 | return model 24 | 25 | @register("stable_diffusion") 26 | class StableDiffusion(Base): 27 | def setup(self): 28 | hf_key = self.config.get("hf_key", None) 29 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | fp16 = self.config.get("fp16", True) 31 | self.dtype = torch.bfloat16 if fp16 else torch.float32 32 | vae_padding = self.config.get("vae_padding", "zeros") 33 | 34 | self.sd_version = self.config.get("version", 2.1) 35 | local_files_only = False 36 | if hf_key is not None: 37 | print(f"[INFO] using hugging face custom model key: {hf_key}") 38 | model_key = hf_key 39 | local_files_only = True 40 | elif str(self.sd_version) == "2.1": 41 | # model_key = "stabilityai/stable-diffusion-2-1" 42 | # StabilityAI deleted the original 2.1 model from HF, use a community version 43 | model_key = "RedbeardNZ/stable-diffusion-2-1-base" 44 | else: 45 | raise ValueError( 46 | f"Stable-diffusion version {self.sd_version} not supported." 47 | ) 48 | 49 | # Load components separately to avoid download unnecessary weights 50 | # 1. UNet (diffusion backbone) 51 | unet_config = UNet2DConditionModel.load_config(model_key, subfolder="unet") 52 | self.unet = UNet2DConditionModel.from_config(unet_config, local_files_only=local_files_only) 53 | self.unet.to(self.device, dtype=self.dtype).eval() 54 | # 2. VAE (image autoencoder) 55 | vae_config = AutoencoderKL.load_config(model_key, subfolder="vae") 56 | self.vae = AutoencoderKL.from_config(vae_config, local_files_only=local_files_only) 57 | self.vae.to(self.device, dtype=self.dtype).eval() 58 | self.vae = apply_padding(freeze(self.vae), vae_padding) 59 | # 3. Text encoder (CLIP) 60 | text_encoder_config = CLIPTextConfig.from_pretrained(model_key, subfolder="text_encoder", local_files_only=local_files_only) 61 | self.text_encoder = CLIPTextModel(text_encoder_config) 62 | self.text_encoder.to(self.device, dtype=self.dtype).eval() 63 | # 4. Tokenizer (CLIP tokenizer, this one has vocab so from_pretrained is needed) 64 | self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer", local_files_only=local_files_only) 65 | # 5. Scheduler 66 | scheduler_config = DDIMScheduler.load_config(model_key, subfolder="scheduler") 67 | scheduler_config["prediction_type"] = "v_prediction" 68 | scheduler_config["timestep_spacing"] = "trailing" 69 | scheduler_config["rescale_betas_zero_snr"] = True 70 | self.scheduler = DDIMScheduler.from_config(scheduler_config) 71 | 72 | def encode_text(self, prompt, padding_mode="do_not_pad"): 73 | # prompt: [str] 74 | inputs = self.tokenizer( 75 | prompt, 76 | padding=padding_mode, 77 | max_length=self.tokenizer.model_max_length, 78 | return_tensors="pt", 79 | ) 80 | embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] 81 | return embeddings 82 | 83 | def decode_latents(self, latents): 84 | latents = 1 / self.vae.config.scaling_factor * latents 85 | imgs = self.vae.decode(latents).sample 86 | imgs = (imgs / 2 + 0.5).clamp(0, 1) 87 | return imgs 88 | 89 | def encode_imgs(self, imgs): 90 | if imgs.shape[1] == 1: # for grayscale maps 91 | imgs = v2.functional.grayscale_to_rgb(imgs) 92 | imgs = 2 * imgs - 1 93 | posterior = self.vae.encode(imgs).latent_dist 94 | latents = posterior.sample() * self.vae.config.scaling_factor 95 | return latents 96 | 97 | def encode_imgs_deterministic(self, imgs): 98 | if imgs.shape[1] == 1: # for grayscale maps 99 | imgs = v2.functional.grayscale_to_rgb(imgs) 100 | imgs = 2 * imgs - 1 101 | h = self.vae.encoder(imgs) 102 | moments = self.vae.quant_conv(h) 103 | mean, logvar = torch.chunk(moments, 2, dim=1) 104 | latents = mean * self.vae.config.scaling_factor 105 | return latents -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Ubisoft Machine Learning License (Research-Only - Copyleft) 2 | 3 | This license governs the use, reproduction, and distribution of the Licensed 4 | Materials, including AI Models and associated source code for the sole purpose 5 | of scientific research. By accessing, downloading or using the Licensed 6 | Materials, you hereby accept to be bound by this [Ubisoft Machine Learning 7 | License (Research-Only - Copyleft)] agreement (hereinafter the “License”). 8 | 9 | 1. Licensed Materials 10 | 11 | - AI Models 12 | - Source Code 13 | 14 | 2. Definitions 15 | 16 | “Licensed Materials”: Refers to the AI Models and/or Source Code licensed under 17 | this agreement. 18 | "Source Code" means the preferred form of the work for making modifications to 19 | it corresponding to text written using human-readable programming language. 20 | "Object Code" means any non-source form of a work. 21 | “AI Model” means any machine learning based assembly or assemblies (including 22 | checkpoints), consisting of learnt weights, parameters (including optimizer 23 | states), corresponding to the model architecture as embodied in the Source Code. 24 | “Output” means the results of operating an AI Model as embodied in 25 | informational content resulting therefrom. 26 | “Derivative”: Any work derived from or based upon the Licensed Materials, 27 | including modifications. 28 | “Permitted Purpose”: Use for academic or research purposes only. Commercial 29 | use is strictly prohibited. 30 | “Distribution”: Any sharing of the Licensed Materials or Derivatives with third 31 | parties, including hosting as a service. 32 | “Licensor”: The rights holder or authorized entity granting this License. 33 | “You”: The individual or entity receiving and exercising rights under this 34 | License. 35 | 36 | 3. Grant of Rights 37 | 38 | Subject to compliance with the terms of this License, You are granted a 39 | worldwide, royalty-free, non-exclusive License to use, study, reproduce, 40 | modify, and distribute the Licensed Materials and Derivatives solely for the 41 | Permitted Purpose. As between You and Licensor, Licensor claims no rights in 42 | the Outputs You generate using the AI Models used in accordance with the 43 | Permitted Purpose. 44 | 45 | 4. Distribution of Licensed Materials and Derivatives 46 | 47 | Any Distribution of the Derivatives of the Licensed Materials, or the Licensed 48 | Materials shall be licensed under the same exact terms as this License. 49 | Redistribution shall include this License and retain all notices of author 50 | attribution and all modifications shall be clearly marked. 51 | 52 | 5. Use Restrictions 53 | 54 | You shall not use the Licensed Materials or its Derivatives for: 55 | - any other purposes than the Permitted Purpose, including for commercial 56 | purposes such as using the Licensed Materials in any activity intended for 57 | commercial advantage or monetary compensation directly or indirectly; 58 | - weaponry, warfare, military applications, surveillance, or any activity that 59 | may cause harm or violate human rights; 60 | - engaging or enabling fully automated decision-making that may adversely 61 | impacts a natural person's legal rights; 62 | - providing medical advice or making clinical decisions; 63 | - generating content that promotes or incites hatred, violence, discrimination, 64 | or harm based on race, ethnicity, religion, gender, sexual orientation, or 65 | any other protected characteristic; 66 | - generating content that includes depictions of sexual abuse, sexual 67 | violence, explicit pornography, or any form of non-consensual acts and/or 68 | generating content that includes depictions of child nudity, child 69 | pornography, or any form of child exploitation; 70 | 71 | 6. Disclaimer of Warranty 72 | 73 | THE LICENSED MATERIALS IS PROVIDED "AS IS" AND “AS AVAILABLE” WITHOUT 74 | WARRANTIES OF ANY KIND WHETHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 75 | THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, 76 | NON-INFRINGEMENT, CORRECTNESS, ACCURACY, OR RELIABILITY. THE LICENSOR DISCLAIMS 77 | ALL LIABILITY FOR DAMAGES RESULTING FROM THE USE OR INABILITY TO USE THE 78 | LICENSED MATERIALS. THE USE OF THE LICENSED MATERIALS AND ANY OUTPUTS YOU MAY 79 | GENERATE SHALL BE AT YOUR OWN RISK. 80 | 81 | 7. Termination 82 | 83 | This License terminates automatically if You violate any of its terms. Upon 84 | termination, You shall cease all use and distribution of the Licensed 85 | Materials and its Derivatives. 86 | 87 | 8. Governing Law 88 | 89 | The validity of this Agreement and any of its terms and provisions, as well as 90 | the rights and duties of the parties hereunder, shall be governed, interpreted 91 | and enforced in accordance with the laws of France. 92 | 93 | 9. Miscellaneous 94 | 95 | If any provision of this License is held to be invalid, illegal or 96 | unenforceable, the remaining provisions shall be unaffected thereby and remain 97 | valid as if such provision had not been set forth herein. 98 | 99 | Copyright (C) 2025 UBISOFT ENTERTAINMENT. All Rights Reserved. 100 | -------------------------------------------------------------------------------- /demo_gradio.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | import copy 7 | from omegaconf import OmegaConf 8 | from torchvision.transforms import v2 9 | from torchvision.transforms.functional import to_pil_image 10 | from huggingface_hub import hf_hub_download 11 | 12 | from chord import ChordModel 13 | from chord.module import make 14 | from chord.util import get_positions, rgb_to_srgb 15 | from chord.io import load_torch_file 16 | 17 | EXAMPLES_USECASE_1 = [ 18 | [f"examples/generated/{f}"] 19 | for f in sorted(os.listdir("examples/generated")) 20 | ] 21 | EXAMPLES_USECASE_2 = [ 22 | [f"examples/in_the_wild/{f}"] 23 | for f in sorted(os.listdir("examples/in_the_wild")) 24 | ] 25 | EXAMPLES_USECASE_3 = [ 26 | [f"examples/specular/{f}"] 27 | for f in sorted(os.listdir("examples/specular")) 28 | ] 29 | 30 | MODEL_OBJ = None 31 | MODEL_CKPT_PATH = hf_hub_download(repo_id="Ubisoft/ubisoft-laforge-chord", filename="chord_v1.safetensors") 32 | def load_model(ckpt_path): 33 | print("Loading model from:", ckpt_path) 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | config = OmegaConf.load("config/chord.yaml") 36 | model = ChordModel(config) 37 | state_dict = load_torch_file(ckpt_path) 38 | model.load_state_dict(state_dict) 39 | model.eval() 40 | model.to(device) 41 | return model 42 | 43 | def run_model(model, img: Image.Image): 44 | to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) 45 | image = to_tensor(img).to(next(model.parameters()).device) 46 | x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0) 47 | with torch.no_grad() as no_grad, torch.autocast(device_type="cuda") as amp: 48 | output = model(x) 49 | output.update({"input": image}) 50 | return output 51 | 52 | def relit(model, maps): 53 | maps['metallic'] = maps.get('metalness', torch.zeros_like(maps['basecolor'])) 54 | device = next(model.parameters()).device 55 | h, w = maps["basecolor"].shape[-2:] 56 | light = make("point-light", {"position": [0, 0, 10]}).to(device) 57 | pos = get_positions(h, w, 10).to(device) 58 | camera = torch.tensor([0, 0, 10.0]).to(device) 59 | for key in maps: 60 | if maps[key].dim() == 3: 61 | maps[key] = maps[key].unsqueeze(0) 62 | maps[key] = maps[key].permute(0,2,3,1) # BxCxHxW -> BxHxWxC 63 | rgb = model.model.compute_render(maps, camera, pos, light).squeeze(0).permute(0,3,1,2) # GxBxHxWxC -> BxCxHxW 64 | return torch.clamp(rgb_to_srgb(rgb), 0, 1) 65 | 66 | def inference(img): 67 | global MODEL_OBJ 68 | 69 | if MODEL_OBJ is None or getattr(MODEL_OBJ, "_ckpt", None) != MODEL_CKPT_PATH: 70 | MODEL_OBJ = load_model(MODEL_CKPT_PATH) 71 | MODEL_OBJ._ckpt = MODEL_CKPT_PATH # store path inside object 72 | 73 | if img is None: 74 | return None, None, None, None, None 75 | 76 | ori_h, ori_w = img.size[1], img.size[0] 77 | out = run_model(MODEL_OBJ, img) 78 | maps = copy.deepcopy(out) 79 | rendered = relit(MODEL_OBJ, maps) 80 | resize_back = v2.Resize(size=(ori_h, ori_w), antialias=True) 81 | return ( 82 | to_pil_image(resize_back(out["basecolor"]).squeeze(0)), 83 | to_pil_image(resize_back(out["normal"]).squeeze(0)), 84 | to_pil_image(resize_back(out["roughness"]).squeeze(0)), 85 | to_pil_image(resize_back(out["metalness"]).squeeze(0)), 86 | to_pil_image(resize_back(rendered).squeeze(0)), 87 | ) 88 | 89 | with gr.Blocks(title="Chord") as demo: 90 | 91 | gr.Markdown("# **Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture Images**") 92 | gr.Markdown("Upload an image or select an example to estimate PBR channels.") 93 | 94 | with gr.Row(): 95 | with gr.Column(): 96 | input_img = gr.Image(type="pil", label="Input Image", height=512) 97 | 98 | gr.Markdown("### Example Inputs — Generated Textures") 99 | gr.Examples( 100 | examples=EXAMPLES_USECASE_1, 101 | inputs=[input_img], 102 | label="Examples (Generated Textures)" 103 | ) 104 | 105 | gr.Markdown("### Example Inputs — In The Wild Photographs") 106 | gr.Examples( 107 | examples=EXAMPLES_USECASE_2, 108 | inputs=[input_img], 109 | label="Examples (In The Wild Photographs)" 110 | ) 111 | 112 | gr.Markdown("### Example Inputs — Specular Textures") 113 | gr.Examples( 114 | examples=EXAMPLES_USECASE_3, 115 | inputs=[input_img], 116 | label="Examples (Specular Textures)" 117 | ) 118 | 119 | run_button = gr.Button("Run Estimation") 120 | 121 | with gr.Column(): 122 | gr.Markdown("### Predicted Channels") 123 | basecolor_out = gr.Image(label="Basecolor", height=512) 124 | normal_out = gr.Image(label="Normal", height=512) 125 | roughness_out = gr.Image(label="Roughness", height=512) 126 | metallic_out = gr.Image(label="Metalness", height=512) 127 | 128 | gr.Markdown("### Relit Output") 129 | render_out = gr.Image(label="Relit Image (Centered Point Light)", height=512) 130 | 131 | run_button.click( 132 | inference, 133 | inputs=[input_img], 134 | outputs=[basecolor_out, normal_out, roughness_out, metallic_out, render_out] 135 | ) 136 | 137 | if __name__ == "__main__": 138 | demo.launch() 139 | -------------------------------------------------------------------------------- /chord/module/chord.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as Fn 5 | from torchvision.transforms import v2 6 | 7 | from . import register, make 8 | from .base import Base 9 | 10 | from chord.util import fresnelSchlick, GeometrySchlickGGX, DistributionGGX 11 | from chord.util import srgb_to_rgb, tone_gamma, get_positions, safe_01_div 12 | 13 | class dummy_module(nn.Module): 14 | def forward(self, x): return x 15 | 16 | def post_decoder(out_dict): 17 | out = {} 18 | for key in out_dict.keys(): 19 | if key.startswith("approx"): continue 20 | elif key == "normal": 21 | out[key] = Fn.normalize(2. * out_dict[key] - 1., dim=1) / 2. + 0.5 22 | elif key == "rou_met": 23 | out['roughness'], out['metalness'] = out_dict['rou_met'][:,0], out_dict['rou_met'][:,1] 24 | else: out[key] = out_dict[key] 25 | return out 26 | 27 | def process_irradiance(radiance, kernel_size=25, res=64): 28 | """ 29 | Process the irradiance using PyTorch, equivalent to the original OpenCV-based function. 30 | 31 | Args: 32 | radiance (torch.Tensor): Input radiance tensor (H, W). 33 | kernel_size (int): Size of the kernel for the median blur. 34 | res (int): Target resolution for resizing the image. 35 | 36 | Returns: 37 | torch.Tensor: Processed radiance tensor (res, res). 38 | """ 39 | # Ensure the input radiance is a 4D tensor (B, 1, H, W) 40 | assert radiance.shape[1] == 1 and radiance.dim() == 4, f"Invalid radiance shape, got {radiance.shape}" 41 | # resize to low resolution 42 | resizer = v2.Resize(size=res, antialias=True) 43 | radiance = resizer(radiance) 44 | 45 | # Define a 11x11 averaging kernel 46 | kernel = torch.ones((1, 1, 11, 11), dtype=torch.float32).to(radiance) / 121.0 47 | # Apply convolution (averaging filter) 48 | radiance = Fn.pad(radiance, (5,)*4, mode="reflect") # Pad for edge handling 49 | radiance = Fn.conv2d(radiance, kernel, padding=0) # 'padding=2' to maintain input dimensions 50 | 51 | # Clamp values and scale to [0, 255] for median filtering 52 | radiance = torch.clamp(radiance * 255, 0, 255) # Remove batch/channel dims 53 | 54 | # Apply median filtering 55 | paded_radiance = Fn.pad(radiance, (kernel_size // 2,) * 4, mode="reflect") # Pad for edge handling 56 | unfolded = Fn.unfold(paded_radiance, kernel_size) # Extract patches 57 | radiance = torch.median(unfolded, dim=1).values.view(radiance.shape) # Median of patches 58 | 59 | # Normalize to [0, 1] 60 | rad_min, rad_max = radiance.amin([2,3], keepdim=True), radiance.amax([2,3], keepdim=True) 61 | radiance = (radiance - rad_min) / (rad_max - rad_min) 62 | return radiance 63 | 64 | def opt_light_dir(_radiance, _num_samples=6): 65 | ''' 66 | _radiance: (bs, 1, h, w) 67 | ''' 68 | assert _radiance.shape[1] == 1 and _radiance.dim()==4 69 | bs, _, h, w = _radiance.shape 70 | 71 | def evenly_sample(_num_samples, min=0, max=2*torch.pi): 72 | # returns torch.tensor([1, _num_samples]) 73 | return torch.tensor(range(_num_samples+1)) * (max - min) / _num_samples + min 74 | 75 | def compute_radiance_diff(angles): 76 | num = angles.shape[-1] 77 | dirs = torch.cat([torch.cos(angles), torch.sin(angles)]).T 78 | pos_dir = grid_pos.repeat(num, 1, 1, 1) 79 | pos_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) > 0 80 | neg_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) < 0 81 | samples_radiance = _radiance.repeat(1,num,1,1) 82 | radiance_diff = (samples_radiance*pos_mask[None] - samples_radiance*neg_mask[None]).sum([2,3]) 83 | return radiance_diff 84 | 85 | angle_min, angle_max = 0, 2*torch.pi 86 | grid_pos = Fn.normalize(get_positions(h,w,10)[...,:2], dim=-1, eps=1e-6).to(_radiance) 87 | while(((angle_max - angle_min) > (torch.pi/90))): 88 | angles = evenly_sample(_num_samples, angle_min, angle_max)[None].to(_radiance) 89 | diffs = compute_radiance_diff(angles).mean(0) 90 | angle_min = angles[:,diffs.argmax()].item() - (angle_max - angle_min)/_num_samples 91 | angle_max = angles[:,diffs.argmax()].item() + (angle_max - angle_min)/_num_samples 92 | 93 | light_angle = angles[:, diffs.argmax()] 94 | return torch.tensor([torch.cos(light_angle), torch.sin(light_angle)]).to(_radiance) 95 | 96 | 97 | def find_light_dir(raw_irradiance, light): 98 | raw_irradiance = v2.functional.rgb_to_grayscale(raw_irradiance) 99 | irradiance = process_irradiance(raw_irradiance) 100 | dir = opt_light_dir(irradiance) 101 | dir = torch.cat([dir, torch.tensor([0.5**0.5]).to(dir)]) 102 | _light = copy.deepcopy(light) 103 | _light.direction = dir 104 | return _light 105 | 106 | @register("chord") 107 | class Chord(Base): 108 | def setup(self): 109 | # Define forward chain 110 | self.chain_type = self.config.get("chain_type", "chord") 111 | self.chain = self.config.get("chain_library", {})[self.chain_type] 112 | self.prompts = self.config.get("rgbx_prompts", {}) 113 | self.roughness_step = self.config.get("roughness_step", 10) 114 | self.metallic_step = self.config.get("metallic_step", 0.2) 115 | 116 | self.sd = make(self.config.stable_diffusion.name, self.config.stable_diffusion) 117 | self.dtype = self.sd.dtype 118 | self.device = self.sd.device 119 | 120 | # LEGO-conditioning 121 | self.sd.unet.ConvIns = nn.ModuleDict() 122 | self.sd.unet.ConvOuts = nn.ModuleDict() 123 | self.sd.unet.FirstDownBlocks = nn.ModuleDict() 124 | self.sd.unet.LastUpBlocks = nn.ModuleDict() 125 | for key in list(set("_".join(self.chain.values()).split("_"))) + ["noise"]: 126 | if "0" in key or "1" in key: continue 127 | self.sd.unet.ConvIns[key] = nn.Conv2d(4, 320, 3, 1 , 1, device=self.device, dtype=self.dtype) 128 | self.sd.unet.ConvIns[key].load_state_dict(self.sd.unet.conv_in.state_dict()) 129 | for kout in list(set(self.chain.keys())): 130 | self.sd.unet.ConvOuts[kout] = nn.Conv2d(320, 4, 3, 1 , 1, device=self.device, dtype=self.dtype) 131 | self.sd.unet.ConvOuts[kout].load_state_dict(self.sd.unet.conv_out.state_dict()) 132 | self.sd.unet.LastUpBlocks[kout] = copy.deepcopy(self.sd.unet.up_blocks[-1]).to(self.device) 133 | self.sd.unet.FirstDownBlocks[kout] = copy.deepcopy(self.sd.unet.down_blocks[0]).to(self.device) 134 | self.sd.unet.ConvIns.train() 135 | self.sd.unet.ConvOuts.train() 136 | self.sd.unet.FirstDownBlocks.train() 137 | self.sd.unet.LastUpBlocks.train() 138 | self.sd.unet.conv_in = dummy_module() 139 | self.sd.unet.conv_out = dummy_module() 140 | 141 | # Load Lights 142 | if self.config.get("prior_light", None) is None: 143 | self.prior_light = make("point-light", {"position": [0, 0, 10]}) 144 | else: 145 | self.prior_light = make(self.config.prior_light.name, self.config.prior_light) 146 | 147 | # Init Embeddings 148 | self.text_emb = {} 149 | # Eq.3 150 | def compute_approxIrr(self, render, basecolor): 151 | approxIrr = safe_01_div.apply(srgb_to_rgb(render), srgb_to_rgb(basecolor)) 152 | return tone_gamma(approxIrr) 153 | # Eq.6 154 | @torch.no_grad() 155 | def compute_approxRouMet(self, render, maps, seperate=False, light=None): 156 | render = srgb_to_rgb(render) 157 | bs, _, h, w = render.shape 158 | light = find_light_dir(maps['approxIrr'], self.prior_light) if light is None else light 159 | # light.direction = estimate_light_dir(render, maps) 160 | pos = get_positions(h, w, 10).to(self.device) 161 | cameras = torch.tensor([0, 0, 10.0]).to(self.device) 162 | 163 | # sample grid 164 | r_samples = torch.arange(25, 225+self.roughness_step, self.roughness_step) / 255 165 | m_samples = torch.arange(0., 1.+self.metallic_step, self.metallic_step) 166 | 167 | grid_maps = {} # change map size into: gs, bs, h, w, c 168 | grid_maps['basecolor'] = maps['basecolor'][None].permute(0,1,3,4,2) 169 | grid_maps['normal'] = maps['normal'][None].permute(0,1,3,4,2) 170 | r_values = r_samples[:,None].repeat(1,len(m_samples)).reshape(-1,1,1,1,1).to(maps['basecolor']) 171 | m_values = m_samples[None].repeat(len(r_samples),1).reshape(-1,1,1,1,1).to(maps['basecolor']) 172 | # split into chunks to avoid OOM 173 | chunk_size = 25 174 | rgb_list, r_list, m_list = [], [], [] 175 | for _r, _m in zip(torch.split(r_values, chunk_size), torch.split(m_values, chunk_size)): 176 | grid_maps['roughness'], grid_maps['metallic'] = _r, _m 177 | _rgb = self.compute_render(grid_maps, cameras, pos, light) 178 | loss = (render[None].permute(0,1,3,4,2) - _rgb).abs().sum(-1,keepdim=True) 179 | min_idx = loss.argmin(dim=0,keepdim=True) 180 | r_list.append(torch.gather(grid_maps['roughness'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape)) 181 | m_list.append(torch.gather(grid_maps['metallic'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape)) 182 | rgb_list.append(torch.gather(_rgb, 0, min_idx.repeat(1,1,1,1,3))) 183 | rgb = torch.cat(rgb_list).permute(0,1,4,2,3) 184 | roughness = torch.cat(r_list).permute(0,1,4,2,3) 185 | metallic = torch.cat(m_list).permute(0,1,4,2,3) 186 | loss = (render[None] - rgb).abs().sum(2,keepdim=True) 187 | roughness = torch.gather(roughness, 0, loss.argmin(dim=0,keepdim=True))[0] 188 | metallic = torch.gather(metallic, 0, loss.argmin(dim=0,keepdim=True))[0] 189 | torch.cuda.empty_cache() 190 | if seperate: 191 | return roughness, metallic 192 | else: 193 | out = torch.cat([roughness, metallic, torch.zeros_like(roughness)], dim=1) 194 | return out 195 | 196 | 197 | @torch.no_grad() 198 | def compute_render(self, maps, camera_position, pos, light): 199 | ''' 200 | maps: gs, bs, h, w, c (gs: the number of grids) 201 | ''' 202 | def cos(x, y): 203 | return torch.clamp((x*y).sum(-1, keepdim=True), min=0, max=1) 204 | 205 | # pre-process 206 | albedo = srgb_to_rgb(maps['basecolor']) 207 | normal = maps['normal'].clone() 208 | normal[..., :2] = normal[..., [1,0]] 209 | N = Fn.normalize((normal - 0.5) * 2.0, dim=-1, eps=1e-6) 210 | roughness = maps['roughness'] 211 | metallic = maps['metallic'] 212 | V = Fn.normalize(camera_position - pos, dim=-1, eps=1e-6).repeat(1,1,1,1,1).to(self.device) 213 | irradiance, L = light(pos) 214 | irradiance, L = irradiance.repeat(1,1,1,1,1).to(self.device), L.repeat(1,1,1,1,1).to(self.device) 215 | # rendering 216 | H = Fn.normalize(L+V, dim=-1, eps=1e-6) 217 | f0 = torch.ones_like(albedo).to(self.device) * 0.04 218 | F0 = torch.lerp(f0, albedo, metallic) 219 | F = fresnelSchlick(cos(H,V), F0) 220 | ks = F 221 | 222 | diffuse = (1-ks) * albedo / torch.pi 223 | diffuse *= 1-metallic 224 | 225 | NDF = DistributionGGX(cos(N,H), roughness) 226 | G = GeometrySchlickGGX(cos(N,L), roughness) * GeometrySchlickGGX(cos(N,V), roughness) 227 | 228 | numerator = NDF * G * F 229 | denominator = 4.0 * cos(N,V) * cos(N,L) + 1e-3 230 | specular = numerator / denominator 231 | ambient = 0.3 * albedo 232 | 233 | rgb = (diffuse + specular) * irradiance * cos(N,L) + ambient 234 | 235 | return rgb 236 | 237 | def forward(self, maps:dict): 238 | # prepare 239 | bs = maps['render'].shape[0] 240 | self.sd.scheduler.set_timesteps(1) 241 | t = self.sd.scheduler.timesteps[0] 242 | # chain processing 243 | pred, pred_latent, arxiv_latent = {}, {}, {} 244 | for kout, info in self.chain.items(): 245 | info = info.split("_") 246 | keys, ids = info[:-1], info[-1] 247 | # Swap active LEGO blocks 248 | self.sd.unet.down_blocks[0] = self.sd.unet.FirstDownBlocks[kout] 249 | self.sd.unet.up_blocks[-1] = self.sd.unet.LastUpBlocks[kout] 250 | # Eq.2, summing input latents 251 | in_latent = 0 252 | for k, i in zip(keys, ids): 253 | if i=="0": 254 | if not k in arxiv_latent.keys(): arxiv_latent[k] = self.sd.encode_imgs_deterministic(maps[k]) 255 | zx = arxiv_latent[k] 256 | else: 257 | zx = pred_latent[k] 258 | in_latent += self.sd.unet.ConvIns[k](zx) 259 | in_latent = in_latent / len(keys) 260 | # single-step denoising 261 | embs = self.produce_embeddings(kout, bs) 262 | out_latent = self.sd.unet(in_latent, t, **embs)[0] 263 | out_latent = self.sd.unet.ConvOuts[kout](out_latent) 264 | pred_latent[kout] = self.sd.scheduler.step(out_latent, t, torch.zeros_like(zx)).pred_original_sample 265 | pred[kout] = self.sd.decode_latents(pred_latent[kout]).float() 266 | # compute intermediate representations 267 | if self.chain_type in ["chord"] and kout == "basecolor": 268 | pred['approxIrr'] = self.compute_approxIrr(maps['render'], pred['basecolor']) 269 | pred_latent['approxIrr'] = self.sd.encode_imgs_deterministic(pred['approxIrr']) 270 | if self.chain_type in ["chord"] and kout == "normal": 271 | pred['approxRM'] = self.compute_approxRouMet(maps['render'], pred, seperate=False) 272 | pred_latent['approxRM'] = self.sd.encode_imgs_deterministic(pred['approxRM']) 273 | 274 | return pred 275 | 276 | @torch.no_grad() 277 | def produce_embeddings(self, key, batch_size): 278 | if key not in self.text_emb.keys(): 279 | self.text_emb[key] = self.sd.encode_text(self.prompts[key], "max_length") 280 | prompt_emb = self.text_emb[key].expand(batch_size, -1, -1) 281 | return { "encoder_hidden_states": prompt_emb } --------------------------------------------------------------------------------