├── teaser.gif ├── README.md ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── utils.py ├── modules.py ├── model.py └── LICENSE /teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/PointInfinity/HEAD/teaser.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointInfinity 2 | 3 | [[Project Page]](https://zixuanh.com/projects/pointinfinity) [[Paper]](https://zixuanh.com/projects/pointinfinity/paper.pdf) 4 | 5 | 6 | 7 | This repository currently includes the core code of the default denoiser in PointInfinity. 8 | 9 | ## License 10 | The majority of PointInfinity is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Point-E are licensed under the MIT license. 11 | 12 | ## Acknowlegement and Reference 13 | Part of this implementation is based on [MCC](https://github.com/facebookresearch/MCC), [Point-E](https://github.com/openai/point-e) and [RIN](https://arxiv.org/pdf/2212.11972). If you find our work helpful, please consider citing these works, as well as ours: 14 | ``` 15 | @inproceedings{huang2024pointinfinity, 16 | title={PointInfinity: Resolution-Invariant Point Diffusion Models}, 17 | author={Huang, Zixuan and Johnson, Justin and Debnath, Shoubhik and Rehg, James M and Wu, Chao-Yuan}, 18 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 19 | year={2024} 20 | } 21 | ``` -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to PointInfinity 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to MCC, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # util methods 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # MCC: https://github.com/facebookresearch/MCC 19 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 20 | # MoCo v3: https://github.com/facebookresearch/moco-v3 21 | # -------------------------------------------------------- 22 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 23 | """ 24 | grid_size: int of the grid height and width 25 | return: 26 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 27 | """ 28 | grid_h = np.arange(grid_size, dtype=np.float32) 29 | grid_w = np.arange(grid_size, dtype=np.float32) 30 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 31 | grid = np.stack(grid, axis=0) 32 | 33 | grid = grid.reshape([2, 1, grid_size, grid_size]) 34 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 35 | if cls_token: 36 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 37 | return pos_embed 38 | 39 | 40 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 41 | assert embed_dim % 2 == 0 42 | 43 | # use half of dimensions to encode grid_h 44 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 45 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 46 | 47 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 48 | return emb 49 | 50 | 51 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 52 | """ 53 | embed_dim: output dimension for each position 54 | pos: a list of positions to be encoded: size (M,) 55 | out: (M, D) 56 | """ 57 | assert embed_dim % 2 == 0 58 | omega = np.arange(embed_dim // 2, dtype=np.float) 59 | omega /= embed_dim / 2. 60 | omega = 1. / 10000**omega # (D/2,) 61 | 62 | pos = pos.reshape(-1) # (M,) 63 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 64 | 65 | emb_sin = np.sin(out) # (M, D/2) 66 | emb_cos = np.cos(out) # (M, D/2) 67 | 68 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 69 | return emb 70 | 71 | # -------------------------------------------------------- 72 | # Timestep embedding 73 | # References: 74 | # Point-E: https://github.com/openai/point-e 75 | # -------------------------------------------------------- 76 | def timestep_embedding(timesteps, dim, max_period=10000): 77 | """ 78 | Create sinusoidal timestep embeddings. 79 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 80 | These may be fractional. 81 | :param dim: the dimension of the output. 82 | :param max_period: controls the minimum frequency of the embeddings. 83 | :return: an [N x dim] Tensor of positional embeddings. 84 | """ 85 | half = dim // 2 86 | freqs = torch.exp( 87 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 88 | ).to(device=timesteps.device) 89 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None] 90 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 91 | if dim % 2: 92 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 93 | return embedding 94 | 95 | # -------------------------------------------------------- 96 | # Image preprocessor 97 | # References: 98 | # MCC: https://github.com/facebookresearch/MCC 99 | # -------------------------------------------------------- 100 | def preprocess_img(x): 101 | """ 102 | Preprocess images for MCC encoder. 103 | """ 104 | if x.shape[2] != 224: 105 | x = F.interpolate( 106 | x, 107 | scale_factor=224./x.shape[2], 108 | mode="bilinear", 109 | ) 110 | resnet_mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).reshape((1, 3, 1, 1)) 111 | resnet_std = torch.tensor([0.229, 0.224, 0.225], device=x.device).reshape((1, 3, 1, 1)) 112 | imgs_normed = (x - resnet_mean) / resnet_std 113 | return imgs_normed 114 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # RIN: https://arxiv.org/pdf/2212.11972 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from timm.models.vision_transformer import Mlp, DropPath 15 | from utils import timestep_embedding 16 | 17 | class CrossAttention(nn.Module): 18 | def __init__( 19 | self, 20 | dim, 21 | kv_dim=None, 22 | num_heads=16, 23 | qkv_bias=False, 24 | attn_drop=0., 25 | proj_drop=0., 26 | ): 27 | super().__init__() 28 | self.num_heads = num_heads 29 | head_dim = dim // num_heads 30 | self.scale = head_dim ** -0.5 31 | 32 | kv_dim = dim if not kv_dim else kv_dim 33 | self.wq = nn.Linear(dim, dim, bias=qkv_bias) 34 | self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias) 35 | self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias) 36 | self.attn_drop_rate = attn_drop 37 | self.attn_drop = nn.Dropout(self.attn_drop_rate) 38 | self.proj = nn.Linear(dim, dim) 39 | self.proj_drop = nn.Dropout(proj_drop) 40 | 41 | def forward(self, x_q, x_kv): 42 | B, N_q, C = x_q.shape 43 | B, N_kv, _ = x_kv.shape 44 | # [B, N_q, C] -> [B, N_q, H, C/H] -> [B, H, N_q, C/H] 45 | q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 46 | # [B, N_kv, C] -> [B, N_kv, H, C/H] -> [B, H, N_kv, C/H] 47 | k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 48 | # [B, N_kv, C] -> [B, N_kv, H, C/H] -> [B, H, N_kv, C/H] 49 | v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 50 | 51 | # [B, H, N_q, C/H] @ [B, H, C/H, N_kv] -> [B, H, N_q, N_kv] 52 | attn = (q @ k.transpose(-2, -1)) * self.scale 53 | attn = attn.softmax(dim=-1) 54 | attn = self.attn_drop(attn) 55 | 56 | # [B, H, N_q, N_kv] @ [B, H, N_kv, C/H] -> [B, H, N_q, C/H] 57 | x = attn @ v 58 | 59 | # [B, H, N_q, C/H] -> [B, N_q, C] 60 | x = x.transpose(1, 2).reshape(B, N_q, C) 61 | x = self.proj(x) 62 | x = self.proj_drop(x) 63 | return x 64 | 65 | class Compute_Block(nn.Module): 66 | 67 | def __init__(self, z_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 68 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 69 | super().__init__() 70 | self.norm_z1 = norm_layer(z_dim) 71 | self.attn = CrossAttention( 72 | z_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 73 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 74 | self.norm_z2 = norm_layer(z_dim) 75 | mlp_hidden_dim = int(z_dim * mlp_ratio) 76 | self.mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 77 | 78 | def forward(self, z): 79 | zn = self.norm_z1(z) 80 | z = z + self.drop_path(self.attn(zn, zn)) 81 | z = z + self.drop_path(self.mlp(self.norm_z2(z))) 82 | return z 83 | 84 | class Read_Block(nn.Module): 85 | 86 | def __init__(self, z_dim, x_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 87 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 88 | super().__init__() 89 | self.norm_x = norm_layer(x_dim) 90 | self.norm_z1 = norm_layer(z_dim) 91 | self.attn = CrossAttention( 92 | z_dim, x_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 93 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 94 | self.norm_z2 = norm_layer(z_dim) 95 | mlp_hidden_dim = int(z_dim * mlp_ratio) 96 | self.mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 97 | 98 | def forward(self, z, x): 99 | z = z + self.drop_path(self.attn(self.norm_z1(z), self.norm_x(x))) 100 | z = z + self.drop_path(self.mlp(self.norm_z2(z))) 101 | return z 102 | 103 | class Write_Block(nn.Module): 104 | 105 | def __init__(self, z_dim, x_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 106 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 107 | super().__init__() 108 | self.norm_z = norm_layer(z_dim) 109 | self.norm_x1 = norm_layer(x_dim) 110 | self.attn = CrossAttention( 111 | x_dim, z_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 112 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 113 | self.norm_x2 = norm_layer(x_dim) 114 | mlp_hidden_dim = int(x_dim * mlp_ratio) 115 | self.mlp = Mlp(in_features=x_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 116 | 117 | def forward(self, z, x): 118 | x = x + self.drop_path(self.attn(self.norm_x1(x), self.norm_z(z))) 119 | x = x + self.drop_path(self.mlp(self.norm_x2(x))) 120 | return x 121 | 122 | class RCW_Block(nn.Module): 123 | 124 | def __init__(self, z_dim, x_dim, num_compute_layers=4, num_heads=16, 125 | mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 126 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 127 | super().__init__() 128 | self.read = Read_Block(z_dim, x_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, 129 | attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer) 130 | self.write = Write_Block(z_dim, x_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, 131 | attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer) 132 | self.compute = nn.ModuleList([ 133 | Compute_Block(z_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, 134 | attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer) 135 | for _ in range(num_compute_layers) 136 | ]) 137 | 138 | def forward(self, z, x): 139 | z = self.read(z, x) 140 | for layer in self.compute: 141 | z = layer(z) 142 | x = self.write(z, x) 143 | return z, x 144 | 145 | class Denoiser_backbone(nn.Module): 146 | def __init__(self, input_channels=3, output_channels=3, 147 | num_z=256, num_x=4096, z_dim=768, x_dim=512, 148 | num_blocks=6, num_compute_layers=4, num_heads=8, 149 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., 150 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 151 | super().__init__() 152 | 153 | self.num_z = num_z 154 | self.num_x = num_x 155 | self.z_dim = z_dim 156 | 157 | # input blocks 158 | self.input_proj = nn.Linear(input_channels, x_dim) 159 | self.ln_pre = nn.LayerNorm(x_dim) 160 | self.z_init = nn.Parameter(torch.zeros(1, num_z, z_dim)) 161 | 162 | # timestep embedding 163 | mlp_hidden_dim = int(z_dim * mlp_ratio) 164 | self.time_embed = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim) 165 | 166 | # RCW blocks 167 | self.latent_mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 168 | self.ln_latent = nn.LayerNorm(z_dim) 169 | self.blocks = nn.ModuleList([ 170 | RCW_Block(z_dim, x_dim, num_compute_layers=num_compute_layers, 171 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 172 | drop=drop, attn_drop=attn_drop, drop_path=drop_path, 173 | act_layer=act_layer, norm_layer=norm_layer) 174 | for _ in range(num_blocks) 175 | ]) 176 | 177 | # output blocks 178 | self.ln_post = nn.LayerNorm(x_dim) 179 | self.output_proj = nn.Linear(x_dim, output_channels) 180 | 181 | self.initialize_weights() 182 | 183 | def initialize_weights(self): 184 | nn.init.normal_(self.z_init, std=.02) 185 | 186 | # initialize nn.Linear and nn.LayerNorm 187 | self.apply(self._init_weights) 188 | 189 | nn.init.constant_(self.ln_latent.weight, 0) 190 | nn.init.constant_(self.ln_latent.bias, 0) 191 | 192 | def _init_weights(self, m): 193 | if isinstance(m, nn.Linear): 194 | torch.nn.init.xavier_uniform_(m.weight) 195 | if isinstance(m, nn.Linear) and m.bias is not None: 196 | nn.init.constant_(m.bias, 0) 197 | elif isinstance(m, nn.LayerNorm): 198 | nn.init.constant_(m.bias, 0) 199 | nn.init.constant_(m.weight, 1.0) 200 | 201 | def forward(self, x, t, cond, prev_latent): 202 | """ 203 | Forward pass of the model. 204 | 205 | Parameters: 206 | x: [B, num_x, C_in] 207 | t: [B] 208 | cond: [B, num_cond, C_latent] 209 | prev_latent: [B, num_z + num_cond + 1, C_latent] 210 | 211 | Returns: 212 | x_denoised: [B, num_x, C_out] 213 | z: [B, num_z + num_cond + 1, C_latent] 214 | """ 215 | B, num_x, _ = x.shape 216 | num_cond = cond.shape[1] 217 | assert num_x == self.num_x 218 | if prev_latent is not None: 219 | _, num_z, _ = prev_latent.shape 220 | assert num_z == self.num_z + num_cond + 1 221 | else: 222 | prev_latent = torch.zeros(B, self.num_z + num_cond + 1, self.z_dim).to(x.device) 223 | 224 | # timestep embedding, [B, 1, z_dim] 225 | t_embed = self.time_embed(timestep_embedding(t, self.z_dim)).unsqueeze(1) 226 | 227 | # project x -> [B, num_x, C_x] 228 | x = self.input_proj(x) 229 | x = self.ln_pre(x) 230 | 231 | # latent self-conditioning 232 | z = self.z_init.repeat(B, 1, 1) # [B, num_z, z_dim] 233 | z = torch.cat([z, cond, t_embed], dim=1) # [B, num_z + num_cond + 1, z_dim] 234 | prev_latent = prev_latent + self.latent_mlp(prev_latent.detach()) 235 | z = z + self.ln_latent(prev_latent) 236 | 237 | # compute 238 | for blk in self.blocks: 239 | z, x = blk(z, x) 240 | 241 | # output proj 242 | x = self.ln_post(x) 243 | x_denoised = self.output_proj(x) 244 | return x_denoised, z 245 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # MCC: https://github.com/facebookresearch/MCC 9 | # Point-E: https://github.com/openai/point-e 10 | # RIN: https://arxiv.org/pdf/2212.11972 11 | # This code includes the implementation of our default two-stream model. 12 | # Our default two-stream implementation is based on RIN and MCC, 13 | # Other backbone in the two-stream family such as PerceiverIO will also work. 14 | # -------------------------------------------------------- 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | from functools import partial 21 | from timm.models.vision_transformer import PatchEmbed, Block 22 | from utils import get_2d_sincos_pos_embed, preprocess_img 23 | from modules import Denoiser_backbone 24 | 25 | class XYZPosEmbed(nn.Module): 26 | """ 27 | A Masked Autoencoder with VisionTransformer backbone. 28 | """ 29 | def __init__(self, embed_dim, num_heads): 30 | super().__init__() 31 | self.embed_dim = embed_dim 32 | 33 | self.two_d_pos_embed = nn.Parameter( 34 | torch.zeros(1, 64 + 1, embed_dim), requires_grad=False) 35 | 36 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 37 | self.win_size = 8 38 | 39 | self.pos_embed = nn.Linear(3, embed_dim) 40 | 41 | self.blocks = nn.ModuleList([ 42 | Block(embed_dim, num_heads=num_heads, mlp_ratio=2.0, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 43 | for _ in range(1) 44 | ]) 45 | 46 | self.invalid_xyz_token = nn.Parameter(torch.zeros(embed_dim,)) 47 | 48 | self.initialize_weights() 49 | 50 | def initialize_weights(self): 51 | torch.nn.init.normal_(self.cls_token, std=.02) 52 | 53 | two_d_pos_embed = get_2d_sincos_pos_embed(self.two_d_pos_embed.shape[-1], 8, cls_token=True) 54 | self.two_d_pos_embed.data.copy_(torch.from_numpy(two_d_pos_embed).float().unsqueeze(0)) 55 | 56 | torch.nn.init.normal_(self.invalid_xyz_token, std=.02) 57 | 58 | def forward(self, seen_xyz, valid_seen_xyz): 59 | emb = self.pos_embed(seen_xyz) 60 | 61 | emb[~valid_seen_xyz] = 0.0 62 | emb[~valid_seen_xyz] += self.invalid_xyz_token 63 | 64 | B, H, W, C = emb.shape 65 | emb = emb.view(B, H // self.win_size, self.win_size, W // self.win_size, self.win_size, C) 66 | emb = emb.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.win_size * self.win_size, C) 67 | 68 | emb = emb + self.two_d_pos_embed[:, 1:, :] 69 | cls_token = self.cls_token + self.two_d_pos_embed[:, :1, :] 70 | 71 | cls_tokens = cls_token.expand(emb.shape[0], -1, -1) 72 | emb = torch.cat((cls_tokens, emb), dim=1) 73 | for _, blk in enumerate(self.blocks): 74 | emb = blk(emb) 75 | return emb[:, 0].view(B, (H // self.win_size) * (W // self.win_size), -1) 76 | 77 | class MCCEncoder(nn.Module): 78 | """ 79 | MCC's RGB and XYZ encoder 80 | """ 81 | def __init__(self, 82 | img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, 83 | num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, drop_path=0.1): 84 | super().__init__() 85 | 86 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 87 | num_patches = self.patch_embed.num_patches 88 | self.n_tokens = num_patches + 1 89 | self.embed_dim = embed_dim 90 | 91 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 92 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) 93 | 94 | self.blocks = nn.ModuleList([ 95 | Block( 96 | embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, 97 | drop_path=drop_path 98 | ) for i in range(depth)]) 99 | 100 | self.norm = norm_layer(embed_dim) 101 | 102 | self.cls_token_xyz = nn.Parameter(torch.zeros(1, 1, embed_dim)) 103 | 104 | self.xyz_pos_embed = XYZPosEmbed(embed_dim, num_heads) 105 | 106 | self.blocks_xyz = nn.ModuleList([ 107 | Block( 108 | embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, 109 | drop_path=drop_path 110 | ) for i in range(depth)]) 111 | 112 | self.norm_xyz = norm_layer(embed_dim) 113 | 114 | self.initialize_weights() 115 | 116 | def initialize_weights(self): 117 | 118 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 119 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 120 | 121 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 122 | w = self.patch_embed.proj.weight.data 123 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 124 | 125 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 126 | torch.nn.init.normal_(self.cls_token, std=.02) 127 | torch.nn.init.normal_(self.cls_token_xyz, std=.02) 128 | 129 | # initialize nn.Linear and nn.LayerNorm 130 | self.apply(self._init_weights) 131 | 132 | def _init_weights(self, m): 133 | if isinstance(m, nn.Linear): 134 | torch.nn.init.xavier_uniform_(m.weight) 135 | if isinstance(m, nn.Linear) and m.bias is not None: 136 | nn.init.constant_(m.bias, 0) 137 | elif isinstance(m, nn.LayerNorm): 138 | nn.init.constant_(m.bias, 0) 139 | nn.init.constant_(m.weight, 1.0) 140 | 141 | def forward(self, x, seen_xyz, valid_seen_xyz): 142 | 143 | # get tokens 144 | x = self.patch_embed(x) 145 | x = x + self.pos_embed[:, 1:, :] 146 | y = self.xyz_pos_embed(seen_xyz, valid_seen_xyz) 147 | 148 | ##### forward E_XYZ ##### 149 | # append cls token 150 | cls_token_xyz = self.cls_token_xyz 151 | cls_tokens_xyz = cls_token_xyz.expand(y.shape[0], -1, -1) 152 | 153 | y = torch.cat((cls_tokens_xyz, y), dim=1) 154 | # apply Transformer blocks 155 | for blk in self.blocks_xyz: 156 | y = blk(y) 157 | y = self.norm_xyz(y) 158 | 159 | ##### forward E_RGB ##### 160 | # append cls token 161 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 162 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 163 | 164 | x = torch.cat((cls_tokens, x), dim=1) 165 | # apply Transformer blocks 166 | for blk in self.blocks: 167 | x = blk(x) 168 | x = self.norm(x) 169 | 170 | # combine encodings 171 | return torch.cat([x, y], dim=2) 172 | 173 | class TwoStreamDenoiser(nn.Module): 174 | ''' 175 | Full Point diffusion model using MCC's encoders with the Two Stream backbone 176 | ''' 177 | def __init__( 178 | self, 179 | num_points: int = 1024, 180 | num_latents: int = 256, 181 | cond_drop_prob: float = 0.1, 182 | input_channels: int = 6, 183 | output_channels: int = 6, 184 | latent_dim: int = 768, 185 | num_blocks: int = 6, 186 | num_compute_layers: int = 4, 187 | **kwargs, 188 | ): 189 | super().__init__() 190 | # define encoders 191 | self.mcc_encoder = MCCEncoder(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 192 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 193 | # define backbone 194 | self.denoiser_backbone = Denoiser_backbone(input_channels=input_channels, output_channels=output_channels, 195 | num_x=num_points, num_z=num_latents, z_dim=latent_dim, 196 | num_blocks=num_blocks, num_compute_layers=num_compute_layers) 197 | self.cond_embed = nn.Sequential( 198 | nn.LayerNorm( 199 | normalized_shape=(self.mcc_encoder.embed_dim*2,) 200 | ), 201 | nn.Linear(self.mcc_encoder.embed_dim*2, self.denoiser_backbone.z_dim), 202 | ) 203 | self.cond_drop_prob = cond_drop_prob 204 | self.num_points = num_points 205 | 206 | def cached_model_kwargs(self, model_kwargs): 207 | with torch.no_grad(): 208 | cond_dict = {} 209 | images = preprocess_img(model_kwargs["images"]) 210 | embeddings = self.mcc_encoder( 211 | images, 212 | model_kwargs["seen_xyz"], 213 | model_kwargs["seen_xyz_mask"], 214 | ) 215 | cond_dict["embeddings"] = embeddings 216 | if "prev_latent" in model_kwargs: 217 | cond_dict["prev_latent"] = model_kwargs["prev_latent"] 218 | return cond_dict 219 | 220 | def forward( 221 | self, 222 | x, 223 | t, 224 | images=None, 225 | seen_xyz=None, 226 | seen_xyz_mask=None, 227 | embeddings=None, 228 | prev_latent=None, 229 | ): 230 | """ 231 | Forward pass through the model. 232 | 233 | Parameters: 234 | x: Tensor of shape [B, C, N_points], raw input point cloud. 235 | t: Tensor of shape [B], time step. 236 | images (Tensor, optional): A batch of images to condition on. 237 | seen_xyz (Tensor, optional): A batch of xyz maps to condition on. 238 | seen_xyz_mask (Tensor, optional): Validity mask for xyz maps. 239 | embeddings (Tensor, optional): A batch of conditional latent (avoid duplicate 240 | computation of MCC encoder in diffusion inference) 241 | prev_latent (Tensor, optional): Self-conditioning latent. 242 | 243 | Returns: 244 | x_denoised: Tensor of shape [B, C, N_points], denoised point cloud/noise. 245 | """ 246 | assert images is not None or embeddings is not None, "must specify images or embeddings" 247 | assert images is None or embeddings is None, "cannot specify both images and embeddings" 248 | assert x.shape[-1] == self.num_points 249 | 250 | # get the condition vectors with MCC encoders 251 | if images is not None: 252 | images = preprocess_img(images) 253 | cond_vec = self.mcc_encoder(images, seen_xyz, seen_xyz_mask) 254 | else: 255 | cond_vec = embeddings 256 | # condition dropout 257 | if self.training: 258 | mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob 259 | cond_vec = cond_vec * mask[:, None, None].to(cond_vec) 260 | cond_vec = self.cond_embed(cond_vec) 261 | 262 | # denoiser forward 263 | x_denoised, latent = self.denoiser_backbone(x.permute(0, 2, 1).contiguous(), t, cond_vec, prev_latent=prev_latent) 264 | x_denoised = x_denoised.permute(0, 2, 1).contiguous() 265 | return x_denoised, latent -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. --------------------------------------------------------------------------------