├── point_e ├── __init__.py ├── evals │ ├── __init__.py │ ├── scripts │ │ ├── evaluate_pis.py │ │ ├── evaluate_pfid.py │ │ └── blender_script.py │ ├── fid_is.py │ ├── pointnet2_cls_ssg.py │ ├── feature_extractor.py │ ├── npz_stream.py │ └── pointnet2_utils.py ├── models │ ├── __init__.py │ ├── util.py │ ├── checkpoint.py │ ├── download.py │ ├── configs.py │ ├── sdf.py │ ├── perceiver.py │ ├── pretrained_clip.py │ └── transformer.py ├── util │ ├── __init__.py │ ├── ply_util.py │ ├── plotting.py │ ├── mesh.py │ ├── pc_to_mesh.py │ └── point_cloud.py ├── diffusion │ ├── __init__.py │ ├── configs.py │ ├── sampler.py │ ├── k_diffusion.py │ └── gaussian_diffusion.py └── examples │ ├── paper_banner.gif │ ├── example_data │ ├── corgi.jpg │ ├── corgi.ply │ ├── pc_corgi.npz │ ├── cube_stack.jpg │ └── pc_cube_stack.npz │ ├── pointcloud2mesh.ipynb │ ├── image2pointcloud.ipynb │ └── text2pointcloud.ipynb ├── .gitignore ├── setup.py ├── Dockerfile ├── LICENSE ├── README.md └── model-card.md /point_e/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /point_e/evals/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /point_e/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /point_e/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /point_e/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | __pycache__/ 3 | point_e_model_cache/ 4 | .ipynb_checkpoints/ 5 | .DS_Store 6 | 7 | -------------------------------------------------------------------------------- /point_e/examples/paper_banner.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpaviot/point-e/main/point_e/examples/paper_banner.gif -------------------------------------------------------------------------------- /point_e/examples/example_data/corgi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpaviot/point-e/main/point_e/examples/example_data/corgi.jpg -------------------------------------------------------------------------------- /point_e/examples/example_data/corgi.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpaviot/point-e/main/point_e/examples/example_data/corgi.ply -------------------------------------------------------------------------------- /point_e/examples/example_data/pc_corgi.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpaviot/point-e/main/point_e/examples/example_data/pc_corgi.npz -------------------------------------------------------------------------------- /point_e/examples/example_data/cube_stack.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpaviot/point-e/main/point_e/examples/example_data/cube_stack.jpg -------------------------------------------------------------------------------- /point_e/examples/example_data/pc_cube_stack.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpaviot/point-e/main/point_e/examples/example_data/pc_cube_stack.npz -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="point-e", 5 | packages=[ 6 | "point_e", 7 | "point_e.diffusion", 8 | "point_e.evals", 9 | "point_e.models", 10 | "point_e.util", 11 | ], 12 | install_requires=[ 13 | "filelock", 14 | "Pillow", 15 | "torch", 16 | "fire", 17 | "humanize", 18 | "requests", 19 | "tqdm", 20 | "matplotlib", 21 | "scikit-image", 22 | "scipy", 23 | "numpy", 24 | "clip @ git+https://github.com/openai/CLIP.git", 25 | ], 26 | author="OpenAI", 27 | ) 28 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM jupyter/scipy-notebook:notebook-6.5.2 2 | MAINTAINER Thomas Paviot 3 | 4 | USER jovyan 5 | 6 | ###################################### 7 | # create dedicated conda environment # 8 | ###################################### 9 | RUN /opt/conda/bin/conda config --set always_yes yes --set changeps1 no 10 | RUN /opt/conda/bin/conda update -q conda 11 | RUN /opt/conda/bin/conda info -a 12 | RUN /opt/conda/bin/conda config --add channels https://conda.anaconda.org/conda-forge 13 | RUN /opt/conda/bin/conda create --name pointe python=3.9 14 | RUN source activate pointe 15 | RUN conda install pip 16 | 17 | ################### 18 | # Install point-e # 19 | ################### 20 | RUN git clone https://github.com/tpaviot/point-e.git 21 | RUN /opt/conda/bin/pip install -e point-e/. 22 | 23 | ##################### 24 | # back to user mode # 25 | ##################### 26 | WORKDIR /home/jovyan/point-e/point_e/examples 27 | -------------------------------------------------------------------------------- /point_e/models/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def timestep_embedding(timesteps, dim, max_period=10000): 7 | """ 8 | Create sinusoidal timestep embeddings. 9 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 10 | These may be fractional. 11 | :param dim: the dimension of the output. 12 | :param max_period: controls the minimum frequency of the embeddings. 13 | :return: an [N x dim] Tensor of positional embeddings. 14 | """ 15 | half = dim // 2 16 | freqs = torch.exp( 17 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 18 | ).to(device=timesteps.device) 19 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None] 20 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 21 | if dim % 2: 22 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 23 | return embedding 24 | -------------------------------------------------------------------------------- /point_e/evals/scripts/evaluate_pis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate P-IS of a batch of point clouds. 3 | 4 | The point cloud batch should be saved to an npz file, where there is an 5 | arr_0 key of shape [N x K x 3], where K is the dimensionality of each 6 | point cloud and N is the number of clouds. 7 | """ 8 | 9 | import argparse 10 | 11 | from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices 12 | from point_e.evals.fid_is import compute_inception_score 13 | from point_e.evals.npz_stream import NpzStreamer 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--cache_dir", type=str, default=None) 19 | parser.add_argument("batch", type=str) 20 | args = parser.parse_args() 21 | 22 | print("creating classifier...") 23 | clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir) 24 | 25 | print("computing batch predictions") 26 | _, preds = clf.features_and_preds(NpzStreamer(args.batch)) 27 | print(f"P-IS: {compute_inception_score(preds)}") 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /point_e/evals/scripts/evaluate_pfid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate P-FID between two batches of point clouds. 3 | 4 | The point cloud batches should be saved to two npz files, where there 5 | is an arr_0 key of shape [N x K x 3], where K is the dimensionality of 6 | each point cloud and N is the number of clouds. 7 | """ 8 | 9 | import argparse 10 | 11 | from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices 12 | from point_e.evals.fid_is import compute_statistics 13 | from point_e.evals.npz_stream import NpzStreamer 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--cache_dir", type=str, default=None) 19 | parser.add_argument("batch_1", type=str) 20 | parser.add_argument("batch_2", type=str) 21 | args = parser.parse_args() 22 | 23 | print("creating classifier...") 24 | clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir) 25 | 26 | print("computing first batch activations") 27 | 28 | features_1, _ = clf.features_and_preds(NpzStreamer(args.batch_1)) 29 | stats_1 = compute_statistics(features_1) 30 | del features_1 31 | 32 | features_2, _ = clf.features_and_preds(NpzStreamer(args.batch_2)) 33 | stats_2 = compute_statistics(features_2) 34 | del features_2 35 | 36 | print(f"P-FID: {stats_1.frechet_distance(stats_2)}") 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Point·E 2 | 3 | ![Animation of four 3D point clouds rotating](point_e/examples/paper_banner.gif) 4 | 5 | This is the official code and model release for [Point-E: A System for Generating 3D Point Clouds from Complex Prompts](https://arxiv.org/abs/2212.08751). 6 | 7 | # Usage 8 | 9 | Install with `pip install -e .`. 10 | 11 | To get started with examples, see the following notebooks: 12 | 13 | * [image2pointcloud.ipynb](point_e/examples/image2pointcloud.ipynb) - sample a point cloud, conditioned on some example synthetic view images. 14 | * [text2pointcloud.ipynb](point_e/examples/text2pointcloud.ipynb) - use our small, worse quality pure text-to-3D model to produce 3D point clouds directly from text descriptions. This model's capabilities are limited, but it does understand some simple categories and colors. 15 | * [pointcloud2mesh.ipynb](point_e/examples/pointcloud2mesh.ipynb) - try our SDF regression model for producing meshes from point clouds. 16 | 17 | For our P-FID and P-IS evaluation scripts, see: 18 | 19 | * [evaluate_pfid.py](point_e/evals/scripts/evaluate_pfid.py) 20 | * [evaluate_pis.py](point_e/evals/scripts/evaluate_pis.py) 21 | 22 | For our Blender rendering code, see [blender_script.py](point_e/evals/scripts/blender_script.py) 23 | 24 | # Samples 25 | 26 | You can download the seed images and point clouds corresponding to the paper banner images [here](https://openaipublic.azureedge.net/main/point-e/banner_pcs.zip). 27 | 28 | You can download the seed images used for COCO CLIP R-Precision evaluations [here](https://openaipublic.azureedge.net/main/point-e/coco_images.zip). 29 | -------------------------------------------------------------------------------- /point_e/diffusion/configs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py 3 | """ 4 | 5 | from typing import Any, Dict 6 | 7 | import numpy as np 8 | 9 | from .gaussian_diffusion import ( 10 | GaussianDiffusion, 11 | SpacedDiffusion, 12 | get_named_beta_schedule, 13 | space_timesteps, 14 | ) 15 | 16 | BASE_DIFFUSION_CONFIG = { 17 | "channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0], 18 | "channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255], 19 | "mean_type": "epsilon", 20 | "schedule": "cosine", 21 | "timesteps": 1024, 22 | } 23 | 24 | DIFFUSION_CONFIGS = { 25 | "base40M-imagevec": BASE_DIFFUSION_CONFIG, 26 | "base40M-textvec": BASE_DIFFUSION_CONFIG, 27 | "base40M-uncond": BASE_DIFFUSION_CONFIG, 28 | "base40M": BASE_DIFFUSION_CONFIG, 29 | "base300M": BASE_DIFFUSION_CONFIG, 30 | "base1B": BASE_DIFFUSION_CONFIG, 31 | "upsample": { 32 | "channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0], 33 | "channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255], 34 | "mean_type": "epsilon", 35 | "schedule": "linear", 36 | "timesteps": 1024, 37 | }, 38 | } 39 | 40 | 41 | def diffusion_from_config(config: Dict[str, Any]) -> GaussianDiffusion: 42 | schedule = config["schedule"] 43 | steps = config["timesteps"] 44 | respace = config.get("respacing", None) 45 | mean_type = config.get("mean_type", "epsilon") 46 | betas = get_named_beta_schedule(schedule, steps) 47 | channel_scales = config.get("channel_scales", None) 48 | channel_biases = config.get("channel_biases", None) 49 | if channel_scales is not None: 50 | channel_scales = np.array(channel_scales) 51 | if channel_biases is not None: 52 | channel_biases = np.array(channel_biases) 53 | kwargs = dict( 54 | betas=betas, 55 | model_mean_type=mean_type, 56 | model_var_type="learned_range", 57 | loss_type="mse", 58 | channel_scales=channel_scales, 59 | channel_biases=channel_biases, 60 | ) 61 | if respace is None: 62 | return GaussianDiffusion(**kwargs) 63 | else: 64 | return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs) 65 | -------------------------------------------------------------------------------- /point_e/models/checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 3 | """ 4 | 5 | from typing import Callable, Iterable, Sequence, Union 6 | 7 | import torch 8 | 9 | 10 | def checkpoint( 11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], 12 | inputs: Sequence[torch.Tensor], 13 | params: Iterable[torch.Tensor], 14 | flag: bool, 15 | ): 16 | """ 17 | Evaluate a function without caching intermediate activations, allowing for 18 | reduced memory at the expense of extra compute in the backward pass. 19 | :param func: the function to evaluate. 20 | :param inputs: the argument sequence to pass to `func`. 21 | :param params: a sequence of parameters `func` depends on but does not 22 | explicitly take as arguments. 23 | :param flag: if False, disable gradient checkpointing. 24 | """ 25 | if flag: 26 | args = tuple(inputs) + tuple(params) 27 | return CheckpointFunction.apply(func, len(inputs), *args) 28 | else: 29 | return func(*inputs) 30 | 31 | 32 | class CheckpointFunction(torch.autograd.Function): 33 | @staticmethod 34 | def forward(ctx, run_function, length, *args): 35 | ctx.run_function = run_function 36 | ctx.input_tensors = list(args[:length]) 37 | ctx.input_params = list(args[length:]) 38 | with torch.no_grad(): 39 | output_tensors = ctx.run_function(*ctx.input_tensors) 40 | return output_tensors 41 | 42 | @staticmethod 43 | def backward(ctx, *output_grads): 44 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 45 | with torch.enable_grad(): 46 | # Fixes a bug where the first op in run_function modifies the 47 | # Tensor storage in place, which is not allowed for detach()'d 48 | # Tensors. 49 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 50 | output_tensors = ctx.run_function(*shallow_copies) 51 | input_grads = torch.autograd.grad( 52 | output_tensors, 53 | ctx.input_tensors + ctx.input_params, 54 | output_grads, 55 | allow_unused=True, 56 | ) 57 | del ctx.input_tensors 58 | del ctx.input_params 59 | del output_tensors 60 | return (None, None) + input_grads 61 | -------------------------------------------------------------------------------- /point_e/util/ply_util.py: -------------------------------------------------------------------------------- 1 | import io 2 | import struct 3 | from contextlib import contextmanager 4 | from typing import BinaryIO, Iterator, Optional 5 | 6 | import numpy as np 7 | 8 | 9 | def write_ply( 10 | raw_f: BinaryIO, 11 | coords: np.ndarray, 12 | rgb: Optional[np.ndarray] = None, 13 | faces: Optional[np.ndarray] = None, 14 | ): 15 | """ 16 | Write a PLY file for a mesh or a point cloud. 17 | 18 | :param coords: an [N x 3] array of floating point coordinates. 19 | :param rgb: an [N x 3] array of vertex colors, in the range [0.0, 1.0]. 20 | :param faces: an [N x 3] array of triangles encoded as integer indices. 21 | """ 22 | with buffered_writer(raw_f) as f: 23 | f.write(b"ply\n") 24 | f.write(b"format binary_little_endian 1.0\n") 25 | f.write(bytes(f"element vertex {len(coords)}\n", "ascii")) 26 | f.write(b"property float x\n") 27 | f.write(b"property float y\n") 28 | f.write(b"property float z\n") 29 | if rgb is not None: 30 | f.write(b"property uchar red\n") 31 | f.write(b"property uchar green\n") 32 | f.write(b"property uchar blue\n") 33 | if faces is not None: 34 | f.write(bytes(f"element face {len(faces)}\n", "ascii")) 35 | f.write(b"property list uchar int vertex_index\n") 36 | f.write(b"end_header\n") 37 | 38 | if rgb is not None: 39 | rgb = (rgb * 255.499).round().astype(int) 40 | vertices = [ 41 | (*coord, *rgb) 42 | for coord, rgb in zip( 43 | coords.tolist(), 44 | rgb.tolist(), 45 | ) 46 | ] 47 | format = struct.Struct("<3f3B") 48 | for item in vertices: 49 | f.write(format.pack(*item)) 50 | else: 51 | format = struct.Struct("<3f") 52 | for vertex in coords.tolist(): 53 | f.write(format.pack(*vertex)) 54 | 55 | if faces is not None: 56 | format = struct.Struct(" Iterator[io.BufferedIOBase]: 63 | if isinstance(raw_f, io.BufferedIOBase): 64 | yield raw_f 65 | else: 66 | f = io.BufferedWriter(raw_f) 67 | yield f 68 | f.flush() 69 | -------------------------------------------------------------------------------- /point_e/util/plotting.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from .point_cloud import PointCloud 7 | 8 | 9 | def plot_point_cloud( 10 | pc: PointCloud, 11 | color: bool = True, 12 | grid_size: int = 1, 13 | fixed_bounds: Optional[Tuple[Tuple[float, float, float], Tuple[float, float, float]]] = ( 14 | (-0.75, -0.75, -0.75), 15 | (0.75, 0.75, 0.75), 16 | ), 17 | ): 18 | """ 19 | Render a point cloud as a plot to the given image path. 20 | 21 | :param pc: the PointCloud to plot. 22 | :param image_path: the path to save the image, with a file extension. 23 | :param color: if True, show the RGB colors from the point cloud. 24 | :param grid_size: the number of random rotations to render. 25 | """ 26 | fig = plt.figure(figsize=(8, 8)) 27 | 28 | for i in range(grid_size): 29 | for j in range(grid_size): 30 | ax = fig.add_subplot(grid_size, grid_size, 1 + j + i * grid_size, projection="3d") 31 | color_args = {} 32 | if color: 33 | color_args["c"] = np.stack( 34 | [pc.channels["R"], pc.channels["G"], pc.channels["B"]], axis=-1 35 | ) 36 | c = pc.coords 37 | 38 | if grid_size > 1: 39 | theta = np.pi * 2 * (i * grid_size + j) / (grid_size**2) 40 | rotation = np.array( 41 | [ 42 | [np.cos(theta), -np.sin(theta), 0.0], 43 | [np.sin(theta), np.cos(theta), 0.0], 44 | [0.0, 0.0, 1.0], 45 | ] 46 | ) 47 | c = c @ rotation 48 | 49 | ax.scatter(c[:, 0], c[:, 1], c[:, 2], **color_args) 50 | 51 | if fixed_bounds is None: 52 | min_point = c.min(0) 53 | max_point = c.max(0) 54 | size = (max_point - min_point).max() / 2 55 | center = (min_point + max_point) / 2 56 | ax.set_xlim3d(center[0] - size, center[0] + size) 57 | ax.set_ylim3d(center[1] - size, center[1] + size) 58 | ax.set_zlim3d(center[2] - size, center[2] + size) 59 | else: 60 | ax.set_xlim3d(fixed_bounds[0][0], fixed_bounds[1][0]) 61 | ax.set_ylim3d(fixed_bounds[0][1], fixed_bounds[1][1]) 62 | ax.set_zlim3d(fixed_bounds[0][2], fixed_bounds[1][2]) 63 | 64 | return fig 65 | -------------------------------------------------------------------------------- /point_e/evals/fid_is.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py 3 | """ 4 | 5 | 6 | import warnings 7 | 8 | import numpy as np 9 | from scipy import linalg 10 | 11 | 12 | class InvalidFIDException(Exception): 13 | pass 14 | 15 | 16 | class FIDStatistics: 17 | def __init__(self, mu: np.ndarray, sigma: np.ndarray): 18 | self.mu = mu 19 | self.sigma = sigma 20 | 21 | def frechet_distance(self, other, eps=1e-6): 22 | """ 23 | Compute the Frechet distance between two sets of statistics. 24 | """ 25 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 26 | mu1, sigma1 = self.mu, self.sigma 27 | mu2, sigma2 = other.mu, other.sigma 28 | 29 | mu1 = np.atleast_1d(mu1) 30 | mu2 = np.atleast_1d(mu2) 31 | 32 | sigma1 = np.atleast_2d(sigma1) 33 | sigma2 = np.atleast_2d(sigma2) 34 | 35 | assert ( 36 | mu1.shape == mu2.shape 37 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" 38 | assert ( 39 | sigma1.shape == sigma2.shape 40 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" 41 | 42 | diff = mu1 - mu2 43 | 44 | # product might be almost singular 45 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 46 | if not np.isfinite(covmean).all(): 47 | msg = ( 48 | "fid calculation produces singular product; adding %s to diagonal of cov estimates" 49 | % eps 50 | ) 51 | warnings.warn(msg) 52 | offset = np.eye(sigma1.shape[0]) * eps 53 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 54 | 55 | # numerical error might give slight imaginary component 56 | if np.iscomplexobj(covmean): 57 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 58 | m = np.max(np.abs(covmean.imag)) 59 | raise ValueError("Imaginary component {}".format(m)) 60 | covmean = covmean.real 61 | 62 | tr_covmean = np.trace(covmean) 63 | 64 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 65 | 66 | 67 | def compute_statistics(feats: np.ndarray) -> FIDStatistics: 68 | mu = np.mean(feats, axis=0) 69 | sigma = np.cov(feats, rowvar=False) 70 | return FIDStatistics(mu, sigma) 71 | 72 | 73 | def compute_inception_score(preds: np.ndarray, split_size: int = 5000) -> float: 74 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 75 | scores = [] 76 | for i in range(0, len(preds), split_size): 77 | part = preds[i : i + split_size] 78 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 79 | kl = np.mean(np.sum(kl, 1)) 80 | scores.append(np.exp(kl)) 81 | return float(np.mean(scores)) 82 | -------------------------------------------------------------------------------- /point_e/util/mesh.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import BinaryIO, Dict, Optional, Union 3 | 4 | import numpy as np 5 | 6 | from .ply_util import write_ply 7 | 8 | 9 | @dataclass 10 | class TriMesh: 11 | """ 12 | A 3D triangle mesh with optional data at the vertices and faces. 13 | """ 14 | 15 | # [N x 3] array of vertex coordinates. 16 | verts: np.ndarray 17 | 18 | # [M x 3] array of triangles, pointing to indices in verts. 19 | faces: np.ndarray 20 | 21 | # [P x 3] array of normal vectors per face. 22 | normals: Optional[np.ndarray] = None 23 | 24 | # Extra data per vertex and face. 25 | vertex_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict) 26 | face_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict) 27 | 28 | @classmethod 29 | def load(cls, f: Union[str, BinaryIO]) -> "TriMesh": 30 | """ 31 | Load the mesh from a .npz file. 32 | """ 33 | if isinstance(f, str): 34 | with open(f, "rb") as reader: 35 | return cls.load(reader) 36 | else: 37 | obj = np.load(f) 38 | keys = list(obj.keys()) 39 | verts = obj["verts"] 40 | faces = obj["faces"] 41 | normals = obj["normals"] if "normals" in keys else None 42 | vertex_channels = {} 43 | face_channels = {} 44 | for key in keys: 45 | if key.startswith("v_"): 46 | vertex_channels[key[2:]] = obj[key] 47 | elif key.startswith("f_"): 48 | face_channels[key[2:]] = obj[key] 49 | return cls( 50 | verts=verts, 51 | faces=faces, 52 | normals=normals, 53 | vertex_channels=vertex_channels, 54 | face_channels=face_channels, 55 | ) 56 | 57 | def save(self, f: Union[str, BinaryIO]): 58 | """ 59 | Save the mesh to a .npz file. 60 | """ 61 | if isinstance(f, str): 62 | with open(f, "wb") as writer: 63 | self.save(writer) 64 | else: 65 | obj_dict = dict(verts=self.verts, faces=self.faces) 66 | if self.normals is not None: 67 | obj_dict["normals"] = self.normals 68 | for k, v in self.vertex_channels.items(): 69 | obj_dict[f"v_{k}"] = v 70 | for k, v in self.face_channels.items(): 71 | obj_dict[f"f_{k}"] = v 72 | np.savez(f, **obj_dict) 73 | 74 | def has_vertex_colors(self) -> bool: 75 | return self.vertex_channels is not None and all(x in self.vertex_channels for x in "RGB") 76 | 77 | def write_ply(self, raw_f: BinaryIO): 78 | write_ply( 79 | raw_f, 80 | coords=self.verts, 81 | rgb=( 82 | np.stack([self.vertex_channels[x] for x in "RGB"], axis=1) 83 | if self.has_vertex_colors() 84 | else None 85 | ), 86 | faces=self.faces, 87 | ) 88 | -------------------------------------------------------------------------------- /point_e/models/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py 3 | """ 4 | 5 | import os 6 | from functools import lru_cache 7 | from typing import Dict, Optional 8 | 9 | import requests 10 | import torch 11 | from filelock import FileLock 12 | from tqdm.auto import tqdm 13 | 14 | MODEL_PATHS = { 15 | "base40M-imagevec": "https://openaipublic.azureedge.net/main/point-e/base_40m_imagevec.pt", 16 | "base40M-textvec": "https://openaipublic.azureedge.net/main/point-e/base_40m_textvec.pt", 17 | "base40M-uncond": "https://openaipublic.azureedge.net/main/point-e/base_40m_uncond.pt", 18 | "base40M": "https://openaipublic.azureedge.net/main/point-e/base_40m.pt", 19 | "base300M": "https://openaipublic.azureedge.net/main/point-e/base_300m.pt", 20 | "base1B": "https://openaipublic.azureedge.net/main/point-e/base_1b.pt", 21 | "upsample": "https://openaipublic.azureedge.net/main/point-e/upsample_40m.pt", 22 | "sdf": "https://openaipublic.azureedge.net/main/point-e/sdf.pt", 23 | "pointnet": "https://openaipublic.azureedge.net/main/point-e/pointnet.pt", 24 | } 25 | 26 | 27 | @lru_cache() 28 | def default_cache_dir() -> str: 29 | return os.path.join(os.path.abspath(os.getcwd()), "point_e_model_cache") 30 | 31 | 32 | def fetch_file_cached( 33 | url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096 34 | ) -> str: 35 | """ 36 | Download the file at the given URL into a local file and return the path. 37 | If cache_dir is specified, it will be used to download the files. 38 | Otherwise, default_cache_dir() is used. 39 | """ 40 | if cache_dir is None: 41 | cache_dir = default_cache_dir() 42 | os.makedirs(cache_dir, exist_ok=True) 43 | local_path = os.path.join(cache_dir, url.split("/")[-1]) 44 | if os.path.exists(local_path): 45 | return local_path 46 | 47 | response = requests.get(url, stream=True) 48 | size = int(response.headers.get("content-length", "0")) 49 | with FileLock(local_path + ".lock"): 50 | if progress: 51 | pbar = tqdm(total=size, unit="iB", unit_scale=True) 52 | tmp_path = local_path + ".tmp" 53 | with open(tmp_path, "wb") as f: 54 | for chunk in response.iter_content(chunk_size): 55 | if progress: 56 | pbar.update(len(chunk)) 57 | f.write(chunk) 58 | os.rename(tmp_path, local_path) 59 | if progress: 60 | pbar.close() 61 | return local_path 62 | 63 | 64 | def load_checkpoint( 65 | checkpoint_name: str, 66 | device: torch.device, 67 | progress: bool = True, 68 | cache_dir: Optional[str] = None, 69 | chunk_size: int = 4096, 70 | ) -> Dict[str, torch.Tensor]: 71 | if checkpoint_name not in MODEL_PATHS: 72 | raise ValueError( 73 | f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}." 74 | ) 75 | path = fetch_file_cached( 76 | MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size 77 | ) 78 | return torch.load(path, map_location=device) 79 | -------------------------------------------------------------------------------- /point_e/examples/pointcloud2mesh.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from PIL import Image\n", 10 | "import torch\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "from tqdm.auto import tqdm\n", 13 | "\n", 14 | "from point_e.models.download import load_checkpoint\n", 15 | "from point_e.models.configs import MODEL_CONFIGS, model_from_config\n", 16 | "from point_e.util.pc_to_mesh import marching_cubes_mesh\n", 17 | "from point_e.util.plotting import plot_point_cloud\n", 18 | "from point_e.util.point_cloud import PointCloud" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 28 | "\n", 29 | "print('creating SDF model...')\n", 30 | "name = 'sdf'\n", 31 | "model = model_from_config(MODEL_CONFIGS[name], device)\n", 32 | "model.eval()\n", 33 | "\n", 34 | "print('loading SDF model...')\n", 35 | "model.load_state_dict(load_checkpoint(name, device))" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# Load a point cloud we want to convert into a mesh.\n", 45 | "pc = PointCloud.load('example_data/pc_corgi.npz')\n", 46 | "\n", 47 | "# Plot the point cloud as a sanity check.\n", 48 | "fig = plot_point_cloud(pc, grid_size=2)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# Produce a mesh (with vertex colors)\n", 58 | "mesh = marching_cubes_mesh(\n", 59 | " pc=pc,\n", 60 | " model=model,\n", 61 | " batch_size=4096,\n", 62 | " grid_size=32, # increase to 128 for resolution used in evals\n", 63 | " progress=True,\n", 64 | ")" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Write the mesh to a PLY file to import into some other program.\n", 74 | "with open('mesh.ply', 'wb') as f:\n", 75 | " mesh.write_ply(f)" 76 | ] 77 | } 78 | ], 79 | "metadata": { 80 | "kernelspec": { 81 | "display_name": "Python 3.9.9 64-bit ('3.9.9')", 82 | "language": "python", 83 | "name": "python3" 84 | }, 85 | "language_info": { 86 | "codemirror_mode": { 87 | "name": "ipython", 88 | "version": 3 89 | }, 90 | "file_extension": ".py", 91 | "mimetype": "text/x-python", 92 | "name": "python", 93 | "nbconvert_exporter": "python", 94 | "pygments_lexer": "ipython3", 95 | "version": "3.9.9" 96 | }, 97 | "orig_nbformat": 4, 98 | "vscode": { 99 | "interpreter": { 100 | "hash": "b270b0f43bc427bcab7703c037711644cc480aac7c1cc8d2940cfaf0b447ee2e" 101 | } 102 | } 103 | }, 104 | "nbformat": 4, 105 | "nbformat_minor": 2 106 | } 107 | -------------------------------------------------------------------------------- /point_e/util/pc_to_mesh.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import skimage 5 | import torch 6 | from tqdm.auto import tqdm 7 | 8 | from point_e.models.sdf import PointCloudSDFModel 9 | 10 | from .mesh import TriMesh 11 | from .point_cloud import PointCloud 12 | 13 | 14 | def marching_cubes_mesh( 15 | pc: PointCloud, 16 | model: PointCloudSDFModel, 17 | batch_size: int = 4096, 18 | grid_size: int = 128, 19 | side_length: float = 1.02, 20 | fill_vertex_channels: bool = True, 21 | progress: bool = False, 22 | ) -> TriMesh: 23 | """ 24 | Run marching cubes on the SDF predicted from a point cloud to produce a 25 | mesh representing the 3D surface. 26 | 27 | :param pc: the point cloud to apply marching cubes to. 28 | :param model: the model to use to predict SDF values. 29 | :param grid_size: the number of samples along each axis. A total of 30 | grid_size**3 function evaluations are performed. 31 | :param side_length: the size of the cube containing the model, which is 32 | assumed to be centered at the origin. 33 | :param fill_vertex_channels: if True, use the nearest neighbor of each mesh 34 | vertex in the point cloud to compute vertex 35 | data (e.g. colors). 36 | """ 37 | voxel_size = side_length / (grid_size - 1) 38 | min_coord = -side_length / 2 39 | 40 | def int_coord_to_float(int_coords: torch.Tensor) -> torch.Tensor: 41 | return int_coords.float() * voxel_size + min_coord 42 | 43 | with torch.no_grad(): 44 | cond = model.encode_point_clouds( 45 | torch.from_numpy(pc.coords).permute(1, 0).to(model.device)[None] 46 | ) 47 | 48 | indices = range(0, grid_size**3, batch_size) 49 | if progress: 50 | indices = tqdm(indices) 51 | 52 | volume = [] 53 | for i in indices: 54 | indices = torch.arange( 55 | i, min(i + batch_size, grid_size**3), step=1, dtype=torch.int64, device=model.device 56 | ) 57 | zs = int_coord_to_float(indices % grid_size) 58 | ys = int_coord_to_float(torch.div(indices, grid_size, rounding_mode="trunc") % grid_size) 59 | xs = int_coord_to_float(torch.div(indices, grid_size**2, rounding_mode="trunc")) 60 | coords = torch.stack([xs, ys, zs], dim=0) 61 | with torch.no_grad(): 62 | volume.append(model(coords[None], encoded=cond)[0]) 63 | volume_np = torch.cat(volume).view(grid_size, grid_size, grid_size).cpu().numpy() 64 | 65 | if np.all(volume_np < 0) or np.all(volume_np > 0): 66 | # The volume is invalid for some reason, which will break 67 | # marching cubes unless we center it. 68 | volume_np -= np.mean(volume_np) 69 | 70 | verts, faces, normals, _ = skimage.measure.marching_cubes( 71 | volume=volume_np, 72 | level=0, 73 | allow_degenerate=False, 74 | spacing=(voxel_size,) * 3, 75 | ) 76 | 77 | # The triangles follow the left-hand rule, but we want to 78 | # follow the right-hand rule. 79 | # This syntax might seem roundabout, but we get incorrect 80 | # results if we do: x[:,0], x[:,1] = x[:,1], x[:,0] 81 | old_f1 = faces[:, 0].copy() 82 | faces[:, 0] = faces[:, 1] 83 | faces[:, 1] = old_f1 84 | 85 | verts += min_coord 86 | return TriMesh( 87 | verts=verts, 88 | faces=faces, 89 | normals=normals, 90 | vertex_channels=None if not fill_vertex_channels else _nearest_vertex_channels(pc, verts), 91 | ) 92 | 93 | 94 | def _nearest_vertex_channels(pc: PointCloud, verts: np.ndarray) -> Dict[str, np.ndarray]: 95 | nearest = pc.nearest_points(verts) 96 | return {ch: arr[nearest] for ch, arr in pc.channels.items()} 97 | -------------------------------------------------------------------------------- /point_e/examples/image2pointcloud.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from PIL import Image\n", 10 | "import torch\n", 11 | "from tqdm.auto import tqdm\n", 12 | "\n", 13 | "from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config\n", 14 | "from point_e.diffusion.sampler import PointCloudSampler\n", 15 | "from point_e.models.download import load_checkpoint\n", 16 | "from point_e.models.configs import MODEL_CONFIGS, model_from_config\n", 17 | "from point_e.util.plotting import plot_point_cloud" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 27 | "\n", 28 | "print('creating base model...')\n", 29 | "base_name = 'base40M' # use base300M or base1B for better results\n", 30 | "base_model = model_from_config(MODEL_CONFIGS[base_name], device)\n", 31 | "base_model.eval()\n", 32 | "base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])\n", 33 | "\n", 34 | "print('creating upsample model...')\n", 35 | "upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)\n", 36 | "upsampler_model.eval()\n", 37 | "upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])\n", 38 | "\n", 39 | "print('downloading base checkpoint...')\n", 40 | "base_model.load_state_dict(load_checkpoint(base_name, device))\n", 41 | "\n", 42 | "print('downloading upsampler checkpoint...')\n", 43 | "upsampler_model.load_state_dict(load_checkpoint('upsample', device))" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "sampler = PointCloudSampler(\n", 53 | " device=device,\n", 54 | " models=[base_model, upsampler_model],\n", 55 | " diffusions=[base_diffusion, upsampler_diffusion],\n", 56 | " num_points=[1024, 4096 - 1024],\n", 57 | " aux_channels=['R', 'G', 'B'],\n", 58 | " guidance_scale=[3.0, 3.0],\n", 59 | ")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Load an image to condition on.\n", 69 | "img = Image.open('example_data/cube_stack.jpg')\n", 70 | "\n", 71 | "# Produce a sample from the model.\n", 72 | "samples = None\n", 73 | "for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):\n", 74 | " samples = x" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "pc = sampler.output_to_point_clouds(samples)[0]\n", 84 | "fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))" 85 | ] 86 | } 87 | ], 88 | "metadata": { 89 | "kernelspec": { 90 | "display_name": "Python 3.9.9 64-bit ('3.9.9')", 91 | "language": "python", 92 | "name": "python3" 93 | }, 94 | "language_info": { 95 | "codemirror_mode": { 96 | "name": "ipython", 97 | "version": 3 98 | }, 99 | "file_extension": ".py", 100 | "mimetype": "text/x-python", 101 | "name": "python", 102 | "nbconvert_exporter": "python", 103 | "pygments_lexer": "ipython3", 104 | "version": "3.9.9" 105 | }, 106 | "orig_nbformat": 4, 107 | "vscode": { 108 | "interpreter": { 109 | "hash": "b270b0f43bc427bcab7703c037711644cc480aac7c1cc8d2940cfaf0b447ee2e" 110 | } 111 | } 112 | }, 113 | "nbformat": 4, 114 | "nbformat_minor": 2 115 | } 116 | -------------------------------------------------------------------------------- /point_e/examples/text2pointcloud.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from tqdm.auto import tqdm\n", 11 | "\n", 12 | "from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config\n", 13 | "from point_e.diffusion.sampler import PointCloudSampler\n", 14 | "from point_e.models.download import load_checkpoint\n", 15 | "from point_e.models.configs import MODEL_CONFIGS, model_from_config\n", 16 | "from point_e.util.plotting import plot_point_cloud" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 26 | "\n", 27 | "print('creating base model...')\n", 28 | "base_name = 'base40M-textvec'\n", 29 | "base_model = model_from_config(MODEL_CONFIGS[base_name], device)\n", 30 | "base_model.eval()\n", 31 | "base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])\n", 32 | "\n", 33 | "print('creating upsample model...')\n", 34 | "upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)\n", 35 | "upsampler_model.eval()\n", 36 | "upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])\n", 37 | "\n", 38 | "print('downloading base checkpoint...')\n", 39 | "base_model.load_state_dict(load_checkpoint(base_name, device))\n", 40 | "\n", 41 | "print('downloading upsampler checkpoint...')\n", 42 | "upsampler_model.load_state_dict(load_checkpoint('upsample', device))" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "sampler = PointCloudSampler(\n", 52 | " device=device,\n", 53 | " models=[base_model, upsampler_model],\n", 54 | " diffusions=[base_diffusion, upsampler_diffusion],\n", 55 | " num_points=[1024, 4096 - 1024],\n", 56 | " aux_channels=['R', 'G', 'B'],\n", 57 | " guidance_scale=[3.0, 0.0],\n", 58 | " model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all\n", 59 | ")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Set a prompt to condition on.\n", 69 | "prompt = 'a red motorcycle'\n", 70 | "\n", 71 | "# Produce a sample from the model.\n", 72 | "samples = None\n", 73 | "for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[prompt]))):\n", 74 | " samples = x" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "pc = sampler.output_to_point_clouds(samples)[0]\n", 84 | "fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))" 85 | ] 86 | } 87 | ], 88 | "metadata": { 89 | "kernelspec": { 90 | "display_name": "Python 3.9.9 64-bit ('3.9.9')", 91 | "language": "python", 92 | "name": "python3" 93 | }, 94 | "language_info": { 95 | "codemirror_mode": { 96 | "name": "ipython", 97 | "version": 3 98 | }, 99 | "file_extension": ".py", 100 | "mimetype": "text/x-python", 101 | "name": "python", 102 | "nbconvert_exporter": "python", 103 | "pygments_lexer": "ipython3", 104 | "version": "3.9.9 (main, Aug 15 2022, 16:40:41) \n[Clang 13.1.6 (clang-1316.0.21.2.5)]" 105 | }, 106 | "orig_nbformat": 4, 107 | "vscode": { 108 | "interpreter": { 109 | "hash": "b270b0f43bc427bcab7703c037711644cc480aac7c1cc8d2940cfaf0b447ee2e" 110 | } 111 | } 112 | }, 113 | "nbformat": 4, 114 | "nbformat_minor": 2 115 | } 116 | -------------------------------------------------------------------------------- /point_e/evals/pointnet2_cls_ssg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/eb64fe0b4c24055559cea26299cb485dcb43d8dd/models/pointnet2_cls_ssg.py 3 | 4 | MIT License 5 | 6 | Copyright (c) 2019 benny 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | """ 26 | 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | 30 | from .pointnet2_utils import PointNetSetAbstraction 31 | 32 | 33 | class get_model(nn.Module): 34 | def __init__(self, num_class, normal_channel=True, width_mult=1): 35 | super(get_model, self).__init__() 36 | self.width_mult = width_mult 37 | in_channel = 6 if normal_channel else 3 38 | self.normal_channel = normal_channel 39 | self.sa1 = PointNetSetAbstraction( 40 | npoint=512, 41 | radius=0.2, 42 | nsample=32, 43 | in_channel=in_channel, 44 | mlp=[64 * width_mult, 64 * width_mult, 128 * width_mult], 45 | group_all=False, 46 | ) 47 | self.sa2 = PointNetSetAbstraction( 48 | npoint=128, 49 | radius=0.4, 50 | nsample=64, 51 | in_channel=128 * width_mult + 3, 52 | mlp=[128 * width_mult, 128 * width_mult, 256 * width_mult], 53 | group_all=False, 54 | ) 55 | self.sa3 = PointNetSetAbstraction( 56 | npoint=None, 57 | radius=None, 58 | nsample=None, 59 | in_channel=256 * width_mult + 3, 60 | mlp=[256 * width_mult, 512 * width_mult, 1024 * width_mult], 61 | group_all=True, 62 | ) 63 | self.fc1 = nn.Linear(1024 * width_mult, 512 * width_mult) 64 | self.bn1 = nn.BatchNorm1d(512 * width_mult) 65 | self.drop1 = nn.Dropout(0.4) 66 | self.fc2 = nn.Linear(512 * width_mult, 256 * width_mult) 67 | self.bn2 = nn.BatchNorm1d(256 * width_mult) 68 | self.drop2 = nn.Dropout(0.4) 69 | self.fc3 = nn.Linear(256 * width_mult, num_class) 70 | 71 | def forward(self, xyz, features=False): 72 | B, _, _ = xyz.shape 73 | if self.normal_channel: 74 | norm = xyz[:, 3:, :] 75 | xyz = xyz[:, :3, :] 76 | else: 77 | norm = None 78 | l1_xyz, l1_points = self.sa1(xyz, norm) 79 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 80 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 81 | x = l3_points.view(B, 1024 * self.width_mult) 82 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 83 | result_features = self.bn2(self.fc2(x)) 84 | x = self.drop2(F.relu(result_features)) 85 | x = self.fc3(x) 86 | x = F.log_softmax(x, -1) 87 | 88 | if features: 89 | return x, l3_points, result_features 90 | else: 91 | return x, l3_points 92 | 93 | 94 | class get_loss(nn.Module): 95 | def __init__(self): 96 | super(get_loss, self).__init__() 97 | 98 | def forward(self, pred, target, trans_feat): 99 | total_loss = F.nll_loss(pred, target) 100 | 101 | return total_loss 102 | -------------------------------------------------------------------------------- /point_e/evals/feature_extractor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from multiprocessing.pool import ThreadPool 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from point_e.models.download import load_checkpoint 9 | 10 | from .npz_stream import NpzStreamer 11 | from .pointnet2_cls_ssg import get_model 12 | 13 | 14 | def get_torch_devices() -> List[Union[str, torch.device]]: 15 | if torch.cuda.is_available(): 16 | return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())] 17 | else: 18 | return ["cpu"] 19 | 20 | 21 | class FeatureExtractor(ABC): 22 | @property 23 | @abstractmethod 24 | def supports_predictions(self) -> bool: 25 | pass 26 | 27 | @property 28 | @abstractmethod 29 | def feature_dim(self) -> int: 30 | pass 31 | 32 | @property 33 | @abstractmethod 34 | def num_classes(self) -> int: 35 | pass 36 | 37 | @abstractmethod 38 | def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]: 39 | """ 40 | For a stream of point cloud batches, compute feature vectors and class 41 | predictions. 42 | 43 | :param point_clouds: a streamer for a sample batch. Typically, arr_0 44 | will contain the XYZ coordinates. 45 | :return: a tuple (features, predictions) 46 | - features: a [B x feature_dim] array of feature vectors. 47 | - predictions: a [B x num_classes] array of probabilities. 48 | """ 49 | 50 | 51 | class PointNetClassifier(FeatureExtractor): 52 | def __init__( 53 | self, 54 | devices: List[Union[str, torch.device]], 55 | device_batch_size: int = 64, 56 | cache_dir: Optional[str] = None, 57 | ): 58 | state_dict = load_checkpoint("pointnet", device=torch.device("cpu"), cache_dir=cache_dir)[ 59 | "model_state_dict" 60 | ] 61 | 62 | self.device_batch_size = device_batch_size 63 | self.devices = devices 64 | self.models = [] 65 | for device in devices: 66 | model = get_model(num_class=40, normal_channel=False, width_mult=2) 67 | model.load_state_dict(state_dict) 68 | model.to(device) 69 | model.eval() 70 | self.models.append(model) 71 | 72 | @property 73 | def supports_predictions(self) -> bool: 74 | return True 75 | 76 | @property 77 | def feature_dim(self) -> int: 78 | return 256 79 | 80 | @property 81 | def num_classes(self) -> int: 82 | return 40 83 | 84 | def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]: 85 | batch_size = self.device_batch_size * len(self.devices) 86 | point_clouds = (x["arr_0"] for x in streamer.stream(batch_size, ["arr_0"])) 87 | 88 | output_features = [] 89 | output_predictions = [] 90 | 91 | with ThreadPool(len(self.devices)) as pool: 92 | for batch in point_clouds: 93 | batch = normalize_point_clouds(batch) 94 | batches = [] 95 | for i, device in zip(range(0, len(batch), self.device_batch_size), self.devices): 96 | batches.append( 97 | torch.from_numpy(batch[i : i + self.device_batch_size]) 98 | .permute(0, 2, 1) 99 | .to(dtype=torch.float32, device=device) 100 | ) 101 | 102 | def compute_features(i_batch): 103 | i, batch = i_batch 104 | with torch.no_grad(): 105 | return self.models[i](batch, features=True) 106 | 107 | for logits, _, features in pool.imap(compute_features, enumerate(batches)): 108 | output_features.append(features.cpu().numpy()) 109 | output_predictions.append(logits.exp().cpu().numpy()) 110 | 111 | return np.concatenate(output_features, axis=0), np.concatenate(output_predictions, axis=0) 112 | 113 | 114 | def normalize_point_clouds(pc: np.ndarray) -> np.ndarray: 115 | centroids = np.mean(pc, axis=1, keepdims=True) 116 | pc = pc - centroids 117 | m = np.max(np.sqrt(np.sum(pc**2, axis=-1, keepdims=True)), axis=1, keepdims=True) 118 | pc = pc / m 119 | return pc 120 | -------------------------------------------------------------------------------- /point_e/models/configs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .sdf import CrossAttentionPointCloudSDFModel 7 | from .transformer import ( 8 | CLIPImageGridPointDiffusionTransformer, 9 | CLIPImageGridUpsamplePointDiffusionTransformer, 10 | CLIPImagePointDiffusionTransformer, 11 | PointDiffusionTransformer, 12 | UpsamplePointDiffusionTransformer, 13 | ) 14 | 15 | MODEL_CONFIGS = { 16 | "base40M-imagevec": { 17 | "cond_drop_prob": 0.1, 18 | "heads": 8, 19 | "init_scale": 0.25, 20 | "input_channels": 6, 21 | "layers": 12, 22 | "n_ctx": 1024, 23 | "name": "CLIPImagePointDiffusionTransformer", 24 | "output_channels": 12, 25 | "time_token_cond": True, 26 | "token_cond": True, 27 | "width": 512, 28 | }, 29 | "base40M-textvec": { 30 | "cond_drop_prob": 0.1, 31 | "heads": 8, 32 | "init_scale": 0.25, 33 | "input_channels": 6, 34 | "layers": 12, 35 | "n_ctx": 1024, 36 | "name": "CLIPImagePointDiffusionTransformer", 37 | "output_channels": 12, 38 | "time_token_cond": True, 39 | "token_cond": True, 40 | "width": 512, 41 | }, 42 | "base40M-uncond": { 43 | "heads": 8, 44 | "init_scale": 0.25, 45 | "input_channels": 6, 46 | "layers": 12, 47 | "n_ctx": 1024, 48 | "name": "PointDiffusionTransformer", 49 | "output_channels": 12, 50 | "time_token_cond": True, 51 | "width": 512, 52 | }, 53 | "base40M": { 54 | "cond_drop_prob": 0.1, 55 | "heads": 8, 56 | "init_scale": 0.25, 57 | "input_channels": 6, 58 | "layers": 12, 59 | "n_ctx": 1024, 60 | "name": "CLIPImageGridPointDiffusionTransformer", 61 | "output_channels": 12, 62 | "time_token_cond": True, 63 | "width": 512, 64 | }, 65 | "base300M": { 66 | "cond_drop_prob": 0.1, 67 | "heads": 16, 68 | "init_scale": 0.25, 69 | "input_channels": 6, 70 | "layers": 24, 71 | "n_ctx": 1024, 72 | "name": "CLIPImageGridPointDiffusionTransformer", 73 | "output_channels": 12, 74 | "time_token_cond": True, 75 | "width": 1024, 76 | }, 77 | "base1B": { 78 | "cond_drop_prob": 0.1, 79 | "heads": 32, 80 | "init_scale": 0.25, 81 | "input_channels": 6, 82 | "layers": 24, 83 | "n_ctx": 1024, 84 | "name": "CLIPImageGridPointDiffusionTransformer", 85 | "output_channels": 12, 86 | "time_token_cond": True, 87 | "width": 2048, 88 | }, 89 | "upsample": { 90 | "channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0], 91 | "channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255], 92 | "cond_ctx": 1024, 93 | "cond_drop_prob": 0.1, 94 | "heads": 8, 95 | "init_scale": 0.25, 96 | "input_channels": 6, 97 | "layers": 12, 98 | "n_ctx": 3072, 99 | "name": "CLIPImageGridUpsamplePointDiffusionTransformer", 100 | "output_channels": 12, 101 | "time_token_cond": True, 102 | "width": 512, 103 | }, 104 | "sdf": { 105 | "decoder_heads": 4, 106 | "decoder_layers": 4, 107 | "encoder_heads": 4, 108 | "encoder_layers": 8, 109 | "init_scale": 0.25, 110 | "n_ctx": 4096, 111 | "name": "CrossAttentionPointCloudSDFModel", 112 | "width": 256, 113 | }, 114 | } 115 | 116 | 117 | def model_from_config(config: Dict[str, Any], device: torch.device) -> nn.Module: 118 | config = config.copy() 119 | name = config.pop("name") 120 | if name == "PointDiffusionTransformer": 121 | return PointDiffusionTransformer(device=device, dtype=torch.float32, **config) 122 | elif name == "CLIPImagePointDiffusionTransformer": 123 | return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config) 124 | elif name == "CLIPImageGridPointDiffusionTransformer": 125 | return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config) 126 | elif name == "UpsamplePointDiffusionTransformer": 127 | return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config) 128 | elif name == "CLIPImageGridUpsamplePointDiffusionTransformer": 129 | return CLIPImageGridUpsamplePointDiffusionTransformer( 130 | device=device, dtype=torch.float32, **config 131 | ) 132 | elif name == "CrossAttentionPointCloudSDFModel": 133 | return CrossAttentionPointCloudSDFModel(device=device, dtype=torch.float32, **config) 134 | raise ValueError(f"unknown model name: {name}") 135 | -------------------------------------------------------------------------------- /point_e/models/sdf.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .perceiver import SimplePerceiver 8 | from .transformer import Transformer 9 | 10 | 11 | class PointCloudSDFModel(nn.Module): 12 | @property 13 | @abstractmethod 14 | def device(self) -> torch.device: 15 | """ 16 | Get the device that should be used for input tensors. 17 | """ 18 | 19 | @property 20 | @abstractmethod 21 | def default_batch_size(self) -> int: 22 | """ 23 | Get a reasonable default number of query points for the model. 24 | In some cases, this might be the only supported size. 25 | """ 26 | 27 | @abstractmethod 28 | def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]: 29 | """ 30 | Encode a batch of point clouds to cache part of the SDF calculation 31 | done by forward(). 32 | 33 | :param point_clouds: a batch of [batch x 3 x N] points. 34 | :return: a state representing the encoded point cloud batch. 35 | """ 36 | 37 | def forward( 38 | self, 39 | x: torch.Tensor, 40 | point_clouds: Optional[torch.Tensor] = None, 41 | encoded: Optional[Dict[str, torch.Tensor]] = None, 42 | ) -> torch.Tensor: 43 | """ 44 | Predict the SDF at the coordinates x, given a batch of point clouds. 45 | 46 | Either point_clouds or encoded should be passed. Only exactly one of 47 | these arguments should be None. 48 | 49 | :param x: a [batch x 3 x N'] tensor of query points. 50 | :param point_clouds: a [batch x 3 x N] batch of point clouds. 51 | :param encoded: the result of calling encode_point_clouds(). 52 | :return: a [batch x N'] tensor of SDF predictions. 53 | """ 54 | assert point_clouds is not None or encoded is not None 55 | assert point_clouds is None or encoded is None 56 | if point_clouds is not None: 57 | encoded = self.encode_point_clouds(point_clouds) 58 | return self.predict_sdf(x, encoded) 59 | 60 | @abstractmethod 61 | def predict_sdf( 62 | self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]] 63 | ) -> torch.Tensor: 64 | """ 65 | Predict the SDF at the query points given the encoded point clouds. 66 | 67 | Each query point should be treated independently, only conditioning on 68 | the point clouds themselves. 69 | """ 70 | 71 | 72 | class CrossAttentionPointCloudSDFModel(PointCloudSDFModel): 73 | """ 74 | Encode point clouds using a transformer, and query points using cross 75 | attention to the encoded latents. 76 | """ 77 | 78 | def __init__( 79 | self, 80 | *, 81 | device: torch.device, 82 | dtype: torch.dtype, 83 | n_ctx: int = 4096, 84 | width: int = 512, 85 | encoder_layers: int = 12, 86 | encoder_heads: int = 8, 87 | decoder_layers: int = 4, 88 | decoder_heads: int = 8, 89 | init_scale: float = 0.25, 90 | ): 91 | super().__init__() 92 | self._device = device 93 | self.n_ctx = n_ctx 94 | 95 | self.encoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype) 96 | self.encoder = Transformer( 97 | device=device, 98 | dtype=dtype, 99 | n_ctx=n_ctx, 100 | width=width, 101 | layers=encoder_layers, 102 | heads=encoder_heads, 103 | init_scale=init_scale, 104 | ) 105 | self.decoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype) 106 | self.decoder = SimplePerceiver( 107 | device=device, 108 | dtype=dtype, 109 | n_data=n_ctx, 110 | width=width, 111 | layers=decoder_layers, 112 | heads=decoder_heads, 113 | init_scale=init_scale, 114 | ) 115 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 116 | self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype) 117 | 118 | @property 119 | def device(self) -> torch.device: 120 | return self._device 121 | 122 | @property 123 | def default_batch_size(self) -> int: 124 | return self.n_query 125 | 126 | def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]: 127 | h = self.encoder_input_proj(point_clouds.permute(0, 2, 1)) 128 | h = self.encoder(h) 129 | return dict(latents=h) 130 | 131 | def predict_sdf( 132 | self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]] 133 | ) -> torch.Tensor: 134 | data = encoded["latents"] 135 | x = self.decoder_input_proj(x.permute(0, 2, 1)) 136 | x = self.decoder(x, data) 137 | x = self.ln_post(x) 138 | x = self.output_proj(x) 139 | return x[..., 0] 140 | -------------------------------------------------------------------------------- /point_e/models/perceiver.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .checkpoint import checkpoint 8 | from .transformer import MLP, init_linear 9 | 10 | 11 | class MultiheadCrossAttention(nn.Module): 12 | def __init__( 13 | self, 14 | *, 15 | device: torch.device, 16 | dtype: torch.dtype, 17 | n_data: int, 18 | width: int, 19 | heads: int, 20 | init_scale: float, 21 | data_width: Optional[int] = None, 22 | ): 23 | super().__init__() 24 | self.n_data = n_data 25 | self.width = width 26 | self.heads = heads 27 | self.data_width = width if data_width is None else data_width 28 | self.c_q = nn.Linear(width, width, device=device, dtype=dtype) 29 | self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=dtype) 30 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 31 | self.attention = QKVMultiheadCrossAttention( 32 | device=device, dtype=dtype, heads=heads, n_data=n_data 33 | ) 34 | init_linear(self.c_q, init_scale) 35 | init_linear(self.c_kv, init_scale) 36 | init_linear(self.c_proj, init_scale) 37 | 38 | def forward(self, x, data): 39 | x = self.c_q(x) 40 | data = self.c_kv(data) 41 | x = checkpoint(self.attention, (x, data), (), True) 42 | x = self.c_proj(x) 43 | return x 44 | 45 | 46 | class QKVMultiheadCrossAttention(nn.Module): 47 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: int): 48 | super().__init__() 49 | self.device = device 50 | self.dtype = dtype 51 | self.heads = heads 52 | self.n_data = n_data 53 | 54 | def forward(self, q, kv): 55 | _, n_ctx, _ = q.shape 56 | bs, n_data, width = kv.shape 57 | attn_ch = width // self.heads // 2 58 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 59 | q = q.view(bs, n_ctx, self.heads, -1) 60 | kv = kv.view(bs, n_data, self.heads, -1) 61 | k, v = torch.split(kv, attn_ch, dim=-1) 62 | weight = torch.einsum( 63 | "bthc,bshc->bhts", q * scale, k * scale 64 | ) # More stable with f16 than dividing afterwards 65 | wdtype = weight.dtype 66 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 67 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 68 | 69 | 70 | class ResidualCrossAttentionBlock(nn.Module): 71 | def __init__( 72 | self, 73 | *, 74 | device: torch.device, 75 | dtype: torch.dtype, 76 | n_data: int, 77 | width: int, 78 | heads: int, 79 | data_width: Optional[int] = None, 80 | init_scale: float = 1.0, 81 | ): 82 | super().__init__() 83 | 84 | if data_width is None: 85 | data_width = width 86 | 87 | self.attn = MultiheadCrossAttention( 88 | device=device, 89 | dtype=dtype, 90 | n_data=n_data, 91 | width=width, 92 | heads=heads, 93 | data_width=data_width, 94 | init_scale=init_scale, 95 | ) 96 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 97 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) 98 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 99 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) 100 | 101 | def forward(self, x: torch.Tensor, data: torch.Tensor): 102 | x = x + self.attn(self.ln_1(x), self.ln_2(data)) 103 | x = x + self.mlp(self.ln_3(x)) 104 | return x 105 | 106 | 107 | class SimplePerceiver(nn.Module): 108 | """ 109 | Only does cross attention 110 | """ 111 | 112 | def __init__( 113 | self, 114 | *, 115 | device: torch.device, 116 | dtype: torch.dtype, 117 | n_data: int, 118 | width: int, 119 | layers: int, 120 | heads: int, 121 | init_scale: float = 0.25, 122 | data_width: Optional[int] = None, 123 | ): 124 | super().__init__() 125 | self.width = width 126 | self.layers = layers 127 | init_scale = init_scale * math.sqrt(1.0 / width) 128 | self.resblocks = nn.ModuleList( 129 | [ 130 | ResidualCrossAttentionBlock( 131 | device=device, 132 | dtype=dtype, 133 | n_data=n_data, 134 | width=width, 135 | heads=heads, 136 | init_scale=init_scale, 137 | data_width=data_width, 138 | ) 139 | for _ in range(layers) 140 | ] 141 | ) 142 | 143 | def forward(self, x: torch.Tensor, data: torch.Tensor): 144 | for block in self.resblocks: 145 | x = block(x, data) 146 | return x 147 | -------------------------------------------------------------------------------- /point_e/util/point_cloud.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import BinaryIO, Dict, List, Optional, Union 4 | 5 | import numpy as np 6 | 7 | from .ply_util import write_ply 8 | 9 | COLORS = frozenset(["R", "G", "B", "A"]) 10 | 11 | 12 | def preprocess(data, channel): 13 | if channel in COLORS: 14 | return np.round(data * 255.0) 15 | return data 16 | 17 | 18 | @dataclass 19 | class PointCloud: 20 | """ 21 | An array of points sampled on a surface. Each point may have zero or more 22 | channel attributes. 23 | 24 | :param coords: an [N x 3] array of point coordinates. 25 | :param channels: a dict mapping names to [N] arrays of channel values. 26 | """ 27 | 28 | coords: np.ndarray 29 | channels: Dict[str, np.ndarray] 30 | 31 | @classmethod 32 | def load(cls, f: Union[str, BinaryIO]) -> "PointCloud": 33 | """ 34 | Load the point cloud from a .npz file. 35 | """ 36 | if isinstance(f, str): 37 | with open(f, "rb") as reader: 38 | return cls.load(reader) 39 | else: 40 | obj = np.load(f) 41 | keys = list(obj.keys()) 42 | return PointCloud( 43 | coords=obj["coords"], 44 | channels={k: obj[k] for k in keys if k != "coords"}, 45 | ) 46 | 47 | def save(self, f: Union[str, BinaryIO]): 48 | """ 49 | Save the point cloud to a .npz file. 50 | """ 51 | if isinstance(f, str): 52 | with open(f, "wb") as writer: 53 | self.save(writer) 54 | else: 55 | np.savez(f, coords=self.coords, **self.channels) 56 | 57 | def write_ply(self, raw_f: BinaryIO): 58 | write_ply( 59 | raw_f, 60 | coords=self.coords, 61 | rgb=( 62 | np.stack([self.channels[x] for x in "RGB"], axis=1) 63 | if all(x in self.channels for x in "RGB") 64 | else None 65 | ), 66 | ) 67 | 68 | def random_sample(self, num_points: int, **subsample_kwargs) -> "PointCloud": 69 | """ 70 | Sample a random subset of this PointCloud. 71 | 72 | :param num_points: maximum number of points to sample. 73 | :param subsample_kwargs: arguments to self.subsample(). 74 | :return: a reduced PointCloud, or self if num_points is not less than 75 | the current number of points. 76 | """ 77 | if len(self.coords) <= num_points: 78 | return self 79 | indices = np.random.choice(len(self.coords), size=(num_points,), replace=False) 80 | return self.subsample(indices, **subsample_kwargs) 81 | 82 | def farthest_point_sample( 83 | self, num_points: int, init_idx: Optional[int] = None, **subsample_kwargs 84 | ) -> "PointCloud": 85 | """ 86 | Sample a subset of the point cloud that is evenly distributed in space. 87 | 88 | First, a random point is selected. Then each successive point is chosen 89 | such that it is furthest from the currently selected points. 90 | 91 | The time complexity of this operation is O(NM), where N is the original 92 | number of points and M is the reduced number. Therefore, performance 93 | can be improved by randomly subsampling points with random_sample() 94 | before running farthest_point_sample(). 95 | 96 | :param num_points: maximum number of points to sample. 97 | :param init_idx: if specified, the first point to sample. 98 | :param subsample_kwargs: arguments to self.subsample(). 99 | :return: a reduced PointCloud, or self if num_points is not less than 100 | the current number of points. 101 | """ 102 | if len(self.coords) <= num_points: 103 | return self 104 | init_idx = random.randrange(len(self.coords)) if init_idx is None else init_idx 105 | indices = np.zeros([num_points], dtype=np.int64) 106 | indices[0] = init_idx 107 | sq_norms = np.sum(self.coords**2, axis=-1) 108 | 109 | def compute_dists(idx: int): 110 | # Utilize equality: ||A-B||^2 = ||A||^2 + ||B||^2 - 2*(A @ B). 111 | return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx]) 112 | 113 | cur_dists = compute_dists(init_idx) 114 | for i in range(1, num_points): 115 | idx = np.argmax(cur_dists) 116 | indices[i] = idx 117 | cur_dists = np.minimum(cur_dists, compute_dists(idx)) 118 | return self.subsample(indices, **subsample_kwargs) 119 | 120 | def subsample(self, indices: np.ndarray, average_neighbors: bool = False) -> "PointCloud": 121 | if not average_neighbors: 122 | return PointCloud( 123 | coords=self.coords[indices], 124 | channels={k: v[indices] for k, v in self.channels.items()}, 125 | ) 126 | 127 | new_coords = self.coords[indices] 128 | neighbor_indices = PointCloud(coords=new_coords, channels={}).nearest_points(self.coords) 129 | 130 | # Make sure every point points to itself, which might not 131 | # be the case if points are duplicated or there is rounding 132 | # error. 133 | neighbor_indices[indices] = np.arange(len(indices)) 134 | 135 | new_channels = {} 136 | for k, v in self.channels.items(): 137 | v_sum = np.zeros_like(v[: len(indices)]) 138 | v_count = np.zeros_like(v[: len(indices)]) 139 | np.add.at(v_sum, neighbor_indices, v) 140 | np.add.at(v_count, neighbor_indices, 1) 141 | new_channels[k] = v_sum / v_count 142 | return PointCloud(coords=new_coords, channels=new_channels) 143 | 144 | def select_channels(self, channel_names: List[str]) -> np.ndarray: 145 | data = np.stack([preprocess(self.channels[name], name) for name in channel_names], axis=-1) 146 | return data 147 | 148 | def nearest_points(self, points: np.ndarray, batch_size: int = 16384) -> np.ndarray: 149 | """ 150 | For each point in another set of points, compute the point in this 151 | pointcloud which is closest. 152 | 153 | :param points: an [N x 3] array of points. 154 | :param batch_size: the number of neighbor distances to compute at once. 155 | Smaller values save memory, while larger values may 156 | make the computation faster. 157 | :return: an [N] array of indices into self.coords. 158 | """ 159 | norms = np.sum(self.coords**2, axis=-1) 160 | all_indices = [] 161 | for i in range(0, len(points), batch_size): 162 | batch = points[i : i + batch_size] 163 | dists = norms + np.sum(batch**2, axis=-1)[:, None] - 2 * (batch @ self.coords.T) 164 | all_indices.append(np.argmin(dists, axis=-1)) 165 | return np.concatenate(all_indices, axis=0) 166 | 167 | def combine(self, other: "PointCloud") -> "PointCloud": 168 | assert self.channels.keys() == other.channels.keys() 169 | return PointCloud( 170 | coords=np.concatenate([self.coords, other.coords], axis=0), 171 | channels={ 172 | k: np.concatenate([v, other.channels[k]], axis=0) for k, v in self.channels.items() 173 | }, 174 | ) 175 | -------------------------------------------------------------------------------- /model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: Point-E 2 | 3 | This is the official codebase for running the point cloud diffusion models and SDF regression models described in [Point-E: A System for Generating 3D Point Clouds from Complex Prompts](https://arxiv.org/abs/2212.08751). These models were trained and released by OpenAI. 4 | Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about how the models were trained and evaluated. 5 | 6 | # Model Details 7 | 8 | The Point-E models are trained for use as point cloud diffusion models and SDF regression models. 9 | Our image-conditional models are often capable of producing coherent 3D point clouds, given a single rendering of a 3D object. However, the models sometimes fail to do so, either producing incorrect geometry where the rendering is occluded, or producing geometry that is inconsistent with visible parts of the rendering. The resulting point clouds are relatively low-resolution, and are often noisy and contain defects such as outliers or cracks. 10 | Our text-conditional model is sometimes capable of producing 3D point clouds which can be recognized as the provided text description, especially when the text description is simple. However, we find that this model fails to generalize to complex prompts or unusual objects. 11 | 12 | ## Model Date 13 | 14 | December 2022 15 | 16 | ## Model Versions 17 | 18 | * `base40M-imagevec` - a 40 million parameter image to point cloud model that conditions on a single CLIP ViT-L/14 image vector. This model can be used to generate point clouds from rendered images, but does not perform as well as our other models for this task. 19 | * `base40M-textvec` - a 40 million parameter text to point cloud model that conditions on a single CLIP ViT-L/14 text vector. This model can be used to directly generate point clouds from text descriptions, but only works for simple prompts. 20 | * `base40M-uncond` - a 40 million parameter point cloud diffusion model that generates unconditional samples. This is included only as a baseline. 21 | * `base40M` - a 40 million parameter image to point cloud diffusion model that conditions on the latent grid from a CLIP ViT-L/14 model. This model can be used to generate point clouds from rendered images, but is not as good as the larger models trained on the same task. 22 | * `base300M` - a 300 million parameter image to point cloud diffusion model that conditions on the latent grid from a CLIP ViT-L/14 model. This model can be used to generate point clouds from rendered images, but it is slightly worse than base1B 23 | * `base1B` - a 1 billion parameter image to point cloud diffusion model that conditions on the latent grid from a CLIP ViT-L/14 model. 24 | * `upsample` - a 40 million parameter point cloud upsampling model that can optionally condition on an image as well. This takes a point cloud of 1024 points and upsamples it to 4096 points. 25 | * `sdf` - a small model for predicting signed distance functions from 3D point clouds. This can be used to predict meshes from point clouds. 26 | * `pointnet` - a small point cloud classification model used for our P-FID and P-IS evaluation metrics. 27 | 28 | ## Paper & samples 29 | 30 | [Paper](https://arxiv.org/abs/2212.08751) / [Sample point clouds](point_e/examples/paper_banner.gif) 31 | 32 | # Training data 33 | 34 | These models were trained on a dataset of several million 3D models. We filtered the dataset to avoid flat objects, and used [CLIP](https://github.com/openai/CLIP/blob/main/model-card.md) to cluster the dataset and downweight clusters of 3D models which appeared to contain mostly unrecognizable objects. We additionally down-weighted clusters which appeared to consist of many similar-looking objects. We processed the resulting dataset into renders (RGB point clouds of 4K points each) and text captions from the associated metadata. 35 | Our SDF regression model was trained on a subset of the above dataset. In particular, we only retained 3D meshes which were manifold (i.e. watertight and free of singularities). 36 | 37 | # Evaluated Use 38 | 39 | We release these models to help advance research in generative modeling. Due to the limitations and biases of our models, we do not currently recommend it for commercial use. We understand that our models may be used in ways we haven't anticipated, and that it is difficult to define clear boundaries around what constitutes appropriate "research" use. In particular, we caution against using these models in applications where precision is critical, as subtle flaws in the outputs could lead to errors or inaccuracies. 40 | Functionally, these models are trained to be able to perform the following tasks for research purposes, and are evaluated on these tasks: 41 | 42 | * Generate 3D point clouds conditioned on single rendered images 43 | * Generate 3D point clouds conditioned on text 44 | * Create 3D meshes from noisy 3D point clouds 45 | 46 | Our image-conditional models are intended to produce coherent point clouds, given a representative rendering of a 3D object. However, at their current level of capabilities, the models sometimes fail to generate coherent output, either producing incorrect geometry where the rendering is occluded, or producing geometry that is inconsistent with visible parts of the rendering. The resulting point clouds are relatively low-resolution, and are often noisy and contain defects such as outliers or cracks. 47 | 48 | Our text-conditional model is sometimes capable of producing 3D point clouds which can be recognized as the provided text description, especially when the text description is simple. However, we find that this model fails to generalize to complex prompts or unusual objects. 49 | 50 | # Performance and Limitations 51 | 52 | Our image-conditional models are limited by the text-to-image model that is used to produce synthetic views. If the text-to-image model contains a bias or fails to understand a particular concept, these limitations will be passed down to the image-conditional point cloud model through conditioning images. 53 | While our main focus is on image-conditional models, we also experimented with a text-conditional model. We find that this model can sometimes produce 3D models of people that exhibit gender biases (for example, samples for "a man" tend to be wider and less narrow than samples for "a woman"). We additionally find that this model is sometimes capable of producing violent objects such as guns or tanks, although these generations are always low-quality and unrealistic. 54 | 55 | Since our dataset contains many simplistic, cartoonish 3D objects, our models are prone to mimicking this style. 56 | 57 | While these models were developed for research purposes, they have potential implications if used more broadly. For example, the ability to generate 3D point clouds from single images could help advance research in computer graphics, virtual reality, and robotics. The text-conditional model could allow for users to easily create 3D models from simple descriptions, which could be useful for rapid prototyping or 3D printing. 58 | 59 | The combination of these models with 3D printing could potentially be harmful, for example if used to prototype dangerous objects or when parts created by the model are trusted without external validation. 60 | 61 | Finally, point cloud models inherit many of the same risks and limitations as image-generation models, including the propensity to produce biased or otherwise harmful content or to carry dual-use risk. More research is needed on how these risks manifest themselves as capabilities improve. 62 | 63 | -------------------------------------------------------------------------------- /point_e/evals/npz_stream.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import io 3 | import os 4 | import re 5 | import zipfile 6 | from abc import ABC, abstractmethod 7 | from contextlib import contextmanager 8 | from dataclasses import dataclass 9 | from typing import Dict, Iterator, List, Optional, Sequence, Tuple 10 | 11 | import numpy as np 12 | 13 | 14 | @dataclass 15 | class NumpyArrayInfo: 16 | """ 17 | Information about an array in an npz file. 18 | """ 19 | 20 | name: str 21 | dtype: np.dtype 22 | shape: Tuple[int] 23 | 24 | @classmethod 25 | def infos_from_first_file(cls, glob_path: str) -> Dict[str, "NumpyArrayInfo"]: 26 | paths, _ = _npz_paths_and_length(glob_path) 27 | return cls.infos_from_file(paths[0]) 28 | 29 | @classmethod 30 | def infos_from_file(cls, npz_path: str) -> Dict[str, "NumpyArrayInfo"]: 31 | """ 32 | Extract the info of every array in an npz file. 33 | """ 34 | if not os.path.exists(npz_path): 35 | raise FileNotFoundError(f"batch of samples was not found: {npz_path}") 36 | results = {} 37 | with open(npz_path, "rb") as f: 38 | with zipfile.ZipFile(f, "r") as zip_f: 39 | for name in zip_f.namelist(): 40 | if not name.endswith(".npy"): 41 | continue 42 | key_name = name[: -len(".npy")] 43 | with zip_f.open(name, "r") as arr_f: 44 | version = np.lib.format.read_magic(arr_f) 45 | if version == (1, 0): 46 | header = np.lib.format.read_array_header_1_0(arr_f) 47 | elif version == (2, 0): 48 | header = np.lib.format.read_array_header_2_0(arr_f) 49 | else: 50 | raise ValueError(f"unknown numpy array version: {version}") 51 | shape, _, dtype = header 52 | results[key_name] = cls(name=key_name, dtype=dtype, shape=shape) 53 | return results 54 | 55 | @property 56 | def elem_shape(self) -> Tuple[int]: 57 | return self.shape[1:] 58 | 59 | def validate(self): 60 | if self.name in {"R", "G", "B"}: 61 | if len(self.shape) != 2: 62 | raise ValueError( 63 | f"expecting exactly 2-D shape for '{self.name}' but got: {self.shape}" 64 | ) 65 | elif self.name == "arr_0": 66 | if len(self.shape) < 2: 67 | raise ValueError(f"expecting at least 2-D shape but got: {self.shape}") 68 | elif len(self.shape) == 3: 69 | # For audio, we require continuous samples. 70 | if not np.issubdtype(self.dtype, np.floating): 71 | raise ValueError( 72 | f"invalid dtype for audio batch: {self.dtype} (expected float)" 73 | ) 74 | elif self.dtype != np.uint8: 75 | raise ValueError(f"invalid dtype for image batch: {self.dtype} (expected uint8)") 76 | 77 | 78 | class NpzStreamer: 79 | def __init__(self, glob_path: str): 80 | self.paths, self.trunc_length = _npz_paths_and_length(glob_path) 81 | self.infos = NumpyArrayInfo.infos_from_file(self.paths[0]) 82 | 83 | def keys(self) -> List[str]: 84 | return list(self.infos.keys()) 85 | 86 | def stream(self, batch_size: int, keys: Sequence[str]) -> Iterator[Dict[str, np.ndarray]]: 87 | cur_batch = None 88 | num_remaining = self.trunc_length 89 | for path in self.paths: 90 | if num_remaining is not None and num_remaining <= 0: 91 | break 92 | with open_npz_arrays(path, keys) as readers: 93 | combined_reader = CombinedReader(keys, readers) 94 | while num_remaining is None or num_remaining > 0: 95 | read_bs = batch_size 96 | if cur_batch is not None: 97 | read_bs -= _dict_batch_size(cur_batch) 98 | if num_remaining is not None: 99 | read_bs = min(read_bs, num_remaining) 100 | 101 | batch = combined_reader.read_batch(read_bs) 102 | if batch is None: 103 | break 104 | if num_remaining is not None: 105 | num_remaining -= _dict_batch_size(batch) 106 | if cur_batch is None: 107 | cur_batch = batch 108 | else: 109 | cur_batch = { 110 | # pylint: disable=unsubscriptable-object 111 | k: np.concatenate([cur_batch[k], v], axis=0) 112 | for k, v in batch.items() 113 | } 114 | if _dict_batch_size(cur_batch) == batch_size: 115 | yield cur_batch 116 | cur_batch = None 117 | if cur_batch is not None: 118 | yield cur_batch 119 | 120 | 121 | def _npz_paths_and_length(glob_path: str) -> Tuple[List[str], Optional[int]]: 122 | # Match slice syntax like path[:100]. 123 | count_match = re.match("^(.*)\\[:([0-9]*)\\]$", glob_path) 124 | if count_match: 125 | raw_path = count_match[1] 126 | max_count = int(count_match[2]) 127 | else: 128 | raw_path = glob_path 129 | max_count = None 130 | paths = sorted(glob.glob(raw_path)) 131 | if not len(paths): 132 | raise ValueError(f"no paths found matching: {glob_path}") 133 | return paths, max_count 134 | 135 | 136 | class NpzArrayReader(ABC): 137 | @abstractmethod 138 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 139 | pass 140 | 141 | 142 | class StreamingNpzArrayReader(NpzArrayReader): 143 | def __init__(self, arr_f, shape, dtype): 144 | self.arr_f = arr_f 145 | self.shape = shape 146 | self.dtype = dtype 147 | self.idx = 0 148 | 149 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 150 | if self.idx >= self.shape[0]: 151 | return None 152 | 153 | bs = min(batch_size, self.shape[0] - self.idx) 154 | self.idx += bs 155 | 156 | if self.dtype.itemsize == 0: 157 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) 158 | 159 | read_count = bs * np.prod(self.shape[1:]) 160 | read_size = int(read_count * self.dtype.itemsize) 161 | data = _read_bytes(self.arr_f, read_size, "array data") 162 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) 163 | 164 | 165 | class MemoryNpzArrayReader(NpzArrayReader): 166 | def __init__(self, arr): 167 | self.arr = arr 168 | self.idx = 0 169 | 170 | @classmethod 171 | def load(cls, path: str, arr_name: str): 172 | with open(path, "rb") as f: 173 | arr = np.load(f)[arr_name] 174 | return cls(arr) 175 | 176 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 177 | if self.idx >= self.arr.shape[0]: 178 | return None 179 | 180 | res = self.arr[self.idx : self.idx + batch_size] 181 | self.idx += batch_size 182 | return res 183 | 184 | 185 | @contextmanager 186 | def open_npz_arrays(path: str, arr_names: Sequence[str]) -> List[NpzArrayReader]: 187 | if not len(arr_names): 188 | yield [] 189 | return 190 | arr_name = arr_names[0] 191 | with open_array(path, arr_name) as arr_f: 192 | version = np.lib.format.read_magic(arr_f) 193 | header = None 194 | if version == (1, 0): 195 | header = np.lib.format.read_array_header_1_0(arr_f) 196 | elif version == (2, 0): 197 | header = np.lib.format.read_array_header_2_0(arr_f) 198 | 199 | if header is None: 200 | reader = MemoryNpzArrayReader.load(path, arr_name) 201 | else: 202 | shape, fortran, dtype = header 203 | if fortran or dtype.hasobject: 204 | reader = MemoryNpzArrayReader.load(path, arr_name) 205 | else: 206 | reader = StreamingNpzArrayReader(arr_f, shape, dtype) 207 | 208 | with open_npz_arrays(path, arr_names[1:]) as next_readers: 209 | yield [reader] + next_readers 210 | 211 | 212 | class CombinedReader: 213 | def __init__(self, keys: List[str], readers: List[NpzArrayReader]): 214 | self.keys = keys 215 | self.readers = readers 216 | 217 | def read_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]]: 218 | batches = [r.read_batch(batch_size) for r in self.readers] 219 | any_none = any(x is None for x in batches) 220 | all_none = all(x is None for x in batches) 221 | if any_none != all_none: 222 | raise RuntimeError("different keys had different numbers of elements") 223 | if any_none: 224 | return None 225 | if any(len(x) != len(batches[0]) for x in batches): 226 | raise RuntimeError("different keys had different numbers of elements") 227 | return dict(zip(self.keys, batches)) 228 | 229 | 230 | def _read_bytes(fp, size, error_template="ran out of data"): 231 | """ 232 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 233 | 234 | Read from file-like object until size bytes are read. 235 | Raises ValueError if not EOF is encountered before size bytes are read. 236 | Non-blocking objects only supported if they derive from io objects. 237 | Required as e.g. ZipExtFile in python 2.6 can return less data than 238 | requested. 239 | """ 240 | data = bytes() 241 | while True: 242 | # io files (default in python3) return None or raise on 243 | # would-block, python2 file will truncate, probably nothing can be 244 | # done about that. note that regular files can't be non-blocking 245 | try: 246 | r = fp.read(size - len(data)) 247 | data += r 248 | if len(r) == 0 or len(data) == size: 249 | break 250 | except io.BlockingIOError: 251 | pass 252 | if len(data) != size: 253 | msg = "EOF: reading %s, expected %d bytes got %d" 254 | raise ValueError(msg % (error_template, size, len(data))) 255 | else: 256 | return data 257 | 258 | 259 | @contextmanager 260 | def open_array(path: str, arr_name: str): 261 | with open(path, "rb") as f: 262 | with zipfile.ZipFile(f, "r") as zip_f: 263 | if f"{arr_name}.npy" not in zip_f.namelist(): 264 | raise ValueError(f"missing {arr_name} in npz file") 265 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f: 266 | yield arr_f 267 | 268 | 269 | def _dict_batch_size(objs: Dict[str, np.ndarray]) -> int: 270 | return len(next(iter(objs.values()))) 271 | -------------------------------------------------------------------------------- /point_e/models/pretrained_clip.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from PIL import Image 7 | 8 | from .download import default_cache_dir 9 | 10 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 11 | 12 | 13 | class ImageCLIP(nn.Module): 14 | """ 15 | A wrapper around a pre-trained CLIP model that automatically handles 16 | batches of texts, images, and embeddings. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | device: torch.device, 22 | dtype: Optional[torch.dtype] = torch.float32, 23 | ensure_used_params: bool = True, 24 | clip_name: str = "ViT-L/14", 25 | cache_dir: Optional[str] = None, 26 | ): 27 | super().__init__() 28 | 29 | assert clip_name in ["ViT-L/14", "ViT-B/32"] 30 | 31 | self.device = device 32 | self.ensure_used_params = ensure_used_params 33 | 34 | # Lazy import because of torchvision. 35 | import clip 36 | 37 | self.clip_model, self.preprocess = clip.load( 38 | clip_name, device=device, download_root=cache_dir or default_cache_dir() 39 | ) 40 | self.clip_name = clip_name 41 | 42 | if dtype is not None: 43 | self.clip_model.to(dtype) 44 | self._tokenize = clip.tokenize 45 | 46 | @property 47 | def feature_dim(self) -> int: 48 | if self.clip_name == "ViT-L/14": 49 | return 768 50 | else: 51 | return 512 52 | 53 | @property 54 | def grid_size(self) -> int: 55 | if self.clip_name == "ViT-L/14": 56 | return 16 57 | else: 58 | return 7 59 | 60 | @property 61 | def grid_feature_dim(self) -> int: 62 | if self.clip_name == "ViT-L/14": 63 | return 1024 64 | else: 65 | return 768 66 | 67 | def forward( 68 | self, 69 | batch_size: int, 70 | images: Optional[Iterable[Optional[ImageType]]] = None, 71 | texts: Optional[Iterable[Optional[str]]] = None, 72 | embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, 73 | ) -> torch.Tensor: 74 | """ 75 | Generate a batch of embeddings from a mixture of images, texts, 76 | precomputed embeddings, and possibly empty values. 77 | 78 | For each batch element, at most one of images, texts, and embeddings 79 | should have a non-None value. Embeddings from multiple modalities 80 | cannot be mixed for a single batch element. If no modality is provided, 81 | a zero embedding will be used for the batch element. 82 | """ 83 | image_seq = [None] * batch_size if images is None else list(images) 84 | text_seq = [None] * batch_size if texts is None else list(texts) 85 | embedding_seq = [None] * batch_size if embeddings is None else list(embeddings) 86 | assert len(image_seq) == batch_size, "number of images should match batch size" 87 | assert len(text_seq) == batch_size, "number of texts should match batch size" 88 | assert len(embedding_seq) == batch_size, "number of embeddings should match batch size" 89 | 90 | if self.ensure_used_params: 91 | return self._static_multimodal_embed( 92 | images=image_seq, texts=text_seq, embeddings=embedding_seq 93 | ) 94 | 95 | result = torch.zeros((batch_size, self.feature_dim), device=self.device) 96 | index_images = [] 97 | index_texts = [] 98 | for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)): 99 | assert ( 100 | sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2 101 | ), "only one modality may be non-None per batch element" 102 | if image is not None: 103 | index_images.append((i, image)) 104 | elif text is not None: 105 | index_texts.append((i, text)) 106 | elif emb is not None: 107 | result[i] = emb.to(result) 108 | 109 | if len(index_images): 110 | embs = self.embed_images((img for _, img in index_images)) 111 | for (i, _), emb in zip(index_images, embs): 112 | result[i] = emb.to(result) 113 | if len(index_texts): 114 | embs = self.embed_text((text for _, text in index_texts)) 115 | for (i, _), emb in zip(index_texts, embs): 116 | result[i] = emb.to(result) 117 | 118 | return result 119 | 120 | def _static_multimodal_embed( 121 | self, 122 | images: List[Optional[ImageType]] = None, 123 | texts: List[Optional[str]] = None, 124 | embeddings: List[Optional[torch.Tensor]] = None, 125 | ) -> torch.Tensor: 126 | """ 127 | Like forward(), but always runs all encoders to ensure that 128 | the forward graph looks the same on every rank. 129 | """ 130 | image_emb = self.embed_images(images) 131 | text_emb = self.embed_text(t if t else "" for t in texts) 132 | joined_embs = torch.stack( 133 | [ 134 | emb.to(device=self.device, dtype=torch.float32) 135 | if emb is not None 136 | else torch.zeros(self.feature_dim, device=self.device) 137 | for emb in embeddings 138 | ], 139 | dim=0, 140 | ) 141 | 142 | image_flag = torch.tensor([x is not None for x in images], device=self.device)[ 143 | :, None 144 | ].expand_as(image_emb) 145 | text_flag = torch.tensor([x is not None for x in texts], device=self.device)[ 146 | :, None 147 | ].expand_as(image_emb) 148 | emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[ 149 | :, None 150 | ].expand_as(image_emb) 151 | 152 | return ( 153 | image_flag.float() * image_emb 154 | + text_flag.float() * text_emb 155 | + emb_flag.float() * joined_embs 156 | + self.clip_model.logit_scale * 0 # avoid unused parameters 157 | ) 158 | 159 | def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: 160 | """ 161 | :param xs: N images, stored as numpy arrays, tensors, or PIL images. 162 | :return: an [N x D] tensor of features. 163 | """ 164 | clip_inputs = self.images_to_tensor(xs) 165 | results = self.clip_model.encode_image(clip_inputs).float() 166 | return results / torch.linalg.norm(results, dim=-1, keepdim=True) 167 | 168 | def embed_text(self, prompts: Iterable[str]) -> torch.Tensor: 169 | """ 170 | Embed text prompts as an [N x D] tensor. 171 | """ 172 | enc = self.clip_model.encode_text( 173 | self._tokenize(list(prompts), truncate=True).to(self.device) 174 | ).float() 175 | return enc / torch.linalg.norm(enc, dim=-1, keepdim=True) 176 | 177 | def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: 178 | """ 179 | Embed images into latent grids. 180 | 181 | :param xs: an iterable of images to embed. 182 | :return: a tensor of shape [N x C x L], where L = self.grid_size**2. 183 | """ 184 | if self.ensure_used_params: 185 | extra_value = 0.0 186 | for p in self.parameters(): 187 | extra_value = extra_value + p.mean() * 0.0 188 | else: 189 | extra_value = 0.0 190 | 191 | x = self.images_to_tensor(xs).to(self.clip_model.dtype) 192 | 193 | # https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225 194 | vt = self.clip_model.visual 195 | x = vt.conv1(x) # shape = [*, width, grid, grid] 196 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 197 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 198 | x = torch.cat( 199 | [ 200 | vt.class_embedding.to(x.dtype) 201 | + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 202 | x, 203 | ], 204 | dim=1, 205 | ) # shape = [*, grid ** 2 + 1, width] 206 | x = x + vt.positional_embedding.to(x.dtype) 207 | x = vt.ln_pre(x) 208 | 209 | x = x.permute(1, 0, 2) # NLD -> LND 210 | x = vt.transformer(x) 211 | x = x.permute(1, 2, 0) # LND -> NDL 212 | 213 | return x[..., 1:].contiguous().float() + extra_value 214 | 215 | def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: 216 | return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device) 217 | 218 | 219 | class FrozenImageCLIP: 220 | def __init__(self, device: torch.device, **kwargs): 221 | self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs) 222 | for parameter in self.model.parameters(): 223 | parameter.requires_grad_(False) 224 | 225 | @property 226 | def feature_dim(self) -> int: 227 | return self.model.feature_dim 228 | 229 | @property 230 | def grid_size(self) -> int: 231 | return self.model.grid_size 232 | 233 | @property 234 | def grid_feature_dim(self) -> int: 235 | return self.model.grid_feature_dim 236 | 237 | def __call__( 238 | self, 239 | batch_size: int, 240 | images: Optional[Iterable[Optional[ImageType]]] = None, 241 | texts: Optional[Iterable[Optional[str]]] = None, 242 | embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, 243 | ) -> torch.Tensor: 244 | # We don't do a no_grad() here so that gradients could still 245 | # flow to the input embeddings argument. 246 | # This behavior is currently not used, but it could be. 247 | return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings) 248 | 249 | def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: 250 | with torch.no_grad(): 251 | return self.model.embed_images(xs) 252 | 253 | def embed_text(self, prompts: Iterable[str]) -> torch.Tensor: 254 | with torch.no_grad(): 255 | return self.model.embed_text(prompts) 256 | 257 | def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: 258 | with torch.no_grad(): 259 | return self.model.embed_images_grid(xs) 260 | 261 | 262 | def _image_to_pil(obj: Optional[ImageType]) -> Image.Image: 263 | if obj is None: 264 | return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8)) 265 | if isinstance(obj, np.ndarray): 266 | return Image.fromarray(obj.astype(np.uint8)) 267 | elif isinstance(obj, torch.Tensor): 268 | return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8)) 269 | else: 270 | return obj 271 | -------------------------------------------------------------------------------- /point_e/diffusion/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for sampling from a single- or multi-stage point cloud diffusion model. 3 | """ 4 | 5 | from typing import Any, Callable, Dict, Iterator, List, Sequence, Tuple 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from point_e.util.point_cloud import PointCloud 11 | 12 | from .gaussian_diffusion import GaussianDiffusion 13 | from .k_diffusion import karras_sample_progressive 14 | 15 | 16 | class PointCloudSampler: 17 | """ 18 | A wrapper around a model or stack of models that produces conditional or 19 | unconditional sample tensors. 20 | 21 | By default, this will load models and configs from files. 22 | If you want to modify the sampler arguments of an existing sampler, call 23 | with_options() or with_args(). 24 | """ 25 | 26 | def __init__( 27 | self, 28 | device: torch.device, 29 | models: Sequence[nn.Module], 30 | diffusions: Sequence[GaussianDiffusion], 31 | num_points: Sequence[int], 32 | aux_channels: Sequence[str], 33 | model_kwargs_key_filter: Sequence[str] = ("*",), 34 | guidance_scale: Sequence[float] = (3.0, 3.0), 35 | clip_denoised: bool = True, 36 | use_karras: Sequence[bool] = (True, True), 37 | karras_steps: Sequence[int] = (64, 64), 38 | sigma_min: Sequence[float] = (1e-3, 1e-3), 39 | sigma_max: Sequence[float] = (120, 160), 40 | s_churn: Sequence[float] = (3, 0), 41 | ): 42 | n = len(models) 43 | assert n > 0 44 | 45 | if n > 1: 46 | if len(guidance_scale) == 1: 47 | # Don't guide the upsamplers by default. 48 | guidance_scale = list(guidance_scale) + [1.0] * (n - 1) 49 | if len(use_karras) == 1: 50 | use_karras = use_karras * n 51 | if len(karras_steps) == 1: 52 | karras_steps = karras_steps * n 53 | if len(sigma_min) == 1: 54 | sigma_min = sigma_min * n 55 | if len(sigma_max) == 1: 56 | sigma_max = sigma_max * n 57 | if len(s_churn) == 1: 58 | s_churn = s_churn * n 59 | if len(model_kwargs_key_filter) == 1: 60 | model_kwargs_key_filter = model_kwargs_key_filter * n 61 | if len(model_kwargs_key_filter) == 0: 62 | model_kwargs_key_filter = ["*"] * n 63 | assert len(guidance_scale) == n 64 | assert len(use_karras) == n 65 | assert len(karras_steps) == n 66 | assert len(sigma_min) == n 67 | assert len(sigma_max) == n 68 | assert len(s_churn) == n 69 | assert len(model_kwargs_key_filter) == n 70 | 71 | self.device = device 72 | self.num_points = num_points 73 | self.aux_channels = aux_channels 74 | self.model_kwargs_key_filter = model_kwargs_key_filter 75 | self.guidance_scale = guidance_scale 76 | self.clip_denoised = clip_denoised 77 | self.use_karras = use_karras 78 | self.karras_steps = karras_steps 79 | self.sigma_min = sigma_min 80 | self.sigma_max = sigma_max 81 | self.s_churn = s_churn 82 | 83 | self.models = models 84 | self.diffusions = diffusions 85 | 86 | @property 87 | def num_stages(self) -> int: 88 | return len(self.models) 89 | 90 | def sample_batch(self, batch_size: int, model_kwargs: Dict[str, Any]) -> torch.Tensor: 91 | samples = None 92 | for x in self.sample_batch_progressive(batch_size, model_kwargs): 93 | samples = x 94 | return samples 95 | 96 | def sample_batch_progressive( 97 | self, batch_size: int, model_kwargs: Dict[str, Any] 98 | ) -> Iterator[torch.Tensor]: 99 | samples = None 100 | for ( 101 | model, 102 | diffusion, 103 | stage_num_points, 104 | stage_guidance_scale, 105 | stage_use_karras, 106 | stage_karras_steps, 107 | stage_sigma_min, 108 | stage_sigma_max, 109 | stage_s_churn, 110 | stage_key_filter, 111 | ) in zip( 112 | self.models, 113 | self.diffusions, 114 | self.num_points, 115 | self.guidance_scale, 116 | self.use_karras, 117 | self.karras_steps, 118 | self.sigma_min, 119 | self.sigma_max, 120 | self.s_churn, 121 | self.model_kwargs_key_filter, 122 | ): 123 | stage_model_kwargs = model_kwargs.copy() 124 | if stage_key_filter != "*": 125 | use_keys = set(stage_key_filter.split(",")) 126 | stage_model_kwargs = {k: v for k, v in stage_model_kwargs.items() if k in use_keys} 127 | if samples is not None: 128 | stage_model_kwargs["low_res"] = samples 129 | if hasattr(model, "cached_model_kwargs"): 130 | stage_model_kwargs = model.cached_model_kwargs(batch_size, stage_model_kwargs) 131 | sample_shape = (batch_size, 3 + len(self.aux_channels), stage_num_points) 132 | 133 | if stage_guidance_scale != 1 and stage_guidance_scale != 0: 134 | for k, v in stage_model_kwargs.copy().items(): 135 | stage_model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0) 136 | 137 | if stage_use_karras: 138 | samples_it = karras_sample_progressive( 139 | diffusion=diffusion, 140 | model=model, 141 | shape=sample_shape, 142 | steps=stage_karras_steps, 143 | clip_denoised=self.clip_denoised, 144 | model_kwargs=stage_model_kwargs, 145 | device=self.device, 146 | sigma_min=stage_sigma_min, 147 | sigma_max=stage_sigma_max, 148 | s_churn=stage_s_churn, 149 | guidance_scale=stage_guidance_scale, 150 | ) 151 | else: 152 | internal_batch_size = batch_size 153 | if stage_guidance_scale: 154 | model = self._uncond_guide_model(model, stage_guidance_scale) 155 | internal_batch_size *= 2 156 | samples_it = diffusion.p_sample_loop_progressive( 157 | model, 158 | shape=(internal_batch_size, *sample_shape[1:]), 159 | model_kwargs=stage_model_kwargs, 160 | device=self.device, 161 | clip_denoised=self.clip_denoised, 162 | ) 163 | for x in samples_it: 164 | samples = x["pred_xstart"][:batch_size] 165 | if "low_res" in stage_model_kwargs: 166 | samples = torch.cat( 167 | [stage_model_kwargs["low_res"][: len(samples)], samples], dim=-1 168 | ) 169 | yield samples 170 | 171 | @classmethod 172 | def combine(cls, *samplers: "PointCloudSampler") -> "PointCloudSampler": 173 | assert all(x.device == samplers[0].device for x in samplers[1:]) 174 | assert all(x.aux_channels == samplers[0].aux_channels for x in samplers[1:]) 175 | assert all(x.clip_denoised == samplers[0].clip_denoised for x in samplers[1:]) 176 | return cls( 177 | device=samplers[0].device, 178 | models=[x for y in samplers for x in y.models], 179 | diffusions=[x for y in samplers for x in y.diffusions], 180 | num_points=[x for y in samplers for x in y.num_points], 181 | aux_channels=samplers[0].aux_channels, 182 | model_kwargs_key_filter=[x for y in samplers for x in y.model_kwargs_key_filter], 183 | guidance_scale=[x for y in samplers for x in y.guidance_scale], 184 | clip_denoised=samplers[0].clip_denoised, 185 | use_karras=[x for y in samplers for x in y.use_karras], 186 | karras_steps=[x for y in samplers for x in y.karras_steps], 187 | sigma_min=[x for y in samplers for x in y.sigma_min], 188 | sigma_max=[x for y in samplers for x in y.sigma_max], 189 | s_churn=[x for y in samplers for x in y.s_churn], 190 | ) 191 | 192 | def _uncond_guide_model( 193 | self, model: Callable[..., torch.Tensor], scale: float 194 | ) -> Callable[..., torch.Tensor]: 195 | def model_fn(x_t, ts, **kwargs): 196 | half = x_t[: len(x_t) // 2] 197 | combined = torch.cat([half, half], dim=0) 198 | model_out = model(combined, ts, **kwargs) 199 | eps, rest = model_out[:, :3], model_out[:, 3:] 200 | cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) 201 | half_eps = uncond_eps + scale * (cond_eps - uncond_eps) 202 | eps = torch.cat([half_eps, half_eps], dim=0) 203 | return torch.cat([eps, rest], dim=1) 204 | 205 | return model_fn 206 | 207 | def split_model_output( 208 | self, 209 | output: torch.Tensor, 210 | rescale_colors: bool = False, 211 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 212 | assert ( 213 | len(self.aux_channels) + 3 == output.shape[1] 214 | ), "there must be three spatial channels before aux" 215 | pos, joined_aux = output[:, :3], output[:, 3:] 216 | 217 | aux = {} 218 | for i, name in enumerate(self.aux_channels): 219 | v = joined_aux[:, i] 220 | if name in {"R", "G", "B", "A"}: 221 | v = v.clamp(0, 255).round() 222 | if rescale_colors: 223 | v = v / 255.0 224 | aux[name] = v 225 | return pos, aux 226 | 227 | def output_to_point_clouds(self, output: torch.Tensor) -> List[PointCloud]: 228 | res = [] 229 | for sample in output: 230 | xyz, aux = self.split_model_output(sample[None], rescale_colors=True) 231 | res.append( 232 | PointCloud( 233 | coords=xyz[0].t().cpu().numpy(), 234 | channels={k: v[0].cpu().numpy() for k, v in aux.items()}, 235 | ) 236 | ) 237 | return res 238 | 239 | def with_options( 240 | self, 241 | guidance_scale: float, 242 | clip_denoised: bool, 243 | use_karras: Sequence[bool] = (True, True), 244 | karras_steps: Sequence[int] = (64, 64), 245 | sigma_min: Sequence[float] = (1e-3, 1e-3), 246 | sigma_max: Sequence[float] = (120, 160), 247 | s_churn: Sequence[float] = (3, 0), 248 | ) -> "PointCloudSampler": 249 | return PointCloudSampler( 250 | device=self.device, 251 | models=self.models, 252 | diffusions=self.diffusions, 253 | num_points=self.num_points, 254 | aux_channels=self.aux_channels, 255 | model_kwargs_key_filter=self.model_kwargs_key_filter, 256 | guidance_scale=guidance_scale, 257 | clip_denoised=clip_denoised, 258 | use_karras=use_karras, 259 | karras_steps=karras_steps, 260 | sigma_min=sigma_min, 261 | sigma_max=sigma_max, 262 | s_churn=s_churn, 263 | ) 264 | -------------------------------------------------------------------------------- /point_e/diffusion/k_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on: https://github.com/crowsonkb/k-diffusion 3 | 4 | Copyright (c) 2022 Katherine Crowson 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | """ 24 | 25 | import numpy as np 26 | import torch as th 27 | 28 | from .gaussian_diffusion import GaussianDiffusion, mean_flat 29 | 30 | 31 | class KarrasDenoiser: 32 | def __init__(self, sigma_data: float = 0.5): 33 | self.sigma_data = sigma_data 34 | 35 | def get_snr(self, sigmas): 36 | return sigmas**-2 37 | 38 | def get_sigmas(self, sigmas): 39 | return sigmas 40 | 41 | def get_scalings(self, sigma): 42 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 43 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 44 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 45 | return c_skip, c_out, c_in 46 | 47 | def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None): 48 | if model_kwargs is None: 49 | model_kwargs = {} 50 | if noise is None: 51 | noise = th.randn_like(x_start) 52 | 53 | terms = {} 54 | 55 | dims = x_start.ndim 56 | x_t = x_start + noise * append_dims(sigmas, dims) 57 | c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)] 58 | model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs) 59 | target = (x_start - c_skip * x_t) / c_out 60 | 61 | terms["mse"] = mean_flat((model_output - target) ** 2) 62 | terms["xs_mse"] = mean_flat((denoised - x_start) ** 2) 63 | 64 | if "vb" in terms: 65 | terms["loss"] = terms["mse"] + terms["vb"] 66 | else: 67 | terms["loss"] = terms["mse"] 68 | 69 | return terms 70 | 71 | def denoise(self, model, x_t, sigmas, **model_kwargs): 72 | c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)] 73 | rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) 74 | model_output = model(c_in * x_t, rescaled_t, **model_kwargs) 75 | denoised = c_out * model_output + c_skip * x_t 76 | return model_output, denoised 77 | 78 | 79 | class GaussianToKarrasDenoiser: 80 | def __init__(self, model, diffusion): 81 | from scipy import interpolate 82 | 83 | self.model = model 84 | self.diffusion = diffusion 85 | self.alpha_cumprod_to_t = interpolate.interp1d( 86 | diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps) 87 | ) 88 | 89 | def sigma_to_t(self, sigma): 90 | alpha_cumprod = 1.0 / (sigma**2 + 1) 91 | if alpha_cumprod > self.diffusion.alphas_cumprod[0]: 92 | return 0 93 | elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]: 94 | return self.diffusion.num_timesteps - 1 95 | else: 96 | return float(self.alpha_cumprod_to_t(alpha_cumprod)) 97 | 98 | def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None): 99 | t = th.tensor( 100 | [self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()], 101 | dtype=th.long, 102 | device=sigmas.device, 103 | ) 104 | c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim) 105 | out = self.diffusion.p_mean_variance( 106 | self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 107 | ) 108 | return None, out["pred_xstart"] 109 | 110 | 111 | def karras_sample(*args, **kwargs): 112 | last = None 113 | for x in karras_sample_progressive(*args, **kwargs): 114 | last = x["x"] 115 | return last 116 | 117 | 118 | def karras_sample_progressive( 119 | diffusion, 120 | model, 121 | shape, 122 | steps, 123 | clip_denoised=True, 124 | progress=False, 125 | model_kwargs=None, 126 | device=None, 127 | sigma_min=0.002, 128 | sigma_max=80, # higher for highres? 129 | rho=7.0, 130 | sampler="heun", 131 | s_churn=0.0, 132 | s_tmin=0.0, 133 | s_tmax=float("inf"), 134 | s_noise=1.0, 135 | guidance_scale=0.0, 136 | ): 137 | sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device) 138 | x_T = th.randn(*shape, device=device) * sigma_max 139 | sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[ 140 | sampler 141 | ] 142 | 143 | if sampler != "ancestral": 144 | sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise) 145 | else: 146 | sampler_args = {} 147 | 148 | if isinstance(diffusion, KarrasDenoiser): 149 | 150 | def denoiser(x_t, sigma): 151 | _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs) 152 | if clip_denoised: 153 | denoised = denoised.clamp(-1, 1) 154 | return denoised 155 | 156 | elif isinstance(diffusion, GaussianDiffusion): 157 | model = GaussianToKarrasDenoiser(model, diffusion) 158 | 159 | def denoiser(x_t, sigma): 160 | _, denoised = model.denoise( 161 | x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs 162 | ) 163 | return denoised 164 | 165 | else: 166 | raise NotImplementedError 167 | 168 | if guidance_scale != 0 and guidance_scale != 1: 169 | 170 | def guided_denoiser(x_t, sigma): 171 | x_t = th.cat([x_t, x_t], dim=0) 172 | sigma = th.cat([sigma, sigma], dim=0) 173 | x_0 = denoiser(x_t, sigma) 174 | cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0) 175 | x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0) 176 | return x_0 177 | 178 | else: 179 | guided_denoiser = denoiser 180 | 181 | for obj in sample_fn( 182 | guided_denoiser, 183 | x_T, 184 | sigmas, 185 | progress=progress, 186 | **sampler_args, 187 | ): 188 | if isinstance(diffusion, GaussianDiffusion): 189 | yield diffusion.unscale_out_dict(obj) 190 | else: 191 | yield obj 192 | 193 | 194 | def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): 195 | """Constructs the noise schedule of Karras et al. (2022).""" 196 | ramp = th.linspace(0, 1, n) 197 | min_inv_rho = sigma_min ** (1 / rho) 198 | max_inv_rho = sigma_max ** (1 / rho) 199 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 200 | return append_zero(sigmas).to(device) 201 | 202 | 203 | def to_d(x, sigma, denoised): 204 | """Converts a denoiser output to a Karras ODE derivative.""" 205 | return (x - denoised) / append_dims(sigma, x.ndim) 206 | 207 | 208 | def get_ancestral_step(sigma_from, sigma_to): 209 | """Calculates the noise level (sigma_down) to step down to and the amount 210 | of noise to add (sigma_up) when doing an ancestral sampling step.""" 211 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 212 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 213 | return sigma_down, sigma_up 214 | 215 | 216 | @th.no_grad() 217 | def sample_euler_ancestral(model, x, sigmas, progress=False): 218 | """Ancestral sampling with Euler method steps.""" 219 | s_in = x.new_ones([x.shape[0]]) 220 | indices = range(len(sigmas) - 1) 221 | if progress: 222 | from tqdm.auto import tqdm 223 | 224 | indices = tqdm(indices) 225 | 226 | for i in indices: 227 | denoised = model(x, sigmas[i] * s_in) 228 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) 229 | yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised} 230 | d = to_d(x, sigmas[i], denoised) 231 | # Euler method 232 | dt = sigma_down - sigmas[i] 233 | x = x + d * dt 234 | x = x + th.randn_like(x) * sigma_up 235 | yield {"x": x, "pred_xstart": x} 236 | 237 | 238 | @th.no_grad() 239 | def sample_heun( 240 | denoiser, 241 | x, 242 | sigmas, 243 | progress=False, 244 | s_churn=0.0, 245 | s_tmin=0.0, 246 | s_tmax=float("inf"), 247 | s_noise=1.0, 248 | ): 249 | """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" 250 | s_in = x.new_ones([x.shape[0]]) 251 | indices = range(len(sigmas) - 1) 252 | if progress: 253 | from tqdm.auto import tqdm 254 | 255 | indices = tqdm(indices) 256 | 257 | for i in indices: 258 | gamma = ( 259 | min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 260 | ) 261 | eps = th.randn_like(x) * s_noise 262 | sigma_hat = sigmas[i] * (gamma + 1) 263 | if gamma > 0: 264 | x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 265 | denoised = denoiser(x, sigma_hat * s_in) 266 | d = to_d(x, sigma_hat, denoised) 267 | yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised} 268 | dt = sigmas[i + 1] - sigma_hat 269 | if sigmas[i + 1] == 0: 270 | # Euler method 271 | x = x + d * dt 272 | else: 273 | # Heun's method 274 | x_2 = x + d * dt 275 | denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in) 276 | d_2 = to_d(x_2, sigmas[i + 1], denoised_2) 277 | d_prime = (d + d_2) / 2 278 | x = x + d_prime * dt 279 | yield {"x": x, "pred_xstart": denoised} 280 | 281 | 282 | @th.no_grad() 283 | def sample_dpm( 284 | denoiser, 285 | x, 286 | sigmas, 287 | progress=False, 288 | s_churn=0.0, 289 | s_tmin=0.0, 290 | s_tmax=float("inf"), 291 | s_noise=1.0, 292 | ): 293 | """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" 294 | s_in = x.new_ones([x.shape[0]]) 295 | indices = range(len(sigmas) - 1) 296 | if progress: 297 | from tqdm.auto import tqdm 298 | 299 | indices = tqdm(indices) 300 | 301 | for i in indices: 302 | gamma = ( 303 | min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 304 | ) 305 | eps = th.randn_like(x) * s_noise 306 | sigma_hat = sigmas[i] * (gamma + 1) 307 | if gamma > 0: 308 | x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 309 | denoised = denoiser(x, sigma_hat * s_in) 310 | d = to_d(x, sigma_hat, denoised) 311 | yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised} 312 | # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule 313 | sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 314 | dt_1 = sigma_mid - sigma_hat 315 | dt_2 = sigmas[i + 1] - sigma_hat 316 | x_2 = x + d * dt_1 317 | denoised_2 = denoiser(x_2, sigma_mid * s_in) 318 | d_2 = to_d(x_2, sigma_mid, denoised_2) 319 | x = x + d_2 * dt_2 320 | yield {"x": x, "pred_xstart": denoised} 321 | 322 | 323 | def append_dims(x, target_dims): 324 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 325 | dims_to_append = target_dims - x.ndim 326 | if dims_to_append < 0: 327 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") 328 | return x[(...,) + (None,) * dims_to_append] 329 | 330 | 331 | def append_zero(x): 332 | return th.cat([x, x.new_zeros([1])]) 333 | -------------------------------------------------------------------------------- /point_e/evals/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/eb64fe0b4c24055559cea26299cb485dcb43d8dd/models/pointnet_utils.py 3 | 4 | MIT License 5 | 6 | Copyright (c) 2019 benny 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | """ 26 | 27 | from time import time 28 | 29 | import numpy as np 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | 34 | 35 | def timeit(tag, t): 36 | print("{}: {}s".format(tag, time() - t)) 37 | return time() 38 | 39 | 40 | def pc_normalize(pc): 41 | l = pc.shape[0] 42 | centroid = np.mean(pc, axis=0) 43 | pc = pc - centroid 44 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 45 | pc = pc / m 46 | return pc 47 | 48 | 49 | def square_distance(src, dst): 50 | """ 51 | Calculate Euclid distance between each two points. 52 | 53 | src^T * dst = xn * xm + yn * ym + zn * zm; 54 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 55 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 56 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 57 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 58 | 59 | Input: 60 | src: source points, [B, N, C] 61 | dst: target points, [B, M, C] 62 | Output: 63 | dist: per-point square distance, [B, N, M] 64 | """ 65 | B, N, _ = src.shape 66 | _, M, _ = dst.shape 67 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 68 | dist += torch.sum(src**2, -1).view(B, N, 1) 69 | dist += torch.sum(dst**2, -1).view(B, 1, M) 70 | return dist 71 | 72 | 73 | def index_points(points, idx): 74 | """ 75 | 76 | Input: 77 | points: input points data, [B, N, C] 78 | idx: sample index data, [B, S] 79 | Return: 80 | new_points:, indexed points data, [B, S, C] 81 | """ 82 | device = points.device 83 | B = points.shape[0] 84 | view_shape = list(idx.shape) 85 | view_shape[1:] = [1] * (len(view_shape) - 1) 86 | repeat_shape = list(idx.shape) 87 | repeat_shape[0] = 1 88 | batch_indices = ( 89 | torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 90 | ) 91 | new_points = points[batch_indices, idx, :] 92 | return new_points 93 | 94 | 95 | def farthest_point_sample(xyz, npoint, deterministic=False): 96 | """ 97 | Input: 98 | xyz: pointcloud data, [B, N, 3] 99 | npoint: number of samples 100 | Return: 101 | centroids: sampled pointcloud index, [B, npoint] 102 | """ 103 | device = xyz.device 104 | B, N, C = xyz.shape 105 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 106 | distance = torch.ones(B, N).to(device) * 1e10 107 | if deterministic: 108 | farthest = torch.arange(0, B, dtype=torch.long).to(device) 109 | else: 110 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 111 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 112 | for i in range(npoint): 113 | centroids[:, i] = farthest 114 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 115 | dist = torch.sum((xyz - centroid) ** 2, -1) 116 | mask = dist < distance 117 | distance[mask] = dist[mask] 118 | farthest = torch.max(distance, -1)[1] 119 | return centroids 120 | 121 | 122 | def query_ball_point(radius, nsample, xyz, new_xyz): 123 | """ 124 | Input: 125 | radius: local region radius 126 | nsample: max sample number in local region 127 | xyz: all points, [B, N, 3] 128 | new_xyz: query points, [B, S, 3] 129 | Return: 130 | group_idx: grouped points index, [B, S, nsample] 131 | """ 132 | device = xyz.device 133 | B, N, C = xyz.shape 134 | _, S, _ = new_xyz.shape 135 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 136 | sqrdists = square_distance(new_xyz, xyz) 137 | group_idx[sqrdists > radius**2] = N 138 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 139 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 140 | mask = group_idx == N 141 | group_idx[mask] = group_first[mask] 142 | return group_idx 143 | 144 | 145 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, deterministic=False): 146 | """ 147 | Input: 148 | npoint: 149 | radius: 150 | nsample: 151 | xyz: input points position data, [B, N, 3] 152 | points: input points data, [B, N, D] 153 | Return: 154 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 155 | new_points: sampled points data, [B, npoint, nsample, 3+D] 156 | """ 157 | B, N, C = xyz.shape 158 | S = npoint 159 | fps_idx = farthest_point_sample(xyz, npoint, deterministic=deterministic) # [B, npoint, C] 160 | new_xyz = index_points(xyz, fps_idx) 161 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 162 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 163 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 164 | 165 | if points is not None: 166 | grouped_points = index_points(points, idx) 167 | new_points = torch.cat( 168 | [grouped_xyz_norm, grouped_points], dim=-1 169 | ) # [B, npoint, nsample, C+D] 170 | else: 171 | new_points = grouped_xyz_norm 172 | if returnfps: 173 | return new_xyz, new_points, grouped_xyz, fps_idx 174 | else: 175 | return new_xyz, new_points 176 | 177 | 178 | def sample_and_group_all(xyz, points): 179 | """ 180 | Input: 181 | xyz: input points position data, [B, N, 3] 182 | points: input points data, [B, N, D] 183 | Return: 184 | new_xyz: sampled points position data, [B, 1, 3] 185 | new_points: sampled points data, [B, 1, N, 3+D] 186 | """ 187 | device = xyz.device 188 | B, N, C = xyz.shape 189 | new_xyz = torch.zeros(B, 1, C).to(device) 190 | grouped_xyz = xyz.view(B, 1, N, C) 191 | if points is not None: 192 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 193 | else: 194 | new_points = grouped_xyz 195 | return new_xyz, new_points 196 | 197 | 198 | class PointNetSetAbstraction(nn.Module): 199 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 200 | super(PointNetSetAbstraction, self).__init__() 201 | self.npoint = npoint 202 | self.radius = radius 203 | self.nsample = nsample 204 | self.mlp_convs = nn.ModuleList() 205 | self.mlp_bns = nn.ModuleList() 206 | last_channel = in_channel 207 | for out_channel in mlp: 208 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 209 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 210 | last_channel = out_channel 211 | self.group_all = group_all 212 | 213 | def forward(self, xyz, points): 214 | """ 215 | Input: 216 | xyz: input points position data, [B, C, N] 217 | points: input points data, [B, D, N] 218 | Return: 219 | new_xyz: sampled points position data, [B, C, S] 220 | new_points_concat: sample points feature data, [B, D', S] 221 | """ 222 | xyz = xyz.permute(0, 2, 1) 223 | if points is not None: 224 | points = points.permute(0, 2, 1) 225 | 226 | if self.group_all: 227 | new_xyz, new_points = sample_and_group_all(xyz, points) 228 | else: 229 | new_xyz, new_points = sample_and_group( 230 | self.npoint, self.radius, self.nsample, xyz, points, deterministic=not self.training 231 | ) 232 | # new_xyz: sampled points position data, [B, npoint, C] 233 | # new_points: sampled points data, [B, npoint, nsample, C+D] 234 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 235 | for i, conv in enumerate(self.mlp_convs): 236 | bn = self.mlp_bns[i] 237 | new_points = F.relu(bn(conv(new_points))) 238 | 239 | new_points = torch.max(new_points, 2)[0] 240 | new_xyz = new_xyz.permute(0, 2, 1) 241 | return new_xyz, new_points 242 | 243 | 244 | class PointNetSetAbstractionMsg(nn.Module): 245 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 246 | super(PointNetSetAbstractionMsg, self).__init__() 247 | self.npoint = npoint 248 | self.radius_list = radius_list 249 | self.nsample_list = nsample_list 250 | self.conv_blocks = nn.ModuleList() 251 | self.bn_blocks = nn.ModuleList() 252 | for i in range(len(mlp_list)): 253 | convs = nn.ModuleList() 254 | bns = nn.ModuleList() 255 | last_channel = in_channel + 3 256 | for out_channel in mlp_list[i]: 257 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 258 | bns.append(nn.BatchNorm2d(out_channel)) 259 | last_channel = out_channel 260 | self.conv_blocks.append(convs) 261 | self.bn_blocks.append(bns) 262 | 263 | def forward(self, xyz, points): 264 | """ 265 | Input: 266 | xyz: input points position data, [B, C, N] 267 | points: input points data, [B, D, N] 268 | Return: 269 | new_xyz: sampled points position data, [B, C, S] 270 | new_points_concat: sample points feature data, [B, D', S] 271 | """ 272 | xyz = xyz.permute(0, 2, 1) 273 | if points is not None: 274 | points = points.permute(0, 2, 1) 275 | 276 | B, N, C = xyz.shape 277 | S = self.npoint 278 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S, deterministic=not self.training)) 279 | new_points_list = [] 280 | for i, radius in enumerate(self.radius_list): 281 | K = self.nsample_list[i] 282 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 283 | grouped_xyz = index_points(xyz, group_idx) 284 | grouped_xyz -= new_xyz.view(B, S, 1, C) 285 | if points is not None: 286 | grouped_points = index_points(points, group_idx) 287 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 288 | else: 289 | grouped_points = grouped_xyz 290 | 291 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 292 | for j in range(len(self.conv_blocks[i])): 293 | conv = self.conv_blocks[i][j] 294 | bn = self.bn_blocks[i][j] 295 | grouped_points = F.relu(bn(conv(grouped_points))) 296 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 297 | new_points_list.append(new_points) 298 | 299 | new_xyz = new_xyz.permute(0, 2, 1) 300 | new_points_concat = torch.cat(new_points_list, dim=1) 301 | return new_xyz, new_points_concat 302 | 303 | 304 | class PointNetFeaturePropagation(nn.Module): 305 | def __init__(self, in_channel, mlp): 306 | super(PointNetFeaturePropagation, self).__init__() 307 | self.mlp_convs = nn.ModuleList() 308 | self.mlp_bns = nn.ModuleList() 309 | last_channel = in_channel 310 | for out_channel in mlp: 311 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 312 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 313 | last_channel = out_channel 314 | 315 | def forward(self, xyz1, xyz2, points1, points2): 316 | """ 317 | Input: 318 | xyz1: input points position data, [B, C, N] 319 | xyz2: sampled input points position data, [B, C, S] 320 | points1: input points data, [B, D, N] 321 | points2: input points data, [B, D, S] 322 | Return: 323 | new_points: upsampled points data, [B, D', N] 324 | """ 325 | xyz1 = xyz1.permute(0, 2, 1) 326 | xyz2 = xyz2.permute(0, 2, 1) 327 | 328 | points2 = points2.permute(0, 2, 1) 329 | B, N, C = xyz1.shape 330 | _, S, _ = xyz2.shape 331 | 332 | if S == 1: 333 | interpolated_points = points2.repeat(1, N, 1) 334 | else: 335 | dists = square_distance(xyz1, xyz2) 336 | dists, idx = dists.sort(dim=-1) 337 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 338 | 339 | dist_recip = 1.0 / (dists + 1e-8) 340 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 341 | weight = dist_recip / norm 342 | interpolated_points = torch.sum( 343 | index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2 344 | ) 345 | 346 | if points1 is not None: 347 | points1 = points1.permute(0, 2, 1) 348 | new_points = torch.cat([points1, interpolated_points], dim=-1) 349 | else: 350 | new_points = interpolated_points 351 | 352 | new_points = new_points.permute(0, 2, 1) 353 | for i, conv in enumerate(self.mlp_convs): 354 | bn = self.mlp_bns[i] 355 | new_points = F.relu(bn(conv(new_points))) 356 | return new_points 357 | -------------------------------------------------------------------------------- /point_e/models/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/openai/openai/blob/55363aa496049423c37124b440e9e30366db3ed6/orc/orc/diffusion/vit.py 3 | """ 4 | 5 | 6 | import math 7 | from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from .checkpoint import checkpoint 13 | from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType 14 | from .util import timestep_embedding 15 | 16 | 17 | def init_linear(l, stddev): 18 | nn.init.normal_(l.weight, std=stddev) 19 | if l.bias is not None: 20 | nn.init.constant_(l.bias, 0.0) 21 | 22 | 23 | class MultiheadAttention(nn.Module): 24 | def __init__( 25 | self, 26 | *, 27 | device: torch.device, 28 | dtype: torch.dtype, 29 | n_ctx: int, 30 | width: int, 31 | heads: int, 32 | init_scale: float, 33 | ): 34 | super().__init__() 35 | self.n_ctx = n_ctx 36 | self.width = width 37 | self.heads = heads 38 | self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype) 39 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 40 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) 41 | init_linear(self.c_qkv, init_scale) 42 | init_linear(self.c_proj, init_scale) 43 | 44 | def forward(self, x): 45 | x = self.c_qkv(x) 46 | x = checkpoint(self.attention, (x,), (), True) 47 | x = self.c_proj(x) 48 | return x 49 | 50 | 51 | class MLP(nn.Module): 52 | def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float): 53 | super().__init__() 54 | self.width = width 55 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) 56 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) 57 | self.gelu = nn.GELU() 58 | init_linear(self.c_fc, init_scale) 59 | init_linear(self.c_proj, init_scale) 60 | 61 | def forward(self, x): 62 | return self.c_proj(self.gelu(self.c_fc(x))) 63 | 64 | 65 | class QKVMultiheadAttention(nn.Module): 66 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): 67 | super().__init__() 68 | self.device = device 69 | self.dtype = dtype 70 | self.heads = heads 71 | self.n_ctx = n_ctx 72 | 73 | def forward(self, qkv): 74 | bs, n_ctx, width = qkv.shape 75 | attn_ch = width // self.heads // 3 76 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 77 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 78 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 79 | weight = torch.einsum( 80 | "bthc,bshc->bhts", q * scale, k * scale 81 | ) # More stable with f16 than dividing afterwards 82 | wdtype = weight.dtype 83 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 84 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 85 | 86 | 87 | class ResidualAttentionBlock(nn.Module): 88 | def __init__( 89 | self, 90 | *, 91 | device: torch.device, 92 | dtype: torch.dtype, 93 | n_ctx: int, 94 | width: int, 95 | heads: int, 96 | init_scale: float = 1.0, 97 | ): 98 | super().__init__() 99 | 100 | self.attn = MultiheadAttention( 101 | device=device, 102 | dtype=dtype, 103 | n_ctx=n_ctx, 104 | width=width, 105 | heads=heads, 106 | init_scale=init_scale, 107 | ) 108 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 109 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 110 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) 111 | 112 | def forward(self, x: torch.Tensor): 113 | x = x + self.attn(self.ln_1(x)) 114 | x = x + self.mlp(self.ln_2(x)) 115 | return x 116 | 117 | 118 | class Transformer(nn.Module): 119 | def __init__( 120 | self, 121 | *, 122 | device: torch.device, 123 | dtype: torch.dtype, 124 | n_ctx: int, 125 | width: int, 126 | layers: int, 127 | heads: int, 128 | init_scale: float = 0.25, 129 | ): 130 | super().__init__() 131 | self.n_ctx = n_ctx 132 | self.width = width 133 | self.layers = layers 134 | init_scale = init_scale * math.sqrt(1.0 / width) 135 | self.resblocks = nn.ModuleList( 136 | [ 137 | ResidualAttentionBlock( 138 | device=device, 139 | dtype=dtype, 140 | n_ctx=n_ctx, 141 | width=width, 142 | heads=heads, 143 | init_scale=init_scale, 144 | ) 145 | for _ in range(layers) 146 | ] 147 | ) 148 | 149 | def forward(self, x: torch.Tensor): 150 | for block in self.resblocks: 151 | x = block(x) 152 | return x 153 | 154 | 155 | class PointDiffusionTransformer(nn.Module): 156 | def __init__( 157 | self, 158 | *, 159 | device: torch.device, 160 | dtype: torch.dtype, 161 | input_channels: int = 3, 162 | output_channels: int = 3, 163 | n_ctx: int = 1024, 164 | width: int = 512, 165 | layers: int = 12, 166 | heads: int = 8, 167 | init_scale: float = 0.25, 168 | time_token_cond: bool = False, 169 | ): 170 | super().__init__() 171 | self.input_channels = input_channels 172 | self.output_channels = output_channels 173 | self.n_ctx = n_ctx 174 | self.time_token_cond = time_token_cond 175 | self.time_embed = MLP( 176 | device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width) 177 | ) 178 | self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) 179 | self.backbone = Transformer( 180 | device=device, 181 | dtype=dtype, 182 | n_ctx=n_ctx + int(time_token_cond), 183 | width=width, 184 | layers=layers, 185 | heads=heads, 186 | init_scale=init_scale, 187 | ) 188 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 189 | self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) 190 | self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) 191 | with torch.no_grad(): 192 | self.output_proj.weight.zero_() 193 | self.output_proj.bias.zero_() 194 | 195 | def forward(self, x: torch.Tensor, t: torch.Tensor): 196 | """ 197 | :param x: an [N x C x T] tensor. 198 | :param t: an [N] tensor. 199 | :return: an [N x C' x T] tensor. 200 | """ 201 | assert x.shape[-1] == self.n_ctx 202 | t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) 203 | return self._forward_with_cond(x, [(t_embed, self.time_token_cond)]) 204 | 205 | def _forward_with_cond( 206 | self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]] 207 | ) -> torch.Tensor: 208 | h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC 209 | for emb, as_token in cond_as_token: 210 | if not as_token: 211 | h = h + emb[:, None] 212 | extra_tokens = [ 213 | (emb[:, None] if len(emb.shape) == 2 else emb) 214 | for emb, as_token in cond_as_token 215 | if as_token 216 | ] 217 | if len(extra_tokens): 218 | h = torch.cat(extra_tokens + [h], dim=1) 219 | 220 | h = self.ln_pre(h) 221 | h = self.backbone(h) 222 | h = self.ln_post(h) 223 | if len(extra_tokens): 224 | h = h[:, sum(h.shape[1] for h in extra_tokens) :] 225 | h = self.output_proj(h) 226 | return h.permute(0, 2, 1) 227 | 228 | 229 | class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer): 230 | def __init__( 231 | self, 232 | *, 233 | device: torch.device, 234 | dtype: torch.dtype, 235 | n_ctx: int = 1024, 236 | token_cond: bool = False, 237 | cond_drop_prob: float = 0.0, 238 | frozen_clip: bool = True, 239 | cache_dir: Optional[str] = None, 240 | **kwargs, 241 | ): 242 | super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), **kwargs) 243 | self.n_ctx = n_ctx 244 | self.token_cond = token_cond 245 | self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device, cache_dir=cache_dir) 246 | self.clip_embed = nn.Linear( 247 | self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype 248 | ) 249 | self.cond_drop_prob = cond_drop_prob 250 | 251 | def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: 252 | with torch.no_grad(): 253 | return dict(embeddings=self.clip(batch_size, **model_kwargs)) 254 | 255 | def forward( 256 | self, 257 | x: torch.Tensor, 258 | t: torch.Tensor, 259 | images: Optional[Iterable[Optional[ImageType]]] = None, 260 | texts: Optional[Iterable[Optional[str]]] = None, 261 | embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, 262 | ): 263 | """ 264 | :param x: an [N x C x T] tensor. 265 | :param t: an [N] tensor. 266 | :param images: a batch of images to condition on. 267 | :param texts: a batch of texts to condition on. 268 | :param embeddings: a batch of CLIP embeddings to condition on. 269 | :return: an [N x C' x T] tensor. 270 | """ 271 | assert x.shape[-1] == self.n_ctx 272 | 273 | t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) 274 | clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings) 275 | assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0] 276 | 277 | if self.training: 278 | mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob 279 | clip_out = clip_out * mask[:, None].to(clip_out) 280 | 281 | # Rescale the features to have unit variance 282 | clip_out = math.sqrt(clip_out.shape[1]) * clip_out 283 | 284 | clip_embed = self.clip_embed(clip_out) 285 | 286 | cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)] 287 | return self._forward_with_cond(x, cond) 288 | 289 | 290 | class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer): 291 | def __init__( 292 | self, 293 | *, 294 | device: torch.device, 295 | dtype: torch.dtype, 296 | n_ctx: int = 1024, 297 | cond_drop_prob: float = 0.0, 298 | frozen_clip: bool = True, 299 | cache_dir: Optional[str] = None, 300 | **kwargs, 301 | ): 302 | clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)( 303 | device, 304 | cache_dir=cache_dir, 305 | ) 306 | super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs) 307 | self.n_ctx = n_ctx 308 | self.clip = clip 309 | self.clip_embed = nn.Sequential( 310 | nn.LayerNorm( 311 | normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype 312 | ), 313 | nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), 314 | ) 315 | self.cond_drop_prob = cond_drop_prob 316 | 317 | def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: 318 | _ = batch_size 319 | with torch.no_grad(): 320 | return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"])) 321 | 322 | def forward( 323 | self, 324 | x: torch.Tensor, 325 | t: torch.Tensor, 326 | images: Optional[Iterable[ImageType]] = None, 327 | embeddings: Optional[Iterable[torch.Tensor]] = None, 328 | ): 329 | """ 330 | :param x: an [N x C x T] tensor. 331 | :param t: an [N] tensor. 332 | :param images: a batch of images to condition on. 333 | :param embeddings: a batch of CLIP latent grids to condition on. 334 | :return: an [N x C' x T] tensor. 335 | """ 336 | assert images is not None or embeddings is not None, "must specify images or embeddings" 337 | assert images is None or embeddings is None, "cannot specify both images and embeddings" 338 | assert x.shape[-1] == self.n_ctx 339 | 340 | t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) 341 | 342 | if images is not None: 343 | clip_out = self.clip.embed_images_grid(images) 344 | else: 345 | clip_out = embeddings 346 | 347 | if self.training: 348 | mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob 349 | clip_out = clip_out * mask[:, None, None].to(clip_out) 350 | 351 | clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC 352 | clip_embed = self.clip_embed(clip_out) 353 | 354 | cond = [(t_embed, self.time_token_cond), (clip_embed, True)] 355 | return self._forward_with_cond(x, cond) 356 | 357 | 358 | class UpsamplePointDiffusionTransformer(PointDiffusionTransformer): 359 | def __init__( 360 | self, 361 | *, 362 | device: torch.device, 363 | dtype: torch.dtype, 364 | cond_input_channels: Optional[int] = None, 365 | cond_ctx: int = 1024, 366 | n_ctx: int = 4096 - 1024, 367 | channel_scales: Optional[Sequence[float]] = None, 368 | channel_biases: Optional[Sequence[float]] = None, 369 | **kwargs, 370 | ): 371 | super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs) 372 | self.n_ctx = n_ctx 373 | self.cond_input_channels = cond_input_channels or self.input_channels 374 | self.cond_point_proj = nn.Linear( 375 | self.cond_input_channels, self.backbone.width, device=device, dtype=dtype 376 | ) 377 | 378 | self.register_buffer( 379 | "channel_scales", 380 | torch.tensor(channel_scales, dtype=dtype, device=device) 381 | if channel_scales is not None 382 | else None, 383 | ) 384 | self.register_buffer( 385 | "channel_biases", 386 | torch.tensor(channel_biases, dtype=dtype, device=device) 387 | if channel_biases is not None 388 | else None, 389 | ) 390 | 391 | def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor): 392 | """ 393 | :param x: an [N x C1 x T] tensor. 394 | :param t: an [N] tensor. 395 | :param low_res: an [N x C2 x T'] tensor of conditioning points. 396 | :return: an [N x C3 x T] tensor. 397 | """ 398 | assert x.shape[-1] == self.n_ctx 399 | t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) 400 | low_res_embed = self._embed_low_res(low_res) 401 | cond = [(t_embed, self.time_token_cond), (low_res_embed, True)] 402 | return self._forward_with_cond(x, cond) 403 | 404 | def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor: 405 | if self.channel_scales is not None: 406 | x = x * self.channel_scales[None, :, None] 407 | if self.channel_biases is not None: 408 | x = x + self.channel_biases[None, :, None] 409 | return self.cond_point_proj(x.permute(0, 2, 1)) 410 | 411 | 412 | class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer): 413 | def __init__( 414 | self, 415 | *, 416 | device: torch.device, 417 | dtype: torch.dtype, 418 | n_ctx: int = 4096 - 1024, 419 | cond_drop_prob: float = 0.0, 420 | frozen_clip: bool = True, 421 | cache_dir: Optional[str] = None, 422 | **kwargs, 423 | ): 424 | clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)( 425 | device, 426 | cache_dir=cache_dir, 427 | ) 428 | super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs) 429 | self.n_ctx = n_ctx 430 | 431 | self.clip = clip 432 | self.clip_embed = nn.Sequential( 433 | nn.LayerNorm( 434 | normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype 435 | ), 436 | nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), 437 | ) 438 | self.cond_drop_prob = cond_drop_prob 439 | 440 | def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: 441 | if "images" not in model_kwargs: 442 | zero_emb = torch.zeros( 443 | [batch_size, self.clip.grid_feature_dim, self.clip.grid_size**2], 444 | device=next(self.parameters()).device, 445 | ) 446 | return dict(embeddings=zero_emb, low_res=model_kwargs["low_res"]) 447 | with torch.no_grad(): 448 | return dict( 449 | embeddings=self.clip.embed_images_grid(model_kwargs["images"]), 450 | low_res=model_kwargs["low_res"], 451 | ) 452 | 453 | def forward( 454 | self, 455 | x: torch.Tensor, 456 | t: torch.Tensor, 457 | *, 458 | low_res: torch.Tensor, 459 | images: Optional[Iterable[ImageType]] = None, 460 | embeddings: Optional[Iterable[torch.Tensor]] = None, 461 | ): 462 | """ 463 | :param x: an [N x C1 x T] tensor. 464 | :param t: an [N] tensor. 465 | :param low_res: an [N x C2 x T'] tensor of conditioning points. 466 | :param images: a batch of images to condition on. 467 | :param embeddings: a batch of CLIP latent grids to condition on. 468 | :return: an [N x C3 x T] tensor. 469 | """ 470 | assert x.shape[-1] == self.n_ctx 471 | t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) 472 | low_res_embed = self._embed_low_res(low_res) 473 | 474 | if images is not None: 475 | clip_out = self.clip.embed_images_grid(images) 476 | elif embeddings is not None: 477 | clip_out = embeddings 478 | else: 479 | # Support unconditional generation. 480 | clip_out = torch.zeros( 481 | [len(x), self.clip.grid_feature_dim, self.clip.grid_size**2], 482 | dtype=x.dtype, 483 | device=x.device, 484 | ) 485 | 486 | if self.training: 487 | mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob 488 | clip_out = clip_out * mask[:, None, None].to(clip_out) 489 | 490 | clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC 491 | clip_embed = self.clip_embed(clip_out) 492 | 493 | cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)] 494 | return self._forward_with_cond(x, cond) 495 | -------------------------------------------------------------------------------- /point_e/evals/scripts/blender_script.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run within Blender to render a 3D model as RGBAD images. 3 | 4 | Example usage 5 | 6 | blender -b -P blender_script.py -- \ 7 | --input_path ../../examples/example_data/corgi.ply \ 8 | --output_path render_out 9 | 10 | Pass `--camera_pose z-circular-elevated` for the rendering used to compute 11 | CLIP R-Precision results. 12 | 13 | The output directory will include metadata json files for each rendered view, 14 | as well as a global metadata file for the render. Each image will be saved as 15 | a collection of 16-bit PNG files for each channel (rgbad), as well as a full 16 | grayscale render of the view. 17 | """ 18 | 19 | import argparse 20 | import json 21 | import math 22 | import os 23 | import random 24 | import sys 25 | 26 | import bpy 27 | from mathutils import Vector 28 | from mathutils.noise import random_unit_vector 29 | 30 | MAX_DEPTH = 5.0 31 | FORMAT_VERSION = 6 32 | UNIFORM_LIGHT_DIRECTION = [0.09387503, -0.63953443, -0.7630093] 33 | 34 | 35 | def clear_scene(): 36 | bpy.ops.object.select_all(action="SELECT") 37 | bpy.ops.object.delete() 38 | 39 | 40 | def clear_lights(): 41 | bpy.ops.object.select_all(action="DESELECT") 42 | for obj in bpy.context.scene.objects.values(): 43 | if isinstance(obj.data, bpy.types.Light): 44 | obj.select_set(True) 45 | bpy.ops.object.delete() 46 | 47 | 48 | def import_model(path): 49 | clear_scene() 50 | _, ext = os.path.splitext(path) 51 | ext = ext.lower() 52 | if ext == ".obj": 53 | bpy.ops.import_scene.obj(filepath=path) 54 | elif ext in [".glb", ".gltf"]: 55 | bpy.ops.import_scene.gltf(filepath=path) 56 | elif ext == ".stl": 57 | bpy.ops.import_mesh.stl(filepath=path) 58 | elif ext == ".fbx": 59 | bpy.ops.import_scene.fbx(filepath=path) 60 | elif ext == ".dae": 61 | bpy.ops.wm.collada_import(filepath=path) 62 | elif ext == ".ply": 63 | bpy.ops.import_mesh.ply(filepath=path) 64 | else: 65 | raise RuntimeError(f"unexpected extension: {ext}") 66 | 67 | 68 | def scene_root_objects(): 69 | for obj in bpy.context.scene.objects.values(): 70 | if not obj.parent: 71 | yield obj 72 | 73 | 74 | def scene_bbox(single_obj=None, ignore_matrix=False): 75 | bbox_min = (math.inf,) * 3 76 | bbox_max = (-math.inf,) * 3 77 | found = False 78 | for obj in scene_meshes() if single_obj is None else [single_obj]: 79 | found = True 80 | for coord in obj.bound_box: 81 | coord = Vector(coord) 82 | if not ignore_matrix: 83 | coord = obj.matrix_world @ coord 84 | bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) 85 | bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) 86 | if not found: 87 | raise RuntimeError("no objects in scene to compute bounding box for") 88 | return Vector(bbox_min), Vector(bbox_max) 89 | 90 | 91 | def scene_meshes(): 92 | for obj in bpy.context.scene.objects.values(): 93 | if isinstance(obj.data, (bpy.types.Mesh)): 94 | yield obj 95 | 96 | 97 | def normalize_scene(): 98 | bbox_min, bbox_max = scene_bbox() 99 | scale = 1 / max(bbox_max - bbox_min) 100 | 101 | for obj in scene_root_objects(): 102 | obj.scale = obj.scale * scale 103 | 104 | # Apply scale to matrix_world. 105 | bpy.context.view_layer.update() 106 | 107 | bbox_min, bbox_max = scene_bbox() 108 | offset = -(bbox_min + bbox_max) / 2 109 | for obj in scene_root_objects(): 110 | obj.matrix_world.translation += offset 111 | 112 | bpy.ops.object.select_all(action="DESELECT") 113 | 114 | 115 | def create_camera(): 116 | # https://b3d.interplanety.org/en/how-to-create-camera-through-the-blender-python-api/ 117 | camera_data = bpy.data.cameras.new(name="Camera") 118 | camera_object = bpy.data.objects.new("Camera", camera_data) 119 | bpy.context.scene.collection.objects.link(camera_object) 120 | bpy.context.scene.camera = camera_object 121 | 122 | 123 | def set_camera(direction, camera_dist=2.0): 124 | camera_pos = -camera_dist * direction 125 | bpy.context.scene.camera.location = camera_pos 126 | 127 | # https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically 128 | rot_quat = direction.to_track_quat("-Z", "Y") 129 | bpy.context.scene.camera.rotation_euler = rot_quat.to_euler() 130 | 131 | bpy.context.view_layer.update() 132 | 133 | 134 | def randomize_camera(camera_dist=2.0): 135 | direction = random_unit_vector() 136 | set_camera(direction, camera_dist=camera_dist) 137 | 138 | 139 | def pan_camera(time, axis="Z", camera_dist=2.0, elevation=-0.1): 140 | angle = time * math.pi * 2 141 | direction = [-math.cos(angle), -math.sin(angle), -elevation] 142 | assert axis in ["X", "Y", "Z"] 143 | if axis == "X": 144 | direction = [direction[2], *direction[:2]] 145 | elif axis == "Y": 146 | direction = [direction[0], -elevation, direction[1]] 147 | direction = Vector(direction).normalized() 148 | set_camera(direction, camera_dist=camera_dist) 149 | 150 | 151 | def place_camera(time, camera_pose_mode="random", camera_dist_min=2.0, camera_dist_max=2.0): 152 | camera_dist = random.uniform(camera_dist_min, camera_dist_max) 153 | if camera_pose_mode == "random": 154 | randomize_camera(camera_dist=camera_dist) 155 | elif camera_pose_mode == "z-circular": 156 | pan_camera(time, axis="Z", camera_dist=camera_dist) 157 | elif camera_pose_mode == "z-circular-elevated": 158 | pan_camera(time, axis="Z", camera_dist=camera_dist, elevation=0.2617993878) 159 | else: 160 | raise ValueError(f"Unknown camera pose mode: {camera_pose_mode}") 161 | 162 | 163 | def create_light(location, energy=1.0, angle=0.5 * math.pi / 180): 164 | # https://blender.stackexchange.com/questions/215624/how-to-create-a-light-with-the-python-api-in-blender-2-92 165 | light_data = bpy.data.lights.new(name="Light", type="SUN") 166 | light_data.energy = energy 167 | light_data.angle = angle 168 | light_object = bpy.data.objects.new(name="Light", object_data=light_data) 169 | 170 | direction = -location 171 | rot_quat = direction.to_track_quat("-Z", "Y") 172 | light_object.rotation_euler = rot_quat.to_euler() 173 | bpy.context.view_layer.update() 174 | 175 | bpy.context.collection.objects.link(light_object) 176 | light_object.location = location 177 | 178 | 179 | def create_random_lights(count=4, distance=2.0, energy=1.5): 180 | clear_lights() 181 | for _ in range(count): 182 | create_light(random_unit_vector() * distance, energy=energy) 183 | 184 | 185 | def create_camera_light(): 186 | clear_lights() 187 | create_light(bpy.context.scene.camera.location, energy=5.0) 188 | 189 | 190 | def create_uniform_light(backend): 191 | clear_lights() 192 | # Random direction to decorrelate axis-aligned sides. 193 | pos = Vector(UNIFORM_LIGHT_DIRECTION) 194 | angle = 0.0092 if backend == "CYCLES" else math.pi 195 | create_light(pos, energy=5.0, angle=angle) 196 | create_light(-pos, energy=5.0, angle=angle) 197 | 198 | 199 | def create_vertex_color_shaders(): 200 | # By default, Blender will ignore vertex colors in both the 201 | # Eevee and Cycles backends, since these colors aren't 202 | # associated with a material. 203 | # 204 | # What we do here is create a simple material shader and link 205 | # the vertex color to the material color. 206 | for obj in bpy.context.scene.objects.values(): 207 | if not isinstance(obj.data, (bpy.types.Mesh)): 208 | continue 209 | 210 | if len(obj.data.materials): 211 | # We don't want to override any existing materials. 212 | continue 213 | 214 | color_keys = (obj.data.vertex_colors or {}).keys() 215 | if not len(color_keys): 216 | # Many objects will have no materials *or* vertex colors. 217 | continue 218 | 219 | mat = bpy.data.materials.new(name="VertexColored") 220 | mat.use_nodes = True 221 | 222 | # There should be a Principled BSDF by default. 223 | bsdf_node = None 224 | for node in mat.node_tree.nodes: 225 | if node.type == "BSDF_PRINCIPLED": 226 | bsdf_node = node 227 | assert bsdf_node is not None, "material has no Principled BSDF node to modify" 228 | 229 | socket_map = {} 230 | for input in bsdf_node.inputs: 231 | socket_map[input.name] = input 232 | 233 | # Make sure nothing lights the object except for the diffuse color. 234 | socket_map["Specular"].default_value = 0.0 235 | socket_map["Roughness"].default_value = 1.0 236 | 237 | v_color = mat.node_tree.nodes.new("ShaderNodeVertexColor") 238 | v_color.layer_name = color_keys[0] 239 | 240 | mat.node_tree.links.new(v_color.outputs[0], socket_map["Base Color"]) 241 | 242 | obj.data.materials.append(mat) 243 | 244 | 245 | def create_default_materials(): 246 | for obj in bpy.context.scene.objects.values(): 247 | if isinstance(obj.data, (bpy.types.Mesh)): 248 | if not len(obj.data.materials): 249 | mat = bpy.data.materials.new(name="DefaultMaterial") 250 | mat.use_nodes = True 251 | obj.data.materials.append(mat) 252 | 253 | 254 | def find_materials(): 255 | all_materials = set() 256 | for obj in bpy.context.scene.objects.values(): 257 | if not isinstance(obj.data, (bpy.types.Mesh)): 258 | continue 259 | for mat in obj.data.materials: 260 | all_materials.add(mat) 261 | return all_materials 262 | 263 | 264 | def get_socket_value(tree, socket): 265 | default = socket.default_value 266 | if not isinstance(default, float): 267 | default = list(default) 268 | for link in tree.links: 269 | if link.to_socket == socket: 270 | return (link.from_socket, default) 271 | return (None, default) 272 | 273 | 274 | def clear_socket_input(tree, socket): 275 | for link in list(tree.links): 276 | if link.to_socket == socket: 277 | tree.links.remove(link) 278 | 279 | 280 | def set_socket_value(tree, socket, socket_and_default): 281 | clear_socket_input(tree, socket) 282 | old_source_socket, default = socket_and_default 283 | if isinstance(default, float) and not isinstance(socket.default_value, float): 284 | # Codepath for setting Emission to a previous alpha value. 285 | socket.default_value = [default] * 3 + [1.0] 286 | else: 287 | socket.default_value = default 288 | if old_source_socket is not None: 289 | tree.links.new(old_source_socket, socket) 290 | 291 | 292 | def setup_nodes(output_path, capturing_material_alpha: bool = False): 293 | tree = bpy.context.scene.node_tree 294 | links = tree.links 295 | 296 | for node in tree.nodes: 297 | tree.nodes.remove(node) 298 | 299 | # Helpers to perform math on links and constants. 300 | def node_op(op: str, *args, clamp=False): 301 | node = tree.nodes.new(type="CompositorNodeMath") 302 | node.operation = op 303 | if clamp: 304 | node.use_clamp = True 305 | for i, arg in enumerate(args): 306 | if isinstance(arg, (int, float)): 307 | node.inputs[i].default_value = arg 308 | else: 309 | links.new(arg, node.inputs[i]) 310 | return node.outputs[0] 311 | 312 | def node_clamp(x, maximum=1.0): 313 | return node_op("MINIMUM", x, maximum) 314 | 315 | def node_mul(x, y, **kwargs): 316 | return node_op("MULTIPLY", x, y, **kwargs) 317 | 318 | input_node = tree.nodes.new(type="CompositorNodeRLayers") 319 | input_node.scene = bpy.context.scene 320 | 321 | input_sockets = {} 322 | for output in input_node.outputs: 323 | input_sockets[output.name] = output 324 | 325 | if capturing_material_alpha: 326 | color_socket = input_sockets["Image"] 327 | else: 328 | raw_color_socket = input_sockets["Image"] 329 | 330 | # We apply sRGB here so that our fixed-point depth map and material 331 | # alpha values are not sRGB, and so that we perform ambient+diffuse 332 | # lighting in linear RGB space. 333 | color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace") 334 | color_node.from_color_space = "Linear" 335 | color_node.to_color_space = "sRGB" 336 | tree.links.new(raw_color_socket, color_node.inputs[0]) 337 | color_socket = color_node.outputs[0] 338 | split_node = tree.nodes.new(type="CompositorNodeSepRGBA") 339 | tree.links.new(color_socket, split_node.inputs[0]) 340 | # Create separate file output nodes for every channel we care about. 341 | # The process calling this script must decide how to recombine these 342 | # channels, possibly into a single image. 343 | for i, channel in enumerate("rgba") if not capturing_material_alpha else [(0, "MatAlpha")]: 344 | output_node = tree.nodes.new(type="CompositorNodeOutputFile") 345 | output_node.base_path = f"{output_path}_{channel}" 346 | links.new(split_node.outputs[i], output_node.inputs[0]) 347 | 348 | if capturing_material_alpha: 349 | # No need to re-write depth here. 350 | return 351 | 352 | depth_out = node_clamp(node_mul(input_sockets["Depth"], 1 / MAX_DEPTH)) 353 | output_node = tree.nodes.new(type="CompositorNodeOutputFile") 354 | output_node.base_path = f"{output_path}_depth" 355 | links.new(depth_out, output_node.inputs[0]) 356 | 357 | 358 | def render_scene(output_path, fast_mode: bool): 359 | use_workbench = bpy.context.scene.render.engine == "BLENDER_WORKBENCH" 360 | if use_workbench: 361 | # We must use a different engine to compute depth maps. 362 | bpy.context.scene.render.engine = "BLENDER_EEVEE" 363 | bpy.context.scene.eevee.taa_render_samples = 1 # faster, since we discard image. 364 | if fast_mode: 365 | if bpy.context.scene.render.engine == "BLENDER_EEVEE": 366 | bpy.context.scene.eevee.taa_render_samples = 1 367 | elif bpy.context.scene.render.engine == "CYCLES": 368 | bpy.context.scene.cycles.samples = 256 369 | else: 370 | if bpy.context.scene.render.engine == "CYCLES": 371 | # We should still impose a per-frame time limit 372 | # so that we don't timeout completely. 373 | bpy.context.scene.cycles.time_limit = 40 374 | bpy.context.view_layer.update() 375 | bpy.context.scene.use_nodes = True 376 | bpy.context.scene.view_layers["ViewLayer"].use_pass_z = True 377 | bpy.context.scene.view_settings.view_transform = "Raw" # sRGB done in graph nodes 378 | bpy.context.scene.render.film_transparent = True 379 | bpy.context.scene.render.resolution_x = 512 380 | bpy.context.scene.render.resolution_y = 512 381 | bpy.context.scene.render.image_settings.file_format = "PNG" 382 | bpy.context.scene.render.image_settings.color_mode = "BW" 383 | bpy.context.scene.render.image_settings.color_depth = "16" 384 | bpy.context.scene.render.filepath = output_path 385 | setup_nodes(output_path) 386 | bpy.ops.render.render(write_still=True) 387 | 388 | # The output images must be moved from their own sub-directories, or 389 | # discarded if we are using workbench for the color. 390 | for channel_name in ["r", "g", "b", "a", "depth"]: 391 | sub_dir = f"{output_path}_{channel_name}" 392 | image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0]) 393 | name, ext = os.path.splitext(output_path) 394 | if channel_name == "depth" or not use_workbench: 395 | os.rename(image_path, f"{name}_{channel_name}{ext}") 396 | else: 397 | os.remove(image_path) 398 | os.removedirs(sub_dir) 399 | 400 | if use_workbench: 401 | # Re-render RGBA using workbench with texture mode, since this seems 402 | # to show the most reasonable colors when lighting is broken. 403 | bpy.context.scene.use_nodes = False 404 | bpy.context.scene.render.engine = "BLENDER_WORKBENCH" 405 | bpy.context.scene.render.image_settings.color_mode = "RGBA" 406 | bpy.context.scene.render.image_settings.color_depth = "8" 407 | bpy.context.scene.display.shading.color_type = "TEXTURE" 408 | bpy.context.scene.display.shading.light = "FLAT" 409 | if fast_mode: 410 | # Single pass anti-aliasing. 411 | bpy.context.scene.display.render_aa = "FXAA" 412 | os.remove(output_path) 413 | bpy.ops.render.render(write_still=True) 414 | bpy.context.scene.render.image_settings.color_mode = "BW" 415 | bpy.context.scene.render.image_settings.color_depth = "16" 416 | 417 | 418 | def scene_fov(): 419 | x_fov = bpy.context.scene.camera.data.angle_x 420 | y_fov = bpy.context.scene.camera.data.angle_y 421 | width = bpy.context.scene.render.resolution_x 422 | height = bpy.context.scene.render.resolution_y 423 | if bpy.context.scene.camera.data.angle == x_fov: 424 | y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width) 425 | else: 426 | x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height) 427 | return x_fov, y_fov 428 | 429 | 430 | def write_camera_metadata(path): 431 | x_fov, y_fov = scene_fov() 432 | bbox_min, bbox_max = scene_bbox() 433 | matrix = bpy.context.scene.camera.matrix_world 434 | with open(path, "w") as f: 435 | json.dump( 436 | dict( 437 | format_version=FORMAT_VERSION, 438 | max_depth=MAX_DEPTH, 439 | bbox=[list(bbox_min), list(bbox_max)], 440 | origin=list(matrix.col[3])[:3], 441 | x_fov=x_fov, 442 | y_fov=y_fov, 443 | x=list(matrix.col[0])[:3], 444 | y=list(-matrix.col[1])[:3], 445 | z=list(-matrix.col[2])[:3], 446 | ), 447 | f, 448 | ) 449 | 450 | 451 | def save_rendering_dataset( 452 | input_path: str, 453 | output_path: str, 454 | num_images: int, 455 | backend: str, 456 | light_mode: str, 457 | camera_pose: str, 458 | camera_dist_min: float, 459 | camera_dist_max: float, 460 | fast_mode: bool, 461 | ): 462 | assert light_mode in ["random", "uniform", "camera"] 463 | assert camera_pose in ["random", "z-circular", "z-circular-elevated"] 464 | 465 | import_model(input_path) 466 | bpy.context.scene.render.engine = backend 467 | normalize_scene() 468 | if light_mode == "random": 469 | create_random_lights() 470 | elif light_mode == "uniform": 471 | create_uniform_light(backend) 472 | create_camera() 473 | create_vertex_color_shaders() 474 | for i in range(num_images): 475 | t = i / max(num_images - 1, 1) # same as np.linspace(0, 1, num_images) 476 | place_camera( 477 | t, 478 | camera_pose_mode=camera_pose, 479 | camera_dist_min=camera_dist_min, 480 | camera_dist_max=camera_dist_max, 481 | ) 482 | if light_mode == "camera": 483 | create_camera_light() 484 | render_scene( 485 | os.path.join(output_path, f"{i:05}.png"), 486 | fast_mode=fast_mode, 487 | ) 488 | write_camera_metadata(os.path.join(output_path, f"{i:05}.json")) 489 | with open(os.path.join(output_path, "info.json"), "w") as f: 490 | info = dict( 491 | backend=backend, 492 | light_mode=light_mode, 493 | fast_mode=fast_mode, 494 | format_version=FORMAT_VERSION, 495 | channels=["R", "G", "B", "A", "D"], 496 | scale=0.5, # The scene is bounded by [-scale, scale]. 497 | ) 498 | json.dump(info, f) 499 | 500 | 501 | def main(): 502 | try: 503 | dash_index = sys.argv.index("--") 504 | except ValueError as exc: 505 | raise ValueError("arguments must be preceded by '--'") from exc 506 | 507 | raw_args = sys.argv[dash_index + 1 :] 508 | parser = argparse.ArgumentParser() 509 | parser.add_argument("--input_path", required=True, type=str) 510 | parser.add_argument("--output_path", required=True, type=str) 511 | parser.add_argument("--num_images", type=int, default=20) 512 | parser.add_argument("--backend", type=str, default="BLENDER_EEVEE") 513 | parser.add_argument("--light_mode", type=str, default="uniform") 514 | parser.add_argument("--camera_pose", type=str, default="random") 515 | parser.add_argument("--camera_dist_min", type=float, default=2.0) 516 | parser.add_argument("--camera_dist_max", type=float, default=2.0) 517 | parser.add_argument("--fast_mode", action="store_true") 518 | args = parser.parse_args(raw_args) 519 | 520 | save_rendering_dataset( 521 | input_path=args.input_path, 522 | output_path=args.output_path, 523 | num_images=args.num_images, 524 | backend=args.backend, 525 | light_mode=args.light_mode, 526 | camera_pose=args.camera_pose, 527 | camera_dist_min=args.camera_dist_min, 528 | camera_dist_max=args.camera_dist_max, 529 | fast_mode=args.fast_mode, 530 | ) 531 | 532 | 533 | main() 534 | -------------------------------------------------------------------------------- /point_e/diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py 3 | """ 4 | 5 | import math 6 | from typing import Any, Dict, Iterable, Optional, Sequence, Union 7 | 8 | import numpy as np 9 | import torch as th 10 | 11 | 12 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 13 | """ 14 | This is the deprecated API for creating beta schedules. 15 | 16 | See get_named_beta_schedule() for the new library of schedules. 17 | """ 18 | if beta_schedule == "linear": 19 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 20 | else: 21 | raise NotImplementedError(beta_schedule) 22 | assert betas.shape == (num_diffusion_timesteps,) 23 | return betas 24 | 25 | 26 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 27 | """ 28 | Get a pre-defined beta schedule for the given name. 29 | 30 | The beta schedule library consists of beta schedules which remain similar 31 | in the limit of num_diffusion_timesteps. 32 | Beta schedules may be added, but should not be removed or changed once 33 | they are committed to maintain backwards compatibility. 34 | """ 35 | if schedule_name == "linear": 36 | # Linear schedule from Ho et al, extended to work for any number of 37 | # diffusion steps. 38 | scale = 1000 / num_diffusion_timesteps 39 | return get_beta_schedule( 40 | "linear", 41 | beta_start=scale * 0.0001, 42 | beta_end=scale * 0.02, 43 | num_diffusion_timesteps=num_diffusion_timesteps, 44 | ) 45 | elif schedule_name == "cosine": 46 | return betas_for_alpha_bar( 47 | num_diffusion_timesteps, 48 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 49 | ) 50 | else: 51 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 52 | 53 | 54 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 55 | """ 56 | Create a beta schedule that discretizes the given alpha_t_bar function, 57 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 58 | 59 | :param num_diffusion_timesteps: the number of betas to produce. 60 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 61 | produces the cumulative product of (1-beta) up to that 62 | part of the diffusion process. 63 | :param max_beta: the maximum beta to use; use values lower than 1 to 64 | prevent singularities. 65 | """ 66 | betas = [] 67 | for i in range(num_diffusion_timesteps): 68 | t1 = i / num_diffusion_timesteps 69 | t2 = (i + 1) / num_diffusion_timesteps 70 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 71 | return np.array(betas) 72 | 73 | 74 | def space_timesteps(num_timesteps, section_counts): 75 | """ 76 | Create a list of timesteps to use from an original diffusion process, 77 | given the number of timesteps we want to take from equally-sized portions 78 | of the original process. 79 | For example, if there's 300 timesteps and the section counts are [10,15,20] 80 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 81 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 82 | :param num_timesteps: the number of diffusion steps in the original 83 | process to divide up. 84 | :param section_counts: either a list of numbers, or a string containing 85 | comma-separated numbers, indicating the step count 86 | per section. As a special case, use "ddimN" where N 87 | is a number of steps to use the striding from the 88 | DDIM paper. 89 | :return: a set of diffusion steps from the original process to use. 90 | """ 91 | if isinstance(section_counts, str): 92 | if section_counts.startswith("ddim"): 93 | desired_count = int(section_counts[len("ddim") :]) 94 | for i in range(1, num_timesteps): 95 | if len(range(0, num_timesteps, i)) == desired_count: 96 | return set(range(0, num_timesteps, i)) 97 | raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") 98 | elif section_counts.startswith("exact"): 99 | res = set(int(x) for x in section_counts[len("exact") :].split(",")) 100 | for x in res: 101 | if x < 0 or x >= num_timesteps: 102 | raise ValueError(f"timestep out of bounds: {x}") 103 | return res 104 | section_counts = [int(x) for x in section_counts.split(",")] 105 | size_per = num_timesteps // len(section_counts) 106 | extra = num_timesteps % len(section_counts) 107 | start_idx = 0 108 | all_steps = [] 109 | for i, section_count in enumerate(section_counts): 110 | size = size_per + (1 if i < extra else 0) 111 | if size < section_count: 112 | raise ValueError(f"cannot divide section of {size} steps into {section_count}") 113 | if section_count <= 1: 114 | frac_stride = 1 115 | else: 116 | frac_stride = (size - 1) / (section_count - 1) 117 | cur_idx = 0.0 118 | taken_steps = [] 119 | for _ in range(section_count): 120 | taken_steps.append(start_idx + round(cur_idx)) 121 | cur_idx += frac_stride 122 | all_steps += taken_steps 123 | start_idx += size 124 | return set(all_steps) 125 | 126 | 127 | class GaussianDiffusion: 128 | """ 129 | Utilities for training and sampling diffusion models. 130 | 131 | Ported directly from here: 132 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 133 | 134 | :param betas: a 1-D array of betas for each diffusion timestep from T to 1. 135 | :param model_mean_type: a string determining what the model outputs. 136 | :param model_var_type: a string determining how variance is output. 137 | :param loss_type: a string determining the loss function to use. 138 | :param discretized_t0: if True, use discrete gaussian loss for t=0. Only 139 | makes sense for images. 140 | :param channel_scales: a multiplier to apply to x_start in training_losses 141 | and sampling functions. 142 | """ 143 | 144 | def __init__( 145 | self, 146 | *, 147 | betas: Sequence[float], 148 | model_mean_type: str, 149 | model_var_type: str, 150 | loss_type: str, 151 | discretized_t0: bool = False, 152 | channel_scales: Optional[np.ndarray] = None, 153 | channel_biases: Optional[np.ndarray] = None, 154 | ): 155 | self.model_mean_type = model_mean_type 156 | self.model_var_type = model_var_type 157 | self.loss_type = loss_type 158 | self.discretized_t0 = discretized_t0 159 | self.channel_scales = channel_scales 160 | self.channel_biases = channel_biases 161 | 162 | # Use float64 for accuracy. 163 | betas = np.array(betas, dtype=np.float64) 164 | self.betas = betas 165 | assert len(betas.shape) == 1, "betas must be 1-D" 166 | assert (betas > 0).all() and (betas <= 1).all() 167 | 168 | self.num_timesteps = int(betas.shape[0]) 169 | 170 | alphas = 1.0 - betas 171 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 172 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 173 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 174 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 175 | 176 | # calculations for diffusion q(x_t | x_{t-1}) and others 177 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 178 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 179 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 180 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 181 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 182 | 183 | # calculations for posterior q(x_{t-1} | x_t, x_0) 184 | self.posterior_variance = ( 185 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 186 | ) 187 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 188 | self.posterior_log_variance_clipped = np.log( 189 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 190 | ) 191 | self.posterior_mean_coef1 = ( 192 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 193 | ) 194 | self.posterior_mean_coef2 = ( 195 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) 196 | ) 197 | 198 | def get_sigmas(self, t): 199 | return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape) 200 | 201 | def q_mean_variance(self, x_start, t): 202 | """ 203 | Get the distribution q(x_t | x_0). 204 | 205 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 206 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 207 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 208 | """ 209 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 210 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 211 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) 212 | return mean, variance, log_variance 213 | 214 | def q_sample(self, x_start, t, noise=None): 215 | """ 216 | Diffuse the data for a given number of diffusion steps. 217 | 218 | In other words, sample from q(x_t | x_0). 219 | 220 | :param x_start: the initial data batch. 221 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 222 | :param noise: if specified, the split-out normal noise. 223 | :return: A noisy version of x_start. 224 | """ 225 | if noise is None: 226 | noise = th.randn_like(x_start) 227 | assert noise.shape == x_start.shape 228 | return ( 229 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 230 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 231 | ) 232 | 233 | def q_posterior_mean_variance(self, x_start, x_t, t): 234 | """ 235 | Compute the mean and variance of the diffusion posterior: 236 | 237 | q(x_{t-1} | x_t, x_0) 238 | 239 | """ 240 | assert x_start.shape == x_t.shape 241 | posterior_mean = ( 242 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 243 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 244 | ) 245 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 246 | posterior_log_variance_clipped = _extract_into_tensor( 247 | self.posterior_log_variance_clipped, t, x_t.shape 248 | ) 249 | assert ( 250 | posterior_mean.shape[0] 251 | == posterior_variance.shape[0] 252 | == posterior_log_variance_clipped.shape[0] 253 | == x_start.shape[0] 254 | ) 255 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 256 | 257 | def p_mean_variance( 258 | self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None 259 | ): 260 | """ 261 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 262 | the initial x, x_0. 263 | 264 | :param model: the model, which takes a signal and a batch of timesteps 265 | as input. 266 | :param x: the [N x C x ...] tensor at time t. 267 | :param t: a 1-D Tensor of timesteps. 268 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 269 | :param denoised_fn: if not None, a function which applies to the 270 | x_start prediction before it is used to sample. Applies before 271 | clip_denoised. 272 | :param model_kwargs: if not None, a dict of extra keyword arguments to 273 | pass to the model. This can be used for conditioning. 274 | :return: a dict with the following keys: 275 | - 'mean': the model mean output. 276 | - 'variance': the model variance output. 277 | - 'log_variance': the log of 'variance'. 278 | - 'pred_xstart': the prediction for x_0. 279 | """ 280 | if model_kwargs is None: 281 | model_kwargs = {} 282 | 283 | B, C = x.shape[:2] 284 | assert t.shape == (B,) 285 | model_output = model(x, t, **model_kwargs) 286 | if isinstance(model_output, tuple): 287 | model_output, extra = model_output 288 | else: 289 | extra = None 290 | 291 | if self.model_var_type in ["learned", "learned_range"]: 292 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 293 | model_output, model_var_values = th.split(model_output, C, dim=1) 294 | if self.model_var_type == "learned": 295 | model_log_variance = model_var_values 296 | model_variance = th.exp(model_log_variance) 297 | else: 298 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) 299 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 300 | # The model_var_values is [-1, 1] for [min_var, max_var]. 301 | frac = (model_var_values + 1) / 2 302 | model_log_variance = frac * max_log + (1 - frac) * min_log 303 | model_variance = th.exp(model_log_variance) 304 | else: 305 | model_variance, model_log_variance = { 306 | # for fixedlarge, we set the initial (log-)variance like so 307 | # to get a better decoder log likelihood. 308 | "fixed_large": ( 309 | np.append(self.posterior_variance[1], self.betas[1:]), 310 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 311 | ), 312 | "fixed_small": ( 313 | self.posterior_variance, 314 | self.posterior_log_variance_clipped, 315 | ), 316 | }[self.model_var_type] 317 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 318 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 319 | 320 | def process_xstart(x): 321 | if denoised_fn is not None: 322 | x = denoised_fn(x) 323 | if clip_denoised: 324 | return x.clamp(-1, 1) 325 | return x 326 | 327 | if self.model_mean_type == "x_prev": 328 | pred_xstart = process_xstart( 329 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 330 | ) 331 | model_mean = model_output 332 | elif self.model_mean_type in ["x_start", "epsilon"]: 333 | if self.model_mean_type == "x_start": 334 | pred_xstart = process_xstart(model_output) 335 | else: 336 | pred_xstart = process_xstart( 337 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 338 | ) 339 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 340 | else: 341 | raise NotImplementedError(self.model_mean_type) 342 | 343 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 344 | return { 345 | "mean": model_mean, 346 | "variance": model_variance, 347 | "log_variance": model_log_variance, 348 | "pred_xstart": pred_xstart, 349 | "extra": extra, 350 | } 351 | 352 | def _predict_xstart_from_eps(self, x_t, t, eps): 353 | assert x_t.shape == eps.shape 354 | return ( 355 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 356 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 357 | ) 358 | 359 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 360 | assert x_t.shape == xprev.shape 361 | return ( # (xprev - coef2*x_t) / coef1 362 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 363 | - _extract_into_tensor( 364 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 365 | ) 366 | * x_t 367 | ) 368 | 369 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 370 | return ( 371 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart 372 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 373 | 374 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 375 | """ 376 | Compute the mean for the previous step, given a function cond_fn that 377 | computes the gradient of a conditional log probability with respect to 378 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 379 | condition on y. 380 | 381 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 382 | """ 383 | gradient = cond_fn(x, t, **model_kwargs) 384 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 385 | return new_mean 386 | 387 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 388 | """ 389 | Compute what the p_mean_variance output would have been, should the 390 | model's score function be conditioned by cond_fn. 391 | 392 | See condition_mean() for details on cond_fn. 393 | 394 | Unlike condition_mean(), this instead uses the conditioning strategy 395 | from Song et al (2020). 396 | """ 397 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 398 | 399 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 400 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) 401 | 402 | out = p_mean_var.copy() 403 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 404 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) 405 | return out 406 | 407 | def p_sample( 408 | self, 409 | model, 410 | x, 411 | t, 412 | clip_denoised=False, 413 | denoised_fn=None, 414 | cond_fn=None, 415 | model_kwargs=None, 416 | ): 417 | """ 418 | Sample x_{t-1} from the model at the given timestep. 419 | 420 | :param model: the model to sample from. 421 | :param x: the current tensor at x_{t-1}. 422 | :param t: the value of t, starting at 0 for the first diffusion step. 423 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 424 | :param denoised_fn: if not None, a function which applies to the 425 | x_start prediction before it is used to sample. 426 | :param cond_fn: if not None, this is a gradient function that acts 427 | similarly to the model. 428 | :param model_kwargs: if not None, a dict of extra keyword arguments to 429 | pass to the model. This can be used for conditioning. 430 | :return: a dict containing the following keys: 431 | - 'sample': a random sample from the model. 432 | - 'pred_xstart': a prediction of x_0. 433 | """ 434 | out = self.p_mean_variance( 435 | model, 436 | x, 437 | t, 438 | clip_denoised=clip_denoised, 439 | denoised_fn=denoised_fn, 440 | model_kwargs=model_kwargs, 441 | ) 442 | noise = th.randn_like(x) 443 | nonzero_mask = ( 444 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 445 | ) # no noise when t == 0 446 | if cond_fn is not None: 447 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) 448 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 449 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 450 | 451 | def p_sample_loop( 452 | self, 453 | model, 454 | shape, 455 | noise=None, 456 | clip_denoised=False, 457 | denoised_fn=None, 458 | cond_fn=None, 459 | model_kwargs=None, 460 | device=None, 461 | progress=False, 462 | temp=1.0, 463 | ): 464 | """ 465 | Generate samples from the model. 466 | 467 | :param model: the model module. 468 | :param shape: the shape of the samples, (N, C, H, W). 469 | :param noise: if specified, the noise from the encoder to sample. 470 | Should be of the same shape as `shape`. 471 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 472 | :param denoised_fn: if not None, a function which applies to the 473 | x_start prediction before it is used to sample. 474 | :param cond_fn: if not None, this is a gradient function that acts 475 | similarly to the model. 476 | :param model_kwargs: if not None, a dict of extra keyword arguments to 477 | pass to the model. This can be used for conditioning. 478 | :param device: if specified, the device to create the samples on. 479 | If not specified, use a model parameter's device. 480 | :param progress: if True, show a tqdm progress bar. 481 | :return: a non-differentiable batch of samples. 482 | """ 483 | final = None 484 | for sample in self.p_sample_loop_progressive( 485 | model, 486 | shape, 487 | noise=noise, 488 | clip_denoised=clip_denoised, 489 | denoised_fn=denoised_fn, 490 | cond_fn=cond_fn, 491 | model_kwargs=model_kwargs, 492 | device=device, 493 | progress=progress, 494 | temp=temp, 495 | ): 496 | final = sample 497 | return final["sample"] 498 | 499 | def p_sample_loop_progressive( 500 | self, 501 | model, 502 | shape, 503 | noise=None, 504 | clip_denoised=False, 505 | denoised_fn=None, 506 | cond_fn=None, 507 | model_kwargs=None, 508 | device=None, 509 | progress=False, 510 | temp=1.0, 511 | ): 512 | """ 513 | Generate samples from the model and yield intermediate samples from 514 | each timestep of diffusion. 515 | 516 | Arguments are the same as p_sample_loop(). 517 | Returns a generator over dicts, where each dict is the return value of 518 | p_sample(). 519 | """ 520 | if device is None: 521 | device = next(model.parameters()).device 522 | assert isinstance(shape, (tuple, list)) 523 | if noise is not None: 524 | img = noise 525 | else: 526 | img = th.randn(*shape, device=device) * temp 527 | indices = list(range(self.num_timesteps))[::-1] 528 | 529 | if progress: 530 | # Lazy import so that we don't depend on tqdm. 531 | from tqdm.auto import tqdm 532 | 533 | indices = tqdm(indices) 534 | 535 | for i in indices: 536 | t = th.tensor([i] * shape[0], device=device) 537 | with th.no_grad(): 538 | out = self.p_sample( 539 | model, 540 | img, 541 | t, 542 | clip_denoised=clip_denoised, 543 | denoised_fn=denoised_fn, 544 | cond_fn=cond_fn, 545 | model_kwargs=model_kwargs, 546 | ) 547 | yield self.unscale_out_dict(out) 548 | img = out["sample"] 549 | 550 | def ddim_sample( 551 | self, 552 | model, 553 | x, 554 | t, 555 | clip_denoised=False, 556 | denoised_fn=None, 557 | cond_fn=None, 558 | model_kwargs=None, 559 | eta=0.0, 560 | ): 561 | """ 562 | Sample x_{t-1} from the model using DDIM. 563 | 564 | Same usage as p_sample(). 565 | """ 566 | out = self.p_mean_variance( 567 | model, 568 | x, 569 | t, 570 | clip_denoised=clip_denoised, 571 | denoised_fn=denoised_fn, 572 | model_kwargs=model_kwargs, 573 | ) 574 | if cond_fn is not None: 575 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 576 | 577 | # Usually our model outputs epsilon, but we re-derive it 578 | # in case we used x_start or x_prev prediction. 579 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 580 | 581 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 582 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 583 | sigma = ( 584 | eta 585 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 586 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 587 | ) 588 | # Equation 12. 589 | noise = th.randn_like(x) 590 | mean_pred = ( 591 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 592 | + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps 593 | ) 594 | nonzero_mask = ( 595 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 596 | ) # no noise when t == 0 597 | sample = mean_pred + nonzero_mask * sigma * noise 598 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 599 | 600 | def ddim_reverse_sample( 601 | self, 602 | model, 603 | x, 604 | t, 605 | clip_denoised=False, 606 | denoised_fn=None, 607 | cond_fn=None, 608 | model_kwargs=None, 609 | eta=0.0, 610 | ): 611 | """ 612 | Sample x_{t+1} from the model using DDIM reverse ODE. 613 | """ 614 | assert eta == 0.0, "Reverse ODE only for deterministic path" 615 | out = self.p_mean_variance( 616 | model, 617 | x, 618 | t, 619 | clip_denoised=clip_denoised, 620 | denoised_fn=denoised_fn, 621 | model_kwargs=model_kwargs, 622 | ) 623 | if cond_fn is not None: 624 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 625 | # Usually our model outputs epsilon, but we re-derive it 626 | # in case we used x_start or x_prev prediction. 627 | eps = ( 628 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 629 | - out["pred_xstart"] 630 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 631 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 632 | 633 | # Equation 12. reversed 634 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps 635 | 636 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 637 | 638 | def ddim_sample_loop( 639 | self, 640 | model, 641 | shape, 642 | noise=None, 643 | clip_denoised=False, 644 | denoised_fn=None, 645 | cond_fn=None, 646 | model_kwargs=None, 647 | device=None, 648 | progress=False, 649 | eta=0.0, 650 | temp=1.0, 651 | ): 652 | """ 653 | Generate samples from the model using DDIM. 654 | 655 | Same usage as p_sample_loop(). 656 | """ 657 | final = None 658 | for sample in self.ddim_sample_loop_progressive( 659 | model, 660 | shape, 661 | noise=noise, 662 | clip_denoised=clip_denoised, 663 | denoised_fn=denoised_fn, 664 | cond_fn=cond_fn, 665 | model_kwargs=model_kwargs, 666 | device=device, 667 | progress=progress, 668 | eta=eta, 669 | temp=temp, 670 | ): 671 | final = sample 672 | return final["sample"] 673 | 674 | def ddim_sample_loop_progressive( 675 | self, 676 | model, 677 | shape, 678 | noise=None, 679 | clip_denoised=False, 680 | denoised_fn=None, 681 | cond_fn=None, 682 | model_kwargs=None, 683 | device=None, 684 | progress=False, 685 | eta=0.0, 686 | temp=1.0, 687 | ): 688 | """ 689 | Use DDIM to sample from the model and yield intermediate samples from 690 | each timestep of DDIM. 691 | 692 | Same usage as p_sample_loop_progressive(). 693 | """ 694 | if device is None: 695 | device = next(model.parameters()).device 696 | assert isinstance(shape, (tuple, list)) 697 | if noise is not None: 698 | img = noise 699 | else: 700 | img = th.randn(*shape, device=device) * temp 701 | indices = list(range(self.num_timesteps))[::-1] 702 | 703 | if progress: 704 | # Lazy import so that we don't depend on tqdm. 705 | from tqdm.auto import tqdm 706 | 707 | indices = tqdm(indices) 708 | 709 | for i in indices: 710 | t = th.tensor([i] * shape[0], device=device) 711 | with th.no_grad(): 712 | out = self.ddim_sample( 713 | model, 714 | img, 715 | t, 716 | clip_denoised=clip_denoised, 717 | denoised_fn=denoised_fn, 718 | cond_fn=cond_fn, 719 | model_kwargs=model_kwargs, 720 | eta=eta, 721 | ) 722 | yield self.unscale_out_dict(out) 723 | img = out["sample"] 724 | 725 | def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None): 726 | """ 727 | Get a term for the variational lower-bound. 728 | 729 | The resulting units are bits (rather than nats, as one might expect). 730 | This allows for comparison to other papers. 731 | 732 | :return: a dict with the following keys: 733 | - 'output': a shape [N] tensor of NLLs or KLs. 734 | - 'pred_xstart': the x_0 predictions. 735 | """ 736 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 737 | x_start=x_start, x_t=x_t, t=t 738 | ) 739 | out = self.p_mean_variance( 740 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 741 | ) 742 | kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) 743 | kl = mean_flat(kl) / np.log(2.0) 744 | 745 | decoder_nll = -discretized_gaussian_log_likelihood( 746 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 747 | ) 748 | if not self.discretized_t0: 749 | decoder_nll = th.zeros_like(decoder_nll) 750 | assert decoder_nll.shape == x_start.shape 751 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 752 | 753 | # At the first timestep return the decoder NLL, 754 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 755 | output = th.where((t == 0), decoder_nll, kl) 756 | return { 757 | "output": output, 758 | "pred_xstart": out["pred_xstart"], 759 | "extra": out["extra"], 760 | } 761 | 762 | def training_losses( 763 | self, model, x_start, t, model_kwargs=None, noise=None 764 | ) -> Dict[str, th.Tensor]: 765 | """ 766 | Compute training losses for a single timestep. 767 | 768 | :param model: the model to evaluate loss on. 769 | :param x_start: the [N x C x ...] tensor of inputs. 770 | :param t: a batch of timestep indices. 771 | :param model_kwargs: if not None, a dict of extra keyword arguments to 772 | pass to the model. This can be used for conditioning. 773 | :param noise: if specified, the specific Gaussian noise to try to remove. 774 | :return: a dict with the key "loss" containing a tensor of shape [N]. 775 | Some mean or variance settings may also have other keys. 776 | """ 777 | x_start = self.scale_channels(x_start) 778 | if model_kwargs is None: 779 | model_kwargs = {} 780 | if noise is None: 781 | noise = th.randn_like(x_start) 782 | x_t = self.q_sample(x_start, t, noise=noise) 783 | 784 | terms = {} 785 | 786 | if self.loss_type == "kl" or self.loss_type == "rescaled_kl": 787 | vb_terms = self._vb_terms_bpd( 788 | model=model, 789 | x_start=x_start, 790 | x_t=x_t, 791 | t=t, 792 | clip_denoised=False, 793 | model_kwargs=model_kwargs, 794 | ) 795 | terms["loss"] = vb_terms["output"] 796 | if self.loss_type == "rescaled_kl": 797 | terms["loss"] *= self.num_timesteps 798 | extra = vb_terms["extra"] 799 | elif self.loss_type == "mse" or self.loss_type == "rescaled_mse": 800 | model_output = model(x_t, t, **model_kwargs) 801 | if isinstance(model_output, tuple): 802 | model_output, extra = model_output 803 | else: 804 | extra = {} 805 | 806 | if self.model_var_type in [ 807 | "learned", 808 | "learned_range", 809 | ]: 810 | B, C = x_t.shape[:2] 811 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 812 | model_output, model_var_values = th.split(model_output, C, dim=1) 813 | # Learn the variance using the variational bound, but don't let 814 | # it affect our mean prediction. 815 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 816 | terms["vb"] = self._vb_terms_bpd( 817 | model=lambda *args, r=frozen_out: r, 818 | x_start=x_start, 819 | x_t=x_t, 820 | t=t, 821 | clip_denoised=False, 822 | )["output"] 823 | if self.loss_type == "rescaled_mse": 824 | # Divide by 1000 for equivalence with initial implementation. 825 | # Without a factor of 1/1000, the VB term hurts the MSE term. 826 | terms["vb"] *= self.num_timesteps / 1000.0 827 | 828 | target = { 829 | "x_prev": self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], 830 | "x_start": x_start, 831 | "epsilon": noise, 832 | }[self.model_mean_type] 833 | assert model_output.shape == target.shape == x_start.shape 834 | terms["mse"] = mean_flat((target - model_output) ** 2) 835 | if "vb" in terms: 836 | terms["loss"] = terms["mse"] + terms["vb"] 837 | else: 838 | terms["loss"] = terms["mse"] 839 | else: 840 | raise NotImplementedError(self.loss_type) 841 | 842 | if "losses" in extra: 843 | terms.update({k: loss for k, (loss, _scale) in extra["losses"].items()}) 844 | for loss, scale in extra["losses"].values(): 845 | terms["loss"] = terms["loss"] + loss * scale 846 | 847 | return terms 848 | 849 | def _prior_bpd(self, x_start): 850 | """ 851 | Get the prior KL term for the variational lower-bound, measured in 852 | bits-per-dim. 853 | 854 | This term can't be optimized, as it only depends on the encoder. 855 | 856 | :param x_start: the [N x C x ...] tensor of inputs. 857 | :return: a batch of [N] KL values (in bits), one per batch element. 858 | """ 859 | batch_size = x_start.shape[0] 860 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 861 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 862 | kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) 863 | return mean_flat(kl_prior) / np.log(2.0) 864 | 865 | def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None): 866 | """ 867 | Compute the entire variational lower-bound, measured in bits-per-dim, 868 | as well as other related quantities. 869 | 870 | :param model: the model to evaluate loss on. 871 | :param x_start: the [N x C x ...] tensor of inputs. 872 | :param clip_denoised: if True, clip denoised samples. 873 | :param model_kwargs: if not None, a dict of extra keyword arguments to 874 | pass to the model. This can be used for conditioning. 875 | 876 | :return: a dict containing the following keys: 877 | - total_bpd: the total variational lower-bound, per batch element. 878 | - prior_bpd: the prior term in the lower-bound. 879 | - vb: an [N x T] tensor of terms in the lower-bound. 880 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 881 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 882 | """ 883 | device = x_start.device 884 | batch_size = x_start.shape[0] 885 | 886 | vb = [] 887 | xstart_mse = [] 888 | mse = [] 889 | for t in list(range(self.num_timesteps))[::-1]: 890 | t_batch = th.tensor([t] * batch_size, device=device) 891 | noise = th.randn_like(x_start) 892 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 893 | # Calculate VLB term at the current timestep 894 | with th.no_grad(): 895 | out = self._vb_terms_bpd( 896 | model, 897 | x_start=x_start, 898 | x_t=x_t, 899 | t=t_batch, 900 | clip_denoised=clip_denoised, 901 | model_kwargs=model_kwargs, 902 | ) 903 | vb.append(out["output"]) 904 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 905 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 906 | mse.append(mean_flat((eps - noise) ** 2)) 907 | 908 | vb = th.stack(vb, dim=1) 909 | xstart_mse = th.stack(xstart_mse, dim=1) 910 | mse = th.stack(mse, dim=1) 911 | 912 | prior_bpd = self._prior_bpd(x_start) 913 | total_bpd = vb.sum(dim=1) + prior_bpd 914 | return { 915 | "total_bpd": total_bpd, 916 | "prior_bpd": prior_bpd, 917 | "vb": vb, 918 | "xstart_mse": xstart_mse, 919 | "mse": mse, 920 | } 921 | 922 | def scale_channels(self, x: th.Tensor) -> th.Tensor: 923 | if self.channel_scales is not None: 924 | x = x * th.from_numpy(self.channel_scales).to(x).reshape( 925 | [1, -1, *([1] * (len(x.shape) - 2))] 926 | ) 927 | if self.channel_biases is not None: 928 | x = x + th.from_numpy(self.channel_biases).to(x).reshape( 929 | [1, -1, *([1] * (len(x.shape) - 2))] 930 | ) 931 | return x 932 | 933 | def unscale_channels(self, x: th.Tensor) -> th.Tensor: 934 | if self.channel_biases is not None: 935 | x = x - th.from_numpy(self.channel_biases).to(x).reshape( 936 | [1, -1, *([1] * (len(x.shape) - 2))] 937 | ) 938 | if self.channel_scales is not None: 939 | x = x / th.from_numpy(self.channel_scales).to(x).reshape( 940 | [1, -1, *([1] * (len(x.shape) - 2))] 941 | ) 942 | return x 943 | 944 | def unscale_out_dict( 945 | self, out: Dict[str, Union[th.Tensor, Any]] 946 | ) -> Dict[str, Union[th.Tensor, Any]]: 947 | return { 948 | k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) for k, v in out.items() 949 | } 950 | 951 | 952 | class SpacedDiffusion(GaussianDiffusion): 953 | """ 954 | A diffusion process which can skip steps in a base diffusion process. 955 | :param use_timesteps: (unordered) timesteps from the original diffusion 956 | process to retain. 957 | :param kwargs: the kwargs to create the base diffusion process. 958 | """ 959 | 960 | def __init__(self, use_timesteps: Iterable[int], **kwargs): 961 | self.use_timesteps = set(use_timesteps) 962 | self.timestep_map = [] 963 | self.original_num_steps = len(kwargs["betas"]) 964 | 965 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 966 | last_alpha_cumprod = 1.0 967 | new_betas = [] 968 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 969 | if i in self.use_timesteps: 970 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 971 | last_alpha_cumprod = alpha_cumprod 972 | self.timestep_map.append(i) 973 | kwargs["betas"] = np.array(new_betas) 974 | super().__init__(**kwargs) 975 | 976 | def p_mean_variance(self, model, *args, **kwargs): 977 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 978 | 979 | def training_losses(self, model, *args, **kwargs): 980 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 981 | 982 | def condition_mean(self, cond_fn, *args, **kwargs): 983 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 984 | 985 | def condition_score(self, cond_fn, *args, **kwargs): 986 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 987 | 988 | def _wrap_model(self, model): 989 | if isinstance(model, _WrappedModel): 990 | return model 991 | return _WrappedModel(model, self.timestep_map, self.original_num_steps) 992 | 993 | 994 | class _WrappedModel: 995 | def __init__(self, model, timestep_map, original_num_steps): 996 | self.model = model 997 | self.timestep_map = timestep_map 998 | self.original_num_steps = original_num_steps 999 | 1000 | def __call__(self, x, ts, **kwargs): 1001 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 1002 | new_ts = map_tensor[ts] 1003 | return self.model(x, new_ts, **kwargs) 1004 | 1005 | 1006 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 1007 | """ 1008 | Extract values from a 1-D numpy array for a batch of indices. 1009 | 1010 | :param arr: the 1-D numpy array. 1011 | :param timesteps: a tensor of indices into the array to extract. 1012 | :param broadcast_shape: a larger shape of K dimensions with the batch 1013 | dimension equal to the length of timesteps. 1014 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 1015 | """ 1016 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 1017 | while len(res.shape) < len(broadcast_shape): 1018 | res = res[..., None] 1019 | return res + th.zeros(broadcast_shape, device=timesteps.device) 1020 | 1021 | 1022 | def normal_kl(mean1, logvar1, mean2, logvar2): 1023 | """ 1024 | Compute the KL divergence between two gaussians. 1025 | Shapes are automatically broadcasted, so batches can be compared to 1026 | scalars, among other use cases. 1027 | """ 1028 | tensor = None 1029 | for obj in (mean1, logvar1, mean2, logvar2): 1030 | if isinstance(obj, th.Tensor): 1031 | tensor = obj 1032 | break 1033 | assert tensor is not None, "at least one argument must be a Tensor" 1034 | 1035 | # Force variances to be Tensors. Broadcasting helps convert scalars to 1036 | # Tensors, but it does not work for th.exp(). 1037 | logvar1, logvar2 = [ 1038 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2) 1039 | ] 1040 | 1041 | return 0.5 * ( 1042 | -1.0 1043 | + logvar2 1044 | - logvar1 1045 | + th.exp(logvar1 - logvar2) 1046 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 1047 | ) 1048 | 1049 | 1050 | def approx_standard_normal_cdf(x): 1051 | """ 1052 | A fast approximation of the cumulative distribution function of the 1053 | standard normal. 1054 | """ 1055 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 1056 | 1057 | 1058 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 1059 | """ 1060 | Compute the log-likelihood of a Gaussian distribution discretizing to a 1061 | given image. 1062 | :param x: the target images. It is assumed that this was uint8 values, 1063 | rescaled to the range [-1, 1]. 1064 | :param means: the Gaussian mean Tensor. 1065 | :param log_scales: the Gaussian log stddev Tensor. 1066 | :return: a tensor like x of log probabilities (in nats). 1067 | """ 1068 | assert x.shape == means.shape == log_scales.shape 1069 | centered_x = x - means 1070 | inv_stdv = th.exp(-log_scales) 1071 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 1072 | cdf_plus = approx_standard_normal_cdf(plus_in) 1073 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 1074 | cdf_min = approx_standard_normal_cdf(min_in) 1075 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 1076 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 1077 | cdf_delta = cdf_plus - cdf_min 1078 | log_probs = th.where( 1079 | x < -0.999, 1080 | log_cdf_plus, 1081 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 1082 | ) 1083 | assert log_probs.shape == x.shape 1084 | return log_probs 1085 | 1086 | 1087 | def mean_flat(tensor): 1088 | """ 1089 | Take the mean over all non-batch dimensions. 1090 | """ 1091 | return tensor.flatten(1).mean(1) 1092 | --------------------------------------------------------------------------------