├── models ├── networks │ ├── __init__.py │ ├── utils.py │ ├── losses.py │ ├── networks.py │ └── spectral_norm.py ├── __init__.py ├── latent_object_model.py └── base_model.py ├── util ├── visualizer │ ├── __init__.py │ ├── wandb_visualizer.py │ ├── terminal_visualizer.py │ └── base_visualizer.py ├── __init__.py └── util.py ├── assets └── teaser.gif ├── .gitignore ├── options ├── __init__.py ├── test_options.py ├── train_options.py └── base_options.py ├── requirements.txt ├── configs ├── eval.yaml └── config.yaml ├── LICENSE ├── prepare_datasets.sh ├── data ├── nocs_hdf5_dataset.py ├── __init__.py └── base_dataset.py ├── README.md ├── nocs ├── eval.py └── aligning.py ├── train.py └── eval.py /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/visualizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuchen-ethz/neural_object_fitting/HEAD/assets/teaser.gif -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | datasets 3 | results 4 | wandb 5 | .vscode 6 | 7 | **.pyc 8 | **/__pycache__/ 9 | /.idea/ 10 | **.npy 11 | **.pkl 12 | **.csv -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml==5.4.1 2 | opencv-python==4.2.0.34 3 | envyaml==0.1912 4 | scipy==1.2.1 5 | h5py==2.8.0 6 | imageio==2.9.0 7 | matplotlib==3.4.1 8 | moviepy 9 | wandb 10 | tqdm 11 | 12 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | basic: 2 | checkpoints_dir: ./checkpoints 3 | dataroot: ./datasets/test 4 | gpu_ids: '0' 5 | project_name: neural_object_fitting 6 | run_name: fitting 7 | fitting: 8 | lambda_reg: 1 9 | n_init: 32 10 | n_iter: 50 11 | misc: 12 | load_suffix: latest 13 | verbose: false 14 | visualizers: 15 | - terminal 16 | - wandb 17 | model: 18 | init_gain: 0.02 19 | init_type: normal 20 | input_nc: 3 21 | model: latent_object 22 | output_nc: 3 23 | models: 24 | batch_size_vis: 8 25 | use_VAE: true 26 | z_dim: 16 27 | test: 28 | target_size: 64 29 | num_agent: 1 30 | id_agent: 0 31 | results_dir: ./results 32 | skip: 1 -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | basic: 2 | checkpoints_dir: ./checkpoints 3 | dataroot: ./datasets/train 4 | gpu_ids: '0' 5 | project_name: neural_object_fitting 6 | data: 7 | batch_size: 256 8 | crop_size: 64 9 | dataset_mode: nocs_hdf5 10 | load_size: 64 11 | max_dataset_size: .inf 12 | no_flip: true 13 | num_threads: 0 14 | preprocess: resize_and_crop 15 | serial_batches: false 16 | models: 17 | batch_size_vis: 8 18 | lambda_KL: 0.01 19 | lambda_recon: 10.0 20 | use_VAE: true 21 | z_dim: 16 22 | log: 23 | display_freq: 102400 24 | print_freq: 1 25 | misc: 26 | load_suffix: latest 27 | verbose: false 28 | visualizers: 29 | - terminal 30 | - wandb 31 | model: 32 | init_gain: 0.02 33 | init_type: normal 34 | input_nc: 3 35 | model: latent_object 36 | output_nc: 3 37 | save: 38 | epoch_count: 1 39 | save_by_iter: false 40 | save_epoch_freq: 5 41 | save_latest_freq: 102400 42 | train: 43 | lr: 0.003 44 | lr_decay_iters: 50 45 | lr_policy: linear 46 | n_views: 2592 47 | niter: 100 48 | niter_decay: 100 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 ETH Zurich, Xu Chen, Zijian Dong 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /prepare_datasets.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | unzip 3 | 4 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/checkpoints.zip 5 | unzip checkpoints.zip 6 | rm checkpoints.zip 7 | 8 | mkdir -p datasets/test 9 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/nocs_det.zip -P ./datasets/test/ 10 | unzip ./datasets/test/nocs_det.zip -d ./datasets/test/ 11 | rm ./datasets/test/nocs_det.zip 12 | 13 | wget http://download.cs.stanford.edu/orion/nocs/real_test.zip -P ./datasets/test/ 14 | unzip ./datasets/test/real_test.zip -d ./datasets/test/ 15 | rm ./datasets/test/real_test.zip 16 | 17 | mkdir -p datasets/train 18 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/bottle.hdf5 -P ./datasets/train/ 19 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/bowl.hdf5 -P ./datasets/train/ 20 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/camera.hdf5 -P ./datasets/train/ 21 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/can.hdf5 -P ./datasets/train/ 22 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/laptop.hdf5 -P ./datasets/train/ 23 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/mug.hdf5 -P ./datasets/train/ 24 | 25 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | 13 | test_args = parser.add_argument_group('test') 14 | 15 | test_args.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 16 | 17 | test_args.add_argument('--target_size', type=int, default=64, help='resize the test images to this size') 18 | test_args.add_argument('--vis', action='store_true', help='visualize the fitting results') 19 | 20 | test_args.add_argument('--num_agent', type=int, default=10, help='number of evaluation agents running in parallel') 21 | test_args.add_argument('--id_agent', type=int, default=0, help='the id of current agents') 22 | 23 | test_args.add_argument('--test_name', type=str, default='fitting', help='test name') 24 | test_args.add_argument('--skip', type=int, default=1, help='evaluate every n-th sample') 25 | 26 | # rewrite devalue values 27 | test_args.set_defaults(model='test') 28 | # To avoid cropping, the load_size should be the same as crop_size 29 | test_args.set_defaults(load_size=parser.get_default('crop_size')) 30 | 31 | self.isTrain = False 32 | return parser 33 | -------------------------------------------------------------------------------- /models/networks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def init_variable(dim, n_init, device, mode='random',range=[0,1],value=1): 5 | 6 | shape = (n_init,dim) 7 | var = torch.ones( shape,requires_grad=True,device=device,dtype=torch.float) 8 | if mode == 'random': 9 | var.data = torch.rand(shape,device=device) * (range[1]-range[0]) + range[0] 10 | elif mode == 'linspace': 11 | var.data = torch.linspace(range[0],range[1],steps=n_init,device=device).unsqueeze(-1) 12 | elif mode == 'constant': 13 | var.data =value*var.data 14 | else: 15 | raise NotImplementedError 16 | return var 17 | 18 | def grid_sample(image, grid, mode='bilinear',padding_mode='constant',padding_value=1): 19 | image_out = F.grid_sample(image, grid, mode=mode, padding_mode='border') 20 | if padding_mode == 'constant': 21 | out_of_bound = grid[:, :, :, 0] > 1 22 | out_of_bound += grid[:, :, :, 0] < -1 23 | out_of_bound += grid[:, :, :, 1] > 1 24 | out_of_bound += grid[:, :, :, 1] < -1 25 | out_of_bound = out_of_bound.unsqueeze(1).expand(image_out.shape) 26 | image_out[out_of_bound] = padding_value 27 | return image_out 28 | 29 | def warping_grid(angle, transx, transy, scale, image_shape): 30 | cosz = torch.cos(angle) 31 | sinz = torch.sin(angle) 32 | affine_mat = torch.cat( [cosz, -sinz, transx, 33 | sinz, cosz, transy], dim=1).view(image_shape[0], 2, 3) 34 | scale = scale.view(-1,1,1).expand(affine_mat.shape) 35 | return F.affine_grid(size=image_shape, theta=scale*affine_mat) 36 | 37 | def set_axis(ax): 38 | ax.clear() 39 | ax.xaxis.set_visible(False) 40 | ax.spines['right'].set_visible(False) 41 | ax.spines['top'].set_visible(False) 42 | ax.grid(axis='y') -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visualization parameters 13 | log_args = parser.add_argument_group('log') 14 | log_args.add_argument('--display_freq', type=int, default=10, help='frequency of showing training results on screen') 15 | log_args.add_argument('--print_freq', type=int, default=1, help='frequency of showing training results on console') 16 | # network saving and loading parameters 17 | save_args = parser.add_argument_group('save') 18 | save_args.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 19 | save_args.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 20 | save_args.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 21 | save_args.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 22 | # training parameters 23 | train_args = parser.add_argument_group('train') 24 | train_args.add_argument('--niter', type=int, default=15, help='# of iter at starting learning rate') 25 | train_args.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero') 26 | train_args.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 27 | train_args.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 28 | train_args.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 29 | train_args.add_argument('--n_views', type=int, default=2592, help='number of training views per sample') 30 | self.isTrain = True 31 | return parser 32 | -------------------------------------------------------------------------------- /util/visualizer/wandb_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import wandb 5 | 6 | from util import util 7 | from util.visualizer.base_visualizer import BaseVisualizer 8 | 9 | 10 | class WandbVisualizer(BaseVisualizer): 11 | """This class includes several functions that can display/save images and print/save logging information. 12 | 13 | It uses a Python library 'wandb' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 14 | """ 15 | 16 | @staticmethod 17 | def modify_commandline_options(parser): 18 | return parser 19 | 20 | def __init__(self, opt): 21 | self.opt = opt # cache the option 22 | config_file = os.path.join(opt.checkpoints_dir,opt.project_name, opt.exp_name, opt.run_name, 'config.yaml') 23 | import yaml 24 | with open(config_file, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) 25 | wandb.init(project=opt.project_name, name=opt.run_name, group=opt.exp_name, config=config) 26 | 27 | def update_state(self,epochs,iters,times): 28 | self.epochs = epochs 29 | self.iters = iters 30 | self.times = times 31 | 32 | def display_current_results(self, visuals): 33 | from torchvision.utils import make_grid 34 | visual_wandb = {} 35 | for key, image_tensor in visuals.items(): 36 | visual_wandb[key] = wandb.Image(util.tensor2im(make_grid(image_tensor).unsqueeze(0))) 37 | wandb.log(visual_wandb,step=self.iters) 38 | 39 | def display_current_videos(self, visuals): 40 | from torchvision.utils import make_grid 41 | video_wandb = {} 42 | for label, visual in visuals.items(): 43 | frames = [] 44 | for frame in visual: 45 | image = util.tensor2im(make_grid(frame).unsqueeze(0)) 46 | frames.append(image) 47 | gif = np.stack(frames, axis=0) 48 | gif = np.transpose(gif, (0, 3, 1, 2)) 49 | 50 | video_wandb[label] = wandb.Video(gif, fps=20) 51 | print('hello') 52 | 53 | wandb.log(video_wandb,step=self.iters) 54 | 55 | def plot_current_losses(self, losses): 56 | wandb.log(losses,step=self.iters) -------------------------------------------------------------------------------- /data/nocs_hdf5_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from data.base_dataset import BaseDataset, get_transform 3 | import random 4 | import numpy as np 5 | import h5py 6 | from PIL import Image 7 | import os 8 | class NOCSHDF5Dataset(BaseDataset): 9 | 10 | 11 | def __init__(self, opt): 12 | """Initialize this dataset class. 13 | 14 | Parameters: 15 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 16 | """ 17 | BaseDataset.__init__(self, opt) 18 | 19 | input_nc = self.opt.output_nc 20 | self.transform = get_transform(opt, grayscale=(input_nc == 1),method=Image.BILINEAR) 21 | 22 | hdf5_file = h5py.File(os.path.join(opt.dataroot,opt.category+'.hdf5'),'r',swmr=True) 23 | self.images = hdf5_file['images'] 24 | self.poses = hdf5_file['poses'][...] 25 | 26 | self.dataset_size = self.poses.shape[0] 27 | self.num_view = opt.n_views 28 | self.num_model = self.dataset_size // self.num_view 29 | 30 | def __getitem__(self, index): 31 | """Return a data point and its metadata information. 32 | 33 | Parameters: 34 | index - - a random integer for data indexing 35 | 36 | Returns a dictionary that contains A and A_paths 37 | A(tensor) - - an image in one domain 38 | A_paths(str) - - the path of the image 39 | """ 40 | model_id = random.randint(0,self.num_model-1) 41 | 42 | image_id = random.randint(0,self.num_view-1) 43 | id = model_id*self.num_view + image_id 44 | A_img = np.copy(self.images[id,:,:,:]) 45 | elev,azi = np.copy(self.poses[id,:]) 46 | if A_img.shape[2] == 4: 47 | A_mask = A_img[:,:,-1] == 0 48 | A_img[A_mask,:3] = 255 49 | A_img = A_img[:,:,:3] 50 | 51 | A = self.transform(Image.fromarray(A_img.astype(np.uint8))) 52 | A_pose = np.array([elev,azi,0]).astype(np.float32) 53 | 54 | image_id = random.randint(0,self.num_view-1) 55 | id = model_id*self.num_view + image_id 56 | B_img = np.copy(self.images[id,:,:,:]) 57 | elev,azi = np.copy(self.poses[id,:]) 58 | if B_img.shape[2] == 4: 59 | B_mask = B_img[:,:,-1] == 0 60 | B_img[B_mask,:3] = 255 61 | B_img = B_img[:,:,:3] 62 | 63 | B = self.transform(Image.fromarray(B_img.astype(np.uint8))) 64 | B_pose = np.array([elev,azi,0]).astype(np.float32) 65 | 66 | return {'A': A, 'A_pose': A_pose, 67 | 'B': B, 'B_pose': B_pose,} 68 | def __len__(self): 69 | """Return the total number of images in the dataset.""" 70 | return self.dataset_size -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | 23 | from models.base_model import BaseModel 24 | 25 | 26 | def find_model_using_name(model_name): 27 | """Import the module "models/[model_name]_model.py". 28 | 29 | In the file, the class called DatasetNameModel() will 30 | be instantiated. It has to be a subclass of BaseModel, 31 | and it is case-insensitive. 32 | """ 33 | model_filename = "models." + model_name + "_model" 34 | modellib = importlib.import_module(model_filename) 35 | model = None 36 | target_model_name = model_name.replace('_', '') + 'model' 37 | for name, cls in modellib.__dict__.items(): 38 | if name.lower() == target_model_name.lower() \ 39 | and issubclass(cls, BaseModel): 40 | model = cls 41 | 42 | if model is None: 43 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 44 | exit(0) 45 | 46 | return model 47 | 48 | 49 | def get_option_setter(model_name): 50 | """Return the static method of the model class.""" 51 | model_class = find_model_using_name(model_name) 52 | return model_class.modify_commandline_options 53 | 54 | 55 | def create_model(opt): 56 | """Create a model given the option. 57 | 58 | This function warps the class CustomDatasetDataLoader. 59 | This is the main interface between this package and 'train.py'/'test.py' 60 | 61 | Example: 62 | >>> from models import create_model 63 | >>> model = create_model(opt) 64 | """ 65 | model = find_model_using_name(opt.model) 66 | instance = model(opt) 67 | print("model [%s] was created" % type(instance).__name__) 68 | return instance 69 | -------------------------------------------------------------------------------- /util/visualizer/terminal_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from util import util 5 | from util.visualizer.base_visualizer import BaseVisualizer 6 | 7 | 8 | class TerminalVisualizer(BaseVisualizer): 9 | """This class stores the training results, images in HTML and losses in text file. 10 | """ 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser): 14 | return parser 15 | 16 | def __init__(self, opt): 17 | """Initialize the Visualizer class 18 | """ 19 | self.opt = opt # cache the option 20 | self.name = opt.exp_name 21 | self.win_size = opt.crop_size 22 | self.epoch = -1 23 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.project_name, opt.exp_name,opt.run_name, 'web') 24 | self.img_dir = os.path.join(self.web_dir, 'images') 25 | print('create web directory %s...' % self.web_dir) 26 | util.mkdirs([self.web_dir, self.img_dir]) 27 | 28 | # create a logging file to store training losses 29 | self.log_name = os.path.join(opt.checkpoints_dir, opt.project_name, opt.exp_name,opt.run_name, 'loss_log.txt') 30 | with open(self.log_name, "a") as log_file: 31 | now = time.strftime("%c") 32 | log_file.write('================ Training Loss (%s) ================\n' % now) 33 | 34 | def update_state(self,epochs,iters,times): 35 | self.epochs = epochs 36 | self.iters = iters 37 | self.times = times 38 | 39 | def display_current_results(self, visuals): 40 | 41 | # save images to the disk 42 | for label, image in visuals.items(): 43 | image_numpy = util.tensor2im(image) 44 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (self.epochs, label)) 45 | util.save_image(image_numpy, img_path) 46 | 47 | # # update website 48 | # webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 49 | # for n in range(self.epochs, 0, -1): 50 | # webpage.add_header('epoch [%d]' % n) 51 | # ims, txts, links = [], [], [] 52 | 53 | # for label, image_numpy in visuals.items(): 54 | # image_numpy = util.tensor2im(image) 55 | # img_path = 'epoch%.3d_%s.png' % (n, label) 56 | # ims.append(img_path) 57 | # txts.append(label) 58 | # links.append(img_path) 59 | # webpage.add_images(ims, txts, links, width=self.win_size) 60 | # webpage.save() 61 | 62 | def display_current_videos(self, visuals): 63 | import imageio 64 | from torchvision.utils import make_grid 65 | for label, visual in visuals.items(): 66 | frames = [] 67 | path = os.path.join(self.web_dir, 'epoch%.3d_%s.gif' % (self.epochs, label)) 68 | for frame in visual: 69 | image = util.tensor2im(make_grid(frame).unsqueeze(0)) 70 | frames.append(image) 71 | imageio.mimsave(path, frames) 72 | 73 | # losses: same format as |losses| of plot_current_losses 74 | def plot_current_losses(self, losses): 75 | 76 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (self.epochs, self.iters, self.times['comp'], self.times['data']) 77 | for k, v in losses.items(): 78 | message += '%s: %.3f ' % (k, v) 79 | 80 | print(message) # print the message 81 | with open(self.log_name, "a") as log_file: 82 | log_file.write('%s\n' % message) # save the message -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Category Level Object Pose Estimation via Neural Analysis-by-Synthesis 2 | [Paper](https://arxiv.org/pdf/2008.08145.pdf)|[Project Page](https://ait.ethz.ch/projects/2020/neural-object-fitting/) 3 | 4 | 5 | 6 | This repository contains the code for the paper [Category Level Object Pose Estimation via Neural Analysis-by-Synthesis](https://arxiv.org/pdf/2008.08145.pdf). 7 | 8 | # Installation 9 | This code has been tested in Ubuntu 18.04 and Python 3.7.7. 10 | ``` 11 | conda create -n neural_object_fitting python=3.7 12 | conda activate neural_object_fitting 13 | conda install pytorch=1.1.0 torchvision=0.3.0 cudatoolkit=10.0 -c pytorch 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | # Preparation 18 | Run the following command to download the datasets and checkpoints. This script downloads the real test dataset from [NOCS](https://github.com/hughw19/NOCS_CVPR2019#datasets), our synthetic training set, and our pre-trained models. Our method uses 2D detections and segmentations predicted by NOCS. This script also downloads the pre-computed results. You can also obtain these predictions by running NOCS following the instructions in the [original repo](https://github.com/hughw19/NOCS_CVPR2019) 19 | 20 | ``` 21 | sh prepare_datasets.sh 22 | ``` 23 | 24 | # Evaluation 25 | ## Estimate Pose 26 | Run the following command to estimate poses. 27 | 28 | ``` 29 | python eval.py --config configs/eval.yaml 30 | ``` 31 | 32 | Note: a full evaluation can take ~2 days on a single 1080Ti GPU. You can split the evaluation into multiple batches which can then be run in parallel on a cluster. For example, the following command splits the job into two parts, and each agent takes care of one. 33 | ``` 34 | python eval.py --config configs/eval.yaml --num_agent 2 --id_agent 0 35 | python eval.py --config configs/eval.yaml --num_agent 2 --id_agent 1 36 | ``` 37 | 38 | You can also run the evaluation on a subset of the data first for a sanity check. This can be done by specifying `--skip` option, e.g. with `--skip 10` only 1/10 data will be evaluated. 39 | ``` 40 | python eval.py --config configs/eval.yaml --skip 10 41 | ``` 42 | 43 | To visualize the fitting procedure, add `--vis` to the command. 44 | ``` 45 | python eval.py --config configs/eval.yaml --vis 46 | ``` 47 | 48 | ## Compute Score 49 | Run the following command to evaluate the estimated poses and draw the plot. 50 | ``` 51 | python nocs/eval.py 52 | ``` 53 | 54 | This should produce similar results to [this one](https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/results.pkl). Note that there might be a small variance due to the randomness of initialization. 55 | 56 | # Training 57 | Run the following command to train the model for laptop category. Modify `--category` to train models for other categories. Each of the provided checkpoints was trained for 24h on a single 1080Ti GPU. 58 | 59 | ``` 60 | python train.py --config configs/config.yaml --category laptop 61 | ``` 62 | 63 | Note: we use [Weight&Bias](https://wandb.ai/site) for experiment logging which requires registration (for free). To avoid registration, you can deactivate the usage of Weight&Bias by removing `wandb` from `visualizers` in `configs/config.yaml`. 64 | 65 | 66 | # Citation 67 | If you find this repository useful, please consider citing our paper. 68 | ``` 69 | @article{chen2020category, 70 | title={Category Level Object Pose Estimation via Neural Analysis-by-Synthesis}, 71 | author={Chen, Xu and Dong, Zijian, and Song, Jie and Geiger, Andreas and Hilliges, Otmar}, 72 | year= {2020}, 73 | booktitle = {European Conference on Computer Vision (ECCV)}, 74 | } 75 | ``` 76 | 77 | # Acknowledgement 78 | This code is based on [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git). -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def tensor2im(input_image, imtype=np.uint8): 12 | """"Converts a Tensor array into a numpy image array. 13 | 14 | Parameters: 15 | input_image (tensor) -- the input image tensor array 16 | imtype (type) -- the desired type of the converted numpy array 17 | """ 18 | if not isinstance(input_image, np.ndarray): 19 | if isinstance(input_image, torch.Tensor): # get the data from a variable 20 | image_tensor = input_image.data 21 | else: 22 | return input_image 23 | image_numpy = image_tensor[0].clamp(-1, 1).cpu().float().numpy() # convert it into a numpy array 24 | if image_numpy.shape[0] == 1: # grayscale to RGB 25 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 26 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 27 | else: # if it is a numpy array, do nothing 28 | image_numpy = input_image 29 | return image_numpy.astype(imtype) 30 | 31 | 32 | def diagnose_network(net, name='network'): 33 | """Calculate and print the mean of average absolute(gradients) 34 | 35 | Parameters: 36 | net (torch network) -- Torch network 37 | name (str) -- the name of the network 38 | """ 39 | mean = 0.0 40 | count = 0 41 | for param in net.parameters(): 42 | if param.grad is not None: 43 | mean += torch.mean(torch.abs(param.grad.data)) 44 | count += 1 45 | if count > 0: 46 | mean = mean / count 47 | print(name) 48 | print(mean) 49 | 50 | 51 | def save_image(image_numpy, image_path): 52 | """Save a numpy image to the disk 53 | 54 | Parameters: 55 | image_numpy (numpy array) -- input numpy array 56 | image_path (str) -- the path of the image 57 | """ 58 | image_pil = Image.fromarray(image_numpy) 59 | image_pil.save(image_path) 60 | 61 | 62 | def print_numpy(x, val=True, shp=False): 63 | """Print the mean, min, max, median, std, and size of a numpy array 64 | 65 | Parameters: 66 | val (bool) -- if print the values of the numpy array 67 | shp (bool) -- if print the shape of the numpy array 68 | """ 69 | x = x.astype(np.float64) 70 | if shp: 71 | print('shape,', x.shape) 72 | if val: 73 | x = x.flatten() 74 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 75 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 76 | 77 | 78 | def mkdirs(paths): 79 | """create empty directories if they don't exist 80 | 81 | Parameters: 82 | paths (str list) -- a list of directory paths 83 | """ 84 | if isinstance(paths, list) and not isinstance(paths, str): 85 | for path in paths: 86 | mkdir(path) 87 | else: 88 | mkdir(paths) 89 | 90 | 91 | def mkdir(path): 92 | """create a single empty directory if it didn't exist 93 | 94 | Parameters: 95 | path (str) -- a single directory path 96 | """ 97 | if not os.path.exists(path): 98 | os.makedirs(path) 99 | 100 | def sort_str_by_num(l): 101 | """ Sort the given list in the way that humans expect. 102 | """ 103 | 104 | def alphanum_key(s): 105 | """ Turn a string into a list of string and number chunks. 106 | "z23a" -> ["z", 23, "a"] 107 | """ 108 | def tryint(s): 109 | try: return int(s) 110 | except: return s 111 | import re 112 | return [tryint(c) for c in re.split('([0-9]+)', s)] 113 | 114 | l.sort(key=alphanum_key) -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | 15 | import torch.utils.data 16 | 17 | from data.base_dataset import BaseDataset 18 | 19 | 20 | def find_dataset_using_name(dataset_name): 21 | """Import the module "data/[dataset_name]_dataset.py". 22 | 23 | In the file, the class called DatasetNameDataset() will 24 | be instantiated. It has to be a subclass of BaseDataset, 25 | and it is case-insensitive. 26 | """ 27 | dataset_filename = "data." + dataset_name + "_dataset" 28 | datasetlib = importlib.import_module(dataset_filename) 29 | 30 | dataset = None 31 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 32 | for name, cls in datasetlib.__dict__.items(): 33 | if name.lower() == target_dataset_name.lower() \ 34 | and issubclass(cls, BaseDataset): 35 | dataset = cls 36 | 37 | if dataset is None: 38 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 39 | 40 | return dataset 41 | 42 | 43 | def get_option_setter(dataset_name): 44 | """Return the static method of the dataset class.""" 45 | dataset_class = find_dataset_using_name(dataset_name) 46 | return dataset_class.modify_commandline_options 47 | 48 | 49 | def create_dataset(opt): 50 | """Create a dataset given the option. 51 | 52 | This function wraps the class CustomDatasetDataLoader. 53 | This is the main interface between this package and 'train.py'/'test.py' 54 | 55 | Example: 56 | >>> from data import create_dataset 57 | >>> dataset = create_dataset(opt) 58 | """ 59 | data_loader = CustomDatasetDataLoader(opt) 60 | dataset = data_loader.load_data() 61 | return dataset 62 | 63 | 64 | class CustomDatasetDataLoader(): 65 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 66 | 67 | def __init__(self, opt): 68 | """Initialize this class 69 | 70 | Step 1: create a dataset instance given the name [dataset_mode] 71 | Step 2: create a multi-threaded data loader. 72 | """ 73 | self.opt = opt 74 | dataset_class = find_dataset_using_name(opt.dataset_mode) 75 | self.dataset = dataset_class(opt) 76 | print("dataset [%s] was created with [%d] samples" % (type(self.dataset).__name__, len(self.dataset))) 77 | self.dataloader = torch.utils.data.DataLoader( 78 | self.dataset, 79 | drop_last=True, 80 | batch_size=opt.batch_size, 81 | shuffle=not opt.serial_batches, 82 | num_workers=int(opt.num_threads), 83 | pin_memory=True) 84 | 85 | def load_data(self): 86 | return self 87 | 88 | def __len__(self): 89 | """Return the number of data in the dataset""" 90 | return min(len(self.dataset), self.opt.max_dataset_size) 91 | 92 | def __iter__(self): 93 | """Return a batch of data""" 94 | for i, data in enumerate(self.dataloader): 95 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 96 | break 97 | yield data 98 | -------------------------------------------------------------------------------- /nocs/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Normalized Object Coordinate Space for Category-Level 6D Object Pose and Size Estimation 3 | Detection and evaluation 4 | 5 | Modified based on Mask R-CNN(https://github.com/matterport/Mask_RCNN) 6 | Written by He Wang 7 | """ 8 | 9 | import os 10 | import argparse 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--num_eval', type=int, default=-1) 13 | parser.add_argument('--result_path', type=str, default='./results/neural_object_fitting/fitting') 14 | 15 | args = parser.parse_args() 16 | 17 | num_eval = args.num_eval 18 | 19 | import glob 20 | import numpy as np 21 | import utils as utils 22 | import _pickle as cPickle 23 | import matplotlib as mpl 24 | mpl.use('Agg') 25 | 26 | if __name__ == '__main__': 27 | 28 | 29 | # real classes 30 | coco_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 31 | 'bus', 'train', 'truck', 'boat', 'traffic light', 32 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 33 | 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 34 | 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 35 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 36 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 37 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 38 | 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 39 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 40 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 41 | 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 42 | 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 43 | 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 44 | 'teddy bear', 'hair drier', 'toothbrush'] 45 | 46 | 47 | synset_names = ['BG', #0 48 | 'bottle', #1 49 | 'bowl', #2 50 | 'camera', #3 51 | 'can', #4 52 | 'laptop',#5 53 | 'mug'#6 54 | ] 55 | 56 | class_map = { 57 | 'bottle': 'bottle', 58 | 'bowl':'bowl', 59 | 'cup':'mug', 60 | 'laptop': 'laptop', 61 | } 62 | 63 | 64 | coco_cls_ids = [] 65 | for coco_cls in class_map: 66 | ind = coco_names.index(coco_cls) 67 | coco_cls_ids.append(ind) 68 | 69 | result_pkl_list = glob.glob(os.path.join(args.result_path, 'results_*.pkl')) 70 | result_pkl_list = sorted(result_pkl_list)[:num_eval] 71 | assert len(result_pkl_list) 72 | 73 | final_results = [] 74 | for pkl_path in result_pkl_list: 75 | with open(pkl_path, 'rb') as f: 76 | result = cPickle.load(f) 77 | if not 'gt_handle_visibility' in result: 78 | result['gt_handle_visibility'] = np.ones_like(result['gt_class_ids']) 79 | print('can\'t find gt_handle_visibility in the pkl.') 80 | else: 81 | assert len(result['gt_handle_visibility']) == len(result['gt_class_ids']), "{} {}".format(result['gt_handle_visibility'], result['gt_class_ids']) 82 | 83 | 84 | if type(result) is list: 85 | final_results += result 86 | elif type(result) is dict: 87 | final_results.append(result) 88 | else: 89 | assert False 90 | 91 | aps = utils.compute_degree_cm_mAP(final_results, synset_names, args.result_path, 92 | degree_thresholds = range(0, 61, 1),#range(0, 61, 1), 93 | shift_thresholds= np.linspace(0, 1, 31)*15, #np.linspace(0, 1, 31)*15, 94 | iou_3d_thresholds=np.linspace(0, 1, 101), 95 | iou_pose_thres=0.1, 96 | use_matches_for_pose=True) 97 | 98 | -------------------------------------------------------------------------------- /models/networks/losses.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torchvision.models as models 7 | from math import exp 8 | 9 | class PerceptualLoss(nn.Module): 10 | 11 | def __init__(self,type='l2',reduce=True,final_layer=14): 12 | super(PerceptualLoss, self).__init__() 13 | self.model = self.contentFunc(final_layer=final_layer) 14 | self.model.eval() 15 | self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda() 16 | self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda() 17 | self.type = type 18 | if type == 'l1': 19 | self.criterion = torch.nn.L1Loss(reduce=reduce) 20 | elif type == 'l2': 21 | self.criterion = torch.nn.MSELoss(reduce=reduce) 22 | elif type == 'both': 23 | self.criterion1 = torch.nn.L1Loss(reduce=reduce) 24 | self.criterion2 = torch.nn.MSELoss(reduce=reduce) 25 | else: 26 | raise NotImplementedError 27 | 28 | def normalize(self, tensor): 29 | tensor = (tensor+1)*0.5 30 | tensor_norm = (tensor-self.mean.expand(tensor.shape))/self.std.expand(tensor.shape) 31 | return tensor_norm 32 | 33 | def contentFunc(self,final_layer=14): 34 | cnn = models.vgg19(pretrained=True).features 35 | cnn = cnn.cuda() 36 | model = nn.Sequential() 37 | model = model.cuda() 38 | for i, layer in enumerate(list(cnn)): 39 | model.add_module(str(i), layer) 40 | if i == final_layer: 41 | break 42 | return model 43 | 44 | def forward(self, fakeIm, realIm): 45 | f_fake = self.model.forward(self.normalize(fakeIm)) 46 | f_real = self.model.forward(self.normalize(realIm)) 47 | if self.type == 'both': 48 | loss = self.criterion1(f_fake, f_real.detach())+self.criterion2(f_fake, f_real.detach()) 49 | else: 50 | loss = self.criterion(f_fake, f_real.detach()) 51 | return loss 52 | 53 | def gaussian(window_size, sigma): 54 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 55 | return gauss / gauss.sum() 56 | 57 | 58 | def create_window(window_size, channel): 59 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 60 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 61 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 62 | return window 63 | 64 | 65 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 66 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 67 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 68 | 69 | mu1_sq = mu1.pow(2) 70 | mu2_sq = mu2.pow(2) 71 | mu1_mu2 = mu1 * mu2 72 | 73 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 74 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 75 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 76 | 77 | C1 = 0.01 ** 2 78 | C2 = 0.03 ** 2 79 | 80 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 81 | 82 | if size_average: 83 | return ssim_map.mean() 84 | else: 85 | return ssim_map.mean(1).mean(1).mean(1) 86 | 87 | 88 | class SSIM(torch.nn.Module): 89 | def __init__(self, window_size=11, reduce=True,negative=False): 90 | super(SSIM, self).__init__() 91 | self.window_size = window_size 92 | self.reduce = reduce 93 | self.channel = 1 94 | self.window = create_window(window_size, self.channel) 95 | self.negative = negative 96 | 97 | def forward(self, img1, img2): 98 | (_, channel, _, _) = img1.size() 99 | 100 | if channel == self.channel and self.window.data.type() == img1.data.type(): 101 | window = self.window 102 | else: 103 | window = create_window(self.window_size, channel) 104 | 105 | if img1.is_cuda: 106 | window = window.cuda(img1.get_device()) 107 | window = window.type_as(img1) 108 | 109 | self.window = window 110 | self.channel = channel 111 | if self.negative: 112 | return -_ssim(img1, img2, window, self.window_size, channel, self.reduce) 113 | else: 114 | return _ssim(img1, img2, window, self.window_size, channel, self.reduce) 115 | 116 | 117 | def ssim(img1, img2, window_size=11, reduce=True): 118 | (_, channel, _, _) = img1.size() 119 | window = create_window(window_size, channel) 120 | 121 | if img1.is_cuda: 122 | window = window.cuda(img1.get_device()) 123 | window = window.type_as(img1) 124 | 125 | return _ssim(img1, img2, window, window_size, channel, reduce) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """General-purpose training script for image-to-image translation. 2 | 3 | This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and 4 | different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). 5 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). 6 | 7 | It first creates model, dataset, and visualizer given the option. 8 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. 9 | 10 | Example: 11 | Train a CycleGAN model: 12 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan 13 | Train a pix2pix model: 14 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA 15 | 16 | See options/base_options.py and options/train_options.py for more training options. 17 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md 18 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md 19 | """ 20 | import time 21 | from options.train_options import TrainOptions 22 | from data import create_dataset 23 | from models import create_model 24 | from util.visualizer.base_visualizer import BaseVisualizer as Visualizer 25 | import copy 26 | 27 | 28 | if __name__ == '__main__': 29 | opt = TrainOptions().parse() # get training options 30 | print('------------- Creating Dataset ----------------') 31 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 32 | opt.dataset_size = len(dataset) # get the number of images in the dataset. 33 | print('-----------------------------------------------\n') 34 | 35 | print('-------------- Creating Model -----------------') 36 | model = create_model(opt) # create a model given opt.model and other options 37 | iter_start = model.setup(opt) # regular setup: load and print networks; create schedulers 38 | print('train from [Iter %d]' % (iter_start)) 39 | print('-----------------------------------------------\n') 40 | 41 | print('------------ Creating Visualizer --------------') 42 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 43 | print('-----------------------------------------------\n') 44 | 45 | 46 | print('--------------- Start Training -----------------') 47 | total_iters = 0 # the total number of training iterations 48 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by , + 49 | epoch_start_time = time.time() # timer for entire epoch 50 | iter_data_time = time.time() # timer for data loading per iteration 51 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 52 | model.update_learning_rate() # update learning rates at the end of every epoch. TODO: moved from end to the begining, check the consequence 53 | 54 | for i, data in enumerate(dataset): # inner loop within one epoch 55 | 56 | total_iters += opt.batch_size 57 | epoch_iter += opt.batch_size 58 | 59 | if total_iters < iter_start: continue # skip until the starting iteration 60 | 61 | times = {} # recording compuation time 62 | iter_start_time = time.time() # timer for computation per iteration 63 | times['data'] = iter_start_time - iter_data_time 64 | 65 | model.set_input(data) # unpack data from dataset and apply preprocessing 66 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 67 | times['comp'] = time.time() - iter_start_time 68 | 69 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file 70 | model.compute_visuals() 71 | visualizer.update_state(epoch, total_iters,times=times) 72 | visualizer.display_current_results(model.get_current_visuals()) 73 | visualizer.display_current_videos(model.get_current_videos()) 74 | 75 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 76 | losses = model.get_current_losses() 77 | visualizer.update_state(epoch, total_iters,times=times) 78 | visualizer.plot_current_losses(losses) 79 | 80 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 81 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 82 | model.save('latest', iter=total_iters) 83 | 84 | iter_data_time = time.time() 85 | 86 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 87 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 88 | model.save("%d" % total_iters, iter=total_iters) 89 | 90 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) -------------------------------------------------------------------------------- /util/visualizer/base_visualizer.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | import numpy as np 5 | 6 | from util import util 7 | 8 | from cv2 import resize 9 | 10 | class BaseVisualizer(): 11 | @staticmethod 12 | def modify_commandline_options(parser): 13 | opt, _ = parser.parse_known_args() 14 | 15 | for vis_name in opt.visualizers: 16 | vis_filename = "util.visualizer." + vis_name + "_visualizer" 17 | vislib = importlib.import_module(vis_filename) 18 | vis = None 19 | target_vis_name = vis_name + 'visualizer' 20 | for name, cls in vislib.__dict__.items(): 21 | if name.lower() == target_vis_name.lower() \ 22 | and issubclass(cls, BaseVisualizer): 23 | vis = cls 24 | 25 | if vis is None: 26 | print( 27 | "In %s.py, there should be a subclass of BaseVisualizer with class name that matches %s in lowercase." % ( 28 | vis_filename, target_vis_name)) 29 | exit(0) 30 | 31 | parser = vis.modify_commandline_options(parser) 32 | 33 | return parser 34 | 35 | def __init__(self, opt): 36 | self.visualizer_list = [] 37 | 38 | for vis_name in opt.visualizers: 39 | vis_filename = "util.visualizer." + vis_name + "_visualizer" 40 | vislib = importlib.import_module(vis_filename) 41 | vis = None 42 | target_vis_name = vis_name + 'visualizer' 43 | for name, cls in vislib.__dict__.items(): 44 | if name.lower() == target_vis_name.lower() \ 45 | and issubclass(cls, BaseVisualizer): 46 | vis = cls 47 | 48 | if vis is None: 49 | print( 50 | "In %s.py, there should be a subclass of BaseVisualizer with class name that matches %s in lowercase." % ( 51 | vis_filename, target_vis_name)) 52 | exit(0) 53 | 54 | self.visualizer_list.append(vis(opt)) 55 | 56 | def update_state(self,epochs,iters,times): 57 | for visualizer in self.visualizer_list: 58 | visualizer.update_state(epochs,iters,times) 59 | 60 | def display_current_results(self, visuals): 61 | """Display current results on visdom; save current results to an HTML file. 62 | 63 | Parameters: 64 | visuals (OrderedDict) - - dictionary of images to display or save 65 | epoch (int) - - the current epoch 66 | save_result (bool) - - if save the current results to an HTML file 67 | """ 68 | for visualizer in self.visualizer_list: 69 | visualizer.display_current_results(visuals) 70 | 71 | def display_current_videos(self, visuals): 72 | """Display current results on visdom; save current results to an HTML file. 73 | 74 | Parameters: 75 | visuals (OrderedDict) - - dictionary of images to display or save 76 | epoch (int) - - the current epoch 77 | save_result (bool) - - if save the current results to an HTML file 78 | """ 79 | for visualizer in self.visualizer_list: 80 | visualizer.display_current_videos(visuals) 81 | 82 | def plot_current_losses(self, losses): 83 | """print current losses on console; also save the losses to the disk 84 | 85 | Parameters: 86 | epoch (int) -- current epoch 87 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 88 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 89 | t_comp (float) -- computational time per data point (normalized by batch_size) 90 | t_data (float) -- data loading time per data point (normalized by batch_size) 91 | """ 92 | # losses: same format as |losses| of plot_current_losses 93 | for visualizer in self.visualizer_list: 94 | visualizer.plot_current_losses(losses) 95 | 96 | 97 | 98 | 99 | # TODO merge image saver for test and training time 100 | def save_images(webpage, visuals, name, aspect_ratio=1.0, width=256): 101 | """Save images to the disk. 102 | 103 | Parameters: 104 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 105 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 106 | image_path (str) -- the string is used to create image paths 107 | aspect_ratio (float) -- the aspect ratio of saved images 108 | width (int) -- the images will be resized to width x width 109 | 110 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 111 | """ 112 | image_dir = webpage.get_image_dir() 113 | 114 | ims, txts, links = [], [], [] 115 | 116 | for label, im_data in visuals.items(): 117 | im = util.tensor2im(im_data) 118 | image_name = '%s_%s.png' % (name, label) 119 | save_path = os.path.join(image_dir, image_name) 120 | h, w, _ = im.shape 121 | if aspect_ratio > 1.0: 122 | im = resize(im, (h, int(w * aspect_ratio)), interpolation='bicubic') 123 | if aspect_ratio < 1.0: 124 | im = resize(im, (int(h / aspect_ratio), w), interpolation='bicubic') 125 | util.save_image(im, save_path) 126 | 127 | ims.append(image_name) 128 | txts.append(label) 129 | links.append(image_name) 130 | webpage.add_images(ims, txts, links, width=width) 131 | 132 | 133 | def get_img_from_fig(fig, dpi=180): 134 | import io 135 | import cv2 136 | buf = io.BytesIO() 137 | fig.savefig(buf, format="png", dpi=180) 138 | buf.seek(0) 139 | img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) 140 | buf.close() 141 | img = cv2.imdecode(img_arr, 1) 142 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 143 | 144 | return img 145 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | 13 | 14 | class BaseDataset(data.Dataset, ABC): 15 | """This class is an abstract base class (ABC) for datasets. 16 | 17 | To create a subclass, you need to implement the following four functions: 18 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 19 | -- <__len__>: return the size of dataset. 20 | -- <__getitem__>: get a data point. 21 | -- : (optionally) add dataset-specific options and set default options. 22 | """ 23 | 24 | def __init__(self, opt): 25 | """Initialize the class; save the options in the class 26 | 27 | Parameters: 28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 29 | """ 30 | self.opt = opt 31 | self.root = opt.dataroot 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_params(opt, size): 65 | w, h = size 66 | new_h = h 67 | new_w = w 68 | if opt.preprocess == 'resize_and_crop': 69 | new_h = new_w = opt.load_size 70 | elif opt.preprocess == 'scale_width_and_crop': 71 | new_w = opt.load_size 72 | new_h = opt.load_size * h // w 73 | 74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 76 | 77 | flip = random.random() > 0.5 78 | 79 | return {'crop_pos': (x, y), 'flip': flip} 80 | 81 | 82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 83 | transform_list = [] 84 | if grayscale: 85 | transform_list.append(transforms.Grayscale(1)) 86 | if 'resize' in opt.preprocess: 87 | osize = [opt.load_size, opt.load_size] 88 | transform_list.append(transforms.Resize(osize, method)) 89 | elif 'scale_width' in opt.preprocess: 90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 91 | 92 | if 'crop' in opt.preprocess: 93 | if params is None: 94 | transform_list.append(transforms.CenterCrop(opt.crop_size)) 95 | else: 96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 97 | 98 | if opt.preprocess == 'none': 99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 100 | 101 | if not opt.no_flip: 102 | if params is None: 103 | transform_list.append(transforms.RandomHorizontalFlip()) 104 | elif params['flip']: 105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 106 | 107 | if convert: 108 | transform_list += [transforms.ToTensor()] 109 | if grayscale: 110 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 111 | else: 112 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 113 | 114 | # add small random noise 115 | transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())) 116 | return transforms.Compose(transform_list) 117 | 118 | 119 | def __make_power_2(img, base, method=Image.BICUBIC): 120 | ow, oh = img.size 121 | h = int(round(oh / base) * base) 122 | w = int(round(ow / base) * base) 123 | if (h == oh) and (w == ow): 124 | return img 125 | 126 | __print_size_warning(ow, oh, w, h) 127 | return img.resize((w, h), method) 128 | 129 | 130 | def __scale_width(img, target_width, method=Image.BICUBIC): 131 | ow, oh = img.size 132 | if (ow == target_width): 133 | return img 134 | w = target_width 135 | h = int(target_width * oh / ow) 136 | return img.resize((w, h), method) 137 | 138 | 139 | def __crop(img, pos, size): 140 | ow, oh = img.size 141 | x1, y1 = pos 142 | tw = th = size 143 | if (ow > tw or oh > th): 144 | return img.crop((x1, y1, x1 + tw, y1 + th)) 145 | return img 146 | 147 | 148 | def __flip(img, flip): 149 | if flip: 150 | return img.transpose(Image.FLIP_LEFT_RIGHT) 151 | return img 152 | 153 | 154 | def __print_size_warning(ow, oh, w, h): 155 | """Print warning information about image size(only print once)""" 156 | if not hasattr(__print_size_warning, 'has_printed'): 157 | print("The image size needs to be a multiple of 4. " 158 | "The loaded image size was (%d, %d), so it was adjusted to " 159 | "(%d, %d). This adjustment will be done to all images " 160 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 161 | __print_size_warning.has_printed = True 162 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from models import create_model 4 | import torch 5 | import numpy as np 6 | import tqdm 7 | import pickle 8 | import cv2 9 | import torch.nn.functional as F 10 | from scipy.spatial.transform import Rotation as scipy_rot 11 | 12 | 13 | def voting(loss_record, state_record, thresh=20, topk=1, iter=-1, category='laptop'): 14 | state_rank = get_topk_angle(loss_record, state_record,topk=topk,iter=iter) 15 | 16 | def compare_angle(angle1, angle2): 17 | 18 | R1 = scipy_rot.from_euler('yxz', angle1, degrees=True).as_dcm()[:3, :3] 19 | R2 = scipy_rot.from_euler('yxz', angle2, degrees=True).as_dcm()[:3, :3] 20 | 21 | R1 = R1[:3, :3] / np.cbrt(np.linalg.det(R1[:3, :3])) 22 | R2 = R2[:3, :3] / np.cbrt(np.linalg.det(R2[:3, :3])) 23 | 24 | if category in ['bottle', 'can', 'bowl']: ## symmetric when rotating around y-axis 25 | y = np.array([0, 1, 0]) 26 | y1 = R1 @ y 27 | y2 = R2 @ y 28 | rot_error = np.arccos(y1.dot(y2) / (np.linalg.norm(y1) * np.linalg.norm(y2))) 29 | else: 30 | R = R1 @ R2.transpose() 31 | rot_error = np.arccos((np.trace(R) - 1) / 2) 32 | 33 | return rot_error * 180 / np.pi 34 | 35 | ids_inliars_best = [] 36 | for index1, state1 in enumerate(state_rank): 37 | ids_inliars = [index1] 38 | for index2, state2 in enumerate(state_rank): 39 | if compare_angle(state1[:3], state2[:3]) <= thresh: 40 | ids_inliars.append(index2) 41 | if len(ids_inliars) > len(ids_inliars_best): 42 | ids_inliars_best = ids_inliars.copy() 43 | 44 | return state_rank[np.array(ids_inliars_best).min(),:] 45 | 46 | def get_topk_angle(loss_record,state_record,topk=1,iter=-1): 47 | recon_error = loss_record[:,iter,:].sum(-1) 48 | ranking_sample = [r[0] for r in sorted(enumerate(recon_error), key=lambda r: r[1])] 49 | return state_record[ranking_sample[:topk],iter,:] 50 | 51 | 52 | 53 | 54 | if __name__ == '__main__': 55 | opt = TestOptions().parse() # get test options 56 | opt.num_threads = 1 57 | opt.serial_batches = True 58 | opt.no_flip = True 59 | 60 | # https://github.com/hughw19/NOCS_CVPR2019/blob/14dbce775c3c7c45bb7b19269bd53d68efb8f73f/detect_eval.py#L172 61 | intrinsics = np.array([[591.0125, 0, 322.525], [0, 590.16775, 244.11084], [0, 0, 1]]) 62 | 63 | # Rendering parameters 64 | focal_lengh_render = 70. 65 | image_size_render = 64 66 | 67 | # Average scales from the synthetic training set CAMERA 68 | mean_scales = np.array([0.34, 0.21, 0.19, 0.15, 0.46, 0.17]) 69 | categories = ['bottle','bowl','camera','can','laptop','mug'] 70 | 71 | 72 | output_folder = os.path.join(opt.results_dir,opt.project_name,opt.test_name) 73 | if not os.path.exists(output_folder): 74 | os.makedirs(output_folder) 75 | 76 | models = [] 77 | for cat in categories: 78 | 79 | opt.category = cat 80 | opt.exp_name = cat 81 | 82 | model = create_model(opt) 83 | model.setup(opt) 84 | model.eval() 85 | 86 | models.append(model) 87 | 88 | nocs_list = sorted(os.listdir( os.path.join(opt.dataroot,'nocs_det')))[::opt.skip] 89 | 90 | interval = len(nocs_list)//(opt.num_agent-1) if opt.num_agent > 1 else len(nocs_list) 91 | task_range = nocs_list[interval*opt.id_agent:min(interval*(opt.id_agent+1), len(nocs_list))] 92 | 93 | for file_name in tqdm.tqdm(task_range): 94 | 95 | file_path = os.path.join(opt.dataroot,'nocs_det', file_name) 96 | pose_file = pickle.load(open(file_path, 'rb'), encoding='utf-8') 97 | 98 | image_name = pose_file['image_path'].replace('data/real/test', opt.dataroot+'/real_test/')+'_color.png' 99 | image = cv2.imread(image_name)[:,:,::-1] 100 | 101 | masks = pose_file['pred_mask'] 102 | bboxes = pose_file['pred_bboxes'] 103 | 104 | pose_file['pred_RTs_ours'] = np.zeros_like(pose_file['pred_RTs']) 105 | 106 | for id, class_pred in enumerate(pose_file['pred_class_ids']): 107 | bbox = bboxes[id] 108 | image_mask = image.copy() 109 | image_mask[masks[:,:,id]==0,:] = 255 110 | image_mask = image_mask[bbox[0]:bbox[2],bbox[1]:bbox[3],:] 111 | 112 | A = (torch.from_numpy(image_mask.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2) /255) * 2 - 1 113 | 114 | _, c, h, w = A.shape 115 | s = max( h, w) + 30 116 | A = F.pad(A,[(s - w)//2, (s - w) - (s - w)//2, 117 | (s - h)//2, (s - h) - (s - h)//2],value=1) 118 | A = F.interpolate(A,size=opt.target_size,mode='bilinear') 119 | 120 | state_history, loss_history, image_history = models[class_pred-1].fitting(A) 121 | if opt.vis: 122 | # Use NOCS's prediction as reference for visualizing pose error as the resuls are not matched to GT's order. 123 | models[class_pred-1].visulize_fitting(A,torch.tensor(pose_file['pred_RTs'][id]).float().unsqueeze(0),state_history,loss_history,image_history) 124 | 125 | states = voting(loss_history,state_history,category=categories[class_pred-1],topk=5,thresh=10) 126 | pose_file['pred_RTs_ours'][id][:3,:3] = scipy_rot.from_euler('yxz', states[:3], degrees=True).as_dcm()[:3, :3] 127 | 128 | 129 | angle = -states[2] / 180 * np.pi 130 | mat = np.array([[states[5]*np.cos(angle), -states[5]*np.sin(angle), states[5]*states[3]], 131 | [states[5]*np.sin(angle), states[5]*np.cos(angle), states[5]*states[4]], 132 | [ 0, 0, 1]]) 133 | 134 | mat_inv = np.linalg.inv(mat) 135 | u = (bbox[1] + bbox[3])/2 + mat_inv[0,2]*s/2 136 | v = (bbox[0] + bbox[2])/2 + mat_inv[1,2]*s/2 137 | 138 | z = image_size_render/(s/states[5]) * (intrinsics[0,0]+intrinsics[1,1])/2 /focal_lengh_render * mean_scales[class_pred-1] 139 | 140 | pose_file['pred_RTs_ours'][id][2, 3] = z 141 | pose_file['pred_RTs_ours'][id][0, 3] = (u - intrinsics[0,2])/intrinsics[0,0]*z 142 | pose_file['pred_RTs_ours'][id][1, 3] = (v - intrinsics[1,2])/intrinsics[1,1]*z 143 | pose_file['pred_RTs_ours'][id][3, 3] = 1 144 | 145 | f = open(os.path.join(output_folder,file_name),'wb') 146 | pickle.dump(pose_file,f,-1) -------------------------------------------------------------------------------- /nocs/aligning.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Normalized Object Coordinate Space for Category-Level 6D Object Pose and Size Estimation 3 | RANSAC for Similarity Transformation Estimation 4 | 5 | Written by Srinath Sridhar 6 | ''' 7 | 8 | import numpy as np 9 | import cv2 10 | import itertools 11 | 12 | def estimateSimilarityTransform(source: np.array, target: np.array, verbose=False): 13 | SourceHom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])])) 14 | TargetHom = np.transpose(np.hstack([target, np.ones([source.shape[0], 1])])) 15 | 16 | # Auto-parameter selection based on source-target heuristics 17 | TargetNorm = np.mean(np.linalg.norm(target, axis=1)) 18 | SourceNorm = np.mean(np.linalg.norm(source, axis=1)) 19 | RatioTS = (TargetNorm / SourceNorm) 20 | RatioST = (SourceNorm / TargetNorm) 21 | PassT = RatioST if(RatioST>RatioTS) else RatioTS 22 | StopT = PassT / 100 23 | nIter = 100 24 | if verbose: 25 | print('Pass threshold: ', PassT) 26 | print('Stop threshold: ', StopT) 27 | print('Number of iterations: ', nIter) 28 | 29 | SourceInliersHom, TargetInliersHom, BestInlierRatio = getRANSACInliers(SourceHom, TargetHom, MaxIterations=nIter, PassThreshold=PassT, StopThreshold=StopT) 30 | 31 | if(BestInlierRatio < 0.1): 32 | print('[ WARN ] - Something is wrong. Small BestInlierRatio: ', BestInlierRatio) 33 | return None, None, None, None 34 | 35 | Scales, Rotation, Translation, OutTransform = estimateSimilarityUmeyama(SourceInliersHom, TargetInliersHom) 36 | 37 | if verbose: 38 | print('BestInlierRatio:', BestInlierRatio) 39 | print('Rotation:\n', Rotation) 40 | print('Translation:\n', Translation) 41 | print('Scales:', Scales) 42 | 43 | return Scales, Rotation, Translation, OutTransform 44 | 45 | def estimateRestrictedAffineTransform(source: np.array, target: np.array, verbose=False): 46 | SourceHom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])])) 47 | TargetHom = np.transpose(np.hstack([target, np.ones([source.shape[0], 1])])) 48 | 49 | RetVal, AffineTrans, Inliers = cv2.estimateAffine3D(source, target) 50 | # We assume no shear in the affine matrix and decompose into rotation, non-uniform scales, and translation 51 | Translation = AffineTrans[:3, 3] 52 | NUScaleRotMat = AffineTrans[:3, :3] 53 | # NUScaleRotMat should be the matrix SR, where S is a diagonal scale matrix and R is the rotation matrix (equivalently RS) 54 | # Let us do the SVD of NUScaleRotMat to obtain R1*S*R2 and then R = R1 * R2 55 | R1, ScalesSorted, R2 = np.linalg.svd(NUScaleRotMat, full_matrices=True) 56 | 57 | if verbose: 58 | print('-----------------------------------------------------------------------') 59 | # Now, the scales are sort in ascending order which is painful because we don't know the x, y, z scales 60 | # Let's figure that out by evaluating all 6 possible permutations of the scales 61 | ScalePermutations = list(itertools.permutations(ScalesSorted)) 62 | MinResidual = 1e8 63 | Scales = ScalePermutations[0] 64 | OutTransform = np.identity(4) 65 | Rotation = np.identity(3) 66 | for ScaleCand in ScalePermutations: 67 | CurrScale = np.asarray(ScaleCand) 68 | CurrTransform = np.identity(4) 69 | CurrRotation = (np.diag(1 / CurrScale) @ NUScaleRotMat).transpose() 70 | CurrTransform[:3, :3] = np.diag(CurrScale) @ CurrRotation 71 | CurrTransform[:3, 3] = Translation 72 | # Residual = evaluateModel(CurrTransform, SourceHom, TargetHom) 73 | Residual = evaluateModelNonHom(source, target, CurrScale,CurrRotation, Translation) 74 | if verbose: 75 | # print('CurrTransform:\n', CurrTransform) 76 | print('CurrScale:', CurrScale) 77 | print('Residual:', Residual) 78 | print('AltRes:', evaluateModelNoThresh(CurrTransform, SourceHom, TargetHom)) 79 | if Residual < MinResidual: 80 | MinResidual = Residual 81 | Scales = CurrScale 82 | Rotation = CurrRotation 83 | OutTransform = CurrTransform 84 | 85 | if verbose: 86 | print('Best Scale:', Scales) 87 | 88 | if verbose: 89 | print('Affine Scales:', Scales) 90 | print('Affine Translation:', Translation) 91 | print('Affine Rotation:\n', Rotation) 92 | print('-----------------------------------------------------------------------') 93 | 94 | return Scales, Rotation, Translation, OutTransform 95 | 96 | def getRANSACInliers(SourceHom, TargetHom, MaxIterations=100, PassThreshold=200, StopThreshold=1): 97 | BestResidual = 1e10 98 | BestInlierRatio = 0 99 | BestInlierIdx = np.arange(SourceHom.shape[1]) 100 | for i in range(0, MaxIterations): 101 | # Pick 5 random (but corresponding) points from source and target 102 | RandIdx = np.random.randint(SourceHom.shape[1], size=5) 103 | _, _, _, OutTransform = estimateSimilarityUmeyama(SourceHom[:, RandIdx], TargetHom[:, RandIdx]) 104 | Residual, InlierRatio, InlierIdx = evaluateModel(OutTransform, SourceHom, TargetHom, PassThreshold) 105 | if Residual < BestResidual: 106 | BestResidual = Residual 107 | BestInlierRatio = InlierRatio 108 | BestInlierIdx = InlierIdx 109 | if BestResidual < StopThreshold: 110 | break 111 | 112 | # print('Iteration: ', i) 113 | # print('Residual: ', Residual) 114 | # print('Inlier ratio: ', InlierRatio) 115 | 116 | return SourceHom[:, BestInlierIdx], TargetHom[:, BestInlierIdx], BestInlierRatio 117 | 118 | def evaluateModel(OutTransform, SourceHom, TargetHom, PassThreshold): 119 | Diff = TargetHom - np.matmul(OutTransform, SourceHom) 120 | ResidualVec = np.linalg.norm(Diff[:3, :], axis=0) 121 | Residual = np.linalg.norm(ResidualVec) 122 | InlierIdx = np.where(ResidualVec < PassThreshold) 123 | nInliers = np.count_nonzero(InlierIdx) 124 | InlierRatio = nInliers / SourceHom.shape[1] 125 | return Residual, InlierRatio, InlierIdx[0] 126 | 127 | def evaluateModelNoThresh(OutTransform, SourceHom, TargetHom): 128 | Diff = TargetHom - np.matmul(OutTransform, SourceHom) 129 | ResidualVec = np.linalg.norm(Diff[:3, :], axis=0) 130 | Residual = np.linalg.norm(ResidualVec) 131 | return Residual 132 | 133 | def evaluateModelNonHom(source, target, Scales, Rotation, Translation): 134 | RepTrans = np.tile(Translation, (source.shape[0], 1)) 135 | TransSource = (np.diag(Scales) @ Rotation @ source.transpose() + RepTrans.transpose()).transpose() 136 | Diff = target - TransSource 137 | ResidualVec = np.linalg.norm(Diff, axis=0) 138 | Residual = np.linalg.norm(ResidualVec) 139 | return Residual 140 | 141 | def testNonUniformScale(SourceHom, TargetHom): 142 | OutTransform = np.matmul(TargetHom, np.linalg.pinv(SourceHom)) 143 | ScaledRotation = OutTransform[:3, :3] 144 | Translation = OutTransform[:3, 3] 145 | Sx = np.linalg.norm(ScaledRotation[0, :]) 146 | Sy = np.linalg.norm(ScaledRotation[1, :]) 147 | Sz = np.linalg.norm(ScaledRotation[2, :]) 148 | Rotation = np.vstack([ScaledRotation[0, :] / Sx, ScaledRotation[1, :] / Sy, ScaledRotation[2, :] / Sz]) 149 | print('Rotation matrix norm:', np.linalg.norm(Rotation)) 150 | Scales = np.array([Sx, Sy, Sz]) 151 | 152 | # # Check 153 | # Diff = TargetHom - np.matmul(OutTransform, SourceHom) 154 | # Residual = np.linalg.norm(Diff[:3, :], axis=0) 155 | return Scales, Rotation, Translation, OutTransform 156 | 157 | def estimateSimilarityUmeyama(SourceHom, TargetHom): 158 | # Copy of original paper is at: http://web.stanford.edu/class/cs273/refs/umeyama.pdf 159 | SourceCentroid = np.mean(SourceHom[:3, :], axis=1) 160 | TargetCentroid = np.mean(TargetHom[:3, :], axis=1) 161 | nPoints = SourceHom.shape[1] 162 | 163 | CenteredSource = SourceHom[:3, :] - np.tile(SourceCentroid, (nPoints, 1)).transpose() 164 | CenteredTarget = TargetHom[:3, :] - np.tile(TargetCentroid, (nPoints, 1)).transpose() 165 | 166 | CovMatrix = np.matmul(CenteredTarget, np.transpose(CenteredSource)) / nPoints 167 | 168 | if np.isnan(CovMatrix).any(): 169 | print('nPoints:', nPoints) 170 | print(SourceHom.shape) 171 | print(TargetHom.shape) 172 | raise RuntimeError('There are NANs in the input.') 173 | 174 | U, D, Vh = np.linalg.svd(CovMatrix, full_matrices=True) 175 | d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0 176 | if d: 177 | D[-1] = -D[-1] 178 | U[:, -1] = -U[:, -1] 179 | 180 | Rotation = np.matmul(U, Vh).T # Transpose is the one that works 181 | 182 | varP = np.var(SourceHom[:3, :], axis=1).sum() 183 | ScaleFact = 1/varP * np.sum(D) # scale factor 184 | Scales = np.array([ScaleFact, ScaleFact, ScaleFact]) 185 | ScaleMatrix = np.diag(Scales) 186 | 187 | Translation = TargetHom[:3, :].mean(axis=1) - SourceHom[:3, :].mean(axis=1).dot(ScaleFact*Rotation) 188 | 189 | OutTransform = np.identity(4) 190 | OutTransform[:3, :3] = ScaleMatrix @ Rotation 191 | OutTransform[:3, 3] = Translation 192 | 193 | # # Check 194 | # Diff = TargetHom - np.matmul(OutTransform, SourceHom) 195 | # Residual = np.linalg.norm(Diff[:3, :], axis=0) 196 | return Scales, Rotation, Translation, OutTransform 197 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import argparse 5 | import torch 6 | 7 | import data 8 | import models 9 | from util import util 10 | from util.visualizer.base_visualizer import BaseVisualizer as Visualizer 11 | 12 | 13 | class BaseOptions(): 14 | """This class defines options used during both training and test time. 15 | 16 | It also implements several helper functions such as parsing, printing, and saving the options. 17 | It also gathers additional options defined in functions in both dataset class and model class. 18 | """ 19 | 20 | def __init__(self): 21 | """Reset the class; indicates the class hasn't been initailized""" 22 | self.initialized = False 23 | 24 | def initialize(self, parser): 25 | """Define the common options that are used in both training and test.""" 26 | parser.add_argument('--config', type=str) 27 | # basic parameters 28 | basic_args = parser.add_argument_group('basic') 29 | basic_args.add_argument('--project_name', type=str, default='project template',help='project name, use project folder name by default') 30 | basic_args.add_argument('--dataroot', type=str,help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 31 | basic_args.add_argument('--run_name', type=str, default='', help='id of the experiment run, specified as string format, e.g. lr={lr} or string. Using current datetime by default') 32 | basic_args.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 33 | basic_args.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 34 | # model parameters 35 | model_args = parser.add_argument_group('model') 36 | model_args.add_argument('--model', type=str, default='latent_object', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') 37 | model_args.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 38 | model_args.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 39 | model_args.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 40 | model_args.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 41 | # dataset parameters 42 | data_args = parser.add_argument_group('data') 43 | data_args.add_argument('--dataset_mode', type=str, default='nocs_real', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') 44 | data_args.add_argument('--num_threads', default=0, type=int, help='# threads for loading data') 45 | data_args.add_argument('--batch_size', type=int, default=1, help='input batch size') 46 | data_args.add_argument('--load_size', type=int, default=64, help='scale images to this size') 47 | data_args.add_argument('--crop_size', type=int, default=64, help='then crop to this size') 48 | data_args.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 49 | data_args.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 50 | data_args.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 51 | data_args.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 52 | data_args.add_argument('--keep_last', action='store_true', help='drop the last batch of the dataset to keep batch size consistent.') 53 | # additional parameters 54 | misc_args = parser.add_argument_group('misc') 55 | misc_args.add_argument('--load_suffix', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 56 | misc_args.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 57 | misc_args.add_argument('--visualizers', nargs='+', type=str, default=['terminal', 'wandb'], help='visualizers to use. local | wandb') 58 | self.initialized = True 59 | return parser 60 | 61 | def gather_options(self): 62 | """Initialize our parser with basic options(only once). 63 | Add additional model-specific and dataset-specific options. 64 | These options are defined in the function 65 | in model and dataset classes. 66 | """ 67 | if not self.initialized: # check if it has been initialized 68 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 69 | parser = self.initialize(parser) 70 | 71 | 72 | # get the basic options 73 | opt, _ = parser.parse_known_args() 74 | if opt.config is not None: opt = self.load_options(opt) 75 | # modify model-related parser options 76 | model_name = opt.model 77 | model_option_setter = models.get_option_setter(model_name) 78 | parser = model_option_setter(parser, self.isTrain) 79 | opt, args = parser.parse_known_args() # parse again with new defaults 80 | if opt.config is not None: opt = self.load_options(opt) 81 | 82 | # modify dataset-related parser options 83 | dataset_name = opt.dataset_mode 84 | dataset_option_setter = data.get_option_setter(dataset_name) 85 | parser = dataset_option_setter(parser, self.isTrain) 86 | 87 | # modify visualization-related parser options 88 | parser = Visualizer.modify_commandline_options(parser) 89 | 90 | # save and return the parser 91 | self.parser = parser 92 | opt = parser.parse_args() 93 | if opt.config is not None: opt = self.load_options(opt) 94 | 95 | opt.exp_name = opt.category 96 | 97 | return opt 98 | 99 | def print_options(self, opt): 100 | """Print and save options 101 | 102 | It will print both current options and default values(if different). 103 | It will save options into a text file / [checkpoints_dir] / opt.txt 104 | """ 105 | message = '' 106 | message += '----------------- Options ---------------\n' 107 | for k, v in sorted(vars(opt).items()): 108 | comment = '' 109 | default = self.parser.get_default(k) 110 | if v != default: 111 | comment = '\t[default: %s]' % str(default) 112 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 113 | message += '----------------- End -------------------' 114 | print(message) 115 | 116 | def save_options(self,opt): 117 | output_dict = {} 118 | for group in self.parser._action_groups: 119 | if group.title in ['positional arguments', 'optional arguments']: continue 120 | output_dict[group.title] = {a.dest: getattr(opt, a.dest, None) for a in group._group_actions} 121 | 122 | import yaml 123 | 124 | if self.isTrain: 125 | output_path = os.path.join(opt.checkpoints_dir,opt.project_name,opt.exp_name, opt.run_name,'config.yaml') 126 | else: 127 | output_path = os.path.join(opt.results_dir, opt.project_name, opt.test_name, 'config.yaml') 128 | 129 | util.mkdirs(os.path.dirname(output_path)) 130 | with open(output_path, 'w') as f: 131 | yaml.dump(output_dict,f,default_flow_style=False, sort_keys=True) 132 | 133 | def load_options(self,opt): 134 | assert(opt.config is not None) 135 | from envyaml import EnvYAML 136 | 137 | args_usr = [ arg[2:] for arg in sys.argv if '--' in arg] 138 | config = EnvYAML(opt.config,include_environment=False) 139 | for name in config.keys(): 140 | # make sure yaml won't overwrite cmd input and the arg is defined 141 | basename = name.split('.')[-1] 142 | if basename not in args_usr and hasattr(opt,basename): 143 | setattr(opt, basename, config[name]) 144 | 145 | return opt 146 | 147 | def parse(self): 148 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 149 | opt = self.gather_options() 150 | 151 | opt.isTrain = self.isTrain # train or test 152 | 153 | # process opt.run_name 154 | if opt.run_name != '': 155 | opt.run_name = opt.run_name.format(**vars(opt)) 156 | else: 157 | from datetime import datetime 158 | opt.run_name = datetime.now().strftime("%d-%m-%Y %H:%M:%S") 159 | 160 | self.save_options(opt) 161 | 162 | if opt.verbose: 163 | self.print_options(opt) 164 | 165 | # set gpu ids 166 | str_ids = opt.gpu_ids.split(',') 167 | opt.gpu_ids = [] 168 | for str_id in str_ids: 169 | id = int(str_id) 170 | if id >= 0: 171 | opt.gpu_ids.append(id) 172 | if len(opt.gpu_ids) > 0: 173 | torch.cuda.set_device(opt.gpu_ids[0]) 174 | 175 | self.opt = opt 176 | return self.opt -------------------------------------------------------------------------------- /models/latent_object_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | import torch 4 | from scipy.spatial.transform import Rotation as scipy_rot 5 | 6 | from models.networks import networks 7 | 8 | from .base_model import BaseModel 9 | 10 | class LatentObjectModel(BaseModel): 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | 15 | models_args = parser.add_argument_group('models') 16 | 17 | models_args.add_argument('--z_dim', type=int, default=16, help='dimension of z') 18 | models_args.add_argument('--batch_size_vis', type=int, default=8, help='number of visualization samples') 19 | models_args.add_argument('--use_VAE', action='store_true', default=True, help='use KL divergence') 20 | models_args.add_argument('--category', type=str, default='laptop', help='object category') 21 | if is_train: 22 | models_args.add_argument('--lambda_recon', type=float, default=10., help='weight for reconstruction loss') 23 | models_args.add_argument('--lambda_KL', type=float, default=0.01, help='weight for the KL divergence') 24 | else: 25 | fitting_args = parser.add_argument_group('fitting') 26 | fitting_args.set_defaults(dataset_mode='nocs_hdf5', batch_size=1, no_flip=True, preprocess=' ') 27 | fitting_args.add_argument('--n_iter', type=int, default=50, help='number of optimization iterations') 28 | fitting_args.add_argument('--n_init', type=int, default=32, help='number of initializations') 29 | fitting_args.add_argument('--lambda_reg', type=float, default=1, help='weight for the KL divergence') 30 | 31 | return parser 32 | 33 | def __init__(self, opt): 34 | 35 | BaseModel.__init__(self, opt) 36 | self.use_VAE = opt.use_VAE 37 | 38 | self.loss_names = ['G_recon'] 39 | if self.opt.use_VAE > 0: self.loss_names += ['KL'] 40 | 41 | self.visual_names = ['real_A','real_B','fake_B'] 42 | 43 | self.video_names = ['anim_azim','anim_elev'] 44 | 45 | self.model_names = ['G','E'] 46 | 47 | self.optimizer_names = ['G'] 48 | 49 | # define networks (both generator and discriminator) 50 | self.netG = networks.Generator(opt.z_dim).to(self.device) 51 | networks.init_net(self.netG, init_type=self.opt.init_type, init_gain=self.opt.init_gain,gpu_ids=self.gpu_ids) 52 | 53 | output_dim = opt.z_dim *2 if self.use_VAE else opt.z_dim 54 | self.netE = networks.Encoder(3, opt.crop_size, output_dim).to(self.device) 55 | self.netE = networks.add_SN(self.netE) 56 | networks.init_net(self.netE, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 57 | 58 | if self.isTrain: 59 | self.criterion_recon = torch.nn.L1Loss().to(self.device) 60 | self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG.parameters(),self.netE.parameters()), lr=opt.lr, betas=(0.5,0.999)) 61 | self.optimizers.append(self.optimizer_G) 62 | 63 | # define the prior distribution` 64 | mu = torch.zeros(opt.z_dim, device=self.device) 65 | scale = torch.ones(opt.z_dim, device=self.device) 66 | self.z_dist = torch.distributions.Normal(mu, scale) 67 | 68 | self.batch_size_vis = opt.batch_size_vis 69 | 70 | def set_input(self, input): 71 | self.real_A = input['A'].to(self.device) 72 | self.real_B = input['B'].to(self.device) 73 | self.theta = input['B_pose'].to(self.device) 74 | 75 | def forward(self): 76 | if self.use_VAE == 0: 77 | self.z = self.netE(self.real_A) 78 | else: 79 | b = self.real_A.shape[0] 80 | output = self.netE(self.real_A) 81 | self.mu, self.logvar = output[:,:self.opt.z_dim],output[:,self.opt.z_dim:] 82 | std = self.logvar.mul(0.5).exp_() 83 | self.z_sample = self.z_dist.sample((b,)) 84 | eps = self.z_sample 85 | self.z = eps.mul(std).add_(self.mu) 86 | 87 | self.fake_B = self.netG(self.z,self.theta) 88 | return self.fake_B 89 | 90 | def backward(self): 91 | self.loss_KL = (1 + self.logvar - self.mu.pow(2) - self.logvar.exp()).mean() * (-0.5 * self.opt.lambda_KL) 92 | self.loss_G_recon = self.criterion_recon(self.fake_B, self.real_B) 93 | self.loss_G = self.loss_G_recon * self.opt.lambda_recon 94 | self.loss_G.backward() 95 | 96 | def optimize_parameters(self): 97 | self.train() 98 | self.forward() 99 | self.optimizer_G.zero_grad() 100 | self.backward() 101 | self.optimizer_G.step() 102 | 103 | def compute_visuals(self): 104 | self.netG.eval() 105 | self.real_A = self.real_A[:self.batch_size_vis,...] 106 | self.real_B = self.real_B[:self.batch_size_vis,...] 107 | self.fake_B = self.fake_B[:self.batch_size_vis,...] 108 | 109 | self.z_vis = self.netE(self.real_A)[:,:self.opt.z_dim] 110 | 111 | with torch.no_grad(): 112 | self.anim_azim = [] 113 | elev = 0 114 | for azim in range(-180,180,3): 115 | theta = torch.zeros((self.batch_size_vis,3)).to(self.device) 116 | theta[:,0],theta[:,1] = elev,azim 117 | frame = self.netG(self.z_vis,theta).detach().data 118 | self.anim_azim.append(frame) 119 | self.anim_elev= [] 120 | azim = 0 121 | for elev in range(-90,90,3): 122 | theta = torch.zeros((self.batch_size_vis, 3)).to(self.device) 123 | theta[:, 0], theta[:, 1] = elev, azim 124 | frame = self.netG(self.z_vis, theta).detach().data 125 | self.anim_elev.append(frame) 126 | 127 | def fitting(self, real_B): 128 | import tqdm 129 | import torch.optim as optim 130 | import torch.nn.functional as F 131 | 132 | from models.networks.utils import grid_sample, warping_grid, init_variable 133 | from models.networks.losses import PerceptualLoss 134 | 135 | real_B = real_B.to(self.device).repeat((self.opt.n_init, 1, 1, 1)) 136 | real_B = real_B[:, [2, 1, 0], :, :] 137 | 138 | ay = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='linspace', range=[-1/2,1/2]) 139 | ax = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=1/4) 140 | az = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=0) 141 | s = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=1) 142 | tx = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=0) 143 | ty = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=0) 144 | z = init_variable(dim=self.opt.z_dim, n_init=self.opt.n_init, device=self.device, mode='constant', value=0) 145 | 146 | latent = self.netE(F.interpolate(real_B, size=self.opt.crop_size, mode='nearest')) 147 | if self.opt.use_VAE: 148 | mu, logvar = latent[:, :self.opt.z_dim], latent[:, self.opt.z_dim:] 149 | std = logvar.mul(0.5).exp_() 150 | eps = self.z_dist.sample((self.opt.n_init,)) 151 | z.data = eps.mul(std).add_(mu) 152 | else: 153 | z.data = latent 154 | 155 | variable_dict = [ 156 | {'params': z, 'lr': 3e-1}, 157 | {'params': ax, 'lr': 1e-2}, 158 | {'params': ay, 'lr': 3e-2}, 159 | {'params': az, 'lr': 1e-2}, 160 | {'params': tx, 'lr': 3e-2}, 161 | {'params': ty, 'lr': 3e-2}, 162 | {'params': s, 'lr': 3e-2}, 163 | ] 164 | optimizer = optim.Adam(variable_dict,betas=(0.5,0.999)) 165 | 166 | losses = [('VGG', 1, PerceptualLoss(reduce=False))] 167 | reg_creterion = torch.nn.MSELoss(reduce=False) 168 | 169 | loss_history = np.zeros( (self.opt.n_init,self.opt.n_iter,len(losses)+1)) 170 | state_history = np.zeros( (self.opt.n_init,self.opt.n_iter,6 + self.opt.z_dim)) 171 | image_history = [] 172 | 173 | for iter in tqdm.tqdm(range(self.opt.n_iter)): 174 | 175 | optimizer.zero_grad() 176 | 177 | angle = 180 * torch.cat([ax, ay, torch.zeros_like(ay)], dim=1) 178 | fake_B = self.netG(z,angle) 179 | 180 | grid = warping_grid(az * np.pi, tx, ty, s, fake_B.shape) 181 | fake_B = grid_sample(fake_B, grid) 182 | 183 | fake_B_upsampled = F.interpolate(fake_B, size=real_B.shape[-1], mode='bilinear') 184 | 185 | error_all = 0 186 | for l, (name, weight, criterion)in enumerate(losses): 187 | error = weight * criterion(fake_B_upsampled, real_B).view(self.opt.n_init,-1).mean(1) 188 | loss_history[:,iter,l] = error.data.cpu().numpy() 189 | error_all = error_all + error 190 | 191 | error = self.opt.lambda_reg * reg_creterion(z,torch.zeros_like(z)).view(self.opt.n_init,-1).mean(1) 192 | loss_history[:, iter, l+1] = error.data.cpu().numpy() 193 | error_all = error_all + error 194 | 195 | error_all.mean().backward() 196 | 197 | optimizer.step() 198 | image_history.append(fake_B) 199 | 200 | state_history[:, iter, :3] = 180*torch.cat([-ay-0.5, ax+1, -az],dim=-1).data.cpu().numpy() 201 | state_history[:, iter, 3:] = torch.cat([tx, ty, s, z],dim=-1).data.cpu().numpy() 202 | 203 | return state_history, loss_history, image_history 204 | 205 | def visulize_fitting(self, real_B, RT_gt, state_history, loss_history, image_history): 206 | import matplotlib.pyplot as plt 207 | from util.util import tensor2im 208 | from models.networks.utils import set_axis 209 | import matplotlib 210 | matplotlib.use('TkAgg') 211 | 212 | RT_gt = RT_gt.numpy()[0] 213 | R_gt = RT_gt[:3, :3] 214 | real_B_img = tensor2im(real_B) 215 | 216 | n_init, n_iter, n_loss = loss_history.shape 217 | 218 | fig, axes = plt.subplots(nrows=loss_history.shape[2] + 2, ncols=n_init + 1, sharey='row') 219 | axes[0, -1].clear();axes[0, -1].axis('off') 220 | axes[0, -1].imshow(real_B_img) 221 | plt.ion() 222 | 223 | plots = axes.copy() 224 | for row in range(axes.shape[0]): 225 | for col in range(axes.shape[1]): 226 | if row == 0: 227 | axes[row, col].axis('off') 228 | plots[row, col] = axes[row, col].imshow(real_B_img) 229 | elif col < n_init: 230 | if row < n_loss+1: 231 | set_axis(axes[row, col]) 232 | plots[row, col] = axes[row, col].plot(np.arange(n_iter),loss_history[col, :,row-1]) 233 | else: 234 | plots[row, col] = axes[row, col].plot(np.arange(n_iter),60*np.ones(n_iter)) 235 | axes[row, col].set_ylim([0,60]) 236 | 237 | errors = np.zeros((n_init,n_iter)) 238 | 239 | for iter in range(n_iter): 240 | for init in range(n_init): 241 | pose = state_history[init,iter,:3] 242 | R_pd = scipy_rot.from_euler('yxz', pose, degrees=True).as_dcm()[:3, :3] 243 | 244 | R_pd = R_pd[:3, :3]/np.cbrt(np.linalg.det(R_pd[:3, :3])) 245 | R_gt = R_gt[:3, :3]/np.cbrt(np.linalg.det(R_gt[:3, :3])) 246 | 247 | R = R_pd @ R_gt.transpose() 248 | errors[init,iter] = np.arccos((np.trace(R) - 1)/2) * 180/np.pi 249 | 250 | ranking = [r[0] for r in sorted(enumerate(loss_history[:, iter,:].mean(-1)), key=lambda r: r[1])] 251 | 252 | for r, b in enumerate(ranking[::-1]): 253 | plots[0, r].set_data(tensor2im(image_history[iter][b].unsqueeze(0))) 254 | for l in range(loss_history.shape[2]): 255 | plots[l + 1, r][0].set_data(np.arange(iter),loss_history[b, :iter,l]) 256 | plots[-1, r][0].set_data(np.arange(iter), errors[b, :iter]) 257 | 258 | plt.draw() 259 | 260 | plt.pause(0.01) 261 | plt.close(fig) -------------------------------------------------------------------------------- /models/networks/networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import init 4 | import torch.utils.data 5 | import torch.utils.data.distributed 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.optim import lr_scheduler 9 | 10 | from models.networks import spectral_norm 11 | 12 | class Generator(nn.Module): 13 | def __init__(self, z_dim, euler_seq='zyx', **kwargs): 14 | super().__init__() 15 | self.shape_code = nn.Parameter(0.02*torch.randn(1,512,4,4,4),requires_grad=True) 16 | # Upsampling 3D 17 | self.enc_1 = nn.Sequential(*[nn.ConvTranspose3d(512,128,kernel_size=4,stride=2, padding=1)]) 18 | self.enc_2 = nn.Sequential(*[nn.ConvTranspose3d(128,64, kernel_size=4, stride=2, padding=1)]) 19 | # Projection 20 | self.proj = nn.Sequential(*[nn.ConvTranspose2d(64*16,64*16, kernel_size=1,stride=1)]) 21 | # Upsampling 2D 22 | self.enc_3 = nn.Sequential(*[nn.ConvTranspose2d(64*16,64*4,kernel_size=4,stride=2, padding=1)]) 23 | self.enc_4 = nn.Sequential(*[nn.ConvTranspose2d(64*4,64,kernel_size=4,stride=2, padding=1)]) 24 | self.enc_5 = nn.Sequential(*[nn.ConvTranspose2d(64,3,kernel_size=3,stride=1,padding=1)]) 25 | # MLP for AdaIN 26 | self.mlp0 = LinearBlock(z_dim,512*2,activation='relu') 27 | self.mlp1 = LinearBlock(z_dim,128*2,activation='relu') 28 | self.mlp2 = LinearBlock(z_dim,64*2,activation='relu') 29 | self.mlp3 = LinearBlock(z_dim,256*2,activation='relu') 30 | self.mlp4 = LinearBlock(z_dim,64*2,activation='relu') 31 | 32 | self.euler_seq = euler_seq 33 | 34 | def forward(self, z, angle, a=None, debug=False): 35 | b,_ = z.size() 36 | angle = angle / 180. * np.pi 37 | # Upsampling 3D 38 | h0 = self.shape_code.expand(b, 512, 4, 4, 4).clone() 39 | a0 = self.mlp0(z) 40 | h0 = actvn( adaIN(h0,a0) ) 41 | 42 | h1 = self.enc_1(h0) 43 | a1 = self.mlp1(z) 44 | h1 = actvn( adaIN(h1,a1) ) 45 | 46 | h2 = self.enc_2(h1) 47 | a2 = self.mlp2(z) 48 | h2 = actvn(adaIN(h2, a2)) 49 | 50 | # Rotation 51 | h2_rot = rot(h2,angle,euler_seq=self.euler_seq,padding="border") 52 | b,c,d,h,w = h2_rot.size() 53 | h2_2d = h2_rot.contiguous().view(b,c*d,h,w) 54 | h2_2d = actvn(self.proj(h2_2d)) 55 | # Upsampling 2D 56 | h3 = self.enc_3(h2_2d) 57 | a3 = self.mlp3(z) 58 | h3 = actvn(adaIN(h3, a3)) 59 | 60 | h4 = self.enc_4(h3) 61 | a4 = self.mlp4(z) 62 | h4 = actvn(adaIN(h4, a4)) 63 | 64 | h5 = self.enc_5(h4) 65 | return F.tanh(h5) 66 | def actvn(x): 67 | out = F.leaky_relu(x, 2e-1) 68 | return out 69 | 70 | class Encoder(nn.Module): 71 | def __init__(self,in_dim=3, in_size=64, z_dim=128): 72 | super().__init__() 73 | self.model = nn.Sequential(*[ 74 | nn.Conv2d( in_dim, 64, 3, 2, 1), nn.LeakyReLU(0.2), 75 | nn.Conv2d( 64, 128, 3, 2, 1), nn.LeakyReLU(0.2), nn.InstanceNorm2d(128), 76 | nn.Conv2d(128, 256, 3, 2, 1), nn.LeakyReLU(0.2), nn.InstanceNorm2d(256), 77 | nn.Conv2d(256, 512, 3, 2, 1), nn.LeakyReLU(0.2), nn.InstanceNorm2d(512), 78 | ]) 79 | self.enc_out = nn.Sequential(*[ 80 | nn.Linear((in_size//16)**2*512,128), nn.LeakyReLU(0.2), 81 | nn.Linear(128,z_dim), nn.Tanh() 82 | ]) 83 | def forward(self, x): 84 | b,c,h,w = x.shape 85 | x = self.model.forward(x).view(b,(h//16)**2*512) 86 | enc = self.enc_out(x) 87 | return enc 88 | 89 | class LinearBlock(nn.Module): 90 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'): 91 | super(LinearBlock, self).__init__() 92 | use_bias = True 93 | # initialize fully connected layer 94 | 95 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 96 | 97 | # initialize normalization 98 | norm_dim = output_dim 99 | if norm == 'bn': 100 | self.norm = nn.BatchNorm1d(norm_dim) 101 | elif norm == 'in': 102 | self.norm = nn.InstanceNorm1d(norm_dim) 103 | elif norm == 'none' or norm == 'sn': 104 | self.norm = None 105 | else: 106 | assert 0, "Unsupported normalization: {}".format(norm) 107 | 108 | # initialize activation 109 | if activation == 'relu': 110 | self.activation = nn.ReLU(inplace=True) 111 | elif activation == 'lrelu': 112 | self.activation = nn.LeakyReLU(0.2, inplace=True) 113 | elif activation == 'prelu': 114 | self.activation = nn.PReLU() 115 | elif activation == 'selu': 116 | self.activation = nn.SELU(inplace=True) 117 | elif activation == 'tanh': 118 | self.activation = nn.Tanh() 119 | elif activation == 'none': 120 | self.activation = None 121 | else: 122 | assert 0, "Unsupported activation: {}".format(activation) 123 | 124 | def forward(self, x): 125 | out = self.fc(x) 126 | if self.norm: 127 | out = self.norm(out) 128 | if self.activation: 129 | out = self.activation(out) 130 | return out 131 | 132 | def rot(x,angle,euler_seq='xyz',padding='zeros'): 133 | b,c,d,h,w = x.shape 134 | grid = set_id_grid(x) 135 | grid_flat = grid.reshape(b, 3, -1) 136 | grid_rot_flat = euler2mat(angle,euler_seq=euler_seq).bmm(grid_flat) 137 | grid_rot = grid_rot_flat.reshape(b,3,d,h,w) 138 | x_rot = F.grid_sample(x,grid_rot.permute(0,2,3,4,1),padding_mode=padding,mode='bilinear') 139 | return x_rot 140 | 141 | def euler2mat(angle, euler_seq='xyz' ): 142 | """Convert euler angles to rotation matrix. 143 | 144 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174 145 | 146 | Args: 147 | angle: rotation angle along 3 axis (in radians) -- size = [B, 3] 148 | Returns: 149 | Rotation matrix corresponding to the euler angles -- size = [B, 3, 3] 150 | """ 151 | B = angle.size(0) 152 | x, y, z = angle[:,0], angle[:,1], angle[:,2] 153 | 154 | zeros = z.detach()*0 155 | ones = zeros.detach()+1 156 | 157 | cosz = torch.cos(z) 158 | sinz = torch.sin(z) 159 | zmat = torch.stack([cosz, -sinz, zeros, 160 | sinz, cosz, zeros, 161 | zeros, zeros, ones], dim=1).reshape(B, 3, 3) 162 | 163 | cosy = torch.cos(y) 164 | siny = torch.sin(y) 165 | ymat = torch.stack([cosy, zeros, siny, 166 | zeros, ones, zeros, 167 | -siny, zeros, cosy], dim=1).reshape(B, 3, 3) 168 | 169 | cosx = torch.cos(x) 170 | sinx = torch.sin(x) 171 | xmat = torch.stack([ones, zeros, zeros, 172 | zeros, cosx, -sinx, 173 | zeros, sinx, cosx], dim=1).reshape(B, 3, 3) 174 | 175 | if euler_seq == 'xyz': 176 | rotMat = xmat.bmm(ymat).bmm(zmat) 177 | elif euler_seq == 'zyx': 178 | rotMat = zmat.bmm(ymat).bmm(xmat) 179 | return rotMat 180 | 181 | 182 | def set_id_grid(x): 183 | b, c, d, h, w = x.shape 184 | z_range = (torch.linspace(-1,1,steps=d)).view(1, d, 1, 1).expand(1, d, h, w).type_as(x) # [1, H, W, D] 185 | y_range = (torch.linspace(-1,1,steps=h)).view(1, 1, h, 1).expand(1, d, h, w).type_as(x) # [1, H, W, D] 186 | x_range = (torch.linspace(-1,1,steps=w)).view(1, 1, 1, w).expand(1, d, h, w).type_as(x) # [1, H, W, D] 187 | grid = torch.cat((x_range, y_range, z_range), dim=0)[None,...] # x,y,z 188 | grid = grid.expand(b,3,d,h,w) 189 | return grid 190 | 191 | def calc_mean_std(feat, eps=1e-5): 192 | # eps is a small value added to the variance to avoid divide-by-zero. 193 | size = feat.size() 194 | assert (len(size) == 4 or len(size) == 5) 195 | N, C = size[:2] 196 | 197 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 198 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 199 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 200 | 201 | if len(size)==5: 202 | feat_std = feat_std.unsqueeze(-1) 203 | feat_mean = feat_mean.unsqueeze(-1) 204 | 205 | return feat_mean, feat_std 206 | 207 | 208 | def adaIN(content_feat, style_mean_std): 209 | assert(content_feat.size(1) == style_mean_std.size(1)/2) 210 | size = content_feat.size() 211 | b,c = style_mean_std.size() 212 | style_mean, style_std = style_mean_std[:,:c//2],style_mean_std[:,c//2:] 213 | 214 | style_mean = style_mean.unsqueeze(-1).unsqueeze(-1) 215 | style_std = style_std.unsqueeze(-1).unsqueeze(-1) 216 | if len(size)==5: 217 | style_mean = style_mean.unsqueeze(-1) 218 | style_std = style_std.unsqueeze(-1) 219 | content_mean, content_std = calc_mean_std(content_feat) 220 | 221 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 222 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 223 | 224 | def add_SN(m): 225 | for name, c in m.named_children(): 226 | m.add_module(name, add_SN(c)) 227 | if isinstance(m, (nn.Conv2d, nn.Linear)): 228 | return spectral_norm.spectral_norm(m)#nn.utils.spectral_norm(m) 229 | else: 230 | return m 231 | 232 | def init_weights(net, init_type='normal', init_gain=0.02): 233 | """Initialize network weights. 234 | 235 | Parameters: 236 | net (network) -- network to be initialized 237 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 238 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 239 | 240 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 241 | work better for some applications. Feel free to try yourself. 242 | """ 243 | def init_func(m): # define the initialization function 244 | classname = m.__class__.__name__ 245 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 246 | if init_type == 'normal': 247 | init.normal_(m.weight.data, 0.0, init_gain) 248 | elif init_type == 'xavier': 249 | init.xavier_normal_(m.weight.data, gain=init_gain) 250 | elif init_type == 'kaiming': 251 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 252 | elif init_type == 'orthogonal': 253 | init.orthogonal_(m.weight.data, gain=init_gain) 254 | else: 255 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 256 | if hasattr(m, 'bias') and m.bias is not None: 257 | init.constant_(m.bias.data, 0.0) 258 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 259 | init.normal_(m.weight.data, 1.0, init_gain) 260 | init.constant_(m.bias.data, 0.0) 261 | 262 | print('initialize network with %s' % init_type) 263 | net.apply(init_func) # apply the initialization function 264 | 265 | 266 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 267 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 268 | Parameters: 269 | net (network) -- the network to be initialized 270 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 271 | gain (float) -- scaling factor for normal, xavier and orthogonal. 272 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 273 | 274 | Return an initialized network. 275 | """ 276 | if len(gpu_ids) > 0: 277 | assert(torch.cuda.is_available()) 278 | net.to(gpu_ids[0]) 279 | # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 280 | if init_type is not None: 281 | init_weights(net, init_type, init_gain=init_gain) 282 | return net 283 | 284 | 285 | def get_scheduler(optimizer, opt): 286 | """Return a learning rate scheduler 287 | Parameters: 288 | optimizer -- the optimizer of the network 289 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  290 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 291 | For 'linear', we keep the same learning rate for the first epochs 292 | and linearly decay the rate to zero over the next epochs. 293 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 294 | See https://pytorch.org/docs/stable/optim.html for more details. 295 | """ 296 | if opt.lr_policy == 'linear': 297 | def lambda_rule(epoch): 298 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 299 | return lr_l 300 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 301 | elif opt.lr_policy == 'step': 302 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 303 | elif opt.lr_policy == 'plateau': 304 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 305 | elif opt.lr_policy == 'cosine': 306 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 307 | else: 308 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 309 | return scheduler -------------------------------------------------------------------------------- /models/networks/spectral_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Spectral Normalization from https://arxiv.org/abs/1802.05957 3 | """ 4 | import torch 5 | from torch.nn.functional import normalize 6 | 7 | 8 | class SpectralNorm(object): 9 | # Invariant before and after each forward call: 10 | # u = normalize(W @ v) 11 | # NB: At initialization, this invariant is not enforced 12 | 13 | _version = 1 14 | # At version 1: 15 | # made `W` not a buffer, 16 | # added `v` as a buffer, and 17 | # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. 18 | 19 | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): 20 | self.name = name 21 | self.dim = dim 22 | if n_power_iterations <= 0: 23 | raise ValueError('Expected n_power_iterations to be positive, but ' 24 | 'got n_power_iterations={}'.format(n_power_iterations)) 25 | self.n_power_iterations = n_power_iterations 26 | self.eps = eps 27 | 28 | def reshape_weight_to_matrix(self, weight): 29 | weight_mat = weight 30 | if self.dim != 0: 31 | # permute dim to front 32 | weight_mat = weight_mat.permute(self.dim, 33 | *[d for d in range(weight_mat.dim()) if d != self.dim]) 34 | height = weight_mat.size(0) 35 | return weight_mat.reshape(height, -1) 36 | 37 | def compute_weight(self, module, do_power_iteration): 38 | # NB: If `do_power_iteration` is set, the `u` and `v` vectors are 39 | # updated in power iteration **in-place**. This is very important 40 | # because in `DataParallel` forward, the vectors (being buffers) are 41 | # broadcast from the parallelized module to each module replica, 42 | # which is a new module object created on the fly. And each replica 43 | # runs its own spectral norm power iteration. So simply assigning 44 | # the updated vectors to the module this function runs on will cause 45 | # the update to be lost forever. And the next time the parallelized 46 | # module is replicated, the same randomly initialized vectors are 47 | # broadcast and used! 48 | # 49 | # Therefore, to make the change propagate back, we rely on two 50 | # important behaviors (also enforced via tests): 51 | # 1. `DataParallel` doesn't clone storage if the broadcast tensor 52 | # is already on correct device; and it makes sure that the 53 | # parallelized module is already on `device[0]`. 54 | # 2. If the out tensor in `out=` kwarg has correct shape, it will 55 | # just fill in the values. 56 | # Therefore, since the same power iteration is performed on all 57 | # devices, simply updating the tensors in-place will make sure that 58 | # the module replica on `device[0]` will update the _u vector on the 59 | # parallized module (by shared storage). 60 | # 61 | # However, after we update `u` and `v` in-place, we need to **clone** 62 | # them before using them to normalize the weight. This is to support 63 | # backproping through two forward passes, e.g., the common pattern in 64 | # GAN training: loss = D(real) - D(fake). Otherwise, engine will 65 | # complain that variables needed to do backward for the first forward 66 | # (i.e., the `u` and `v` vectors) are changed in the second forward. 67 | weight = getattr(module, self.name + '_orig') 68 | u = getattr(module, self.name + '_u') 69 | v = getattr(module, self.name + '_v') 70 | weight_mat = self.reshape_weight_to_matrix(weight) 71 | 72 | if do_power_iteration: 73 | with torch.no_grad(): 74 | for _ in range(self.n_power_iterations): 75 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 76 | # are the first left and right singular vectors. 77 | # This power iteration produces approximations of `u` and `v`. 78 | v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v) 79 | u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) 80 | if self.n_power_iterations > 0: 81 | # See above on why we need to clone 82 | u = u.clone() 83 | v = v.clone() 84 | 85 | sigma = torch.dot(u, torch.mv(weight_mat, v)) 86 | weight = weight / sigma 87 | return weight 88 | 89 | def remove(self, module): 90 | with torch.no_grad(): 91 | weight = self.compute_weight(module, do_power_iteration=False) 92 | delattr(module, self.name) 93 | delattr(module, self.name + '_u') 94 | delattr(module, self.name + '_v') 95 | delattr(module, self.name + '_orig') 96 | module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) 97 | 98 | def __call__(self, module, inputs): 99 | setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training)) 100 | 101 | def _solve_v_and_rescale(self, weight_mat, u, target_sigma): 102 | # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` 103 | # (the invariant at top of this class) and `u @ W @ v = sigma`. 104 | # This uses pinverse in case W^T W is not invertible. 105 | v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1) 106 | return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) 107 | 108 | @staticmethod 109 | def apply(module, name, n_power_iterations, dim, eps): 110 | for k, hook in module._forward_pre_hooks.items(): 111 | if isinstance(hook, SpectralNorm) and hook.name == name: 112 | raise RuntimeError("Cannot register two spectral_norm hooks on " 113 | "the same parameter {}".format(name)) 114 | 115 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 116 | weight = module._parameters[name] 117 | 118 | with torch.no_grad(): 119 | weight_mat = fn.reshape_weight_to_matrix(weight) 120 | 121 | h, w = weight_mat.size() 122 | # randomly initialize `u` and `v` 123 | u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) 124 | v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) 125 | 126 | delattr(module, fn.name) 127 | module.register_parameter(fn.name + "_orig", weight) 128 | # We still need to assign weight back as fn.name because all sorts of 129 | # things may assume that it exists, e.g., when initializing weights. 130 | # However, we can't directly assign as it could be an nn.Parameter and 131 | # gets added as a parameter. Instead, we register weight.data as a plain 132 | # attribute. 133 | setattr(module, fn.name, weight.data) 134 | module.register_buffer(fn.name + "_u", u) 135 | module.register_buffer(fn.name + "_v", v) 136 | 137 | module.register_forward_pre_hook(fn) 138 | 139 | module._register_state_dict_hook(SpectralNormStateDictHook(fn)) 140 | module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) 141 | return fn 142 | 143 | 144 | # This is a top level class because Py2 pickle doesn't like inner class nor an 145 | # instancemethod. 146 | class SpectralNormLoadStateDictPreHook(object): 147 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 148 | def __init__(self, fn): 149 | self.fn = fn 150 | 151 | # For state_dict with version None, (assuming that it has gone through at 152 | # least one training forward), we have 153 | # 154 | # u = normalize(W_orig @ v) 155 | # W = W_orig / sigma, where sigma = u @ W_orig @ v 156 | # 157 | # To compute `v`, we solve `W_orig @ x = u`, and let 158 | # v = x / (u @ W_orig @ x) * (W / W_orig). 159 | def __call__(self, state_dict, prefix, local_metadata, strict, 160 | missing_keys, unexpected_keys, error_msgs): 161 | fn = self.fn 162 | version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None) 163 | if version is None or version < 1: 164 | weight_key = prefix + fn.name 165 | if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \ 166 | weight_key not in state_dict: 167 | # Detect if it is the updated state dict and just missing metadata. 168 | # This could happen if the users are crafting a state dict themselves, 169 | # so we just pretend that this is the newest. 170 | return 171 | has_missing_keys = False 172 | for suffix in ('_orig', '', '_u'): 173 | key = weight_key + suffix 174 | if key not in state_dict: 175 | has_missing_keys = True 176 | if strict: 177 | missing_keys.append(key) 178 | if has_missing_keys: 179 | return 180 | with torch.no_grad(): 181 | weight_orig = state_dict[weight_key + '_orig'] 182 | weight = state_dict.pop(weight_key) 183 | sigma = (weight_orig / weight).mean() 184 | weight_mat = fn.reshape_weight_to_matrix(weight_orig) 185 | u = state_dict[weight_key + '_u'] 186 | v = fn._solve_v_and_rescale(weight_mat, u, sigma) 187 | state_dict[weight_key + '_v'] = v 188 | 189 | 190 | 191 | # This is a top level class because Py2 pickle doesn't like inner class nor an 192 | # instancemethod. 193 | class SpectralNormStateDictHook(object): 194 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 195 | def __init__(self, fn): 196 | self.fn = fn 197 | 198 | def __call__(self, module, state_dict, prefix, local_metadata): 199 | if 'spectral_norm' not in local_metadata: 200 | local_metadata['spectral_norm'] = {} 201 | key = self.fn.name + '.version' 202 | if key in local_metadata['spectral_norm']: 203 | raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key)) 204 | local_metadata['spectral_norm'][key] = self.fn._version 205 | 206 | 207 | def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): 208 | r"""Applies spectral normalization to a parameter in the given module. 209 | 210 | .. math:: 211 | \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, 212 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 213 | 214 | Spectral normalization stabilizes the training of discriminators (critics) 215 | in Generative Adversarial Networks (GANs) by rescaling the weight tensor 216 | with spectral norm :math:`\sigma` of the weight matrix calculated using 217 | power iteration method. If the dimension of the weight tensor is greater 218 | than 2, it is reshaped to 2D in power iteration method to get spectral 219 | norm. This is implemented via a hook that calculates spectral norm and 220 | rescales weight before every :meth:`~Module.forward` call. 221 | 222 | See `Spectral Normalization for Generative Adversarial Networks`_ . 223 | 224 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 225 | 226 | Args: 227 | module (nn.Module): containing module 228 | name (str, optional): name of weight parameter 229 | n_power_iterations (int, optional): number of power iterations to 230 | calculate spectral norm 231 | eps (float, optional): epsilon for numerical stability in 232 | calculating norms 233 | dim (int, optional): dimension corresponding to number of outputs, 234 | the default is ``0``, except for modules that are instances of 235 | ConvTranspose{1,2,3}d, when it is ``1`` 236 | 237 | Returns: 238 | The original module with the spectral norm hook 239 | 240 | Example:: 241 | 242 | >>> m = spectral_norm(nn.Linear(20, 40)) 243 | >>> m 244 | Linear(in_features=20, out_features=40, bias=True) 245 | >>> m.weight_u.size() 246 | torch.Size([40]) 247 | 248 | """ 249 | if dim is None: 250 | if isinstance(module, (torch.nn.ConvTranspose1d, 251 | torch.nn.ConvTranspose2d, 252 | torch.nn.ConvTranspose3d)): 253 | dim = 1 254 | else: 255 | dim = 0 256 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 257 | return module 258 | 259 | 260 | def remove_spectral_norm(module, name='weight'): 261 | r"""Removes the spectral normalization reparameterization from a module. 262 | 263 | Args: 264 | module (Module): containing module 265 | name (str, optional): name of weight parameter 266 | 267 | Example: 268 | >>> m = spectral_norm(nn.Linear(40, 10)) 269 | >>> remove_spectral_norm(m) 270 | """ 271 | for k, hook in module._forward_pre_hooks.items(): 272 | if isinstance(hook, SpectralNorm) and hook.name == name: 273 | hook.remove(module) 274 | del module._forward_pre_hooks[k] 275 | return module 276 | 277 | raise ValueError("spectral_norm of '{}' not found in {}".format( 278 | name, module)) 279 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from models.networks import networks 8 | 9 | 10 | class BaseModel(ABC): 11 | """This class is an abstract base class (ABC) for models. 12 | To create a subclass, you need to implement the following five functions: 13 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 14 | -- : unpack data from dataset and apply preprocessing. 15 | -- : produce intermediate results. 16 | -- : calculate losses, gradients, and update network weights. 17 | -- : (optionally) add model-specific options and set default options. 18 | """ 19 | 20 | def __init__(self, opt): 21 | """Initialize the BaseModel class. 22 | 23 | Parameters: 24 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 25 | 26 | When creating your custom class, you need to implement your own initialization. 27 | In this fucntion, you should first call 28 | Then, you need to define four lists: 29 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 30 | -- self.model_names (str list): specify the images that you want to display and save. 31 | -- self.visual_names (str list): define networks used in our training. 32 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 33 | """ 34 | self.opt = opt 35 | self.gpu_ids = opt.gpu_ids 36 | self.isTrain = opt.isTrain 37 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 38 | # if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 39 | # torch.backends.cudnn.benchmark = True 40 | self.loss_names = [] 41 | self.model_names = [] 42 | self.optimizer_names = [] 43 | self.visual_names = [] 44 | self.optimizers = [] 45 | self.image_paths = [] 46 | self.metric = 0 # used for learning rate policy 'plateau' 47 | 48 | self.save_dir = os.path.join(opt.checkpoints_dir,opt.project_name, opt.exp_name, opt.run_name) # save all the checkpoints to save_dir 49 | self.net_dict = { name: getattr(self, 'net' + name) for name in self.model_names} 50 | self.optimizer_dict = {name: getattr(self, 'optimizer_' + name) for name in self.model_names} if opt.isTrain else {} 51 | 52 | @staticmethod 53 | def modify_commandline_options(parser, is_train): 54 | """Add new model-specific options, and rewrite default values for existing options. 55 | 56 | Parameters: 57 | parser -- original option parser 58 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 59 | 60 | Returns: 61 | the modified parser. 62 | """ 63 | return parser 64 | 65 | @abstractmethod 66 | def set_input(self, input): 67 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 68 | 69 | Parameters: 70 | input (dict): includes the data itself and its metadata information. 71 | """ 72 | pass 73 | 74 | @abstractmethod 75 | def forward(self,vis=False): 76 | """Run forward pass; called by both functions and .""" 77 | pass 78 | 79 | @abstractmethod 80 | def optimize_parameters(self): 81 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 82 | pass 83 | 84 | def setup(self, opt): 85 | """Load and print networks; create schedulers 86 | 87 | Parameters: 88 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 89 | """ 90 | if self.isTrain: 91 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] #TODO: why need this? 92 | 93 | # get the latest checkpoints 94 | import glob 95 | 96 | checkpoint_folder = os.path.join(opt.checkpoints_dir, opt.project_name, opt.exp_name, opt.run_name) 97 | 98 | search_pattern = '*.pth' 99 | 100 | checkpoint_list = glob.glob(os.path.join(checkpoint_folder, search_pattern)) 101 | print(checkpoint_folder) 102 | iter_start = 0 103 | if len(checkpoint_list) > 0: 104 | load_suffix = self.opt.load_suffix 105 | self.load_networks(load_suffix) 106 | if self.isTrain: 107 | self.load_optimizers(load_suffix) 108 | iter_start = self.load_states(load_suffix) 109 | 110 | self.print_networks(opt.verbose) 111 | 112 | return iter_start 113 | 114 | def eval(self): 115 | """Make models eval mode during test time""" 116 | for name in self.model_names: 117 | if isinstance(name, str): 118 | net = getattr(self, 'net' + name) 119 | net.eval() 120 | 121 | def train(self): 122 | """Make models eval mode during test time""" 123 | for name in self.model_names: 124 | if isinstance(name, str): 125 | net = getattr(self, 'net' + name) 126 | net.train() 127 | 128 | def test(self): 129 | """Forward function used in test time. 130 | 131 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 132 | It also calls to produce additional visualization results 133 | """ 134 | with torch.no_grad(): 135 | self.forward() 136 | self.compute_visuals() 137 | 138 | def compute_visuals(self): 139 | """Calculate additional output images for visdom and HTML visualization""" 140 | pass 141 | 142 | def get_image_paths(self): 143 | """ Return image paths that are used to load current data""" 144 | return self.image_paths 145 | 146 | def update_learning_rate(self): 147 | """Update learning rates for all the networks; called at the end of every epoch""" 148 | for scheduler in self.schedulers: 149 | if self.opt.lr_policy == 'plateau': 150 | scheduler.step(self.metric) 151 | else: 152 | scheduler.step() 153 | 154 | # lr = self.optimizers[0].param_groups[0]['lr'] 155 | # print('learning rate = %.7f' % lr) 156 | 157 | def get_current_visuals(self): 158 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 159 | visual_ret = OrderedDict() 160 | for name in self.visual_names: 161 | if isinstance(name, str): 162 | visual_ret[name] = getattr(self, name) 163 | return visual_ret 164 | 165 | def get_current_videos(self): 166 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 167 | visual_ret = OrderedDict() 168 | for name in self.video_names: 169 | if isinstance(name, str): 170 | visual_ret[name] = getattr(self, name) 171 | return visual_ret 172 | 173 | def get_current_losses(self): 174 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 175 | errors_ret = OrderedDict() 176 | for name in self.loss_names: 177 | if isinstance(name, str): 178 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 179 | return errors_ret 180 | 181 | def save(self, suffix,iter): 182 | """ Save all the networks, optimizers and states to the disk. 183 | 184 | Parameters: 185 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 186 | """ 187 | self.save_optimizers(suffix) 188 | self.save_networks(suffix) 189 | self.save_states(suffix,iter) 190 | 191 | def save_networks(self, suffix): 192 | """Save all the networks to the disk. 193 | 194 | Parameters: 195 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 196 | """ 197 | save_name = 'model_%s.pth' % suffix 198 | save_path = os.path.join(self.save_dir, save_name) 199 | outdict = {} 200 | for name in self.model_names: 201 | if isinstance(name, str): 202 | net_name = 'net' + name 203 | outdict[net_name] = getattr(self, net_name).state_dict() 204 | 205 | torch.save(outdict, save_path) 206 | 207 | def save_optimizers(self,suffix): 208 | """Save all the optimizers to the disk. 209 | 210 | Parameters: 211 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 212 | """ 213 | save_name = 'optimizer_%s.pth' % suffix 214 | save_path = os.path.join(self.save_dir, save_name) 215 | output_dict = {} 216 | for name in self.optimizer_names: 217 | if isinstance(name, str): 218 | optimizer_name = 'optimizer_' + name 219 | output_dict[optimizer_name] = getattr(self, optimizer_name).state_dict() 220 | torch.save(output_dict, save_path) 221 | 222 | def save_states(self, suffix, iter): 223 | """Save all the states (epoch, iter) to the disk. 224 | 225 | Parameters: 226 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 227 | """ 228 | save_name = 'states_%s.txt' % suffix 229 | save_path = os.path.join(self.save_dir, save_name) 230 | import numpy as np 231 | states = np.array([iter]) 232 | np.savetxt(save_path, states) 233 | 234 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 235 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 236 | key = keys[i] 237 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 238 | if module.__class__.__name__.startswith('InstanceNorm') and \ 239 | (key == 'running_mean' or key == 'running_var'): 240 | if getattr(module, key) is None: 241 | state_dict.pop('.'.join(keys)) 242 | if module.__class__.__name__.startswith('InstanceNorm') and \ 243 | (key == 'num_batches_tracked'): 244 | state_dict.pop('.'.join(keys)) 245 | else: 246 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 247 | 248 | def load_networks(self, suffix): 249 | """Load all the networks from the disk. 250 | 251 | Parameters: 252 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 253 | """ 254 | load_path = os.path.join(self.save_dir, 'model_%s.pth' % (suffix)) 255 | # if you are using PyTorch newer than 0.4 (e.g., built from 256 | # GitHub source), you can remove str() on self.device 257 | try: 258 | out_dict = torch.load(load_path,map_location=str(self.device)) 259 | for name in self.model_names: 260 | if isinstance(name, str): 261 | net_name = 'net' + name 262 | net = getattr(self, net_name) 263 | 264 | 265 | state_dict = out_dict[net_name] 266 | if hasattr(state_dict, '_metadata'): 267 | del state_dict._metadata 268 | # patch InstanceNorm checkpoints prior to 0.4 269 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 270 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 271 | 272 | net.load_state_dict(state_dict, strict=True) 273 | print('[%s] loaded from [%s]' % (net_name,load_path)) 274 | 275 | 276 | except Exception: 277 | print('no checkpoints for the network found, parameters will be initialized') 278 | 279 | 280 | 281 | def load_optimizers(self, suffix): 282 | """Load all the optimizers from the disk. 283 | 284 | Parameters: 285 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 286 | """ 287 | load_path = os.path.join(self.save_dir, 'optimizer_%s.pth' % (suffix)) 288 | # if you are using PyTorch newer than 0.4 (e.g., built from 289 | # GitHub source), you can remove str() on self.device 290 | try: 291 | out_dict = torch.load(load_path,map_location=str(self.device)) 292 | for name in self.optimizer_names: 293 | if isinstance(name, str): 294 | optimizer_name = 'optimizer_' + name 295 | optimizer = getattr(self, optimizer_name) 296 | optimizer.load_state_dict(out_dict[optimizer_name]) 297 | print('optimizer loaded from [%s]' % load_path) 298 | except Exception: 299 | print('no checkpoints for the optimizer found, parameters will be initialized') 300 | 301 | 302 | def load_states(self, suffix): 303 | """Load all the states (epoch, iterations) from the disk. 304 | 305 | Parameters: 306 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 307 | """ 308 | load_path = os.path.join(self.save_dir, 'states_%s.txt' % (suffix)) 309 | import numpy as np 310 | try: 311 | out_dict = np.loadtxt(load_path) 312 | print('states loaded from [%s]' % load_path) 313 | return out_dict 314 | except Exception: 315 | print('no states found, start from epoch 1, iter 0') 316 | return 0 317 | 318 | def print_networks(self, verbose): 319 | """Print the total number of parameters in the network and (if verbose) network architecture 320 | 321 | Parameters: 322 | verbose (bool) -- if verbose: print the network architecture 323 | """ 324 | for name in self.model_names: 325 | if isinstance(name, str): 326 | net = getattr(self, 'net' + name) 327 | num_params = 0 328 | for param in net.parameters(): 329 | num_params += param.numel() 330 | if verbose: 331 | print(net) 332 | print('[Network %s] has [%.3f M] parameters' % (name, num_params / 1e6)) 333 | 334 | def set_requires_grad(self, nets, requires_grad=False): 335 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 336 | Parameters: 337 | nets (network list) -- a list of networks 338 | requires_grad (bool) -- whether the networks require gradients or not 339 | """ 340 | if not isinstance(nets, list): 341 | nets = [nets] 342 | for net in nets: 343 | if net is not None: 344 | for param in net.parameters(): 345 | param.requires_grad_(requires_grad) 346 | --------------------------------------------------------------------------------