├── examples ├── gan_shapes.png ├── gan_generator_voxels_chairs.to ├── gan_generator_voxels_sofas.to └── gan_generator_voxels_airplanes.to ├── rendering ├── depth_fragment.glsl ├── depth_vertex.glsl ├── vertex.glsl ├── math.py ├── shader.py ├── fragment.glsl ├── binary_voxels_to_mesh.py ├── raymarching.py └── __init__.py ├── .gitignore ├── model ├── classifier.py ├── __init__.py ├── progressive_gan.py ├── gan.py ├── point_sdf_net.py ├── autoencoder.py └── sdf_net.py ├── demo_gan.py ├── demo_autoencoder.py ├── demo_training.py ├── demo_data_preparation.py ├── shapenet_metadata.py ├── train_point_gan.py ├── util.py ├── datasets.py ├── metrics.py ├── train_sdf_autodecoder.py ├── train_wgan.py ├── demo_sdf_net.py ├── prepare_data.py ├── train_gan.py ├── train_autoencoder.py ├── train_point_gan_ref.py ├── train_hybrid_wgan.py ├── train_hybrid_gan.py ├── demo_latent_space.py ├── README.md ├── prepare_shapenet_dataset.py ├── train_hybrid_progressive_gan.py └── create_plot.py /examples/gan_shapes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marian42/shapegan/HEAD/examples/gan_shapes.png -------------------------------------------------------------------------------- /rendering/depth_fragment.glsl: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | void main() 4 | { 5 | gl_FragColor = vec4(1.0); 6 | } -------------------------------------------------------------------------------- /examples/gan_generator_voxels_chairs.to: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marian42/shapegan/HEAD/examples/gan_generator_voxels_chairs.to -------------------------------------------------------------------------------- /examples/gan_generator_voxels_sofas.to: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marian42/shapegan/HEAD/examples/gan_generator_voxels_sofas.to -------------------------------------------------------------------------------- /examples/gan_generator_voxels_airplanes.to: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marian42/shapegan/HEAD/examples/gan_generator_voxels_airplanes.to -------------------------------------------------------------------------------- /rendering/depth_vertex.glsl: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | in vec3 position; 3 | 4 | uniform mat4 VP; 5 | 6 | void main() 7 | { 8 | gl_Position = VP * vec4(position, 1.0); 9 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | 4 | \.mypy_cache/** 5 | .vscode/** 6 | *.to 7 | 8 | *.gz 9 | 10 | screenshots/ 11 | images/ 12 | plots/ 13 | data/shapenet** 14 | 15 | *.pdf 16 | 17 | *.csv 18 | 19 | *.npy 20 | *.txt 21 | 22 | *.mp4 23 | 24 | 25 | data/ 26 | 27 | generated_objects/ 28 | -------------------------------------------------------------------------------- /rendering/vertex.glsl: -------------------------------------------------------------------------------- 1 | varying out vec3 normal; 2 | varying out vec3 position; 3 | 4 | varying out vec4 shadowPosition; 5 | varying out vec3 lightPosition; 6 | 7 | uniform mat4 VP; 8 | uniform mat4 lightVP; 9 | uniform float yOffset; 10 | 11 | void main() { 12 | vec3 vertexWithOffset = gl_Vertex + vec3(0.0, yOffset, 0.0); 13 | gl_Position = VP * vec4(vertexWithOffset, 1.0); 14 | position = gl_Position.xyz; 15 | 16 | shadowPosition = lightVP * vec4(vertexWithOffset, 1.0); 17 | lightPosition = (VP * inverse(lightVP) * vec4(0.0, 0.0, -1.0, 1.0)).xyz; 18 | 19 | normal = (VP * vec4(gl_Normal, 0.0)).xyz; 20 | } -------------------------------------------------------------------------------- /rendering/math.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from scipy.spatial.transform import Rotation 4 | 5 | PROJECTION_MATRIX = np.array( 6 | [[ 1.73205081, 0, 0, 0, ], 7 | [ 0, 1.73205081, 0, 0, ], 8 | [ 0, 0, -1.02020202, -0.2020202, ], 9 | [ 0, 0, -1, 0, ]], dtype=float) 10 | 11 | def get_rotation_matrix(angle, axis='y'): 12 | rotation = Rotation.from_euler(axis, angle, degrees=True) 13 | matrix = np.identity(4) 14 | matrix[:3, :3] = rotation.as_dcm() 15 | return matrix 16 | 17 | def get_camera_transform(camera_distance, rotation_y, rotation_x=0, project=False): 18 | camera_transform = np.identity(4) 19 | camera_transform[2, 3] = -camera_distance 20 | camera_transform = np.matmul(camera_transform, get_rotation_matrix(rotation_x, axis='x')) 21 | camera_transform = np.matmul(camera_transform, get_rotation_matrix(rotation_y, axis='y')) 22 | 23 | if project: 24 | camera_transform = np.matmul(PROJECTION_MATRIX, camera_transform) 25 | return camera_transform -------------------------------------------------------------------------------- /model/classifier.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | 3 | class Classifier(SavableModule): 4 | def __init__(self, label_count): 5 | super(Classifier, self).__init__(filename="classifier.to") 6 | 7 | self.layers = nn.Sequential( 8 | nn.Conv3d(in_channels = 1, out_channels = 12, kernel_size = 5), 9 | nn.ReLU(inplace=True), 10 | nn.MaxPool3d(2), 11 | 12 | nn.Conv3d(in_channels = 12, out_channels = 16, kernel_size = 5), 13 | nn.ReLU(inplace=True), 14 | nn.MaxPool3d(2), 15 | 16 | nn.Conv3d(in_channels = 16, out_channels = 32, kernel_size = 5), 17 | nn.ReLU(inplace=True), 18 | 19 | Lambda(lambda x: x.view(x.shape[0], -1)), 20 | 21 | nn.Linear(in_features = 32, out_features = label_count), 22 | nn.Softmax(dim=1) 23 | ) 24 | 25 | self.cuda() 26 | 27 | def forward(self, x): 28 | if len(x.shape) == 3: 29 | x = x.unsqueeze(dim = 0) # add dimension for batch 30 | if len(x.shape) == 4: 31 | x = x.unsqueeze(dim = 1) # add dimension for channels 32 | 33 | return self.layers(x) -------------------------------------------------------------------------------- /demo_gan.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | import torch 3 | import time 4 | import numpy as np 5 | import sys 6 | 7 | from rendering import MeshRenderer 8 | from model.gan import Generator, LATENT_CODE_SIZE 9 | from util import device, standard_normal_distribution 10 | 11 | generator = Generator() 12 | if "wgan" in sys.argv: 13 | generator.filename = "wgan-generator.to" 14 | generator.load() 15 | generator.eval() 16 | 17 | viewer = MeshRenderer() 18 | 19 | STEPS = 20 20 | 21 | TRANSITION_TIME = 0.4 22 | WAIT_TIME = 0.8 23 | 24 | def get_random(): 25 | return standard_normal_distribution.sample(sample_shape=(LATENT_CODE_SIZE,)).to(device) 26 | 27 | previous_model = None 28 | next_model = get_random() 29 | 30 | for epoch in count(): 31 | try: 32 | previous_model = next_model 33 | next_model = get_random() 34 | 35 | for step in range(STEPS + 1): 36 | progress = step / STEPS 37 | model = None 38 | if step < STEPS: 39 | model = previous_model * (1 - progress) + next_model * progress 40 | else: 41 | model = next_model 42 | 43 | viewer.set_voxels(generator(model).squeeze().detach().cpu()) 44 | time.sleep(TRANSITION_TIME / STEPS) 45 | 46 | time.sleep(WAIT_TIME) 47 | 48 | except KeyboardInterrupt: 49 | viewer.stop() 50 | break -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Sequential, Linear, ReLU, BatchNorm1d 4 | 5 | import os 6 | 7 | MODEL_PATH = "models" 8 | CHECKPOINT_PATH = os.path.join(MODEL_PATH, 'checkpoints') 9 | LATENT_CODES_FILENAME = os.path.join(MODEL_PATH, "sdf_net_latent_codes.to") 10 | LATENT_CODE_SIZE = 128 11 | 12 | class Lambda(nn.Module): 13 | def __init__(self, function): 14 | super(Lambda, self).__init__() 15 | self.function = function 16 | 17 | def forward(self, x): 18 | return self.function(x) 19 | 20 | class SavableModule(nn.Module): 21 | def __init__(self, filename): 22 | super(SavableModule, self).__init__() 23 | self.filename = filename 24 | 25 | def get_filename(self, epoch=None, filename=None): 26 | if filename is None: 27 | filename = self.filename 28 | if epoch is None: 29 | return os.path.join(MODEL_PATH, filename) 30 | else: 31 | filename = filename.split('.') 32 | filename[-2] += '-epoch-{:05d}'.format(epoch) 33 | filename = '.'.join(filename) 34 | return os.path.join(CHECKPOINT_PATH, filename) 35 | 36 | 37 | def load(self, epoch=None): 38 | self.load_state_dict(torch.load(self.get_filename(epoch=epoch)), strict=False) 39 | 40 | def save(self, epoch=None): 41 | if epoch is not None and not os.path.exists(CHECKPOINT_PATH): 42 | os.mkdir(CHECKPOINT_PATH) 43 | torch.save(self.state_dict(), self.get_filename(epoch=epoch)) 44 | 45 | @property 46 | def device(self): 47 | return next(self.parameters()).device -------------------------------------------------------------------------------- /demo_autoencoder.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | 7 | import time 8 | import random 9 | import numpy as np 10 | import sys 11 | 12 | from rendering import MeshRenderer 13 | from model.autoencoder import Autoencoder, LATENT_CODE_SIZE 14 | from util import device 15 | from datasets import VoxelDataset 16 | 17 | dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy') 18 | 19 | autoencoder = Autoencoder(is_variational='classic' not in sys.argv) 20 | autoencoder.load() 21 | autoencoder.eval() 22 | 23 | viewer = MeshRenderer() 24 | 25 | STEPS = 40 26 | 27 | SHAPE = (LATENT_CODE_SIZE, ) 28 | 29 | TRANSITION_TIME = 1.2 30 | WAIT_TIME = 1.2 31 | 32 | SAMPLE_FROM_LATENT_DISTRIBUTION = 'sample' in sys.argv 33 | 34 | def get_latent_distribution(): 35 | print("Calculating latent distribution...") 36 | indices = random.sample(list(range(len(dataset))), min(1000, len(dataset))) 37 | voxels = torch.stack([dataset[i] for i in indices]).to(device) 38 | with torch.no_grad(): 39 | codes = autoencoder.encode(voxels) 40 | latent_codes_flattened = codes.detach().cpu().numpy().reshape(-1) 41 | mean = np.mean(latent_codes_flattened) 42 | variance = np.var(latent_codes_flattened) ** 0.5 43 | print('Latent distribution: µ = {:.3f}, σ = {:.3f}'.format(mean, variance)) 44 | return torch.distributions.normal.Normal(mean, variance) 45 | 46 | if SAMPLE_FROM_LATENT_DISTRIBUTION: 47 | latent_distribution = get_latent_distribution() 48 | 49 | def get_random(): 50 | if SAMPLE_FROM_LATENT_DISTRIBUTION: 51 | return latent_distribution.sample(sample_shape=SHAPE).to(device) 52 | else: 53 | index = random.randint(0, len(dataset) -1) 54 | return autoencoder.encode(dataset[index].to(device)) 55 | 56 | 57 | previous_model = None 58 | next_model = get_random() 59 | 60 | for epoch in count(): 61 | try: 62 | previous_model = next_model 63 | next_model = get_random() 64 | 65 | start = time.perf_counter() 66 | end = start + TRANSITION_TIME 67 | while time.perf_counter() < end: 68 | progress = min((time.perf_counter() - start) / TRANSITION_TIME, 1.0) 69 | model = previous_model * (1 - progress) + next_model * progress 70 | voxels = autoencoder.decode(model).detach().cpu() 71 | viewer.set_voxels(voxels) 72 | 73 | time.sleep(WAIT_TIME) 74 | 75 | except KeyboardInterrupt: 76 | viewer.stop() 77 | break -------------------------------------------------------------------------------- /model/progressive_gan.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from util import standard_normal_distribution 3 | 4 | RESOLUTIONS = [8, 16, 32, 64] 5 | FEATURE_COUNTS = [128, 64, 32, 1] 6 | FINAL_LAYER_FEATURES = 256 7 | 8 | # works like fromRGB in the Progressive GAN paper 9 | def from_SDF(x, iteration): 10 | resolution = RESOLUTIONS[iteration] 11 | target_feature_count = FEATURE_COUNTS[iteration] 12 | 13 | x = x.reshape((-1, 1, resolution, resolution, resolution)) 14 | batch_size = x.shape[0] 15 | x = torch.cat((x, torch.zeros((batch_size, target_feature_count - 1, resolution, resolution, resolution), device=x.device)), dim=1) 16 | return x 17 | 18 | class Discriminator(SavableModule): 19 | def __init__(self): 20 | self.iteration = 0 21 | self.filename_base="hybrid_progressive_gan_discriminator_{:d}.to" 22 | super(Discriminator, self).__init__(filename=self.filename_base.format(self.iteration)) 23 | 24 | self.fade_in_progress = 1 25 | 26 | self.head = nn.Sequential( 27 | Lambda(lambda x: x.reshape(-1, 64*FINAL_LAYER_FEATURES)), 28 | nn.Linear(64*FINAL_LAYER_FEATURES, 128), 29 | nn.LeakyReLU(negative_slope=0.2), 30 | nn.Linear(128, 1) 31 | ) 32 | 33 | self.optional_layers = nn.ModuleList() 34 | for i in range(len(FEATURE_COUNTS)): 35 | in_channels = FEATURE_COUNTS[i] 36 | out_channels = FEATURE_COUNTS[i-1] if i > 0 else FINAL_LAYER_FEATURES 37 | submodule = nn.Sequential( 38 | nn.Conv3d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1), 39 | nn.LeakyReLU(negative_slope=0.2) 40 | ) 41 | self.optional_layers.append(submodule) 42 | self.add_module('optional_layer_{:d}'.format(i), submodule) 43 | 44 | def forward(self, x): 45 | x_in = x 46 | x = from_SDF(x, self.iteration) 47 | x = self.optional_layers[self.iteration](x) 48 | if (self.fade_in_progress < 1.0) and self.iteration > 0: 49 | x2 = from_SDF(x_in[:, ::2, ::2, ::2], self.iteration - 1) 50 | x = self.fade_in_progress * x + (1.0 - self.fade_in_progress) * x2 51 | 52 | i = self.iteration - 1 53 | while i >= 0: 54 | x = self.optional_layers[i](x) 55 | i -= 1 56 | 57 | return self.head(x).squeeze() 58 | 59 | def set_iteration(self, value): 60 | self.iteration = value 61 | self.filename = self.filename_base.format(self.iteration) 62 | -------------------------------------------------------------------------------- /demo_training.py: -------------------------------------------------------------------------------- 1 | from mesh_to_sdf import sample_sdf_near_surface 2 | import numpy as np 3 | import trimesh 4 | import torch 5 | from util import device, ensure_directory 6 | 7 | from model.sdf_net import SDFNet 8 | from rendering import MeshRenderer 9 | import sys 10 | import cv2 11 | LATENT_CODE_SIZE = 0 12 | 13 | MODEL_PATH = 'examples/chair.obj' 14 | 15 | mesh = trimesh.load(MODEL_PATH) 16 | points, sdf = sample_sdf_near_surface(mesh) 17 | 18 | save_images = 'save' in sys.argv 19 | 20 | if save_images: 21 | viewer = MeshRenderer(start_thread=False, size=1080) 22 | ensure_directory('images') 23 | else: 24 | viewer = MeshRenderer() 25 | 26 | points = torch.tensor(points, dtype=torch.float32, device=device) 27 | sdf = torch.tensor(sdf, dtype=torch.float32, device=device) 28 | sdf.clamp_(-0.1, 0.1) 29 | 30 | sdf_net = SDFNet(latent_code_size=LATENT_CODE_SIZE).to(device) 31 | optimizer = torch.optim.Adam(sdf_net.parameters(), lr=1e-5) 32 | 33 | BATCH_SIZE = 20000 34 | latent_code = torch.zeros((BATCH_SIZE, LATENT_CODE_SIZE), device=device) 35 | indices = torch.zeros(BATCH_SIZE, dtype=torch.int64, device=device) 36 | 37 | positive_indices = (sdf > 0).nonzero().squeeze().cpu().numpy() 38 | negative_indices = (sdf < 0).nonzero().squeeze().cpu().numpy() 39 | 40 | step = 0 41 | error_targets = np.logspace(np.log10(0.02), np.log10(0.0005), num=500) 42 | image_index = 0 43 | 44 | while True: 45 | try: 46 | indices[:BATCH_SIZE//2] = torch.tensor(np.random.choice(positive_indices, BATCH_SIZE//2), device=device) 47 | indices[BATCH_SIZE//2:] = torch.tensor(np.random.choice(negative_indices, BATCH_SIZE//2), device=device) 48 | 49 | sdf_net.zero_grad() 50 | predicted_sdf = sdf_net(points[indices, :], latent_code) 51 | batch_sdf = sdf[indices] 52 | loss = torch.mean(torch.abs(predicted_sdf - batch_sdf)) 53 | loss.backward() 54 | optimizer.step() 55 | 56 | if loss.item() < error_targets[image_index]: 57 | try: 58 | viewer.set_mesh(sdf_net.get_mesh(latent_code[0, :], voxel_resolution=64, raise_on_empty=True)) 59 | if save_images: 60 | image = viewer.get_image(flip_red_blue=True) 61 | cv2.imwrite("images/frame-{:05d}.png".format(image_index), image) 62 | image_index += 1 63 | except ValueError: 64 | pass 65 | step += 1 66 | print('Step {:04d}, Image {:04d}, loss: {:.6f}, target: {:.6f}'.format(step, image_index, loss.item(), error_targets[image_index])) 67 | except KeyboardInterrupt: 68 | viewer.stop() 69 | break 70 | -------------------------------------------------------------------------------- /rendering/shader.py: -------------------------------------------------------------------------------- 1 | from pygame.locals import * 2 | from OpenGL.GL import shaders 3 | import sys 4 | 5 | from OpenGL.GL import * 6 | from OpenGL.GLU import * 7 | 8 | class Shader(object): 9 | def initShader(self, vertex_shader_source, fragment_shader_source): 10 | self.program = glCreateProgram() 11 | 12 | self.vs = glCreateShader(GL_VERTEX_SHADER) 13 | glShaderSource(self.vs, [vertex_shader_source]) 14 | glAttachShader(self.program, self.vs) 15 | 16 | self.fs = glCreateShader(GL_FRAGMENT_SHADER) 17 | glShaderSource(self.fs, [fragment_shader_source]) 18 | glCompileShader(self.fs) 19 | glAttachShader(self.program, self.fs) 20 | 21 | glLinkProgram(self.program) 22 | 23 | self.vp_location = None 24 | self.light_vp_location = None 25 | self.shadow_texture_location = None 26 | self.is_floor_location = None 27 | self.y_offset_location = None 28 | self.color_location = None 29 | 30 | try: 31 | glUseProgram(self.program) 32 | except GLError: 33 | err = glGetProgramInfoLog(self.program) 34 | print(err.decode("utf-8")) 35 | sys.exit() 36 | 37 | def set_light_vp_matrix(self, light_vp_matrix): 38 | if self.light_vp_location is None: 39 | self.light_vp_location = glGetUniformLocation(self.program, 'lightVP') 40 | 41 | glUniformMatrix4fv(self.light_vp_location, 1, GL_TRUE, light_vp_matrix) 42 | 43 | def set_vp_matrix(self, vp_matrix): 44 | if self.vp_location is None: 45 | self.vp_location = glGetUniformLocation(self.program, 'VP') 46 | 47 | glUniformMatrix4fv(self.vp_location, 1, GL_TRUE, vp_matrix) 48 | 49 | def set_shadow_texture(self, texture): 50 | if self.shadow_texture_location is None: 51 | self.shadow_texture_location = glGetUniformLocation(self.program, 'shadow_map') 52 | glUniform1iv(self.shadow_texture_location, 1, GL_TRUE, texture) 53 | 54 | def set_floor(self, is_floor): 55 | if self.is_floor_location is None: 56 | self.is_floor_location = glGetUniformLocation(self.program, 'isFloor') 57 | glUniform1fv(self.is_floor_location, 1, 1.0 if is_floor else 0.0) 58 | 59 | def set_color(self, color): 60 | if self.color_location is None: 61 | self.color_location = glGetUniformLocation(self.program, 'albedo') 62 | glUniform3fv(self.color_location, 1, color) 63 | 64 | def set_y_offset(self, value): 65 | if self.y_offset_location is None: 66 | self.y_offset_location = glGetUniformLocation(self.program, 'yOffset') 67 | glUniform1fv(self.y_offset_location, 1, value) 68 | 69 | 70 | def use(self): 71 | glUseProgram(self.program) -------------------------------------------------------------------------------- /rendering/fragment.glsl: -------------------------------------------------------------------------------- 1 | in vec3 normal; 2 | in vec3 position; 3 | 4 | in vec4 shadowPosition; 5 | in vec3 lightPosition; 6 | 7 | sampler2D shadow_map; 8 | 9 | const float ambient = 0.5; 10 | const float diffuse = 0.5; 11 | const float specular = 0.3; 12 | 13 | uniform float isFloor; 14 | uniform vec3 albedo; 15 | 16 | float isInShadow(vec2 uv, float reference_depth) { 17 | return reference_depth > texture(shadow_map, uv.xy).r ? 1.0 : 0.0; 18 | } 19 | 20 | float texture2DShadowLerp(vec2 uv, float reference_depth, float shadowTextureSize) { 21 | vec2 texelSize = vec2(1.0) / shadowTextureSize; 22 | vec2 f = fract(uv * shadowTextureSize + 0.5); 23 | vec2 centroidUV = floor(uv*shadowTextureSize + 0.5)/shadowTextureSize; 24 | 25 | float lb = isInShadow(centroidUV+texelSize * vec2(0.0, 0.0), reference_depth); 26 | float lt = isInShadow(centroidUV+texelSize * vec2(0.0, 1.0), reference_depth); 27 | float rb = isInShadow(centroidUV+texelSize * vec2(1.0, 0.0), reference_depth); 28 | float rt = isInShadow(centroidUV+texelSize * vec2(1.0, 1.0), reference_depth); 29 | float a = mix(lb, lt, f.y); 30 | float b = mix(rb, rt, f.y); 31 | return mix(a, b, f.x); 32 | } 33 | 34 | float getShadow(vec4 shadowPosition, vec3 lightDotNormal){ 35 | vec3 shadow_coords = shadowPosition.xyz / shadowPosition.w; 36 | shadow_coords = shadow_coords * 0.5 + 0.5; 37 | 38 | if (shadow_coords.z > 1.0) { 39 | return 0.0; 40 | } 41 | 42 | float bias = max(0.002 * (1.0 - lightDotNormal), 0.001) / shadowPosition.w; 43 | float reference_depth = (shadow_coords.z - bias); 44 | vec2 shadowTextureSize = textureSize(shadow_map, 0); 45 | 46 | float result = 0.0; 47 | for(int x = -1; x <= 1; x++){ 48 | for(int y= -1; y <= 1; y++){ 49 | vec2 offset = vec2(x, y) / shadowTextureSize; 50 | result += texture2DShadowLerp(shadow_coords.xy + offset, reference_depth, shadowTextureSize); 51 | } 52 | } 53 | return clamp(result / 9.0, 0.0, 1.0); 54 | } 55 | 56 | void main() { 57 | normal = normalize(normal); 58 | vec3 viewDirection = normalize(-position); 59 | vec3 lightDirection = normalize(lightPosition - position); 60 | vec3 reflectDirection = -normalize(reflect(lightDirection, normal)); 61 | vec3 lightDotNormal = clamp(dot(normal, lightDirection), 0.0, 1.0); 62 | 63 | float shadow = getShadow(shadowPosition, lightDotNormal); 64 | float rimLight = pow(1.0 - clamp(-normal.z, 0.0, 1.0), 4) * 0.3; 65 | 66 | vec3 color = albedo * ambient 67 | + albedo * diffuse * lightDotNormal * (1.0 - shadow) 68 | + vec3(1.0) * specular * pow(max(0.0, dot(reflectDirection, viewDirection)), 20) * (1.0 - shadow) 69 | + vec3(1.0) * rimLight; 70 | 71 | if (isFloor == 1.0) { 72 | color = mix(vec3(1.0), vec3(0.8) * ambient, shadow); 73 | } 74 | 75 | gl_FragColor = vec4(color.rgb, 1.0); 76 | } -------------------------------------------------------------------------------- /rendering/binary_voxels_to_mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Creates a cube for every occupied (negative) voxel 4 | def create_binary_voxel_mesh(voxels_array, threshold = 0.0): 5 | voxels = np.pad(voxels_array, 1, mode = 'constant') 6 | mask = voxels < threshold 7 | 8 | # X 9 | x, y, z = np.where(mask[:-1,:,:] & ~mask[1:,:,:]) 10 | vertices = [ 11 | x + 1, y, z, 12 | x + 1, y + 1, z, 13 | x + 1, y, z + 1, 14 | 15 | x + 1, y + 1, z, 16 | x + 1, y + 1, z + 1, 17 | x + 1, y, z + 1 18 | ] 19 | vertex_arrays = [np.array(vertices).transpose().flatten()] 20 | normals = [np.tile(np.array([1, 0, 0]), 6 * x.shape[0])] 21 | 22 | 23 | x, y, z = np.where(~mask[:-1,:,:] & mask[1:,:,:]) 24 | vertices = [ 25 | x + 1, y + 1, z, 26 | x + 1, y, z, 27 | x + 1, y, z + 1, 28 | 29 | x + 1, y, z + 1, 30 | x + 1, y + 1, z + 1, 31 | x + 1, y + 1, z 32 | ] 33 | 34 | vertex_arrays.append(np.array(vertices).transpose().flatten()) 35 | normals.append(np.tile(np.array([-1, 0, 0]), 6 * x.shape[0])) 36 | 37 | # Y 38 | x, y, z = np.where(mask[:,:-1,:] & ~mask[:,1:,:]) 39 | vertices = [ 40 | x + 1, y + 1, z, 41 | x, y + 1, z, 42 | x, y + 1, z + 1, 43 | 44 | x + 1, y + 1, z + 1, 45 | x + 1, y + 1, z, 46 | x, y + 1, z + 1 47 | ] 48 | vertex_arrays.append(np.array(vertices).transpose().flatten()) 49 | normals.append(np.tile(np.array([0, 1, 0]), 6 * x.shape[0])) 50 | 51 | x, y, z = np.where(~mask[:,:-1,:] & mask[:,1:,:]) 52 | vertices = [ 53 | x, y + 1, z, 54 | x + 1, y + 1, z, 55 | x, y + 1, z + 1, 56 | 57 | x + 1, y + 1, z, 58 | x + 1, y + 1, z + 1, 59 | x, y + 1, z + 1 60 | ] 61 | vertex_arrays.append(np.array(vertices).transpose().flatten()) 62 | normals.append(np.tile(np.array([0, -1, 0]), 6 * x.shape[0])) 63 | 64 | # Z 65 | x, y, z = np.where(mask[:,:,:-1] & ~mask[:,:,1:]) 66 | vertices = [ 67 | x, y, z + 1, 68 | x + 1, y, z + 1, 69 | x, y + 1, z + 1, 70 | 71 | x + 1, y, z + 1, 72 | x + 1, y + 1, z + 1, 73 | x, y + 1, z + 1 74 | ] 75 | vertex_arrays.append(np.array(vertices).transpose().flatten()) 76 | normals.append(np.tile(np.array([0, 0, 1]), 6 * x.shape[0])) 77 | 78 | x, y, z = np.where(~mask[:,:,:-1] & mask[:,:,1:]) 79 | vertices = [ 80 | x + 1, y, z + 1, 81 | x, y, z + 1, 82 | x, y + 1, z + 1, 83 | 84 | x + 1, y + 1, z + 1, 85 | x + 1, y, z + 1, 86 | x, y + 1, z + 1 87 | ] 88 | vertex_arrays.append(np.array(vertices).transpose().flatten()) 89 | normals.append(np.tile(np.array([0, 0, -1]), 6 * x.shape[0])) 90 | 91 | return np.concatenate(vertex_arrays).astype(np.float32), np.concatenate(normals).astype(np.float32) -------------------------------------------------------------------------------- /model/gan.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from util import standard_normal_distribution 3 | 4 | class Generator(SavableModule): 5 | def __init__(self): 6 | super(Generator, self).__init__(filename="generator.to") 7 | 8 | self.layers = nn.Sequential( 9 | nn.ConvTranspose3d(in_channels = LATENT_CODE_SIZE, out_channels = 256, kernel_size = 4, stride = 1), 10 | nn.BatchNorm3d(256), 11 | nn.LeakyReLU(negative_slope = 0.2), 12 | 13 | nn.ConvTranspose3d(in_channels = 256, out_channels = 128, kernel_size = 4, stride = 2, padding = 1), 14 | nn.BatchNorm3d(128), 15 | nn.LeakyReLU(negative_slope = 0.2), 16 | 17 | nn.ConvTranspose3d(in_channels = 128, out_channels = 64, kernel_size = 4, stride = 2, padding = 1), 18 | nn.BatchNorm3d(64), 19 | nn.LeakyReLU(negative_slope = 0.2), 20 | 21 | nn.ConvTranspose3d(in_channels = 64, out_channels = 1, kernel_size = 4, stride = 2, padding = 1), 22 | nn.Tanh() 23 | ) 24 | 25 | self.cuda() 26 | 27 | def forward(self, x): 28 | x = x.reshape((-1, LATENT_CODE_SIZE, 1, 1, 1)) 29 | return self.layers(x) 30 | 31 | def generate(self, sample_size = 1): 32 | shape = torch.Size((sample_size, LATENT_CODE_SIZE)) 33 | x = standard_normal_distribution.sample(shape).to(self.device) 34 | return self(x) 35 | 36 | def copy_autoencoder_weights(self, autoencoder): 37 | def copy(source, destination): 38 | destination.load_state_dict(source.state_dict(), strict=False) 39 | 40 | raise Exception("Not implemented.") 41 | 42 | 43 | class Discriminator(SavableModule): 44 | def __init__(self): 45 | super(Discriminator, self).__init__(filename="discriminator.to") 46 | 47 | self.use_sigmoid = True 48 | self.layers = nn.Sequential( 49 | nn.Conv3d(in_channels = 1, out_channels = 64, kernel_size = 4, stride = 2, padding = 1), 50 | nn.LeakyReLU(negative_slope = 0.2), 51 | nn.Conv3d(in_channels = 64, out_channels = 128, kernel_size = 4, stride = 2, padding = 1), 52 | nn.LeakyReLU(negative_slope = 0.2), 53 | nn.Conv3d(in_channels = 128, out_channels = 256, kernel_size = 4, stride = 2, padding = 1), 54 | nn.LeakyReLU(negative_slope = 0.2), 55 | nn.Conv3d(in_channels = 256, out_channels = 1, kernel_size = 4, stride = 1), 56 | Lambda(lambda x: torch.sigmoid(x) if self.use_sigmoid else x) 57 | ) 58 | 59 | self.cuda() 60 | 61 | def forward(self, x): 62 | if (len(x.shape) < 5): 63 | x = x.unsqueeze(dim = 1) # add dimension for channels 64 | 65 | return self.layers(x).squeeze() 66 | 67 | def clip_weights(self, value): 68 | for parameter in self.parameters(): 69 | parameter.data.clamp_(-value, value) 70 | -------------------------------------------------------------------------------- /demo_data_preparation.py: -------------------------------------------------------------------------------- 1 | from mesh_to_sdf import scale_to_unit_sphere, get_surface_point_cloud 2 | from mesh_to_sdf.pyrender_wrapper import render_normal_and_depth_buffers 3 | from mesh_to_sdf.scan import get_camera_transform 4 | import pyrender 5 | import trimesh 6 | import skimage.measure 7 | import numpy as np 8 | import math 9 | from matplotlib import pyplot as plt 10 | from util import show_sdf_point_cloud 11 | 12 | MODEL_PATH = 'examples/chair.obj' 13 | 14 | def show_image(image, grayscale=False): 15 | from matplotlib import pyplot as plt 16 | plt.axis('off') 17 | if grayscale: 18 | plt.gray() 19 | plt.tight_layout() 20 | plt.imshow(image, interpolation='nearest') 21 | plt.show() 22 | 23 | mesh = trimesh.load(MODEL_PATH) 24 | mesh = scale_to_unit_sphere(mesh) 25 | 26 | scene = pyrender.Scene() 27 | scene.add(pyrender.Mesh.from_trimesh(mesh, smooth=False)) 28 | print("Now showing the input model as a triangle mesh.\nClose the window to continue.") 29 | pyrender.Viewer(scene, use_raymond_lighting=True) 30 | 31 | camera_transform = get_camera_transform(math.radians(-140), math.radians(-20)) 32 | camera = pyrender.PerspectiveCamera(yfov=2 * math.asin(1.0 / 2) * 0.97, aspectRatio=1.0, znear = 2 - 1.0, zfar = 2 + 1.0) 33 | normal_buffer, depth_buffer = render_normal_and_depth_buffers(mesh, camera, camera_transform, 1080) 34 | print("Now showing the normal map of a render of the mesh.\nClose the window to continue.") 35 | show_image(normal_buffer) 36 | 37 | print("Now showing the depth map of a render of the mesh.\nClose the window to continue.") 38 | show_image(depth_buffer, grayscale=True) 39 | 40 | surface_point_cloud = get_surface_point_cloud(mesh) 41 | 42 | print("Now showing the surface point cloud with normals.\nClose the window to continue.") 43 | surface_point_cloud.show() 44 | 45 | print('Calculating...') 46 | resolution = 800 47 | slice_position = 0.35 48 | points = np.meshgrid( 49 | np.linspace(slice_position, slice_position, 1), 50 | np.linspace(1, -1, resolution), 51 | np.linspace(-1, 1, resolution) 52 | ) 53 | 54 | points = np.stack(points).reshape(3, -1).transpose() 55 | sdf = surface_point_cloud.get_sdf_in_batches(points).reshape(1, resolution, resolution)[0, :, :] 56 | clip = 0.2 57 | sdf = np.clip(sdf, -clip, clip) / clip 58 | 59 | image = np.ones((resolution, resolution, 3)) 60 | image[:,:,:2][sdf > 0] = (1.0 - sdf[sdf > 0])[:, np.newaxis] 61 | image[:,:,1:][sdf < 0] = (1.0 + sdf[sdf < 0])[:, np.newaxis] 62 | image[np.abs(sdf) < 0.02, :] = 0 63 | print("Now showing a slice through the SDF of the model.\nClose the window to continue.") 64 | show_image(image) 65 | 66 | voxels = surface_point_cloud.get_voxels(voxel_resolution=64) 67 | vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(voxels, level=0) 68 | mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) 69 | print("Now showing a voxel volume reconstructed with Marching Cubes.\nClose the window to continue.") 70 | mesh.show() 71 | 72 | points, sdf = surface_point_cloud.sample_sdf_near_surface(number_of_points=150000) 73 | 74 | print("Now showing a point cloud of non-uniformly sampled SDF data. Negative distances are red, positive distances are blue.") 75 | show_sdf_point_cloud(points, sdf) 76 | -------------------------------------------------------------------------------- /shapenet_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | DATASET_DIRECTORY = "data/shapenet/" 5 | MIN_SAMPLES_PER_CATEGORY = 2000 6 | 7 | from util import device 8 | 9 | class ShapenetCategory(): 10 | def __init__(self, name, id, count): 11 | self.name = name 12 | self.id = id 13 | self.is_root = True 14 | self.children = [] 15 | self.count = count 16 | self.label = None 17 | 18 | def print(self, depth = 0): 19 | print(' ' * depth + self.name + '({:d})'.format(self.count)) 20 | for child in self.children: 21 | child.print(depth = depth + 1) 22 | 23 | def get_directory(self): 24 | return os.path.join(DATASET_DIRECTORY, str(self.id).rjust(8, '0')) 25 | 26 | class ShapenetMetadata(): 27 | def __init__(self): 28 | self.clip_sdf = True 29 | self.rescale_sdf = True 30 | 31 | self.load_categories() 32 | self.labels = None 33 | 34 | def load_categories(self): 35 | taxonomy_filename = os.path.join(DATASET_DIRECTORY, "taxonomy.json") 36 | if not os.path.isfile(taxonomy_filename): 37 | taxonomy_filename = 'examples/shapenet_taxonomy.json' 38 | file_content = open(taxonomy_filename).read() 39 | taxonomy = json.loads(file_content) 40 | categories = dict() 41 | for item in taxonomy: 42 | id = int(item['synsetId']) 43 | category = ShapenetCategory(item['name'], id, item['numInstances']) 44 | categories[id] = category 45 | 46 | for item in taxonomy: 47 | id = int(item['synsetId']) 48 | category = categories[id] 49 | for str_id in item["children"]: 50 | child_id = int(str_id) 51 | category.children.append(categories[child_id]) 52 | categories[child_id].is_root = False 53 | 54 | self.categories = [item for item in categories.values() if item.is_root and item.count >= MIN_SAMPLES_PER_CATEGORY] 55 | self.categories = sorted(self.categories, key=lambda item: item.id) 56 | self.categories_by_id = {item.id : item for item in self.categories} 57 | self.label_count = len(self.categories) 58 | for i in range(len(self.categories)): 59 | self.categories[i].label = i 60 | 61 | def get_color(self, label): 62 | if label == 2: 63 | return (0.9, 0.1, 0.14) # red 64 | elif label == 1: 65 | return (0.8, 0.7, 0.1) # yellow 66 | elif label == 6: 67 | return (0.05, 0.5, 0.05) # green 68 | elif label == 5: 69 | return (0.1, 0.2, 0.9) # blue 70 | elif label == 4: 71 | return (0.46, 0.1, 0.9) # purple 72 | elif label == 3: 73 | return (0.9, 0.1, 0.673) # purple 74 | elif label == 0: 75 | return (0.01, 0.6, 0.9) # cyan 76 | else: 77 | return (0.7, 0.7, 0.7) 78 | 79 | shapenet = ShapenetMetadata() 80 | 81 | if __name__ == "__main__": 82 | for category in sorted(shapenet.categories, key=lambda c: -c.count): 83 | print('{:d}: {:s} - {:d}'.format( 84 | category.label, 85 | category.name, 86 | category.count)) -------------------------------------------------------------------------------- /train_point_gan.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.optim import RMSprop 7 | 8 | from datasets import PointDataset 9 | from model.point_sdf_net import PointNet, SDFGenerator 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--category', type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | LATENT_SIZE = 128 16 | GRADIENT_PENALITY = 10 17 | HIDDEN_SIZE = 256 18 | NUM_LAYERS = 8 19 | NORM = True 20 | 21 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 22 | G = SDFGenerator(LATENT_SIZE, HIDDEN_SIZE, NUM_LAYERS, NORM, dropout=0.0) 23 | D = PointNet(out_channels=1) 24 | G, D = G.to(device), D.to(device) 25 | G_optimizer = RMSprop(G.parameters(), lr=0.0001) 26 | D_optimizer = RMSprop(D.parameters(), lr=0.0001) 27 | 28 | root = osp.join(f'data/{args.category}') 29 | dataset = PointDataset.from_split(root, split='train') 30 | 31 | configuration = [ # num_points, batch_size, epochs 32 | (1024, 32, 300), 33 | (2048, 32, 300), 34 | (4096, 32, 300), 35 | (8192, 24, 300), 36 | (16384, 12, 300), 37 | (32768, 6, 900), 38 | ] 39 | 40 | num_steps = 0 41 | for num_points, batch_size, epochs in configuration: 42 | dataset.num_points = num_points 43 | loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=6) 44 | 45 | for epoch in range(1, epochs + 1): 46 | total_loss = 0 47 | for uniform, _ in loader: 48 | num_steps += 1 49 | 50 | uniform = uniform.to(device) 51 | u_pos, u_dist = uniform[..., :3], uniform[..., 3:] 52 | 53 | D_optimizer.zero_grad() 54 | 55 | z = torch.randn(uniform.size(0), LATENT_SIZE, device=device) 56 | fake = G(u_pos, z) 57 | out_real = D(u_pos, u_dist) 58 | out_fake = D(u_pos, fake) 59 | D_loss = out_fake.mean() - out_real.mean() 60 | 61 | alpha = torch.rand((uniform.size(0), 1, 1), device=device) 62 | interpolated = alpha * u_dist + (1 - alpha) * fake 63 | interpolated.requires_grad_(True) 64 | out = D(u_pos, interpolated) 65 | 66 | grad = torch.autograd.grad(out, interpolated, 67 | grad_outputs=torch.ones_like(out), 68 | create_graph=True, retain_graph=True, 69 | only_inputs=True)[0] 70 | grad_norm = grad.view(grad.size(0), -1).norm(dim=-1, p=2) 71 | gp = GRADIENT_PENALITY * ((grad_norm - 1).pow(2).mean()) 72 | 73 | loss = D_loss + gp 74 | loss.backward() 75 | D_optimizer.step() 76 | 77 | if num_steps % 5 == 0: 78 | G_optimizer.zero_grad() 79 | z = torch.randn(uniform.size(0), LATENT_SIZE, device=device) 80 | fake = G(u_pos, z) 81 | out_fake = D(u_pos, fake) 82 | loss = -out_fake.mean() 83 | loss.backward() 84 | G_optimizer.step() 85 | 86 | total_loss += D_loss.abs().item() 87 | 88 | print('Num points: {}, Epoch: {:03d}, Loss: {:.6f}'.format( 89 | num_points, epoch, total_loss / len(loader))) 90 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 3 | standard_normal_distribution = torch.distributions.normal.Normal(0, 1) 4 | import numpy as np 5 | import os 6 | 7 | def ensure_directory(directory): 8 | if not os.path.exists(directory): 9 | os.makedirs(directory) 10 | 11 | ensure_directory('plots') 12 | ensure_directory('models') 13 | ensure_directory('data') 14 | 15 | CHARACTERS = ' `.-:/+osyhdmm###############' 16 | 17 | def create_text_slice(voxels): 18 | voxel_resolution = voxels.shape[-1] 19 | center = voxels.shape[-1] // 4 20 | data = voxels[center, :, :] 21 | data = torch.clamp(data * -0.5 + 0.5, 0, 1) * (len(CHARACTERS) - 1) 22 | data = data.type(torch.int).cpu() 23 | lines = ['|' + ''.join([CHARACTERS[i] for i in line]) + '|' for line in data] 24 | result = [] 25 | for i in range(voxel_resolution): 26 | if len(result) < i / 2.2: 27 | result.append(lines[i]) 28 | frame = '+' + '—' * voxel_resolution + '+\n' 29 | return frame + '\n'.join(reversed(result)) + '\n' + frame 30 | 31 | 32 | def get_points_in_unit_sphere(n, device): 33 | x = torch.rand(int(n * 2.5), 3, device=device) * 2 - 1 34 | mask = (torch.norm(x, dim=1) < 1).nonzero().squeeze() 35 | mask = mask[:n] 36 | x = x[mask, :] 37 | if x.shape[0] < n: 38 | print("Warning: Did not find enough points.") 39 | return x 40 | 41 | def crop_image(image, background=255): 42 | mask = image[:, :] != background 43 | coords = np.array(np.nonzero(mask)) 44 | 45 | if coords.size != 0: 46 | top_left = np.min(coords, axis=1) 47 | bottom_right = np.max(coords, axis=1) 48 | else: 49 | top_left = np.array((0, 0)) 50 | bottom_right = np.array(image.shape) 51 | print("Warning: Image contains only background pixels.") 52 | 53 | half_size = int(max(bottom_right[0] - top_left[0], bottom_right[1] - top_left[1]) / 2) 54 | center = ((top_left + bottom_right) / 2).astype(int) 55 | center = (min(max(half_size, center[0]), image.shape[0] - half_size), min(max(half_size, center[1]), image.shape[1] - half_size)) 56 | if half_size > 100: 57 | image = image[center[0] - half_size : center[0] + half_size, center[1] - half_size : center[1] + half_size] 58 | return image 59 | 60 | def get_voxel_coordinates(resolution = 32, size=1, center=0, return_torch_tensor=False): 61 | if type(center) == int: 62 | center = (center, center, center) 63 | points = np.meshgrid( 64 | np.linspace(center[0] - size, center[0] + size, resolution), 65 | np.linspace(center[1] - size, center[1] + size, resolution), 66 | np.linspace(center[2] - size, center[2] + size, resolution) 67 | ) 68 | points = np.stack(points) 69 | points = np.swapaxes(points, 1, 2) 70 | points = points.reshape(3, -1).transpose() 71 | if return_torch_tensor: 72 | return torch.tensor(points, dtype=torch.float32, device=device) 73 | else: 74 | return points.astype(np.float32) 75 | 76 | def show_sdf_point_cloud(points, sdf): 77 | import pyrender 78 | colors = np.zeros(points.shape) 79 | colors[sdf < 0, 2] = 1 80 | colors[sdf > 0, 0] = 1 81 | cloud = pyrender.Mesh.from_points(points, colors=colors) 82 | 83 | scene = pyrender.Scene() 84 | scene.add(cloud) 85 | viewer = pyrender.Viewer(scene, use_raymond_lighting=True, point_size=2) -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import os 4 | import numpy as np 5 | 6 | 7 | class VoxelDataset(Dataset): 8 | def __init__(self, files, clamp=0.1, rescale_sdf=True): 9 | self.files = files 10 | self.clamp = clamp 11 | self.rescale_sdf = rescale_sdf 12 | 13 | def __len__(self): 14 | return len(self.files) 15 | 16 | def __getitem__(self, index): 17 | array = np.load(self.files[index]) 18 | result = torch.from_numpy(array) 19 | if self.clamp is not None: 20 | result.clamp_(-self.clamp, self.clamp) 21 | if self.rescale_sdf: 22 | result /= self.clamp 23 | return result 24 | 25 | @staticmethod 26 | def glob(pattern): 27 | import glob 28 | files = glob.glob(pattern, recursive=True) 29 | if len(files) == 0: 30 | raise Exception( 31 | 'No files found for glob pattern {:s}.'.format(pattern)) 32 | return VoxelDataset(sorted(files)) 33 | 34 | @staticmethod 35 | def from_split(pattern, split_file_name): 36 | split_file = open(split_file_name, 'r') 37 | ids = split_file.readlines() 38 | files = [pattern.format(id.strip()) for id in ids] 39 | files = [file for file in files if os.path.exists(file)] 40 | return VoxelDataset(files) 41 | 42 | def show(self): 43 | from rendering import MeshRenderer 44 | import time 45 | from tqdm import tqdm 46 | 47 | viewer = MeshRenderer() 48 | for item in tqdm(self): 49 | viewer.set_voxels(item.numpy()) 50 | time.sleep(0.5) 51 | 52 | 53 | class PointDataset(Dataset): 54 | def __init__(self, root, filenames, num_points=1024, transform=None): 55 | self.root = os.path.expanduser(os.path.join(os.path.normpath(root))) 56 | self.filenames = filenames 57 | self.num_points = num_points 58 | assert 0 < self.num_points <= 64**3 59 | self.transform = transform 60 | 61 | def __len__(self): 62 | return len(self.filenames) 63 | 64 | def __getitem__(self, idx): 65 | name = self.filenames[idx] 66 | 67 | uniform = os.path.join(self.root, 'uniform', f'{name}.npy') 68 | uniform = torch.from_numpy(np.load(uniform)) 69 | 70 | surface = os.path.join(self.root, 'surface', f'{name}.npy') 71 | surface = torch.from_numpy(np.load(surface)) 72 | 73 | # Sample a subset of points. 74 | sample = np.random.choice(uniform.size(0), self.num_points) 75 | uniform, surface = uniform[sample], surface[sample] 76 | 77 | data = (uniform, surface) 78 | 79 | if self.transform is not None: 80 | data = self.transform(data) 81 | 82 | return data 83 | 84 | @staticmethod 85 | def from_split(root, split, num_points=1024, transform=None): 86 | with open(os.path.join(root, f'{split}.txt'), 'r') as f: 87 | filenames = f.read().split('\n') 88 | if filenames[-1] == '': 89 | filenames = filenames[:-1] 90 | return PointDataset(root, filenames, num_points, transform) 91 | 92 | 93 | if __name__ == '__main__': 94 | # dataset = VoxelDataset.glob('data/chairs/voxels_64/') 95 | dataset = VoxelDataset.from_split( 96 | 'data/chairs/voxels_{:d}/{{:s}}.npy'.format(64), 97 | 'data/chairs/train.txt') 98 | dataset.show() 99 | 100 | dataset = PointDataset.from_split('data/chairs', 'train') 101 | -------------------------------------------------------------------------------- /model/point_sdf_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear, Sequential, ReLU, LayerNorm 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from torch_scatter import scatter_max 7 | except ImportError: 8 | scatter_max = None 9 | 10 | 11 | class PointNet(torch.nn.Module): 12 | def __init__(self, out_channels): 13 | super(PointNet, self).__init__() 14 | 15 | self.nn1 = Sequential( 16 | Linear(4, 64), 17 | ReLU(), 18 | Linear(64, 128), 19 | ReLU(), 20 | Linear(128, 256), 21 | ReLU(), 22 | Linear(256, 512), 23 | ) 24 | 25 | self.nn2 = Sequential( 26 | Linear(512, 256), 27 | ReLU(), 28 | Linear(256, 128), 29 | ReLU(), 30 | Linear(128, out_channels), 31 | ) 32 | 33 | def forward(self, pos, dist, batch=None): 34 | dist = dist.unsqueeze(-1) if dist.size(-1) != 1 else dist 35 | 36 | x = torch.cat([pos, dist], dim=-1) 37 | 38 | x = self.nn1(x) 39 | 40 | if batch is None: 41 | x = x.max(dim=-2)[0] 42 | else: 43 | x = scatter_max(x, batch, dim=-2)[0] 44 | 45 | x = self.nn2(x) 46 | 47 | return x 48 | 49 | 50 | class SDFGenerator(torch.nn.Module): 51 | def __init__(self, latent_channels, hidden_channels, num_layers, norm=True, 52 | dropout=0.0): 53 | super(SDFGenerator, self).__init__() 54 | 55 | self.layers1 = None 56 | self.layers2 = None 57 | 58 | assert num_layers % 2 == 0 59 | 60 | self.latent_channels = latent_channels 61 | self.hidden_channels = hidden_channels 62 | self.num_layers = num_layers 63 | self.norm = norm 64 | self.dropout = dropout 65 | 66 | in_channels = 3 67 | out_channels = hidden_channels 68 | 69 | self.lins = torch.nn.ModuleList() 70 | self.norms = torch.nn.ModuleList() 71 | for i in range(num_layers): 72 | self.lins.append(Linear(in_channels, out_channels)) 73 | self.norms.append(LayerNorm(out_channels)) 74 | 75 | if i == (num_layers // 2) - 1: 76 | in_channels = hidden_channels + 3 77 | else: 78 | in_channels = hidden_channels 79 | 80 | if i == num_layers - 2: 81 | out_channels = 1 82 | 83 | self.z_lin1 = Linear(latent_channels, hidden_channels) 84 | self.z_lin2 = Linear(latent_channels, hidden_channels) 85 | 86 | def forward(self, pos, z): 87 | # pos: [batch_size, num_points, 3] 88 | # z: [batch_size, latent_channels] 89 | 90 | pos = pos.unsqueeze(0) if pos.dim() == 2 else pos 91 | 92 | assert pos.dim() == 3 93 | assert pos.size(-1) == 3 94 | 95 | z = z.unsqueeze(0) if z.dim() == 1 else z 96 | assert z.dim() == 2 97 | assert z.size(-1) == self.latent_channels 98 | 99 | assert pos.size(0) == z.size(0) 100 | 101 | x = pos 102 | for i, (lin, norm) in enumerate(zip(self.lins, self.norms)): 103 | if i == self.num_layers // 2: 104 | x = torch.cat([x, pos], dim=-1) 105 | 106 | x = lin(x) 107 | 108 | if i == 0: 109 | x = self.z_lin1(z).unsqueeze(1) + x 110 | 111 | if i == self.num_layers // 2: 112 | x = self.z_lin2(z).unsqueeze(1) + x 113 | 114 | if i < self.num_layers - 1: 115 | x = norm(x) if self.norm else x 116 | x = F.relu(x) 117 | x = F.dropout(x, p=self.dropout, training=self.training) 118 | 119 | return x 120 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from model.sdf_net import SDFNet, LATENT_CODE_SIZE 2 | import numpy as np 3 | from util import device, standard_normal_distribution 4 | from tqdm import tqdm 5 | import sys 6 | import torch 7 | import skimage.measure 8 | import trimesh 9 | 10 | LEVEL = 0 11 | 12 | def rescale_point_cloud(point_cloud, method=None): 13 | if method == 'half_unit_sphere': 14 | point_cloud /= np.linalg.norm(point_cloud, axis=1).max() * 2 15 | elif method == 'half_unit_cube': 16 | point_cloud /= np.abs(point_cloud).max() * 2 17 | 18 | def sample_point_clouds(sdf_net, sample_count, point_cloud_size, voxel_resolution=128, rescale='half_unit_sphere', latent_codes=None): 19 | result = np.zeros((sample_count, point_cloud_size, 3)) 20 | if latent_codes is None: 21 | latent_codes = standard_normal_distribution.sample((sample_count, LATENT_CODE_SIZE)).to(device) 22 | for i in tqdm(range(sample_count)): 23 | try: 24 | point_cloud = sdf_net.get_uniform_surface_points(latent_codes[i, :], point_count=point_cloud_size, voxel_resolution=voxel_resolution, sphere_only=False, level=LEVEL) 25 | rescale_point_cloud(point_cloud, method=rescale) 26 | result[i, :, :] = point_cloud 27 | except AttributeError: 28 | print("Warning: Empty mesh.") 29 | return result 30 | 31 | def sample_from_voxels(voxels, point_cloud_size, rescale='half_unit_sphere'): 32 | result = np.zeros((voxels.shape[0], point_cloud_size, 3)) 33 | size = 2 34 | voxel_resolution = voxels.shape[1] 35 | for i in tqdm(range(voxels.shape[0])): 36 | voxels_current = voxels[i, :, :, :] 37 | voxels_current = np.pad(voxels_current, 1, mode='constant', constant_values=1) 38 | 39 | vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(voxels_current, level=0, spacing=(size / voxel_resolution, size / voxel_resolution, size / voxel_resolution)) 40 | vertices -= size / 2 41 | mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) 42 | point_cloud = mesh.sample(point_cloud_size) 43 | rescale_point_cloud(point_cloud, method=rescale) 44 | result[i, :, :] = point_cloud 45 | return result 46 | 47 | 48 | if 'sample' in sys.argv: 49 | sdf_net = SDFNet() 50 | sdf_net.filename = 'hybrid_gan_generator.to' 51 | sdf_net.load() 52 | sdf_net.eval() 53 | 54 | clouds = sample_point_clouds(sdf_net, 1000, 2048, voxel_resolution=32) 55 | np.save('data/generated_point_cloud_sample.npy', clouds) 56 | 57 | if 'checkpoints' in sys.argv: 58 | import glob 59 | from tqdm import tqdm 60 | torch.manual_seed(1234) 61 | files = glob.glob('models/checkpoints/hybrid_progressive_gan_generator_2-epoch-*.to', recursive=True) 62 | latent_codes = standard_normal_distribution.sample((50, LATENT_CODE_SIZE)).to(device) 63 | for filename in tqdm(files): 64 | epoch_id = filename[61:-3] 65 | sdf_net = SDFNet() 66 | sdf_net.filename = filename[7:] 67 | sdf_net.load() 68 | sdf_net.eval() 69 | 70 | clouds = sample_point_clouds(sdf_net, 50, 2048, voxel_resolution=64, latent_codes=latent_codes) 71 | np.save('data/chairs/results/voxels_{:s}.npy'.format(epoch_id), clouds) 72 | 73 | 74 | if 'dataset' in sys.argv: 75 | from datasets import VoxelDataset 76 | dataset = VoxelDataset.from_split('data/airplanes/voxels_64/{:s}.npy', 'data/airplanes/val.txt') 77 | from torch.utils.data import DataLoader 78 | voxels = next(iter(DataLoader(dataset, batch_size=100, shuffle=True))) 79 | print(voxels.shape) 80 | clouds = sample_from_voxels(voxels, 2048) 81 | np.save('data/dataset_airplanes_point_cloud_sample.npy', clouds) 82 | 83 | 84 | if 'test' in sys.argv: 85 | import pyrender 86 | data = np.load('data/dataset_point_cloud_sample.npy') 87 | for i in range(data.shape[0]): 88 | points = data[i, :, :] 89 | scene = pyrender.Scene() 90 | scene.add(pyrender.Mesh.from_points(points)) 91 | pyrender.Viewer(scene, use_raymond_lighting=True, point_size=8) -------------------------------------------------------------------------------- /train_sdf_autodecoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | import numpy as np 7 | from itertools import count 8 | import time 9 | import random 10 | from tqdm import tqdm 11 | import sys 12 | 13 | from model.sdf_net import SDFNet, LATENT_CODE_SIZE, LATENT_CODES_FILENAME 14 | from util import device 15 | 16 | if "nogui" not in sys.argv: 17 | from rendering import MeshRenderer 18 | viewer = MeshRenderer() 19 | 20 | POINTCLOUD_SIZE = 200000 21 | 22 | points = torch.load('data/sdf_points.to').to(device) 23 | sdf = torch.load('data/sdf_values.to').to(device) 24 | 25 | MODEL_COUNT = points.shape[0] // POINTCLOUD_SIZE 26 | BATCH_SIZE = 20000 27 | SDF_CUTOFF = 0.1 28 | sdf.clamp_(-SDF_CUTOFF, SDF_CUTOFF) 29 | signs = sdf.cpu().numpy() > 0 30 | 31 | SIGMA = 0.01 32 | 33 | LOG_FILE_NAME = "plots/sdf_net_training.csv" 34 | 35 | sdf_net = SDFNet() 36 | if "continue" in sys.argv: 37 | sdf_net.load() 38 | latent_codes = torch.load(LATENT_CODES_FILENAME).to(device) 39 | else: 40 | normal_distribution = torch.distributions.normal.Normal(0, 0.0001) 41 | latent_codes = normal_distribution.sample((MODEL_COUNT, LATENT_CODE_SIZE)).to(device) 42 | latent_codes.requires_grad = True 43 | 44 | network_optimizer = optim.Adam(sdf_net.parameters(), lr=1e-5) 45 | latent_code_optimizer = optim.Adam([latent_codes], lr=1e-5) 46 | criterion = nn.MSELoss() 47 | 48 | first_epoch = 0 49 | if 'continue' in sys.argv: 50 | log_file_contents = open(LOG_FILE_NAME, 'r').readlines() 51 | first_epoch = len(log_file_contents) 52 | 53 | log_file = open(LOG_FILE_NAME, "a" if "continue" in sys.argv else "w") 54 | 55 | def create_batches(): 56 | indices_positive = np.nonzero(signs)[0] 57 | indices_negative = np.nonzero(~signs)[0] 58 | if indices_negative.shape[0] > indices_positive.shape[0]: 59 | np.random.shuffle(indices_negative) 60 | indices_negative = indices_negative[:indices_positive.shape[0]] 61 | else: 62 | np.random.shuffle(indices_positive) 63 | indices_positive = indices_positive[:indices_negative.shape[0]] 64 | indices = np.concatenate((indices_negative, indices_positive)) 65 | np.random.shuffle(indices) 66 | batch_count = int(indices.shape[0] / BATCH_SIZE) 67 | for i in range(batch_count - 1): 68 | yield indices[i * BATCH_SIZE:(i+1)*BATCH_SIZE] 69 | yield indices[(batch_count - 1) * BATCH_SIZE:] 70 | 71 | def train(): 72 | for epoch in count(start=first_epoch): 73 | epoch_start_time = time.time() 74 | loss_values = [] 75 | batch_index = 0 76 | for batch in tqdm(list(create_batches())): 77 | indices = torch.tensor(batch, device = device) 78 | model_indices = indices / POINTCLOUD_SIZE 79 | 80 | batch_latent_codes = latent_codes[model_indices, :] 81 | batch_points = points[indices, :] 82 | batch_sdf = sdf[indices] 83 | 84 | sdf_net.zero_grad() 85 | if latent_codes.grad is not None: 86 | latent_codes.grad.data.zero_() 87 | output = sdf_net.forward(batch_points, batch_latent_codes) 88 | loss = torch.mean(torch.abs(output - batch_sdf)) + SIGMA * torch.mean(torch.pow(batch_latent_codes, 2)) 89 | loss.backward() 90 | network_optimizer.step() 91 | latent_code_optimizer.step() 92 | loss_values.append(loss.item()) 93 | 94 | if batch_index % 400 == 0 and "nogui" not in sys.argv: 95 | try: 96 | viewer.set_mesh(sdf_net.get_mesh(latent_codes[random.randrange(MODEL_COUNT), :])) 97 | except ValueError: 98 | pass 99 | 100 | batch_index += 1 101 | 102 | variance = np.var(latent_codes.detach().reshape(-1).cpu().numpy()) ** 0.5 103 | epoch_duration = time.time() - epoch_start_time 104 | 105 | print("Epoch {:d}, {:.1f}s. Loss: {:.8f}".format(epoch, epoch_duration, np.mean(loss_values))) 106 | 107 | sdf_net.save() 108 | torch.save(latent_codes, LATENT_CODES_FILENAME) 109 | 110 | sdf_net.save(epoch=epoch) 111 | torch.save(latent_codes, sdf_net.get_filename(epoch=epoch, filename='sdf_net_latent_codes.to')) 112 | 113 | log_file.write('{:d} {:.1f} {:.6f} {:.6f}\n'.format(epoch, epoch_duration, np.mean(loss_values), variance)) 114 | log_file.flush() 115 | 116 | train() -------------------------------------------------------------------------------- /train_wgan.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import numpy as np 7 | 8 | import random 9 | import time 10 | import sys 11 | from collections import deque 12 | 13 | from model.gan import Generator, Discriminator 14 | from util import device 15 | 16 | from util import create_text_slice 17 | from datasets import VoxelDataset 18 | from torch.utils.data import DataLoader 19 | 20 | show_viewer = "nogui" not in sys.argv 21 | 22 | if show_viewer: 23 | from rendering import MeshRenderer 24 | viewer = MeshRenderer() 25 | 26 | generator = Generator() 27 | generator.filename = "wgan-generator.to" 28 | 29 | critic = Discriminator() 30 | critic.filename = "wgan-critic.to" 31 | critic.use_sigmoid = False 32 | 33 | if "continue" in sys.argv: 34 | generator.load() 35 | critic.load() 36 | 37 | LEARN_RATE = 0.00005 38 | BATCH_SIZE = 64 39 | CRITIC_UPDATES_PER_GENERATOR_UPDATE = 5 40 | CRITIC_WEIGHT_LIMIT = 0.01 41 | 42 | dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy') 43 | data_loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=8) 44 | 45 | generator_optimizer = optim.RMSprop(generator.parameters(), lr=LEARN_RATE) 46 | critic_optimizer = optim.RMSprop(critic.parameters(), lr=LEARN_RATE) 47 | 48 | log_file = open("plots/wgan_training.csv", "a" if "continue" in sys.argv else "w") 49 | 50 | def train(): 51 | history_fake = deque(maxlen=50) 52 | history_real = deque(maxlen=50) 53 | 54 | for epoch in count(): 55 | batch_index = 0 56 | epoch_start_time = time.time() 57 | for batch in data_loader: 58 | try: 59 | # train critic 60 | current_batch_size = batch.shape[0] # equals BATCH_SIZE for all batches except the last one 61 | 62 | generator.zero_grad() 63 | critic.zero_grad() 64 | 65 | fake_sample = generator.generate(sample_size = current_batch_size).detach() 66 | fake_critic_output = critic(fake_sample) 67 | valid_critic_output = critic(batch.to(device)) 68 | critic_loss = torch.mean(fake_critic_output) - torch.mean(valid_critic_output) 69 | critic_loss.backward() 70 | critic_optimizer.step() 71 | critic.clip_weights(CRITIC_WEIGHT_LIMIT) 72 | 73 | # train generator 74 | if batch_index % CRITIC_UPDATES_PER_GENERATOR_UPDATE == 0: 75 | generator.zero_grad() 76 | critic.zero_grad() 77 | 78 | fake_sample = generator.generate(sample_size = BATCH_SIZE) 79 | if show_viewer: 80 | viewer.set_voxels(fake_sample[0, :, :, :].squeeze().detach().cpu().numpy()) 81 | fake_critic_output = critic(fake_sample) 82 | generator_loss = -torch.mean(fake_critic_output) 83 | generator_loss.backward() 84 | generator_optimizer.step() 85 | 86 | history_fake.append(torch.mean(fake_critic_output).item()) 87 | history_real.append(torch.mean(valid_critic_output).item()) 88 | if "verbose" in sys.argv: 89 | print("epoch " + str(epoch) + ", batch " + str(batch_index) \ 90 | + ": fake value: " + '{0:.1f}'.format(history_fake[-1]) \ 91 | + ", valid value: " + '{0:.1f}'.format(history_real[-1])) 92 | batch_index += 1 93 | except KeyboardInterrupt: 94 | if show_viewer: 95 | viewer.stop() 96 | return 97 | 98 | generator.save() 99 | critic.save() 100 | 101 | if epoch % 20 == 0: 102 | generator.save(epoch=epoch) 103 | critic.save(epoch=epoch) 104 | 105 | if "show_slice" in sys.argv: 106 | voxels = generator.generate().squeeze() 107 | print(create_text_slice(voxels)) 108 | 109 | epoch_duration = time.time() - epoch_start_time 110 | fake_prediction = np.mean(history_fake) 111 | valid_prediction = np.mean(history_real) 112 | print('Epoch {:d} ({:.1f}s), critic values: {:.2f}, {:.2f}'.format( 113 | epoch, epoch_duration, fake_prediction, valid_prediction)) 114 | log_file.write("{:d} {:.1f} {:.2f} {:.2f}\n".format( 115 | epoch, epoch_duration, fake_prediction, valid_prediction)) 116 | log_file.flush() 117 | 118 | 119 | train() 120 | -------------------------------------------------------------------------------- /demo_sdf_net.py: -------------------------------------------------------------------------------- 1 | from model.sdf_net import SDFNet, LATENT_CODE_SIZE, LATENT_CODES_FILENAME 2 | from util import device, standard_normal_distribution, ensure_directory 3 | import scipy.interpolate 4 | import numpy as np 5 | from rendering import MeshRenderer 6 | import time 7 | import torch 8 | from tqdm import tqdm 9 | import cv2 10 | import random 11 | import sys 12 | 13 | SAMPLE_COUNT = 30 # Number of distinct objects to generate and interpolate between 14 | TRANSITION_FRAMES = 60 15 | 16 | ROTATE_MODEL = False 17 | USE_HYBRID_GAN = True 18 | 19 | SURFACE_LEVEL = 0.04 if USE_HYBRID_GAN else 0.011 20 | 21 | sdf_net = SDFNet() 22 | if USE_HYBRID_GAN: 23 | sdf_net.filename = 'hybrid_progressive_gan_generator_3.to' 24 | sdf_net.load() 25 | sdf_net.eval() 26 | 27 | if USE_HYBRID_GAN: 28 | codes = standard_normal_distribution.sample((SAMPLE_COUNT + 1, LATENT_CODE_SIZE)).numpy() 29 | else: 30 | latent_codes = torch.load(LATENT_CODES_FILENAME).detach().cpu().numpy() 31 | indices = random.sample(list(range(latent_codes.shape[0])), SAMPLE_COUNT + 1) 32 | codes = latent_codes[indices, :] 33 | 34 | codes[0, :] = codes[-1, :] # Make animation periodic 35 | spline = scipy.interpolate.CubicSpline(np.arange(SAMPLE_COUNT + 1), codes, axis=0, bc_type='periodic') 36 | 37 | def create_image_sequence(): 38 | ensure_directory('images') 39 | frame_index = 0 40 | viewer = MeshRenderer(size=1080, start_thread=False) 41 | progress_bar = tqdm(total=SAMPLE_COUNT * TRANSITION_FRAMES) 42 | 43 | for sample_index in range(SAMPLE_COUNT): 44 | for step in range(TRANSITION_FRAMES): 45 | code = torch.tensor(spline(float(sample_index) + step / TRANSITION_FRAMES), dtype=torch.float32, device=device) 46 | if ROTATE_MODEL: 47 | viewer.rotation = (147 + frame_index / (SAMPLE_COUNT * TRANSITION_FRAMES) * 360 * 6, 40) 48 | viewer.set_mesh(sdf_net.get_mesh(code, voxel_resolution=128, sphere_only=False, level=SURFACE_LEVEL)) 49 | image = viewer.get_image(flip_red_blue=True) 50 | cv2.imwrite("images/frame-{:05d}.png".format(frame_index), image) 51 | frame_index += 1 52 | progress_bar.update() 53 | 54 | print("\n\nUse this command to create a video:\n") 55 | print('ffmpeg -framerate 30 -i images/frame-%05d.png -c:v libx264 -profile:v high -crf 19 -pix_fmt yuv420p video.mp4') 56 | 57 | def show_models(): 58 | TRANSITION_TIME = 2 59 | viewer = MeshRenderer() 60 | 61 | while True: 62 | for sample_index in range(SAMPLE_COUNT): 63 | try: 64 | start = time.perf_counter() 65 | end = start + TRANSITION_TIME 66 | while time.perf_counter() < end: 67 | progress = min((time.perf_counter() - start) / TRANSITION_TIME, 1.0) 68 | if ROTATE_MODEL: 69 | viewer.rotation = (147 + (sample_index + progress) / SAMPLE_COUNT * 360 * 6, 40) 70 | code = torch.tensor(spline(float(sample_index) + progress), dtype=torch.float32, device=device) 71 | viewer.set_mesh(sdf_net.get_mesh(code, voxel_resolution=64, sphere_only=False, level=SURFACE_LEVEL)) 72 | 73 | except KeyboardInterrupt: 74 | viewer.stop() 75 | return 76 | 77 | def create_objects(): 78 | from util import ensure_directory 79 | from rendering.raymarching import render_image 80 | from rendering.math import get_rotation_matrix 81 | import os 82 | ensure_directory('generated_objects/') 83 | image_filename = 'generated_objects/chair-{:03d}.png' 84 | mesh_filename = 'generated_objects/chair-{:03d}.stl' 85 | index = 0 86 | while True: 87 | if os.path.exists(image_filename.format(index)) or os.path.exists(mesh_filename.format(index)): 88 | index += 1 89 | continue 90 | latent_code = standard_normal_distribution.sample((LATENT_CODE_SIZE,)).to(device) 91 | image = render_image(sdf_net, latent_code, resolution=128, sdf_offset=-SURFACE_LEVEL, ssaa=2, radius=1.4, color=(0.7, 0.7, 0.7)) 92 | image.save(image_filename.format(index)) 93 | mesh = sdf_net.get_mesh(latent_code, voxel_resolution=256, sphere_only=False, level=SURFACE_LEVEL) 94 | mesh.apply_transform(get_rotation_matrix(90, 'x')) 95 | mesh.apply_translation((0, 0, -np.min(mesh.vertices[:, 2]))) 96 | mesh.export(mesh_filename.format(index)) 97 | print("Created mesh for index {:d}".format(index)) 98 | index += 1 99 | 100 | 101 | if 'save' in sys.argv: 102 | create_image_sequence() 103 | elif 'create_objects' in sys.argv: 104 | create_objects() 105 | else: 106 | show_models() -------------------------------------------------------------------------------- /model/autoencoder.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from util import standard_normal_distribution 3 | 4 | AUTOENCODER_MODEL_COMPLEXITY_MULTIPLIER = 24 5 | amcm = AUTOENCODER_MODEL_COMPLEXITY_MULTIPLIER 6 | 7 | class Autoencoder(SavableModule): 8 | def __init__(self, is_variational = True): 9 | super(Autoencoder, self).__init__(filename="autoencoder-{:d}.to".format(LATENT_CODE_SIZE)) 10 | 11 | self.is_variational = is_variational 12 | if is_variational: 13 | self.filename = 'variational-' + self.filename 14 | 15 | self.encoder = nn.Sequential( 16 | nn.Conv3d(in_channels = 1, out_channels = 1 * amcm, kernel_size = 4, stride = 2, padding = 1), 17 | nn.BatchNorm3d(1 * amcm), 18 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 19 | 20 | nn.Conv3d(in_channels = 1 * amcm, out_channels = 2 * amcm, kernel_size = 4, stride = 2, padding = 1), 21 | nn.BatchNorm3d(2 * amcm), 22 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 23 | 24 | nn.Conv3d(in_channels = 2 * amcm, out_channels = 4 * amcm, kernel_size = 4, stride = 2, padding = 1), 25 | nn.BatchNorm3d(4 * amcm), 26 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 27 | 28 | nn.Conv3d(in_channels = 4 * amcm, out_channels = LATENT_CODE_SIZE * 2, kernel_size = 4, stride = 1), 29 | nn.BatchNorm3d(LATENT_CODE_SIZE * 2), 30 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 31 | 32 | Lambda(lambda x: x.reshape(x.shape[0], -1)), 33 | 34 | nn.Linear(in_features = LATENT_CODE_SIZE * 2, out_features=LATENT_CODE_SIZE) 35 | ) 36 | 37 | if is_variational: 38 | self.encoder.add_module('vae-bn', nn.BatchNorm1d(LATENT_CODE_SIZE)) 39 | self.encoder.add_module('vae-lr', nn.LeakyReLU(negative_slope=0.2, inplace=True)) 40 | 41 | self.encode_mean = nn.Linear(in_features=LATENT_CODE_SIZE, out_features=LATENT_CODE_SIZE) 42 | self.encode_log_variance = nn.Linear(in_features=LATENT_CODE_SIZE, out_features=LATENT_CODE_SIZE) 43 | 44 | self.decoder = nn.Sequential( 45 | nn.Linear(in_features = LATENT_CODE_SIZE, out_features=LATENT_CODE_SIZE * 2), 46 | nn.BatchNorm1d(LATENT_CODE_SIZE * 2), 47 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 48 | 49 | Lambda(lambda x: x.reshape(-1, LATENT_CODE_SIZE * 2, 1, 1, 1)), 50 | 51 | nn.ConvTranspose3d(in_channels = LATENT_CODE_SIZE * 2, out_channels = 4 * amcm, kernel_size = 4, stride = 1), 52 | nn.BatchNorm3d(4 * amcm), 53 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 54 | 55 | nn.ConvTranspose3d(in_channels = 4 * amcm, out_channels = 2 * amcm, kernel_size = 4, stride = 2, padding = 1), 56 | nn.BatchNorm3d(2 * amcm), 57 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 58 | 59 | nn.ConvTranspose3d(in_channels = 2 * amcm, out_channels = 1 * amcm, kernel_size = 4, stride = 2, padding = 1), 60 | nn.BatchNorm3d(1 * amcm), 61 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 62 | 63 | nn.ConvTranspose3d(in_channels = 1 * amcm, out_channels = 1, kernel_size = 4, stride = 2, padding = 1) 64 | ) 65 | self.cuda() 66 | 67 | def encode(self, x, return_mean_and_log_variance = False): 68 | x = x.reshape((-1, 1, 32, 32, 32)) 69 | x = self.encoder(x) 70 | 71 | if not self.is_variational: 72 | return x 73 | 74 | mean = self.encode_mean(x).squeeze() 75 | 76 | if self.training or return_mean_and_log_variance: 77 | log_variance = self.encode_log_variance(x).squeeze() 78 | standard_deviation = torch.exp(log_variance * 0.5) 79 | eps = standard_normal_distribution.sample(mean.shape).to(x.device) 80 | 81 | if self.training: 82 | x = mean + standard_deviation * eps 83 | else: 84 | x = mean 85 | 86 | if return_mean_and_log_variance: 87 | return x, mean, log_variance 88 | else: 89 | return x 90 | 91 | def decode(self, x): 92 | if len(x.shape) == 1: 93 | x = x.unsqueeze(dim = 0) # add dimension for channels 94 | x = self.decoder(x) 95 | return x.squeeze() 96 | 97 | def forward(self, x): 98 | if not self.is_variational: 99 | z = self.encode(x) 100 | x = self.decode(z) 101 | return x 102 | 103 | z, mean, log_variance = self.encode(x, return_mean_and_log_variance = True) 104 | x = self.decode(z) 105 | return x, mean, log_variance -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import trimesh 3 | from tqdm import tqdm 4 | import numpy as np 5 | from mesh_to_sdf import get_surface_point_cloud, scale_to_unit_sphere, BadMeshException 6 | from util import ensure_directory 7 | from multiprocessing import Pool 8 | from rendering.math import get_rotation_matrix 9 | 10 | DIRECTORY_MODELS = 'data/meshes/' 11 | MODEL_EXTENSION = '.stl' 12 | DIRECTORY_SDF = 'data/sdf/' 13 | 14 | CREATE_VOXELS = True 15 | VOXEL_RESOLUTION = 32 16 | 17 | CREATE_SDF_CLOUDS = True 18 | SDF_CLOUD_SAMPLE_SIZE = 200000 19 | 20 | ROTATION = None # get_rotation_matrix(90, axis='x') 21 | 22 | def get_model_files(): 23 | for directory, _, files in os.walk(DIRECTORY_MODELS): 24 | for filename in files: 25 | if filename.endswith(MODEL_EXTENSION): 26 | yield os.path.join(directory, filename) 27 | 28 | def get_npy_filename(model_filename, qualifier=''): 29 | return DIRECTORY_SDF + model_filename[len(DIRECTORY_MODELS):-len(MODEL_EXTENSION)] + qualifier + '.npy' 30 | 31 | def get_voxel_filename(model_filename): 32 | return get_npy_filename(model_filename, '-voxels-{:d}'.format(VOXEL_RESOLUTION)) 33 | 34 | def get_sdf_cloud_filename(model_filename): 35 | return get_npy_filename(model_filename, '-sdf') 36 | 37 | def get_bad_mesh_filename(model_filename): 38 | return DIRECTORY_SDF + model_filename[len(DIRECTORY_MODELS):-len(MODEL_EXTENSION)] + '.badmesh' 39 | 40 | def mark_bad_mesh(model_filename): 41 | filename = get_bad_mesh_filename(model_filename) 42 | ensure_directory(os.path.dirname(filename)) 43 | open(filename, 'w').close() 44 | 45 | def is_bad_mesh(model_filename): 46 | return os.path.exists(get_bad_mesh_filename(model_filename)) 47 | 48 | def process_model_file(filename): 49 | voxels_filename = get_voxel_filename(filename) 50 | sdf_cloud_filename = get_sdf_cloud_filename(filename) 51 | 52 | if is_bad_mesh(filename): 53 | return 54 | if not (CREATE_VOXELS and not os.path.isfile(voxels_filename) or CREATE_SDF_CLOUDS and not os.path.isfile(sdf_cloud_filename)): 55 | return 56 | 57 | mesh = trimesh.load(filename) 58 | if ROTATION is not None: 59 | mesh.apply_transform(ROTATION) 60 | mesh = scale_to_unit_sphere(mesh) 61 | 62 | surface_point_cloud = get_surface_point_cloud(mesh) 63 | if CREATE_SDF_CLOUDS: 64 | try: 65 | points, sdf = surface_point_cloud.sample_sdf_near_surface(number_of_points=SDF_CLOUD_SAMPLE_SIZE, sign_method='depth', min_size=0.015) 66 | combined = np.concatenate((points, sdf[:, np.newaxis]), axis=1) 67 | ensure_directory(os.path.dirname(sdf_cloud_filename)) 68 | np.save(sdf_cloud_filename, combined) 69 | except BadMeshException: 70 | tqdm.write("Skipping bad mesh. ({:s})".format(filename)) 71 | mark_bad_mesh(filename) 72 | return 73 | 74 | if CREATE_VOXELS: 75 | try: 76 | voxels = surface_point_cloud.get_voxels(voxel_resolution=VOXEL_RESOLUTION, use_depth_buffer=True) 77 | ensure_directory(os.path.dirname(voxels_filename)) 78 | np.save(voxels_filename, voxels) 79 | except BadMeshException: 80 | tqdm.write("Skipping bad mesh. ({:s})".format(filename)) 81 | mark_bad_mesh(filename) 82 | return 83 | 84 | 85 | def process_model_files(): 86 | ensure_directory(DIRECTORY_SDF) 87 | files = list(get_model_files()) 88 | 89 | worker_count = os.cpu_count() // 2 90 | print("Using {:d} processes.".format(worker_count)) 91 | pool = Pool(worker_count) 92 | 93 | progress = tqdm(total=len(files)) 94 | def on_complete(*_): 95 | progress.update() 96 | 97 | for filename in files: 98 | pool.apply_async(process_model_file, args=(filename,), callback=on_complete) 99 | pool.close() 100 | pool.join() 101 | 102 | def combine_pointcloud_files(): 103 | import torch 104 | print("Combining SDF point clouds...") 105 | npy_files = sorted([get_sdf_cloud_filename(f) for f in get_model_files()]) 106 | npy_files = [f for f in npy_files if os.path.exists(f)] 107 | 108 | N = len(npy_files) 109 | points = torch.zeros((N * SDF_CLOUD_SAMPLE_SIZE, 3)) 110 | sdf = torch.zeros((N * SDF_CLOUD_SAMPLE_SIZE)) 111 | position = 0 112 | 113 | for npy_filename in tqdm(npy_files): 114 | numpy_array = np.load(npy_filename) 115 | points[position * SDF_CLOUD_SAMPLE_SIZE:(position + 1) * SDF_CLOUD_SAMPLE_SIZE, :] = torch.tensor(numpy_array[:, :3]) 116 | sdf[position * SDF_CLOUD_SAMPLE_SIZE:(position + 1) * SDF_CLOUD_SAMPLE_SIZE] = torch.tensor(numpy_array[:, 3]) 117 | del numpy_array 118 | position += 1 119 | 120 | print("Saving combined SDF clouds...") 121 | torch.save(points, os.path.join('data', 'sdf_points.to')) 122 | torch.save(sdf, os.path.join('data', 'sdf_values.to')) 123 | 124 | if __name__ == '__main__': 125 | process_model_files() 126 | if CREATE_SDF_CLOUDS: 127 | combine_pointcloud_files() -------------------------------------------------------------------------------- /train_gan.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import numpy as np 7 | 8 | import random 9 | import time 10 | import sys 11 | from collections import deque 12 | 13 | from model.gan import Generator, Discriminator 14 | 15 | from util import create_text_slice, device 16 | from datasets import VoxelDataset 17 | from torch.utils.data import DataLoader 18 | 19 | generator = Generator() 20 | discriminator = Discriminator() 21 | 22 | if "continue" in sys.argv: 23 | generator.load() 24 | discriminator.load() 25 | 26 | log_file = open("plots/gan_training.csv", "a" if "continue" in sys.argv else "w") 27 | 28 | generator_optimizer = optim.Adam(generator.parameters(), lr=0.001) 29 | 30 | discriminator_criterion = torch.nn.functional.binary_cross_entropy 31 | discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.00001) 32 | 33 | show_viewer = "nogui" not in sys.argv 34 | 35 | if show_viewer: 36 | from rendering import MeshRenderer 37 | viewer = MeshRenderer() 38 | 39 | BATCH_SIZE = 64 40 | 41 | dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy') 42 | data_loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=8) 43 | 44 | valid_target_default = torch.ones(BATCH_SIZE, requires_grad=False).to(device) 45 | fake_target_default = torch.zeros(BATCH_SIZE, requires_grad=False).to(device) 46 | 47 | def train(): 48 | history_fake = deque(maxlen=50) 49 | history_real = deque(maxlen=50) 50 | 51 | for epoch in count(): 52 | batch_index = 0 53 | epoch_start_time = time.time() 54 | for batch in data_loader: 55 | try: 56 | 57 | # train generator 58 | generator_optimizer.zero_grad() 59 | 60 | fake_sample = generator.generate(sample_size = BATCH_SIZE) 61 | if show_viewer: 62 | viewer.set_voxels(fake_sample[0, :, :, :].squeeze().detach().cpu().numpy()) 63 | 64 | fake_discriminator_output = discriminator(fake_sample) 65 | fake_loss = -torch.mean(torch.log(fake_discriminator_output)) 66 | fake_loss.backward() 67 | generator_optimizer.step() 68 | 69 | 70 | # train discriminator 71 | current_batch_size = batch.shape[0] # equals BATCH_SIZE for all batches except the last one 72 | fake_target = fake_target_default if current_batch_size == BATCH_SIZE else torch.zeros(current_batch_size, requires_grad=False).to(device) 73 | valid_target = valid_target_default if current_batch_size == BATCH_SIZE else torch.ones(current_batch_size, requires_grad=False).to(device) 74 | 75 | discriminator_optimizer.zero_grad() 76 | fake_sample = generator.generate(sample_size = current_batch_size).detach() 77 | discriminator_output_fake = discriminator(fake_sample) 78 | fake_loss = discriminator_criterion(discriminator_output_fake, fake_target) 79 | fake_loss.backward() 80 | discriminator_optimizer.step() 81 | 82 | discriminator_optimizer.zero_grad() 83 | discriminator_output_valid = discriminator(batch.to(device)) 84 | valid_loss = discriminator_criterion(discriminator_output_valid, valid_target) 85 | valid_loss.backward() 86 | discriminator_optimizer.step() 87 | 88 | history_fake.append(torch.mean(discriminator_output_fake).item()) 89 | history_real.append(torch.mean(discriminator_output_valid).item()) 90 | batch_index += 1 91 | 92 | if "verbose" in sys.argv: 93 | print("Epoch " + str(epoch) + ", batch " + str(batch_index) + 94 | ": prediction on fake samples: " + '{0:.4f}'.format(history_fake[-1]) + 95 | ", prediction on valid samples: " + '{0:.4f}'.format(history_real[-1])) 96 | except KeyboardInterrupt: 97 | if show_viewer: 98 | viewer.stop() 99 | return 100 | 101 | generator.save() 102 | discriminator.save() 103 | 104 | if epoch % 20 == 0: 105 | generator.save(epoch=epoch) 106 | discriminator.save(epoch=epoch) 107 | 108 | if "show_slice" in sys.argv: 109 | voxels = generator.generate().squeeze() 110 | print(create_text_slice(voxels)) 111 | 112 | prediction_fake = np.mean(history_fake) 113 | prediction_real = np.mean(history_real) 114 | print('Epoch {:d} ({:.1f}s), prediction on fake: {:.4f}, prediction on real: {:.4f}'.format(epoch, time.time() - epoch_start_time, prediction_fake, prediction_real)) 115 | log_file.write('{:d} {:.1f} {:.4f} {:.4f}\n'.format(epoch, time.time() - epoch_start_time, prediction_fake, prediction_real)) 116 | log_file.flush() 117 | 118 | 119 | train() 120 | log_file.close() 121 | -------------------------------------------------------------------------------- /train_autoencoder.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from datasets import VoxelDataset 7 | from torch.utils.data import DataLoader 8 | 9 | import random 10 | random.seed(0) 11 | torch.manual_seed(0) 12 | 13 | import numpy as np 14 | import sys 15 | import time 16 | from tqdm import tqdm 17 | 18 | from model.autoencoder import Autoencoder 19 | from collections import deque 20 | from util import create_text_slice, device 21 | 22 | BATCH_SIZE = 32 23 | 24 | dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy') 25 | data_loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=8) 26 | 27 | VIEWER_UPDATE_STEP = 20 28 | 29 | IS_VARIATIONAL = 'classic' not in sys.argv 30 | 31 | autoencoder = Autoencoder(is_variational=IS_VARIATIONAL) 32 | if "continue" in sys.argv: 33 | autoencoder.load() 34 | 35 | optimizer = optim.Adam(autoencoder.parameters(), lr=0.00005) 36 | 37 | show_viewer = "nogui" not in sys.argv 38 | 39 | if show_viewer: 40 | from rendering import MeshRenderer 41 | viewer = MeshRenderer() 42 | 43 | reconstruction_error_history = deque(maxlen = BATCH_SIZE) 44 | kld_error_history = deque(maxlen = BATCH_SIZE) 45 | 46 | criterion = nn.functional.mse_loss 47 | 48 | log_file = open("plots/{:s}autoencoder_training.csv".format('variational_' if autoencoder.is_variational else ''), "a" if "continue" in sys.argv else "w") 49 | 50 | def voxel_difference(input, target): 51 | wrong_signs = (input * target) < 0 52 | return torch.sum(wrong_signs).item() / wrong_signs.nelement() 53 | 54 | def kld_loss(mean, log_variance): 55 | return -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp()) / mean.nelement() 56 | 57 | def get_reconstruction_loss(input, target): 58 | difference = input - target 59 | wrong_signs = target < 0 60 | difference[wrong_signs] *= 32 61 | 62 | return torch.mean(torch.abs(difference)) 63 | 64 | def test(epoch_index, epoch_time, test_set): 65 | with torch.no_grad(): 66 | autoencoder.eval() 67 | 68 | if IS_VARIATIONAL: 69 | output, mean, log_variance = autoencoder(test_set) 70 | kld = kld_loss(mean, log_variance) 71 | else: 72 | output = autoencoder(test_set) 73 | kld = 0 74 | 75 | reconstruction_loss = criterion(output, test_set) 76 | 77 | voxel_diff = voxel_difference(output, test_set) 78 | 79 | if "show_slice" in sys.argv: 80 | print(create_text_slice(output[0, :, :, :])) 81 | 82 | print("Epoch {:d} ({:.1f}s): ".format(epoch_index, epoch_time) + 83 | "Reconstruction loss: {:.4f}, ".format(reconstruction_loss) + 84 | "Voxel diff: {:.4f}, ".format(voxel_diff) + 85 | "KLD loss: {:4f}, ".format(kld) + 86 | "training loss: {:4f}, ".format(np.mean(reconstruction_error_history)) 87 | ) 88 | 89 | log_file.write('{:d} {:.1f} {:.6f} {:.6f} {:.6f}\n'.format(epoch_index, epoch_time, reconstruction_loss, kld, voxel_diff)) 90 | log_file.flush() 91 | 92 | def train(): 93 | for epoch in count(): 94 | batch_index = 0 95 | epoch_start_time = time.time() 96 | for batch in tqdm(data_loader, desc='Epoch {:d}'.format(epoch)): 97 | try: 98 | batch = batch.to(device) 99 | 100 | autoencoder.zero_grad() 101 | autoencoder.train() 102 | if IS_VARIATIONAL: 103 | output, mean, log_variance = autoencoder(batch) 104 | kld = kld_loss(mean, log_variance) 105 | else: 106 | output = autoencoder(batch) 107 | kld = 0 108 | 109 | reconstruction_loss = get_reconstruction_loss(output, batch) 110 | 111 | loss = reconstruction_loss + kld 112 | 113 | reconstruction_error_history.append(reconstruction_loss.item()) 114 | kld_error_history.append(kld.item() if IS_VARIATIONAL else 0) 115 | 116 | loss.backward() 117 | optimizer.step() 118 | 119 | if show_viewer and batch_index == 0: 120 | viewer.set_voxels(output[0, :, :, :].squeeze().detach().cpu().numpy()) 121 | 122 | if show_viewer and (batch_index + 1) % VIEWER_UPDATE_STEP == 0 and 'verbose' in sys.argv: 123 | viewer.set_voxels(output[0, :, :, :].squeeze().detach().cpu().numpy()) 124 | print("epoch " + str(epoch) + ", batch " + str(batch_index) \ 125 | + ', reconstruction loss: {0:.4f}'.format(reconstruction_loss.item()) \ 126 | + ' (average: {0:.4f}), '.format(np.mean(reconstruction_error_history)) \ 127 | + 'KLD loss: {0:.4f}'.format(np.mean(kld_error_history))) 128 | batch_index += 1 129 | except KeyboardInterrupt: 130 | if show_viewer: 131 | viewer.stop() 132 | return 133 | autoencoder.save() 134 | if epoch % 20 == 0: 135 | autoencoder.save(epoch=epoch) 136 | #test(epoch, time.time() - epoch_start_time, test_set) 137 | print("Epoch {:d} ({:.1f}s): reconstruction loss: {:.4f}, KLD loss: {:.4f}".format( 138 | epoch, 139 | time.time() - epoch_start_time, 140 | np.mean(reconstruction_error_history), 141 | np.mean(kld_error_history))) 142 | 143 | train() -------------------------------------------------------------------------------- /train_point_gan_ref.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.optim import RMSprop 7 | 8 | from datasets import PointDataset 9 | from model.point_sdf_net import PointNet, SDFGenerator 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--category', type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | LATENT_SIZE = 128 16 | GRADIENT_PENALITY = 10 17 | HIDDEN_SIZE = 256 18 | NUM_LAYERS = 8 19 | NORM = True 20 | THRESHOLD = 0.1 21 | 22 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 23 | G = SDFGenerator(LATENT_SIZE, HIDDEN_SIZE, NUM_LAYERS, NORM, dropout=0.0) 24 | D = PointNet(out_channels=1) 25 | G, D = G.to(device), D.to(device) 26 | 27 | root = osp.join(f'data/{args.category}') 28 | dataset = PointDataset.from_split(root, split='train') 29 | 30 | 31 | def generate_batch(u_pos, u_dist, s_pos, s_dist): 32 | u_batch = torch.arange(u_pos.size(0), device=u_pos.device) 33 | u_batch = u_batch.view(-1, 1).repeat(1, u_pos.size(1)) 34 | 35 | mask = u_dist.abs().squeeze(-1) < THRESHOLD 36 | 37 | s_pos = s_pos[mask].view(-1, 3) 38 | s_dist = s_dist[mask].view(-1, 1) 39 | s_batch = u_batch[mask].view(-1) 40 | 41 | mask = mask | (torch.rand(mask.size(), device=mask.device) < 0.15) 42 | 43 | u_pos = u_pos[mask].view(-1, 3) 44 | u_dist = u_dist[mask].view(-1, 1) 45 | u_batch = u_batch[mask].view(-1) 46 | 47 | return ( 48 | torch.cat([u_pos, s_pos], dim=0), 49 | torch.cat([u_dist, s_dist], dim=0), 50 | torch.cat([u_batch, s_batch], dim=0), 51 | ) 52 | 53 | 54 | class RefinementGenerator(torch.nn.Module): 55 | def __init__(self, generator): 56 | super(RefinementGenerator, self).__init__() 57 | self.generator = generator 58 | 59 | def forward(self, u_pos, z): 60 | u_pos.requires_grad_(True) 61 | u_dist = self.generator(u_pos, z) 62 | 63 | grad = torch.autograd.grad(u_dist, u_pos, 64 | grad_outputs=torch.ones_like(u_dist), 65 | retain_graph=True, only_inputs=True)[0] 66 | s_pos = u_pos - u_dist * grad 67 | s_pos = s_pos + 0.0025 * torch.randn_like(s_pos) 68 | s_dist = self.generator(s_pos, z) 69 | 70 | return u_pos, u_dist, s_pos, s_dist 71 | 72 | 73 | # TODO: Load G and D from `train_point_gan.py`. 74 | # G.load_state_dict(torch.load(..., map_location=device)) 75 | # D.load_state_dict(torch.load(..., map_location=device)) 76 | ref_G = RefinementGenerator(G).to(device) 77 | G_optimizer = RMSprop(ref_G.parameters(), lr=0.0001) 78 | D_optimizer = RMSprop(D.parameters(), lr=0.0001) 79 | 80 | configuration = [ # num_points, batch_size, epochs 81 | (8192, 16, 60), 82 | (16384, 8, 60), 83 | ] 84 | 85 | num_steps = 0 86 | for num_points, batch_size, epochs in configuration: 87 | dataset.num_points = num_points 88 | loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=6) 89 | 90 | for epoch in range(1, epochs + 1): 91 | total_loss = 0 92 | for uniform, surface in loader: 93 | num_steps += 1 94 | 95 | uniform, surface = uniform.to(device), surface.to(device) 96 | u_pos, u_dist = uniform[..., :3], uniform[..., 3:] 97 | s_pos, s_dist = surface[..., :3], surface[..., 3:] 98 | 99 | D_optimizer.zero_grad() 100 | 101 | z = torch.randn(uniform.size(0), LATENT_SIZE, device=device) 102 | fake_u_pos, fake_u_dist, fake_s_pos, fake_s_dist = ref_G(u_pos, z) 103 | fake_pos, fake_dist, fake_batch = generate_batch( 104 | fake_u_pos, fake_u_dist, fake_s_pos, fake_s_dist) 105 | 106 | real_pos, real_dist, real_batch = generate_batch( 107 | u_pos, u_dist, s_pos, s_dist) 108 | 109 | out_real = D(real_pos, real_dist, real_batch) 110 | out_fake = D(fake_pos, fake_dist, fake_batch) 111 | D_loss = out_fake.mean() - out_real.mean() 112 | 113 | alpha = torch.rand((uniform.size(0), 1, 1), device=device) 114 | interpolated = alpha * u_dist + (1 - alpha) * fake_u_dist 115 | interpolated.requires_grad_(True) 116 | out = D(u_pos, interpolated) 117 | 118 | grad = torch.autograd.grad(out, interpolated, 119 | grad_outputs=torch.ones_like(out), 120 | create_graph=True, retain_graph=True, 121 | only_inputs=True)[0] 122 | grad_norm = grad.view(grad.size(0), -1).norm(dim=-1, p=2) 123 | gp = GRADIENT_PENALITY * ((grad_norm - 1).pow(2).mean()) 124 | 125 | loss = D_loss + gp 126 | loss.backward() 127 | D_optimizer.step() 128 | 129 | if num_steps % 5 == 0: 130 | G_optimizer.zero_grad() 131 | z = torch.randn(uniform.size(0), LATENT_SIZE, device=device) 132 | fake = ref_G(u_pos, z) 133 | fake_u_pos, fake_u_dist, fake_s_pos, fake_s_dist = fake 134 | fake_pos, fake_dist, fake_batch = generate_batch( 135 | fake_u_pos, fake_u_dist, fake_s_pos, fake_s_dist) 136 | out_fake = D(fake_pos, fake_dist, fake_batch) 137 | loss = -out_fake.mean() 138 | loss.backward() 139 | G_optimizer.step() 140 | 141 | total_loss += D_loss.abs().item() 142 | 143 | print('Num points: {}, Epoch: {:03d}, Loss: {:.6f}'.format( 144 | num_points, epoch, total_loss / len(loader))) 145 | -------------------------------------------------------------------------------- /train_hybrid_wgan.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import numpy as np 7 | 8 | import random 9 | import time 10 | import sys 11 | from collections import deque 12 | from tqdm import tqdm 13 | 14 | from model.sdf_net import SDFNet 15 | from model.gan import Discriminator, LATENT_CODE_SIZE 16 | from util import create_text_slice, device, standard_normal_distribution 17 | 18 | VOXEL_RESOLUTION = 32 19 | SDF_CLIPPING = 0.1 20 | from util import create_text_slice,get_voxel_coordinates 21 | 22 | from datasets import VoxelDataset 23 | from torch.utils.data import DataLoader 24 | 25 | LEARN_RATE = 0.00001 26 | BATCH_SIZE = 8 27 | CRITIC_UPDATES_PER_GENERATOR_UPDATE = 5 28 | CRITIC_WEIGHT_LIMIT = 0.01 29 | 30 | dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy') 31 | dataset.rescale_sdf = False 32 | data_loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=8) 33 | 34 | generator = SDFNet() 35 | generator.filename = 'hybrid_wgan_generator.to' 36 | 37 | critic = Discriminator() 38 | critic.filename = 'hybrid_wgan_critic.to' 39 | critic.use_sigmoid = False 40 | 41 | if "continue" in sys.argv: 42 | generator.load() 43 | critic.load() 44 | 45 | LOG_FILE_NAME = "plots/hybrid_wgan_training.csv" 46 | first_epoch = 0 47 | if 'continue' in sys.argv: 48 | log_file_contents = open(LOG_FILE_NAME, 'r').readlines() 49 | first_epoch = len(log_file_contents) 50 | 51 | log_file = open(LOG_FILE_NAME, "a" if "continue" in sys.argv else "w") 52 | 53 | generator_optimizer = optim.Adam(generator.parameters(), lr=LEARN_RATE) 54 | 55 | critic_criterion = torch.nn.functional.binary_cross_entropy 56 | critic_optimizer = optim.RMSprop(critic.parameters(), lr=LEARN_RATE) 57 | 58 | show_viewer = "nogui" not in sys.argv 59 | 60 | if show_viewer: 61 | from rendering import MeshRenderer 62 | viewer = MeshRenderer() 63 | 64 | valid_target = torch.ones(BATCH_SIZE, requires_grad=False).to(device) 65 | fake_target = torch.zeros(BATCH_SIZE, requires_grad=False).to(device) 66 | 67 | def sample_latent_codes(): 68 | latent_codes = standard_normal_distribution.sample(sample_shape=[BATCH_SIZE, LATENT_CODE_SIZE]).to(device) 69 | latent_codes = latent_codes.repeat((1, 1, VOXEL_RESOLUTION**3)).reshape(-1, LATENT_CODE_SIZE) 70 | return latent_codes 71 | 72 | grid_points = get_voxel_coordinates(VOXEL_RESOLUTION, return_torch_tensor=True).repeat((BATCH_SIZE, 1)) 73 | history_fake = deque(maxlen=50) 74 | history_real = deque(maxlen=50) 75 | 76 | def train(): 77 | for epoch in count(start=first_epoch): 78 | batch_index = 0 79 | epoch_start_time = time.time() 80 | for batch in tqdm(data_loader, desc='Epoch {:d}'.format(epoch)): 81 | try: 82 | # train critic 83 | critic_optimizer.zero_grad() 84 | latent_codes = sample_latent_codes() 85 | fake_sample = generator(grid_points, latent_codes) 86 | fake_sample = fake_sample.reshape(-1, VOXEL_RESOLUTION, VOXEL_RESOLUTION, VOXEL_RESOLUTION) 87 | 88 | critic_output_fake = critic(fake_sample) 89 | critic_output_valid = critic(batch.to(device)) 90 | 91 | critic_loss = torch.mean(critic_output_fake) - torch.mean(critic_output_valid) 92 | critic_loss.backward() 93 | critic_optimizer.step() 94 | critic.clip_weights(CRITIC_WEIGHT_LIMIT) 95 | 96 | # train generator 97 | if batch_index % CRITIC_UPDATES_PER_GENERATOR_UPDATE == 0: 98 | generator_optimizer.zero_grad() 99 | critic.zero_grad() 100 | 101 | latent_codes = sample_latent_codes() 102 | fake_sample = generator(grid_points, latent_codes) 103 | fake_sample = fake_sample.reshape(-1, VOXEL_RESOLUTION, VOXEL_RESOLUTION, VOXEL_RESOLUTION) 104 | if batch_index % 20 == 0 and show_viewer: 105 | viewer.set_voxels(fake_sample[0, :, :, :].squeeze().detach().cpu().numpy()) 106 | if batch_index % 20 == 0 and "show_slice" in sys.argv: 107 | print(create_text_slice(fake_sample[0, :, :, :] / SDF_CLIPPING)) 108 | 109 | critic_output_fake = critic(fake_sample) 110 | # TODO an incorrect loss function was used here as pointed out in issue #2 111 | # This hasn't been tested yet after fixing the loss function 112 | # The incorrect loss function was: fake_loss = torch.mean(-torch.log(critic_output_fake)) 113 | fake_loss = torch.mean(-critic_output_fake) 114 | fake_loss.backward() 115 | generator_optimizer.step() 116 | 117 | history_fake.append(torch.mean(critic_output_fake).item()) 118 | history_real.append(torch.mean(critic_output_valid).item()) 119 | 120 | if "verbose" in sys.argv and batch_index % 20 == 0: 121 | print("Epoch " + str(epoch) + ", batch " + str(batch_index) + 122 | ": prediction on fake samples: " + '{0:.4f}'.format(history_fake[-1]) + 123 | ", prediction on valid samples: " + '{0:.4f}'.format(history_real[-1])) 124 | 125 | batch_index += 1 126 | except KeyboardInterrupt: 127 | if show_viewer: 128 | viewer.stop() 129 | return 130 | 131 | generator.save() 132 | critic.save() 133 | 134 | generator.save(epoch=epoch) 135 | critic.save(epoch=epoch) 136 | 137 | prediction_fake = np.mean(history_fake) 138 | prediction_real = np.mean(history_real) 139 | print('Epoch {:d} ({:.1f}s), prediction on fake: {:.4f}, prediction on real: {:.4f}'.format(epoch, time.time() - epoch_start_time, prediction_fake, prediction_real)) 140 | log_file.write('{:d} {:.1f} {:.4f} {:.4f}\n'.format(epoch, time.time() - epoch_start_time, prediction_fake, prediction_real)) 141 | log_file.flush() 142 | 143 | 144 | train() 145 | log_file.close() 146 | -------------------------------------------------------------------------------- /train_hybrid_gan.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import numpy as np 7 | 8 | import random 9 | import time 10 | import sys 11 | from collections import deque 12 | from tqdm import tqdm 13 | 14 | from model.sdf_net import SDFNet 15 | from model.gan import Discriminator, LATENT_CODE_SIZE 16 | from util import create_text_slice, device, standard_normal_distribution, get_voxel_coordinates 17 | 18 | VOXEL_RESOLUTION = 32 19 | SDF_CLIPPING = 0.1 20 | from util import create_text_slice 21 | 22 | from datasets import VoxelDataset 23 | from torch.utils.data import DataLoader 24 | 25 | generator = SDFNet() 26 | generator.filename = 'hybrid_gan_generator.to' 27 | 28 | discriminator = Discriminator() 29 | discriminator.filename = 'hybrid_gan_discriminator.to' 30 | 31 | if "continue" in sys.argv: 32 | generator.load() 33 | discriminator.load() 34 | 35 | LOG_FILE_NAME = "plots/hybrid_gan_training.csv" 36 | first_epoch = 0 37 | if 'continue' in sys.argv: 38 | log_file_contents = open(LOG_FILE_NAME, 'r').readlines() 39 | first_epoch = len(log_file_contents) 40 | 41 | log_file = open(LOG_FILE_NAME, "a" if "continue" in sys.argv else "w") 42 | 43 | generator_optimizer = optim.Adam(generator.parameters(), lr=0.001) 44 | 45 | discriminator_criterion = torch.nn.functional.binary_cross_entropy 46 | discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.00001) 47 | 48 | show_viewer = "nogui" not in sys.argv 49 | 50 | if show_viewer: 51 | from rendering import MeshRenderer 52 | viewer = MeshRenderer() 53 | 54 | BATCH_SIZE = 8 55 | 56 | dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy') 57 | dataset.rescale_sdf = False 58 | data_loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=8) 59 | 60 | valid_target_default = torch.ones(BATCH_SIZE, requires_grad=False).to(device) 61 | fake_target_default = torch.zeros(BATCH_SIZE, requires_grad=False).to(device) 62 | 63 | def sample_latent_codes(current_batch_size): 64 | latent_codes = standard_normal_distribution.sample(sample_shape=[current_batch_size, LATENT_CODE_SIZE]).to(device) 65 | latent_codes = latent_codes.repeat((1, 1, grid_points.shape[0])).reshape(-1, LATENT_CODE_SIZE) 66 | return latent_codes 67 | 68 | grid_points = get_voxel_coordinates(VOXEL_RESOLUTION, return_torch_tensor=True) 69 | history_fake = deque(maxlen=50) 70 | history_real = deque(maxlen=50) 71 | 72 | def train(): 73 | for epoch in count(start=first_epoch): 74 | batch_index = 0 75 | epoch_start_time = time.time() 76 | for batch in tqdm(data_loader, desc='Epoch {:d}'.format(epoch)): 77 | try: 78 | current_batch_size = batch.shape[0] # equals BATCH_SIZE for all batches except the last one 79 | batch_grid_points = grid_points.repeat((current_batch_size, 1)) 80 | 81 | # train generator 82 | generator_optimizer.zero_grad() 83 | 84 | latent_codes = sample_latent_codes(current_batch_size) 85 | fake_sample = generator(batch_grid_points, latent_codes) 86 | fake_sample = fake_sample.reshape(-1, VOXEL_RESOLUTION, VOXEL_RESOLUTION, VOXEL_RESOLUTION) 87 | if batch_index % 20 == 0 and show_viewer: 88 | viewer.set_voxels(fake_sample[0, :, :, :].squeeze().detach().cpu().numpy()) 89 | if batch_index % 20 == 0 and "show_slice" in sys.argv: 90 | print(create_text_slice(fake_sample[0, :, :, :] / SDF_CLIPPING)) 91 | 92 | fake_discriminator_output = discriminator(fake_sample) 93 | fake_loss = torch.mean(-torch.log(fake_discriminator_output)) 94 | fake_loss.backward() 95 | generator_optimizer.step() 96 | 97 | 98 | # train discriminator on fake samples 99 | fake_target = fake_target_default if current_batch_size == BATCH_SIZE else torch.zeros(current_batch_size, requires_grad=False).to(device) 100 | valid_target = valid_target_default if current_batch_size == BATCH_SIZE else torch.ones(current_batch_size, requires_grad=False).to(device) 101 | 102 | discriminator_optimizer.zero_grad() 103 | latent_codes = sample_latent_codes(current_batch_size) 104 | fake_sample = generator(batch_grid_points, latent_codes) 105 | fake_sample = fake_sample.reshape(-1, VOXEL_RESOLUTION, VOXEL_RESOLUTION, VOXEL_RESOLUTION) 106 | discriminator_output_fake = discriminator(fake_sample) 107 | fake_loss = discriminator_criterion(discriminator_output_fake, fake_target) 108 | fake_loss.backward() 109 | discriminator_optimizer.step() 110 | 111 | # train discriminator on real samples 112 | discriminator_optimizer.zero_grad() 113 | discriminator_output_valid = discriminator(batch.to(device)) 114 | valid_loss = discriminator_criterion(discriminator_output_valid, valid_target) 115 | valid_loss.backward() 116 | discriminator_optimizer.step() 117 | 118 | history_fake.append(torch.mean(discriminator_output_fake).item()) 119 | history_real.append(torch.mean(discriminator_output_valid).item()) 120 | batch_index += 1 121 | 122 | if "verbose" in sys.argv: 123 | print("Epoch " + str(epoch) + ", batch " + str(batch_index) + 124 | ": prediction on fake samples: " + '{0:.4f}'.format(history_fake[-1]) + 125 | ", prediction on valid samples: " + '{0:.4f}'.format(history_real[-1])) 126 | except KeyboardInterrupt: 127 | if show_viewer: 128 | viewer.stop() 129 | return 130 | 131 | prediction_fake = np.mean(history_fake) 132 | prediction_real = np.mean(history_real) 133 | 134 | print('Epoch {:d} ({:.1f}s), prediction on fake: {:.4f}, prediction on real: {:.4f}'.format(epoch, time.time() - epoch_start_time, prediction_fake, prediction_real)) 135 | 136 | if abs(prediction_fake - prediction_real) > 0.1: 137 | print("Network diverged.") 138 | exit() 139 | 140 | generator.save() 141 | discriminator.save() 142 | 143 | generator.save(epoch=epoch) 144 | discriminator.save(epoch=epoch) 145 | 146 | if "show_slice" in sys.argv: 147 | latent_code = sample_latent_codes(1) 148 | voxels = generator(grid_points, latent_code) 149 | voxels = voxels.reshape(VOXEL_RESOLUTION, VOXEL_RESOLUTION, VOXEL_RESOLUTION) 150 | print(create_text_slice(voxels / SDF_CLIPPING)) 151 | 152 | log_file.write('{:d} {:.1f} {:.4f} {:.4f}\n'.format(epoch, time.time() - epoch_start_time, prediction_fake, prediction_real)) 153 | log_file.flush() 154 | 155 | 156 | train() 157 | log_file.close() 158 | -------------------------------------------------------------------------------- /demo_latent_space.py: -------------------------------------------------------------------------------- 1 | from util import device, ensure_directory 2 | import scipy.interpolate 3 | import numpy as np 4 | from rendering import MeshRenderer 5 | import torch 6 | from tqdm import tqdm 7 | import cv2 8 | import random 9 | import matplotlib.pyplot as plt 10 | from sklearn.manifold import TSNE 11 | from matplotlib.offsetbox import Bbox 12 | from sklearn.cluster import KMeans 13 | 14 | SAMPLE_COUNT = 30 # Number of distinct objects to generate and interpolate between 15 | TRANSITION_FRAMES = 60 16 | 17 | USE_VAE = False 18 | 19 | SURFACE_LEVEL = 0.011 20 | 21 | FRAMES = SAMPLE_COUNT * TRANSITION_FRAMES 22 | progress = np.arange(FRAMES, dtype=float) / TRANSITION_FRAMES 23 | 24 | 25 | if USE_VAE: 26 | from model.autoencoder import Autoencoder, LATENT_CODE_SIZE 27 | vae = Autoencoder() 28 | vae.load() 29 | vae.eval() 30 | print("Calculating latent codes...") 31 | 32 | 33 | from datasets import VoxelDataset 34 | from torch.utils.data import DataLoader 35 | 36 | dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy') 37 | dataloader = DataLoader(dataset, batch_size=1000, num_workers=8) 38 | 39 | latent_codes = torch.zeros((len(dataset), LATENT_CODE_SIZE)) 40 | 41 | with torch.no_grad(): 42 | position = 0 43 | for batch in tqdm(dataloader): 44 | latent_codes[position:position + batch.shape[0], :] = vae.encode(batch.to(device)).detach().cpu() 45 | latent_codes = latent_codes.numpy() 46 | else: 47 | from model.sdf_net import SDFNet, LATENT_CODES_FILENAME 48 | latent_codes = torch.load(LATENT_CODES_FILENAME).detach().cpu().numpy() 49 | 50 | sdf_net = SDFNet() 51 | sdf_net.load() 52 | sdf_net.eval() 53 | 54 | from shapenet_metadata import shapenet 55 | raise NotImplementedError('A labels tensor needs to be supplied here.') 56 | labels = None 57 | 58 | print("Calculating embedding...") 59 | tsne = TSNE(n_components=2) 60 | latent_codes_embedded = tsne.fit_transform(latent_codes) 61 | print("Calculating clusters...") 62 | kmeans = KMeans(n_clusters=SAMPLE_COUNT) 63 | 64 | indices = np.zeros(SAMPLE_COUNT, dtype=int) 65 | kmeans_clusters = kmeans.fit_predict(latent_codes_embedded) 66 | for i in range(SAMPLE_COUNT): 67 | center = kmeans.cluster_centers_[i, :] 68 | cluster_classes = labels[kmeans_clusters == i] 69 | cluster_class = np.bincount(cluster_classes).argmax() 70 | dist = np.linalg.norm(latent_codes_embedded - center[np.newaxis, :], axis=1) 71 | dist[labels != cluster_class] = float('inf') 72 | indices[i] = np.argmin(dist) 73 | 74 | def try_find_shortest_roundtrip(indices): 75 | best_order = indices 76 | best_distance = None 77 | for _ in range(5000): 78 | candiate = best_order.copy() 79 | a = random.randint(0, SAMPLE_COUNT-1) 80 | b = random.randint(0, SAMPLE_COUNT-1) 81 | candiate[a] = best_order[b] 82 | candiate[b] = best_order[a] 83 | dist = np.sum(np.linalg.norm(latent_codes_embedded[candiate, :] - latent_codes_embedded[np.roll(candiate, 1), :], axis=1)).item() 84 | if best_distance is None or dist < best_distance: 85 | best_distance = dist 86 | best_order = candiate 87 | 88 | return best_order, best_distance 89 | 90 | def find_shortest_roundtrip(indices): 91 | best_order, best_distance = try_find_shortest_roundtrip(indices) 92 | 93 | for _ in tqdm(range(100)): 94 | np.random.shuffle(indices) 95 | order, distance = try_find_shortest_roundtrip(indices) 96 | if distance < best_distance: 97 | best_order = order 98 | return best_order 99 | 100 | print("Calculating trip...") 101 | indices = find_shortest_roundtrip(indices) 102 | indices = np.concatenate((indices, indices[0][np.newaxis])) 103 | 104 | SIZE = latent_codes.shape[0] 105 | 106 | stop_latent_codes = latent_codes[indices, :] 107 | 108 | colors = np.zeros((labels.shape[0], 3)) 109 | for i in range(labels.shape[0]): 110 | colors[i, :] = shapenet.get_color(labels[i]) 111 | 112 | spline = scipy.interpolate.CubicSpline(np.arange(SAMPLE_COUNT + 1), stop_latent_codes, axis=0, bc_type='periodic') 113 | frame_latent_codes = spline(progress) 114 | 115 | color_spline = scipy.interpolate.CubicSpline(np.arange(SAMPLE_COUNT + 1), colors[indices, :], axis=0, bc_type='periodic') 116 | frame_colors = color_spline(progress) 117 | frame_colors = np.clip(frame_colors, 0, 1) 118 | 119 | frame_colors = np.zeros((progress.shape[0], 3)) 120 | for i in range(SAMPLE_COUNT): 121 | frame_colors[i*TRANSITION_FRAMES:(i+1)*TRANSITION_FRAMES, :] = np.linspace(colors[indices[i]], colors[indices[i+1]], num=TRANSITION_FRAMES) 122 | 123 | embedded_spline = scipy.interpolate.CubicSpline(np.arange(SAMPLE_COUNT + 1), latent_codes_embedded[indices, :], axis=0, bc_type='periodic') 124 | frame_latent_codes_embedded = embedded_spline(progress) 125 | frame_latent_codes_embedded[0, :] = frame_latent_codes_embedded[-1, :] 126 | 127 | width, height = 40, 40 128 | 129 | PLOT_FILE_NAME = 'tsne.png' 130 | ensure_directory('images') 131 | 132 | margin = 2 133 | range_x = (latent_codes_embedded[:, 0].min() - margin, latent_codes_embedded[:, 0].max() + margin) 134 | range_y = (latent_codes_embedded[:, 1].min() - margin, latent_codes_embedded[:, 1].max() + margin) 135 | 136 | plt.ioff() 137 | 138 | def create_plot(index, resolution=1080, filename=PLOT_FILE_NAME, dpi=100): 139 | frame_color = frame_colors[index, :] 140 | frame_color = (frame_color[0], frame_color[1], frame_color[2], 1.0) 141 | 142 | size_inches = resolution / dpi 143 | 144 | fig, ax = plt.subplots(1, figsize=(size_inches, size_inches), dpi=dpi) 145 | ax.set_position([0, 0, 1, 1]) 146 | plt.axis('off') 147 | ax.set_xlim(range_x) 148 | ax.set_ylim(range_y) 149 | 150 | ax.plot(frame_latent_codes_embedded[:, 0], frame_latent_codes_embedded[:, 1], c=(0.2, 0.2, 0.2, 1.0), zorder=1, linewidth=2) 151 | ax.scatter(latent_codes_embedded[:, 0], latent_codes_embedded[:, 1], c=colors[:SIZE], s = 10, zorder=0) 152 | ax.scatter(frame_latent_codes_embedded[index, 0], frame_latent_codes_embedded[index, 1], facecolors=frame_color, s = 200, linewidths=2, edgecolors=(0.1, 0.1, 0.1, 1.0), zorder=2) 153 | ax.scatter(latent_codes_embedded[indices, 0], latent_codes_embedded[indices, 1], facecolors=colors[indices, :], s = 140, linewidths=1, edgecolors=(0.1, 0.1, 0.1, 1.0), zorder=3) 154 | 155 | fig.savefig(filename, bbox_inches=Bbox([[0, 0], [size_inches, size_inches]]), dpi=dpi) 156 | plt.close(fig) 157 | 158 | frame_latent_codes = torch.tensor(frame_latent_codes, dtype=torch.float32, device=device) 159 | 160 | print("Rendering...") 161 | viewer = MeshRenderer(size=1080, start_thread=False) 162 | 163 | def render_frame(frame_index): 164 | viewer.model_color = frame_colors[frame_index, :] 165 | with torch.no_grad(): 166 | if USE_VAE: 167 | viewer.set_voxels(vae.decode(frame_latent_codes[frame_index, :])) 168 | else: 169 | viewer.set_mesh(sdf_net.get_mesh(frame_latent_codes[frame_index, :], voxel_resolution=128, sphere_only=True, level=SURFACE_LEVEL)) 170 | image_mesh = viewer.get_image(flip_red_blue=True) 171 | 172 | create_plot(frame_index) 173 | image_tsne = plt.imread(PLOT_FILE_NAME)[:, :, [2, 1, 0]] * 255 174 | 175 | image = np.concatenate((image_mesh, image_tsne), axis=1) 176 | 177 | cv2.imwrite("images/frame-{:05d}.png".format(frame_index), image) 178 | 179 | 180 | for frame_index in tqdm(range(SAMPLE_COUNT * TRANSITION_FRAMES)): 181 | render_frame(frame_index) 182 | frame_index += 1 183 | 184 | print("\n\nUse this command to create a video:\n") 185 | print('ffmpeg -framerate 30 -i images/frame-%05d.png -c:v libx264 -profile:v high -crf 19 -pix_fmt yuv420p video.mp4') -------------------------------------------------------------------------------- /model/sdf_net.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | import trimesh 3 | import skimage.measure 4 | from util import get_points_in_unit_sphere, get_voxel_coordinates 5 | import numpy as np 6 | 7 | class SDFVoxelizationHelperData(): 8 | def __init__(self, device, voxel_resolution, sphere_only=True): 9 | sample_points = get_voxel_coordinates(voxel_resolution) 10 | 11 | if sphere_only: 12 | unit_sphere_mask = np.linalg.norm(sample_points, axis=1) < 1.1 13 | sample_points = sample_points[unit_sphere_mask, :] 14 | self.unit_sphere_mask = unit_sphere_mask.reshape(voxel_resolution, voxel_resolution, voxel_resolution) 15 | 16 | self.sample_points = torch.tensor(sample_points, device=device) 17 | self.point_count = self.sample_points.shape[0] 18 | 19 | sdf_voxelization_helper = dict() 20 | 21 | SDF_NET_BREADTH = 256 22 | 23 | class SDFNet(SavableModule): 24 | def __init__(self, latent_code_size=LATENT_CODE_SIZE, device='cuda'): 25 | super(SDFNet, self).__init__(filename="sdf_net.to") 26 | self.layers1 = nn.Sequential( 27 | nn.Linear(in_features = 3 + latent_code_size, out_features = SDF_NET_BREADTH), 28 | nn.ReLU(inplace=True), 29 | 30 | nn.Linear(in_features = SDF_NET_BREADTH, out_features = SDF_NET_BREADTH), 31 | nn.ReLU(inplace=True), 32 | 33 | nn.Linear(in_features = SDF_NET_BREADTH, out_features = SDF_NET_BREADTH), 34 | nn.ReLU(inplace=True), 35 | 36 | nn.Linear(in_features = SDF_NET_BREADTH, out_features = SDF_NET_BREADTH), 37 | nn.ReLU(inplace=True) 38 | ) 39 | 40 | self.layers2 = nn.Sequential( 41 | nn.Linear(in_features = SDF_NET_BREADTH + latent_code_size + 3, out_features = SDF_NET_BREADTH), 42 | nn.ReLU(inplace=True), 43 | 44 | nn.Linear(in_features = SDF_NET_BREADTH, out_features = SDF_NET_BREADTH), 45 | nn.ReLU(inplace=True), 46 | 47 | nn.Linear(in_features = SDF_NET_BREADTH, out_features = SDF_NET_BREADTH), 48 | nn.ReLU(inplace=True), 49 | 50 | nn.Linear(in_features = SDF_NET_BREADTH, out_features = 1), 51 | nn.Tanh() 52 | ) 53 | 54 | self.to(device) 55 | 56 | def forward(self, points, latent_codes): 57 | input = torch.cat((points, latent_codes), dim=1) 58 | x = self.layers1(input) 59 | x = torch.cat((x, input), dim=1) 60 | x = self.layers2(x) 61 | return x.squeeze() 62 | 63 | def evaluate_in_batches(self, points, latent_code, batch_size=100000, return_cpu_tensor=True): 64 | latent_codes = latent_code.repeat(batch_size, 1) 65 | with torch.no_grad(): 66 | batch_count = points.shape[0] // batch_size 67 | if return_cpu_tensor: 68 | result = torch.zeros((points.shape[0])) 69 | else: 70 | result = torch.zeros((points.shape[0]), device=points.device) 71 | for i in range(batch_count): 72 | result[batch_size * i:batch_size * (i+1)] = self(points[batch_size * i:batch_size * (i+1), :], latent_codes) 73 | remainder = points.shape[0] - batch_size * batch_count 74 | result[batch_size * batch_count:] = self(points[batch_size * batch_count:, :], latent_codes[:remainder, :]) 75 | return result 76 | 77 | def get_voxels(self, latent_code, voxel_resolution, sphere_only=True, pad=True): 78 | if not (voxel_resolution, sphere_only) in sdf_voxelization_helper: 79 | helper_data = SDFVoxelizationHelperData(self.device, voxel_resolution, sphere_only) 80 | sdf_voxelization_helper[(voxel_resolution, sphere_only)] = helper_data 81 | else: 82 | helper_data = sdf_voxelization_helper[(voxel_resolution, sphere_only)] 83 | 84 | with torch.no_grad(): 85 | distances = self.evaluate_in_batches(helper_data.sample_points, latent_code).numpy() 86 | 87 | if sphere_only: 88 | voxels = np.ones((voxel_resolution, voxel_resolution, voxel_resolution), dtype=np.float32) 89 | voxels[helper_data.unit_sphere_mask] = distances 90 | else: 91 | voxels = distances.reshape(voxel_resolution, voxel_resolution, voxel_resolution) 92 | if pad: 93 | voxels = np.pad(voxels, 1, mode='constant', constant_values=1) 94 | 95 | return voxels 96 | 97 | def get_mesh(self, latent_code, voxel_resolution = 64, sphere_only = True, raise_on_empty=False, level=0): 98 | size = 2 99 | 100 | voxels = self.get_voxels(latent_code, voxel_resolution=voxel_resolution, sphere_only=sphere_only) 101 | voxels = np.pad(voxels, 1, mode='constant', constant_values=1) 102 | try: 103 | vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(voxels, level=level, spacing=(size / voxel_resolution, size / voxel_resolution, size / voxel_resolution)) 104 | except ValueError as value_error: 105 | if raise_on_empty: 106 | raise value_error 107 | else: 108 | return None 109 | 110 | vertices -= size / 2 111 | mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) 112 | return mesh 113 | 114 | def get_uniform_surface_points(self, latent_code, point_count=1000, voxel_resolution=64, sphere_only=True, level=0): 115 | mesh = self.get_mesh(latent_code, voxel_resolution=voxel_resolution, sphere_only=sphere_only, level=level) 116 | return mesh.sample(point_count) 117 | 118 | def get_normals(self, latent_code, points): 119 | if latent_code.requires_grad or points.requires_grad: 120 | raise Exception('get_normals may only be called with tensors that don\'t require grad.') 121 | 122 | points.requires_grad = True 123 | latent_codes = latent_code.repeat(points.shape[0], 1) 124 | sdf = self(points, latent_codes) 125 | sdf.backward(torch.ones(sdf.shape[0], device=self.device)) 126 | normals = points.grad 127 | normals /= torch.norm(normals, dim=1).unsqueeze(dim=1) 128 | return normals 129 | 130 | def get_surface_points(self, latent_code, sample_size=100000, sdf_cutoff=0.1, return_normals=False, use_unit_sphere=True): 131 | if use_unit_sphere: 132 | points = get_points_in_unit_sphere(n=sample_size, device=self.device) * 1.1 133 | else: 134 | points = torch.rand((sample_size, 3), device=self.device) * 2.2 - 1 135 | points.requires_grad = True 136 | latent_codes = latent_code.repeat(points.shape[0], 1) 137 | 138 | sdf = self(points, latent_codes) 139 | 140 | sdf.backward(torch.ones((sdf.shape[0]), device=self.device)) 141 | normals = points.grad 142 | normals /= torch.norm(normals, dim=1).unsqueeze(dim=1) 143 | points.requires_grad = False 144 | 145 | # Move points towards surface by the amount given by the signed distance 146 | points -= normals * sdf.unsqueeze(dim=1) 147 | 148 | # Discard points with truncated SDF values 149 | mask = (torch.abs(sdf) < sdf_cutoff) & torch.all(torch.isfinite(points), dim=1) 150 | points = points[mask, :] 151 | normals = normals[mask, :] 152 | 153 | if return_normals: 154 | return points, normals 155 | else: 156 | return points 157 | 158 | def get_surface_points_in_batches(self, latent_code, amount = 1000): 159 | result = torch.zeros((amount, 3), device=self.device) 160 | position = 0 161 | iteration_limit = 20 162 | while position < amount and iteration_limit > 0: 163 | points = self.get_surface_points(latent_code, sample_size=amount * 6) 164 | amount_used = min(amount - position, points.shape[0]) 165 | result[position:position+amount_used, :] = points[:amount_used, :] 166 | position += amount_used 167 | iteration_limit -= 1 168 | return result 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Adversarial Networks and Autoencoders for 3D Shapes 2 | 3 | ![Shapes generated with our propsed GAN architecture and reconstructed using Marching Cubes](https://raw.githubusercontent.com/marian42/shapegan/master/examples/gan_shapes.png) 4 | 5 | This repository provides code for the paper "[Adversarial Generation of Continuous Implicit Shape 6 | Representations](https://arxiv.org/abs/2002.00349)" and for my master thesis about generative machine learning models for 3D shapes. 7 | It contains: 8 | 9 | - the networks proposed in the paper (GANs with a DeepSDF network as the generator and a 3D CNN or Pointnet as discriminator) 10 | - an autoencoder, variational autoencoder and GANs for SDF voxel volumes using [3D CNNs](http://papers.nips.cc/paper/6096-learning-a-probabilistic-latent-space-of-object-shapes-via-3d-generative-adversarial-modeling.pdf) 11 | - an implementation of [the DeepSDF autodecoder](https://arxiv.org/pdf/1901.05103.pdf) that learns implicit function representations of 3D shapes 12 | - a GAN that uses a DeepSDF network as the generator and a 3D CNN as the discriminator ("Hybrid GAN", as proposed in the paper, but without progressive growing and without gradient penalty) 13 | - a data prepration pipeline that can prepare SDF datasets from triangle meshes, such as the [Shapenet dataset](https://www.shapenet.org/) (based on my [mesh_to_sdf](https://github.com/marian42/mesh_to_sdf) project) 14 | - a [ray marching](http://jamie-wong.com/2016/07/15/ray-marching-signed-distance-functions/) renderer to render signed distance fields given by a neural network, as well as a classic rasterized renderer to render triangle meshes reconstructed with Marching Cubes 15 | - tools to visualize the results 16 | 17 | Note that although the code provided here works, most of the scripts need some configuration to work for a specific task. 18 | 19 | This project uses two different ways to represent 3D shapes. 20 | These representations are *voxel volumes* and *implicit functions*. 21 | Both use [signed distances](https://en.wikipedia.org/wiki/Signed_distance_function). 22 | 23 | For both representations, there are networks that learn latent embeddings and then reconstruct objects from latent codes. 24 | These are the *autoencoder* and *variational autoencoder* for voxel volumes and the [*autodecoder* for the DeepSDF network](https://arxiv.org/pdf/1901.05103.pdf). 25 | 26 | In addition, for both representations, there are *generative adversarial networks* that learn to generate novel objects from random latent codes. 27 | The GANs come in a classic and a Wasserstein flavor. 28 | 29 | 30 | # Reproducing the paper 31 | 32 | This section explains how to reproduce the paper "Generative Adversarial Networks and Autoencoders for 3D Shapes". 33 | 34 | ## Data preparation 35 | 36 | To train the model, the meshes in the Shapenet dataset need to be voxelized for the voxel-based approach and converted to SDF point clouds for the point based approach. 37 | 38 | We provide readily prepared datasets for the Chairs, Airplanes and Sofas categories of Shapenet as a [download](https://ls7-data.cs.tu-dortmund.de/shape_net/ShapeNet_SDF.tar.gz). 39 | The size of that dataset is 71 GB. 40 | 41 | To prepare the data yourself, follow these steps: 42 | 43 | 1. install the `mesh_to_sdf` pip module. 44 | 2. Download the Shapenet files to the `data/shapenet/` directory or create an equivalent symlink. 45 | 3. Review the settings at the top of `prepare_shapenet_dataset.py`. 46 | The default settings are configured for reproducing the GAN paper, so you shouldn't need to change anything. 47 | You can change the dataset category that will be prepared, the default is the chairs category. 48 | You can disable preparation of either the voxel or point datasets if you don't need both. 49 | 4. Run `prepare_shapenet_dataset.py`. 50 | You can stop and resume this script and it will continue where it left off. 51 | 52 | ## Training 53 | 54 | ### Voxel-based discriminator 55 | 56 | To train the GAN with the 3D CNN discriminator, run 57 | 58 | python3 train_hybrid_progressive_gan.py iteration=0 59 | python3 train_hybrid_progressive_gan.py iteration=1 60 | python3 train_hybrid_progressive_gan.py iteration=2 61 | python3 train_hybrid_progressive_gan.py iteration=3 62 | 63 | This runs the four steps of progressive growing. 64 | Each iteration will start with the result of the previous iteration or the most recent result of the current iteration if the "continue" parameter is supplied. 65 | Add the `nogui` parameter to disable the model viewer during training. 66 | This parameter should be used when the script is run remotely. 67 | 68 | ### Point-based discriminator 69 | 70 | TODO 71 | 72 | Note that the pointnet-based approach currently has a separate implementation of the generator and doesn't work with the visualization scripts provided here. 73 | The two implementations will be merged soon so that the demos work. 74 | 75 | ## Use pretrained generator models 76 | 77 | In the `examples` directory, you find network parameters for the GAN generators trained on chairs, airplanes and sofas with the 3D CNN discriminator. 78 | You can use these by loading the generator from these files, i.e. in `demo_sdf_net.py` you can change `sdf_net.filename` accordingly. 79 | 80 | TODO: Examples for the pointnet-based GANs will be added soon. 81 | 82 | # Running other 3D deep learning models 83 | 84 | ## Data preparation 85 | 86 | Two data preparation scripts are available, `prepare_shapenet_dataset.py` is configured to work specifically with the Shapenet dataset. 87 | `prepare_data.py` can be used with any folder of 3D meshes. 88 | Both need to be configured depending on what data you want to prepare. 89 | Most of the time, not all types of data need to be prepared. 90 | For the DeepSDF network, you need SDF clouds. 91 | For the remaining networks, you need voxels of resolution 32. 92 | The "uniform" and "surface" datasets, as well as the voxels of other resolutions are only needed for the GAN paper (see the section above). 93 | 94 | ## Training 95 | 96 | Run any of the scripts that start with `train_` to train the networks. 97 | The `train_autoencoder.py` trains the variational autoencoder, unless the `classic` argument is supplied. 98 | All training scripts take these command line arguments: 99 | - `continue` to load existing parameters 100 | - `nogui` to not show the model viewer, which is useful for VMs 101 | - `show_slice` to show a text representation of the learned shape 102 | 103 | Progress is saved after each epoch. 104 | There is no stopping criterion. 105 | The longer you train, the better the result. 106 | You should have at least 8GB of GPU RAM available. 107 | Use a datacenter GPU, training on a desktop GPU will take several days to get good results. 108 | The classifiers take the least time to train and the GANs take the most time. 109 | 110 | ## Visualization 111 | 112 | To visualize the results, run any of the scripts starting with `demo_`. 113 | They might need to be configured depending on the model that was trained and the visualizations needed. 114 | The `create_plot.py` contains code to generate figures for my thesis. 115 | 116 | ## Using the pretrained DeepSDF model and recreating the latent space traversal animation 117 | 118 | This section explains how get a DeepSDF network model that was pre-trained on the Shapenet dataset and how to use it to recreate [this latent space traversal animation](https://twitter.com/marian42_/status/1188969971898048512). 119 | 120 | Since the model was trained, some network parameters have changed. 121 | If you're training a new model, you can use the parameters on the master branch and it will work as well. 122 | To be compatible with the pretrained model, you'll need the changes in the [`pretrained-deepsdf-shapenet`](https://github.com/marian42/shapegan/tree/pretrained-deepsdf-shapenet) branch. 123 | 124 | To generate the latent space animation, follow these steps: 125 | 126 | 1. Switch to the the [`pretrained-deepsdf-shapenet`](https://github.com/marian42/shapegan/tree/pretrained-deepsdf-shapenet) branch. 127 | 128 | 2. Move the contents of the `examples/deepsdf-shapenet-pretrained` directory to the project root directory. 129 | The scripts will look for the .to files in `/models` and `/data` relative to the project root. 130 | 131 | 3. Run `python3 demo_latent_space.py`. 132 | This takes about 40 minutes on my machine. 133 | To make it faster, you can lower the values of `SAMPLE_COUNT` and `TRANSITION_FRAMES` in `demo_latent_space.py`. 134 | 135 | 4. To render a video file from the frames, run `ffmpeg -framerate 30 -i images/frame-%05d.png -c:v libx264 -profile:v high -crf 19 -pix_fmt yuv420p video.mp4`. 136 | 137 | Note that after completing steps 1 and 2, you can run `python3 demo_sdf_net.py` to show a realtime latent space interpolation. -------------------------------------------------------------------------------- /rendering/raymarching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import math 5 | from tqdm import tqdm 6 | from PIL import Image 7 | import os 8 | 9 | from model.sdf_net import SDFNet, LATENT_CODES_FILENAME 10 | from util import device, ensure_directory 11 | from rendering.math import get_camera_transform 12 | from scipy.spatial.transform import Rotation 13 | 14 | BATCH_SIZE = 100000 15 | 16 | def get_default_coordinates(): 17 | camera_transform = get_camera_transform(2.2, 147, 20) 18 | camera_position = np.matmul(np.linalg.inv(camera_transform), np.array([0, 0, 0, 1]))[:3] 19 | light_matrix = get_camera_transform(6, 164, 50) 20 | light_position = np.matmul(np.linalg.inv(light_matrix), np.array([0, 0, 0, 1]))[:3] 21 | return camera_position, light_position 22 | 23 | camera_position, light_position = get_default_coordinates() 24 | 25 | def get_normals(sdf_net, points, latent_code): 26 | batch_count = points.shape[0] // BATCH_SIZE 27 | result = torch.zeros((points.shape[0], 3), device=points.device) 28 | for i in range(batch_count): 29 | result[BATCH_SIZE * i:BATCH_SIZE * (i+1), :] = sdf_net.get_normals(latent_code, points[BATCH_SIZE * i:BATCH_SIZE * (i+1), :]) 30 | 31 | if points.shape[0] > BATCH_SIZE * batch_count: 32 | result[BATCH_SIZE * batch_count:, :] = sdf_net.get_normals(latent_code, points[BATCH_SIZE * batch_count:, :]) 33 | return result 34 | 35 | 36 | def get_shadows(sdf_net, points, light_position, latent_code, threshold = 0.001, sdf_offset=0, radius=1.0): 37 | ray_directions = light_position[np.newaxis, :] - points 38 | ray_directions /= np.linalg.norm(ray_directions, axis=1)[:, np.newaxis] 39 | ray_directions_t = torch.tensor(ray_directions, device=device, dtype=torch.float32) 40 | points = torch.tensor(points, device=device, dtype=torch.float32) 41 | 42 | points += ray_directions_t * 0.1 43 | 44 | indices = torch.arange(points.shape[0]) 45 | shadows = torch.zeros(points.shape[0]) 46 | 47 | for i in tqdm(range(200)): 48 | test_points = points[indices, :] 49 | sdf = sdf_net.evaluate_in_batches(test_points, latent_code, return_cpu_tensor=False) + sdf_offset 50 | sdf = torch.clamp_(sdf, -0.1, 0.1) 51 | points[indices, :] += ray_directions_t[indices, :] * sdf.unsqueeze(1) 52 | 53 | hits = (sdf > 0) & (sdf < threshold) 54 | shadows[indices[hits]] = 1 55 | indices = indices[~hits] 56 | 57 | misses = points[indices, 1] > radius 58 | indices = indices[~misses] 59 | 60 | if indices.shape[0] < 2: 61 | break 62 | 63 | shadows[indices] = 1 64 | return shadows.cpu().numpy() 65 | 66 | 67 | def render_image(sdf_net, latent_code, resolution=800, threshold=0.0005, sdf_offset=0, iterations=1000, ssaa=2, radius=1.0, crop=False, color=(0.8, 0.1, 0.1), vertical_cutoff=None): 68 | camera_forward = camera_position / np.linalg.norm(camera_position) * -1 69 | camera_distance = np.linalg.norm(camera_position).item() 70 | up = np.array([0, 1, 0]) 71 | camera_right = np.cross(camera_forward, up) 72 | camera_right /= np.linalg.norm(camera_right) 73 | camera_up = np.cross(camera_forward, camera_right) 74 | camera_up /= np.linalg.norm(camera_up) 75 | 76 | screenspace_points = np.meshgrid( 77 | np.linspace(-1, 1, resolution * ssaa), 78 | np.linspace(-1, 1, resolution * ssaa), 79 | ) 80 | screenspace_points = np.stack(screenspace_points) 81 | screenspace_points = screenspace_points.reshape(2, -1).transpose() 82 | 83 | points = np.tile(camera_position, (screenspace_points.shape[0], 1)) 84 | points = points.astype(np.float32) 85 | 86 | focal_distance = 1.0 / math.tan(math.asin(radius / camera_distance)) 87 | ray_directions = screenspace_points[:, 0] * camera_right[:, np.newaxis] \ 88 | + screenspace_points[:, 1] * camera_up[:, np.newaxis] \ 89 | + focal_distance * camera_forward[:, np.newaxis] 90 | ray_directions = ray_directions.transpose().astype(np.float32) 91 | ray_directions /= np.linalg.norm(ray_directions, axis=1)[:, np.newaxis] 92 | 93 | b = np.einsum('ij,ij->i', points, ray_directions) * 2 94 | c = np.dot(camera_position, camera_position) - radius * radius 95 | distance_to_sphere = (-b - np.sqrt(np.power(b, 2) - 4 * c)) / 2 96 | indices = np.argwhere(np.isfinite(distance_to_sphere)).reshape(-1) 97 | 98 | points[indices] += ray_directions[indices] * distance_to_sphere[indices, np.newaxis] 99 | 100 | points = torch.tensor(points, device=device, dtype=torch.float32) 101 | ray_directions_t = torch.tensor(ray_directions, device=device, dtype=torch.float32) 102 | 103 | indices = torch.tensor(indices, device=device, dtype=torch.int64) 104 | model_mask = torch.zeros(points.shape[0], dtype=torch.uint8) 105 | 106 | for i in tqdm(range(iterations)): 107 | test_points = points[indices, :] 108 | sdf = sdf_net.evaluate_in_batches(test_points, latent_code, return_cpu_tensor=False) + sdf_offset 109 | torch.clamp_(sdf, -0.02, 0.02) 110 | points[indices, :] += ray_directions_t[indices, :] * sdf.unsqueeze(1) 111 | 112 | hits = (sdf > 0) & (sdf < threshold) 113 | model_mask[indices[hits]] = 1 114 | indices = indices[~hits] 115 | 116 | misses = torch.norm(points[indices, :], dim=1) > radius 117 | indices = indices[~misses] 118 | 119 | if indices.shape[0] < 2: 120 | break 121 | 122 | model_mask[indices] = 1 123 | 124 | if vertical_cutoff is not None: 125 | model_mask[points[:, 1] > vertical_cutoff] = 0 126 | model_mask[points[:, 1] < -vertical_cutoff] = 0 127 | 128 | normal = get_normals(sdf_net, points[model_mask], latent_code).cpu().numpy() 129 | 130 | model_mask = model_mask.cpu().numpy().astype(bool) 131 | points = points.cpu().numpy() 132 | model_points = points[model_mask] 133 | 134 | seen_by_light = 1.0 - get_shadows(sdf_net, model_points, light_position, latent_code, radius=radius, sdf_offset=sdf_offset) 135 | 136 | light_direction = light_position[np.newaxis, :] - model_points 137 | light_direction /= np.linalg.norm(light_direction, axis=1)[:, np.newaxis] 138 | 139 | diffuse = np.einsum('ij,ij->i', light_direction, normal) 140 | diffuse = np.clip(diffuse, 0, 1) * seen_by_light 141 | 142 | reflect = light_direction - np.einsum('ij,ij->i', light_direction, normal)[:, np.newaxis] * normal * 2 143 | reflect /= np.linalg.norm(reflect, axis=1)[:, np.newaxis] 144 | specular = np.einsum('ij,ij->i', reflect, ray_directions[model_mask, :]) 145 | specular = np.clip(specular, 0.0, 1.0) 146 | specular = np.power(specular, 20) * seen_by_light 147 | rim_light = -np.einsum('ij,ij->i', normal, ray_directions[model_mask, :]) 148 | rim_light = 1.0 - np.clip(rim_light, 0, 1) 149 | rim_light = np.power(rim_light, 4) * 0.3 150 | 151 | color = np.array(color)[np.newaxis, :] * (diffuse * 0.5 + 0.5)[:, np.newaxis] 152 | color += (specular * 0.3 + rim_light)[:, np.newaxis] 153 | 154 | color = np.clip(color, 0, 1) 155 | 156 | ground_points = ray_directions[:, 1] < 0 157 | ground_points[model_mask] = 0 158 | ground_points = np.argwhere(ground_points).reshape(-1) 159 | ground_plane = np.min(model_points[:, 1]).item() 160 | points[ground_points, :] -= ray_directions[ground_points, :] * ((points[ground_points, 1] - ground_plane) / ray_directions[ground_points, 1])[:, np.newaxis] 161 | ground_points = ground_points[np.linalg.norm(points[ground_points, ::2], axis=1) < 3] 162 | 163 | ground_shadows = get_shadows(sdf_net, points[ground_points, :], light_position, latent_code, sdf_offset=sdf_offset) 164 | 165 | pixels = np.ones((points.shape[0], 3)) 166 | pixels[model_mask] = color 167 | pixels[ground_points] -= ((1.0 - 0.65) * ground_shadows)[:, np.newaxis] 168 | pixels = pixels.reshape((resolution * ssaa, resolution * ssaa, 3)) 169 | 170 | if crop: 171 | from util import crop_image 172 | pixels = crop_image(pixels, background=1) 173 | 174 | image = Image.fromarray(np.uint8(pixels * 255) , 'RGB') 175 | 176 | if ssaa != 1: 177 | image = image.resize((resolution, resolution), Image.ANTIALIAS) 178 | 179 | return image 180 | 181 | 182 | def render_image_for_index(sdf_net, latent_codes, index, crop=False, resolution=800): 183 | ensure_directory('screenshots') 184 | FILENAME = 'screenshots/raymarching-examples/image-{:d}-{:d}.png' 185 | filename = FILENAME.format(index, resolution) 186 | 187 | if os.path.isfile(filename): 188 | return Image.open(filename) 189 | 190 | img = render_image(sdf_net, latent_codes[index], resolution=resolution, crop=crop) 191 | img.save(filename) 192 | return img -------------------------------------------------------------------------------- /prepare_shapenet_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | # Enable this when running on a computer without a screen: 3 | # os.environ['PYOPENGL_PLATFORM'] = 'egl' 4 | import trimesh 5 | from tqdm import tqdm 6 | import numpy as np 7 | from util import ensure_directory 8 | from multiprocessing import Pool 9 | import traceback 10 | from mesh_to_sdf import get_surface_point_cloud,scale_to_unit_cube, scale_to_unit_sphere, BadMeshException 11 | 12 | DATASET_NAME = 'chairs' 13 | DIRECTORY_MODELS = 'data/shapenet/03001627' 14 | MODEL_EXTENSION = '.obj' 15 | DIRECTORY_VOXELS = 'data/{:s}/voxels_{{:d}}/'.format(DATASET_NAME) 16 | DIRECTORY_UNIFORM = 'data/{:s}/uniform/'.format(DATASET_NAME) 17 | DIRECTORY_SURFACE = 'data/{:s}/surface/'.format(DATASET_NAME) 18 | DIRECTORY_SDF_CLOUD = 'data/{:s}/cloud/'.format(DATASET_NAME) 19 | DIRECTORY_BAD_MESHES = 'data/{:s}/bad_meshes/'.format(DATASET_NAME) 20 | 21 | # Voxel resolutions to create. 22 | # Set to [] if no voxels are needed. 23 | # Set to [32] for for all models except for the progressively growing DeepSDF/Voxel GAN 24 | VOXEL_RESOLUTIONS = [8, 16, 32, 64] 25 | 26 | CREATE_SDF_CLOUDS = False # For DeepSDF autodecoder, contains uniformly and non-uniformly sampled points as proposed in the DeepSDF paper 27 | CREATE_UNIFORM_AND_SURFACE = True # Uniformly sampled points for the Pointnet-based GAN and surface point clouds for the pointnet-based GAN with refinement 28 | 29 | SDF_POINT_CLOUD_SIZE = 200000 # For DeepSDF point clouds (CREATE_SDF_CLOUDS) 30 | POINT_CLOUD_SAMPLE_SIZE = 64**3 # For uniform and surface points (CREATE_UNIFORM_AND_SURFACE) 31 | 32 | # Options for virtual scans used to generate SDFs 33 | USE_DEPTH_BUFFER = True 34 | SCAN_COUNT = 50 35 | SCAN_RESOLUTION = 1024 36 | 37 | def get_model_files(): 38 | for directory, _, files in os.walk(DIRECTORY_MODELS): 39 | for filename in files: 40 | if filename.endswith(MODEL_EXTENSION): 41 | yield os.path.join(directory, filename) 42 | 43 | def get_hash(filename): 44 | return filename.split('/')[-3] 45 | 46 | def get_voxel_filename(model_filename, resolution): 47 | return os.path.join(DIRECTORY_VOXELS.format(resolution), get_hash(model_filename) + '.npy') 48 | 49 | def get_uniform_filename(model_filename): 50 | return os.path.join(DIRECTORY_UNIFORM, get_hash(model_filename) + '.npy') 51 | 52 | def get_surface_filename(model_filename): 53 | return os.path.join(DIRECTORY_SURFACE, get_hash(model_filename) + '.npy') 54 | 55 | def get_sdf_cloud_filename(model_filename): 56 | return os.path.join(DIRECTORY_SDF_CLOUD, get_hash(model_filename) + '.npy') 57 | 58 | def get_bad_mesh_filename(model_filename): 59 | return os.path.join(DIRECTORY_BAD_MESHES, get_hash(model_filename)) 60 | 61 | def mark_bad_mesh(model_filename): 62 | filename = get_bad_mesh_filename(model_filename) 63 | ensure_directory(os.path.dirname(filename)) 64 | open(filename, 'w').close() 65 | 66 | def is_bad_mesh(model_filename): 67 | return os.path.exists(get_bad_mesh_filename(model_filename)) 68 | 69 | def get_uniform_and_surface_points(surface_point_cloud, number_of_points = 200000): 70 | unit_sphere_points = np.random.uniform(-1, 1, size=(number_of_points * 2, 3)).astype(np.float32) 71 | unit_sphere_points = unit_sphere_points[np.linalg.norm(unit_sphere_points, axis=1) < 1] 72 | uniform_points = unit_sphere_points[:number_of_points, :] 73 | 74 | distances, indices = surface_point_cloud.kd_tree.query(uniform_points) 75 | uniform_sdf = distances.astype(np.float32).reshape(-1) * -1 76 | uniform_sdf[surface_point_cloud.is_outside(uniform_points)] *= -1 77 | 78 | surface_points = surface_point_cloud.points[indices[:, 0], :] 79 | near_surface_points = surface_points + np.random.normal(scale=0.0025, size=surface_points.shape).astype(np.float32) 80 | near_surface_sdf = surface_point_cloud.get_sdf(near_surface_points, use_depth_buffer=USE_DEPTH_BUFFER) 81 | 82 | model_size = np.count_nonzero(uniform_sdf < 0) / number_of_points 83 | if model_size < 0.01: 84 | raise BadMeshException() 85 | 86 | return uniform_points, uniform_sdf, near_surface_points, near_surface_sdf 87 | 88 | def process_model_file(filename): 89 | try: 90 | if is_bad_mesh(filename): 91 | return 92 | 93 | mesh = trimesh.load(filename) 94 | 95 | voxel_filenames = [get_voxel_filename(filename, resolution) for resolution in VOXEL_RESOLUTIONS] 96 | if not all(os.path.exists(f) for f in voxel_filenames): 97 | mesh_unit_cube = scale_to_unit_cube(mesh) 98 | surface_point_cloud = get_surface_point_cloud(mesh_unit_cube, bounding_radius=3**0.5, scan_count=SCAN_COUNT, scan_resolution=SCAN_RESOLUTION) 99 | try: 100 | for resolution in VOXEL_RESOLUTIONS: 101 | voxels = surface_point_cloud.get_voxels(resolution, use_depth_buffer=USE_DEPTH_BUFFER, check_result=True) 102 | np.save(get_voxel_filename(filename, resolution), voxels) 103 | del voxels 104 | 105 | except BadMeshException: 106 | tqdm.write("Skipping bad mesh. ({:s})".format(get_hash(filename))) 107 | mark_bad_mesh(filename) 108 | return 109 | del mesh_unit_cube, surface_point_cloud 110 | 111 | 112 | create_uniform_and_surface = CREATE_UNIFORM_AND_SURFACE and (not os.path.exists(get_uniform_filename(filename)) or not os.path.exists(get_surface_filename(filename))) 113 | create_sdf_clouds = CREATE_SDF_CLOUDS and not os.path.exists(get_sdf_cloud_filename(filename)) 114 | 115 | if create_uniform_and_surface or create_sdf_clouds: 116 | mesh_unit_sphere = scale_to_unit_sphere(mesh) 117 | surface_point_cloud = get_surface_point_cloud(mesh_unit_sphere, bounding_radius=1, scan_count=SCAN_COUNT, scan_resolution=SCAN_RESOLUTION) 118 | try: 119 | if create_uniform_and_surface: 120 | uniform_points, uniform_sdf, near_surface_points, near_surface_sdf = get_uniform_and_surface_points(surface_point_cloud, number_of_points=POINT_CLOUD_SAMPLE_SIZE) 121 | 122 | combined_uniform = np.concatenate((uniform_points, uniform_sdf[:, np.newaxis]), axis=1) 123 | np.save(get_uniform_filename(filename), combined_uniform) 124 | 125 | combined_surface = np.concatenate((near_surface_points, near_surface_sdf[:, np.newaxis]), axis=1) 126 | np.save(get_surface_filename(filename), combined_surface) 127 | 128 | if create_sdf_clouds: 129 | sdf_points, sdf_values = surface_point_cloud.sample_sdf_near_surface(use_scans=True, sign_method='depth' if USE_DEPTH_BUFFER else 'normal', number_of_points=SDF_POINT_CLOUD_SIZE, min_size=0.015) 130 | combined = np.concatenate((sdf_points, sdf_values[:, np.newaxis]), axis=1) 131 | np.save(get_sdf_cloud_filename(filename), combined) 132 | except BadMeshException: 133 | tqdm.write("Skipping bad mesh. ({:s})".format(get_hash(filename))) 134 | mark_bad_mesh(filename) 135 | return 136 | del mesh_unit_sphere, surface_point_cloud 137 | 138 | except: 139 | traceback.print_exc() 140 | 141 | 142 | def process_model_files(): 143 | for res in VOXEL_RESOLUTIONS: 144 | ensure_directory(DIRECTORY_VOXELS.format(res)) 145 | if CREATE_UNIFORM_AND_SURFACE: 146 | ensure_directory(DIRECTORY_UNIFORM) 147 | ensure_directory(DIRECTORY_SURFACE) 148 | if CREATE_SDF_CLOUDS: 149 | ensure_directory(DIRECTORY_SDF_CLOUD) 150 | ensure_directory(DIRECTORY_BAD_MESHES) 151 | 152 | files = list(get_model_files()) 153 | 154 | worker_count = os.cpu_count() // 2 155 | print("Using {:d} processes.".format(worker_count)) 156 | pool = Pool(worker_count) 157 | 158 | progress = tqdm(total=len(files)) 159 | def on_complete(*_): 160 | progress.update() 161 | 162 | for filename in files: 163 | pool.apply_async(process_model_file, args=(filename,), callback=on_complete) 164 | pool.close() 165 | pool.join() 166 | 167 | def combine_sdf_clouds(): 168 | import torch 169 | print("Combining SDF point clouds...") 170 | 171 | files = list(sorted(get_model_files())) 172 | files = [f for f in files if os.path.exists(get_sdf_cloud_filename(f))] 173 | 174 | N = len(files) 175 | points = torch.zeros((N * SDF_POINT_CLOUD_SIZE, 3)) 176 | sdf = torch.zeros((N * SDF_POINT_CLOUD_SIZE)) 177 | position = 0 178 | 179 | for file_name in tqdm(files): 180 | numpy_array = np.load(get_sdf_cloud_filename(file_name)) 181 | points[position * SDF_POINT_CLOUD_SIZE:(position + 1) * SDF_POINT_CLOUD_SIZE, :] = torch.tensor(numpy_array[:, :3]) 182 | sdf[position * SDF_POINT_CLOUD_SIZE:(position + 1) * SDF_POINT_CLOUD_SIZE] = torch.tensor(numpy_array[:, 3]) 183 | del numpy_array 184 | position += 1 185 | 186 | print("Saving combined SDF clouds...") 187 | torch.save(points, os.path.join('data', 'sdf_points.to')) 188 | torch.save(sdf, os.path.join('data', 'sdf_values.to')) 189 | 190 | if __name__ == '__main__': 191 | process_model_files() 192 | if CREATE_SDF_CLOUDS: 193 | combine_sdf_clouds() -------------------------------------------------------------------------------- /train_hybrid_progressive_gan.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import numpy as np 7 | import torch.autograd as autograd 8 | 9 | import random 10 | import time 11 | import sys 12 | from collections import deque 13 | from tqdm import tqdm 14 | 15 | from model.sdf_net import SDFNet 16 | from model.progressive_gan import Discriminator, LATENT_CODE_SIZE, RESOLUTIONS 17 | from util import create_text_slice, device, standard_normal_distribution, get_voxel_coordinates 18 | 19 | SDF_CLIPPING = 0.1 20 | from util import create_text_slice 21 | from datasets import VoxelDataset 22 | from torch.utils.data import DataLoader 23 | 24 | def get_parameter(name, default): 25 | for arg in sys.argv: 26 | if arg.startswith(name + '='): 27 | return arg[len(name) + 1:] 28 | return default 29 | 30 | 31 | ITERATION = int(get_parameter('iteration', 0)) 32 | # Continue with model parameters that were previously trained at the SAME iteration 33 | # Otherwise, it will use the model parameters of the previous iteration or initialize randomly at iteration 0 34 | CONTINUE = "continue" in sys.argv 35 | 36 | FADE_IN_EPOCHS = 10 37 | BATCH_SIZE = 16 38 | GRADIENT_PENALTY_WEIGHT = 10 39 | NUMBER_OF_EPOCHS = int(get_parameter('epochs', 250)) 40 | 41 | VOXEL_RESOLUTION = RESOLUTIONS[ITERATION] 42 | 43 | dataset = VoxelDataset.from_split('data/chairs/voxels_{:d}/{{:s}}.npy'.format(VOXEL_RESOLUTION), 'data/chairs/train.txt') 44 | data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) 45 | 46 | def get_generator_filename(iteration): 47 | return 'hybrid_progressive_gan_generator_{:d}.to'.format(iteration) 48 | 49 | generator = SDFNet(device='cpu') 50 | discriminator = Discriminator() 51 | if not CONTINUE and ITERATION > 0: 52 | generator.filename = get_generator_filename(ITERATION - 1) 53 | generator.load() 54 | discriminator.set_iteration(ITERATION - 1) 55 | discriminator.load() 56 | discriminator.set_iteration(ITERATION) 57 | generator.filename = get_generator_filename(ITERATION) 58 | if CONTINUE: 59 | generator.load() 60 | discriminator.load() 61 | 62 | if torch.cuda.device_count() > 1: 63 | print("Using dataparallel with {:d} GPUs.".format(torch.cuda.device_count())) 64 | generator_parallel = nn.DataParallel(generator) 65 | discriminator_parallel = nn.DataParallel(discriminator) 66 | else: 67 | generator_parallel = generator 68 | discriminator_parallel = discriminator 69 | 70 | generator_parallel.to(device) 71 | discriminator_parallel.to(device) 72 | 73 | LOG_FILE_NAME = "plots/hybrid_gan_training_{:d}.csv".format(ITERATION) 74 | first_epoch = 0 75 | if 'continue' in sys.argv: 76 | log_file_contents = open(LOG_FILE_NAME, 'r').readlines() 77 | first_epoch = len(log_file_contents) 78 | 79 | log_file = open(LOG_FILE_NAME, "a" if "continue" in sys.argv else "w") 80 | 81 | generator_optimizer = optim.RMSprop(generator_parallel.parameters(), lr=0.0001) 82 | discriminator_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.0001) 83 | 84 | show_viewer = "nogui" not in sys.argv 85 | 86 | if show_viewer: 87 | from rendering import MeshRenderer 88 | viewer = MeshRenderer() 89 | 90 | def sample_latent_codes(current_batch_size): 91 | latent_codes = standard_normal_distribution.sample(sample_shape=[current_batch_size, LATENT_CODE_SIZE]).to(device) 92 | latent_codes = latent_codes.repeat((1, 1, grid_points.shape[0])).reshape(-1, LATENT_CODE_SIZE) 93 | return latent_codes 94 | 95 | grid_points = get_voxel_coordinates(VOXEL_RESOLUTION, return_torch_tensor=True) 96 | grid_points_default_batch = grid_points.repeat((BATCH_SIZE, 1)) 97 | 98 | history_fake = deque(maxlen=50) 99 | history_real = deque(maxlen=50) 100 | history_gradient_penalty = deque(maxlen=50) 101 | 102 | def get_gradient_penalty(real_sample, fake_sample): 103 | alpha = torch.rand((real_sample.shape[0], 1, 1, 1), device=device).expand(real_sample.shape) 104 | 105 | interpolated_sample = alpha * real_sample + ((1 - alpha) * fake_sample) 106 | interpolated_sample.requires_grad = True 107 | 108 | discriminator_output = discriminator_parallel(interpolated_sample) 109 | 110 | gradients = autograd.grad(outputs=discriminator_output, inputs=interpolated_sample, grad_outputs=torch.ones(discriminator_output.shape).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] 111 | return ((gradients.norm(2, dim=(1,2,3)) - 1) ** 2).mean() * GRADIENT_PENALTY_WEIGHT 112 | 113 | def train(): 114 | progress = tqdm(total=NUMBER_OF_EPOCHS * (len(dataset) // BATCH_SIZE + 1), initial=first_epoch * (len(dataset) // BATCH_SIZE + 1)) 115 | 116 | for epoch in range(first_epoch, NUMBER_OF_EPOCHS): 117 | progress.desc = 'Epoch {:d}/{:d} ({:d}³)'.format(epoch, NUMBER_OF_EPOCHS, VOXEL_RESOLUTION) 118 | batch_index = 0 119 | epoch_start_time = time.time() 120 | for valid_sample in data_loader: 121 | try: 122 | if valid_sample.shape[0] == 1: # Skip final batch if it contains only one object 123 | continue 124 | valid_sample = valid_sample.to(device) 125 | current_batch_size = valid_sample.shape[0] 126 | if current_batch_size == BATCH_SIZE: 127 | batch_grid_points = grid_points_default_batch 128 | else: 129 | batch_grid_points = grid_points.repeat((current_batch_size, 1)) 130 | 131 | if not CONTINUE and ITERATION > 0: 132 | discriminator.fade_in_progress = (epoch + batch_index / (len(dataset) / BATCH_SIZE)) / FADE_IN_EPOCHS 133 | 134 | # train generator 135 | if batch_index % 5 == 0: 136 | generator_optimizer.zero_grad() 137 | 138 | latent_codes = sample_latent_codes(current_batch_size) 139 | fake_sample = generator_parallel(batch_grid_points, latent_codes) 140 | fake_sample = fake_sample.reshape(-1, VOXEL_RESOLUTION, VOXEL_RESOLUTION, VOXEL_RESOLUTION) 141 | if batch_index % 50 == 0 and show_viewer: 142 | viewer.set_voxels(fake_sample[0, :, :, :].squeeze().detach().cpu().numpy()) 143 | if batch_index % 50 == 0 and "show_slice" in sys.argv: 144 | tqdm.write(create_text_slice(fake_sample[0, :, :, :] / SDF_CLIPPING)) 145 | 146 | fake_discriminator_output = discriminator_parallel(fake_sample) 147 | fake_loss = -fake_discriminator_output.mean() 148 | fake_loss.backward() 149 | generator_optimizer.step() 150 | 151 | 152 | # train discriminator on fake samples 153 | discriminator_optimizer.zero_grad() 154 | latent_codes = sample_latent_codes(current_batch_size) 155 | fake_sample = generator_parallel(batch_grid_points, latent_codes) 156 | fake_sample = fake_sample.reshape(-1, VOXEL_RESOLUTION, VOXEL_RESOLUTION, VOXEL_RESOLUTION) 157 | discriminator_output_fake = discriminator_parallel(fake_sample) 158 | 159 | # train discriminator on real samples 160 | discriminator_output_valid = discriminator_parallel(valid_sample) 161 | 162 | gradient_penalty = get_gradient_penalty(valid_sample.detach(), fake_sample.detach()) 163 | loss = discriminator_output_fake.mean() - discriminator_output_valid.mean() + gradient_penalty 164 | loss.backward() 165 | 166 | discriminator_optimizer.step() 167 | 168 | history_fake.append(discriminator_output_fake.mean().item()) 169 | history_real.append(discriminator_output_valid.mean().item()) 170 | history_gradient_penalty.append(gradient_penalty.item()) 171 | batch_index += 1 172 | 173 | if "verbose" in sys.argv and batch_index % 50 == 0: 174 | tqdm.write("Epoch " + str(epoch) + ", batch " + str(batch_index) + 175 | ": D(x'): " + '{0:.4f}'.format(history_fake[-1]) + 176 | ", D(x): " + '{0:.4f}'.format(history_real[-1]) + 177 | ", loss: " + '{0:.4f}'.format(history_real[-1] - history_fake[-1]) + 178 | ", gradient penalty: " + '{0:.4f}'.format(gradient_penalty.item())) 179 | progress.update() 180 | except KeyboardInterrupt: 181 | if show_viewer: 182 | viewer.stop() 183 | return 184 | 185 | prediction_fake = np.mean(history_fake) 186 | prediction_real = np.mean(history_real) 187 | recent_gradient_penalty = np.mean(history_gradient_penalty) 188 | 189 | tqdm.write('Epoch {:d} ({:.1f}s), D(x\'): {:.4f}, D(x): {:.4f}, loss: {:4f}, gradient penalty: {:.4f}'.format( 190 | epoch, 191 | time.time() - epoch_start_time, 192 | prediction_fake, 193 | prediction_real, 194 | prediction_real - prediction_fake, 195 | recent_gradient_penalty)) 196 | 197 | generator.save() 198 | discriminator.save() 199 | 200 | if epoch % 10 == 0: 201 | generator.save(epoch=epoch) 202 | discriminator.save(epoch=epoch) 203 | 204 | if "show_slice" in sys.argv: 205 | latent_code = sample_latent_codes(1) 206 | slice_voxels = generator_parallel(grid_points, latent_code) 207 | slice_voxels = slice_voxels.reshape(VOXEL_RESOLUTION, VOXEL_RESOLUTION, VOXEL_RESOLUTION) 208 | tqdm.write(create_text_slice(slice_voxels / SDF_CLIPPING)) 209 | 210 | log_file.write('{:d} {:.1f} {:.4f} {:.4f} {:.4f}\n'.format(epoch, time.time() - epoch_start_time, prediction_fake, prediction_real, recent_gradient_penalty)) 211 | log_file.flush() 212 | 213 | 214 | train() 215 | log_file.close() 216 | -------------------------------------------------------------------------------- /rendering/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 3 | import pygame 4 | from pygame.locals import * 5 | from OpenGL.arrays import vbo 6 | import pygame.image 7 | 8 | from OpenGL.GL import * 9 | from OpenGL.GLU import * 10 | 11 | import numpy as np 12 | 13 | from rendering.binary_voxels_to_mesh import create_binary_voxel_mesh 14 | from rendering.shader import Shader 15 | 16 | import cv2 17 | import skimage.measure 18 | 19 | from threading import Thread, Lock 20 | import torch 21 | import trimesh 22 | import cv2 23 | 24 | from util import crop_image, ensure_directory 25 | from rendering.math import get_camera_transform 26 | 27 | CLAMP_TO_EDGE = 33071 28 | SHADOW_TEXTURE_SIZE = 1024 29 | 30 | DEFAULT_ROTATION = (147, 20) 31 | 32 | def create_shadow_texture(): 33 | texture_id = glGenTextures(1) 34 | glBindTexture(GL_TEXTURE_2D, texture_id) 35 | 36 | glTexImage2D( 37 | GL_TEXTURE_2D, 0, GL_DEPTH_COMPONENT, SHADOW_TEXTURE_SIZE, SHADOW_TEXTURE_SIZE, 0, GL_DEPTH_COMPONENT, 38 | GL_FLOAT, None 39 | ) 40 | 41 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST) 42 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST) 43 | 44 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, CLAMP_TO_EDGE) 45 | glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, CLAMP_TO_EDGE) 46 | glTexParameterfv( 47 | GL_TEXTURE_2D, GL_TEXTURE_BORDER_COLOR, 48 | np.ones(4).astype(np.float32) 49 | ) 50 | 51 | glBindTexture(GL_TEXTURE_2D, 0) 52 | return texture_id 53 | 54 | class MeshRenderer(): 55 | def __init__(self, size = 800, start_thread = True, background_color = (1, 1, 1, 1)): 56 | self.size = size 57 | 58 | self.mouse = None 59 | self.rotation = list(DEFAULT_ROTATION) 60 | 61 | self.vertex_buffer = None 62 | self.normal_buffer = None 63 | 64 | self.model_size = 1 65 | 66 | self.request_render = False 67 | 68 | self.running = True 69 | 70 | self.window = None 71 | 72 | self.background_color = background_color 73 | self.model_color = (0.8, 0.1, 0.1) 74 | 75 | self.shadow_framebuffer = None 76 | self.shadow_texture = None 77 | 78 | self.floor_vertices = None 79 | self.floor_normals = None 80 | 81 | self.ground_level = -1 82 | 83 | self.render_lock = Lock() 84 | 85 | self.dataset_directories = None 86 | 87 | if start_thread: 88 | thread = Thread(target = self._run) 89 | thread.start() 90 | else: 91 | self._initialize_opengl() 92 | 93 | def _update_buffers(self, vertices, normals): 94 | self.render_lock.acquire() 95 | if self.vertex_buffer is None: 96 | self.vertex_buffer = vbo.VBO(vertices) 97 | else: 98 | self.vertex_buffer.set_array(vertices) 99 | 100 | if self.normal_buffer is None: 101 | self.normal_buffer = vbo.VBO(normals) 102 | else: 103 | self.normal_buffer.set_array(normals) 104 | 105 | self.vertex_buffer_size = vertices.shape[0] 106 | self.request_render = True 107 | self.render_lock.release() 108 | 109 | 110 | def set_voxels(self, voxels, use_marching_cubes=True, shade_smooth=False, pad=True, level=0): 111 | if use_marching_cubes: 112 | if type(voxels) is torch.Tensor: 113 | if len(voxels.shape) > 3: 114 | voxels = voxels.squeeze() 115 | voxels = voxels.cpu().numpy() 116 | voxel_resolution = voxels.shape[1] 117 | if pad: 118 | voxels = np.pad(voxels, 1, mode='constant', constant_values=1) 119 | try: 120 | vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(voxels, level=level, spacing=(2.0 / voxel_resolution, 2.0 / voxel_resolution, 2.0 / voxel_resolution)) 121 | vertices = vertices[faces, :].astype(np.float32) - 1 122 | self.ground_level = np.min(vertices[:, 1]).item() 123 | 124 | if shade_smooth: 125 | normals = normals[faces, :].astype(np.float32) 126 | else: 127 | normals = np.cross(vertices[:, 1, :] - vertices[:, 0, :], vertices[:, 2, :] - vertices[:, 0, :]) 128 | normals = np.repeat(normals, 3, axis=0) 129 | 130 | self._update_buffers(vertices.reshape((-1)), normals.reshape((-1))) 131 | self.model_size = 1.4 132 | except ValueError: 133 | pass # Voxel array contains no sign change 134 | else: 135 | vertices, normals = create_binary_voxel_mesh(voxels) 136 | vertices -= (voxels.shape[0] + 1) / 2 137 | vertices /= voxels.shape[0] + 1 138 | self._update_buffers(vertices, normals) 139 | self.model_size = max([voxels.shape[0] + 1, voxels.shape[1] + 1, voxels.shape[2] + 1]) 140 | self.model_size = 0.75 141 | self.ground_level = np.min(vertices[1::3]).item() 142 | 143 | def set_mesh(self, mesh, smooth=False, center_and_scale=False): 144 | if mesh is None: 145 | return 146 | 147 | vertices = np.array(mesh.triangles, dtype=np.float32).reshape(-1, 3) 148 | 149 | if center_and_scale: 150 | vertices -= mesh.bounding_box.centroid[np.newaxis, :] 151 | vertices /= np.max(np.linalg.norm(vertices, axis=1)) 152 | 153 | self.ground_level = np.min(vertices[:, 1]).item() 154 | vertices = vertices.reshape((-1)) 155 | 156 | if smooth: 157 | normals = mesh.vertex_normals[mesh.faces.reshape(-1)].astype(np.float32) * -1 158 | else: 159 | normals = np.repeat(mesh.face_normals, 3, axis=0).astype(np.float32) 160 | 161 | self._update_buffers(vertices, normals) 162 | self.model_size = 1.08 163 | 164 | def _poll_mouse(self): 165 | left_mouse, _, right_mouse = pygame.mouse.get_pressed() 166 | pressed = left_mouse == 1 or right_mouse == 1 167 | current_mouse = pygame.mouse.get_pos() 168 | if self.mouse is not None and pressed: 169 | movement = (current_mouse[0] - self.mouse[0], current_mouse[1] - self.mouse[1]) 170 | self.rotation = [self.rotation[0] + movement[0], max(-90, min(90, self.rotation[1] + movement[1]))] 171 | self.mouse = current_mouse 172 | return pressed 173 | 174 | def _render_shadow_texture(self, light_vp_matrix): 175 | glBindFramebuffer(GL_FRAMEBUFFER, self.shadow_framebuffer) 176 | glBindTexture(GL_TEXTURE_2D, self.shadow_texture) 177 | glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, GL_TEXTURE_2D, self.shadow_texture, 0) 178 | glActiveTexture(GL_TEXTURE0) 179 | glBindTexture(GL_TEXTURE_2D, self.shadow_texture) 180 | glDrawBuffer(GL_NONE) 181 | glReadBuffer(GL_NONE) 182 | 183 | glClear(GL_DEPTH_BUFFER_BIT) 184 | glViewport(0, 0, SHADOW_TEXTURE_SIZE, SHADOW_TEXTURE_SIZE) 185 | glEnable(GL_DEPTH_TEST) 186 | glDepthMask(GL_TRUE) 187 | glDepthFunc(GL_LESS) 188 | glDepthRange(0.0, 1.0) 189 | glDisable(GL_CULL_FACE) 190 | glDisable(GL_BLEND) 191 | 192 | self.depth_shader.use() 193 | self.depth_shader.set_vp_matrix(light_vp_matrix) 194 | self._draw_mesh(use_normals=False) 195 | 196 | glBindFramebuffer(GL_FRAMEBUFFER, 0) 197 | 198 | def _draw_mesh(self, use_normals=True): 199 | if self.vertex_buffer is None or self.normal_buffer is None: 200 | return 201 | 202 | glEnableClientState(GL_VERTEX_ARRAY) 203 | self.vertex_buffer.bind() 204 | glVertexPointer(3, GL_FLOAT, 0, self.vertex_buffer) 205 | 206 | if use_normals: 207 | glEnableClientState(GL_NORMAL_ARRAY) 208 | self.normal_buffer.bind() 209 | glNormalPointer(GL_FLOAT, 0, self.normal_buffer) 210 | 211 | glDrawArrays(GL_TRIANGLES, 0, self.vertex_buffer_size) 212 | 213 | def _draw_floor(self): 214 | self.shader.set_y_offset(self.ground_level) 215 | 216 | glEnableClientState(GL_VERTEX_ARRAY) 217 | self.floor_vertices.bind() 218 | glVertexPointer(3, GL_FLOAT, 0, self.floor_vertices) 219 | 220 | glEnableClientState(GL_NORMAL_ARRAY) 221 | self.floor_normals.bind() 222 | glNormalPointer(GL_FLOAT, 0, self.floor_normals) 223 | 224 | glDrawArrays(GL_TRIANGLES, 0, 6) 225 | 226 | def _render(self): 227 | self.request_render = False 228 | self.render_lock.acquire() 229 | 230 | light_vp_matrix = get_camera_transform(6, self.rotation[0], 50, project=True) 231 | self._render_shadow_texture(light_vp_matrix) 232 | 233 | self.shader.use() 234 | self.shader.set_floor(False) 235 | self.shader.set_color(self.model_color) 236 | self.shader.set_y_offset(0) 237 | camera_vp_matrix = get_camera_transform(self.model_size * 2, self.rotation[0], self.rotation[1], project=True) 238 | self.shader.set_vp_matrix(camera_vp_matrix) 239 | self.shader.set_light_vp_matrix(light_vp_matrix) 240 | 241 | glClearColor(*self.background_color) 242 | glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) 243 | 244 | glClearDepth(1.0) 245 | glDepthMask(GL_TRUE) 246 | glDepthFunc(GL_LESS) 247 | glDepthRange(0.0, 1.0) 248 | glEnable(GL_CULL_FACE) 249 | glEnable(GL_DEPTH_TEST) 250 | glViewport(0, 0, self.size, self.size) 251 | 252 | glActiveTexture(GL_TEXTURE1) 253 | glBindTexture(GL_TEXTURE_2D, self.shadow_texture) 254 | self.shader.set_shadow_texture(1) 255 | 256 | self._draw_mesh() 257 | self.shader.set_floor(True) 258 | self._draw_floor() 259 | self.render_lock.release() 260 | 261 | def _initialize_opengl(self): 262 | pygame.init() 263 | pygame.display.set_caption('Model Viewer') 264 | pygame.display.gl_set_attribute(pygame.GL_MULTISAMPLEBUFFERS, 1) 265 | pygame.display.gl_set_attribute(pygame.GL_MULTISAMPLESAMPLES, 4) 266 | self.window = pygame.display.set_mode((self.size, self.size), pygame.OPENGLBLIT) 267 | 268 | self.shader = Shader() 269 | self.shader.initShader(open('rendering/vertex.glsl').read(), open('rendering/fragment.glsl').read()) 270 | 271 | self.shadow_framebuffer = glGenFramebuffers(1) 272 | self.shadow_texture = create_shadow_texture() 273 | 274 | self.depth_shader = Shader() 275 | self.depth_shader.initShader(open('rendering/depth_vertex.glsl').read(), open('rendering/depth_fragment.glsl').read()) 276 | 277 | self.prepare_floor() 278 | 279 | def prepare_floor(self): 280 | size = 6 281 | mesh = trimesh.Trimesh([ 282 | [-size, 0, -size], 283 | [-size, 0, +size], 284 | [+size, 0, +size], 285 | [-size, 0, -size], 286 | [+size, 0, +size], 287 | [+size, 0, -size] 288 | ], faces=[[0, 1, 2], [3, 4, 5]]) 289 | 290 | vertices = np.array(mesh.triangles, dtype=np.float32).reshape(-1, 3) 291 | vertices = vertices.reshape((-1)) 292 | normals = np.repeat(mesh.face_normals, 3, axis=0).astype(np.float32) 293 | 294 | self.floor_vertices = vbo.VBO(vertices) 295 | self.floor_normals = vbo.VBO(normals) 296 | 297 | def _run(self): 298 | self._initialize_opengl() 299 | self._render() 300 | 301 | while self.running: 302 | for event in pygame.event.get(): 303 | if event.type == pygame.QUIT: 304 | pygame.quit() 305 | return 306 | 307 | if event.type == pygame.KEYDOWN: 308 | if pygame.key.get_pressed()[pygame.K_F12]: 309 | self.save_screenshot() 310 | if pygame.key.get_pressed()[pygame.K_r]: 311 | self.rotation = list(DEFAULT_ROTATION) 312 | self.request_render = True 313 | 314 | if self._poll_mouse() or self.request_render: 315 | self._render() 316 | pygame.display.flip() 317 | 318 | pygame.time.wait(10) 319 | 320 | self.delete_buffers() 321 | 322 | def delete_buffers(self): 323 | for buffer in [self.normal_buffer, self.vertex_buffer]: 324 | if buffer is not None: 325 | buffer.delete() 326 | 327 | def stop(self): 328 | self.running = False 329 | 330 | def get_image(self, crop=False, output_size=None, greyscale=False, flip_red_blue=False): 331 | if self.request_render: 332 | self._render() 333 | if output_size is None: 334 | output_size = self.size 335 | 336 | string_image = pygame.image.tostring(self.window, 'RGB') 337 | image = pygame.image.fromstring(string_image, (self.size, self.size), 'RGB') 338 | if greyscale: 339 | array = np.transpose(pygame.surfarray.array3d(image)[:, :, 0]) 340 | else: 341 | array = np.transpose(pygame.surfarray.array3d(image)[:, :, (2, 1, 0) if flip_red_blue else slice(None)], (1, 0, 2)) 342 | 343 | if crop: 344 | array = crop_image(array) 345 | 346 | if output_size != self.size: 347 | array = cv2.resize(array, dsize=(output_size, output_size), interpolation=cv2.INTER_CUBIC) 348 | 349 | return array 350 | 351 | def save_screenshot(self): 352 | ensure_directory('screenshots') 353 | FILENAME_FORMAT = "screenshots/{:04d}.png" 354 | 355 | index = 0 356 | while os.path.isfile(FILENAME_FORMAT.format(index)): 357 | index += 1 358 | filename = FILENAME_FORMAT.format(index) 359 | image = self.get_image(flip_red_blue=True) 360 | cv2.imwrite(filename, image) 361 | print("Screenshot saved to " + filename + ".") -------------------------------------------------------------------------------- /create_plot.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file contains code to create various plots and figures. 3 | It is currently not maintained and not all of the figures work. 4 | But it can serve as inspiration for further development. 5 | ''' 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import torch 10 | import sys 11 | import os 12 | from tqdm import tqdm 13 | 14 | from model import LATENT_CODE_SIZE, LATENT_CODES_FILENAME 15 | import random 16 | from util import device 17 | 18 | class ImageGrid(): 19 | def __init__(self, width, height=1, cell_width = 3, cell_height = None, margin=0.2, create_viewer=True, crop=True): 20 | print("Plotting...") 21 | self.width = width 22 | self.height = height 23 | cell_height = cell_height if cell_height is not None else cell_width 24 | 25 | self.figure, self.axes = plt.subplots(height, width, 26 | figsize=(width * cell_width, height * cell_height), 27 | gridspec_kw={'left': 0, 'right': 1, 'top': 1, 'bottom': 0, 'wspace': margin, 'hspace': margin}) 28 | self.figure.patch.set_visible(False) 29 | 30 | self.crop = crop 31 | if create_viewer: 32 | from rendering import MeshRenderer 33 | self.viewer = MeshRenderer(start_thread=False) 34 | else: 35 | self.viewer = None 36 | 37 | def set_image(self, image, x = 0, y = 0): 38 | cell = self.axes[y, x] if self.height > 1 and self.width > 1 else self.axes[x + y] 39 | cell.imshow(image) 40 | cell.axis('off') 41 | cell.patch.set_visible(False) 42 | 43 | def set_voxels(self, voxels, x = 0, y = 0, color=None): 44 | if color is not None: 45 | self.viewer.model_color = color 46 | self.viewer.set_voxels(voxels) 47 | image = self.viewer.get_image(crop=self.crop) 48 | self.set_image(image, x, y) 49 | 50 | def save(self, filename): 51 | plt.axis('off') 52 | extent = self.figure.get_window_extent().transformed(self.figure.dpi_scale_trans.inverted()) 53 | plt.savefig(filename, bbox_inches=extent, dpi=400) 54 | if self.viewer is not None: 55 | self.viewer.delete_buffers() 56 | 57 | def load_autoencoder(is_variational=False): 58 | from model.autoencoder import Autoencoder 59 | autoencoder = Autoencoder(is_variational=is_variational) 60 | autoencoder.load() 61 | autoencoder.eval() 62 | return autoencoder 63 | 64 | def load_generator(is_wgan=False): 65 | from model.gan import Generator 66 | generator = Generator() 67 | if is_wgan: 68 | generator.filename = "wgan-generator.to" 69 | generator.load() 70 | generator.eval() 71 | return generator 72 | 73 | def load_sdf_net(filename=None, return_latent_codes = False): 74 | from model.sdf_net import SDFNet, LATENT_CODES_FILENAME 75 | sdf_net = SDFNet() 76 | if filename is not None: 77 | sdf_net.filename = filename 78 | sdf_net.load() 79 | sdf_net.eval() 80 | 81 | if return_latent_codes: 82 | latent_codes = torch.load(LATENT_CODES_FILENAME).to(device) 83 | latent_codes.requires_grad = False 84 | return sdf_net, latent_codes 85 | else: 86 | return sdf_net 87 | 88 | def create_tsne_plot(codes, voxels = None, labels = None, filename = "plot.pdf", indices=None): 89 | from sklearn.manifold import TSNE 90 | from matplotlib.offsetbox import OffsetImage, AnnotationBbox 91 | 92 | width, height = 40, 52 93 | 94 | print("Calculating t-sne embedding...") 95 | tsne = TSNE(n_components=2) 96 | embedded = tsne.fit_transform(codes) 97 | 98 | print("Plotting...") 99 | fig, ax = plt.subplots() 100 | plt.axis('off') 101 | margin = 0.0128 102 | plt.margins(margin * height / width, margin) 103 | 104 | x = embedded[:, 0] 105 | y = embedded[:, 1] 106 | x = np.interp(x, (x.min(), x.max()), (0, 1)) 107 | y = np.interp(y, (y.min(), y.max()), (0, 1)) 108 | 109 | ax.scatter(x, y, c=labels, s = 40, cmap='Set1') 110 | fig.set_size_inches(width, height) 111 | 112 | if voxels is not None: 113 | print("Creating images...") 114 | from rendering import MeshRenderer 115 | viewer = MeshRenderer(start_thread=False) 116 | for i in tqdm(range(voxels.shape[0])): 117 | viewer.set_voxels(voxels[i, :, :, :].cpu().numpy()) 118 | viewer.model_color = dataset.get_color(labels[i]) 119 | image = viewer.get_image(crop=True, output_size=128) 120 | box = AnnotationBbox(OffsetImage(image, zoom = 0.5, cmap='gray'), (x[i], y[i]), frameon=True) 121 | ax.add_artist(box) 122 | 123 | if indices is not None: 124 | print("Creating images...") 125 | dataset_directories = open('data/models.txt', 'r').readlines() 126 | from rendering import MeshRenderer 127 | viewer = MeshRenderer(start_thread=False) 128 | import trimesh 129 | import logging 130 | logging.getLogger('trimesh').setLevel(1000000) 131 | for i in tqdm(range(len(indices))): 132 | mesh = trimesh.load(os.path.join(dataset_directories[index].strip(), 'model_normalized.obj')) 133 | viewer.set_mesh(mesh, center_and_scale=True) 134 | viewer.model_color = dataset.get_color(labels[i]) 135 | image = viewer.get_image(crop=True, output_size=128) 136 | box = AnnotationBbox(OffsetImage(image, zoom = 0.5, cmap='gray'), (x[i], y[i]), frameon=True) 137 | ax.add_artist(box) 138 | 139 | print("Saving PDF...") 140 | 141 | extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) 142 | plt.savefig(filename, bbox_inches=extent, dpi=200) 143 | 144 | if "color-test" in sys.argv: 145 | from dataset import dataset as dataset 146 | dataset.load_voxels('cpu') 147 | dataset.load_labels() 148 | voxels = dataset.voxels 149 | 150 | COUNT = dataset.label_count 151 | 152 | plot = ImageGrid(COUNT) 153 | 154 | for label in tqdm(range(COUNT)): 155 | objects = (dataset.labels == label).nonzero() 156 | index = objects[random.randint(0, objects.shape[0] - 1)].item() 157 | tensor = voxels[index, :, :, :].clone() 158 | plot.set_voxels(tensor, label, color=dataset.get_color(label)) 159 | 160 | plot.save("plots/test.pdf") 161 | 162 | if "autoencoder-classes" in sys.argv: 163 | from dataset import dataset as dataset 164 | dataset.load_voxels(device) 165 | dataset.load_labels() 166 | 167 | COUNT = dataset.label_count 168 | 169 | vae = load_autoencoder(is_variational=True) 170 | indices = [] 171 | for label in tqdm(range(COUNT)): 172 | objects = (dataset.labels == label).nonzero() 173 | indices.append(objects[random.randint(0, objects.shape[0] - 1)].item()) 174 | voxels = dataset.voxels[indices, :, :, :] 175 | 176 | print("Generating codes...") 177 | with torch.no_grad(): 178 | codes_vae = vae.encode(voxels) 179 | reconstructed_vae = vae.decode(codes_vae).cpu().numpy() 180 | 181 | plot = ImageGrid(COUNT, 2) 182 | 183 | for i in range(COUNT): 184 | plot.set_voxels(voxels[i, :, :, :], i, 0, color=dataset.get_color(i)) 185 | plot.set_voxels(reconstructed_vae[i, :, :, :], i, 1) 186 | 187 | plot.save("plots/vae-reconstruction-classes.pdf") 188 | 189 | if "autodecoder-classes" in sys.argv: 190 | from dataset import dataset as dataset 191 | dataset.load_labels(device='cpu') 192 | from rendering.raymarching import render_image 193 | from rendering import MeshRenderer 194 | import logging 195 | logging.getLogger('trimesh').setLevel(1000000) 196 | 197 | viewer = MeshRenderer(start_thread=False) 198 | 199 | COUNT = dataset.label_count 200 | 201 | sdf_net, latent_codes = load_sdf_net(return_latent_codes=True) 202 | indices = [] 203 | for label in range(COUNT): 204 | objects = (dataset.labels == label).nonzero() 205 | indices.append(objects[random.randint(0, objects.shape[0] - 1)].item()) 206 | 207 | latent_codes = latent_codes[indices, :] 208 | 209 | plot = ImageGrid(COUNT, 2, create_viewer=False) 210 | dataset_directories = directories = open('data/models.txt', 'r').readlines() 211 | 212 | for i in range(COUNT): 213 | mesh = trimesh.load(os.path.join(dataset_directories[index].strip(), 'model_normalized.obj')) 214 | viewer.set_mesh(mesh, center_and_scale=True) 215 | viewer.model_color = dataset.get_color(i) 216 | image = viewer.get_image(crop=True) 217 | plot.set_image(image, i, 0) 218 | 219 | image = render_image(sdf_net, latent_codes[i, :], color=dataset.get_color(i), crop=True) 220 | plot.set_image(image, i, 1) 221 | viewer.delete_buffers() 222 | plot.save("plots/deepsdf-reconstruction-classes.pdf") 223 | 224 | if "autoencoder" in sys.argv: 225 | from dataset import dataset as dataset 226 | dataset.load_voxels(device) 227 | dataset.load_labels(device) 228 | 229 | indices = random.sample(list(range(dataset.size)), 1000) 230 | voxels = dataset.voxels[indices, :, :, :] 231 | autoencoder = load_autoencoder(is_variational='clasic' not in sys.argv) 232 | print("Generating codes...") 233 | with torch.no_grad(): 234 | codes = autoencoder.encode(voxels).cpu().numpy() 235 | create_tsne_plot(codes, voxels, dataset.labels[indices].cpu().numpy(), "plots/{:s}autoencoder-tsne.pdf".format('' if 'classic' in sys.argv else 'variational-')) 236 | 237 | if "autodecoder_tsne" in sys.argv: 238 | from dataset import dataset as dataset 239 | dataset.load_labels('cpu') 240 | dataset.load_labels() 241 | from model.sdf_net import LATENT_CODES_FILENAME 242 | latent_codes = torch.load(LATENT_CODES_FILENAME).detach().cpu().numpy() 243 | 244 | indices = random.sample(range(latent_codes.shape[0]), 1000) 245 | latent_codes = latent_codes[indices, :] 246 | labels = dataset.labels[indices] 247 | 248 | create_tsne_plot(latent_codes, labels=labels, filename="plots/deepsdf-tsne.pdf", indices=indices) 249 | 250 | 251 | if "autoencoder_hist" in sys.argv: 252 | import scipy.stats 253 | from dataset import dataset as dataset 254 | dataset.load_voxels(device) 255 | is_variational = 'classic' not in sys.argv 256 | 257 | x_range = 4 if is_variational else 1 258 | 259 | indices = random.sample(list(range(dataset.size)), min(5000, dataset.size)) 260 | voxels = dataset.voxels[indices, :, :, :] 261 | autoencoder = load_autoencoder(is_variational=is_variational) 262 | print("Generating codes...") 263 | with torch.no_grad(): 264 | autoencoder.train() 265 | codes = autoencoder.encode(voxels).cpu().numpy() 266 | 267 | print("Plotting...") 268 | plt.hist(codes[:, ::4], bins=100, range=(-x_range, x_range), histtype='step', density=1, color=['#1f77b4' for _ in range(0, codes.shape[1], 4)]) 269 | plt.xlabel("$\mathbf{z}^{(i)}$") 270 | plt.ylabel("relative abundance") 271 | plt.savefig("plots/{:s}autoencoder-histogram.pdf".format('variational-' if is_variational else ''), bbox_inches='tight') 272 | codes = codes.flatten() 273 | plt.clf() 274 | x = np.linspace(-x_range, x_range, 500) 275 | y = scipy.stats.norm.pdf(x, 0, 1) 276 | if is_variational: 277 | plt.plot(x, y, color='green') 278 | plt.hist(codes, bins=100, range=(-x_range, x_range), density=1) 279 | plt.xlabel("$\mathbf{z}$") 280 | plt.ylabel("relative abundance") 281 | plt.savefig("plots/{:s}autoencoder-histogram-combined.pdf".format('variational-' if is_variational else ''), bbox_inches='tight') 282 | 283 | if "autodecoder_hist" in sys.argv: 284 | import scipy.stats 285 | codes = torch.load(LATENT_CODES_FILENAME).cpu().detach().numpy() 286 | 287 | x_range = 0.42 288 | 289 | print("Plotting...") 290 | plt.hist(codes[:, ::4], bins=100, range=(-x_range, x_range), histtype='step', density=1, color=['#1f77b4' for _ in range(0, codes.shape[1], 4)]) 291 | plt.xlabel("$\mathbf{z}^{(i)}$") 292 | plt.ylabel("relative abundance") 293 | plt.savefig("plots/autodecoder-histogram.pdf", bbox_inches='tight') 294 | codes = codes.flatten() 295 | plt.clf() 296 | x = np.linspace(-x_range, x_range, 500) 297 | y = scipy.stats.norm.pdf(x, 0, 1) 298 | plt.hist(codes, bins=100, range=(-x_range, x_range), density=1) 299 | plt.xlabel("$\mathbf{z}$") 300 | plt.ylabel("relative abundance") 301 | plt.savefig("plots/autodecoder-histogram-combined.pdf", bbox_inches='tight') 302 | 303 | if "autoencoder_examples" in sys.argv: 304 | from dataset import dataset as dataset 305 | dataset.load_voxels(device) 306 | 307 | from rendering import MeshRenderer 308 | viewer = MeshRenderer(start_thread=False) 309 | 310 | indices = random.sample(list(range(dataset.size)), 20) 311 | voxels = dataset.voxels[indices, :, :, :] 312 | autoencoder = load_autoencoder() 313 | print("Generating codes...") 314 | with torch.no_grad(): 315 | codes = autoencoder.encode(voxels) 316 | reconstructed = autoencoder.decode(codes).cpu().numpy() 317 | codes = codes.cpu().numpy() 318 | 319 | print("Plotting...") 320 | fig, axs = plt.subplots(len(indices), 3, figsize=(10, 32)) 321 | for i in range(len(indices)): 322 | viewer.set_voxels(voxels[i, :, :, :].cpu().numpy()) 323 | image = viewer.get_image(output_size=512) 324 | axs[i, 0].imshow(image, cmap='gray') 325 | axs[i, 0].axis('off') 326 | 327 | axs[i, 1].bar(range(codes.shape[1]), codes[i, :]) 328 | axs[i, 1].set_ylim((-3, 3)) 329 | 330 | viewer.set_voxels(reconstructed[i, :, :, :]) 331 | image = viewer.get_image(output_size=512) 332 | axs[i, 2].imshow(image, cmap='gray') 333 | axs[i, 2].axis('off') 334 | plt.savefig("plots/autoencoder-examples.pdf", bbox_inches='tight', dpi=400) 335 | 336 | if "autoencoder_examples_2" in sys.argv: 337 | from dataset import dataset as dataset 338 | dataset.load_voxels(device) 339 | 340 | indices = random.sample(list(range(dataset.size)), 5) 341 | voxels = dataset.voxels[indices, :, :, :] 342 | ae = load_autoencoder(is_variational=False) 343 | vae = load_autoencoder(is_variational=True) 344 | 345 | print("Generating codes...") 346 | with torch.no_grad(): 347 | codes_ae = ae.encode(voxels) 348 | reconstructed_ae = ae.decode(codes_ae).cpu().numpy() 349 | codes_vae = vae.encode(voxels) 350 | reconstructed_vae = vae.decode(codes_vae).cpu().numpy() 351 | 352 | plot = ImageGrid(len(indices), 3) 353 | 354 | for i in range(len(indices)): 355 | plot.set_voxels(voxels[i, :, :, :], i, 0) 356 | plot.set_voxels(reconstructed_ae[i, :, :, :], i, 1) 357 | plot.set_voxels(reconstructed_vae[i, :, :, :], i, 2) 358 | 359 | plot.save("plots/ae-vae-examples.pdf") 360 | 361 | if "autoencoder_generate" in sys.argv: 362 | from dataset import dataset as dataset 363 | dataset.load_voxels(device) 364 | from sklearn.metrics import pairwise_distances 365 | 366 | SAMPLES = 5 367 | 368 | voxels = dataset.voxels 369 | ae = load_autoencoder(is_variational=False) 370 | vae = load_autoencoder(is_variational=True) 371 | print("Generating codes...") 372 | with torch.no_grad(): 373 | codes_ae = ae.encode(voxels).cpu().numpy() 374 | codes_vae = vae.encode(voxels).cpu().numpy() 375 | codes_ae_flattented = codes_ae.reshape(-1) 376 | codes_vae_flattented = codes_vae.reshape(-1) 377 | 378 | ae_distribution = torch.distributions.normal.Normal( 379 | np.mean(codes_ae_flattented), 380 | np.var(codes_ae_flattented) ** 0.5 381 | ) 382 | vae_distribution = torch.distributions.normal.Normal( 383 | np.mean(codes_vae_flattented), 384 | np.var(codes_vae_flattented) ** 0.5 385 | ) 386 | 387 | samples_ae = ae_distribution.sample([SAMPLES, LATENT_CODE_SIZE]).to(device) 388 | samples_vae = vae_distribution.sample([SAMPLES, LATENT_CODE_SIZE]).to(device) 389 | with torch.no_grad(): 390 | reconstructed_ae = ae.decode(samples_ae).cpu().numpy() 391 | reconstructed_vae = vae.decode(samples_vae).cpu().numpy() 392 | 393 | distances_ae = pairwise_distances(codes_ae, samples_ae.cpu().numpy(), metric='cosine') 394 | indices_ae = np.argmin(distances_ae, axis=0) 395 | reference_codes_ae = torch.tensor(codes_ae[indices_ae, :], device=device) 396 | with torch.no_grad(): 397 | reconstructed_references_ae = ae.decode(reference_codes_ae).cpu().numpy() 398 | 399 | distances_vae = pairwise_distances(codes_vae, samples_vae.cpu().numpy(), metric='cosine') 400 | indices_vae = np.argmin(distances_vae, axis=0) 401 | reference_codes_vae = torch.tensor(codes_vae[indices_vae, :], device=device) 402 | with torch.no_grad(): 403 | reconstructed_references_vae = vae.decode(reference_codes_vae).cpu().numpy() 404 | 405 | plot = ImageGrid(SAMPLES, 4) 406 | 407 | for i in range(SAMPLES): 408 | plot.set_voxels(reconstructed_ae[i, :, :, :], i, 0) 409 | plot.set_voxels(reconstructed_references_ae[i, :, :, :], i, 1) 410 | plot.set_voxels(reconstructed_vae[i, :, :, :], i, 2) 411 | plot.set_voxels(reconstructed_references_vae[i, :, :, :], i, 3) 412 | 413 | plot.save("plots/ae-vae-samples.pdf") 414 | 415 | if "autoencoder_interpolation" in sys.argv: 416 | from dataset import dataset as dataset 417 | dataset.load_voxels(device) 418 | voxels = dataset.voxels 419 | 420 | STEPS = 6 421 | 422 | indices = random.sample(list(range(dataset.size)), 2) 423 | print(indices) 424 | 425 | ae = load_autoencoder(is_variational=False) 426 | vae = load_autoencoder(is_variational=True) 427 | 428 | print("Generating codes...") 429 | with torch.no_grad(): 430 | codes_ae = torch.zeros([STEPS, LATENT_CODE_SIZE], device=device) 431 | codes_start_end = ae.encode(voxels[indices, :, :, :]) 432 | code_start = codes_start_end[0, :] 433 | code_end = codes_start_end[1, :] 434 | for i in range(STEPS): 435 | codes_ae[i, :] = code_start * (1.0 - i / (STEPS - 1)) + code_end * i / (STEPS - 1) 436 | reconstructed_ae = ae.decode(codes_ae) 437 | 438 | codes_vae = torch.zeros([STEPS, LATENT_CODE_SIZE], device=device) 439 | codes_start_end = vae.encode(voxels[indices, :, :, :]) 440 | code_start = codes_start_end[0, :] 441 | code_end = codes_start_end[1, :] 442 | for i in range(STEPS): 443 | codes_vae[i, :] = code_start * (1.0 - i / (STEPS - 1)) + code_end * i / (STEPS - 1) 444 | reconstructed_vae = vae.decode(codes_vae) 445 | 446 | plot = ImageGrid(STEPS, 2) 447 | 448 | for i in range(STEPS): 449 | plot.set_voxels(reconstructed_ae[i, :, :, :], i, 0) 450 | plot.set_voxels(reconstructed_vae[i, :, :, :], i, 1) 451 | 452 | plot.save("plots/ae-vae-interpolation.pdf") 453 | 454 | if "autoencoder_interpolation_2" in sys.argv: 455 | from dataset import dataset as dataset 456 | dataset.load_voxels(device) 457 | voxels = dataset.voxels 458 | 459 | STEPS = 6 460 | 461 | indices = random.sample(list(range(dataset.size)), 2) 462 | print(indices) 463 | 464 | vae = load_autoencoder(is_variational=True) 465 | 466 | print("Generating codes...") 467 | with torch.no_grad(): 468 | codes_vae = torch.zeros([STEPS, LATENT_CODE_SIZE], device=device) 469 | codes_start_end = vae.encode(voxels[indices, :, :, :]) 470 | code_start = codes_start_end[0, :] 471 | code_end = codes_start_end[1, :] 472 | for i in range(STEPS): 473 | codes_vae[i, :] = code_start * (1.0 - i / (STEPS - 1)) + code_end * i / (STEPS - 1) 474 | reconstructed_vae = vae.decode(codes_vae) 475 | 476 | plot = ImageGrid(STEPS) 477 | 478 | for i in range(STEPS): 479 | plot.set_voxels(reconstructed_vae[i, :, :, :], i) 480 | 481 | plot.save("plots/vae-interpolation.pdf") 482 | 483 | if "gan_tsne" in sys.argv: 484 | generator = load_generator(is_wgan='wgan' in sys.argv) 485 | from util import standard_normal_distribution 486 | 487 | shape = torch.Size([500, LATENT_CODE_SIZE]) 488 | x = standard_normal_distribution.sample(shape).to(device) 489 | with torch.no_grad(): 490 | voxels = generator(x).squeeze() 491 | codes = x.squeeze().cpu().numpy() 492 | filename = "plots/gan-images.pdf" if 'wgan' in sys.argv else "plots/wgan-images.pdf" 493 | create_tsne_plot(codes, voxels, labels = None, filename = filename) 494 | 495 | if "gan_examples" in sys.argv: 496 | generator = load_generator(is_wgan='wgan' in sys.argv) 497 | 498 | COUNT = 5 499 | with torch.no_grad(): 500 | voxels = generator.generate(sample_size=COUNT) 501 | 502 | plot = ImageGrid(COUNT) 503 | for i in range(COUNT): 504 | plot.set_voxels(voxels[i, :, :, :], i) 505 | 506 | filename = "plots/wgan-examples.pdf" if 'wgan' in sys.argv else "plots/gan-examples.pdf" 507 | plot.save(filename) 508 | 509 | if "gan_interpolation" in sys.argv: 510 | from util import standard_normal_distribution 511 | 512 | STEPS = 6 513 | 514 | generator = load_generator(is_wgan='wgan' in sys.argv) 515 | 516 | print("Generating codes...") 517 | with torch.no_grad(): 518 | codes = torch.zeros([STEPS, LATENT_CODE_SIZE], device=device) 519 | codes_start_end = standard_normal_distribution.sample((2, LATENT_CODE_SIZE)) 520 | code_start = codes_start_end[0, :] 521 | code_end = codes_start_end[1, :] 522 | for i in range(STEPS): 523 | codes[i, :] = code_start * (1.0 - i / (STEPS - 1)) + code_end * i / (STEPS - 1) 524 | voxels = generator(codes) 525 | 526 | plot = ImageGrid(STEPS) 527 | for i in range(STEPS): 528 | plot.set_voxels(voxels[i, :, :, :], i) 529 | 530 | filename = "plots/wgan-interpolation.pdf" if 'wgan' in sys.argv else "plots/gan-interpolation.pdf" 531 | plot.save(filename) 532 | 533 | def get_moving_average(data, window_size): 534 | moving_average = [] 535 | for i in range(data.shape[0] - window_size): 536 | moving_average.append(np.mean(data[i:i+window_size])) 537 | 538 | return np.arange(window_size / 2, data.shape[0] - window_size / 2, dtype=int), moving_average 539 | 540 | if "wgan_training" in sys.argv: 541 | data = np.genfromtxt('plots/wgan_training.csv', delimiter=' ') 542 | 543 | plt.ylim((-400, 1000)) 544 | plt.plot(data[:, 4], label="Assessment of real objects") 545 | plt.plot(data[:, 3], label="Assessment of fake objects") 546 | 547 | plt.xlabel('Epoch') 548 | plt.ylabel('Critic output') 549 | plt.legend() 550 | plt.savefig("plots/wgan-training-critic.pdf", bbox_inches='tight') 551 | 552 | if "sdf_training" in sys.argv: 553 | data = np.genfromtxt('plots/sdf_net_training.csv', delimiter=' ') 554 | 555 | plt.clf() 556 | plt.plot(np.arange(1, data.shape[0] + 1), data[:, 2], linestyle='-', linewidth=0.5, color='grey') 557 | plt.plot(np.arange(1, data.shape[0] + 1), data[:, 2], 'x') 558 | 559 | plt.ylabel('Loss') 560 | plt.xlabel('Epoch') 561 | plt.savefig("plots/deepsdf-training-loss.pdf", bbox_inches='tight') 562 | 563 | 564 | def create_autoencoder_training_plot(data_file, title, plot_file): 565 | if not os.path.isfile(data_file): 566 | return 567 | 568 | data = np.genfromtxt(data_file, delimiter=' ') 569 | 570 | #plt.yscale('log') 571 | max_reconstruction_loss = np.max(data[:, 2]) 572 | reconstruction_loss = data[:, 2] / max_reconstruction_loss 573 | kld_loss = data[:, 3] / max_reconstruction_loss 574 | voxel_error = data[:, 4] / np.max(data[:, 4]) 575 | #plt.axhline(y=data[-1, 2], color='black', linewidth=1) 576 | plt.plot(reconstruction_loss, label='Reconstruction loss ({:.3f})'.format(data[-1, 2])) 577 | #plt.plot(kld_loss, label='KLD loss ({:.3f})'.format(data[-1, 3])) 578 | plt.plot(voxel_error, label='Voxel error ({:.3f})'.format(data[-1, 4])) 579 | 580 | plt.xlabel('Epoch') 581 | plt.yticks([]) 582 | plt.title(title) 583 | plt.legend(loc='center right') 584 | plt.savefig(plot_file, bbox_inches='tight') 585 | plt.clf() 586 | 587 | def create_autoencoder_training_plot_latex(): 588 | data = np.genfromtxt('plots/variational_autoencoder_training.csv', delimiter=' ') 589 | 590 | plt.plot(data[:, 2], label='Reconstruction loss') 591 | plt.plot(data[:, 3], label='KLD loss') 592 | plt.xlabel('Epoch') 593 | plt.legend() 594 | plt.ylabel('Loss') 595 | plt.savefig('plots/vae-training-loss.pdf', bbox_inches='tight') 596 | 597 | plt.clf() 598 | plt.plot(data[:, 4]) 599 | plt.xlabel('Epoch') 600 | plt.ylabel('Voxel error') 601 | plt.savefig('plots/vae-training-error.pdf', bbox_inches='tight') 602 | 603 | plt.clf() 604 | 605 | if "autoencoder_training" in sys.argv: 606 | if 'latex' in sys.argv: 607 | create_autoencoder_training_plot_latex() 608 | else: 609 | create_autoencoder_training_plot('plots/autoencoder_training.csv', 'Autoencoder Training', 'plots/autoencoder-training.pdf') 610 | create_autoencoder_training_plot('plots/variational_autoencoder_training.csv', 'Variational Autoencoder Training', 'plots/variational-autoencoder-training.pdf') 611 | 612 | if "sdf_slice" in sys.argv: 613 | from mesh_to_sdf import mesh_to_sdf, scale_to_unit_sphere 614 | import trimesh 615 | import cv2 616 | 617 | model_filename = 'data/shapenet/03001627/6ae8076b0f9c74199c2009e4fd70d135/models/model_normalized.obj' 618 | 619 | print("Loading mesh...") 620 | mesh = trimesh.load(model_filename) 621 | mesh = scale_to_unit_sphere(mesh) 622 | 623 | resolution = 1280 624 | slice_position = 0.0 625 | clip = 0.1 626 | points = np.meshgrid( 627 | np.linspace(slice_position, slice_position, 1), 628 | np.linspace(1, -1, resolution), 629 | np.linspace(-1, 1, resolution) 630 | ) 631 | 632 | points = np.stack(points) 633 | points = points.reshape(3, -1).transpose() 634 | 635 | print("Calculating SDF values...") 636 | sdf = mesh_to_sdf(mesh, points) 637 | sdf = sdf.reshape(1, resolution, resolution) 638 | sdf = sdf[0, :, :] 639 | sdf = np.clip(sdf, -clip, clip) / clip 640 | 641 | print("Creating image...") 642 | image = np.ones((resolution, resolution, 3)) 643 | image[:,:,:2][sdf > 0] = (1.0 - sdf[sdf > 0])[:, np.newaxis] 644 | image[:,:,1:][sdf < 0] = (1.0 + sdf[sdf < 0])[:, np.newaxis] 645 | mask = np.abs(sdf) < 0.03 646 | image[mask, :] = 0 647 | image *= 255 648 | cv2.imwrite("plots/sdf_example.png", image) 649 | 650 | if "voxel_occupancy" in sys.argv: 651 | from dataset import dataset as dataset 652 | dataset.load_voxels(device) 653 | voxels = dataset.voxels.cpu() 654 | mask = voxels < 0 655 | occupied = torch.sum(mask, dim=[1, 2, 3]).numpy() 656 | 657 | plt.hist(occupied, bins=100, range=(0, 10000)) 658 | plt.savefig("plots/voxel-occupancy-histogram.pdf") 659 | 660 | if "model_images" in sys.argv: 661 | from rendering import MeshRenderer 662 | viewer = MeshRenderer(start_thread=False) 663 | import trimesh 664 | import cv2 665 | import logging 666 | 667 | logging.getLogger('trimesh').setLevel(1000000) 668 | 669 | filenames = open('data/sdf-clouds.txt', 'r').read().split("\n") 670 | index = 0 671 | 672 | for filename in tqdm(filenames): 673 | model_filename = filename.replace('sdf-pointcloud.npy', 'model_normalized.obj') 674 | image_filename = 'screenshots/sdf_meshes/{:d}.png'.format(index) 675 | index += 1 676 | if os.path.isfile(image_filename): 677 | continue 678 | 679 | mesh = trimesh.load(model_filename) 680 | viewer.set_mesh(mesh, center_and_scale=True) 681 | image = viewer.get_image(crop=False, output_size=viewer.size, greyscale=False) 682 | cv2.imwrite(image_filename, image) 683 | 684 | if "wgan-results" in sys.argv: 685 | from util import crop_image 686 | 687 | COUNT = 5 688 | 689 | plot = ImageGrid(COUNT, create_viewer=False) 690 | 691 | for i in range(COUNT): 692 | image = plt.imread('screenshots/wgan/{:d}.png'.format(i)) 693 | plot.set_image(crop_image(image, background=1), i) 694 | 695 | plot.save('plots/wgan-results.pdf') 696 | 697 | if 'sdf_net_reconstruction' in sys.argv: 698 | from rendering.raymarching import render_image_for_index 699 | from PIL import Image 700 | from util import crop_image 701 | sdf_net, latent_codes = load_sdf_net(return_latent_codes=True) 702 | 703 | COUNT = 5 704 | MESH_FILENAME = 'screenshots/sdf_meshes/{:d}.png' 705 | 706 | indices = random.sample(range(latent_codes.shape[0]), COUNT) 707 | print(indices) 708 | 709 | plot = ImageGrid(COUNT, 2, create_viewer=False) 710 | 711 | for i in range(COUNT): 712 | mesh = Image.open(MESH_FILENAME.format(indices[i])) 713 | mesh = np.array(mesh) 714 | mesh = crop_image(mesh) 715 | plot.set_image(mesh, i, 0) 716 | 717 | image = render_image_for_index(sdf_net, latent_codes, indices[i], crop=True) 718 | plot.set_image(image, i, 1) 719 | 720 | plot.save('plots/deepsdf-reconstruction.pdf') 721 | 722 | if "sdf_net_interpolation" in sys.argv: 723 | from rendering.raymarching import render_image_for_index, render_image 724 | sdf_net, latent_codes = load_sdf_net(return_latent_codes=True) 725 | 726 | STEPS = 6 727 | 728 | indices = random.sample(list(range(latent_codes.shape[0])), 2) 729 | print(indices) 730 | code_start = latent_codes[indices[0], :] 731 | code_end = latent_codes[indices[1], :] 732 | 733 | print("Generating codes...") 734 | with torch.no_grad(): 735 | codes = torch.zeros([STEPS, LATENT_CODE_SIZE], device=device) 736 | for i in range(STEPS): 737 | codes[i, :] = code_start * (1.0 - i / (STEPS - 1)) + code_end * i / (STEPS - 1) 738 | 739 | plot = ImageGrid(STEPS, create_viewer=False) 740 | 741 | for i in range(STEPS): 742 | plot.set_image(render_image(sdf_net, codes[i, :], crop=True), i) 743 | 744 | plot.save("plots/deepsdf-interpolation.pdf") 745 | 746 | if "sdf_net_sample" in sys.argv: 747 | from rendering.raymarching import render_image 748 | sdf_net, latent_codes = load_sdf_net(return_latent_codes=True) 749 | latent_codes_flattened = latent_codes.detach().reshape(-1).cpu().numpy() 750 | 751 | COUNT = 5 752 | 753 | mean, variance = np.mean(latent_codes_flattened), np.var(latent_codes_flattened) ** 0.5 754 | print("mean: ", mean) 755 | print("variance: ", variance) 756 | distribution = torch.distributions.normal.Normal(mean, variance) 757 | codes = distribution.sample([COUNT, LATENT_CODE_SIZE]).to(device) 758 | 759 | plot = ImageGrid(COUNT, create_viewer=False) 760 | 761 | for i in range(COUNT): 762 | plot.set_image(render_image(sdf_net, codes[i, :], crop=True), i) 763 | 764 | plot.save("plots/deepsdf-samples.pdf") 765 | 766 | if "hybrid_gan" in sys.argv: 767 | from rendering.raymarching import render_image 768 | from util import standard_normal_distribution 769 | generator = load_sdf_net(filename='hybrid_gan_generator.to') 770 | 771 | COUNT = 5 772 | 773 | codes = standard_normal_distribution.sample([COUNT, LATENT_CODE_SIZE]).to(device) 774 | 775 | plot = ImageGrid(COUNT, create_viewer=False) 776 | 777 | for i in range(COUNT): 778 | plot.set_image(render_image(generator, codes[i, :], radius=1.6, crop=True, sdf_offset=-0.045, vertical_cutoff=1), i) 779 | 780 | plot.save("plots/hybrid-gan-samples.pdf") 781 | 782 | 783 | if "hybrid_gan_interpolation" in sys.argv: 784 | from rendering.raymarching import render_image_for_index, render_image 785 | from util import standard_normal_distribution 786 | import cv2 787 | sdf_net = load_sdf_net(filename='hybrid_gan_generator.to') 788 | 789 | OPTIONS = 10 790 | 791 | codes = standard_normal_distribution.sample([OPTIONS, LATENT_CODE_SIZE]).to(device) 792 | for i in range(OPTIONS): 793 | image = render_image(sdf_net, codes[i, :], resolution=200, radius=1.6, sdf_offset=-0.045, vertical_cutoff=1, crop=True) 794 | image.save('plots/option-{:d}.png'.format(i)) 795 | 796 | STEPS = 6 797 | 798 | code_start = codes[int(input('Enter index for starting shape: ')), :] 799 | code_end = codes[int(input('Enter index for ending shape: ')), :] 800 | 801 | with torch.no_grad(): 802 | codes = torch.zeros([STEPS, LATENT_CODE_SIZE], device=device) 803 | for i in range(STEPS): 804 | codes[i, :] = code_start * (1.0 - i / (STEPS - 1)) + code_end * i / (STEPS - 1) 805 | 806 | plot = ImageGrid(STEPS, create_viewer=False) 807 | 808 | for i in range(STEPS): 809 | plot.set_image(render_image(sdf_net, codes[i, :], crop=True, radius=1.6, sdf_offset=-0.045, vertical_cutoff=1), i) 810 | 811 | plot.save("plots/hybrid-gan-interpolation.pdf") 812 | 813 | if "hybrid_gan_upscaling" in sys.argv: 814 | from rendering.raymarching import render_image_for_index, render_image 815 | from util import standard_normal_distribution 816 | sdf_net = load_sdf_net(filename='hybrid_gan_generator.to') 817 | 818 | code = standard_normal_distribution.sample([LATENT_CODE_SIZE]).to(device) 819 | 820 | plot = ImageGrid(4) 821 | 822 | voxels_32 = sdf_net.get_voxels(code, 32, sphere_only=False) 823 | plot.set_voxels(voxels_32, 0) 824 | voxels_32 = voxels_32[1:-2, 1:-2, 1:-2] 825 | 826 | import scipy.ndimage 827 | voxels_upscaled = scipy.ndimage.zoom(voxels_32, 4) 828 | voxels_upscaled = np.pad(voxels_upscaled, 1, mode='constant', constant_values=1) 829 | plot.set_voxels(voxels_upscaled, 1) 830 | 831 | voxels_128 = sdf_net.get_voxels(code, 128, sphere_only=False) 832 | plot.set_voxels(voxels_128, 2) 833 | 834 | plot.set_image(render_image(sdf_net, code, radius=1.6, crop=True, vertical_cutoff=1, sdf_offset=-0.045), 3) 835 | 836 | plot.save("plots/hybrid-gan-upscaling.pdf") 837 | 838 | if "shapenet-errors" in sys.argv: 839 | from PIL import Image 840 | from util import crop_image 841 | plot = ImageGrid(6, create_viewer=False) 842 | 843 | for i in range(6): 844 | image = Image.open('screenshots/errors/error-{:d}.png'.format(i+1)) 845 | image = np.array(image) 846 | image = crop_image(image) 847 | plot.set_image(image, i) 848 | 849 | plot.save("plots/errors.pdf") 850 | 851 | from model import CHECKPOINT_PATH 852 | 853 | if 'vae_checkpoints' in sys.argv: 854 | COUNT = 5 855 | checkpoints = os.listdir(CHECKPOINT_PATH) 856 | checkpoints = [i for i in checkpoints if i.startswith('variational-autoencoder-64-')] 857 | checkpoints = sorted(checkpoints) 858 | checkpoints = checkpoints[:6] 859 | checkpoints = [checkpoints[i * (len(checkpoints) - 1) // (COUNT - 1)] for i in range(COUNT)] 860 | print('\n'.join(checkpoints)) 861 | 862 | from dataset import dataset 863 | dataset.load_voxels(device=device) 864 | 865 | MODEL_INDEX = random.randint(0, dataset.voxels.shape[0]-1) 866 | print(MODEL_INDEX) 867 | model = dataset.voxels[(MODEL_INDEX, MODEL_INDEX), :, :, :] 868 | 869 | from model.autoencoder import Autoencoder 870 | vae = Autoencoder() 871 | vae.eval() 872 | 873 | plot = ImageGrid(COUNT) 874 | with torch.no_grad(): 875 | for i in range(COUNT): 876 | vae.load_state_dict(torch.load(os.path.join(CHECKPOINT_PATH, checkpoints[i]))) 877 | reconstructed, _, _ = vae(model) 878 | plot.set_voxels(reconstructed[0, :, :, :], i) 879 | 880 | plot.save('plots/vae-checkpoints.pdf') 881 | 882 | if 'sdf_checkpoints' in sys.argv: 883 | from rendering.raymarching import render_image 884 | COUNT = 5 885 | checkpoints = os.listdir(CHECKPOINT_PATH) 886 | checkpoints_network = [i for i in checkpoints if i.startswith('sdf_net-epoch-')] 887 | checkpoints_latent_codes = [i for i in checkpoints if i.startswith('sdf_net_latent_codes-epoch-')] 888 | checkpoints_network = sorted(checkpoints_network) 889 | checkpoints_latent_codes = sorted(checkpoints_latent_codes) 890 | indices = [i * (len(checkpoints_network) - 1) // (COUNT - 1) for i in range(COUNT)] 891 | 892 | checkpoints_network = [checkpoints_network[i] for i in indices] 893 | checkpoints_latent_codes = [checkpoints_latent_codes[i] for i in indices] 894 | print('\n'.join(checkpoints_network)) 895 | 896 | MODEL_INDEX = 1000 897 | print(MODEL_INDEX) 898 | 899 | from model.sdf_net import SDFNet 900 | sdf_net = SDFNet() 901 | sdf_net.eval() 902 | 903 | plot = ImageGrid(COUNT, create_viewer=False) 904 | for i in range(COUNT): 905 | sdf_net.load_state_dict(torch.load(os.path.join(CHECKPOINT_PATH, checkpoints_network[i]))) 906 | latent_codes = torch.load(os.path.join(CHECKPOINT_PATH, checkpoints_latent_codes[i])).detach() 907 | latent_code = latent_codes[MODEL_INDEX, :] 908 | plot.set_image(render_image(sdf_net, latent_code, crop=True), i) 909 | 910 | plot.save('plots/deepsdf-checkpoints.pdf') 911 | 912 | 913 | 914 | if "deepsdf-interpolation-stl" in sys.argv: 915 | from rendering.raymarching import render_image_for_index, render_image 916 | sdf_net, latent_codes = load_sdf_net(return_latent_codes=True) 917 | 918 | STEPS = 5 919 | 920 | indices = random.sample(list(range(latent_codes.shape[0])), 2) 921 | print(indices) 922 | code_start = latent_codes[indices[0], :] 923 | code_end = latent_codes[indices[1], :] 924 | 925 | print("Generating codes...") 926 | with torch.no_grad(): 927 | codes = torch.zeros([STEPS, LATENT_CODE_SIZE], device=device) 928 | for i in range(STEPS): 929 | codes[i, :] = code_start * (1.0 - i / (STEPS - 1)) + code_end * i / (STEPS - 1) 930 | 931 | for i in range(STEPS): 932 | print(i) 933 | mesh = sdf_net.get_mesh(codes[i, :], voxel_resolution=256, sphere_only=False) 934 | mesh.export('plots/mesh-{:d}.stl'.format(i)) --------------------------------------------------------------------------------