├── src ├── datasets │ ├── __init__.py │ ├── slam_utils.py │ ├── umi_utils.py │ ├── umi_pos_utils.py │ ├── DistributedSaveableSampler.py │ └── write_translations.py ├── utils │ ├── trainers │ │ ├── __init__.py │ │ ├── Base_trainer.py │ │ └── Score_sde_trainer.py │ ├── __init__.py │ ├── init_instance_dirs.py │ ├── tlog.py │ ├── Attribute_dict.py │ ├── camera_tools.py │ └── score_tools.py ├── models │ ├── score_sde │ │ ├── op │ │ │ ├── __init__.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── upfirdn2d.cpp │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ │ ├── __init__.py │ │ ├── configs │ │ │ ├── LDM.py │ │ │ └── conditional.py │ │ ├── CrossAttnBlockpp.py │ │ ├── ema.py │ │ ├── utils.py │ │ ├── ddpm.py │ │ ├── normalization.py │ │ ├── sde_lib.py │ │ ├── up_or_down_sampling.py │ │ ├── layerspp.py │ │ └── ncsnpp.py │ └── vqgan │ │ ├── configs │ │ ├── vqgan_32_256.py │ │ └── vqgan_32_4.py │ │ ├── quantize.py │ │ └── vqgan.py └── scripts │ ├── free-port.py │ ├── mini_train.py │ ├── generate_inference_json.py │ └── sample-imgs-multi.py ├── data ├── mask_box_maxlens.png ├── colmap.sh ├── generate_colmap_labels.py └── generate_slam_labels.py ├── LICENSE ├── .gitignore └── environment.yaml /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .DistributedSaveableSampler import DistributedSaveableSampler 2 | -------------------------------------------------------------------------------- /data/mask_box_maxlens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ErinZhang1998/dmd_diffusion/HEAD/data/mask_box_maxlens.png -------------------------------------------------------------------------------- /src/utils/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Base_trainer import Base_trainer 2 | from .Score_sde_trainer import Score_sde_trainer 3 | -------------------------------------------------------------------------------- /src/models/score_sde/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /src/scripts/free-port.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import sys 3 | 4 | ip = sys.argv[1] 5 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 6 | s.bind((ip, 0)) 7 | addr = s.getsockname() 8 | print(addr[1],end='') 9 | s.close() 10 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.Attribute_dict import * 2 | globals = Attribute_dict({}) 3 | globals.config = Attribute_dict({}) 4 | 5 | from .tlog import tlog 6 | from .init_instance_dirs import * 7 | from .camera_tools import rel_camera_ray_encoding, abs_cameras_freq_encoding, freq_enc 8 | from .score_tools import Score_sde_model, Score_modifier, Score_sde_monocular_model 9 | -------------------------------------------------------------------------------- /src/utils/init_instance_dirs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils import globals 3 | from os.path import join as pjoin 4 | 5 | def init_instance_dirs(required_dirs=[]): 6 | root = globals.instance_data_path 7 | os.makedirs(root, exist_ok=True) 8 | for sub_path in required_dirs: 9 | rel_path = pjoin(root,sub_path) 10 | if not os.path.isdir(rel_path): 11 | os.mkdir(rel_path) 12 | -------------------------------------------------------------------------------- /src/models/vqgan/configs/vqgan_32_256.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | vqgan_32_256_config = { 4 | 'ckpt_path': "../instance-data/vqgan-32-256.ckpt", 5 | 'embed_dim': 256, 6 | # 'n_embed': 16384, 7 | 'n_embed': 1024, 8 | 'ddconfig': { 9 | 'double_z': False, 10 | 'z_channels': 256, 11 | 'resolution': 128, 12 | 'in_channels': 3, 13 | 'out_ch': 3, 14 | 'ch': 128, 15 | 'ch_mult': [ 1,2,4 ], # num_down = len(ch_mult)-1 16 | 'num_res_blocks': 2, 17 | 'attn_resolutions': [ 16 ], 18 | 'dropout': 0.0, 19 | } 20 | } 21 | 22 | -------------------------------------------------------------------------------- /src/models/vqgan/configs/vqgan_32_4.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | vqgan_32_4_config = { 4 | 'ckpt_path': "/home/zhang401/Documents/bc_code/taming-32-4-realestate-256.ckpt", 5 | 'embed_dim': 4, 6 | # 'n_embed': 16384, 7 | 'n_embed': 16384, 8 | 'ddconfig': { 9 | 'double_z': False, 10 | 'z_channels': 4, 11 | 'resolution': 256, 12 | 'in_channels': 3, 13 | 'out_ch': 3, 14 | 'ch': 128, 15 | 'ch_mult': [ 1,2,2,4], # num_down = len(ch_mult)-1 16 | 'num_res_blocks': 2, 17 | 'attn_resolutions': [ 16 ], 18 | 'dropout': 0.0, 19 | } 20 | } 21 | 22 | -------------------------------------------------------------------------------- /src/utils/tlog.py: -------------------------------------------------------------------------------- 1 | from termcolor import colored 2 | 3 | def tlog(s, mode='debug'): 4 | ''' 5 | prints a string with an appropriate header 6 | ''' 7 | if mode == 'note': 8 | header = colored('[Note]','red','on_cyan') 9 | elif mode == 'iter': 10 | header = colored(' ','grey','on_white') 11 | elif mode == 'debug': 12 | header = colored('[Debug]','grey','on_yellow') 13 | elif mode == 'error': 14 | header = colored('[Error]','cyan','on_red') 15 | else: 16 | header = colored('[Invalid print mode]','white','on_red') 17 | 18 | out = '{} {}'.format(header, s) 19 | print(out) 20 | -------------------------------------------------------------------------------- /src/models/score_sde/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /src/models/score_sde/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /data/colmap.sh: -------------------------------------------------------------------------------- 1 | # Set GPU. 2 | export CUDA_VISIBLE_DEVICES="0" 3 | 4 | export DATASET_PATH=PATH_TO_DATASET 5 | export LD_LIBRARY_PATH=/usr/local/cuda-11.2/lib64 6 | export colmap=PATH_TO_COLMAP_INSTALL/bin/colmap 7 | export mask=PATH_TO_GRABBER_MASK 8 | 9 | for filename in $DATASET_PATH/*; do 10 | echo $filename 11 | $colmap database_creator --database_path $filename/db.db 12 | $colmap feature_extractor --database_path $filename/db.db --image_path $filename/images --ImageReader.camera_mask_path $mask --ImageReader.single_camera 1 --SiftExtraction.use_gpu 1 13 | colmap exhaustive_matcher --database_path $filename/db.db --SiftMatching.use_gpu 1 14 | mkdir -p $filename/sparse 15 | $colmap mapper --database_path $filename/db.db --image_path $filename/images --output_path $filename/sparse 16 | for f in $filename/sparse/*; do 17 | $colmap model_converter --input_path $f --output_path $f --output_type TXT 18 | $colmap model_converter --input_path $f --output_path $f/export.ply --output_type PLY 19 | done 20 | done 21 | 22 | -------------------------------------------------------------------------------- /src/datasets/slam_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import collections 5 | 6 | def read_pose(file): 7 | data = json.load(open(file, "r")) 8 | dct = collections.defaultdict(list) 9 | frame_names = sorted(list(data.keys())) 10 | for frame in frame_names: 11 | pose = data[frame] 12 | pose = np.asarray(pose) 13 | orb_pose = orb_to_blender(pose) 14 | dct['imgs'].append("images/" + frame) 15 | dct['poses'].append(orb_pose) 16 | dct['poses_orig'].append(pose) 17 | return dict(dct) 18 | 19 | def orb_to_blender(camera_local): 20 | pre_conversion = np.array([ 21 | [1,0,0,0], 22 | [0,-1,0,0], 23 | [0,0,-1,0], 24 | [0,0,0,1], 25 | ]) 26 | conversion = np.array([ 27 | [1,0,0,0], 28 | [0,0,1,0], 29 | [0,-1,0,0], 30 | [0,0,0,1], 31 | ]) 32 | 33 | orb_world = np.matmul(camera_local,pre_conversion) 34 | blender_world = np.matmul(conversion,orb_world) 35 | 36 | return blender_world -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/models/score_sde/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); // added this to ensure we are on the same device as input 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /src/utils/Attribute_dict.py: -------------------------------------------------------------------------------- 1 | import re as reg 2 | 3 | class Attribute_dict: 4 | ''' 5 | recursively convert dictionary to an object where keys are attributes 6 | ''' 7 | def __init__(self,dict_in): 8 | for k in dict_in: 9 | v = dict_in[k] 10 | if type(v) is dict: 11 | self.__dict__[k] = Attribute_dict(v) 12 | else: 13 | self.__dict__[k] = v 14 | 15 | def update(self,dict_in): 16 | for k in dict_in: 17 | v = dict_in[k] 18 | self.__dict__[k] = Attribute_dict(v) if type(v) is dict else v 19 | 20 | def dict(self): 21 | out_dict = {} 22 | for k in self.__dict__: 23 | v = self.__dict__[k] 24 | out_dict[k] = v.dict() if type(v) is Attribute_dict else v 25 | return out_dict 26 | 27 | def __str__(self,indent='',node_indent=''): 28 | local_indent = ' | ' 29 | local_node_indent = ' |--' 30 | str_out = '' 31 | for k in self.__dict__: 32 | v = self.__dict__[k] 33 | if type(v) is Attribute_dict: 34 | str_out += node_indent + k+': \n' 35 | contents_str = v.__str__(indent=indent+local_indent,node_indent=indent+local_node_indent) 36 | contents_str = contents_str 37 | if not contents_str == '': # don't print empty line for empty object 38 | str_out += contents_str+'\n' 39 | else: 40 | str_out += node_indent 41 | str_out += '{}: {}\n'.format(k,v) 42 | 43 | return str_out[:-1] 44 | -------------------------------------------------------------------------------- /src/datasets/umi_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | 7 | def umi_read_poses_from_folder(data_save_dir, focal_length): 8 | subfolders = sorted(list(os.listdir(data_save_dir))) 9 | all_poses = {} 10 | for i, subfolder in tqdm(enumerate(subfolders)): 11 | subfolder_path = os.path.join(data_save_dir, subfolder) 12 | 13 | file = os.path.join(subfolder_path, 'raw_labels.json') 14 | data = json.load(open(file, "r")) 15 | dct = collections.defaultdict(list) 16 | frame_names = sorted(list(data.keys())) 17 | for frame in frame_names: 18 | pose = data[frame] 19 | pose = np.asarray(pose) 20 | dct['imgs'].append("images/" + frame) 21 | dct['poses'].append(pose) 22 | 23 | gripper_state_json = os.path.join(subfolder_path, 'gripper_state.json') 24 | gripper_state_data = json.load(open(gripper_state_json, "r")) 25 | gripper_state = [] 26 | for img_path in dct['imgs']: 27 | img_pathkey = img_path.split("/")[-1] 28 | assert img_pathkey in gripper_state_data, f"{img_pathkey} not in {subfolder_path}/gripper_state.json" 29 | gripper_state.append(int(gripper_state_data[img_pathkey])) 30 | dct['gripper_state'] = gripper_state 31 | dct['sample_param'] = (1,15) 32 | dct['focal_length'] = float(focal_length) 33 | dct['has_gripper_state'] = True 34 | all_poses[subfolder_path] = dct 35 | 36 | return all_poses 37 | -------------------------------------------------------------------------------- /src/models/score_sde/configs/LDM.py: -------------------------------------------------------------------------------- 1 | 2 | class LDM_config: 3 | def __init__(self): 4 | self.model = lambda: None 5 | self.seed = lambda: None 6 | self.data = lambda: None 7 | self.eval = lambda: None 8 | self.sampling = lambda: None 9 | self.training = lambda: None 10 | 11 | self.training.continuous = True 12 | 13 | self.data.dataset = 'LSUN' 14 | self.data.image_size = 32 15 | self.data.random_flip = True 16 | self.data.uniform_dequantization = False 17 | self.data.centered = False 18 | self.data.num_channels = 4 19 | 20 | self.sampling.n_steps_each = 1 21 | self.sampling.noise_removal = True 22 | self.sampling.probability_flow = False 23 | self.sampling.snr = 0.075 24 | self.sampling.method = 'pc' 25 | self.sampling.predictor = 'reverse_diffusion' 26 | self.sampling.corrector = 'langevin' 27 | 28 | self.model.sigma_max = 378 29 | self.model.sigma_min = 0.01 30 | self.model.num_scales = 2000 31 | self.model.beta_min = 0.1 32 | self.model.beta_max = 20. 33 | self.model.dropout = 0. 34 | self.model.embedding_type = 'fourier' 35 | self.model.name = 'ncsnpp' 36 | self.model.scale_by_sigma = True 37 | self.model.ema_rate = 0.999 38 | self.model.normalization = 'GroupNorm' 39 | self.model.nonlinearity = 'swish' 40 | self.model.nf = 128 41 | self.model.ch_mult = (2, 2, 2, 2) 42 | self.model.num_res_blocks = 2 43 | self.model.attn_resolutions = (32,16,8,4) 44 | self.model.resamp_with_conv = True 45 | self.model.conditional = True 46 | self.model.fir = True 47 | self.model.fir_kernel = [1, 3, 3, 1] 48 | self.model.skip_rescale = True 49 | self.model.resblock_type = 'biggan' 50 | self.model.progressive = 'output_skip' 51 | self.model.progressive_input = 'input_skip' 52 | self.model.progressive_combine = 'sum' 53 | self.model.attention_type = 'ddpm' 54 | self.model.init_scale = 0. 55 | self.model.fourier_scale = 16 56 | self.model.conv_size = 3 57 | 58 | -------------------------------------------------------------------------------- /src/models/score_sde/configs/conditional.py: -------------------------------------------------------------------------------- 1 | 2 | class Conditional_config: 3 | def __init__(self): 4 | self.model = lambda: None 5 | self.seed = lambda: None 6 | self.data = lambda: None 7 | self.eval = lambda: None 8 | self.sampling = lambda: None 9 | self.training = lambda: None 10 | 11 | self.training.continuous = True 12 | 13 | self.data.dataset = 'LSUN' 14 | self.data.image_size = 128 15 | self.data.random_flip = True 16 | self.data.uniform_dequantization = False 17 | self.data.centered = False 18 | self.data.num_channels = 3 19 | 20 | self.sampling.n_steps_each = 1 21 | self.sampling.noise_removal = True 22 | self.sampling.probability_flow = False 23 | self.sampling.snr = 0.075 24 | self.sampling.method = 'pc' 25 | self.sampling.predictor = 'reverse_diffusion' 26 | self.sampling.corrector = 'langevin' 27 | 28 | self.model.sigma_max = 378 29 | self.model.sigma_min = 0.01 30 | self.model.num_scales = 2000 31 | self.model.beta_min = 0.1 32 | self.model.beta_max = 20. 33 | self.model.dropout = 0. 34 | self.model.embedding_type = 'fourier' 35 | self.model.name = 'ncsnpp' 36 | self.model.scale_by_sigma = True 37 | self.model.ema_rate = 0.999 38 | self.model.normalization = 'GroupNorm' 39 | self.model.nonlinearity = 'swish' 40 | self.model.nf = 128 41 | # self.model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 42 | self.model.ch_mult = (1, 1, 2, 2, 2, 2) 43 | self.model.num_res_blocks = 2 44 | self.model.attn_resolutions = (16,) 45 | self.model.resamp_with_conv = True 46 | self.model.conditional = True 47 | self.model.fir = True 48 | self.model.fir_kernel = [1, 3, 3, 1] 49 | self.model.skip_rescale = True 50 | self.model.resblock_type = 'biggan' 51 | self.model.progressive = 'output_skip' 52 | self.model.progressive_input = 'input_skip' 53 | self.model.progressive_combine = 'sum' 54 | self.model.attention_type = 'ddpm' 55 | self.model.init_scale = 0. 56 | self.model.fourier_scale = 16 57 | self.model.conv_size = 3 58 | 59 | -------------------------------------------------------------------------------- /src/utils/camera_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def rel_camera_ray_encoding(tform,im_size,focal): 5 | # create camera ray encoding 6 | # assumes square images with same focal length on all axes, assume principle = 0.5 7 | cam_center = tform[:3,-1] 8 | cam_rot = tform[:3,:3] 9 | 10 | # find max limit from focal length 11 | max_pos = 1/(2*focal) # our images are normalizedto -1,1 12 | 13 | # find rays for all pixels 14 | pix_size = 2/im_size 15 | x,y = np.meshgrid(range(im_size),range(im_size),indexing='xy') 16 | x = 2*x/(im_size-1) - 1 17 | y = 2*y/(im_size-1) - 1 18 | x *= max_pos # scale to match focal length 19 | y *= max_pos 20 | pix_grid = np.stack([x,-y],0) 21 | ray_grid = np.concatenate([pix_grid,-np.ones(shape=[1,im_size,im_size])],0) 22 | ray_grid_flat = ray_grid.reshape(3,-1) 23 | 24 | rays = np.matmul(cam_rot,ray_grid_flat) 25 | rays = rays/np.linalg.norm(rays,2,0) 26 | rays = rays.reshape(3,im_size,im_size) 27 | 28 | camera_center = np.tile(cam_center[:,None,None],(1,im_size,im_size)) 29 | camera_data = np.concatenate([rays,camera_center],0) 30 | 31 | camera_data = camera_data.astype(np.float32) 32 | return camera_data 33 | 34 | def abs_cameras_freq_encoding(pose_a,pose_b,focal_y): 35 | new_rel = np.matmul(np.linalg.inv(pose_a),pose_b) 36 | camera_data_a = rel_camera_ray_encoding(np.eye(4),128,focal_y) 37 | camera_data_b = rel_camera_ray_encoding(new_rel,128,focal_y) 38 | camera_data_a = torch.Tensor(camera_data_a).unsqueeze(0).cuda() 39 | camera_data_b = torch.Tensor(camera_data_b).unsqueeze(0).cuda() 40 | fourier_feats = torch.cat([freq_enc(camera_data_a),freq_enc(camera_data_b)],1) 41 | return fourier_feats 42 | 43 | def freq_enc(camera_data,n_frequencies=4,half_period=6): 44 | # encodeing does not repeat until [-half_period, half_period] 45 | n_in_channels = camera_data.shape[1] 46 | frequency_exponent = torch.arange(n_frequencies) 47 | frequency_multiplier = (2.0**frequency_exponent)/half_period 48 | frequency_multiplier = frequency_multiplier.tile(n_in_channels,1).T.reshape(-1)[None,:,None,None]*torch.pi 49 | camera_data_tiled = camera_data.tile(1,n_frequencies,1,1)*frequency_multiplier.to(camera_data.device) 50 | camera_data_sin = torch.sin(camera_data_tiled) 51 | camera_data_cos = torch.cos(camera_data_tiled) 52 | fourier_feats = torch.cat([camera_data,camera_data_sin,camera_data_cos],1) 53 | return fourier_feats 54 | -------------------------------------------------------------------------------- /src/models/score_sde/CrossAttnBlockpp.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | NIN = layers.NIN 7 | 8 | class CrossAttnBlockpp(nn.Module): 9 | """Channel-wise self-attention block. Modified from DDPM.""" 10 | 11 | def __init__(self, channels, cond_chans, n_heads=4, skip_rescale=False, init_scale=0.): 12 | super().__init__() 13 | self.n_heads = n_heads 14 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 15 | eps=1e-6) 16 | self.NIN_0 = NIN(channels+cond_chans, channels*n_heads) 17 | self.NIN_1 = NIN(channels+cond_chans, channels*n_heads) 18 | self.NIN_2 = NIN(channels+cond_chans, channels*n_heads) 19 | self.NIN_3 = NIN(channels*n_heads, channels, init_scale=init_scale) 20 | self.skip_rescale = skip_rescale 21 | 22 | def split(self,x): 23 | # splits views along batch dim to two tensors 24 | B, E, C, H, W = x.shape 25 | x_split = x.reshape(B//2,2,E,C,H,W) 26 | x_a = x_split[:,0,:,:,:,:] 27 | x_b = x_split[:,1,:,:,:,:] 28 | 29 | return x_a,x_b 30 | 31 | def stack_rays(self,q_cond,k_a_cond,k_b_cond): 32 | # stacks views along batch dim 33 | b,c,h,w = q_cond.shape 34 | q_stacked = torch.cat([q_cond[:,None,:,:,:],q_cond[:,None,:,:,:]],1).reshape(b*2,c,h,w) 35 | b,c,h,w = k_a_cond.shape 36 | k_stacked = torch.cat([k_a_cond[:,None,:,:,:],k_b_cond[:,None,:,:,:]],1).reshape(b*2,c,h,w) 37 | return q_stacked,k_stacked 38 | 39 | def forward(self, x, q_cond, k_a_cond, k_b_cond): 40 | B, C, H, W = x.shape 41 | q_cond_stacked, k_cond_stacked = self.stack_rays(q_cond,k_a_cond,k_b_cond) 42 | h = self.GroupNorm_0(x) 43 | q = self.NIN_0(torch.cat([h,q_cond_stacked],1)).reshape(B,self.n_heads,C,H,W) 44 | k = self.NIN_1(torch.cat([h,k_cond_stacked],1)).reshape(B,self.n_heads,C,H,W) 45 | v = self.NIN_2(torch.cat([h,k_cond_stacked],1)).reshape(B,self.n_heads,C,H,W) 46 | 47 | # split into two halves 48 | q_a,q_b = self.split(q) 49 | k_a,k_b = self.split(k) 50 | v_a,v_b = self.split(v) 51 | 52 | # cross for part a 53 | w_a = torch.einsum('bechw,becij->behwij', q_a, k_b) * (int(C) ** (-0.5)) 54 | w_a = torch.reshape(w_a, (B//2, self.n_heads, H, W, H * W)) 55 | w_a = F.softmax(w_a, dim=-1) 56 | w_a = torch.reshape(w_a, (B//2, self.n_heads, H, W, H, W)) 57 | h_a = torch.einsum('behwij,becij->bechw', w_a, v_b) 58 | 59 | # cross for part a 60 | w_b = torch.einsum('bechw,becij->behwij', q_b, k_a) * (int(C) ** (-0.5)) 61 | w_b = torch.reshape(w_b, (B//2, self.n_heads, H, W, H * W)) 62 | w_b = F.softmax(w_b, dim=-1) 63 | w_b = torch.reshape(w_b, (B//2, self.n_heads, H, W, H, W)) 64 | h_b = torch.einsum('behwij,becij->bechw', w_b, v_a) 65 | 66 | # recombine 67 | h = torch.cat([h_a[:,None,:,:,:,:],h_b[:,None,:,:,:,:]],1) 68 | h = h.reshape(B,self.n_heads*C,H,W) 69 | h = self.NIN_3(h) 70 | if not self.skip_rescale: 71 | return x + h 72 | else: 73 | return (x + h) / np.sqrt(2.) 74 | -------------------------------------------------------------------------------- /src/models/score_sde/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /src/models/score_sde/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /src/models/vqgan/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import einsum 5 | 6 | 7 | class VectorQuantizer(nn.Module): 8 | """ 9 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 10 | ____________________________________________ 11 | Discretization bottleneck part of the VQ-VAE. 12 | Inputs: 13 | - n_e : number of embeddings 14 | - e_dim : dimension of embedding 15 | - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 16 | _____________________________________________ 17 | """ 18 | 19 | def __init__(self, n_e, e_dim, beta): 20 | super(VectorQuantizer, self).__init__() 21 | self.n_e = n_e 22 | self.e_dim = e_dim 23 | self.beta = beta 24 | 25 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 26 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 27 | 28 | def forward(self, z): 29 | """ 30 | Inputs the output of the encoder network z and maps it to a discrete 31 | one-hot vector that is the index of the closest embedding vector e_j 32 | z (continuous) -> z_q (discrete) 33 | z.shape = (batch, channel, height, width) 34 | quantization pipeline: 35 | 1. get encoder input (B,C,H,W) 36 | 2. flatten input to (B*H*W,C) 37 | """ 38 | # reshape z -> (batch, height, width, channel) and flatten 39 | z = z.permute(0, 2, 3, 1).contiguous() 40 | z_flattened = z.view(-1, self.e_dim) 41 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 42 | 43 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 44 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 45 | torch.matmul(z_flattened, self.embedding.weight.t()) 46 | 47 | # find closest encodings 48 | min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) 49 | 50 | min_encodings = torch.zeros( 51 | min_encoding_indices.shape[0], self.n_e).to(z) 52 | min_encodings.scatter_(1, min_encoding_indices, 1) 53 | 54 | # get quantized latent vectors 55 | z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) 56 | 57 | # compute loss for embedding 58 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 59 | torch.mean((z_q - z.detach()) ** 2) 60 | 61 | # preserve gradients 62 | z_q = z + (z_q - z).detach() 63 | 64 | # perplexity 65 | e_mean = torch.mean(min_encodings, dim=0) 66 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) 67 | 68 | # reshape back to match original input shape 69 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 70 | 71 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 72 | 73 | def get_codebook_entry(self, indices, shape): 74 | # shape specifying (batch, height, width, channel) 75 | min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) 76 | min_encodings.scatter_(1, indices[:,None], 1) 77 | 78 | # get quantized latent vectors 79 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight) 80 | 81 | if shape is not None: 82 | z_q = z_q.view(shape) 83 | 84 | # reshape back to match original input shape 85 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 86 | 87 | return z_q 88 | -------------------------------------------------------------------------------- /src/datasets/umi_pos_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.spatial.transform as st 3 | 4 | def pos_rot_to_mat(pos, rot): 5 | shape = pos.shape[:-1] 6 | mat = np.zeros(shape + (4,4), dtype=pos.dtype) 7 | mat[...,:3,3] = pos 8 | mat[...,:3,:3] = rot.as_matrix() 9 | mat[...,3,3] = 1 10 | return mat 11 | 12 | def mat_to_pos_rot(mat): 13 | pos = (mat[...,:3,3].T / mat[...,3,3].T).T 14 | rot = st.Rotation.from_matrix(mat[...,:3,:3]) 15 | return pos, rot 16 | 17 | def pos_rot_to_pose(pos, rot): 18 | shape = pos.shape[:-1] 19 | pose = np.zeros(shape+(6,), dtype=pos.dtype) 20 | pose[...,:3] = pos 21 | pose[...,3:] = rot.as_rotvec() 22 | return pose 23 | 24 | def pose_to_pos_rot(pose): 25 | pos = pose[...,:3] 26 | rot = st.Rotation.from_rotvec(pose[...,3:]) 27 | return pos, rot 28 | 29 | def pose_to_mat(pose): 30 | return pos_rot_to_mat(*pose_to_pos_rot(pose)) 31 | 32 | def mat_to_pose(mat): 33 | return pos_rot_to_pose(*mat_to_pos_rot(mat)) 34 | 35 | def transform_pose(tx, pose): 36 | """ 37 | tx: tx_new_old 38 | pose: tx_old_obj 39 | result: tx_new_obj 40 | """ 41 | pose_mat = pose_to_mat(pose) 42 | tf_pose_mat = tx @ pose_mat 43 | tf_pose = mat_to_pose(tf_pose_mat) 44 | return tf_pose 45 | 46 | def transform_point(tx, point): 47 | return point @ tx[:3,:3].T + tx[:3,3] 48 | 49 | def project_point(k, point): 50 | x = point @ k.T 51 | uv = x[...,:2] / x[...,[2]] 52 | return uv 53 | 54 | def apply_delta_pose(pose, delta_pose): 55 | new_pose = np.zeros_like(pose) 56 | 57 | # simple add for position 58 | new_pose[:3] = pose[:3] + delta_pose[:3] 59 | 60 | # matrix multiplication for rotation 61 | rot = st.Rotation.from_rotvec(pose[3:]) 62 | drot = st.Rotation.from_rotvec(delta_pose[3:]) 63 | new_pose[3:] = (drot * rot).as_rotvec() 64 | 65 | return new_pose 66 | 67 | def normalize(vec, tol=1e-7): 68 | return vec / np.maximum(np.linalg.norm(vec), tol) 69 | 70 | def rot_from_directions(from_vec, to_vec): 71 | from_vec = normalize(from_vec) 72 | to_vec = normalize(to_vec) 73 | axis = np.cross(from_vec, to_vec) 74 | axis = normalize(axis) 75 | angle = np.arccos(np.dot(from_vec, to_vec)) 76 | rotvec = axis * angle 77 | rot = st.Rotation.from_rotvec(rotvec) 78 | return rot 79 | 80 | def normalize(vec, eps=1e-12): 81 | norm = np.linalg.norm(vec, axis=-1) 82 | norm = np.maximum(norm, eps) 83 | out = (vec.T / norm).T 84 | return out 85 | 86 | def rot6d_to_mat(d6): 87 | a1, a2 = d6[..., :3], d6[..., 3:] 88 | b1 = normalize(a1) 89 | b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1 90 | b2 = normalize(b2) 91 | b3 = np.cross(b1, b2, axis=-1) 92 | out = np.stack((b1, b2, b3), axis=-2) 93 | return out 94 | 95 | def mat_to_rot6d(mat): 96 | batch_dim = mat.shape[:-2] 97 | out = mat[..., :2, :].copy().reshape(batch_dim + (6,)) 98 | return out 99 | 100 | def mat_to_pose10d(mat): 101 | pos = mat[...,:3,3] 102 | rotmat = mat[...,:3,:3] 103 | d6 = mat_to_rot6d(rotmat) 104 | d10 = np.concatenate([pos, d6], axis=-1) 105 | return d10 106 | 107 | def pose10d_to_mat(d10): 108 | pos = d10[...,:3] 109 | d6 = d10[...,3:] 110 | rotmat = rot6d_to_mat(d6) 111 | out = np.zeros(d10.shape[:-1]+(4,4), dtype=d10.dtype) 112 | out[...,:3,:3] = rotmat 113 | out[...,:3,3] = pos 114 | out[...,3,3] = 1 115 | return out 116 | -------------------------------------------------------------------------------- /src/datasets/DistributedSaveableSampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.utils.data.distributed import DistributedSampler 4 | 5 | class DistributedSaveableSampler(DistributedSampler): 6 | """Just like with the case with 7 | torch.utils.data.distributed.DistributedSampler you *MUST* call 8 | self.set_epoch(epoch:int) to ensure all replicates use the same 9 | random shuffling within each epoch if shuffle is True 10 | """ 11 | 12 | def __init__(self, *args, force_synchronization=False, **kwargs): 13 | """ 14 | Arguments: 15 | force_synchronization (boolean, optional): If it's true then after 16 | each yield we will force a synchronization so each process' 17 | _curr_idx will be the same, this guarantees correctness of the 18 | save in case there is no synchronization during training, but 19 | comes at a performance cost 20 | For the rest of the arguments please see: 21 | https://pytorch.org/docs/1.7.1/data.html?highlight=distributed%20sampler#torch.utils.data.distributed.DistributedSampler 22 | """ 23 | super().__init__(*args, **kwargs) 24 | self._curr_idx = 0 25 | self.force_synchronization = force_synchronization 26 | 27 | def __iter__(self): 28 | """Logic modified from 29 | https://pytorch.org/docs/1.7.1/_modules/torch/utils/data/distributed.html#DistributedSampler 30 | """ 31 | if self.shuffle: 32 | # deterministically shuffle based on epoch and seed 33 | g = torch.Generator() 34 | g.manual_seed(self.seed + self.epoch) 35 | indices = torch.randperm(len(self.dataset), 36 | generator=g).tolist() # type: ignore 37 | else: 38 | indices = list(range(len(self.dataset))) # type: ignore 39 | 40 | if not self.drop_last: 41 | # add extra samples to make it evenly divisible 42 | indices += indices[:(self.total_size - len(indices))] 43 | else: 44 | # remove tail of data to make it evenly divisible. 45 | indices = indices[:self.total_size] 46 | assert len(indices) == self.total_size 47 | 48 | while self._curr_idx + self.rank < self.total_size: 49 | to_yield = self.rank + self._curr_idx 50 | 51 | # we need to increment this before the yield because 52 | # there might be a save or preemption while we are yielding 53 | # so we must increment it before to save the right index 54 | self._curr_idx += self.num_replicas 55 | 56 | yield indices[to_yield] 57 | 58 | if self.force_synchronization: 59 | dist.barrier() 60 | self._curr_idx = 0 61 | 62 | def state_dict(self, dataloader_iter=None): 63 | prefetched_num = 0 64 | # in the case of multiworker dataloader, the helper worker could be 65 | # pre-fetching the data that is not consumed by the main dataloader. 66 | # we need to subtract the unconsumed part . 67 | if dataloader_iter is not None: 68 | if dataloader_iter._num_workers > 0: 69 | batch_size = dataloader_iter._index_sampler.batch_size 70 | prefetched_num = ( 71 | (dataloader_iter._send_idx - dataloader_iter._rcvd_idx) * 72 | batch_size) 73 | 74 | return { 75 | "index": self._curr_idx - (prefetched_num * self.num_replicas), 76 | "epoch": self.epoch, 77 | } 78 | 79 | def load_state_dict(self, state_dict): 80 | self._curr_idx = state_dict["index"] 81 | self.epoch = state_dict["epoch"] 82 | -------------------------------------------------------------------------------- /src/models/score_sde/ema.py: -------------------------------------------------------------------------------- 1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 10 | class ExponentialMovingAverage: 11 | """ 12 | Maintains (exponential) moving average of a set of parameters. 13 | """ 14 | 15 | def __init__(self, parameters, decay, use_num_updates=True): 16 | """ 17 | Args: 18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 19 | `model.parameters()`. 20 | decay: The exponential decay. 21 | use_num_updates: Whether to use number of updates when computing 22 | averages. 23 | """ 24 | if decay < 0.0 or decay > 1.0: 25 | raise ValueError('Decay must be between 0 and 1') 26 | self.decay = decay 27 | self.num_updates = 0 if use_num_updates else None 28 | self.shadow_params = [p.clone().detach() 29 | for p in parameters if p.requires_grad] 30 | self.collected_params = [] 31 | 32 | def update(self, parameters): 33 | """ 34 | Update currently maintained parameters. 35 | 36 | Call this every time the parameters are updated, such as the result of 37 | the `optimizer.step()` call. 38 | 39 | Args: 40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 41 | parameters used to initialize this object. 42 | """ 43 | decay = self.decay 44 | if self.num_updates is not None: 45 | self.num_updates += 1 46 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 47 | one_minus_decay = 1.0 - decay 48 | with torch.no_grad(): 49 | parameters = [p for p in parameters if p.requires_grad] 50 | for s_param, param in zip(self.shadow_params, parameters): 51 | s_param.sub_(one_minus_decay * (s_param - param)) 52 | 53 | def copy_to(self, parameters): 54 | """ 55 | Copy current parameters into given collection of parameters. 56 | 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | updated with the stored moving averages. 60 | """ 61 | parameters = [p for p in parameters if p.requires_grad] 62 | for s_param, param in zip(self.shadow_params, parameters): 63 | if param.requires_grad: 64 | param.data.copy_(s_param.data) 65 | 66 | def store(self, parameters): 67 | """ 68 | Save the current parameters for restoring later. 69 | 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | temporarily stored. 73 | """ 74 | self.collected_params = [param.clone() for param in parameters] 75 | 76 | def restore(self, parameters): 77 | """ 78 | Restore the parameters stored with the `store` method. 79 | Useful to validate the model with EMA parameters without affecting the 80 | original optimization process. Store the parameters before the 81 | `copy_to` method. After validation (or model saving), use this to 82 | restore the former parameters. 83 | 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | 91 | def state_dict(self): 92 | return dict(decay=self.decay, num_updates=self.num_updates, 93 | shadow_params=self.shadow_params) 94 | 95 | def load_state_dict(self, state_dict): 96 | self.decay = state_dict['decay'] 97 | self.num_updates = state_dict['num_updates'] 98 | self.shadow_params = state_dict['shadow_params'] -------------------------------------------------------------------------------- /src/datasets/write_translations.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | from PIL import Image 6 | 7 | def get_colmap_labels(folder, file='/sparse/0/images.txt', failed_frames = [], focal_length=None, data_type="vime", interval=1): 8 | """ 9 | Saves relative translations and rotations to labels.json 10 | :param f: Current folder 11 | :param file: File to retrieve poses from 12 | :return: 13 | """ 14 | file = folder + file 15 | if not os.path.exists(file): 16 | print("File " + file + " doesn't exist.") 17 | return None 18 | lines = [] 19 | with open(file, 'r') as fi: 20 | # First few header lines of images.txt 21 | for i in range(4): 22 | fi.readline() 23 | while True: 24 | line = fi.readline() 25 | if not line: 26 | break 27 | if ("png" in line) or ("jpg" in line): 28 | lines.append(line) 29 | 30 | lines.sort(key=lambda x: x.split(" ")[-1].strip()) 31 | 32 | dct = {} 33 | dct['poses'] = [] 34 | dct['dependencies'] = [None] 35 | dct['generation_order'] = [] 36 | dct['imgs'] = [] 37 | 38 | all_displacements = [] 39 | for l in range(0, len(lines)-1): 40 | curr = lines[l] 41 | next_line = lines[l+1] 42 | cid, cQW, cQX, cQY, cQZ, cTX, cTY, cTZ, _, cname = curr.split() 43 | nid, nQW, nQX, nQY, nQZ, nTX, nTY, nTZ, _, nname = next_line.split() 44 | if data_type == "vime": 45 | try: 46 | start_frame = int(cname.split(".")[0].split("frame")[-1]) 47 | end_frame = int(nname.split(".")[0].split("frame")[-1]) 48 | except: 49 | print("Format incorrect! {folder} {cname} {nname}") 50 | continue 51 | elif data_type == "apple": 52 | try: 53 | start_frame = int(cname.split(".")[0]) 54 | end_frame = int(nname.split(".")[0]) 55 | except: 56 | print("Format incorrect! {folder} {cname} {nname}") 57 | continue 58 | frame_in_between = end_frame - start_frame 59 | assert frame_in_between != 0 60 | 61 | q1 = np.array([cQX, cQY, cQZ, cQW]) 62 | q2 = np.array([nQX, nQY, nQZ, nQW]) 63 | 64 | rot1 = R.from_quat(q1) 65 | rot2 = R.from_quat(q2) 66 | 67 | t1 = np.array([float(cTX), float(cTY), float(cTZ)]) 68 | t2 = np.array([float(nTX), float(nTY), float(nTZ)]) 69 | 70 | t = t1 - rot1.as_matrix() @ (rot2.inv().as_matrix() @ t2) 71 | trans = (np.linalg.norm(t) / frame_in_between) * interval 72 | all_displacements.append(trans) 73 | 74 | max_displace = np.max(all_displacements) 75 | if max_displace == 0: 76 | return None 77 | dct['scale_factor'] = float(max_displace) 78 | count = 0 79 | height = None 80 | for l in range(len(lines)): 81 | curr = lines[l] 82 | cid, cQW, cQX, cQY, cQZ, cTX, cTY, cTZ, _, cname = curr.split() 83 | if len(failed_frames) > 0 and cname in failed_frames: 84 | continue 85 | 86 | if l == 0: 87 | height = Image.open(os.path.join(folder, "images", cname)).size[1] 88 | if data_type == "vime" and height != 1440: 89 | return None 90 | if data_type == "apple" and height != 1080: 91 | return None 92 | q1 = np.array([cQX, cQY, cQZ, cQW]) 93 | rot1 = R.from_quat(q1) 94 | 95 | t1 = np.array([float(cTX), float(cTY), float(cTZ)]) 96 | t1 = np.array([t1]) 97 | 98 | end_row = np.array([0.0, 0.0, 0.0, 1.0]) 99 | end_row = np.array([end_row]) 100 | 101 | orb_t1 = orb_to_blender(np.append(np.append(rot1.as_matrix(), t1.T, 1), end_row, 0)) 102 | # dct[cname] = (list(t), r) 103 | dct['imgs'].append("images/" + cname) 104 | dct['poses'].append(orb_t1.tolist()) 105 | dct['dependencies'].append([0]) 106 | count = count + 1 107 | dct['generation_order'].append(count) 108 | 109 | dct['focal_y'] = float(focal_length)/float(height) 110 | if len(dct['imgs']) == 0: 111 | return None 112 | return dct 113 | 114 | def orb_to_blender(orb_t): 115 | pre_conversion = np.array([ 116 | [1,0,0,0], 117 | [0,-1,0,0], 118 | [0,0,-1,0], 119 | [0,0,0,1], 120 | ]) 121 | conversion = np.array([ 122 | [1,0,0,0], 123 | [0,0,1,0], 124 | [0,-1,0,0], 125 | [0,0,0,1], 126 | ]) 127 | 128 | camera_local = np.linalg.inv(orb_t) 129 | orb_world = np.matmul(camera_local,pre_conversion) 130 | blender_world = np.matmul(conversion,orb_world) 131 | 132 | return blender_world -------------------------------------------------------------------------------- /src/scripts/mini_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from os.path import join as pjoin 4 | import torch.multiprocessing as mp # don't chdir before we load this 5 | import argparse 6 | 7 | # add parent dir to path 8 | script_path = os.path.dirname(os.path.realpath(sys.argv[0])) 9 | src_path = os.path.abspath(os.path.join(script_path,'..')) 10 | sys.path.append(src_path) 11 | os.chdir(src_path) 12 | 13 | # args 14 | argParser = argparse.ArgumentParser(description="start/resume training") 15 | argParser.add_argument("-r","--rank",dest="rank",action="store",default=0,type=int) 16 | argParser.add_argument("--num_gpus_per_node",dest="n_gpus_per_node",action="store",default=4,type=int) 17 | argParser.add_argument("--num_nodes",dest="n_nodes",action="store",default=1,type=int) 18 | 19 | argParser.add_argument("--instance_data_path",type=str, help="path to save checkpoints and logs") 20 | argParser.add_argument("--batch_size",type=int) 21 | argParser.add_argument("--colmap_data_folders", type=str, nargs="+", default=[]) 22 | argParser.add_argument("--focal_lengths", type=float, nargs="+", default=[]) 23 | argParser.add_argument('--task', type=str, choices=['push', 'stack', 'pour', 'hang', 'umi']) 24 | argParser.add_argument('--sfm_method', type=str, choices=['colmap', 'grabber_orbslam', 'umi']) 25 | argParser.add_argument("--max_epoch",type=int, default=10000) 26 | 27 | argParser.add_argument("--n_kept_checkpoints",type=int, default=5) 28 | argParser.add_argument("--checkpoint_interval",type=int, default=500) 29 | argParser.add_argument("--checkpoint_retention_interval",type=int, default=200) 30 | 31 | 32 | cli_args = argParser.parse_args() 33 | 34 | from utils import * 35 | 36 | globals.instance_data_path = cli_args.instance_data_path 37 | globals.ckpt_path = os.path.join(cli_args.instance_data_path, "checkpoints") 38 | globals.batch_size = cli_args.batch_size 39 | globals.colmap_data_folders = cli_args.colmap_data_folders 40 | globals.focal_lengths = cli_args.focal_lengths 41 | globals.task = cli_args.task 42 | globals.sfm_method = cli_args.sfm_method 43 | globals.n_kept_checkpoints = cli_args.n_kept_checkpoints 44 | globals.checkpoint_interval = cli_args.checkpoint_interval 45 | globals.max_epoch = cli_args.max_epoch 46 | globals.checkpoint_retention_interval = cli_args.checkpoint_retention_interval 47 | 48 | import torch.distributed as dist 49 | import utils.trainers 50 | import socket 51 | import signal 52 | import psutil 53 | 54 | def worker(local_rank,node_rank,n_gpus_per_node,n_nodes): 55 | # set multiprocessing to fork for dataloader workers (might avoid shared memory issues) 56 | try: 57 | mp.set_start_method('spawn') # fork 58 | except: 59 | pass 60 | 61 | # initialize distributed training 62 | rank = node_rank*n_gpus_per_node + local_rank 63 | world_size = n_nodes*n_gpus_per_node 64 | 65 | dist.init_process_group( 66 | backend='nccl', 67 | init_method='env://', 68 | world_size=world_size, 69 | rank=rank 70 | ) 71 | 72 | trainer = utils.trainers.Score_sde_trainer(local_rank,node_rank,n_gpus_per_node,n_nodes) 73 | 74 | # load checkpoint if it exists 75 | ckpt_dir = globals.ckpt_path 76 | checkpoints = os.listdir(ckpt_dir) 77 | checkpoints = [x for x in checkpoints if x.endswith('.pth')] 78 | checkpoints.sort() 79 | if len(checkpoints) > 0: 80 | latest_checkpoint = pjoin(ckpt_dir,checkpoints[-1]) 81 | if rank == 0: tlog(f'Resuming from {latest_checkpoint}','note') 82 | trainer.load_checkpoint(latest_checkpoint) 83 | 84 | # attach signal handler for graceful termination 85 | def termination_handler(sig_num,frame): 86 | trainer.termination_requested = True 87 | signal.signal(signal.SIGTERM,termination_handler) 88 | 89 | # do training 90 | try: 91 | trainer.train() 92 | except Exception as e: 93 | # print which worker failed 94 | tlog(f'Exception on {socket.gethostname()}, rank {rank}: {e}','error') 95 | raise e 96 | 97 | if __name__ == '__main__': 98 | if cli_args.rank < 0: 99 | node_rank = int(os.environ['SLURM_PROCID']) 100 | else: 101 | node_rank = cli_args.rank 102 | tlog(f'{socket.gethostname()}: {node_rank}') 103 | if node_rank == 0: 104 | init_instance_dirs(['logs','checkpoints']) 105 | 106 | def main_exit_handler(sig_num,frame): 107 | parent = psutil.Process(os.getpid()) 108 | children = parent.children() 109 | for child in children: 110 | child.send_signal(sig_num) 111 | tlog(f'SIGTERM received, waiting to children to finish','note') 112 | signal.signal(signal.SIGTERM,main_exit_handler) 113 | 114 | n_gpus_per_node = cli_args.n_gpus_per_node 115 | n_nodes = cli_args.n_nodes 116 | # worker(0, node_rank, n_gpus_per_node, n_nodes) 117 | mp.spawn(worker,nprocs=n_gpus_per_node,args=(node_rank,n_gpus_per_node,n_nodes)) 118 | -------------------------------------------------------------------------------- /src/models/vqgan/vqgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from .diff import Encoder, Decoder 7 | from .quantize import VectorQuantizer 8 | 9 | 10 | class VQModel(nn.Module): 11 | def __init__(self, 12 | ddconfig, 13 | # lossconfig, 14 | n_embed, 15 | embed_dim, 16 | ckpt_path=None, 17 | ignore_keys=[], 18 | image_key="image", 19 | colorize_nlabels=None, 20 | monitor=None, 21 | batch_resize_range=None, 22 | scheduler_config=None, 23 | lr_g_factor=1.0, 24 | ): 25 | super().__init__() 26 | self.image_key = image_key 27 | self.encoder = Encoder(**ddconfig) 28 | self.decoder = Decoder(**ddconfig) 29 | # self.loss = instantiate_from_config(lossconfig) 30 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25) 31 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 32 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 33 | if colorize_nlabels is not None: 34 | assert type(colorize_nlabels)==int 35 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 36 | if monitor is not None: 37 | self.monitor = monitor 38 | self.batch_resize_range = batch_resize_range 39 | if self.batch_resize_range is not None: 40 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") 41 | if ckpt_path is not None: 42 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 43 | self.scheduler_config = scheduler_config 44 | self.lr_g_factor = lr_g_factor 45 | 46 | def init_from_ckpt(self, path, ignore_keys=list()): 47 | sd = torch.load(path, map_location="cpu")["state_dict"] 48 | keys = list(sd.keys()) 49 | for k in keys: 50 | for ik in ignore_keys: 51 | if k.startswith(ik): 52 | print("Deleting key {} from state_dict.".format(k)) 53 | del sd[k] 54 | self.load_state_dict(sd, strict=False) 55 | print(f"VQGAN Weights Restored from {path}") 56 | 57 | def encode(self, x): 58 | h = self.encoder(x) 59 | h = self.quant_conv(h) 60 | return h 61 | 62 | def decode(self, h): 63 | quant, emb_loss, info = self.quantize(h) 64 | quant = self.post_quant_conv(quant) 65 | dec = self.decoder(quant) 66 | return dec 67 | 68 | def decode_code(self, code_b): 69 | quant_b = self.quantize.embed_code(code_b) 70 | dec = self.decode(quant_b) 71 | return dec 72 | 73 | def forward(self, input, return_pred_indices=False): 74 | quant, diff, (_,_,ind) = self.encode(input) 75 | dec = self.decode(quant) 76 | if return_pred_indices: 77 | return dec, diff, ind 78 | return dec, diff 79 | 80 | def get_input(self, batch, k): 81 | x = batch[k] 82 | if len(x.shape) == 3: 83 | x = x[..., None] 84 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 85 | if self.batch_resize_range is not None: 86 | lower_size = self.batch_resize_range[0] 87 | upper_size = self.batch_resize_range[1] 88 | if self.global_step <= 4: 89 | # do the first few batches with max size to avoid later oom 90 | new_resize = upper_size 91 | else: 92 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) 93 | if new_resize != x.shape[2]: 94 | x = F.interpolate(x, size=new_resize, mode="bicubic") 95 | x = x.detach() 96 | return x 97 | 98 | def validation_step(self, batch, batch_idx): 99 | x = self.get_input(batch, self.image_key) 100 | xrec, qloss = self(x) 101 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, 102 | last_layer=self.get_last_layer(), split="val") 103 | 104 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, 105 | last_layer=self.get_last_layer(), split="val") 106 | rec_loss = log_dict_ae["val/rec_loss"] 107 | self.log("val/rec_loss", rec_loss, 108 | prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 109 | self.log("val/aeloss", aeloss, 110 | prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 111 | self.log_dict(log_dict_ae) 112 | self.log_dict(log_dict_disc) 113 | return self.log_dict 114 | 115 | def get_last_layer(self): 116 | return self.decoder.conv_out.weight 117 | 118 | -------------------------------------------------------------------------------- /src/utils/trainers/Base_trainer.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from os.path import join as pjoin 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | import os 6 | import time 7 | 8 | class Base_trainer: 9 | def __init__(self,compute_only=True): 10 | self.epoch = 0 # future: rename to epochs_completed 11 | self.total_iterations = 0 12 | self.max_epoch = 200 13 | self.n_kept_checkpoints = 0 14 | self.compute_only = compute_only 15 | self.last_checkpoint_time = time.time() 16 | self.checkpoint_interval = 0 17 | self.checkpoint_retention_interval = 0 18 | self.termination_requested = False 19 | self.print_iteration = 200 20 | 21 | if not self.compute_only: 22 | log_dir = os.path.join(globals.instance_data_path,f'logs/{time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())}') 23 | self.tb_writer = SummaryWriter(log_dir=log_dir) 24 | 25 | def train(self): 26 | for epochs_completed in range(self.epoch, self.max_epoch): 27 | self.train_epoch() 28 | self.epoch += 1 # epoch done, +1 for correct saving 29 | 30 | # validate 31 | self.validate() 32 | 33 | # checkpoint 34 | if not self.compute_only: 35 | self.tlog(f'Epoch {self.epoch} complete, checkpointing','note') 36 | self.save_checkpoint(f'{self.epoch:04d}-{0:08d}.pth') 37 | else: # load checkpoint to sync params 38 | expected_checkpoint = os.path.join(globals.ckpt_path, f"{self.epoch:04d}-{0:08d}.pth") 39 | #f'../instance-data/checkpoints/{self.epoch:04d}-{0:08d}.pth' 40 | while not os.path.exists(expected_checkpoint): 41 | time.sleep(5) 42 | self.load_checkpoint(expected_checkpoint) 43 | 44 | def train_epoch(self): 45 | raise NotImplementedError 46 | 47 | def validate(self): 48 | raise NotImplementedError 49 | 50 | def state_dict(self): 51 | raise NotImplementedError 52 | 53 | def load_state_dict(self,state_dict): 54 | raise NotImplementedError 55 | 56 | def tlog(self,text,mode='debug'): 57 | if not self.compute_only: 58 | tlog(text,mode) 59 | 60 | def check_termination_request(self,epoch_it,dataloader_iter=None): 61 | if self.termination_requested: 62 | if self.compute_only: exit() 63 | self.tlog('Termination requested, checkpointing','note') 64 | self.save_checkpoint(f'{self.epoch:04d}-{epoch_it:08d}.pth',dataloader_iter) 65 | exit() 66 | 67 | def maybe_save_checkpoint(self,epoch_it,dataloader_iter=None): 68 | # saves checkpoint if haven't saved in more than checkpoint_interval 69 | if self.compute_only: return 70 | if self.checkpoint_interval == 0: return 71 | if (time.time() - self.last_checkpoint_time) < self.checkpoint_interval: return 72 | self.tlog(f'Haven\'t checkpointed in over {self.checkpoint_interval} seconds, checkpointing','note') 73 | self.save_checkpoint(f'{self.epoch:04d}-{epoch_it:08d}.pth',dataloader_iter) 74 | 75 | def save_checkpoint(self,checkpoint_filename,dataloader_iter=None): 76 | checkpoint_root = globals.ckpt_path #'../instance-data/checkpoints' 77 | # save model, atomic, saves as .part, then swaps when done 78 | tmp_out_path = pjoin(checkpoint_root,checkpoint_filename+'.part') 79 | out_path = pjoin(checkpoint_root,checkpoint_filename) 80 | checkpoint = self.state_dict(dataloader_iter) 81 | torch.save(checkpoint,tmp_out_path) 82 | os.replace(tmp_out_path,out_path) 83 | self.tlog(f'Checkpoint saved to: {out_path}','note') 84 | 85 | self.last_checkpoint_time = time.time() 86 | 87 | # don't clear checkpoints if n_kept_checkpoints == 0 88 | if self.n_kept_checkpoints == 0: return 89 | 90 | # delete excess checkpoints 91 | checkpoints = os.listdir(checkpoint_root) 92 | checkpoints = [x for x in checkpoints if x.endswith('.pth')] 93 | if self.checkpoint_retention_interval > 0: # ignore any matching retention criterea 94 | checkpoints_new = [] 95 | for x in checkpoints: 96 | xp1,xp2 = x.split(".")[0].split("-") 97 | xp1_int = int(xp1) 98 | if xp1_int % self.checkpoint_retention_interval > 0: 99 | checkpoints_new.append(x) 100 | if int(xp2[:-4]) > 0: 101 | checkpoints_new.append(x) 102 | checkpoints = checkpoints_new 103 | # checkpoints = [x for x in checkpoints if int(x[:4]) % self.checkpoint_retention_interval > 0 or int(x[5:-4]) > 0] 104 | checkpoints.sort() 105 | if len(checkpoints) > self.n_kept_checkpoints: 106 | n_extra = len(checkpoints) - self.n_kept_checkpoints 107 | for i in range(n_extra): 108 | os.remove(pjoin(checkpoint_root,checkpoints[i])) 109 | self.tlog(f'Cleared old checkpoint: {checkpoints[i]}','note') 110 | 111 | def load_checkpoint(self,checkpoint_path,map_location=None): 112 | self.tlog(f'Checkpoint loaded from: {checkpoint_path}','note') 113 | checkpoint = torch.load(checkpoint_path,map_location=map_location) 114 | 115 | self.load_state_dict(checkpoint) 116 | -------------------------------------------------------------------------------- /data/generate_colmap_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | from autolab_core import RigidTransform 6 | import copy 7 | import argparse 8 | 9 | def get_colmap_labels(folder, file='sparse/0/images.txt', interval = 1, data_type = "apple"): 10 | image_txt_file = os.path.join(folder, file) 11 | if not os.path.exists(image_txt_file): 12 | if not os.path.exists(image_txt_file): 13 | print(f"ERROR!!!! File {image_txt_file} doesn't exist.") 14 | return 15 | 16 | # Jan-16 data are collected at a slightly higher frame rate 17 | if "Jan-16-2024" in folder and interval > 1: 18 | real_interval = int(np.ceil(interval * 1.5)) 19 | else: 20 | real_interval = interval 21 | 22 | lines = [] 23 | with open(image_txt_file, 'r') as fi: 24 | # First few header lines of images.txt 25 | for i in range(4): 26 | fi.readline() 27 | while True: 28 | line = fi.readline() 29 | 30 | if not line: 31 | break 32 | if ("png" in line) or ("jpg" in line): 33 | # print(line.split(" ")[-1].strip()) 34 | lines.append(line) 35 | 36 | lines.sort(key=lambda x: x.split(" ")[-1].strip()) 37 | num_lines = len(lines) 38 | 39 | # Sometimes COLMAP output is wrong by flipping the left and right side of the world 40 | # Only way to detect is to check through the trajectories manually. 41 | # We add a dummy folder named "FLIP" under the trajectory folder to indicate this. 42 | flip_traj = os.path.exists(os.path.join(folder, "FLIP")) 43 | if flip_traj: 44 | print("Flip: ", folder) 45 | 46 | trans_dict = {} 47 | T_acc = RigidTransform(rotation=np.eye(3), 48 | translation=np.zeros(3), 49 | from_frame="cam_0", 50 | to_frame=f"cam_0") 51 | trans_dict_next = {} 52 | for l in range(0, num_lines-1): 53 | curr = lines[l] 54 | next_line = lines[l+1] 55 | cid, cQW, cQX, cQY, cQZ, cTX, cTY, cTZ, _, cname = curr.split() 56 | nid, nQW, nQX, nQY, nQZ, nTX, nTY, nTZ, _, nname = next_line.split() 57 | 58 | if data_type == "vime": 59 | try: 60 | start_frame = int(cname.split(".")[0].split("frame")[-1]) 61 | end_frame = int(nname.split(".")[0].split("frame")[-1]) 62 | except: 63 | print("Format incorrect! {folder} {cname} {nname}") 64 | continue 65 | elif data_type == "apple": 66 | try: 67 | start_frame = int(cname.split(".")[0]) 68 | end_frame = int(nname.split(".")[0]) 69 | except: 70 | print("Format incorrect! Folder={folder} current frame={cname} next frame={nname}! Should be something like xxxxxx.jpg") 71 | continue 72 | 73 | if l == 0: 74 | trans_dict[0] = (start_frame, cname, copy.deepcopy(T_acc)) 75 | 76 | q1 = np.array([cQX, cQY, cQZ, cQW]) 77 | q2 = np.array([nQX, nQY, nQZ, nQW]) 78 | 79 | rot1 = R.from_quat(q1) 80 | rot2 = R.from_quat(q2) 81 | 82 | t1 = np.array([float(cTX), float(cTY), float(cTZ)]) 83 | t2 = np.array([float(nTX), float(nTY), float(nTZ)]) 84 | 85 | T_world_caml = RigidTransform(rotation=rot1.as_matrix(), 86 | translation=t1, 87 | from_frame="world", 88 | to_frame=f"cam_{l}") 89 | T_world_caml2 = RigidTransform(rotation=rot2.as_matrix(), 90 | translation=t2, 91 | from_frame="world", 92 | to_frame=f"cam_{l+1}") 93 | T_caml2_caml = T_world_caml * T_world_caml2.inverse() 94 | if flip_traj: 95 | T_caml2_caml.translation *= np.asarray([-1,-1,1]) 96 | trans_dict_next[l] = (start_frame, end_frame, cname, nname, T_caml2_caml) 97 | # for example, (0, 1, 00000.jpg, 00001.jpg, Transformation between the two cameras) 98 | 99 | for l in range(0, num_lines-1): 100 | start_frame, end_frame, cname, nname, T_caml2_caml = trans_dict_next[l] 101 | 102 | T_caml_cam0 = copy.deepcopy(T_acc * T_caml2_caml) 103 | T_acc = T_acc * T_caml2_caml 104 | 105 | trans_dict[l+1] = (end_frame, nname, T_caml_cam0) 106 | # (3, 00003.jpg, Transformation from cam_0 to cam_3) 107 | 108 | # T_cami_cami+k = T_cam0_cami+k * T_cami_cam0 109 | # Calculate transformation from cam_i to cam_i+k depending on the interval 110 | dct = {} 111 | all_displacements = [] 112 | for l in range(0, num_lines-real_interval): 113 | framel, frame_namel, T_caml_cam0 = trans_dict[l] 114 | framelk, frame_namelk, T_camlk_cam0 = trans_dict[l+real_interval] 115 | 116 | if framelk - framel > real_interval * 2: 117 | print(f"Error {folder} {frame_namel} {frame_namelk}") 118 | continue 119 | 120 | img_full_path = os.path.join(folder, "images", cname) 121 | if not os.path.exists(img_full_path): 122 | continue 123 | 124 | T_camlk_caml = T_caml_cam0.inverse() * T_camlk_cam0 125 | T_camlk_caml_t = T_camlk_caml.translation.astype(float) 126 | T_camlk_caml_r = T_camlk_caml.rotation 127 | T_camlk_caml_r = T_camlk_caml_r.tolist() 128 | 129 | dct[frame_namel] = (list(T_camlk_caml_t), T_camlk_caml_r, frame_namelk) 130 | all_displacements.append(np.linalg.norm(T_camlk_caml_t)) 131 | 132 | max_displace = np.max(all_displacements) 133 | 134 | if interval == 1: 135 | json_save_path = os.path.join(folder, f'labels.json') 136 | else: 137 | json_save_path = os.path.join(folder, f'labels_{interval}.json') 138 | if os.path.exists(json_save_path): 139 | os.remove(json_save_path) 140 | with open(json_save_path, 'w+') as fp: 141 | json.dump(dct, fp, indent=4, sort_keys=True) 142 | return dct, max_displace 143 | 144 | def main(): 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('--dir', required=True, help="directory to process") 147 | parser.add_argument('--recursive', action='store_true', help="if true, recursively process all subfolders") 148 | parser.add_argument('--interval', type=int, default=1) 149 | parser.add_argument('--txt', default='sparse/0/images.txt') 150 | args = parser.parse_args() 151 | 152 | if not args.recursive: 153 | get_colmap_labels(args.dir, args.txt, interval = args.interval) 154 | else: 155 | folders = sorted(os.listdir(args.dir)) 156 | for f in folders: 157 | if not f.startswith(".") and not (".txt" in f) and not (".sh" in f): 158 | get_colmap_labels(os.path.join(args.dir, f), interval = args.interval) 159 | 160 | 161 | if __name__ == "__main__": 162 | main() 163 | -------------------------------------------------------------------------------- /src/models/score_sde/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | from . import sde_lib 21 | import numpy as np 22 | 23 | 24 | _MODELS = {} 25 | 26 | 27 | def register_model(cls=None, *, name=None): 28 | """A decorator for registering model classes.""" 29 | 30 | def _register(cls): 31 | if name is None: 32 | local_name = cls.__name__ 33 | else: 34 | local_name = name 35 | if local_name in _MODELS: 36 | raise ValueError(f'Already registered model with name: {local_name}') 37 | _MODELS[local_name] = cls 38 | return cls 39 | 40 | if cls is None: 41 | return _register 42 | else: 43 | return _register(cls) 44 | 45 | 46 | def get_model(name): 47 | print(_MODELS) 48 | return _MODELS[name] 49 | 50 | 51 | def get_sigmas(config): 52 | """Get sigmas --- the set of noise levels for SMLD from config files. 53 | Args: 54 | config: A ConfigDict object parsed from the config file 55 | Returns: 56 | sigmas: a jax numpy arrary of noise levels 57 | """ 58 | sigmas = np.exp( 59 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) 60 | 61 | return sigmas 62 | 63 | 64 | def get_ddpm_params(config): 65 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 66 | num_diffusion_timesteps = 1000 67 | # parameters need to be adapted if number of time steps differs from 1000 68 | beta_start = config.model.beta_min / config.model.num_scales 69 | beta_end = config.model.beta_max / config.model.num_scales 70 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 71 | 72 | alphas = 1. - betas 73 | alphas_cumprod = np.cumprod(alphas, axis=0) 74 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 75 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 76 | 77 | return { 78 | 'betas': betas, 79 | 'alphas': alphas, 80 | 'alphas_cumprod': alphas_cumprod, 81 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 82 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 83 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 84 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 85 | 'num_diffusion_timesteps': num_diffusion_timesteps 86 | } 87 | 88 | 89 | def create_model(config): 90 | """Create the score model.""" 91 | model_name = config.model.name 92 | score_model = get_model(model_name)(config) 93 | score_model = score_model.to(config.device) 94 | score_model = torch.nn.DataParallel(score_model) 95 | return score_model 96 | 97 | 98 | def get_model_fn(model, train=False): 99 | """Create a function to give the output of the score-based model. 100 | 101 | Args: 102 | model: The score model. 103 | train: `True` for training and `False` for evaluation. 104 | 105 | Returns: 106 | A model function. 107 | """ 108 | 109 | def model_fn(x, cond_im, labels, std, conditioning): 110 | """Compute the output of the score-based model. 111 | 112 | Args: 113 | x: A mini-batch of input data. 114 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 115 | for different models. 116 | 117 | Returns: 118 | A tuple of (model output, new mutable states) 119 | """ 120 | if not train: 121 | model.eval() 122 | return model(x, cond_im, labels, std, conditioning) 123 | else: 124 | model.train() 125 | return model(x, cond_im, labels, std, conditioning) 126 | 127 | return model_fn 128 | 129 | 130 | def get_score_fn(sde, model, train=False, continuous=False): 131 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 132 | 133 | Args: 134 | sde: An `sde_lib.SDE` object that represents the forward SDE. 135 | model: A score model. 136 | train: `True` for training and `False` for evaluation. 137 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 138 | 139 | Returns: 140 | A score function. 141 | """ 142 | model_fn = get_model_fn(model, train=train) 143 | 144 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 145 | def score_fn(x, t): 146 | # Scale neural network output by standard deviation and flip sign 147 | if continuous or isinstance(sde, sde_lib.subVPSDE): 148 | # For VP-trained models, t=0 corresponds to the lowest noise level 149 | # The maximum value of time embedding is assumed to 999 for 150 | # continuously-trained models. 151 | labels = t * 999 152 | score = model_fn(x, labels) 153 | std = sde.marginal_prob(torch.zeros_like(x), t)[1] 154 | else: 155 | # For VP-trained models, t=0 corresponds to the lowest noise level 156 | labels = t * (sde.N - 1) 157 | score = model_fn(x, labels) 158 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 159 | 160 | score = -score / std[:, None, None, None] 161 | return score 162 | 163 | elif isinstance(sde, sde_lib.VESDE): 164 | def score_fn(x, cond_im, t, std, conditioning): 165 | if continuous: 166 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 167 | else: 168 | # For VE-trained models, t=0 corresponds to the highest noise level 169 | labels = sde.T - t 170 | labels *= sde.N - 1 171 | labels = torch.round(labels).long() 172 | 173 | score = model_fn(x, cond_im, labels, std, conditioning) 174 | return score 175 | 176 | else: 177 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 178 | 179 | return score_fn 180 | 181 | 182 | def to_flattened_numpy(x): 183 | """Flatten a torch tensor `x` and convert it to numpy.""" 184 | return x.detach().cpu().numpy().reshape((-1,)) 185 | 186 | 187 | def from_flattened_numpy(x, shape): 188 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 189 | return torch.from_numpy(x.reshape(shape)) 190 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/linux,osx,python,c,c++,cuda 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,osx,python,c,c++,cuda 3 | 4 | ### C ### 5 | # Prerequisites 6 | *.d 7 | 8 | # Object files 9 | *.o 10 | *.ko 11 | *.obj 12 | *.elf 13 | 14 | # Linker output 15 | *.ilk 16 | *.map 17 | *.exp 18 | 19 | # Precompiled Headers 20 | *.gch 21 | *.pch 22 | 23 | # Libraries 24 | *.lib 25 | *.a 26 | *.la 27 | *.lo 28 | 29 | # Shared objects (inc. Windows DLLs) 30 | *.dll 31 | *.so 32 | *.so.* 33 | *.dylib 34 | 35 | # Executables 36 | *.exe 37 | *.out 38 | *.app 39 | *.i*86 40 | *.x86_64 41 | *.hex 42 | 43 | # Debug files 44 | *.dSYM/ 45 | *.su 46 | *.idb 47 | *.pdb 48 | 49 | # Kernel Module Compile Results 50 | *.mod* 51 | *.cmd 52 | .tmp_versions/ 53 | modules.order 54 | Module.symvers 55 | Mkfile.old 56 | dkms.conf 57 | 58 | ### C++ ### 59 | # Prerequisites 60 | 61 | # Compiled Object files 62 | *.slo 63 | 64 | # Precompiled Headers 65 | 66 | # Compiled Dynamic libraries 67 | 68 | # Fortran module files 69 | *.mod 70 | *.smod 71 | 72 | # Compiled Static libraries 73 | *.lai 74 | 75 | # Executables 76 | 77 | ### CUDA ### 78 | *.i 79 | *.ii 80 | *.gpu 81 | *.ptx 82 | *.cubin 83 | *.fatbin 84 | 85 | ### Linux ### 86 | *~ 87 | 88 | # temporary files which can be created if a process still has a handle open of a deleted file 89 | .fuse_hidden* 90 | 91 | # KDE directory preferences 92 | .directory 93 | 94 | # Linux trash folder which might appear on any partition or disk 95 | .Trash-* 96 | 97 | # .nfs files are created when an open file is removed but is still being accessed 98 | .nfs* 99 | 100 | ### OSX ### 101 | # General 102 | .DS_Store 103 | .AppleDouble 104 | .LSOverride 105 | 106 | # Icon must end with two \r 107 | Icon 108 | 109 | 110 | # Thumbnails 111 | ._* 112 | 113 | # Files that might appear in the root of a volume 114 | .DocumentRevisions-V100 115 | .fseventsd 116 | .Spotlight-V100 117 | .TemporaryItems 118 | .Trashes 119 | .VolumeIcon.icns 120 | .com.apple.timemachine.donotpresent 121 | 122 | # Directories potentially created on remote AFP share 123 | .AppleDB 124 | .AppleDesktop 125 | Network Trash Folder 126 | Temporary Items 127 | .apdisk 128 | 129 | ### Python ### 130 | # Byte-compiled / optimized / DLL files 131 | __pycache__/ 132 | *.py[cod] 133 | *$py.class 134 | 135 | # C extensions 136 | 137 | # Distribution / packaging 138 | .Python 139 | build/ 140 | develop-eggs/ 141 | dist/ 142 | downloads/ 143 | eggs/ 144 | .eggs/ 145 | lib/ 146 | lib64/ 147 | parts/ 148 | sdist/ 149 | var/ 150 | wheels/ 151 | share/python-wheels/ 152 | *.egg-info/ 153 | .installed.cfg 154 | *.egg 155 | MANIFEST 156 | 157 | # PyInstaller 158 | # Usually these files are written by a python script from a template 159 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 160 | *.manifest 161 | *.spec 162 | 163 | # Installer logs 164 | pip-log.txt 165 | pip-delete-this-directory.txt 166 | 167 | # Unit test / coverage reports 168 | htmlcov/ 169 | .tox/ 170 | .nox/ 171 | .coverage 172 | .coverage.* 173 | .cache 174 | nosetests.xml 175 | coverage.xml 176 | *.cover 177 | *.py,cover 178 | .hypothesis/ 179 | .pytest_cache/ 180 | cover/ 181 | 182 | # Translations 183 | *.mo 184 | *.pot 185 | 186 | # Django stuff: 187 | *.log 188 | local_settings.py 189 | db.sqlite3 190 | db.sqlite3-journal 191 | 192 | # Flask stuff: 193 | instance/ 194 | .webassets-cache 195 | 196 | # Scrapy stuff: 197 | .scrapy 198 | 199 | # Sphinx documentation 200 | docs/_build/ 201 | 202 | # PyBuilder 203 | .pybuilder/ 204 | target/ 205 | 206 | # Jupyter Notebook 207 | .ipynb_checkpoints 208 | 209 | # IPython 210 | profile_default/ 211 | ipython_config.py 212 | 213 | # pyenv 214 | # For a library or package, you might want to ignore these files since the code is 215 | # intended to run in multiple environments; otherwise, check them in: 216 | # .python-version 217 | 218 | # pipenv 219 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 220 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 221 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 222 | # install all needed dependencies. 223 | #Pipfile.lock 224 | 225 | # poetry 226 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 227 | # This is especially recommended for binary packages to ensure reproducibility, and is more 228 | # commonly ignored for libraries. 229 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 230 | #poetry.lock 231 | 232 | # pdm 233 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 234 | #pdm.lock 235 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 236 | # in version control. 237 | # https://pdm.fming.dev/#use-with-ide 238 | .pdm.toml 239 | 240 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 241 | __pypackages__/ 242 | 243 | # Celery stuff 244 | celerybeat-schedule 245 | celerybeat.pid 246 | 247 | # SageMath parsed files 248 | *.sage.py 249 | 250 | # Environments 251 | .env 252 | .venv 253 | env/ 254 | venv/ 255 | ENV/ 256 | env.bak/ 257 | venv.bak/ 258 | 259 | # Spyder project settings 260 | .spyderproject 261 | .spyproject 262 | 263 | # Rope project settings 264 | .ropeproject 265 | 266 | # mkdocs documentation 267 | /site 268 | 269 | # mypy 270 | .mypy_cache/ 271 | .dmypy.json 272 | dmypy.json 273 | 274 | # Pyre type checker 275 | .pyre/ 276 | 277 | # pytype static type analyzer 278 | .pytype/ 279 | 280 | # Cython debug symbols 281 | cython_debug/ 282 | 283 | ### Vim ### 284 | # Swap 285 | [._]*.s[a-v][a-z] 286 | !*.svg # comment out if you don't need vector files 287 | [._]*.sw[a-p] 288 | [._]s[a-rt-v][a-z] 289 | [._]ss[a-gi-z] 290 | [._]sw[a-p] 291 | 292 | # Session 293 | Session.vim 294 | Sessionx.vim 295 | 296 | # Temporary 297 | .netrwhist 298 | *~ 299 | # Auto-generated tag files 300 | tags 301 | # Persistent undo 302 | [._]*.un~ 303 | 304 | # PyCharm 305 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 306 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 307 | # and can be added to the global gitignore or merged into this file. For a more nuclear 308 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 309 | #.idea/ 310 | 311 | # End of https://www.toptal.com/developers/gitignore/api/linux,osx,python,c,c++,cuda 312 | 313 | 314 | ### Project specific ### 315 | dataset-data 316 | instance-data 317 | vision 318 | -------------------------------------------------------------------------------- /src/models/score_sde/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /src/models/score_sde/ddpm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """DDPM model. 18 | 19 | This code is the pytorch equivalent of: 20 | https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py 21 | """ 22 | import torch 23 | import torch.nn as nn 24 | import functools 25 | 26 | from . import utils, layers, normalization 27 | 28 | RefineBlock = layers.RefineBlock 29 | ResidualBlock = layers.ResidualBlock 30 | ResnetBlockDDPM = layers.ResnetBlockDDPM 31 | Upsample = layers.Upsample 32 | Downsample = layers.Downsample 33 | conv3x3 = layers.ddpm_conv3x3 34 | get_act = layers.get_act 35 | get_normalization = normalization.get_normalization 36 | default_initializer = layers.default_init 37 | 38 | 39 | @utils.register_model(name='ddpm') 40 | class DDPM(nn.Module): 41 | def __init__(self, config): 42 | super().__init__() 43 | self.act = act = get_act(config) 44 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) 45 | 46 | self.nf = nf = config.model.nf 47 | ch_mult = config.model.ch_mult 48 | self.num_res_blocks = num_res_blocks = config.model.num_res_blocks 49 | self.attn_resolutions = attn_resolutions = config.model.attn_resolutions 50 | dropout = config.model.dropout 51 | resamp_with_conv = config.model.resamp_with_conv 52 | self.num_resolutions = num_resolutions = len(ch_mult) 53 | self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] 54 | 55 | AttnBlock = functools.partial(layers.AttnBlock) 56 | self.conditional = conditional = config.model.conditional 57 | ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) 58 | if conditional: 59 | # Condition on noise levels. 60 | modules = [nn.Linear(nf, nf * 4)] 61 | modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) 62 | nn.init.zeros_(modules[0].bias) 63 | modules.append(nn.Linear(nf * 4, nf * 4)) 64 | modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) 65 | nn.init.zeros_(modules[1].bias) 66 | 67 | self.centered = config.data.centered 68 | channels = config.data.num_channels 69 | 70 | # Downsampling block 71 | modules.append(conv3x3(channels, nf)) 72 | hs_c = [nf] 73 | in_ch = nf 74 | for i_level in range(num_resolutions): 75 | # Residual blocks for this resolution 76 | for i_block in range(num_res_blocks): 77 | out_ch = nf * ch_mult[i_level] 78 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 79 | in_ch = out_ch 80 | if all_resolutions[i_level] in attn_resolutions: 81 | modules.append(AttnBlock(channels=in_ch)) 82 | hs_c.append(in_ch) 83 | if i_level != num_resolutions - 1: 84 | modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) 85 | hs_c.append(in_ch) 86 | 87 | in_ch = hs_c[-1] 88 | modules.append(ResnetBlock(in_ch=in_ch)) 89 | modules.append(AttnBlock(channels=in_ch)) 90 | modules.append(ResnetBlock(in_ch=in_ch)) 91 | 92 | # Upsampling block 93 | for i_level in reversed(range(num_resolutions)): 94 | for i_block in range(num_res_blocks + 1): 95 | out_ch = nf * ch_mult[i_level] 96 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) 97 | in_ch = out_ch 98 | if all_resolutions[i_level] in attn_resolutions: 99 | modules.append(AttnBlock(channels=in_ch)) 100 | if i_level != 0: 101 | modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) 102 | 103 | assert not hs_c 104 | modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) 105 | modules.append(conv3x3(in_ch, channels, init_scale=0.)) 106 | self.all_modules = nn.ModuleList(modules) 107 | 108 | self.scale_by_sigma = config.model.scale_by_sigma 109 | 110 | def forward(self, x, labels): 111 | modules = self.all_modules 112 | m_idx = 0 113 | if self.conditional: 114 | # timestep/scale embedding 115 | timesteps = labels 116 | temb = layers.get_timestep_embedding(timesteps, self.nf) 117 | temb = modules[m_idx](temb) 118 | m_idx += 1 119 | temb = modules[m_idx](self.act(temb)) 120 | m_idx += 1 121 | else: 122 | temb = None 123 | 124 | if self.centered: 125 | # Input is in [-1, 1] 126 | h = x 127 | else: 128 | # Input is in [0, 1] 129 | h = 2 * x - 1. 130 | 131 | # Downsampling block 132 | hs = [modules[m_idx](h)] 133 | m_idx += 1 134 | for i_level in range(self.num_resolutions): 135 | # Residual blocks for this resolution 136 | for i_block in range(self.num_res_blocks): 137 | h = modules[m_idx](hs[-1], temb) 138 | m_idx += 1 139 | if h.shape[-1] in self.attn_resolutions: 140 | h = modules[m_idx](h) 141 | m_idx += 1 142 | hs.append(h) 143 | if i_level != self.num_resolutions - 1: 144 | hs.append(modules[m_idx](hs[-1])) 145 | m_idx += 1 146 | 147 | h = hs[-1] 148 | h = modules[m_idx](h, temb) 149 | m_idx += 1 150 | h = modules[m_idx](h) 151 | m_idx += 1 152 | h = modules[m_idx](h, temb) 153 | m_idx += 1 154 | 155 | # Upsampling block 156 | for i_level in reversed(range(self.num_resolutions)): 157 | for i_block in range(self.num_res_blocks + 1): 158 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 159 | m_idx += 1 160 | if h.shape[-1] in self.attn_resolutions: 161 | h = modules[m_idx](h) 162 | m_idx += 1 163 | if i_level != 0: 164 | h = modules[m_idx](h) 165 | m_idx += 1 166 | 167 | assert not hs 168 | h = self.act(modules[m_idx](h)) 169 | m_idx += 1 170 | h = modules[m_idx](h) 171 | m_idx += 1 172 | assert m_idx == len(modules) 173 | 174 | if self.scale_by_sigma: 175 | # Divide the output by sigmas. Useful for training with the NCSN loss. 176 | # The DDPM loss scales the network output by sigma in the loss function, 177 | # so no need of doing it here. 178 | used_sigmas = self.sigmas[labels, None, None, None] 179 | h = h / used_sigmas 180 | 181 | return h 182 | -------------------------------------------------------------------------------- /src/scripts/generate_inference_json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import pathlib 5 | import json 6 | import numpy as np 7 | from scipy.spatial.transform import Rotation as scir 8 | 9 | script_path = os.path.dirname(os.path.abspath(__file__)) 10 | src_path = os.path.abspath(os.path.join(script_path,'..')) 11 | sys.path.append(os.path.join(src_path, "datasets")) 12 | from write_translations import get_colmap_labels 13 | from slam_utils import read_pose 14 | 15 | def main(args): 16 | if args.suffix is not None: 17 | output_root = args.output_root + f"_{args.suffix}" 18 | else: 19 | output_root = args.output_root 20 | # create a folder to store all the generated images 21 | pathlib.Path(output_root+"/images").mkdir(parents=True, exist_ok=True) 22 | 23 | data = {} 24 | for pf, focal_length, img_h in zip(args.data_folders, args.focal_lengths, args.image_heights): 25 | for folder_name in os.listdir(pf): 26 | folder_path = os.path.join(pf, folder_name) 27 | if not os.path.isdir(folder_path): 28 | continue 29 | if args.sfm_method == "colmap": 30 | if os.path.exists(os.path.join(folder_path, "FAIL")): 31 | continue 32 | dct = get_colmap_labels(folder_path, focal_length=focal_length, data_type="apple") 33 | elif args.sfm_method == "grabber_orbslam": 34 | json_path = os.path.join(folder_path, "raw_labels.json") 35 | if not os.path.exists(json_path): 36 | print(f"Skipping {folder_path}, no raw_labels.json") 37 | continue 38 | dct = read_pose(json_path) 39 | dct["focal_y"] = float(focal_length/img_h) 40 | data[folder_path] = dct 41 | 42 | pre_conversion = np.array([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,1]]) 43 | 44 | if args.task == "stack": 45 | x_euler_range = [-10,10] 46 | y_euler_range = [-10,10] 47 | z_euler_range = [-10,15] 48 | magnitude_l, magnitude_u = 0.02,0.04 49 | elif args.task == "push": 50 | # NOTE: pushing data uses COLMAP, thus it doesn't recover real-world metric scale 51 | magnitude_l, magnitude_u = 0.2, 1.0 52 | args.sample_rotation = False 53 | elif args.task == "pour": 54 | x_euler_range = None 55 | y_euler_range = None 56 | z_euler_range = [-10,10] 57 | magnitude_l, magnitude_u = 0.02,0.04 58 | elif args.task == "hang": 59 | x_euler_range = None 60 | y_euler_range = None 61 | z_euler_range = [-10,10] 62 | magnitude_l, magnitude_u = 0.02,0.04 63 | else: 64 | raise NotImplementedError 65 | 66 | generation_data = [] 67 | folders = sorted(list(data.keys())) 68 | image_index = 0 69 | for folder in folders: 70 | folder_data = data[folder] 71 | num_imgs = len(folder_data['imgs']) 72 | 73 | gripper_state_file = os.path.join(folder, "gripper_state.json") 74 | if os.path.exists(gripper_state_file): 75 | gripper_state = json.load(open(gripper_state_file, 'r')) 76 | else: 77 | gripper_state = None 78 | 79 | for frame_idx in range(0, num_imgs, args.every_x_frame): 80 | full_img_path = os.path.join(folder, folder_data['imgs'][frame_idx]) 81 | 82 | # Skip frames where the grabber is opening/closing 83 | if gripper_state is not None: 84 | img_key = full_img_path.split("/")[-1] 85 | if img_key in gripper_state: 86 | if gripper_state[img_key] < 0: 87 | continue 88 | 89 | orig_img_copy_path = os.path.join(output_root, 'images', "%09d_o.png" % image_index) 90 | 91 | for i in range (args.mult): 92 | output_img_path = os.path.join(output_root, "images", "%09d.png" % image_index) 93 | 94 | direction = np.random.randn(3) 95 | direction /= np.linalg.norm(direction) 96 | # Sample a random magnitude between 0 and 5 97 | magnitude = np.random.uniform(magnitude_l,magnitude_u) 98 | # Create a translation vector with the random direction and magnitude 99 | translation = direction * magnitude 100 | transformation = np.eye(4) 101 | transformation[:3, 3] = translation 102 | 103 | if args.sample_rotation: 104 | if x_euler_range is not None: 105 | x_euler = np.random.uniform(*x_euler_range) 106 | else: 107 | x_euler = 0 108 | if y_euler_range is not None: 109 | 110 | y_euler = np.random.uniform(*y_euler_range) 111 | else: 112 | y_euler = 0 113 | if z_euler_range is not None: 114 | z_euler = np.random.uniform(*z_euler_range) 115 | else: 116 | z_euler = 0 117 | rot_euler = np.array([x_euler, y_euler, z_euler]) 118 | rot = scir.from_euler('xyz', rot_euler, degrees=True).as_matrix() 119 | transformation[:3,:3] = rot 120 | transformation = np.matmul(transformation, pre_conversion) 121 | transformation = np.matmul(np.linalg.inv(pre_conversion),transformation) 122 | 123 | this_data = { 124 | "img": full_img_path, 125 | "orig_img_copy_path" : orig_img_copy_path, 126 | "focal_y": folder_data["focal_y"], 127 | "output": output_img_path, 128 | "magnitude" : magnitude, 129 | "transformation": transformation.tolist() 130 | } 131 | if args.sample_rotation: 132 | this_data["rot_euler"] = rot_euler.astype(float).tolist() 133 | generation_data.append(this_data) 134 | image_index += 1 135 | 136 | print("Total number of images:", len(generation_data)) 137 | out_file = os.path.join(output_root, 'data.json') 138 | if os.path.exists(out_file): 139 | print(f"{out_file} file exists") 140 | else: 141 | with open(out_file, 'w') as f: json.dump(generation_data, f) 142 | print("Written to: ", out_file) 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser(description='') 146 | parser.add_argument('--task',default='push',choices=['push','stack','pour', 'hang'],type=str) 147 | parser.add_argument('--sfm_method', type=str, choices=['colmap', 'grabber_orbslam']) 148 | parser.add_argument("--data_folders", type=str, nargs="+", default=[]) 149 | parser.add_argument("--focal_lengths", type=float, nargs="+", default=[]) 150 | parser.add_argument("--image_heights", type=float, nargs="+", default=[]) 151 | parser.add_argument('--output_root', type=str, help="output folder") 152 | parser.add_argument('--suffix', default=None, help="suffix to add to output_root, if you want to generate multiple versions") 153 | parser.add_argument('--sample_rotation',action='store_true') 154 | parser.add_argument('--every_x_frame', type=int, default=10, help="Generate augmenting images of every every_x_frame-th frame. If the demonstrations are recorded at high frame-rate (e.g. above 5fps), nearby frames are very similar, and there are too many frames in each trajectory, so it is not necessary to generate augmenting samples for every frame.") 155 | parser.add_argument('--mult',type=int,default=3,help='Number of augmenting images to generate per input frame') 156 | 157 | args = parser.parse_args() 158 | 159 | main(args) -------------------------------------------------------------------------------- /src/utils/score_tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import pudb 4 | 5 | class Score_sde_model(torch.nn.Module): 6 | def __init__(self,score_model,sde,ray_downsampler,rays_require_downsample=True,rays_as_list=False): 7 | super().__init__() 8 | self.score_model = score_model 9 | self.sde = sde 10 | self.rsde = sde.reverse(self.score,probability_flow=False) 11 | 12 | self.ray_downsampler = ray_downsampler 13 | self.rays_require_downsample = rays_require_downsample 14 | self.rays_as_list = rays_as_list 15 | 16 | def score(self,x,cond_im,t, ff_ref, ff_a, ff_b): 17 | if self.rays_as_list: 18 | d_ff_ref = [self.ray_downsampler(rays) for rays in ff_ref] if self.rays_require_downsample else ff_ref 19 | d_ff_a = [self.ray_downsampler(rays) for rays in ff_a] if self.rays_require_downsample else ff_a 20 | d_ff_b = [self.ray_downsampler(rays) for rays in ff_b] if self.rays_require_downsample else ff_b 21 | else: 22 | d_ff_ref = self.ray_downsampler(ff_ref) if self.rays_require_downsample else ff_ref 23 | d_ff_a = self.ray_downsampler(ff_a) if self.rays_require_downsample else ff_a 24 | d_ff_b = self.ray_downsampler(ff_b) if self.rays_require_downsample else ff_b 25 | _, std = self.sde.marginal_prob(torch.zeros_like(x), t) 26 | cond_std = torch.ones_like(std)*0.01 # assume conditioning image has minimal noise 27 | score_a, score_b = self.score_model(x, cond_im, std, cond_std, d_ff_ref, d_ff_a, d_ff_b) # ignore second score 28 | return score_a 29 | 30 | def forward_diffusion(self,x,t): 31 | z = torch.randn_like(x) 32 | mean, std = self.sde.marginal_prob(x, t) 33 | perturbed_data = mean + std[:, None, None, None] * z 34 | return perturbed_data, z, std 35 | 36 | def t_uniform(self,batch_size,device=None,eps=1e-5): 37 | # eps prevents sampling exactly 0 38 | t = torch.rand(batch_size, device=device) * (self.sde.T - eps) + eps 39 | return t 40 | 41 | def reverse_diffusion_predictor(self, x, cond_im, t, ff_ref, ff_a, ff_b): 42 | f, G = self.rsde.discretize(x, cond_im, t, ff_ref, ff_a, ff_b) 43 | z = torch.randn_like(x) 44 | x_mean = x - f 45 | x = x_mean + G[:, None, None, None] * z 46 | return x, x_mean 47 | 48 | def langevin_corrector(self, x, cond_im, t, ff_ref, ff_a, ff_b): 49 | sde = self.sde 50 | n_steps = 1 51 | target_snr = 0.075 52 | 53 | # specific to VESDE 54 | alpha = torch.ones_like(t) 55 | 56 | for i in range(n_steps): 57 | grad = self.score(x, cond_im, t, ff_ref, ff_a, ff_b) 58 | noise = torch.randn_like(x) 59 | grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() 60 | noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() 61 | step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha 62 | x_mean = x + step_size[:, None, None, None] * grad 63 | x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise 64 | 65 | return x, x_mean 66 | 67 | class Score_sde_monocular_model(torch.nn.Module): 68 | def __init__(self,score_model,sde): 69 | super().__init__() 70 | self.score_model = score_model 71 | self.sde = sde 72 | self.rsde = sde.reverse(self.score,probability_flow=False) 73 | 74 | def score(self,x,t): 75 | _, std = self.sde.marginal_prob(torch.zeros_like(x), t) 76 | score = self.score_model(x, std) 77 | return score 78 | 79 | def forward_diffusion(self,x,t): 80 | z = torch.randn_like(x) 81 | mean, std = self.sde.marginal_prob(x, t) 82 | perturbed_data = mean + std[:, None, None, None] * z 83 | return perturbed_data, z, std 84 | 85 | def t_uniform(self,batch_size,device=None,eps=1e-5): 86 | # eps prevents sampling exactly 0 87 | t = torch.rand(batch_size, device=device) * (self.sde.T - eps) + eps 88 | return t 89 | 90 | def reverse_diffusion_predictor(self, x, t): 91 | f, G = self.rsde.discretize(x, t) 92 | z = torch.randn_like(x) 93 | x_mean = x - f 94 | x = x_mean + G[:, None, None, None] * z 95 | return x, x_mean 96 | 97 | def langevin_corrector(self, x, t): 98 | sde = self.sde 99 | n_steps = 1 100 | target_snr = 0.075 101 | 102 | # specific to VESDE 103 | alpha = torch.ones_like(t) 104 | 105 | for i in range(n_steps): 106 | grad = self.score(x, t) 107 | noise = torch.randn_like(x) 108 | grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() 109 | noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() 110 | step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha 111 | x_mean = x + step_size[:, None, None, None] * grad 112 | x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise 113 | 114 | return x, x_mean 115 | 116 | class Score_modifier(torch.nn.Module): 117 | def __init__(self,model,max_batch_size): 118 | super().__init__() 119 | self.model = model 120 | self.max_batch_size = max_batch_size 121 | 122 | def __call__(self,x,cond_ims,std,cond_std,ff_refs,ff_as,ff_bs): 123 | # assert len(ff_refs) == len(ff_as) == len(ff_bs) == len(cond_ims) 124 | 125 | # generate indices to batch data 126 | indices = [] 127 | for b in range(len(ff_refs)): 128 | for d in range(len(ff_refs[b])): 129 | indices.append([b,d]) 130 | 131 | # compute scores in batches 132 | independent_scores_a = [[] for _ in ff_refs] 133 | independent_scores_b = [[] for _ in ff_refs] 134 | for start in range(0,len(indices),self.max_batch_size): 135 | batch_indices = indices[start:start+self.max_batch_size] 136 | batch_x = torch.stack([x[b,...] for [b,_] in batch_indices],0) 137 | batch_cond_ims = torch.cat([cond_ims[b][n] for [b,n] in batch_indices],0) 138 | batch_ff_refs = torch.cat([ff_refs[b][n] for [b,n] in batch_indices],0) 139 | batch_ff_as = torch.cat([ff_as[b][n] for [b,n] in batch_indices],0) 140 | batch_ff_bs = torch.cat([ff_bs[b][n] for [b,n] in batch_indices],0) 141 | batch_std = torch.stack([std[b] for [b,_] in batch_indices],0) 142 | batch_cond_std = torch.stack([cond_std[b] for [b,_] in batch_indices],0) 143 | batch_score_a,batch_score_b = self.model(batch_x,batch_cond_ims,batch_std,batch_cond_std,batch_ff_refs,batch_ff_as,batch_ff_bs) 144 | for idx,[b,n] in enumerate(batch_indices): # unpack scores 145 | independent_scores_a[b].append(batch_score_a[idx,...]) 146 | independent_scores_b[b].append(batch_score_b[idx,...]) 147 | 148 | aggregated_score_a = [torch.stack(sss,0).mean(0) for sss in independent_scores_a] 149 | aggregated_score_b = [torch.stack(sss,0).mean(0) for sss in independent_scores_b] 150 | 151 | aggregated_score_a = torch.stack(aggregated_score_a,0) 152 | aggregated_score_b = torch.stack(aggregated_score_b,0) 153 | 154 | return aggregated_score_a,aggregated_score_b 155 | 156 | def train(self): self.model.train() 157 | def eval(self): self.model.eval() 158 | 159 | class Score_modifier_stochastic(torch.nn.Module): 160 | def __init__(self,model): 161 | super().__init__() 162 | self.model = model 163 | 164 | def __call__(self,x,cond_ims,std,cond_std,ff_refs,ff_as,ff_bs): 165 | assert len(ff_refs) == len(ff_as) == len(ff_bs) == len(cond_ims) 166 | independent_scores_a = [] 167 | independent_scores_b = [] 168 | n_conditioning_views = len(ff_refs) 169 | 170 | n = random.choice(list(range(n_conditioning_views))) 171 | cond_im = cond_ims[n] 172 | score_a,score_b = self.model(x,cond_im,std,cond_std,ff_refs[n],ff_as[n],ff_bs[n]) 173 | 174 | return score_a,score_b 175 | 176 | def train(self): self.model.train() 177 | def eval(self): self.model.eval() 178 | 179 | class Score_modifier_stochastic_sanity(torch.nn.Module): 180 | def __init__(self,model): 181 | super().__init__() 182 | self.model = model 183 | 184 | def __call__(self,x,cond_ims,std,cond_std,ff_refs,ff_as,ff_bs): 185 | assert len(ff_refs) == len(ff_as) == len(ff_bs) == len(cond_ims) 186 | independent_scores_a = [] 187 | independent_scores_b = [] 188 | n_conditioning_views = len(ff_refs) 189 | 190 | n = random.choice(list(range(n_conditioning_views))) 191 | # if n_conditioning_views == 1: 192 | # n = 0 193 | # else: 194 | # n = 1 195 | # n = 0 196 | cond_im = cond_ims[n] 197 | score_a,score_b = self.model(x,cond_im,std,cond_std,ff_refs[n],ff_as[n],ff_bs[n]) 198 | 199 | return score_a,score_b 200 | 201 | def train(self): self.model.train() 202 | def eval(self): self.model.eval() 203 | -------------------------------------------------------------------------------- /src/models/score_sde/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import torch.nn as nn 18 | import torch 19 | import functools 20 | 21 | 22 | def get_normalization(config, conditional=False): 23 | """Obtain normalization modules from the config file.""" 24 | norm = config.model.normalization 25 | if conditional: 26 | if norm == 'InstanceNorm++': 27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 28 | else: 29 | raise NotImplementedError(f'{norm} not implemented yet.') 30 | else: 31 | if norm == 'InstanceNorm': 32 | return nn.InstanceNorm2d 33 | elif norm == 'InstanceNorm++': 34 | return InstanceNorm2dPlus 35 | elif norm == 'VarianceNorm': 36 | return VarianceNorm2d 37 | elif norm == 'GroupNorm': 38 | return nn.GroupNorm 39 | else: 40 | raise ValueError('Unknown normalization: %s' % norm) 41 | 42 | 43 | class ConditionalBatchNorm2d(nn.Module): 44 | def __init__(self, num_features, num_classes, bias=True): 45 | super().__init__() 46 | self.num_features = num_features 47 | self.bias = bias 48 | self.bn = nn.BatchNorm2d(num_features, affine=False) 49 | if self.bias: 50 | self.embed = nn.Embedding(num_classes, num_features * 2) 51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 53 | else: 54 | self.embed = nn.Embedding(num_classes, num_features) 55 | self.embed.weight.data.uniform_() 56 | 57 | def forward(self, x, y): 58 | out = self.bn(x) 59 | if self.bias: 60 | gamma, beta = self.embed(y).chunk(2, dim=1) 61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 62 | else: 63 | gamma = self.embed(y) 64 | out = gamma.view(-1, self.num_features, 1, 1) * out 65 | return out 66 | 67 | 68 | class ConditionalInstanceNorm2d(nn.Module): 69 | def __init__(self, num_features, num_classes, bias=True): 70 | super().__init__() 71 | self.num_features = num_features 72 | self.bias = bias 73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 74 | if bias: 75 | self.embed = nn.Embedding(num_classes, num_features * 2) 76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 78 | else: 79 | self.embed = nn.Embedding(num_classes, num_features) 80 | self.embed.weight.data.uniform_() 81 | 82 | def forward(self, x, y): 83 | h = self.instance_norm(x) 84 | if self.bias: 85 | gamma, beta = self.embed(y).chunk(2, dim=-1) 86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 87 | else: 88 | gamma = self.embed(y) 89 | out = gamma.view(-1, self.num_features, 1, 1) * h 90 | return out 91 | 92 | 93 | class ConditionalVarianceNorm2d(nn.Module): 94 | def __init__(self, num_features, num_classes, bias=False): 95 | super().__init__() 96 | self.num_features = num_features 97 | self.bias = bias 98 | self.embed = nn.Embedding(num_classes, num_features) 99 | self.embed.weight.data.normal_(1, 0.02) 100 | 101 | def forward(self, x, y): 102 | vars = torch.var(x, dim=(2, 3), keepdim=True) 103 | h = x / torch.sqrt(vars + 1e-5) 104 | 105 | gamma = self.embed(y) 106 | out = gamma.view(-1, self.num_features, 1, 1) * h 107 | return out 108 | 109 | 110 | class VarianceNorm2d(nn.Module): 111 | def __init__(self, num_features, bias=False): 112 | super().__init__() 113 | self.num_features = num_features 114 | self.bias = bias 115 | self.alpha = nn.Parameter(torch.zeros(num_features)) 116 | self.alpha.data.normal_(1, 0.02) 117 | 118 | def forward(self, x): 119 | vars = torch.var(x, dim=(2, 3), keepdim=True) 120 | h = x / torch.sqrt(vars + 1e-5) 121 | 122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 123 | return out 124 | 125 | 126 | class ConditionalNoneNorm2d(nn.Module): 127 | def __init__(self, num_features, num_classes, bias=True): 128 | super().__init__() 129 | self.num_features = num_features 130 | self.bias = bias 131 | if bias: 132 | self.embed = nn.Embedding(num_classes, num_features * 2) 133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 135 | else: 136 | self.embed = nn.Embedding(num_classes, num_features) 137 | self.embed.weight.data.uniform_() 138 | 139 | def forward(self, x, y): 140 | if self.bias: 141 | gamma, beta = self.embed(y).chunk(2, dim=-1) 142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 143 | else: 144 | gamma = self.embed(y) 145 | out = gamma.view(-1, self.num_features, 1, 1) * x 146 | return out 147 | 148 | 149 | class NoneNorm2d(nn.Module): 150 | def __init__(self, num_features, bias=True): 151 | super().__init__() 152 | 153 | def forward(self, x): 154 | return x 155 | 156 | 157 | class InstanceNorm2dPlus(nn.Module): 158 | def __init__(self, num_features, bias=True): 159 | super().__init__() 160 | self.num_features = num_features 161 | self.bias = bias 162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 163 | self.alpha = nn.Parameter(torch.zeros(num_features)) 164 | self.gamma = nn.Parameter(torch.zeros(num_features)) 165 | self.alpha.data.normal_(1, 0.02) 166 | self.gamma.data.normal_(1, 0.02) 167 | if bias: 168 | self.beta = nn.Parameter(torch.zeros(num_features)) 169 | 170 | def forward(self, x): 171 | means = torch.mean(x, dim=(2, 3)) 172 | m = torch.mean(means, dim=-1, keepdim=True) 173 | v = torch.var(means, dim=-1, keepdim=True) 174 | means = (means - m) / (torch.sqrt(v + 1e-5)) 175 | h = self.instance_norm(x) 176 | 177 | if self.bias: 178 | h = h + means[..., None, None] * self.alpha[..., None, None] 179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 180 | else: 181 | h = h + means[..., None, None] * self.alpha[..., None, None] 182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 183 | return out 184 | 185 | 186 | class ConditionalInstanceNorm2dPlus(nn.Module): 187 | def __init__(self, num_features, num_classes, bias=True): 188 | super().__init__() 189 | self.num_features = num_features 190 | self.bias = bias 191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 192 | if bias: 193 | self.embed = nn.Embedding(num_classes, num_features * 3) 194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 196 | else: 197 | self.embed = nn.Embedding(num_classes, 2 * num_features) 198 | self.embed.weight.data.normal_(1, 0.02) 199 | 200 | def forward(self, x, y): 201 | means = torch.mean(x, dim=(2, 3)) 202 | m = torch.mean(means, dim=-1, keepdim=True) 203 | v = torch.var(means, dim=-1, keepdim=True) 204 | means = (means - m) / (torch.sqrt(v + 1e-5)) 205 | h = self.instance_norm(x) 206 | 207 | if self.bias: 208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 209 | h = h + means[..., None, None] * alpha[..., None, None] 210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 211 | else: 212 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 213 | h = h + means[..., None, None] * alpha[..., None, None] 214 | out = gamma.view(-1, self.num_features, 1, 1) * h 215 | return out 216 | -------------------------------------------------------------------------------- /src/models/score_sde/sde_lib.py: -------------------------------------------------------------------------------- 1 | """Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" 2 | import abc 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class SDE(abc.ABC): 8 | """SDE abstract class. Functions are designed for a mini-batch of inputs.""" 9 | 10 | def __init__(self, N): 11 | """Construct an SDE. 12 | 13 | Args: 14 | N: number of discretization time steps. 15 | """ 16 | super().__init__() 17 | self.N = N 18 | 19 | @property 20 | @abc.abstractmethod 21 | def T(self): 22 | """End time of the SDE.""" 23 | pass 24 | 25 | @abc.abstractmethod 26 | def sde(self, x, t): 27 | pass 28 | 29 | @abc.abstractmethod 30 | def marginal_prob(self, x, t): 31 | """Parameters to determine the marginal distribution of the SDE, $p_t(x)$.""" 32 | pass 33 | 34 | @abc.abstractmethod 35 | def prior_sampling(self, shape): 36 | """Generate one sample from the prior distribution, $p_T(x)$.""" 37 | pass 38 | 39 | @abc.abstractmethod 40 | def prior_logp(self, z): 41 | """Compute log-density of the prior distribution. 42 | 43 | Useful for computing the log-likelihood via probability flow ODE. 44 | 45 | Args: 46 | z: latent code 47 | Returns: 48 | log probability density 49 | """ 50 | pass 51 | 52 | def discretize(self, x, t): 53 | """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. 54 | 55 | Useful for reverse diffusion sampling and probabiliy flow sampling. 56 | Defaults to Euler-Maruyama discretization. 57 | 58 | Args: 59 | x: a torch tensor 60 | t: a torch float representing the time step (from 0 to `self.T`) 61 | 62 | Returns: 63 | f, G 64 | """ 65 | dt = 1 / self.N 66 | drift, diffusion = self.sde(x, t) 67 | f = drift * dt 68 | G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) 69 | return f, G 70 | 71 | def reverse(self, score_fn, probability_flow=False): 72 | """Create the reverse-time SDE/ODE. 73 | 74 | Args: 75 | score_fn: A time-dependent score-based model that takes x and t and returns the score. 76 | probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. 77 | """ 78 | N = self.N 79 | T = self.T 80 | sde_fn = self.sde 81 | discretize_fn = self.discretize 82 | 83 | # Build the class for reverse-time SDE. 84 | class RSDE(self.__class__): 85 | def __init__(self): 86 | self.N = N 87 | self.probability_flow = probability_flow 88 | 89 | @property 90 | def T(self): 91 | return T 92 | 93 | def sde(self, x, cond_im, t, ff_ref,ff_a,ff_b): 94 | """Create the drift and diffusion functions for the reverse SDE/ODE.""" 95 | drift, diffusion = sde_fn(x, t) 96 | score = score_fn(x, cond_im, t, ff_ref,ff_a,ff_b) 97 | drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) 98 | # Set the diffusion function to zero for ODEs. 99 | diffusion = 0. if self.probability_flow else diffusion 100 | return drift, diffusion 101 | 102 | def discretize(self, x, cond_im, t, ff_ref,ff_a,ff_b): 103 | """Create discretized iteration rules for the reverse diffusion sampler.""" 104 | f, G = discretize_fn(x, t) 105 | rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, cond_im, t, ff_ref,ff_a,ff_b) * (0.5 if self.probability_flow else 1.) 106 | rev_G = torch.zeros_like(G) if self.probability_flow else G 107 | return rev_f, rev_G 108 | 109 | return RSDE() 110 | 111 | 112 | class VPSDE(SDE): 113 | def __init__(self, beta_min=0.1, beta_max=20, N=1000): 114 | """Construct a Variance Preserving SDE. 115 | 116 | Args: 117 | beta_min: value of beta(0) 118 | beta_max: value of beta(1) 119 | N: number of discretization steps 120 | """ 121 | super().__init__(N) 122 | self.beta_0 = beta_min 123 | self.beta_1 = beta_max 124 | self.N = N 125 | self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N) 126 | self.alphas = 1. - self.discrete_betas 127 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 128 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 129 | self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) 130 | 131 | @property 132 | def T(self): 133 | return 1 134 | 135 | def sde(self, x, t): 136 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) 137 | drift = -0.5 * beta_t[:, None, None, None] * x 138 | diffusion = torch.sqrt(beta_t) 139 | return drift, diffusion 140 | 141 | def marginal_prob(self, x, t): 142 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 143 | mean = torch.exp(log_mean_coeff[:, None, None, None]) * x 144 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 145 | return mean, std 146 | 147 | def prior_sampling(self, shape): 148 | return torch.randn(*shape) 149 | 150 | def prior_logp(self, z): 151 | shape = z.shape 152 | N = np.prod(shape[1:]) 153 | logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2. 154 | return logps 155 | 156 | def discretize(self, x, t): 157 | """DDPM discretization.""" 158 | timestep = (t * (self.N - 1) / self.T).long() 159 | beta = self.discrete_betas.to(x.device)[timestep] 160 | alpha = self.alphas.to(x.device)[timestep] 161 | sqrt_beta = torch.sqrt(beta) 162 | f = torch.sqrt(alpha)[:, None, None, None] * x - x 163 | G = sqrt_beta 164 | return f, G 165 | 166 | 167 | class subVPSDE(SDE): 168 | def __init__(self, beta_min=0.1, beta_max=20, N=1000): 169 | """Construct the sub-VP SDE that excels at likelihoods. 170 | 171 | Args: 172 | beta_min: value of beta(0) 173 | beta_max: value of beta(1) 174 | N: number of discretization steps 175 | """ 176 | super().__init__(N) 177 | self.beta_0 = beta_min 178 | self.beta_1 = beta_max 179 | self.N = N 180 | 181 | @property 182 | def T(self): 183 | return 1 184 | 185 | def sde(self, x, t): 186 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) 187 | drift = -0.5 * beta_t[:, None, None, None] * x 188 | discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2) 189 | diffusion = torch.sqrt(beta_t * discount) 190 | return drift, diffusion 191 | 192 | def marginal_prob(self, x, t): 193 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 194 | mean = torch.exp(log_mean_coeff)[:, None, None, None] * x 195 | std = 1 - torch.exp(2. * log_mean_coeff) 196 | return mean, std 197 | 198 | def prior_sampling(self, shape): 199 | return torch.randn(*shape) 200 | 201 | def prior_logp(self, z): 202 | shape = z.shape 203 | N = np.prod(shape[1:]) 204 | return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2. 205 | 206 | 207 | class VESDE(SDE): 208 | def __init__(self, sigma_min=0.01, sigma_max=50, N=1000): 209 | """Construct a Variance Exploding SDE. 210 | 211 | Args: 212 | sigma_min: smallest sigma. 213 | sigma_max: largest sigma. 214 | N: number of discretization steps 215 | """ 216 | super().__init__(N) 217 | self.sigma_min = sigma_min 218 | self.sigma_max = sigma_max 219 | self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) 220 | self.N = N 221 | 222 | @property 223 | def T(self): 224 | return 1 225 | 226 | def sde(self, x, t): 227 | sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 228 | drift = torch.zeros_like(x) 229 | diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), 230 | device=t.device)) 231 | return drift, diffusion 232 | 233 | def marginal_prob(self, x, t): 234 | std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 235 | mean = x 236 | return mean, std 237 | 238 | def prior_sampling(self, shape): 239 | return torch.randn(*shape) * self.sigma_max 240 | 241 | def prior_logp(self, z): 242 | shape = z.shape 243 | N = np.prod(shape[1:]) 244 | return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2) 245 | 246 | def discretize(self, x, t): 247 | """SMLD(NCSN) discretization.""" 248 | timestep = (t * (self.N - 1) / self.T).long() 249 | sigma = self.discrete_sigmas.to(t.device)[timestep] 250 | adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), 251 | self.discrete_sigmas[timestep - 1].to(t.device)) 252 | f = torch.zeros_like(x) 253 | G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) 254 | return f, G 255 | -------------------------------------------------------------------------------- /src/scripts/sample-imgs-multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import sys 4 | 5 | script_path = os.path.dirname(os.path.realpath(sys.argv[0])) 6 | src_path = os.path.abspath(os.path.join(script_path,'..')) 7 | sys.path.append(src_path) 8 | os.chdir(src_path) 9 | 10 | 11 | def chunks(lst, n): 12 | """Yield successive n-sized chunks from lst.""" 13 | for i in range(0, len(lst), n): 14 | yield lst[i:i + n] 15 | 16 | def chunks_num(lst, n): 17 | """Yield n evenly sized chunks from lst.""" 18 | low = len(lst)//n 19 | rem = len(lst)-(low*n) 20 | counts = [low]*n 21 | for i in range(rem): counts[i] += 1 22 | ptr = 0 23 | res = [] 24 | for count in counts: 25 | res.append(lst[ptr:ptr+count]) 26 | ptr += count 27 | return res 28 | 29 | def gpu_list(string): 30 | return [int(item) for item in string.split(',')] 31 | 32 | import argparse 33 | argParser = argparse.ArgumentParser(description="start/resume inference on a given json file") 34 | argParser.add_argument("config") 35 | argParser.add_argument("-g","--gpus",dest="gpus",action="store",default=[0],type=gpu_list, help="Comma-separated list of GPU IDs to use") 36 | argParser.add_argument("-s",dest="steps",action="store",default=500,type=int,help="Number of steps of diffusion to run for each sample") 37 | argParser.add_argument("-b","--batch_size",dest="batch_size",action="store",default=16,type=int,help="Batch size for sampling") 38 | argParser.add_argument("-v", "--verbose", action="store_true", help="Increase output verbosity") 39 | argParser.add_argument("--model", help="Path to model checkpoint",required=True) 40 | cli_args = argParser.parse_args() 41 | assert cli_args.model is not None 42 | 43 | from utils import * 44 | 45 | from models.score_sde import ncsnpp 46 | import models.score_sde.sde_lib as sde_lib 47 | from models.score_sde.ncsnpp_dual import NCSNpp_dual 48 | import torch 49 | import numpy as np 50 | from tqdm import tqdm 51 | from PIL import Image 52 | import torch.optim as optim 53 | import json 54 | from models.score_sde.configs.LDM import LDM_config 55 | import shutil 56 | from models.vqgan.vqgan import VQModel 57 | from models.vqgan.configs.vqgan_32_4 import vqgan_32_4_config 58 | from models.score_sde.ema import ExponentialMovingAverage 59 | import functools 60 | from models.score_sde.layerspp import ResnetBlockBigGANpp 61 | import torch.nn.functional as F 62 | import pudb 63 | from filelock import FileLock 64 | import time 65 | from cv2 import cv2 66 | # center crop image to shape 67 | def center_crop(img,shape): 68 | h,w = shape 69 | center = img.shape 70 | x = center[1]/2 - w/2 71 | y = center[0]/2 - h/2 72 | center_crop = img[int(y):int(y+h), int(x):int(x+w)] 73 | return center_crop 74 | 75 | def maximal_crop_to_shape(image,shape,interpolation=cv2.INTER_AREA): 76 | target_aspect = shape[1]/shape[0] 77 | input_aspect = image.shape[1]/image.shape[0] 78 | if input_aspect > target_aspect: 79 | center_crop_shape = (image.shape[0],int(image.shape[0] * target_aspect)) 80 | else: 81 | center_crop_shape = (int(image.shape[1] / target_aspect),image.shape[1]) 82 | cropped = center_crop(image,center_crop_shape) 83 | resized = cv2.resize(cropped, (shape[1],shape[0]),interpolation=interpolation) 84 | return resized 85 | 86 | def crop_like_dataset(frame_a): 87 | orig_w, orig_h = frame_a.size 88 | if orig_w == 256 and orig_h == 256: 89 | return np.asarray(frame_a) 90 | # ensure images have 360 height 91 | if orig_h != 360: 92 | new_w = round(frame_a.size[0] * (360/frame_a.size[1])) 93 | frame_a = np.asarray(frame_a.resize((new_w,360))) 94 | else: 95 | frame_a = np.asarray(frame_a) 96 | if orig_w != 360: # if the width is not 360 97 | # crop and downsample 98 | left_pos = (frame_a.shape[1]-360)//2 99 | frame_a_cropped = frame_a[:,left_pos:left_pos+360,:] 100 | else: 101 | frame_a_cropped = frame_a 102 | im_a = Image.fromarray(frame_a_cropped).resize((256,256)) 103 | return np.asarray(im_a) 104 | 105 | def inference(data,device): 106 | # vqgan 107 | vqgan = VQModel(**vqgan_32_4_config).to(device) 108 | vqgan.eval() 109 | 110 | # ========================= build model ========================= 111 | config = LDM_config() 112 | score_model = NCSNpp_dual(config) 113 | score_model.to(device) 114 | sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=cli_args.steps) 115 | 116 | # ray downsampler 117 | ResnetBlock = functools.partial(ResnetBlockBigGANpp, 118 | act=torch.nn.SiLU(), 119 | dropout=False, 120 | fir=True, 121 | fir_kernel=[1,3,3,1], 122 | init_scale=0, 123 | skip_rescale=True, 124 | temb_dim=None) 125 | ray_downsampler = torch.nn.Sequential( 126 | ResnetBlock(in_ch=56,out_ch=128,down=True), 127 | ResnetBlock(in_ch=128,out_ch=128,down=True)).to(device) 128 | score_sde = Score_sde_model(score_model,sde,ray_downsampler,rays_require_downsample=False,rays_as_list=True) 129 | 130 | checkpoint = torch.load(cli_args.model) 131 | adapted_state = {} 132 | for k,v in checkpoint['score_sde_model'].items(): 133 | key_parts = k.split('.') 134 | if key_parts[1] == 'module': 135 | key_parts.pop(1) 136 | new_key = '.'.join(key_parts) 137 | adapted_state[new_key] = v 138 | score_sde.load_state_dict(adapted_state) 139 | 140 | ema = ExponentialMovingAverage(score_sde.parameters(),decay=0.999) 141 | ema.load_state_dict(checkpoint['ema']) 142 | ema.copy_to(score_sde.parameters()) 143 | # substitute model with score modifier, used to make bulk sampling easier 144 | modifier = Score_modifier(score_sde.score_model,max_batch_size=cli_args.batch_size) 145 | score_sde.score_model = modifier 146 | 147 | tlog('Setup complete','note') 148 | batches = chunks(data, cli_args.batch_size) 149 | 150 | with torch.no_grad(): 151 | for batch in batches: 152 | if cli_args.verbose: 153 | for el in batch: 154 | print(device, el['img']) 155 | conditioning_ims = [] 156 | ff_refs = [] 157 | ff_as = [] 158 | ff_bs = [] 159 | for el in batch: 160 | im = Image.open(el['img']) 161 | im = crop_like_dataset(im) 162 | 163 | if not os.path.exists(el['orig_img_copy_path']): 164 | cv2.imwrite(el['orig_img_copy_path'],cv2.cvtColor(im, cv2.COLOR_RGB2BGR)) 165 | im = im[:,:,:3].astype(np.float32).transpose(2,0,1)/127.5 - 1 166 | im = torch.Tensor(im).unsqueeze(0).to(device) 167 | encoded_im = vqgan.encode(im) 168 | focal_y = el['focal_y'] 169 | tform_ref = np.eye(4) 170 | tform_a_relative = np.asarray(el['transformation']) 171 | tform_b_relative = np.linalg.inv(np.asarray(el['transformation'])) 172 | camera_enc_ref = rel_camera_ray_encoding(tform_ref,128,focal_y) 173 | camera_enc_a = rel_camera_ray_encoding(tform_a_relative,128,focal_y) 174 | camera_enc_b = rel_camera_ray_encoding(tform_b_relative,128,focal_y) 175 | camera_enc_ref = torch.Tensor(camera_enc_ref).unsqueeze(0) 176 | camera_enc_a = torch.Tensor(camera_enc_a).unsqueeze(0) 177 | camera_enc_b = torch.Tensor(camera_enc_b).unsqueeze(0) 178 | ff_ref = F.pad(freq_enc(camera_enc_ref),[0,0,0,0,1,1,0,0]) # pad, must be %4==0 for group norm 179 | ff_a = F.pad(freq_enc(camera_enc_a),[0,0,0,0,1,1,0,0]) 180 | ff_b = F.pad(freq_enc(camera_enc_b),[0,0,0,0,1,1,0,0]) 181 | conditioning_ims.append([encoded_im]) 182 | ff_refs.append([ray_downsampler(ff_ref.to(device))]) 183 | ff_as.append([ray_downsampler(ff_a.to(device))]) 184 | ff_bs.append([ray_downsampler(ff_b.to(device))]) 185 | 186 | # sampling loop 187 | sampling_shape = (len(conditioning_ims), 4, 32, 32) 188 | sampling_eps = 1e-5 189 | x = score_sde.sde.prior_sampling(sampling_shape).to(device) 190 | 191 | timesteps = torch.linspace(sde.T, sampling_eps, sde.N, device=device) 192 | for i in tqdm(range(0,sde.N)): 193 | t = timesteps[i] 194 | vec_t = torch.ones(sampling_shape[0], device=t.device) * t 195 | _, std = sde.marginal_prob(x, vec_t) 196 | x, x_mean = score_sde.reverse_diffusion_predictor(x,conditioning_ims,vec_t,ff_refs,ff_as,ff_bs) 197 | x, x_mean = score_sde.langevin_corrector(x,conditioning_ims,vec_t,ff_refs,ff_as,ff_bs) 198 | 199 | decoded = vqgan.decode(x_mean) 200 | intermediate_sample = (decoded/2+0.5) 201 | intermediate_sample = torch.clip(intermediate_sample.permute(0,2,3,1).cpu()* 255., 0, 255).type(torch.uint8).numpy() 202 | outputs = [x['output'] for x in batch] 203 | pathlib.Path(os.path.dirname(outputs[0])).mkdir(parents=True, exist_ok=True) 204 | for out, im in zip(outputs, intermediate_sample): 205 | im_out = Image.fromarray(im) 206 | im_out.save(out) 207 | 208 | if __name__ == '__main__': 209 | import torch.multiprocessing as mp 210 | data_path = cli_args.config 211 | import json 212 | data = json.load(open(data_path)) 213 | print("Total files: ", len(data)) 214 | import os 215 | data = [item for item in data if not os.path.isfile(item["output"])] 216 | print("Remaining files: ", len(data)) 217 | devices = [f'cuda:{i}' for i in cli_args.gpus] 218 | if len(devices) == 1: 219 | inference(data, devices[0]) 220 | else: 221 | mp.set_start_method('spawn') 222 | datas = chunks_num(data,len(devices)) 223 | ar = list(zip(datas,devices)) 224 | with mp.Pool(len(devices)) as p: 225 | p.starmap(inference, ar) 226 | -------------------------------------------------------------------------------- /src/models/score_sde/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | """Layers used for up-sampling or down-sampling images. 2 | 3 | Many functions are ported from https://github.com/NVlabs/stylegan2. 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from .op import upfirdn2d 11 | 12 | 13 | # Function ported from StyleGAN2 14 | def get_weight(module, 15 | shape, 16 | weight_var='weight', 17 | kernel_init=None): 18 | """Get/create weight tensor for a convolution or fully-connected layer.""" 19 | 20 | return module.param(weight_var, kernel_init, shape) 21 | 22 | 23 | class Conv2d(nn.Module): 24 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 25 | 26 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False, 27 | resample_kernel=(1, 3, 3, 1), 28 | use_bias=True, 29 | kernel_init=None): 30 | super().__init__() 31 | assert not (up and down) 32 | assert kernel >= 1 and kernel % 2 == 1 33 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 34 | if kernel_init is not None: 35 | self.weight.data = kernel_init(self.weight.data.shape) 36 | if use_bias: 37 | self.bias = nn.Parameter(torch.zeros(out_ch)) 38 | 39 | self.up = up 40 | self.down = down 41 | self.resample_kernel = resample_kernel 42 | self.kernel = kernel 43 | self.use_bias = use_bias 44 | 45 | def forward(self, x): 46 | if self.up: 47 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 48 | elif self.down: 49 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 50 | else: 51 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 52 | 53 | if self.use_bias: 54 | x = x + self.bias.reshape(1, -1, 1, 1) 55 | 56 | return x 57 | 58 | 59 | def naive_upsample_2d(x, factor=2): 60 | _N, C, H, W = x.shape 61 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 62 | x = x.repeat(1, 1, 1, factor, 1, factor) 63 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 64 | 65 | 66 | def naive_downsample_2d(x, factor=2): 67 | _N, C, H, W = x.shape 68 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 69 | return torch.mean(x, dim=(3, 5)) 70 | 71 | 72 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 73 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 74 | 75 | Padding is performed only once at the beginning, not between the 76 | operations. 77 | The fused op is considerably more efficient than performing the same 78 | calculation 79 | using standard TensorFlow ops. It supports gradients of arbitrary order. 80 | Args: 81 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 82 | C]`. 83 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 84 | outChannels]`. Grouped convolution can be performed by `inChannels = 85 | x.shape[0] // numGroups`. 86 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 87 | (separable). The default is `[1] * factor`, which corresponds to 88 | nearest-neighbor upsampling. 89 | factor: Integer upsampling factor (default: 2). 90 | gain: Scaling factor for signal magnitude (default: 1.0). 91 | 92 | Returns: 93 | Tensor of the shape `[N, C, H * factor, W * factor]` or 94 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 95 | """ 96 | 97 | assert isinstance(factor, int) and factor >= 1 98 | 99 | # Check weight shape. 100 | assert len(w.shape) == 4 101 | convH = w.shape[2] 102 | convW = w.shape[3] 103 | inC = w.shape[1] 104 | outC = w.shape[0] 105 | 106 | assert convW == convH 107 | 108 | # Setup filter kernel. 109 | if k is None: 110 | k = [1] * factor 111 | k = _setup_kernel(k) * (gain * (factor ** 2)) 112 | p = (k.shape[0] - factor) - (convW - 1) 113 | 114 | stride = (factor, factor) 115 | 116 | # Determine data dimensions. 117 | stride = [1, 1, factor, factor] 118 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 119 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 120 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 121 | assert output_padding[0] >= 0 and output_padding[1] >= 0 122 | num_groups = _shape(x, 1) // inC 123 | 124 | # Transpose weights. 125 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 126 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) 127 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 128 | 129 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 130 | ## Original TF code. 131 | # x = tf.nn.conv2d_transpose( 132 | # x, 133 | # w, 134 | # output_shape=output_shape, 135 | # strides=stride, 136 | # padding='VALID', 137 | # data_format=data_format) 138 | ## JAX equivalent 139 | 140 | return upfirdn2d(x, torch.tensor(k, device=x.device), 141 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 142 | 143 | 144 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 145 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 146 | 147 | Padding is performed only once at the beginning, not between the operations. 148 | The fused op is considerably more efficient than performing the same 149 | calculation 150 | using standard TensorFlow ops. It supports gradients of arbitrary order. 151 | Args: 152 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 153 | C]`. 154 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 155 | outChannels]`. Grouped convolution can be performed by `inChannels = 156 | x.shape[0] // numGroups`. 157 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 158 | (separable). The default is `[1] * factor`, which corresponds to 159 | average pooling. 160 | factor: Integer downsampling factor (default: 2). 161 | gain: Scaling factor for signal magnitude (default: 1.0). 162 | 163 | Returns: 164 | Tensor of the shape `[N, C, H // factor, W // factor]` or 165 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 166 | """ 167 | 168 | assert isinstance(factor, int) and factor >= 1 169 | _outC, _inC, convH, convW = w.shape 170 | assert convW == convH 171 | if k is None: 172 | k = [1] * factor 173 | k = _setup_kernel(k) * gain 174 | p = (k.shape[0] - factor) + (convW - 1) 175 | s = [factor, factor] 176 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 177 | pad=((p + 1) // 2, p // 2)) 178 | return F.conv2d(x, w, stride=s, padding=0) 179 | 180 | 181 | def _setup_kernel(k): 182 | k = np.asarray(k, dtype=np.float32) 183 | if k.ndim == 1: 184 | k = np.outer(k, k) 185 | k /= np.sum(k) 186 | assert k.ndim == 2 187 | assert k.shape[0] == k.shape[1] 188 | return k 189 | 190 | 191 | def _shape(x, dim): 192 | return x.shape[dim] 193 | 194 | 195 | def upsample_2d(x, k=None, factor=2, gain=1): 196 | r"""Upsample a batch of 2D images with the given filter. 197 | 198 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 199 | and upsamples each image with the given filter. The filter is normalized so 200 | that 201 | if the input pixels are constant, they will be scaled by the specified 202 | `gain`. 203 | Pixels outside the image are assumed to be zero, and the filter is padded 204 | with 205 | zeros so that its shape is a multiple of the upsampling factor. 206 | Args: 207 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 208 | C]`. 209 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 210 | (separable). The default is `[1] * factor`, which corresponds to 211 | nearest-neighbor upsampling. 212 | factor: Integer upsampling factor (default: 2). 213 | gain: Scaling factor for signal magnitude (default: 1.0). 214 | 215 | Returns: 216 | Tensor of the shape `[N, C, H * factor, W * factor]` 217 | """ 218 | assert isinstance(factor, int) and factor >= 1 219 | if k is None: 220 | k = [1] * factor 221 | k = _setup_kernel(k) * (gain * (factor ** 2)) 222 | p = k.shape[0] - factor 223 | return upfirdn2d(x, torch.tensor(k, device=x.device), 224 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 225 | 226 | 227 | def downsample_2d(x, k=None, factor=2, gain=1): 228 | r"""Downsample a batch of 2D images with the given filter. 229 | 230 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 231 | and downsamples each image with the given filter. The filter is normalized 232 | so that 233 | if the input pixels are constant, they will be scaled by the specified 234 | `gain`. 235 | Pixels outside the image are assumed to be zero, and the filter is padded 236 | with 237 | zeros so that its shape is a multiple of the downsampling factor. 238 | Args: 239 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 240 | C]`. 241 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 242 | (separable). The default is `[1] * factor`, which corresponds to 243 | average pooling. 244 | factor: Integer downsampling factor (default: 2). 245 | gain: Scaling factor for signal magnitude (default: 1.0). 246 | 247 | Returns: 248 | Tensor of the shape `[N, C, H // factor, W // factor]` 249 | """ 250 | 251 | assert isinstance(factor, int) and factor >= 1 252 | if k is None: 253 | k = [1] * factor 254 | k = _setup_kernel(k) * gain 255 | p = k.shape[0] - factor 256 | return upfirdn2d(x, torch.tensor(k, device=x.device), 257 | down=factor, pad=((p + 1) // 2, p // 2)) 258 | -------------------------------------------------------------------------------- /src/utils/trainers/Score_sde_trainer.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from os.path import join as pjoin 3 | import time 4 | import datetime 5 | import torch 6 | from torchvision import datasets, transforms 7 | from datasets import * 8 | from datasets.colmap_dataset import * 9 | from torch.utils.data import DataLoader 10 | import torch.optim as optim 11 | from models.score_sde.ncsnpp_dual import NCSNpp_dual 12 | import models.score_sde.sde_lib as sde_lib 13 | from .Base_trainer import Base_trainer 14 | from models.score_sde.configs.LDM import LDM_config 15 | import numpy as np 16 | from models.vqgan.vqgan import VQModel 17 | from models.vqgan.configs.vqgan_32_4 import vqgan_32_4_config 18 | from models.score_sde.ema import ExponentialMovingAverage 19 | import functools 20 | from models.score_sde.layerspp import ResnetBlockBigGANpp 21 | import torch.nn.functional as F 22 | 23 | def vq_to_img(x, vqgan): 24 | with torch.no_grad(): 25 | decoded = vqgan.decode(x) 26 | intermediate_sample = (decoded/2+0.5) 27 | intermediate_sample = torch.clip(intermediate_sample.permute(0,2,3,1).cpu()* 255., 0, 255).type(torch.uint8).numpy() 28 | return intermediate_sample 29 | 30 | def loss_fn(score_sde, batch, cond_im, ff_ref, ff_a, ff_b): 31 | score_sde.train() 32 | eps=1e-5 33 | 34 | t = score_sde.t_uniform(batch.shape[0],batch.device) 35 | perturbed_data, z, std = score_sde.forward_diffusion(batch,t) 36 | score = score_sde.score(perturbed_data, cond_im, t, ff_ref, ff_a, ff_b) 37 | 38 | losses = torch.square(score * std[:, None, None, None] + z) 39 | losses = 0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1) 40 | 41 | loss = torch.mean(losses) 42 | return loss 43 | 44 | class Score_sde_trainer(Base_trainer): 45 | def __init__(self,local_rank,node_rank,n_gpus_per_node,n_nodes): 46 | # distributed helpers 47 | self.rank = node_rank*n_gpus_per_node + local_rank 48 | self.world_size = n_nodes*n_gpus_per_node 49 | self.gpu = local_rank 50 | compute_only = self.rank != 0 51 | super().__init__(compute_only) 52 | 53 | # configure checkpoint behaviour 54 | self.n_kept_checkpoints = globals.n_kept_checkpoints 55 | self.checkpoint_interval =globals.checkpoint_interval 56 | self.max_epoch = globals.max_epoch 57 | self.checkpoint_retention_interval = globals.checkpoint_retention_interval 58 | 59 | # vqgan 60 | self.vqgan = VQModel(**vqgan_32_4_config).to(f'cuda:{self.gpu}') 61 | self.vqgan.eval() 62 | 63 | # score sde configs 64 | self.config = LDM_config() 65 | config = self.config 66 | 67 | # main components, dataloaders, model, optimizer 68 | batch_size = 64//self.world_size 69 | batch_size = globals.batch_size//self.world_size 70 | 71 | if globals.sfm_method == "colmap": 72 | self.dataset = ColmapDataset(globals.colmap_data_folders, globals.focal_lengths) 73 | elif globals.sfm_method == "umi": 74 | self.dataset = UmiDatasetFromFolder(globals.colmap_data_folders, globals.focal_lengths) 75 | else: 76 | self.dataset = SlamDataset(globals.colmap_data_folders, globals.focal_lengths) 77 | 78 | self.sampler = DistributedSaveableSampler(self.dataset,num_replicas=self.world_size,rank=self.rank,shuffle=True) 79 | self.train_loader = torch.utils.data.DataLoader(self.dataset,batch_size=batch_size,sampler=self.sampler,drop_last=True,num_workers=8,persistent_workers=True) # drop_last used to be True 80 | print("Data loader length: ", len(self.train_loader)) 81 | score_model = NCSNpp_dual(config) 82 | score_model = score_model.to(f'cuda:{self.gpu}') 83 | score_model = torch.nn.parallel.DistributedDataParallel(score_model,device_ids=[self.gpu]) 84 | sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) 85 | 86 | ResnetBlock = functools.partial(ResnetBlockBigGANpp, 87 | act=torch.nn.SiLU(), 88 | dropout=False, 89 | fir=True, 90 | fir_kernel=[1,3,3,1], 91 | init_scale=0, 92 | skip_rescale=True, 93 | temb_dim=None) 94 | ray_downsampler = torch.nn.Sequential( 95 | ResnetBlock(in_ch=56,out_ch=128,down=True).to(f'cuda:{self.gpu}'), 96 | ResnetBlock(in_ch=128,out_ch=128,down=True).to(f'cuda:{self.gpu}')) 97 | ray_downsampler = torch.nn.parallel.DistributedDataParallel(ray_downsampler,device_ids=[self.gpu]) 98 | 99 | self.score_sde = Score_sde_model(score_model,sde,ray_downsampler) 100 | 101 | # only 1 worker needs to keep ema 102 | if self.rank == 0: 103 | self.ema = ExponentialMovingAverage(self.score_sde.parameters(),decay=0.999) 104 | 105 | self.lr = 2e-4 106 | self.optimizer = optim.Adam(self.score_sde.parameters(), lr=self.lr) 107 | self.warmup = 5000 108 | self.grad_clip = 1. 109 | 110 | self.tlog('Setup complete','note') 111 | 112 | def train_epoch(self): 113 | iteration_start = time.time() 114 | self.sampler.set_epoch(self.epoch) # important! or else split will be the same every epoch 115 | dataloader_iter = iter(self.train_loader) # need this to save state 116 | 117 | # print(len(list(enumerate(dataloader_iter, start=1)))) 118 | for epoch_it, batch_data in enumerate(dataloader_iter,start=1): 119 | # print('at train epoch rn') 120 | self.total_iterations += 1 121 | 122 | # unpack data 123 | im_a = batch_data['im_a'] 124 | im_b = batch_data['im_b'] 125 | camera_enc_ref = batch_data['camera_enc_ref'] 126 | camera_enc_a = batch_data['camera_enc_a'] 127 | camera_enc_b = batch_data['camera_enc_b'] 128 | 129 | # move data to gpu 130 | im_a = im_a.to(f'cuda:{self.gpu}') 131 | im_b = im_b.to(f'cuda:{self.gpu}') 132 | camera_enc_ref = camera_enc_ref.to(f'cuda:{self.gpu}') 133 | camera_enc_a = camera_enc_a.to(f'cuda:{self.gpu}') 134 | camera_enc_b = camera_enc_b.to(f'cuda:{self.gpu}') 135 | 136 | # encode with vqgan 137 | with torch.no_grad(): 138 | encoded_a = self.vqgan.encode(im_a) 139 | encoded_b = self.vqgan.encode(im_b) 140 | 141 | # train 142 | self.optimizer.zero_grad() 143 | ff_ref = F.pad(freq_enc(camera_enc_ref),[0,0,0,0,1,1,0,0]) # pad, must be %4==0 for group norm 144 | ff_a = F.pad(freq_enc(camera_enc_a),[0,0,0,0,1,1,0,0]) 145 | ff_b = F.pad(freq_enc(camera_enc_b),[0,0,0,0,1,1,0,0]) 146 | loss = loss_fn(self.score_sde, encoded_a, encoded_b, ff_ref, ff_a, ff_b) 147 | loss.backward() 148 | 149 | # warmup and gradient clip 150 | if self.warmup > 0: 151 | for g in self.optimizer.param_groups: 152 | g['lr'] = self.lr * np.minimum(self.total_iterations / self.warmup, 1.0) 153 | if self.grad_clip >= 0: 154 | torch.nn.utils.clip_grad_norm_(self.score_sde.parameters(), max_norm=self.grad_clip) 155 | self.optimizer.step() 156 | 157 | # update ema 158 | if self.rank == 0: 159 | self.ema.update(self.score_sde.parameters()) 160 | 161 | # print iteration details 162 | if not self.compute_only: 163 | # calc eta 164 | iteration_duration = time.time() - iteration_start 165 | its_per_sec = 1/iteration_duration 166 | remaining_its = self.max_epoch*len(self.train_loader) - self.total_iterations 167 | eta_sec = remaining_its * iteration_duration 168 | eta_min = eta_sec//60 169 | eta = str(datetime.timedelta(minutes=eta_min)) 170 | if self.total_iterations % self.print_iteration == 0: 171 | self.tlog(f'{self.total_iterations} | loss: {loss.item()} | it/s: {its_per_sec} | ETA: {eta}','iter') 172 | 173 | self.tb_writer.add_scalar('training/loss', loss.item(), self.total_iterations) 174 | 175 | # checkpoint if haven't checkpointed in a while 176 | self.maybe_save_checkpoint(epoch_it,dataloader_iter) 177 | 178 | # check if termination requested 179 | self.check_termination_request(epoch_it,dataloader_iter) 180 | 181 | # start time here to include data fetching 182 | iteration_start = time.time() 183 | 184 | 185 | if not self.compute_only: 186 | self.tb_writer.add_scalar('training/loss epoch', loss.item(), epoch_it) 187 | 188 | def validate(self): 189 | pass 190 | 191 | def state_dict(self,dataloader_iter=None): 192 | state_dict = { 193 | 'epoch': self.epoch, 194 | 'total_iterations': self.total_iterations, 195 | 'optimizer': self.optimizer.state_dict(), 196 | 'score_sde_model':self.score_sde.state_dict(), 197 | 'training_data_sampler':self.sampler.state_dict(dataloader_iter), 198 | 'ema': self.ema.state_dict() 199 | } 200 | return state_dict 201 | 202 | def load_state_dict(self,state_dict): 203 | self.epoch = state_dict['epoch'] 204 | self.total_iterations = state_dict['total_iterations'] 205 | self.score_sde.load_state_dict(state_dict['score_sde_model']) 206 | self.optimizer.load_state_dict(state_dict['optimizer']) 207 | self.sampler.load_state_dict(state_dict['training_data_sampler']) 208 | if self.rank == 0: 209 | self.ema.load_state_dict(state_dict['ema']) 210 | 211 | def load_checkpoint(self,checkpoint_path): 212 | super().load_checkpoint(checkpoint_path,map_location=f'cuda:{self.gpu}') 213 | -------------------------------------------------------------------------------- /src/models/score_sde/layerspp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Layers for defining NCSN++. 18 | """ 19 | from . import layers 20 | from . import up_or_down_sampling 21 | import torch.nn as nn 22 | import torch 23 | import torch.nn.functional as F 24 | import numpy as np 25 | 26 | conv1x1 = layers.ddpm_conv1x1 27 | conv3x3 = layers.ddpm_conv3x3 28 | NIN = layers.NIN 29 | default_init = layers.default_init 30 | 31 | 32 | class GaussianFourierProjection(nn.Module): 33 | """Gaussian Fourier embeddings for noise levels.""" 34 | 35 | def __init__(self, embedding_size=256, scale=1.0): 36 | super().__init__() 37 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 38 | 39 | def forward(self, x): 40 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 41 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 42 | 43 | 44 | class Combine(nn.Module): 45 | """Combine information from skip connections.""" 46 | 47 | def __init__(self, dim1, dim2, method='cat'): 48 | super().__init__() 49 | self.Conv_0 = conv1x1(dim1, dim2) 50 | self.method = method 51 | 52 | def forward(self, x, y): 53 | h = self.Conv_0(x) 54 | if self.method == 'cat': 55 | return torch.cat([h, y], dim=1) 56 | elif self.method == 'sum': 57 | return h + y 58 | else: 59 | raise ValueError(f'Method {self.method} not recognized.') 60 | 61 | 62 | class AttnBlockpp(nn.Module): 63 | """Channel-wise self-attention block. Modified from DDPM.""" 64 | 65 | def __init__(self, channels, n_heads=4, skip_rescale=False, init_scale=0.): 66 | super().__init__() 67 | self.n_heads = n_heads 68 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 69 | eps=1e-6) 70 | self.NIN_0 = NIN(channels, channels*n_heads) 71 | self.NIN_1 = NIN(channels, channels*n_heads) 72 | self.NIN_2 = NIN(channels, channels*n_heads) 73 | self.NIN_3 = NIN(channels*n_heads, channels, init_scale=init_scale) 74 | self.skip_rescale = skip_rescale 75 | 76 | def forward(self, x): 77 | B, C, H, W = x.shape 78 | h = self.GroupNorm_0(x) 79 | q = self.NIN_0(h).reshape(B,self.n_heads,C,H,W) 80 | k = self.NIN_1(h).reshape(B,self.n_heads,C,H,W) 81 | v = self.NIN_2(h).reshape(B,self.n_heads,C,H,W) 82 | 83 | w = torch.einsum('bechw,becij->behwij', q, k) * (int(C) ** (-0.5)) 84 | w = torch.reshape(w, (B, self.n_heads, H, W, H * W)) 85 | w = F.softmax(w, dim=-1) 86 | w = torch.reshape(w, (B, self.n_heads, H, W, H, W)) 87 | h = torch.einsum('behwij,becij->bechw', w, v) 88 | h = h.reshape(B, self.n_heads*C, H ,W) 89 | h = self.NIN_3(h) 90 | if not self.skip_rescale: 91 | return x + h 92 | else: 93 | return (x + h) / np.sqrt(2.) 94 | 95 | 96 | class Upsample(nn.Module): 97 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 98 | fir_kernel=(1, 3, 3, 1)): 99 | super().__init__() 100 | out_ch = out_ch if out_ch else in_ch 101 | if not fir: 102 | if with_conv: 103 | self.Conv_0 = conv3x3(in_ch, out_ch) 104 | else: 105 | if with_conv: 106 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 107 | kernel=3, up=True, 108 | resample_kernel=fir_kernel, 109 | use_bias=True, 110 | kernel_init=default_init()) 111 | self.fir = fir 112 | self.with_conv = with_conv 113 | self.fir_kernel = fir_kernel 114 | self.out_ch = out_ch 115 | 116 | def forward(self, x): 117 | B, C, H, W = x.shape 118 | if not self.fir: 119 | h = F.interpolate(x, (H * 2, W * 2), 'nearest') 120 | if self.with_conv: 121 | h = self.Conv_0(h) 122 | else: 123 | if not self.with_conv: 124 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 125 | else: 126 | h = self.Conv2d_0(x) 127 | 128 | return h 129 | 130 | 131 | class Downsample(nn.Module): 132 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 133 | fir_kernel=(1, 3, 3, 1)): 134 | super().__init__() 135 | out_ch = out_ch if out_ch else in_ch 136 | if not fir: 137 | if with_conv: 138 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) 139 | else: 140 | if with_conv: 141 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 142 | kernel=3, down=True, 143 | resample_kernel=fir_kernel, 144 | use_bias=True, 145 | kernel_init=default_init()) 146 | self.fir = fir 147 | self.fir_kernel = fir_kernel 148 | self.with_conv = with_conv 149 | self.out_ch = out_ch 150 | 151 | def forward(self, x): 152 | B, C, H, W = x.shape 153 | if not self.fir: 154 | if self.with_conv: 155 | x = F.pad(x, (0, 1, 0, 1)) 156 | x = self.Conv_0(x) 157 | else: 158 | x = F.avg_pool2d(x, 2, stride=2) 159 | else: 160 | if not self.with_conv: 161 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 162 | else: 163 | x = self.Conv2d_0(x) 164 | 165 | return x 166 | 167 | 168 | class ResnetBlockDDPMpp(nn.Module): 169 | """ResBlock adapted from DDPM.""" 170 | 171 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, 172 | dropout=0.1, skip_rescale=False, init_scale=0.): 173 | super().__init__() 174 | out_ch = out_ch if out_ch else in_ch 175 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 176 | self.Conv_0 = conv3x3(in_ch, out_ch) 177 | if temb_dim is not None: 178 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 179 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 180 | nn.init.zeros_(self.Dense_0.bias) 181 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 182 | self.Dropout_0 = nn.Dropout(dropout) 183 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 184 | if in_ch != out_ch: 185 | if conv_shortcut: 186 | self.Conv_2 = conv3x3(in_ch, out_ch) 187 | else: 188 | self.NIN_0 = NIN(in_ch, out_ch) 189 | 190 | self.skip_rescale = skip_rescale 191 | self.act = act 192 | self.out_ch = out_ch 193 | self.conv_shortcut = conv_shortcut 194 | 195 | def forward(self, x, temb=None): 196 | h = self.act(self.GroupNorm_0(x)) 197 | h = self.Conv_0(h) 198 | if temb is not None: 199 | h += self.Dense_0(self.act(temb))[:, :, None, None] 200 | h = self.act(self.GroupNorm_1(h)) 201 | h = self.Dropout_0(h) 202 | h = self.Conv_1(h) 203 | if x.shape[1] != self.out_ch: 204 | if self.conv_shortcut: 205 | x = self.Conv_2(x) 206 | else: 207 | x = self.NIN_0(x) 208 | if not self.skip_rescale: 209 | return x + h 210 | else: 211 | return (x + h) / np.sqrt(2.) 212 | 213 | 214 | class ResnetBlockBigGANpp(nn.Module): 215 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, 216 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), 217 | skip_rescale=True, init_scale=0.,conditioning_dim=None): 218 | super().__init__() 219 | 220 | out_ch = out_ch if out_ch else in_ch 221 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 222 | self.up = up 223 | self.down = down 224 | self.fir = fir 225 | self.fir_kernel = fir_kernel 226 | 227 | self.Conv_0 = conv3x3(in_ch, out_ch) 228 | if temb_dim is not None: 229 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 230 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) 231 | nn.init.zeros_(self.Dense_0.bias) 232 | 233 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 234 | self.Dropout_0 = nn.Dropout(dropout) 235 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 236 | if in_ch != out_ch or up or down: 237 | self.Conv_2 = conv1x1(in_ch, out_ch) 238 | 239 | self.skip_rescale = skip_rescale 240 | self.act = act 241 | self.in_ch = in_ch 242 | self.out_ch = out_ch 243 | 244 | if conditioning_dim is not None: 245 | self.cond_linear= nn.Conv2d(conditioning_dim, out_ch,1) 246 | self.cond_linear.weight.data = default_init()(self.cond_linear.weight.shape) 247 | nn.init.zeros_(self.cond_linear.bias) 248 | 249 | def forward(self, x, temb=None, conditioning=None): 250 | h = self.act(self.GroupNorm_0(x)) 251 | 252 | if self.up: 253 | if self.fir: 254 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 255 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 256 | else: 257 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 258 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 259 | elif self.down: 260 | if self.fir: 261 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 262 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 263 | else: 264 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 265 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 266 | 267 | h = self.Conv_0(h) 268 | # Add bias to each feature map conditioned on the time embedding 269 | if temb is not None: 270 | h += self.Dense_0(self.act(temb))[:, :, None, None] 271 | if conditioning is not None: 272 | h += self.cond_linear(self.act(conditioning)) 273 | h = self.act(self.GroupNorm_1(h)) 274 | h = self.Dropout_0(h) 275 | h = self.Conv_1(h) 276 | 277 | if self.in_ch != self.out_ch or self.up or self.down: 278 | x = self.Conv_2(x) 279 | 280 | if not self.skip_rescale: 281 | return x + h 282 | else: 283 | return (x + h) / np.sqrt(2.) 284 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: dmd-diffusion 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _ipython_minor_entry_point=8.7.0=h3b92ee0_0 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_kmp_llvm 10 | - alsa-lib=1.2.3.2=h166bdaf_0 11 | - anyio=3.6.2=pyhd8ed1ab_0 12 | - aom=3.5.0=h27087fc_0 13 | - argon2-cffi=21.3.0=pyhd8ed1ab_0 14 | - argon2-cffi-bindings=21.2.0=py38h0a891b7_3 15 | - asttokens=2.2.1=pyhd8ed1ab_0 16 | - attrs=22.1.0=pyh71513ae_1 17 | - backcall=0.2.0=pyh9f0ad1d_0 18 | - backports=1.0=pyhd8ed1ab_3 19 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 20 | - beautifulsoup4=4.11.1=pyha770c72_0 21 | - binutils_impl_linux-64=2.36.1=h193b22a_2 22 | - binutils_linux-64=2.36=hf3e587d_10 23 | - blas=1.0=mkl 24 | - bleach=5.0.1=pyhd8ed1ab_0 25 | - blosc=1.21.2=hafa529b_0 26 | - brotli=1.0.9=h166bdaf_8 27 | - brotli-bin=1.0.9=h166bdaf_8 28 | - brotlipy=0.7.0=py38h0a891b7_1005 29 | - brunsli=0.1=h9c3ff4c_0 30 | - bzip2=1.0.8=h7f98852_4 31 | - c-ares=1.18.1=h7f98852_0 32 | - c-blosc2=2.6.0=hf91038e_0 33 | - ca-certificates=2022.12.7=ha878542_0 34 | - cairo=1.16.0=h6cf1ce9_1008 35 | - certifi=2022.12.7=pyhd8ed1ab_0 36 | - cffi=1.15.1=py38h4a40e3a_2 37 | - cfitsio=4.2.0=hd9d235c_0 38 | - charls=2.3.4=h9c3ff4c_0 39 | - click=8.1.3=unix_pyhd8ed1ab_2 40 | - cloudpickle=2.2.0=pyhd8ed1ab_0 41 | - colorama=0.4.6=pyhd8ed1ab_0 42 | - comm=0.1.2=pyhd8ed1ab_0 43 | - cryptography=38.0.4=py38h2b5fc30_0 44 | #- cudatoolkit=11.3.1=h9edb442_11 45 | - cudatoolkit-dev=11.3.1=py38h497a2fe_0 46 | - cupy=11.4.0=py38h405e1b6_0 47 | - cycler=0.11.0=pyhd8ed1ab_0 48 | - cytoolz=0.12.0=py38h0a891b7_1 49 | - dask-core=2022.12.0=pyhd8ed1ab_0 50 | - dav1d=1.0.0=h166bdaf_1 51 | - dbus=1.13.6=h5008d03_3 52 | - debugpy=1.6.4=py38hfa26641_0 53 | - decorator=5.1.1=pyhd8ed1ab_0 54 | - defusedxml=0.7.1=pyhd8ed1ab_0 55 | - entrypoints=0.4=pyhd8ed1ab_0 56 | - executing=1.2.0=pyhd8ed1ab_0 57 | - expat=2.5.0=h27087fc_0 58 | - fastrlock=0.8=py38hfa26641_3 59 | - ffmpeg=4.3=hf484d3e_0 60 | - filelock=3.8.2=pyhd8ed1ab_0 61 | - flit-core=3.8.0=pyhd8ed1ab_0 62 | - fontconfig=2.14.1=hc2a2eb6_0 63 | - fonttools=4.38.0=py38h0a891b7_1 64 | - freetype=2.12.1=hca18f0e_1 65 | - fsspec=2022.11.0=pyhd8ed1ab_0 66 | - gcc_impl_linux-64=8.5.0=hb55b52c_17 67 | - gcc_linux-64=8.5.0=h87d5063_10 68 | - gettext=0.21.1=h27087fc_0 69 | - giflib=5.2.1=h36c2ea0_2 70 | - gmp=6.2.1=h58526e2_0 71 | - gnutls=3.6.13=h85f3911_1 72 | - graphite2=1.3.13=h58526e2_1001 73 | - gst-plugins-base=1.18.5=hf529b03_3 74 | - gstreamer=1.18.5=h9f60fe5_3 75 | - gxx_impl_linux-64=8.5.0=hb55b52c_17 76 | - gxx_linux-64=8.5.0=h82b3ca4_10 77 | - harfbuzz=2.9.1=h83ec7ef_1 78 | - hdf5=1.10.6=nompi_h6a2412b_1114 79 | - huggingface_hub=0.11.1=pyhd8ed1ab_0 80 | - icu=68.2=h9c3ff4c_0 81 | - idna=3.4=pyhd8ed1ab_0 82 | - imagecodecs=2022.9.26=py38hf74bd01_4 83 | - imageio=2.22.4=pyhfa7a67d_1 84 | - importlib-metadata=5.1.0=pyha770c72_0 85 | - importlib_metadata=5.1.0=hd8ed1ab_0 86 | - importlib_resources=5.10.1=pyhd8ed1ab_0 87 | - ipykernel=6.19.2=pyh210e3f2_0 88 | - ipython=8.7.0=pyh41d4057_0 89 | - ipython_genutils=0.2.0=py_1 90 | - ipywidgets=8.0.3=pyhd8ed1ab_0 91 | - jasper=1.900.1=h07fcdf6_1006 92 | - jedi=0.18.2=pyhd8ed1ab_0 93 | - jinja2=3.1.2=pyhd8ed1ab_1 94 | - jpeg=9e=h166bdaf_2 95 | - jsonschema=4.17.3=pyhd8ed1ab_0 96 | - jxrlib=1.1=h7f98852_2 97 | - kernel-headers_linux-64=2.6.32=he073ed8_15 98 | - keyutils=1.6.1=h166bdaf_0 99 | - kiwisolver=1.4.4=py38h43d8883_1 100 | - krb5=1.19.3=h3790be6_0 101 | - lame=3.100=h166bdaf_1003 102 | - lcms2=2.14=h6ed2654_0 103 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 104 | - lerc=4.0.0=h27087fc_0 105 | - libaec=1.0.6=h9c3ff4c_0 106 | - libavif=0.11.1=h5cdd6b5_0 107 | - libblas=3.9.0=12_linux64_mkl 108 | - libbrotlicommon=1.0.9=h166bdaf_8 109 | - libbrotlidec=1.0.9=h166bdaf_8 110 | - libbrotlienc=1.0.9=h166bdaf_8 111 | - libcblas=3.9.0=12_linux64_mkl 112 | - libclang=11.1.0=default_ha53f305_1 113 | - libcurl=7.86.0=h7bff187_1 114 | - libdeflate=1.14=h166bdaf_0 115 | - libedit=3.1.20191231=he28a2e2_2 116 | - libev=4.33=h516909a_1 117 | - libevent=2.1.10=h9b69904_4 118 | - libffi=3.4.2=h7f98852_5 119 | - libgcc-devel_linux-64=8.5.0=h82e8279_17 120 | - libgcc-ng=12.2.0=h65d4601_19 121 | - libgfortran-ng=12.2.0=h69a702a_19 122 | - libgfortran5=12.2.0=h337968e_19 123 | - libglib=2.74.1=h606061b_1 124 | - libgomp=12.2.0=h65d4601_19 125 | - libiconv=1.17=h166bdaf_0 126 | - liblapack=3.9.0=12_linux64_mkl 127 | - liblapacke=3.9.0=12_linux64_mkl 128 | - libllvm11=11.1.0=he0ac6c6_5 129 | - libnghttp2=1.47.0=hdcd2b5c_1 130 | - libnsl=2.0.0=h7f98852_0 131 | - libogg=1.3.4=h7f98852_1 132 | - libopencv=4.4.0=py38_2 133 | - libopus=1.3.1=h7f98852_1 134 | - libpng=1.6.39=h753d276_0 135 | - libpq=13.8=hd77ab85_0 136 | - libsanitizer=8.5.0=h70fd0c9_17 137 | - libsodium=1.0.18=h36c2ea0_1 138 | - libsqlite=3.40.0=h753d276_0 139 | - libssh2=1.10.0=haa6b8db_3 140 | - libstdcxx-devel_linux-64=8.5.0=h82e8279_17 141 | - libstdcxx-ng=12.2.0=h46fd767_19 142 | - libtiff=4.4.0=h55922b4_4 143 | - libuuid=2.32.1=h7f98852_1000 144 | - libuv=1.44.2=h166bdaf_0 145 | - libvorbis=1.3.7=h9c3ff4c_0 146 | - libwebp=1.2.4=h522a892_0 147 | - libwebp-base=1.2.4=h166bdaf_0 148 | - libxcb=1.13=h7f98852_1004 149 | - libxkbcommon=1.0.3=he3ba5ed_0 150 | - libxml2=2.9.12=h72842e0_0 151 | - libzlib=1.2.13=h166bdaf_4 152 | - libzopfli=1.0.3=h9c3ff4c_0 153 | - llvm-openmp=15.0.6=he0ac6c6_0 154 | - locket=1.0.0=pyhd8ed1ab_0 155 | - lpips=0.1.3=pyhd8ed1ab_0 156 | - lz4-c=1.9.3=h9c3ff4c_1 157 | - markupsafe=2.1.1=py38h0a891b7_2 158 | - matplotlib=3.5.0=py38h06a4308_0 159 | - matplotlib-base=3.5.0=py38hf4fb855_0 160 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 161 | - mistune=2.0.4=pyhd8ed1ab_0 162 | - mkl=2021.4.0=h8d4b97c_729 163 | - mkl-service=2.4.0=py38h95df7f1_0 164 | - mkl_fft=1.3.1=py38h8666266_1 165 | - mkl_random=1.2.2=py38h1abd341_0 166 | - munkres=1.1.4=pyh9f0ad1d_0 167 | - mysql-common=8.0.31=haf5c9bc_0 168 | - mysql-libs=8.0.31=h28c427c_0 169 | - nbclassic=0.4.8=pyhd8ed1ab_0 170 | - nbclient=0.7.2=pyhd8ed1ab_0 171 | - nbconvert=7.2.6=pyhd8ed1ab_0 172 | - nbconvert-core=7.2.6=pyhd8ed1ab_0 173 | - nbconvert-pandoc=7.2.6=pyhd8ed1ab_0 174 | - nbformat=5.7.0=pyhd8ed1ab_0 175 | - nccl=2.14.3.1=h0800d71_0 176 | - ncurses=6.3=h27087fc_1 177 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 178 | - nettle=3.6=he412f7d_0 179 | - networkx=2.8.8=pyhd8ed1ab_0 180 | - ninja=1.10.2=h06a4308_5 181 | - ninja-base=1.10.2=hd09550d_5 182 | - notebook=6.5.2=pyha770c72_1 183 | - notebook-shim=0.2.2=pyhd8ed1ab_0 184 | - nspr=4.35=h27087fc_0 185 | - nss=3.82=he02c5a1_0 186 | - numpy=1.21.2=py38h20f2e39_0 187 | - numpy-base=1.21.2=py38h79a1101_0 188 | - olefile=0.46=pyh9f0ad1d_1 189 | - opencv=4.4.0=py38_2 190 | - openh264=2.1.1=h4ff587b_0 191 | - openjpeg=2.5.0=h7d73246_1 192 | - openssl=1.1.1s=h0b41bf4_1 193 | - packaging=22.0=pyhd8ed1ab_0 194 | - pandoc=2.19.2=h32600fe_1 195 | - pandocfilters=1.5.0=pyhd8ed1ab_0 196 | - parso=0.8.3=pyhd8ed1ab_0 197 | - partd=1.3.0=pyhd8ed1ab_0 198 | - pcre2=10.40=hc3806b6_0 199 | - pexpect=4.8.0=pyh1a96a4e_2 200 | - pickleshare=0.7.5=py_1003 201 | - pillow=8.4.0=py38h5aabda8_0 202 | - pip=21.2.4=py38h06a4308_0 203 | - pixman=0.40.0=h36c2ea0_0 204 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0 205 | - platformdirs=2.6.0=pyhd8ed1ab_0 206 | - prometheus_client=0.15.0=pyhd8ed1ab_0 207 | - prompt-toolkit=3.0.36=pyha770c72_0 208 | - prompt_toolkit=3.0.36=hd8ed1ab_0 209 | - psutil=5.9.4=py38h0a891b7_0 210 | - pthread-stubs=0.4=h36c2ea0_1001 211 | - ptyprocess=0.7.0=pyhd3deb0d_0 212 | - pudb=2022.1=pyhd8ed1ab_1 213 | - pure_eval=0.2.2=pyhd8ed1ab_0 214 | - py-opencv=4.4.0=py38h23f93f0_2 215 | - pycparser=2.21=pyhd8ed1ab_0 216 | - pygments=2.13.0=pyhd8ed1ab_0 217 | - pyopenssl=22.1.0=pyhd8ed1ab_0 218 | - pyparsing=3.0.9=pyhd8ed1ab_0 219 | - pyqt=5.12.3=py38ha8c2ead_4 220 | - pyrsistent=0.19.2=py38h0a891b7_0 221 | - pysocks=1.7.1=pyha2e5f31_6 222 | - python=3.8.15=h257c98d_0_cpython 223 | - python-dateutil=2.8.2=pyhd8ed1ab_0 224 | - python-fastjsonschema=2.16.2=pyhd8ed1ab_0 225 | - python-json-logger=2.0.1=pyh9f0ad1d_0 226 | - python_abi=3.8=3_cp38 227 | - pytorch=1.10.2=py3.8_cuda11.3_cudnn8.2.0_0 228 | - pytorch-mutex=1.0=cuda 229 | - pywavelets=1.3.0=py38h26c90d9_2 230 | - pyyaml=6.0=py38h0a891b7_5 231 | - pyzmq=24.0.1=py38hfc09fa9_1 232 | - qt=5.12.9=hda022c4_4 233 | - qtconsole=5.4.0=pyhd8ed1ab_0 234 | - qtconsole-base=5.4.0=pyha770c72_0 235 | - qtpy=2.3.0=pyhd8ed1ab_0 236 | - readline=8.1.2=h0f457ee_0 237 | - scikit-image=0.19.3=py38h8f669ce_2 238 | - send2trash=1.8.0=pyhd8ed1ab_0 239 | - setuptools=59.5.0 240 | - six=1.16.0=pyh6c4a22f_0 241 | - snappy=1.1.9=hbd366e4_2 242 | - sniffio=1.3.0=pyhd8ed1ab_0 243 | - soupsieve=2.3.2.post1=pyhd8ed1ab_0 244 | - sqlite=3.40.0=h4ff8645_0 245 | - stack_data=0.6.2=pyhd8ed1ab_0 246 | - sysroot_linux-64=2.12=he073ed8_15 247 | - tbb=2021.7.0=h924138e_0 248 | - termcolor=1.1.0=py38h06a4308_1 249 | - terminado=0.17.1=pyh41d4057_0 250 | - tifffile=2022.10.10=pyhd8ed1ab_0 251 | - timm=0.6.12=pyhd8ed1ab_0 252 | - tinycss2=1.2.1=pyhd8ed1ab_0 253 | - tk=8.6.12=h27826a3_0 254 | - toolz=0.12.0=pyhd8ed1ab_0 255 | - tornado=6.2=py38h0a891b7_1 256 | - traitlets=5.7.1=pyhd8ed1ab_0 257 | - typing-extensions=4.4.0=hd8ed1ab_0 258 | - typing_extensions=4.4.0=pyha770c72_0 259 | - unicodedata2=15.0.0=py38h0a891b7_0 260 | - urwid=2.1.2=py38h0a891b7_7 261 | - urwid_readline=0.13=pyhd8ed1ab_0 262 | - wcwidth=0.2.5=pyh9f0ad1d_2 263 | - webencodings=0.5.1=py_1 264 | - websocket-client=1.4.2=pyhd8ed1ab_0 265 | - wheel=0.38.4=pyhd8ed1ab_0 266 | - widgetsnbextension=4.0.4=pyhd8ed1ab_0 267 | - xorg-kbproto=1.0.7=h7f98852_1002 268 | - xorg-libice=1.0.10=h7f98852_0 269 | - xorg-libsm=1.2.3=hd9c2040_1000 270 | - xorg-libx11=1.7.2=h7f98852_0 271 | - xorg-libxau=1.0.9=h7f98852_0 272 | - xorg-libxdmcp=1.1.3=h7f98852_0 273 | - xorg-libxext=1.3.4=h7f98852_1 274 | - xorg-libxrender=0.9.10=h7f98852_1003 275 | - xorg-renderproto=0.11.1=h7f98852_1002 276 | - xorg-xextproto=7.3.0=h7f98852_1002 277 | - xorg-xproto=7.0.31=h7f98852_1007 278 | - xz=5.2.6=h166bdaf_0 279 | - yaml=0.2.5=h7f98852_2 280 | - zeromq=4.3.4=h9c3ff4c_1 281 | - zfp=1.0.0=h27087fc_3 282 | - zipp=3.11.0=pyhd8ed1ab_0 283 | - zlib=1.2.13=h166bdaf_4 284 | - zlib-ng=2.0.6=h166bdaf_0 285 | - zstd=1.5.2=h6239696_4 286 | - pip: 287 | - absl-py==1.3.0 288 | - av==10.0.0 289 | - autolab_core==1.1.1 290 | - cachetools==5.2.0 291 | - charset-normalizer==2.0.12 292 | - google-auth==2.15.0 293 | - google-auth-oauthlib==0.4.6 294 | - grpcio==1.51.1 295 | - markdown==3.4.1 296 | - numpy-quaternion==2022.4.3 297 | - nvidia-ml-py==11.515.75 298 | - nvitop==0.11.0 299 | - oauthlib==3.2.2 300 | - opencv-python==4.5.5.62 301 | - protobuf==3.20.* 302 | - pyasn1==0.4.8 303 | - pyasn1-modules==0.2.8 304 | - pyqt5-sip==4.19.18 305 | - pyqtchart==5.12 306 | - pyqtwebengine==5.12.1 307 | - quaternion==3.5.2.post4 308 | - requests==2.27.1 309 | - requests-oauthlib==1.3.1 310 | - rsa==4.9 311 | - scipy==1.8.0 312 | - tb-nightly==2.9.0a20220217 313 | - tensorboard-data-server==0.6.1 314 | - tensorboard-plugin-wit==1.8.1 315 | - tqdm==4.62.3 316 | - urllib3==1.26.8 317 | - werkzeug==3.0.3 318 | - pytransform3d==3.5.0 319 | - HTML4Vision==0.4.3 320 | -------------------------------------------------------------------------------- /data/generate_slam_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import pickle 5 | import pathlib 6 | import cv2 7 | import av 8 | from collections import defaultdict 9 | from autolab_core import RigidTransform 10 | from tqdm import tqdm 11 | 12 | import numpy as np 13 | import argparse 14 | import concurrent.futures 15 | import multiprocessing 16 | 17 | black_list = { 18 | "demo_EXAMPLE1": [(0,97)], 19 | "demo_EXAMPLE2": [(0,60)], 20 | } 21 | 22 | def parse_fisheye_intrinsics(): 23 | json_data = { 24 | "final_reproj_error": 0.2916398582648, 25 | "fps": 59.94005994005994, 26 | "image_height": 2028, 27 | "image_width": 2704, 28 | "intrinsic_type": "FISHEYE", 29 | "intrinsics": { 30 | "aspect_ratio": 1.0029788958491257, 31 | "focal_length": 796.8544625226342, 32 | "principal_pt_x": 1354.4265245977356, 33 | "principal_pt_y": 1011.4847310011687, 34 | "radial_distortion_1": -0.02196117964405394, 35 | "radial_distortion_2": -0.018959717016668237, 36 | "radial_distortion_3": 0.001693880829392453, 37 | "radial_distortion_4": -0.00016807228608000285, 38 | "skew": 0.0, 39 | }, 40 | "nr_calib_images": 59, 41 | "stabelized": False, 42 | } 43 | assert json_data['intrinsic_type'] == 'FISHEYE' 44 | intr_data = json_data['intrinsics'] 45 | 46 | # img size 47 | h = json_data['image_height'] 48 | w = json_data['image_width'] 49 | 50 | # pinhole parameters 51 | f = intr_data['focal_length'] 52 | px = intr_data['principal_pt_x'] 53 | py = intr_data['principal_pt_y'] 54 | 55 | # Kannala-Brandt non-linear parameters for distortion 56 | kb8 = [ 57 | intr_data['radial_distortion_1'], 58 | intr_data['radial_distortion_2'], 59 | intr_data['radial_distortion_3'], 60 | intr_data['radial_distortion_4'] 61 | ] 62 | 63 | opencv_intr_dict = { 64 | 'DIM': np.array([w, h], dtype=np.int64), 65 | 'K': np.array([ 66 | [f, 0, px], 67 | [0, f, py], 68 | [0, 0, 1] 69 | ], dtype=np.float64), 70 | 'D': np.array([kb8]).T 71 | } 72 | return opencv_intr_dict 73 | 74 | def check_black_list(frame, range_list): 75 | for start,end in range_list: 76 | if frame >= start and frame <= end: 77 | return True 78 | 79 | class FisheyeRectConverter: 80 | def __init__(self, K, D, out_size, out_fov): 81 | out_size = np.array(out_size) 82 | # vertical fov 83 | out_f = (out_size[1] / 2) / np.tan(out_fov[1]/180*np.pi/2) 84 | out_fx = (out_size[0] / 2) / np.tan(out_fov[0]/180*np.pi/2) 85 | out_K = np.array([ 86 | [out_fx, 0, out_size[0]/2], 87 | [0, out_f, out_size[1]/2], 88 | [0, 0, 1] 89 | ], dtype=np.float32) 90 | map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, D, np.eye(3), out_K, out_size, cv2.CV_16SC2) 91 | 92 | self.map1 = map1 93 | self.map2 = map2 94 | 95 | def inverse_forward(self, img): 96 | distor_img = cv2.remap(img, 97 | self.map2, self.map1, 98 | interpolation=cv2.INTER_AREA, 99 | borderMode=cv2.BORDER_CONSTANT) 100 | return distor_img 101 | 102 | def forward(self, img): 103 | rect_img = cv2.remap(img, 104 | self.map1, self.map2, 105 | interpolation=cv2.INTER_AREA, 106 | borderMode=cv2.BORDER_CONSTANT) 107 | return rect_img 108 | 109 | def to_relative_pose_json(cam_pose_slam, start_frame, end_frame, output_folder, interval = 12): 110 | basename = os.path.basename(output_folder) 111 | black_list_frames = black_list.get(basename, []) 112 | dct = {} 113 | for i, frame_idx in enumerate(range(start_frame, end_frame+1)): 114 | goal_frame = frame_idx + interval 115 | if goal_frame >= end_frame: 116 | break 117 | pose1 = cam_pose_slam[i] 118 | pose2 = cam_pose_slam[i+interval] 119 | t1 = pose1[:3,-1] 120 | rot1 = pose1[:3,:3] 121 | t2 = pose2[:3,-1] 122 | rot2 = pose2[:3,:3] 123 | T_slam_cam1 = RigidTransform(rotation=rot1, 124 | translation=t1, 125 | from_frame=f"cam_{frame_idx}", 126 | to_frame=f"slam") 127 | T_slam_cam2 = RigidTransform(rotation=rot2, 128 | translation=t2, 129 | from_frame=f"cam_{goal_frame}", 130 | to_frame=f"slam") 131 | T_cam1_cam2 = T_slam_cam1.inverse() * T_slam_cam2 132 | 133 | T_cam1_cam2_t = T_cam1_cam2.translation.astype(float) 134 | T_cam1_cam2_r = T_cam1_cam2.rotation 135 | T_cam1_cam2_r = T_cam1_cam2_r.tolist() 136 | 137 | frame_name = f"{frame_idx:05}.jpg" 138 | frame_namelk = f"{goal_frame:05}.jpg" 139 | if len(black_list_frames) > 0: 140 | if check_black_list(frame_idx, black_list_frames): 141 | continue 142 | dct[frame_name] = (list(T_cam1_cam2_t), T_cam1_cam2_r, frame_namelk) 143 | 144 | json_save_path = os.path.join(output_folder, f'labels_{interval}.json') 145 | if os.path.exists(json_save_path): 146 | os.remove(json_save_path) 147 | with open(json_save_path, 'w+') as fp: 148 | json.dump(dct, fp, indent=4, sort_keys=True) 149 | 150 | def to_absolute_pose_json(cam_pose_slam, start_frame, end_frame, output_folder): 151 | basename = os.path.basename(output_folder) 152 | black_list_frames = black_list.get(basename, []) # Have to manually check if the ORB-SLAM output is correct. 153 | # One way is to plot time vs. x/y/z of all the extracted trajectories, usually the outliers are incorrect camera poses. 154 | dct = {} 155 | for i, frame_idx in enumerate(range(start_frame, end_frame)): 156 | pose1 = cam_pose_slam[i] 157 | frame_name = f"{frame_idx:05}.jpg" 158 | if len(black_list_frames) > 0: 159 | if check_black_list(frame_idx, black_list_frames): 160 | continue 161 | dct[frame_name] = pose1.tolist() 162 | 163 | json_save_path = os.path.join(output_folder, f'raw_labels.json') 164 | if os.path.exists(json_save_path): 165 | os.remove(json_save_path) 166 | print(f"Save to {json_save_path}") 167 | with open(json_save_path, 'w+') as fp: 168 | json.dump(dct, fp, indent=4, sort_keys=True) 169 | 170 | def main(args): 171 | num_workers = multiprocessing.cpu_count() 172 | cv2.setNumThreads(1) 173 | 174 | # intr_path = pathlib.Path("/data11/zhang401/umi_redwood/example/calibration/gopro_intrinsics_2_7k.json") 175 | opencv_intr_dict = parse_fisheye_intrinsics() 176 | out_res = (args.out_res[0],args.out_res[1]) 177 | out_fov = (args.out_fov[0],args.out_fov[1]) 178 | 179 | ipath = pathlib.Path(args.input_dir) 180 | demos_path = ipath.joinpath('demos') 181 | plan_path = ipath.joinpath('dataset_plan.pkl') 182 | plan = pickle.load(plan_path.open('rb')) 183 | 184 | fisheye_converter = FisheyeRectConverter( 185 | opencv_intr_dict['K'], 186 | opencv_intr_dict['D'], 187 | out_size=out_res, 188 | out_fov=out_fov 189 | ) 190 | 191 | def save_video_frames(mp4_path, output_image_dir, frame_start, frame_end): 192 | with av.open(mp4_path) as container: 193 | in_stream = container.streams.video[0] 194 | # in_stream.thread_type = "AUTO" 195 | in_stream.thread_count = 1 196 | for frame_idx, frame in enumerate(container.decode(in_stream)): 197 | if frame_idx < frame_start: 198 | continue 199 | elif frame_idx < frame_end: 200 | img = frame.to_ndarray(format='rgb24') 201 | img = fisheye_converter.forward(img) 202 | img_path = os.path.join(output_image_dir, f"{frame_idx:05}.jpg") 203 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 204 | 205 | 206 | with tqdm(total=len(plan)) as pbar: 207 | with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: 208 | futures = set() 209 | for plan_episode in plan: 210 | if len(futures) >= num_workers: 211 | completed, futures = concurrent.futures.wait(futures, 212 | return_when=concurrent.futures.FIRST_COMPLETED) 213 | pbar.update(len(completed)) 214 | 215 | grippers = plan_episode['grippers'] 216 | n_grippers = len(grippers) 217 | cameras = plan_episode['cameras'] 218 | n_cameras = len(cameras) 219 | assert n_grippers == 1, "Only support one gripper implemented" 220 | assert n_cameras == 1, "Only support one camera implemented" 221 | 222 | camera = cameras[0] 223 | video_path_rel = camera['video_path'] 224 | video_path = demos_path.joinpath(video_path_rel).absolute() 225 | if not os.path.exists(video_path): 226 | print(f"Video {video_path} does not exist") 227 | continue 228 | 229 | video_start, video_end = camera['video_start_end'] 230 | mp4_path = str(video_path) 231 | video_name = mp4_path.split('/')[-2] 232 | if len(args.process_these_folders) > 0 and video_name not in args.process_these_folders: 233 | continue 234 | output_image_dir = os.path.join(args.output_dir, video_name, "images") 235 | os.makedirs(output_image_dir, exist_ok=True) 236 | if not os.path.exists(output_image_dir) or args.save_frame: 237 | futures.add(executor.submit(save_video_frames, 238 | mp4_path, output_image_dir, video_start, video_end)) 239 | 240 | cam_pose_slam = grippers[0]["cam_pose_slam"] 241 | if args.save_absolute_poses: 242 | to_absolute_pose_json(cam_pose_slam, video_start, video_end, os.path.join(args.output_dir, video_name)) 243 | if args.save_relative_poses: 244 | for interval in args.intervals: 245 | to_relative_pose_json(cam_pose_slam, video_start, video_end, os.path.join(args.output_dir, video_name), interval=interval) 246 | 247 | completed, futures = concurrent.futures.wait(futures) 248 | pbar.update(len(completed)) 249 | 250 | print([x.result() for x in completed]) 251 | 252 | if __name__=='__main__': 253 | parser = argparse.ArgumentParser(description='diffusion label') 254 | parser.add_argument('--output_dir', type=str, help="path to save the task data") 255 | parser.add_argument('--input_dir', type=str, help="path to the folder containing dataset_plan.pkl") 256 | parser.add_argument('--intervals', type=int, nargs='+', default=[12], help="intervals to calculate relative poses") 257 | parser.add_argument('--save_frame',action='store_true') 258 | parser.add_argument('--save_absolute_poses',action='store_true') 259 | parser.add_argument('--save_relative_poses',action='store_true') 260 | parser.add_argument('--out_fov', nargs='+', type=float, default=[69, 69]) 261 | parser.add_argument('--out_res', nargs='+', type=int, default=[360, 360]) 262 | parser.add_argument('--process_these_folders', nargs='+', type=str, default=[], help="extract frames/save camera poses for these folders only") 263 | args = parser.parse_args() 264 | 265 | main(args) -------------------------------------------------------------------------------- /src/models/score_sde/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /src/models/score_sde/ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | 18 | from . import utils, layers, layerspp, normalization 19 | import torch.nn as nn 20 | import functools 21 | import torch 22 | import numpy as np 23 | 24 | ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp 25 | ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp 26 | Combine = layerspp.Combine 27 | conv3x3 = layerspp.conv3x3 28 | conv1x1 = layerspp.conv1x1 29 | get_act = layers.get_act 30 | get_normalization = normalization.get_normalization 31 | default_initializer = layers.default_init 32 | 33 | 34 | @utils.register_model(name='ncsnpp') 35 | class NCSNpp(nn.Module): 36 | """NCSN++ model""" 37 | 38 | def __init__(self, config): 39 | super().__init__() 40 | self.config = config 41 | self.act = act = get_act(config) 42 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) 43 | 44 | self.nf = nf = config.model.nf 45 | ch_mult = config.model.ch_mult 46 | self.num_res_blocks = num_res_blocks = config.model.num_res_blocks 47 | self.attn_resolutions = attn_resolutions = config.model.attn_resolutions 48 | dropout = config.model.dropout 49 | resamp_with_conv = config.model.resamp_with_conv 50 | self.num_resolutions = num_resolutions = len(ch_mult) 51 | self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] 52 | 53 | self.conditional = conditional = config.model.conditional # noise-conditional 54 | fir = config.model.fir 55 | fir_kernel = config.model.fir_kernel 56 | self.skip_rescale = skip_rescale = config.model.skip_rescale 57 | self.resblock_type = resblock_type = config.model.resblock_type.lower() 58 | self.progressive = progressive = config.model.progressive.lower() 59 | self.progressive_input = progressive_input = config.model.progressive_input.lower() 60 | self.embedding_type = embedding_type = config.model.embedding_type.lower() 61 | init_scale = config.model.init_scale 62 | assert progressive in ['none', 'output_skip', 'residual'] 63 | assert progressive_input in ['none', 'input_skip', 'residual'] 64 | assert embedding_type in ['fourier', 'positional'] 65 | combine_method = config.model.progressive_combine.lower() 66 | combiner = functools.partial(Combine, method=combine_method) 67 | 68 | AttnBlock = functools.partial(layerspp.AttnBlockpp, 69 | init_scale=init_scale, 70 | skip_rescale=skip_rescale) 71 | 72 | Upsample = functools.partial(layerspp.Upsample, 73 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 74 | # Downsampling block 75 | Downsample = functools.partial(layerspp.Downsample, 76 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 77 | 78 | 79 | self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 80 | self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 81 | 82 | ResnetBlock = functools.partial(ResnetBlockBigGAN, 83 | act=act, 84 | dropout=dropout, 85 | fir=fir, 86 | fir_kernel=fir_kernel, 87 | init_scale=init_scale, 88 | skip_rescale=skip_rescale, 89 | temb_dim=nf * 4) 90 | 91 | 92 | # ============================================== input layer for embedding the time/noise 93 | 94 | # timestep/noise_level embedding; only for continuous training 95 | # Gaussian Fourier features embeddings. 96 | assert config.training.continuous, "Fourier features are only used for continuous training." 97 | self.temb_modules = torch.nn.ModuleList() 98 | self.temb_modules.append(layerspp.GaussianFourierProjection(embedding_size=nf, scale=config.model.fourier_scale)) 99 | embed_dim = 2 * nf 100 | 101 | # ============================================== time embedding pre processing 102 | if conditional: 103 | self.temb_modules.append(nn.Linear(embed_dim, nf * 4)) 104 | self.temb_modules[-1].weight.data = default_initializer()(self.temb_modules[-1].weight.shape) 105 | nn.init.zeros_(self.temb_modules[-1].bias) 106 | self.temb_modules.append(nn.Linear(nf * 4, nf * 4)) 107 | self.temb_modules[-1].weight.data = default_initializer()(self.temb_modules[-1].weight.shape) 108 | nn.init.zeros_(self.temb_modules[-1].bias) 109 | 110 | channels = config.data.num_channels + 4 # also a hack 111 | if progressive_input != 'none': 112 | input_pyramid_ch = channels 113 | 114 | # ============================================== first layer for input 115 | self.input_conv = conv3x3(8+128, nf) # expanded input for conditioning 116 | hs_c = [nf] 117 | 118 | # ============================================== lookup table for conditioning dimensions 119 | conditioning_dim_table = [128,128,256,256] 120 | 121 | encoder_modules = torch.nn.ModuleList() 122 | 123 | in_ch = nf 124 | for i_level in range(num_resolutions): 125 | level_modules = torch.nn.ModuleDict({ 126 | 'resblocks': torch.nn.ModuleList(), 127 | 'attn': None, 128 | 'downres': None, 129 | 'combiner': None 130 | }) 131 | # Residual blocks for this resolution 132 | for i_block in range(num_res_blocks): 133 | out_ch = nf * ch_mult[i_level] 134 | cond_dim = conditioning_dim_table[i_level] 135 | level_modules['resblocks'].append(ResnetBlock(in_ch=in_ch, out_ch=out_ch, conditioning_dim=cond_dim)) 136 | in_ch = out_ch 137 | 138 | if all_resolutions[i_level] in attn_resolutions: 139 | level_modules['attn'] = AttnBlock(channels=in_ch) 140 | hs_c.append(in_ch) 141 | 142 | # if not last level of downsampling 143 | if i_level != num_resolutions - 1: 144 | cond_dim = conditioning_dim_table[i_level+1] 145 | level_modules['downres'] = ResnetBlock(down=True, in_ch=in_ch, conditioning_dim=cond_dim) 146 | 147 | level_modules['combiner'] = combiner(dim1=input_pyramid_ch, dim2=in_ch) 148 | if combine_method == 'cat': 149 | in_ch *= 2 150 | hs_c.append(in_ch) 151 | 152 | encoder_modules.append(level_modules) 153 | self.encoder_modules = encoder_modules 154 | 155 | # CENTRAL BLOCK 156 | in_ch = hs_c[-1] 157 | cond_dim = conditioning_dim_table[-1] 158 | central_block = torch.nn.ModuleList() 159 | central_block.append(ResnetBlock(in_ch=in_ch, conditioning_dim=cond_dim)) 160 | central_block.append(AttnBlock(channels=in_ch)) 161 | central_block.append(ResnetBlock(in_ch=in_ch, conditioning_dim=cond_dim)) 162 | self.central_block = central_block 163 | 164 | pyramid_ch = 0 165 | # Upsampling block 166 | decoder_modules = torch.nn.ModuleList([None]*num_resolutions) 167 | for i_level in reversed(range(num_resolutions)): 168 | layer_modules = torch.nn.ModuleDict({ 169 | 'resblocks': torch.nn.ModuleList(), 170 | 'attn': None, 171 | 'group_norm': None, 172 | 'conv3x3': None, 173 | 'final_res': None, 174 | }) 175 | for i_block in range(num_res_blocks + 1): 176 | out_ch = nf * ch_mult[i_level] 177 | layer_modules['resblocks'].append(ResnetBlock(in_ch=in_ch + hs_c.pop(),out_ch=out_ch)) 178 | in_ch = out_ch 179 | 180 | if all_resolutions[i_level] in attn_resolutions: 181 | layer_modules['attn'] = AttnBlock(channels=in_ch) 182 | 183 | layer_modules['group_norm'] = nn.GroupNorm(num_groups=min(in_ch // 4, 32),num_channels=in_ch, eps=1e-6) 184 | if i_level == num_resolutions - 1: 185 | layer_modules['conv3x3'] = conv3x3(in_ch, channels-4, init_scale=init_scale) 186 | else: 187 | layer_modules['conv3x3'] = conv3x3(in_ch, channels-4, bias=True, init_scale=init_scale) 188 | # HACK 189 | pyramid_ch = channels 190 | 191 | if i_level != 0: 192 | layer_modules['final_res'] = ResnetBlock(in_ch=in_ch, up=True) 193 | 194 | decoder_modules[i_level] = layer_modules 195 | self.decoder_modules = decoder_modules 196 | 197 | assert not hs_c 198 | 199 | # ================== my modules 200 | self.conditioning_conv_0 = nn.Conv2d(128,128, stride=1, bias=True, padding=1, kernel_size=3) 201 | self.conditioning_conv_1 = nn.Conv2d(128,128, stride=2, bias=True, padding=1, kernel_size=3) 202 | self.conditioning_conv_2 = nn.Conv2d(128, 256, stride=2, bias=True, padding=1, kernel_size=3) 203 | self.conditioning_conv_3 = nn.Conv2d(256, 256, stride=2, bias=True, padding=1, kernel_size=3) 204 | 205 | def forward(self, x, cond_im, time_cond, conditioning): 206 | # in this model, time cond is just the std of noise 207 | # adjust std of cond_im to match statistics of current x 208 | cond_im = (cond_im+0.14)*time_cond[:,None,None,None]*8.8 209 | 210 | # timestep/noise_level embedding; only for continuous training 211 | # Gaussian Fourier features embeddings. 212 | used_sigmas = time_cond 213 | temb = self.temb_modules[0](torch.log(used_sigmas)) 214 | 215 | # time embedding preprocessing 216 | temb = self.temb_modules[1](temb) 217 | temb = self.temb_modules[2](self.act(temb)) 218 | 219 | if not self.config.data.centered: 220 | # If input data is in [0, 1] 221 | x = 2 * x - 1. 222 | x = torch.cat([x,cond_im],1) 223 | 224 | # ========== create conditioning pyramid ========== 225 | cond_0 = self.conditioning_conv_0(self.act(conditioning)) 226 | cond_1 = self.conditioning_conv_1(self.act(cond_0)) 227 | cond_2 = self.conditioning_conv_2(self.act(cond_1)) 228 | cond_3 = self.conditioning_conv_3(self.act(cond_2)) 229 | conditioning_stack = { 230 | 32: cond_0, 231 | 16: cond_1, 232 | 8: cond_2, 233 | 4: cond_3, 234 | } 235 | # ============================================= 236 | 237 | # Downsampling block 238 | input_pyramid = None 239 | if self.progressive_input != 'none': 240 | input_pyramid = x 241 | 242 | # ============================================ 243 | x = torch.cat([x,conditioning],1) 244 | # ============================================ 245 | 246 | hs = [self.input_conv(x)]; print(f'input layer output: {hs[-1].shape}') 247 | for i_level in range(self.num_resolutions): 248 | level_modules = self.encoder_modules[i_level]; print(f'\n======== Encoder level: {i_level}') 249 | # Residual blocks for this resolution 250 | for resblock in level_modules['resblocks']: 251 | layer_cond = conditioning_stack[hs[-1].shape[2]] 252 | h = resblock(hs[-1], temb, layer_cond); print(f'Resblock output: {h.shape}') 253 | if level_modules['attn'] is not None: 254 | h = level_modules['attn'](h); print(f'Attn output: {h.shape}') 255 | hs.append(h); print('Push Skip') 256 | 257 | if level_modules['downres'] is not None: 258 | print(f'===== level post processing') 259 | layer_cond = conditioning_stack[hs[-1].shape[2]//2] 260 | h = level_modules['downres'](hs[-1], temb, layer_cond); print(f'Resnet output: {h.shape}') 261 | 262 | input_pyramid = self.pyramid_downsample(input_pyramid) 263 | h = level_modules['combiner'](input_pyramid, h); print(f'combiner output: {h.shape}') 264 | 265 | hs.append(h); print('Push Skip') 266 | 267 | print(f'\n======== central block') 268 | h = hs[-1] 269 | layer_cond = conditioning_stack[h.shape[2]] 270 | h = self.central_block[0](h, temb, layer_cond) 271 | h = self.central_block[1](h) 272 | h = self.central_block[2](h, temb, layer_cond) 273 | print(f'Resnet output: {h.shape}') 274 | 275 | pyramid = None 276 | 277 | # Upsampling block 278 | for i_level in reversed(range(self.num_resolutions)): 279 | level_modules = self.decoder_modules[i_level]; print(f'\n======== Decoder level: {i_level}') 280 | for resblock in level_modules['resblocks']: 281 | h = resblock(torch.cat([h, hs.pop()], dim=1), temb); print(f'Resblock (cat skip) output: {h.shape}') 282 | 283 | if level_modules['attn'] is not None: 284 | h = level_modules['attn'](h); print(f'Attn output: {h.shape}') 285 | 286 | print(f'===== level post processing') 287 | pyramid_h = self.act(level_modules['group_norm'](h)) 288 | print(f'GroupNorm output: {pyramid_h.shape}') 289 | pyramid_h = level_modules['conv3x3'](pyramid_h) 290 | print(f'Conv3x3 output: {pyramid_h.shape}') 291 | 292 | if i_level == self.num_resolutions - 1: 293 | pyramid = pyramid_h; print(f'Init pyramid') 294 | else: 295 | pyramid = self.pyramid_upsample(pyramid) 296 | pyramid = pyramid + pyramid_h; print(f'Sum pyramid') 297 | 298 | if level_modules['final_res'] is not None: 299 | h = level_modules['final_res'](h, temb); print(f'Final resnet output: {h.shape}') 300 | 301 | assert not hs # ensure we used all the skip connections 302 | 303 | h = pyramid 304 | 305 | if self.config.model.scale_by_sigma: 306 | print('Scale Sigma') 307 | used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) 308 | h = h / used_sigmas 309 | 310 | return h 311 | --------------------------------------------------------------------------------