├── .gitignore ├── ECCV'22 poster_v2.pdf ├── LICENSE ├── README.md ├── config.py ├── configs ├── config.yaml ├── default.yaml ├── eval.yaml └── infer.yaml ├── data ├── .DS_Store ├── __init__.py ├── base_dataset.py ├── carla_dataset.py └── nocs_hdf5_dataset.py ├── eccv_poster.pdf ├── environment.yml ├── images └── teaser_video.gif ├── main.py ├── models ├── __init__.py ├── base_model.py ├── latent_object_model.py ├── networks │ ├── __init__.py │ ├── losses.py │ ├── networks.py │ ├── spectral_norm.py │ └── utils.py └── nocs_gym.py ├── nocs ├── eval.py └── utils.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── prepare_datasets.sh ├── runner.py ├── sac_2 ├── SAC.py ├── __init__.py ├── model.py └── model_encoder │ ├── image_to_latent.py │ ├── latent_optimizer.py │ └── losses.py └── utils ├── __init__.py ├── loss.py ├── util.py ├── utils.py └── visualizer ├── __init__.py ├── base_visualizer.py ├── terminal_visualizer.py └── wandb_visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | __pycache__ 3 | results/ 4 | checkpoints/ 5 | datasets/ 6 | wandb/ 7 | *_model/ 8 | neural_object_fitting/ 9 | pretrained_models/ 10 | record -------------------------------------------------------------------------------- /ECCV'22 poster_v2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrld/visual_navigation_pose_estimation/58d98a3592157f2558120f18af7c9ec77e795ee1/ECCV'22 poster_v2.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jiaxin Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ECCV'22] A Visual Navigation Perspective for Category-Level Object Pose Estimation 2 | 3 | This is the official repository for the ECCV 2022 paper ["A Visual Navigation Perspective for Category-Level Object Pose Estimation"](https://arxiv.org/abs/2203.13572). 4 | ![](https://github.com/wrld/visual_navigation_pose_estimation/blob/main/images/teaser_video.gif) 5 | ## System Environments: 6 | 7 | You can use anaconda and create an anaconda environment: 8 | 9 | ``` shell 10 | conda env create -f environment.yml 11 | conda activate visual_nav 12 | ``` 13 | ## Datasets 14 | We set [neural_object_fitting](https://github.com/xuchen-ethz/neural_object_fitting) as our image generator for NOCS, their checkpoints and datasets could be downloaded following: 15 | 16 | ``` shell 17 | sh prepare_datasets.sh 18 | ``` 19 | 20 | ## Download Pre-trained model 21 | 22 | Download the pretrained models [here](https://drive.google.com/drive/folders/1WFB1fJNyJgWUdyxqKHrpsUXuUpmImhcm?usp=sharing) and put them into `./pretrained_model/`. 23 | 24 | ## Example Usage 25 | 26 | ### Training on Synthetic Dataset 27 | 28 | Run the following command to train on specific category (can / bottle / bowl / mug / laptop / camera) 29 | ``` 30 | python main.py --dataset [category] --name [running name] 31 | ``` 32 | 33 | The saved models and evaluation results could be check at './results/'. 34 | ### Training visualize 35 | 36 | To visualize the training process, you can run: 37 | 38 | ``` shell 39 | # use wandb to visualize the training loss and states 40 | python main.py --dataset [category] --log True --log_interval 50 41 | ``` 42 | Then open the wandb link to monitor the training process. 43 | ### Training options 44 | 45 | There are several settings you can change by adding arguments below: 46 | 47 | | Arguments | What it will trigger | Default | 48 | | ------------------- | ----------------------------------------------- | -------------------- | 49 | | --batch_size | The batch size of input | 50 | 50 | | --lr | The learning rate for training | 0.00003 | 51 | | --pretrain | Continue to train with pretrained model | None | 52 | | --save_interval | save model interval | 1000 | 53 | | --episode_nums | maximum episodes number | 50000 | 54 | 55 | ### Evaluation 56 | 57 | To evaluate on synthetic dataset based on the pretrained model, run the following command: 58 | ``` shell 59 | python main.py --dataset [category] --eval 1 --pretrain [path] --gd_optimize True 60 | ``` 61 | The evaluation results will be reported with plot. 62 | 63 | To evaluate on real dataset, run the following command: 64 | ``` shell 65 | python main.py --dataset [category] --eval 2 --pretrain [path] --gd_optimize True 66 | ``` 67 | 68 | To calculate the score: 69 | ``` shell 70 | python nocs/eval.py --dataset [category] 71 | ``` 72 | The evaluation results of specific category will be reported. 73 | 74 | ## Acknowledgement 75 | Our code is based on [neural_object_fitting](https://github.com/xuchen-ethz/neural_object_fitting) and [pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic). 76 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | def load_options(): 5 | """ configs for training""" 6 | parser = argparse.ArgumentParser( 7 | description='visual navigation pose estimation') 8 | # basic parameters 9 | basic_args = parser.add_argument_group('basic') 10 | 11 | basic_args.add_argument('--seed', type=int, default=123456, metavar='N', 12 | help='random seed (default: 123456)') 13 | basic_args.add_argument('--batch_size', type=int, default=50, metavar='N', 14 | help='batch size (default: 50)') 15 | basic_args.add_argument('--episode_nums', type=int, default=50000, metavar='N', 16 | help='maximum episodes number') 17 | basic_args.add_argument('--cuda', type=bool, default=True, 18 | help='run on CUDA ') 19 | basic_args.add_argument('--save_folder', type=str, default="./results/", 20 | help='folder to save files and logs') 21 | basic_args.add_argument('--name', type=str, default="pose_estimation", 22 | help='running task name') 23 | basic_args.add_argument('--resume', type=bool, default=False, 24 | help='hidden size (default: 256)') 25 | basic_args.add_argument('--pretrain', type=str, default=None, 26 | help='pretrain model path') 27 | basic_args.add_argument('--save_interval', type=int, default=1000, 28 | help='model save interval') 29 | basic_args.add_argument('--log', type=bool, default=False, 30 | help='wandb log') 31 | basic_args.add_argument('--log_interval', type=int, default=30, 32 | help='wandb logging interval') 33 | basic_args.add_argument('--dataset', type=str, default='laptop', 34 | help='datasets category (can/bottle/bowl/laptop/mug/camera)') 35 | # policy parameters 36 | policy_args = parser.add_argument_group('policy') 37 | 38 | policy_args.add_argument('--policy', default="Gaussian", 39 | help='Policy Type: Gaussian | Deterministic (default: Gaussian)') 40 | policy_args.add_argument('--eval', type=int, default=0, 41 | help='Evaluates a policy a policy every 10 episode (0: no evaluation, 1: eval on synthetic data, 2: eval on nocs data)') 42 | policy_args.add_argument('--gamma', type=float, default=0.8, metavar='G', 43 | help='discount factor for reward (default: 0.99)') 44 | policy_args.add_argument('--tau', type=float, default=0.005, metavar='G', 45 | help='target smoothing coefficiet(τ)') 46 | policy_args.add_argument('--lr', type=float, default=0.00003, metavar='G', 47 | help='learning rate') 48 | policy_args.add_argument('--alpha', type=float, default=0.2, metavar='G', 49 | help='Temperature parameter α determines the relative importance of the entropy\ 50 | term against the reward (default: 0.2)') 51 | policy_args.add_argument('--automatic_entropy_tuning', type=bool, default=True, metavar='G', 52 | help='Automaically adjust α (default: False)') 53 | policy_args.add_argument('--max_step', type=int, default=5, metavar='N', 54 | help='maximum training steps') 55 | policy_args.add_argument('--target_update_interval', type=int, default=1, metavar='N', 56 | help='Value target update per no. of updates per step (default: 1)') 57 | policy_args.add_argument('--hidden_size', type=int, default=256, metavar='N', 58 | help='hidden size (default: 256)') 59 | policy_args.add_argument('--updates_per_step', type=int, default=1, metavar='N', 60 | help='model updates per simulator step (default: 1)') 61 | policy_args.add_argument('--demo_episodes', type=int, default=100, metavar='N', 62 | help='Steps sampling random actions (default: 27000)') 63 | policy_args.add_argument('--replay_size', type=int, default=5000, metavar='N', 64 | help='size of replay buffer (default: 27000000)') 65 | policy_args.add_argument('--gd_optimize', type=bool, default=False, 66 | help='use GD to optimize pose after IL') 67 | policy_args.add_argument('--use_encoder', type=bool, default=False, 68 | help='use encoder to initialize latent code') 69 | policy_args.add_argument('--image_size', type=int, default=64, 70 | help='processing image size') 71 | # nocs image generator params 72 | args = parser.parse_args() 73 | return args, parser 74 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | basic: 2 | checkpoints_dir: ./checkpoints 3 | dataroot: ./ 4 | gpu_ids: '0' 5 | project_name: neural_object_fitting 6 | data: 7 | batch_size: 256 8 | crop_size: 64 9 | dataset_mode: carla 10 | load_size: 64 11 | max_dataset_size: .inf 12 | no_flip: true 13 | num_threads: 0 14 | preprocess: resize_and_crop 15 | serial_batches: false 16 | models: 17 | batch_size_vis: 8 18 | lambda_KL: 0.01 19 | lambda_recon: 10.0 20 | use_VAE: true 21 | z_dim: 16 22 | log: 23 | display_freq: 102400 24 | print_freq: 1 25 | misc: 26 | load_suffix: latest 27 | verbose: false 28 | visualizers: 29 | - terminal 30 | - wandb 31 | model: 32 | init_gain: 0.02 33 | init_type: normal 34 | input_nc: 3 35 | model: latent_object 36 | output_nc: 3 37 | save: 38 | epoch_count: 1 39 | save_by_iter: false 40 | save_epoch_freq: 5 41 | save_latest_freq: 102400 42 | train: 43 | lr: 0.003 44 | lr_decay_iters: 50 45 | lr_policy: linear 46 | n_views: 2592 47 | niter: 100 48 | niter_decay: 100 49 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | expname: default 2 | data: 3 | datadir: data/carla 4 | type: carla 5 | imsize: 64 6 | white_bkgd: True 7 | near: 1. 8 | far: 6. 9 | radius: 3.4 # set according to near and far plane 10 | fov: 90. 11 | orthographic: False 12 | umin: 0. # 0 deg, convert to degree via 360. * u 13 | umax: 1. # 360 deg, convert to degree via 360. * u 14 | vmin: 0. # 0 deg, convert to degrees via arccos(1 - 2 * v) * 180. / pi 15 | vmax: 0.45642212862617093 # 85 deg, convert to degrees via arccos(1 - 2 * v) * 180. / pi 16 | nerf: 17 | i_embed: 0 18 | use_viewdirs: True 19 | multires: 10 20 | multires_views: 4 21 | N_samples: 64 22 | N_importance: 0 23 | netdepth: 8 24 | netwidth: 256 25 | netdepth_fine: 8 26 | netwidth_fine: 256 27 | perturb: 1. 28 | raw_noise_std: 1. 29 | decrease_noise: True 30 | z_dist: 31 | type: gauss 32 | dim: 256 33 | dim_appearance: 128 # This dimension is subtracted from "dim" 34 | ray_sampler: 35 | min_scale: 0.25 36 | max_scale: 1. 37 | scale_anneal: 0.0025 # no effect if scale_anneal<0, else the minimum scale decreases exponentially until converge to min_scale 38 | N_samples: 1024 # 32*32, patchsize 39 | discriminator: 40 | ndf: 64 41 | hflip: False # Randomly flip discriminator input horizontally 42 | training: 43 | outdir: ./results 44 | model_file: model.pt 45 | monitoring: tensorboard 46 | use_amp: False # Use automated mixed precision 47 | nworkers: 6 48 | batch_size: 4 49 | chunk: 32768 # 1024*32 50 | netchunk: 65536 # 1024*64 51 | lr_g: 0.0005 52 | lr_d: 0.0001 53 | lr_anneal: 0.5 54 | lr_anneal_every: 50000,100000,200000 55 | equalize_lr: False 56 | gan_type: standard 57 | reg_type: real 58 | reg_param: 10. 59 | optimizer: rmsprop 60 | n_test_samples_with_same_shape_code: 4 61 | take_model_average: true 62 | model_average_beta: 0.999 63 | model_average_reinit: false 64 | restart_every: -1 65 | save_best: fid 66 | fid_every: 5000 # Valid for FID and KID 67 | print_every: 10 68 | sample_every: 500 69 | save_every: 900 70 | backup_every: 50000 71 | video_every: 10000 72 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | basic: 2 | checkpoints_dir: ./checkpoints 3 | dataroot: ./datasets/test 4 | gpu_ids: '0' 5 | project_name: neural_object_fitting 6 | run_name: fitting 7 | data: 8 | batch_size: 1 9 | crop_size: 64 10 | dataset_mode: carla 11 | load_size: 64 12 | max_dataset_size: .inf 13 | no_flip: true 14 | num_threads: 0 15 | preprocess: resize_and_crop 16 | serial_batches: false 17 | fitting: 18 | lambda_reg: 1 19 | n_init: 32 20 | n_iter: 50 21 | misc: 22 | load_suffix: latest 23 | verbose: false 24 | visualizers: 25 | - terminal 26 | - wandb 27 | model: 28 | init_gain: 0.02 29 | init_type: normal 30 | input_nc: 3 31 | model: latent_object 32 | output_nc: 3 33 | models: 34 | batch_size_vis: 8 35 | use_VAE: true 36 | z_dim: 16 37 | test: 38 | target_size: 64 39 | num_agent: 1 40 | id_agent: 0 41 | results_dir: ./results 42 | skip: 20 -------------------------------------------------------------------------------- /configs/infer.yaml: -------------------------------------------------------------------------------- 1 | basic: 2 | checkpoints_dir: ./checkpoints 3 | dataroot: ./datasets/train 4 | gpu_ids: '0' 5 | project_name: neural_object_fitting 6 | run_name: fitting 7 | data: 8 | batch_size: 1 9 | crop_size: 64 10 | dataset_mode: nocs_hdf5 11 | load_size: 64 12 | max_dataset_size: .inf 13 | no_flip: true 14 | num_threads: 0 15 | preprocess: resize_and_crop 16 | serial_batches: false 17 | fitting: 18 | lambda_reg: 1 19 | n_init: 1 20 | n_iter: 50 21 | misc: 22 | load_suffix: latest 23 | verbose: false 24 | visualizers: 25 | - terminal 26 | - wandb 27 | model: 28 | init_gain: 0.02 29 | init_type: normal 30 | input_nc: 3 31 | model: latent_object 32 | output_nc: 3 33 | models: 34 | batch_size_vis: 8 35 | use_VAE: true 36 | z_dim: 16 37 | test: 38 | target_size: 64 39 | num_agent: 1 40 | id_agent: 0 41 | results_dir: ./results 42 | skip: 20 -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrld/visual_navigation_pose_estimation/58d98a3592157f2558120f18af7c9ec77e795ee1/data/.DS_Store -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | 15 | import torch.utils.data 16 | 17 | from data.base_dataset import BaseDataset 18 | 19 | 20 | def find_dataset_using_name(dataset_name): 21 | """Import the module "data/[dataset_name]_dataset.py". 22 | 23 | In the file, the class called DatasetNameDataset() will 24 | be instantiated. It has to be a subclass of BaseDataset, 25 | and it is case-insensitive. 26 | """ 27 | dataset_filename = "data." + dataset_name + "_dataset" 28 | datasetlib = importlib.import_module(dataset_filename) 29 | 30 | dataset = None 31 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 32 | for name, cls in datasetlib.__dict__.items(): 33 | if name.lower() == target_dataset_name.lower() \ 34 | and issubclass(cls, BaseDataset): 35 | dataset = cls 36 | 37 | if dataset is None: 38 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 39 | 40 | return dataset 41 | 42 | 43 | def get_option_setter(dataset_name): 44 | """Return the static method of the dataset class.""" 45 | dataset_class = find_dataset_using_name(dataset_name) 46 | return dataset_class.modify_commandline_options 47 | 48 | 49 | def create_dataset(opt): 50 | """Create a dataset given the option. 51 | 52 | This function wraps the class CustomDatasetDataLoader. 53 | This is the main interface between this package and 'train.py'/'test.py' 54 | 55 | Example: 56 | >>> from data import create_dataset 57 | >>> dataset = create_dataset(opt) 58 | """ 59 | data_loader = CustomDatasetDataLoader(opt) 60 | dataset = data_loader.load_data() 61 | return dataset 62 | 63 | 64 | class CustomDatasetDataLoader(): 65 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 66 | 67 | def __init__(self, opt): 68 | """Initialize this class 69 | 70 | Step 1: create a dataset instance given the name [dataset_mode] 71 | Step 2: create a multi-threaded data loader. 72 | """ 73 | self.opt = opt 74 | dataset_class = find_dataset_using_name(opt.dataset_mode) 75 | self.dataset = dataset_class(opt) 76 | print("dataset [%s] was created with [%d] samples" % (type(self.dataset).__name__, len(self.dataset))) 77 | self.dataloader = torch.utils.data.DataLoader( 78 | self.dataset, 79 | drop_last=True, 80 | batch_size=opt.batch_size, 81 | shuffle=not opt.serial_batches, 82 | num_workers=int(opt.num_threads), 83 | pin_memory=True) 84 | 85 | def load_data(self): 86 | return self 87 | 88 | def __len__(self): 89 | """Return the number of data in the dataset""" 90 | return min(len(self.dataset), self.opt.max_dataset_size) 91 | 92 | def __iter__(self): 93 | """Return a batch of data""" 94 | for i, data in enumerate(self.dataloader): 95 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 96 | break 97 | yield data 98 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | 13 | 14 | class BaseDataset(data.Dataset, ABC): 15 | """This class is an abstract base class (ABC) for datasets. 16 | 17 | To create a subclass, you need to implement the following four functions: 18 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 19 | -- <__len__>: return the size of dataset. 20 | -- <__getitem__>: get a data point. 21 | -- : (optionally) add dataset-specific options and set default options. 22 | """ 23 | 24 | def __init__(self, opt): 25 | """Initialize the class; save the options in the class 26 | 27 | Parameters: 28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 29 | """ 30 | self.opt = opt 31 | self.root = opt.dataroot 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_params(opt, size): 65 | w, h = size 66 | new_h = h 67 | new_w = w 68 | if opt.preprocess == 'resize_and_crop': 69 | new_h = new_w = opt.load_size 70 | elif opt.preprocess == 'scale_width_and_crop': 71 | new_w = opt.load_size 72 | new_h = opt.load_size * h // w 73 | 74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 76 | 77 | flip = random.random() > 0.5 78 | 79 | return {'crop_pos': (x, y), 'flip': flip} 80 | 81 | 82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 83 | transform_list = [] 84 | if grayscale: 85 | transform_list.append(transforms.Grayscale(1)) 86 | if 'resize' in opt.preprocess: 87 | osize = [opt.load_size, opt.load_size] 88 | transform_list.append(transforms.Resize(osize, method)) 89 | elif 'scale_width' in opt.preprocess: 90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 91 | 92 | if 'crop' in opt.preprocess: 93 | if params is None: 94 | transform_list.append(transforms.CenterCrop(opt.crop_size)) 95 | else: 96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 97 | 98 | if opt.preprocess == 'none': 99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 100 | 101 | if not opt.no_flip: 102 | if params is None: 103 | transform_list.append(transforms.RandomHorizontalFlip()) 104 | elif params['flip']: 105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 106 | 107 | if convert: 108 | transform_list += [transforms.ToTensor()] 109 | if grayscale: 110 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 111 | else: 112 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 113 | 114 | # add small random noise 115 | transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())) 116 | return transforms.Compose(transform_list) 117 | 118 | 119 | def __make_power_2(img, base, method=Image.BICUBIC): 120 | ow, oh = img.size 121 | h = int(round(oh / base) * base) 122 | w = int(round(ow / base) * base) 123 | if (h == oh) and (w == ow): 124 | return img 125 | 126 | __print_size_warning(ow, oh, w, h) 127 | return img.resize((w, h), method) 128 | 129 | 130 | def __scale_width(img, target_width, method=Image.BICUBIC): 131 | ow, oh = img.size 132 | if (ow == target_width): 133 | return img 134 | w = target_width 135 | h = int(target_width * oh / ow) 136 | return img.resize((w, h), method) 137 | 138 | 139 | def __crop(img, pos, size): 140 | ow, oh = img.size 141 | x1, y1 = pos 142 | tw = th = size 143 | if (ow > tw or oh > th): 144 | return img.crop((x1, y1, x1 + tw, y1 + th)) 145 | return img 146 | 147 | 148 | def __flip(img, flip): 149 | if flip: 150 | return img.transpose(Image.FLIP_LEFT_RIGHT) 151 | return img 152 | 153 | 154 | def __print_size_warning(ow, oh, w, h): 155 | """Print warning information about image size(only print once)""" 156 | if not hasattr(__print_size_warning, 'has_printed'): 157 | print("The image size needs to be a multiple of 4. " 158 | "The loaded image size was (%d, %d), so it was adjusted to " 159 | "(%d, %d). This adjustment will be done to all images " 160 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 161 | __print_size_warning.has_printed = True 162 | -------------------------------------------------------------------------------- /data/carla_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from data.base_dataset import BaseDataset, get_transform 3 | import random 4 | import numpy as np 5 | import h5py 6 | from PIL import Image 7 | import os 8 | class CARLADATASET(BaseDataset): 9 | 10 | 11 | def __init__(self, opt): 12 | """Initialize this dataset class. 13 | 14 | Parameters: 15 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 16 | """ 17 | BaseDataset.__init__(self, opt) 18 | 19 | input_nc = self.opt.output_nc 20 | self.transform = get_transform(opt, grayscale=(input_nc == 1),method=Image.BILINEAR) 21 | 22 | hdf5_file = h5py.File(os.path.join(opt.dataroot,opt.category+'.hdf5'),'r',swmr=True) 23 | self.images = hdf5_file['images'] 24 | self.poses = hdf5_file['poses'][...] 25 | print(self.poses.shape) 26 | print(self.images.shape) 27 | self.dataset_size = self.poses.shape[0] 28 | self.num_view = opt.n_views 29 | self.num_model = self.dataset_size 30 | 31 | def __getitem__(self, index): 32 | """Return a data point and its metadata information. 33 | 34 | Parameters: 35 | index - - a random integer for data indexing 36 | 37 | Returns a dictionary that contains A and A_paths 38 | A(tensor) - - an image in one domain 39 | A_paths(str) - - the path of the image 40 | """ 41 | model_id = random.randint(0,self.num_model-1) 42 | 43 | id = model_id 44 | A_img = np.copy(self.images[id,:,:,:]) 45 | if A_img.shape[2] == 4: 46 | A_mask = A_img[:,:,-1] == 0 47 | A_img[A_mask,:3] = 255 48 | A_img = A_img[:,:,:3] 49 | A_pose = np.floor(np.copy(self.poses[id,:])).astype(np.float32) 50 | # print(A_pose) 51 | A_pose[1] = A_pose[2] 52 | A_pose[2] = 0 53 | 54 | 55 | A = self.transform(Image.fromarray(A_img.astype(np.uint8))) 56 | 57 | 58 | B = A.clone() 59 | B_pose = np.copy(A_pose).astype(np.float32) 60 | 61 | return {'A': A, 'A_pose': A_pose, 62 | 'B': B, 'B_pose': B_pose,} 63 | def __len__(self): 64 | """Return the total number of images in the dataset.""" 65 | return self.dataset_size -------------------------------------------------------------------------------- /data/nocs_hdf5_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from data.base_dataset import BaseDataset, get_transform 3 | import random 4 | import numpy as np 5 | import h5py 6 | from PIL import Image 7 | import os 8 | from torchvision.transforms import InterpolationMode 9 | class NOCSHDF5Dataset(BaseDataset): 10 | 11 | 12 | def __init__(self, opt): 13 | """Initialize this dataset class. 14 | 15 | Parameters: 16 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 17 | """ 18 | BaseDataset.__init__(self, opt) 19 | 20 | input_nc = self.opt.output_nc 21 | self.transform = get_transform(opt, grayscale=(input_nc == 1),method=InterpolationMode.BILINEAR) 22 | 23 | hdf5_file = h5py.File(os.path.join(opt.dataroot,opt.category+'.hdf5'),'r',swmr=True) 24 | self.images = hdf5_file['images'] 25 | self.poses = hdf5_file['poses'][...] 26 | 27 | self.dataset_size = self.poses.shape[0] 28 | self.num_view = opt.n_views 29 | self.num_model = self.dataset_size // self.num_view 30 | 31 | def __getitem__(self, index): 32 | """Return a data point and its metadata information. 33 | 34 | Parameters: 35 | index - - a random integer for data indexing 36 | 37 | Returns a dictionary that contains A and A_paths 38 | A(tensor) - - an image in one domain 39 | A_paths(str) - - the path of the image 40 | """ 41 | model_id = random.randint(0,self.num_model-1) 42 | 43 | image_id = random.randint(0,self.num_view-1) 44 | id = model_id*self.num_view + image_id 45 | A_img = np.copy(self.images[id,:,:,:]) 46 | elev,azi = np.copy(self.poses[id,:]) 47 | if A_img.shape[2] == 4: 48 | A_mask = A_img[:,:,-1] == 0 49 | A_img[A_mask,:3] = 255 50 | A_img = A_img[:,:,:3] 51 | 52 | A = self.transform(Image.fromarray(A_img.astype(np.uint8))) 53 | A_pose = np.array([elev,azi,0]).astype(np.float32) 54 | 55 | image_id = random.randint(0,self.num_view-1) 56 | id = model_id*self.num_view + image_id 57 | B_img = np.copy(self.images[id,:,:,:]) 58 | elev,azi = np.copy(self.poses[id,:]) 59 | if B_img.shape[2] == 4: 60 | B_mask = B_img[:,:,-1] == 0 61 | B_img[B_mask,:3] = 255 62 | B_img = B_img[:,:,:3] 63 | 64 | B = self.transform(Image.fromarray(B_img.astype(np.uint8))) 65 | B_pose = np.array([elev,azi,0]).astype(np.float32) 66 | 67 | return {'A': A, 'A_pose': A_pose, 68 | 'B': B, 'B_pose': B_pose,} 69 | def __len__(self): 70 | """Return the total number of images in the dataset.""" 71 | return self.dataset_size -------------------------------------------------------------------------------- /eccv_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrld/visual_navigation_pose_estimation/58d98a3592157f2558120f18af7c9ec77e795ee1/eccv_poster.pdf -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: visual_nav 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - _pytorch_select=0.1=cpu_0 9 | - blas=1.0=mkl 10 | - ca-certificates=2022.4.26=h06a4308_0 11 | - certifi=2022.5.18.1=py37h06a4308_0 12 | - cffi=1.15.0=py37hd667e15_1 13 | - cudatoolkit=10.1.243=h6bb024c_0 14 | - freetype=2.11.0=h70c0345_0 15 | - giflib=5.2.1=h7b6447c_0 16 | - intel-openmp=2021.4.0=h06a4308_3561 17 | - jpeg=9e=h7f8727e_0 18 | - lcms2=2.12=h3be6417_0 19 | - ld_impl_linux-64=2.38=h1181459_1 20 | - libffi=3.3=he6710b0_2 21 | - libgcc-ng=11.2.0=h1234567_0 22 | - libgomp=11.2.0=h1234567_0 23 | - libpng=1.6.37=hbc83047_0 24 | - libstdcxx-ng=11.2.0=h1234567_0 25 | - libtiff=4.2.0=h2818925_1 26 | - libuv=1.40.0=h7b6447c_0 27 | - libwebp=1.2.2=h55f646e_0 28 | - libwebp-base=1.2.2=h7f8727e_0 29 | - lz4-c=1.9.3=h295c915_1 30 | - mkl=2021.4.0=h06a4308_640 31 | - mkl-service=2.4.0=py37h7f8727e_0 32 | - mkl_fft=1.3.1=py37hd3c417c_0 33 | - mkl_random=1.2.2=py37h51133e4_0 34 | - ncurses=6.3=h7f8727e_2 35 | - ninja=1.10.2=h06a4308_5 36 | - ninja-base=1.10.2=hd09550d_5 37 | - numpy=1.21.5=py37he7a7128_2 38 | - numpy-base=1.21.5=py37hf524024_2 39 | - openssl=1.1.1o=h7f8727e_0 40 | - pillow=9.0.1=py37h22f2fdc_0 41 | - pip=21.2.2=py37h06a4308_0 42 | - pycparser=2.21=pyhd3eb1b0_0 43 | - python=3.7.13=h12debd9_0 44 | - pytorch=1.7.1=py3.7_cuda10.1.243_cudnn7.6.3_0 45 | - readline=8.1.2=h7f8727e_1 46 | - setuptools=61.2.0=py37h06a4308_0 47 | - six=1.16.0=pyhd3eb1b0_1 48 | - sqlite=3.38.3=hc218d9a_0 49 | - tk=8.6.11=h1ccaba5_1 50 | - typing_extensions=4.1.1=pyh06a4308_0 51 | - wheel=0.37.1=pyhd3eb1b0_0 52 | - xz=5.2.5=h7f8727e_1 53 | - zlib=1.2.12=h7f8727e_2 54 | - zstd=1.5.2=ha4553b6_0 55 | - pip: 56 | - absl-py==1.0.0 57 | - cachetools==5.1.0 58 | - charset-normalizer==2.0.12 59 | - click==8.1.3 60 | - cloudpickle==2.1.0 61 | - cycler==0.11.0 62 | - dataclasses==0.6 63 | - decorator==4.4.2 64 | - docker-pycreds==0.4.0 65 | - envyaml==1.10.211231 66 | - fonttools==4.33.3 67 | - future==0.18.2 68 | - fvcore==0.1.5.post20220512 69 | - gitdb==4.0.9 70 | - gitpython==3.1.27 71 | - google-auth==2.6.6 72 | - google-auth-oauthlib==0.4.6 73 | - grpcio==1.46.3 74 | - gym==0.24.0 75 | - gym-notices==0.0.6 76 | - h5py==3.7.0 77 | - idna==3.3 78 | - imageio==2.19.2 79 | - imageio-ffmpeg==0.4.7 80 | - importlib-metadata==4.11.4 81 | - install==1.3.5 82 | - iopath==0.1.9 83 | - kiwisolver==1.4.2 84 | - markdown==3.3.7 85 | - matplotlib==3.5.2 86 | - moviepy==1.0.3 87 | - oauthlib==3.2.0 88 | - opencv-python==4.5.5.64 89 | - packaging==21.3 90 | - pathtools==0.1.2 91 | - portalocker==2.4.0 92 | - proglog==0.1.10 93 | - promise==2.3 94 | - protobuf==3.20.1 95 | - psutil==5.9.1 96 | - pyasn1==0.4.8 97 | - pyasn1-modules==0.2.8 98 | - pyparsing==3.0.9 99 | - python-dateutil==2.8.2 100 | - python-graphviz==0.20 101 | - pytorch3d==0.3.0 102 | - pyyaml==6.0 103 | - requests==2.27.1 104 | - requests-oauthlib==1.3.1 105 | - rsa==4.8 106 | - scipy==1.5.2 107 | - sentry-sdk==1.5.12 108 | - setproctitle==1.2.3 109 | - shortuuid==1.0.9 110 | - smmap==5.0.0 111 | - tabulate==0.8.9 112 | - tensorboard==2.9.0 113 | - tensorboard-data-server==0.6.1 114 | - tensorboard-plugin-wit==1.8.1 115 | - termcolor==1.1.0 116 | - torch==1.8.1+cu111 117 | - torchaudio==0.8.1 118 | - torchvision==0.9.1+cu111 119 | - torchviz==0.0.2 120 | - tqdm==4.64.0 121 | - typing-extensions==4.2.0 122 | - urllib3==1.26.9 123 | - wandb==0.12.17 124 | - werkzeug==2.1.2 125 | - yacs==0.1.8 126 | - zipp==3.8.0 -------------------------------------------------------------------------------- /images/teaser_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrld/visual_navigation_pose_estimation/58d98a3592157f2558120f18af7c9ec77e795ee1/images/teaser_video.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import os 5 | import torch 6 | import numpy as np 7 | import tqdm 8 | import wandb 9 | from config import load_options 10 | from runner import eval_nocs, train_one_epoch, eval_one_epoch 11 | from utils.loss import SetCriterion 12 | from models.nocs_gym import nocs_gym 13 | from sac_2.SAC import SAC 14 | def main(args, parser): 15 | # Initialization save & log 16 | runner_name = args.name 17 | save_path = os.path.join(args.save_folder, runner_name) 18 | os.makedirs(save_path, exist_ok=True) 19 | 20 | device = torch.device('cuda') 21 | if args.log == True: 22 | run = wandb.init(project=args.name, group=args.name) 23 | run.config.data = save_path 24 | run.name = run.id 25 | 26 | # Initialize criterions 27 | criterion = SetCriterion(args) 28 | 29 | # Initialize datasets 30 | print("========load nocs image generator==========") 31 | env = nocs_gym(args, parser, criterion) 32 | # Initialize agent 33 | agent = SAC(args.image_size, env.action_space, args, save_path, device) 34 | if args.pretrain is not None: 35 | agent.policy.load_state_dict(torch.load(args.pretrain)) 36 | 37 | min_test = np.infty 38 | args.save_path = save_path 39 | # Evaluation on synthetic dataset 40 | if args.eval==1: 41 | eval_one_epoch(args, agent, env, episodes=args.batch_size) 42 | return 43 | # Evaluation on real dataset 44 | elif args.eval ==2: 45 | eval_nocs(args, agent, env) 46 | return 47 | 48 | for i_episode in tqdm.tqdm(range(1, args.episode_nums)): 49 | 50 | train_one_epoch(args, agent, env, i_episode) 51 | 52 | if i_episode % args.save_interval == 0: 53 | avg_loss = eval_one_epoch(args, agent, env, i_episode) 54 | path = os.path.join(save_path, "checkpoint_" + str(i_episode) + ".pt") 55 | torch.save(agent.policy.state_dict(), path) 56 | path = os.path.join(save_path, "latest.pt") 57 | torch.save(agent.policy.state_dict(), path) 58 | 59 | if avg_loss < min_test: 60 | min_test = avg_loss 61 | path = os.path.join(save_path, "best_model.pt") 62 | torch.save(agent.policy.state_dict(), path) 63 | print("save best model on ", i_episode) 64 | 65 | if __name__ == '__main__': 66 | args, parser = load_options() 67 | torch.manual_seed(args.seed) 68 | np.random.seed(args.seed) 69 | main(args, parser) 70 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | 23 | from models.base_model import BaseModel 24 | 25 | 26 | def find_model_using_name(model_name): 27 | """Import the module "models/[model_name]_model.py". 28 | 29 | In the file, the class called DatasetNameModel() will 30 | be instantiated. It has to be a subclass of BaseModel, 31 | and it is case-insensitive. 32 | """ 33 | model_filename = "models." + model_name + "_model" 34 | modellib = importlib.import_module(model_filename) 35 | model = None 36 | target_model_name = model_name.replace('_', '') + 'model' 37 | for name, cls in modellib.__dict__.items(): 38 | if name.lower() == target_model_name.lower() \ 39 | and issubclass(cls, BaseModel): 40 | model = cls 41 | 42 | if model is None: 43 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 44 | exit(0) 45 | 46 | return model 47 | 48 | 49 | def get_option_setter(model_name): 50 | """Return the static method of the model class.""" 51 | model_class = find_model_using_name(model_name) 52 | return model_class.modify_commandline_options 53 | 54 | 55 | def create_model(opt): 56 | """Create a model given the option. 57 | 58 | This function warps the class CustomDatasetDataLoader. 59 | This is the main interface between this package and 'train.py'/'test.py' 60 | 61 | Example: 62 | >>> from models import create_model 63 | >>> model = create_model(opt) 64 | """ 65 | model = find_model_using_name(opt.model) 66 | instance = model(opt) 67 | print("model [%s] was created" % type(instance).__name__) 68 | return instance 69 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from models.networks import networks 8 | 9 | 10 | class BaseModel(ABC): 11 | """This class is an abstract base class (ABC) for models. 12 | To create a subclass, you need to implement the following five functions: 13 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 14 | -- : unpack data from dataset and apply preprocessing. 15 | -- : produce intermediate results. 16 | -- : calculate losses, gradients, and update network weights. 17 | -- : (optionally) add model-specific options and set default options. 18 | """ 19 | 20 | def __init__(self, opt): 21 | """Initialize the BaseModel class. 22 | 23 | Parameters: 24 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 25 | 26 | When creating your custom class, you need to implement your own initialization. 27 | In this fucntion, you should first call 28 | Then, you need to define four lists: 29 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 30 | -- self.model_names (str list): specify the images that you want to display and save. 31 | -- self.visual_names (str list): define networks used in our training. 32 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 33 | """ 34 | self.opt = opt 35 | self.gpu_ids = opt.gpu_ids 36 | self.isTrain = opt.isTrain 37 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 38 | # if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 39 | # torch.backends.cudnn.benchmark = True 40 | self.loss_names = [] 41 | self.model_names = [] 42 | self.optimizer_names = [] 43 | self.visual_names = [] 44 | self.optimizers = [] 45 | self.image_paths = [] 46 | self.metric = 0 # used for learning rate policy 'plateau' 47 | 48 | self.save_dir = os.path.join(opt.checkpoints_dir,opt.project_name, opt.exp_name, opt.run_name) # save all the checkpoints to save_dir 49 | self.net_dict = { name: getattr(self, 'net' + name) for name in self.model_names} 50 | self.optimizer_dict = {name: getattr(self, 'optimizer_' + name) for name in self.model_names} if opt.isTrain else {} 51 | 52 | @staticmethod 53 | def modify_commandline_options(parser, is_train): 54 | """Add new model-specific options, and rewrite default values for existing options. 55 | 56 | Parameters: 57 | parser -- original option parser 58 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 59 | 60 | Returns: 61 | the modified parser. 62 | """ 63 | return parser 64 | 65 | @abstractmethod 66 | def set_input(self, input): 67 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 68 | 69 | Parameters: 70 | input (dict): includes the data itself and its metadata information. 71 | """ 72 | pass 73 | 74 | @abstractmethod 75 | def forward(self,vis=False): 76 | """Run forward pass; called by both functions and .""" 77 | pass 78 | 79 | @abstractmethod 80 | def optimize_parameters(self): 81 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 82 | pass 83 | 84 | def setup(self, opt): 85 | """Load and print networks; create schedulers 86 | 87 | Parameters: 88 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 89 | """ 90 | if self.isTrain: 91 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] #TODO: why need this? 92 | 93 | # get the latest checkpoints 94 | import glob 95 | 96 | checkpoint_folder = os.path.join(opt.checkpoints_dir, opt.project_name, opt.exp_name, opt.run_name) 97 | 98 | search_pattern = '*.pth' 99 | 100 | checkpoint_list = glob.glob(os.path.join(checkpoint_folder, search_pattern)) 101 | print("!!!!!!!!!!!!!!!", checkpoint_folder) 102 | iter_start = 0 103 | if len(checkpoint_list) > 0: 104 | load_suffix = self.opt.load_suffix 105 | self.load_networks(load_suffix) 106 | if self.isTrain: 107 | self.load_optimizers(load_suffix) 108 | iter_start = self.load_states(load_suffix) 109 | 110 | self.print_networks(opt.verbose) 111 | 112 | return iter_start 113 | 114 | def eval(self): 115 | """Make models eval mode during test time""" 116 | for name in self.model_names: 117 | if isinstance(name, str): 118 | net = getattr(self, 'net' + name) 119 | net.eval() 120 | 121 | def train(self): 122 | """Make models eval mode during test time""" 123 | for name in self.model_names: 124 | if isinstance(name, str): 125 | net = getattr(self, 'net' + name) 126 | net.train() 127 | 128 | def test(self): 129 | """Forward function used in test time. 130 | 131 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 132 | It also calls to produce additional visualization results 133 | """ 134 | with torch.no_grad(): 135 | self.forward() 136 | self.compute_visuals() 137 | 138 | def compute_visuals(self): 139 | """Calculate additional output images for visdom and HTML visualization""" 140 | pass 141 | 142 | def get_image_paths(self): 143 | """ Return image paths that are used to load current data""" 144 | return self.image_paths 145 | 146 | def update_learning_rate(self): 147 | """Update learning rates for all the networks; called at the end of every epoch""" 148 | for scheduler in self.schedulers: 149 | if self.opt.lr_policy == 'plateau': 150 | scheduler.step(self.metric) 151 | else: 152 | scheduler.step() 153 | 154 | # lr = self.optimizers[0].param_groups[0]['lr'] 155 | # print('learning rate = %.7f' % lr) 156 | 157 | def get_current_visuals(self): 158 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 159 | visual_ret = OrderedDict() 160 | for name in self.visual_names: 161 | if isinstance(name, str): 162 | visual_ret[name] = getattr(self, name) 163 | return visual_ret 164 | 165 | def get_current_videos(self): 166 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 167 | visual_ret = OrderedDict() 168 | for name in self.video_names: 169 | if isinstance(name, str): 170 | visual_ret[name] = getattr(self, name) 171 | return visual_ret 172 | 173 | def get_current_losses(self): 174 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 175 | errors_ret = OrderedDict() 176 | for name in self.loss_names: 177 | if isinstance(name, str): 178 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 179 | return errors_ret 180 | 181 | def save(self, suffix,iter): 182 | """ Save all the networks, optimizers and states to the disk. 183 | 184 | Parameters: 185 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 186 | """ 187 | self.save_optimizers(suffix) 188 | self.save_networks(suffix) 189 | self.save_states(suffix,iter) 190 | 191 | def save_networks(self, suffix): 192 | """Save all the networks to the disk. 193 | 194 | Parameters: 195 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 196 | """ 197 | save_name = 'model_%s.pth' % suffix 198 | save_path = os.path.join(self.save_dir, save_name) 199 | outdict = {} 200 | for name in self.model_names: 201 | if isinstance(name, str): 202 | net_name = 'net' + name 203 | outdict[net_name] = getattr(self, net_name).state_dict() 204 | 205 | torch.save(outdict, save_path) 206 | 207 | def save_optimizers(self,suffix): 208 | """Save all the optimizers to the disk. 209 | 210 | Parameters: 211 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 212 | """ 213 | save_name = 'optimizer_%s.pth' % suffix 214 | save_path = os.path.join(self.save_dir, save_name) 215 | output_dict = {} 216 | for name in self.optimizer_names: 217 | if isinstance(name, str): 218 | optimizer_name = 'optimizer_' + name 219 | output_dict[optimizer_name] = getattr(self, optimizer_name).state_dict() 220 | torch.save(output_dict, save_path) 221 | 222 | def save_states(self, suffix, iter): 223 | """Save all the states (epoch, iter) to the disk. 224 | 225 | Parameters: 226 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 227 | """ 228 | save_name = 'states_%s.txt' % suffix 229 | save_path = os.path.join(self.save_dir, save_name) 230 | import numpy as np 231 | states = np.array([iter]) 232 | np.savetxt(save_path, states) 233 | 234 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 235 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 236 | key = keys[i] 237 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 238 | if module.__class__.__name__.startswith('InstanceNorm') and \ 239 | (key == 'running_mean' or key == 'running_var'): 240 | if getattr(module, key) is None: 241 | state_dict.pop('.'.join(keys)) 242 | if module.__class__.__name__.startswith('InstanceNorm') and \ 243 | (key == 'num_batches_tracked'): 244 | state_dict.pop('.'.join(keys)) 245 | else: 246 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 247 | 248 | def load_networks(self, suffix): 249 | """Load all the networks from the disk. 250 | 251 | Parameters: 252 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 253 | """ 254 | load_path = os.path.join(self.save_dir, 'model_%s.pth' % (suffix)) 255 | # if you are using PyTorch newer than 0.4 (e.g., built from 256 | # GitHub source), you can remove str() on self.device 257 | try: 258 | out_dict = torch.load(load_path,map_location=str(self.device)) 259 | for name in self.model_names: 260 | if isinstance(name, str): 261 | net_name = 'net' + name 262 | net = getattr(self, net_name) 263 | 264 | 265 | state_dict = out_dict[net_name] 266 | if hasattr(state_dict, '_metadata'): 267 | del state_dict._metadata 268 | # patch InstanceNorm checkpoints prior to 0.4 269 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 270 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 271 | 272 | net.load_state_dict(state_dict, strict=True) 273 | print('[%s] loaded from [%s]' % (net_name,load_path)) 274 | 275 | 276 | except Exception: 277 | print('no checkpoints for the network found, parameters will be initialized') 278 | 279 | 280 | 281 | def load_optimizers(self, suffix): 282 | """Load all the optimizers from the disk. 283 | 284 | Parameters: 285 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 286 | """ 287 | load_path = os.path.join(self.save_dir, 'optimizer_%s.pth' % (suffix)) 288 | # if you are using PyTorch newer than 0.4 (e.g., built from 289 | # GitHub source), you can remove str() on self.device 290 | try: 291 | out_dict = torch.load(load_path,map_location=str(self.device)) 292 | for name in self.optimizer_names: 293 | if isinstance(name, str): 294 | optimizer_name = 'optimizer_' + name 295 | optimizer = getattr(self, optimizer_name) 296 | optimizer.load_state_dict(out_dict[optimizer_name]) 297 | print('optimizer loaded from [%s]' % load_path) 298 | except Exception: 299 | print('no checkpoints for the optimizer found, parameters will be initialized') 300 | 301 | 302 | def load_states(self, suffix): 303 | """Load all the states (epoch, iterations) from the disk. 304 | 305 | Parameters: 306 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 307 | """ 308 | load_path = os.path.join(self.save_dir, 'states_%s.txt' % (suffix)) 309 | import numpy as np 310 | try: 311 | out_dict = np.loadtxt(load_path) 312 | print('states loaded from [%s]' % load_path) 313 | return out_dict 314 | except Exception: 315 | print('no states found, start from epoch 1, iter 0') 316 | return 0 317 | 318 | def print_networks(self, verbose): 319 | """Print the total number of parameters in the network and (if verbose) network architecture 320 | 321 | Parameters: 322 | verbose (bool) -- if verbose: print the network architecture 323 | """ 324 | for name in self.model_names: 325 | if isinstance(name, str): 326 | net = getattr(self, 'net' + name) 327 | num_params = 0 328 | for param in net.parameters(): 329 | num_params += param.numel() 330 | if verbose: 331 | print(net) 332 | print('[Network %s] has [%.3f M] parameters' % (name, num_params / 1e6)) 333 | 334 | def set_requires_grad(self, nets, requires_grad=False): 335 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 336 | Parameters: 337 | nets (network list) -- a list of networks 338 | requires_grad (bool) -- whether the networks require gradients or not 339 | """ 340 | if not isinstance(nets, list): 341 | nets = [nets] 342 | for net in nets: 343 | if net is not None: 344 | for param in net.parameters(): 345 | param.requires_grad_(requires_grad) 346 | -------------------------------------------------------------------------------- /models/latent_object_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | from numpy.lib.type_check import real 4 | import torch 5 | from scipy.spatial.transform import Rotation as scipy_rot 6 | 7 | from models.networks import networks 8 | 9 | from .base_model import BaseModel 10 | 11 | class LatentObjectModel(BaseModel): 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train=True): 15 | 16 | models_args = parser.add_argument_group('models') 17 | 18 | models_args.add_argument('--z_dim', type=int, default=16, help='dimension of z') 19 | models_args.add_argument('--batch_size_vis', type=int, default=8, help='number of visualization samples') 20 | models_args.add_argument('--use_VAE', action='store_true', default=True, help='use KL divergence') 21 | models_args.add_argument('--category', type=str, default='laptop', help='object category') 22 | if is_train: 23 | models_args.add_argument('--lambda_recon', type=float, default=10., help='weight for reconstruction loss') 24 | models_args.add_argument('--lambda_KL', type=float, default=0.01, help='weight for the KL divergence') 25 | else: 26 | fitting_args = parser.add_argument_group('fitting') 27 | fitting_args.set_defaults(dataset_mode='nocs_hdf5', batch_size=1, no_flip=True, preprocess=' ') 28 | fitting_args.add_argument('--n_iter', type=int, default=50, help='number of optimization iterations') 29 | fitting_args.add_argument('--n_init', type=int, default=32, help='number of initializations') 30 | fitting_args.add_argument('--lambda_reg', type=float, default=1, help='weight for the KL divergence') 31 | 32 | return parser 33 | 34 | def __init__(self, opt): 35 | 36 | BaseModel.__init__(self, opt) 37 | self.use_VAE = opt.use_VAE 38 | 39 | self.loss_names = ['G_recon'] 40 | if self.opt.use_VAE > 0: self.loss_names += ['KL'] 41 | 42 | self.visual_names = ['real_A','real_B','fake_B'] 43 | 44 | self.video_names = ['anim_azim','anim_elev'] 45 | 46 | self.model_names = ['G','E'] 47 | 48 | self.optimizer_names = ['G'] 49 | 50 | # define networks (both generator and discriminator) 51 | self.netG = networks.Generator(opt.z_dim).to(self.device) 52 | networks.init_net(self.netG, init_type=self.opt.init_type, init_gain=self.opt.init_gain,gpu_ids=self.gpu_ids) 53 | 54 | output_dim = opt.z_dim *2 if self.use_VAE else opt.z_dim 55 | self.netE = networks.Encoder(3, opt.crop_size, output_dim).to(self.device) 56 | self.netE = networks.add_SN(self.netE) 57 | networks.init_net(self.netE, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 58 | 59 | if self.isTrain: 60 | self.criterion_recon = torch.nn.L1Loss().to(self.device) 61 | self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG.parameters(),self.netE.parameters()), lr=opt.lr, betas=(0.5,0.999)) 62 | self.optimizers.append(self.optimizer_G) 63 | 64 | # define the prior distribution` 65 | mu = torch.zeros(opt.z_dim, device=self.device) 66 | scale = torch.ones(opt.z_dim, device=self.device) 67 | self.z_dist = torch.distributions.Normal(mu, scale) 68 | 69 | self.batch_size_vis = opt.batch_size_vis 70 | 71 | def set_input(self, input): 72 | self.real_A = input['A'].to(self.device) 73 | self.real_B = input['B'].to(self.device) 74 | self.theta = input['B_pose'].to(self.device) 75 | 76 | def forward(self): 77 | if self.use_VAE == 0: 78 | self.z = self.netE(self.real_A) 79 | else: 80 | b = self.real_A.shape[0] 81 | output = self.netE(self.real_A) 82 | self.mu, self.logvar = output[:,:self.opt.z_dim],output[:,self.opt.z_dim:] 83 | std = self.logvar.mul(0.5).exp_() 84 | self.z_sample = self.z_dist.sample((b,)) 85 | eps = self.z_sample 86 | self.z = eps.mul(std).add_(self.mu) 87 | 88 | self.fake_B = self.netG(self.z,self.theta) 89 | return self.fake_B 90 | 91 | def backward(self): 92 | self.loss_KL = (1 + self.logvar - self.mu.pow(2) - self.logvar.exp()).mean() * (-0.5 * self.opt.lambda_KL) 93 | self.loss_G_recon = self.criterion_recon(self.fake_B, self.real_B) 94 | self.loss_G = self.loss_G_recon * self.opt.lambda_recon 95 | self.loss_G.backward() 96 | 97 | def optimize_parameters(self): 98 | self.train() 99 | self.forward() 100 | self.optimizer_G.zero_grad() 101 | self.backward() 102 | self.optimizer_G.step() 103 | 104 | def compute_visuals(self): 105 | self.netG.eval() 106 | self.real_A = self.real_A[:self.batch_size_vis,...] 107 | self.real_B = self.real_B[:self.batch_size_vis,...] 108 | self.fake_B = self.fake_B[:self.batch_size_vis,...] 109 | 110 | self.z_vis = self.netE(self.real_A)[:,:self.opt.z_dim] 111 | 112 | with torch.no_grad(): 113 | self.anim_azim = [] 114 | elev = 0 115 | for azim in range(-180,180,3): 116 | theta = torch.zeros((self.batch_size_vis,3)).to(self.device) 117 | theta[:,0],theta[:,1] = elev,azim 118 | frame = self.netG(self.z_vis,theta).detach().data 119 | self.anim_azim.append(frame) 120 | self.anim_elev= [] 121 | azim = 0 122 | for elev in range(-90,90,3): 123 | theta = torch.zeros((self.batch_size_vis, 3)).to(self.device) 124 | theta[:, 0], theta[:, 1] = elev, azim 125 | frame = self.netG(self.z_vis, theta).detach().data 126 | self.anim_elev.append(frame) 127 | 128 | def fitting(self, real_B): 129 | import tqdm 130 | import torch.optim as optim 131 | import torch.nn.functional as F 132 | 133 | from models.networks.utils import grid_sample, warping_grid, init_variable 134 | from models.networks.losses import PerceptualLoss 135 | 136 | real_B = real_B.to(self.device).repeat((self.opt.n_init, 1, 1, 1)) 137 | real_B = real_B[:, [2, 1, 0], :, :] 138 | 139 | ay = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='linspace', range=[-1/2,1/2]) 140 | ax = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=1/4) 141 | az = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=0) 142 | s = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=1) 143 | tx = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=0) 144 | ty = init_variable(dim=1, n_init=self.opt.n_init, device=self.device, mode='constant', value=0) 145 | z = init_variable(dim=self.opt.z_dim, n_init=self.opt.n_init, device=self.device, mode='constant', value=0) 146 | 147 | latent = self.netE(F.interpolate(real_B, size=self.opt.crop_size, mode='nearest')) 148 | if self.opt.use_VAE: 149 | mu, logvar = latent[:, :self.opt.z_dim], latent[:, self.opt.z_dim:] 150 | std = logvar.mul(0.5).exp_() 151 | eps = self.z_dist.sample((self.opt.n_init,)) 152 | z.data = eps.mul(std).add_(mu) 153 | else: 154 | z.data = latent 155 | 156 | variable_dict = [ 157 | {'params': z, 'lr': 3e-1}, 158 | {'params': ax, 'lr': 1e-2}, 159 | {'params': ay, 'lr': 3e-2}, 160 | {'params': az, 'lr': 1e-2}, 161 | {'params': tx, 'lr': 3e-2}, 162 | {'params': ty, 'lr': 3e-2}, 163 | {'params': s, 'lr': 3e-2}, 164 | ] 165 | optimizer = optim.Adam(variable_dict,betas=(0.5,0.999)) 166 | 167 | losses = [('VGG', 1, PerceptualLoss(reduce=False))] 168 | reg_creterion = torch.nn.MSELoss(reduce=False) 169 | 170 | loss_history = np.zeros( (self.opt.n_init,self.opt.n_iter,len(losses)+1)) 171 | state_history = np.zeros( (self.opt.n_init,self.opt.n_iter,6 + self.opt.z_dim)) 172 | image_history = torch.tensor(()) 173 | 174 | for iter in tqdm.tqdm(range(self.opt.n_iter)): 175 | 176 | optimizer.zero_grad() 177 | 178 | angle = 180 * torch.cat([ax, ay, torch.zeros_like(ay)], dim=1) 179 | fake_B = self.netG(z,angle) 180 | 181 | grid = warping_grid(az * np.pi, tx, ty, s, fake_B.shape) 182 | fake_B = grid_sample(fake_B, grid) 183 | 184 | fake_B_upsampled = F.interpolate(fake_B, size=real_B.shape[-1], mode='bilinear') 185 | 186 | error_all = 0 187 | for l, (name, weight, criterion)in enumerate(losses): 188 | error = weight * criterion(fake_B_upsampled, real_B).view(self.opt.n_init,-1).mean(1) 189 | loss_history[:,iter,l] = error.data.cpu().numpy() 190 | error_all = error_all + error 191 | 192 | error = self.opt.lambda_reg * reg_creterion(z,torch.zeros_like(z)).view(self.opt.n_init,-1).mean(1) 193 | loss_history[:, iter, l+1] = error.data.cpu().numpy() 194 | error_all = error_all + error 195 | 196 | error_all.mean().backward() 197 | 198 | optimizer.step() 199 | 200 | image_history = torch.cat([image_history, fake_B.cpu()], dim=0) 201 | 202 | state_history[:, iter, :3] = 180*torch.cat([-ay-0.5, ax+1, -az],dim=-1).data.cpu().numpy() 203 | state_history[:, iter, 3:] = torch.cat([tx, ty, s, z],dim=-1).data.cpu().numpy() 204 | # print(criterion(fake_B_upsampled[0:1, :, :, :], real_B[0:1, :, :, :]).view( 205 | # 1, -1).mean(1)) 206 | # print(fake_B_upsampled.shape , real_B.shape) 207 | image_history = torch.cat([image_history, real_B.cpu()], dim=0) 208 | return state_history, loss_history, image_history 209 | 210 | def visulize_fitting(self, real_B, RT_gt, state_history, loss_history, image_history): 211 | import matplotlib.pyplot as plt 212 | from util.util import tensor2im 213 | from models.networks.utils import set_axis 214 | import matplotlib 215 | matplotlib.use('TkAgg') 216 | 217 | RT_gt = RT_gt.numpy()[0] 218 | R_gt = RT_gt[:3, :3] 219 | real_B_img = tensor2im(real_B) 220 | 221 | n_init, n_iter, n_loss = loss_history.shape 222 | 223 | fig, axes = plt.subplots(nrows=loss_history.shape[2] + 2, ncols=n_init + 1, sharey='row') 224 | axes[0, -1].clear();axes[0, -1].axis('off') 225 | axes[0, -1].imshow(real_B_img) 226 | plt.ion() 227 | 228 | plots = axes.copy() 229 | for row in range(axes.shape[0]): 230 | for col in range(axes.shape[1]): 231 | if row == 0: 232 | axes[row, col].axis('off') 233 | plots[row, col] = axes[row, col].imshow(real_B_img) 234 | elif col < n_init: 235 | if row < n_loss+1: 236 | set_axis(axes[row, col]) 237 | plots[row, col] = axes[row, col].plot(np.arange(n_iter),loss_history[col, :,row-1]) 238 | else: 239 | plots[row, col] = axes[row, col].plot(np.arange(n_iter),60*np.ones(n_iter)) 240 | axes[row, col].set_ylim([0,60]) 241 | 242 | errors = np.zeros((n_init,n_iter)) 243 | 244 | for iter in range(n_iter): 245 | for init in range(n_init): 246 | pose = state_history[init,iter,:3] 247 | R_pd = scipy_rot.from_euler('yxz', pose, degrees=True).as_dcm()[:3, :3] 248 | 249 | R_pd = R_pd[:3, :3]/np.cbrt(np.linalg.det(R_pd[:3, :3])) 250 | R_gt = R_gt[:3, :3]/np.cbrt(np.linalg.det(R_gt[:3, :3])) 251 | 252 | R = R_pd @ R_gt.transpose() 253 | errors[init,iter] = np.arccos((np.trace(R) - 1)/2) * 180/np.pi 254 | 255 | ranking = [r[0] for r in sorted(enumerate(loss_history[:, iter,:].mean(-1)), key=lambda r: r[1])] 256 | 257 | for r, b in enumerate(ranking[::-1]): 258 | plots[0, r].set_data(tensor2im(image_history[iter][b].unsqueeze(0))) 259 | for l in range(loss_history.shape[2]): 260 | plots[l + 1, r][0].set_data(np.arange(iter),loss_history[b, :iter,l]) 261 | plots[-1, r][0].set_data(np.arange(iter), errors[b, :iter]) 262 | 263 | plt.draw() 264 | 265 | plt.pause(0.01) 266 | plt.close(fig) -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrld/visual_navigation_pose_estimation/58d98a3592157f2558120f18af7c9ec77e795ee1/models/networks/__init__.py -------------------------------------------------------------------------------- /models/networks/losses.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torchvision.models as models 7 | from math import exp 8 | 9 | class PerceptualLoss(nn.Module): 10 | 11 | def __init__(self,type='l2',reduce=True,final_layer=14): 12 | super(PerceptualLoss, self).__init__() 13 | self.model = self.contentFunc(final_layer=final_layer) 14 | self.model.eval() 15 | self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda() 16 | self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda() 17 | self.type = type 18 | if type == 'l1': 19 | self.criterion = torch.nn.L1Loss(reduce=reduce) 20 | elif type == 'l2': 21 | self.criterion = torch.nn.MSELoss(reduce=reduce) 22 | elif type == 'both': 23 | self.criterion1 = torch.nn.L1Loss(reduce=reduce) 24 | self.criterion2 = torch.nn.MSELoss(reduce=reduce) 25 | else: 26 | raise NotImplementedError 27 | 28 | def normalize(self, tensor): 29 | tensor = (tensor+1)*0.5 30 | tensor_norm = (tensor-self.mean.expand(tensor.shape))/self.std.expand(tensor.shape) 31 | return tensor_norm 32 | 33 | def contentFunc(self,final_layer=14): 34 | cnn = models.vgg19(pretrained=True).features 35 | cnn = cnn.cuda() 36 | model = nn.Sequential() 37 | model = model.cuda() 38 | for i, layer in enumerate(list(cnn)): 39 | model.add_module(str(i), layer) 40 | if i == final_layer: 41 | break 42 | return model 43 | 44 | def forward(self, fakeIm, realIm): 45 | f_fake = self.model.forward(self.normalize(fakeIm)) 46 | f_real = self.model.forward(self.normalize(realIm)) 47 | if self.type == 'both': 48 | loss = self.criterion1(f_fake, f_real.detach())+self.criterion2(f_fake, f_real.detach()) 49 | else: 50 | loss = self.criterion(f_fake, f_real.detach()) 51 | return loss 52 | 53 | def gaussian(window_size, sigma): 54 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 55 | return gauss / gauss.sum() 56 | 57 | 58 | def create_window(window_size, channel): 59 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 60 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 61 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 62 | return window 63 | 64 | 65 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 66 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 67 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 68 | 69 | mu1_sq = mu1.pow(2) 70 | mu2_sq = mu2.pow(2) 71 | mu1_mu2 = mu1 * mu2 72 | 73 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 74 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 75 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 76 | 77 | C1 = 0.01 ** 2 78 | C2 = 0.03 ** 2 79 | 80 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 81 | 82 | if size_average: 83 | return ssim_map.mean() 84 | else: 85 | return ssim_map.mean(1).mean(1).mean(1) 86 | 87 | 88 | class SSIM(torch.nn.Module): 89 | def __init__(self, window_size=11, reduce=True,negative=False): 90 | super(SSIM, self).__init__() 91 | self.window_size = window_size 92 | self.reduce = reduce 93 | self.channel = 1 94 | self.window = create_window(window_size, self.channel) 95 | self.negative = negative 96 | 97 | def forward(self, img1, img2): 98 | (_, channel, _, _) = img1.size() 99 | 100 | if channel == self.channel and self.window.data.type() == img1.data.type(): 101 | window = self.window 102 | else: 103 | window = create_window(self.window_size, channel) 104 | 105 | if img1.is_cuda: 106 | window = window.cuda(img1.get_device()) 107 | window = window.type_as(img1) 108 | 109 | self.window = window 110 | self.channel = channel 111 | if self.negative: 112 | return -_ssim(img1, img2, window, self.window_size, channel, self.reduce) 113 | else: 114 | return _ssim(img1, img2, window, self.window_size, channel, self.reduce) 115 | 116 | 117 | def ssim(img1, img2, window_size=11, reduce=True): 118 | (_, channel, _, _) = img1.size() 119 | window = create_window(window_size, channel) 120 | 121 | if img1.is_cuda: 122 | window = window.cuda(img1.get_device()) 123 | window = window.type_as(img1) 124 | 125 | return _ssim(img1, img2, window, window_size, channel, reduce) -------------------------------------------------------------------------------- /models/networks/networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import init 4 | import torch.utils.data 5 | import torch.utils.data.distributed 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.optim import lr_scheduler 9 | 10 | from models.networks import spectral_norm 11 | 12 | class Generator(nn.Module): 13 | def __init__(self, z_dim, euler_seq='zyx', **kwargs): 14 | super().__init__() 15 | self.shape_code = nn.Parameter(0.02*torch.randn(1,512,4,4,4),requires_grad=True) 16 | # Upsampling 3D 17 | self.enc_1 = nn.Sequential(*[nn.ConvTranspose3d(512,128,kernel_size=4,stride=2, padding=1)]) 18 | self.enc_2 = nn.Sequential(*[nn.ConvTranspose3d(128,64, kernel_size=4, stride=2, padding=1)]) 19 | # Projection 20 | self.proj = nn.Sequential(*[nn.ConvTranspose2d(64*16,64*16, kernel_size=1,stride=1)]) 21 | # Upsampling 2D 22 | self.enc_3 = nn.Sequential(*[nn.ConvTranspose2d(64*16,64*4,kernel_size=4,stride=2, padding=1)]) 23 | self.enc_4 = nn.Sequential(*[nn.ConvTranspose2d(64*4,64,kernel_size=4,stride=2, padding=1)]) 24 | self.enc_5 = nn.Sequential(*[nn.ConvTranspose2d(64,3,kernel_size=3,stride=1,padding=1)]) 25 | # MLP for AdaIN 26 | self.mlp0 = LinearBlock(z_dim,512*2,activation='relu') 27 | self.mlp1 = LinearBlock(z_dim,128*2,activation='relu') 28 | self.mlp2 = LinearBlock(z_dim,64*2,activation='relu') 29 | self.mlp3 = LinearBlock(z_dim,256*2,activation='relu') 30 | self.mlp4 = LinearBlock(z_dim,64*2,activation='relu') 31 | 32 | self.euler_seq = euler_seq 33 | 34 | def forward(self, z, angle, a=None, debug=False): 35 | b,_ = z.size() 36 | angle = angle / 180. * np.pi 37 | # Upsampling 3D 38 | h0 = self.shape_code.expand(b, 512, 4, 4, 4).clone() 39 | a0 = self.mlp0(z) 40 | h0 = actvn( adaIN(h0,a0) ) 41 | 42 | h1 = self.enc_1(h0) 43 | a1 = self.mlp1(z) 44 | h1 = actvn( adaIN(h1,a1) ) 45 | 46 | h2 = self.enc_2(h1) 47 | a2 = self.mlp2(z) 48 | h2 = actvn(adaIN(h2, a2)) 49 | 50 | # Rotation 51 | h2_rot = rot(h2,angle,euler_seq=self.euler_seq,padding="border") 52 | b,c,d,h,w = h2_rot.size() 53 | h2_2d = h2_rot.contiguous().view(b,c*d,h,w) 54 | h2_2d = actvn(self.proj(h2_2d)) 55 | # Upsampling 2D 56 | h3 = self.enc_3(h2_2d) 57 | a3 = self.mlp3(z) 58 | h3 = actvn(adaIN(h3, a3)) 59 | 60 | h4 = self.enc_4(h3) 61 | a4 = self.mlp4(z) 62 | h4 = actvn(adaIN(h4, a4)) 63 | 64 | h5 = self.enc_5(h4) 65 | return torch.tanh(h5) 66 | def actvn(x): 67 | out = F.leaky_relu(x, 2e-1) 68 | return out 69 | 70 | class Encoder(nn.Module): 71 | def __init__(self,in_dim=3, in_size=64, z_dim=128): 72 | super().__init__() 73 | self.model = nn.Sequential(*[ 74 | nn.Conv2d( in_dim, 64, 3, 2, 1), nn.LeakyReLU(0.2), 75 | nn.Conv2d( 64, 128, 3, 2, 1), nn.LeakyReLU(0.2), nn.InstanceNorm2d(128), 76 | nn.Conv2d(128, 256, 3, 2, 1), nn.LeakyReLU(0.2), nn.InstanceNorm2d(256), 77 | nn.Conv2d(256, 512, 3, 2, 1), nn.LeakyReLU(0.2), nn.InstanceNorm2d(512), 78 | ]) 79 | self.enc_out = nn.Sequential(*[ 80 | nn.Linear((in_size//16)**2*512,128), nn.LeakyReLU(0.2), 81 | nn.Linear(128,z_dim), nn.Tanh() 82 | ]) 83 | def forward(self, x): 84 | b,c,h,w = x.shape 85 | x = self.model.forward(x).view(b,(h//16)**2*512) 86 | enc = self.enc_out(x) 87 | return enc 88 | 89 | class LinearBlock(nn.Module): 90 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'): 91 | super(LinearBlock, self).__init__() 92 | use_bias = True 93 | # initialize fully connected layer 94 | 95 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 96 | 97 | # initialize normalization 98 | norm_dim = output_dim 99 | if norm == 'bn': 100 | self.norm = nn.BatchNorm1d(norm_dim) 101 | elif norm == 'in': 102 | self.norm = nn.InstanceNorm1d(norm_dim) 103 | elif norm == 'none' or norm == 'sn': 104 | self.norm = None 105 | else: 106 | assert 0, "Unsupported normalization: {}".format(norm) 107 | 108 | # initialize activation 109 | if activation == 'relu': 110 | self.activation = nn.ReLU(inplace=True) 111 | elif activation == 'lrelu': 112 | self.activation = nn.LeakyReLU(0.2, inplace=True) 113 | elif activation == 'prelu': 114 | self.activation = nn.PReLU() 115 | elif activation == 'selu': 116 | self.activation = nn.SELU(inplace=True) 117 | elif activation == 'tanh': 118 | self.activation = nn.Tanh() 119 | elif activation == 'none': 120 | self.activation = None 121 | else: 122 | assert 0, "Unsupported activation: {}".format(activation) 123 | 124 | def forward(self, x): 125 | out = self.fc(x) 126 | if self.norm: 127 | out = self.norm(out) 128 | if self.activation: 129 | out = self.activation(out) 130 | return out 131 | 132 | def rot(x,angle,euler_seq='xyz',padding='zeros'): 133 | b,c,d,h,w = x.shape 134 | grid = set_id_grid(x) 135 | grid_flat = grid.reshape(b, 3, -1) 136 | grid_rot_flat = euler2mat(angle,euler_seq=euler_seq).bmm(grid_flat) 137 | grid_rot = grid_rot_flat.reshape(b,3,d,h,w) 138 | x_rot = F.grid_sample(x,grid_rot.permute(0,2,3,4,1),padding_mode=padding,mode='bilinear', align_corners=True) 139 | return x_rot 140 | 141 | def euler2mat(angle, euler_seq='xyz' ): 142 | """Convert euler angles to rotation matrix. 143 | 144 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174 145 | 146 | Args: 147 | angle: rotation angle along 3 axis (in radians) -- size = [B, 3] 148 | Returns: 149 | Rotation matrix corresponding to the euler angles -- size = [B, 3, 3] 150 | """ 151 | B = angle.size(0) 152 | x, y, z = angle[:,0], angle[:,1], angle[:,2] 153 | 154 | zeros = z.detach()*0 155 | ones = zeros.detach()+1 156 | 157 | cosz = torch.cos(z) 158 | sinz = torch.sin(z) 159 | zmat = torch.stack([cosz, -sinz, zeros, 160 | sinz, cosz, zeros, 161 | zeros, zeros, ones], dim=1).reshape(B, 3, 3) 162 | 163 | cosy = torch.cos(y) 164 | siny = torch.sin(y) 165 | ymat = torch.stack([cosy, zeros, siny, 166 | zeros, ones, zeros, 167 | -siny, zeros, cosy], dim=1).reshape(B, 3, 3) 168 | 169 | cosx = torch.cos(x) 170 | sinx = torch.sin(x) 171 | xmat = torch.stack([ones, zeros, zeros, 172 | zeros, cosx, -sinx, 173 | zeros, sinx, cosx], dim=1).reshape(B, 3, 3) 174 | 175 | if euler_seq == 'xyz': 176 | rotMat = xmat.bmm(ymat).bmm(zmat) 177 | elif euler_seq == 'zyx': 178 | rotMat = zmat.bmm(ymat).bmm(xmat) 179 | return rotMat 180 | 181 | 182 | def set_id_grid(x): 183 | b, c, d, h, w = x.shape 184 | z_range = (torch.linspace(-1,1,steps=d)).view(1, d, 1, 1).expand(1, d, h, w).type_as(x) # [1, H, W, D] 185 | y_range = (torch.linspace(-1,1,steps=h)).view(1, 1, h, 1).expand(1, d, h, w).type_as(x) # [1, H, W, D] 186 | x_range = (torch.linspace(-1,1,steps=w)).view(1, 1, 1, w).expand(1, d, h, w).type_as(x) # [1, H, W, D] 187 | grid = torch.cat((x_range, y_range, z_range), dim=0)[None,...] # x,y,z 188 | grid = grid.expand(b,3,d,h,w) 189 | return grid 190 | 191 | def calc_mean_std(feat, eps=1e-5): 192 | # eps is a small value added to the variance to avoid divide-by-zero. 193 | size = feat.size() 194 | assert (len(size) == 4 or len(size) == 5) 195 | N, C = size[:2] 196 | 197 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 198 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 199 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 200 | 201 | if len(size)==5: 202 | feat_std = feat_std.unsqueeze(-1) 203 | feat_mean = feat_mean.unsqueeze(-1) 204 | 205 | return feat_mean, feat_std 206 | 207 | 208 | def adaIN(content_feat, style_mean_std): 209 | assert(content_feat.size(1) == style_mean_std.size(1)/2) 210 | size = content_feat.size() 211 | b,c = style_mean_std.size() 212 | style_mean, style_std = style_mean_std[:,:c//2],style_mean_std[:,c//2:] 213 | 214 | style_mean = style_mean.unsqueeze(-1).unsqueeze(-1) 215 | style_std = style_std.unsqueeze(-1).unsqueeze(-1) 216 | if len(size)==5: 217 | style_mean = style_mean.unsqueeze(-1) 218 | style_std = style_std.unsqueeze(-1) 219 | content_mean, content_std = calc_mean_std(content_feat) 220 | 221 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 222 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 223 | 224 | def add_SN(m): 225 | for name, c in m.named_children(): 226 | m.add_module(name, add_SN(c)) 227 | if isinstance(m, (nn.Conv2d, nn.Linear)): 228 | return spectral_norm.spectral_norm(m)#nn.utils.spectral_norm(m) 229 | else: 230 | return m 231 | 232 | def init_weights(net, init_type='normal', init_gain=0.02): 233 | """Initialize network weights. 234 | 235 | Parameters: 236 | net (network) -- network to be initialized 237 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 238 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 239 | 240 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 241 | work better for some applications. Feel free to try yourself. 242 | """ 243 | def init_func(m): # define the initialization function 244 | classname = m.__class__.__name__ 245 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 246 | if init_type == 'normal': 247 | init.normal_(m.weight.data, 0.0, init_gain) 248 | elif init_type == 'xavier': 249 | init.xavier_normal_(m.weight.data, gain=init_gain) 250 | elif init_type == 'kaiming': 251 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 252 | elif init_type == 'orthogonal': 253 | init.orthogonal_(m.weight.data, gain=init_gain) 254 | else: 255 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 256 | if hasattr(m, 'bias') and m.bias is not None: 257 | init.constant_(m.bias.data, 0.0) 258 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 259 | init.normal_(m.weight.data, 1.0, init_gain) 260 | init.constant_(m.bias.data, 0.0) 261 | 262 | print('initialize network with %s' % init_type) 263 | net.apply(init_func) # apply the initialization function 264 | 265 | 266 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 267 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 268 | Parameters: 269 | net (network) -- the network to be initialized 270 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 271 | gain (float) -- scaling factor for normal, xavier and orthogonal. 272 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 273 | 274 | Return an initialized network. 275 | """ 276 | if len(gpu_ids) > 0: 277 | assert(torch.cuda.is_available()) 278 | net.to(gpu_ids[0]) 279 | # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 280 | if init_type is not None: 281 | init_weights(net, init_type, init_gain=init_gain) 282 | return net 283 | 284 | 285 | def get_scheduler(optimizer, opt): 286 | """Return a learning rate scheduler 287 | Parameters: 288 | optimizer -- the optimizer of the network 289 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  290 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 291 | For 'linear', we keep the same learning rate for the first epochs 292 | and linearly decay the rate to zero over the next epochs. 293 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 294 | See https://pytorch.org/docs/stable/optim.html for more details. 295 | """ 296 | if opt.lr_policy == 'linear': 297 | def lambda_rule(epoch): 298 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 299 | return lr_l 300 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 301 | elif opt.lr_policy == 'step': 302 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 303 | elif opt.lr_policy == 'plateau': 304 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 305 | elif opt.lr_policy == 'cosine': 306 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 307 | else: 308 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 309 | return scheduler -------------------------------------------------------------------------------- /models/networks/spectral_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Spectral Normalization from https://arxiv.org/abs/1802.05957 3 | """ 4 | import torch 5 | from torch.nn.functional import normalize 6 | 7 | 8 | class SpectralNorm(object): 9 | # Invariant before and after each forward call: 10 | # u = normalize(W @ v) 11 | # NB: At initialization, this invariant is not enforced 12 | 13 | _version = 1 14 | # At version 1: 15 | # made `W` not a buffer, 16 | # added `v` as a buffer, and 17 | # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. 18 | 19 | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): 20 | self.name = name 21 | self.dim = dim 22 | if n_power_iterations <= 0: 23 | raise ValueError('Expected n_power_iterations to be positive, but ' 24 | 'got n_power_iterations={}'.format(n_power_iterations)) 25 | self.n_power_iterations = n_power_iterations 26 | self.eps = eps 27 | 28 | def reshape_weight_to_matrix(self, weight): 29 | weight_mat = weight 30 | if self.dim != 0: 31 | # permute dim to front 32 | weight_mat = weight_mat.permute(self.dim, 33 | *[d for d in range(weight_mat.dim()) if d != self.dim]) 34 | height = weight_mat.size(0) 35 | return weight_mat.reshape(height, -1) 36 | 37 | def compute_weight(self, module, do_power_iteration): 38 | # NB: If `do_power_iteration` is set, the `u` and `v` vectors are 39 | # updated in power iteration **in-place**. This is very important 40 | # because in `DataParallel` forward, the vectors (being buffers) are 41 | # broadcast from the parallelized module to each module replica, 42 | # which is a new module object created on the fly. And each replica 43 | # runs its own spectral norm power iteration. So simply assigning 44 | # the updated vectors to the module this function runs on will cause 45 | # the update to be lost forever. And the next time the parallelized 46 | # module is replicated, the same randomly initialized vectors are 47 | # broadcast and used! 48 | # 49 | # Therefore, to make the change propagate back, we rely on two 50 | # important behaviors (also enforced via tests): 51 | # 1. `DataParallel` doesn't clone storage if the broadcast tensor 52 | # is already on correct device; and it makes sure that the 53 | # parallelized module is already on `device[0]`. 54 | # 2. If the out tensor in `out=` kwarg has correct shape, it will 55 | # just fill in the values. 56 | # Therefore, since the same power iteration is performed on all 57 | # devices, simply updating the tensors in-place will make sure that 58 | # the module replica on `device[0]` will update the _u vector on the 59 | # parallized module (by shared storage). 60 | # 61 | # However, after we update `u` and `v` in-place, we need to **clone** 62 | # them before using them to normalize the weight. This is to support 63 | # backproping through two forward passes, e.g., the common pattern in 64 | # GAN training: loss = D(real) - D(fake). Otherwise, engine will 65 | # complain that variables needed to do backward for the first forward 66 | # (i.e., the `u` and `v` vectors) are changed in the second forward. 67 | weight = getattr(module, self.name + '_orig') 68 | u = getattr(module, self.name + '_u') 69 | v = getattr(module, self.name + '_v') 70 | weight_mat = self.reshape_weight_to_matrix(weight) 71 | 72 | if do_power_iteration: 73 | with torch.no_grad(): 74 | for _ in range(self.n_power_iterations): 75 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 76 | # are the first left and right singular vectors. 77 | # This power iteration produces approximations of `u` and `v`. 78 | v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v) 79 | u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) 80 | if self.n_power_iterations > 0: 81 | # See above on why we need to clone 82 | u = u.clone() 83 | v = v.clone() 84 | 85 | sigma = torch.dot(u, torch.mv(weight_mat, v)) 86 | weight = weight / sigma 87 | return weight 88 | 89 | def remove(self, module): 90 | with torch.no_grad(): 91 | weight = self.compute_weight(module, do_power_iteration=False) 92 | delattr(module, self.name) 93 | delattr(module, self.name + '_u') 94 | delattr(module, self.name + '_v') 95 | delattr(module, self.name + '_orig') 96 | module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) 97 | 98 | def __call__(self, module, inputs): 99 | setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training)) 100 | 101 | def _solve_v_and_rescale(self, weight_mat, u, target_sigma): 102 | # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` 103 | # (the invariant at top of this class) and `u @ W @ v = sigma`. 104 | # This uses pinverse in case W^T W is not invertible. 105 | v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1) 106 | return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) 107 | 108 | @staticmethod 109 | def apply(module, name, n_power_iterations, dim, eps): 110 | for k, hook in module._forward_pre_hooks.items(): 111 | if isinstance(hook, SpectralNorm) and hook.name == name: 112 | raise RuntimeError("Cannot register two spectral_norm hooks on " 113 | "the same parameter {}".format(name)) 114 | 115 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 116 | weight = module._parameters[name] 117 | 118 | with torch.no_grad(): 119 | weight_mat = fn.reshape_weight_to_matrix(weight) 120 | 121 | h, w = weight_mat.size() 122 | # randomly initialize `u` and `v` 123 | u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) 124 | v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) 125 | 126 | delattr(module, fn.name) 127 | module.register_parameter(fn.name + "_orig", weight) 128 | # We still need to assign weight back as fn.name because all sorts of 129 | # things may assume that it exists, e.g., when initializing weights. 130 | # However, we can't directly assign as it could be an nn.Parameter and 131 | # gets added as a parameter. Instead, we register weight.data as a plain 132 | # attribute. 133 | setattr(module, fn.name, weight.data) 134 | module.register_buffer(fn.name + "_u", u) 135 | module.register_buffer(fn.name + "_v", v) 136 | 137 | module.register_forward_pre_hook(fn) 138 | 139 | module._register_state_dict_hook(SpectralNormStateDictHook(fn)) 140 | module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) 141 | return fn 142 | 143 | 144 | # This is a top level class because Py2 pickle doesn't like inner class nor an 145 | # instancemethod. 146 | class SpectralNormLoadStateDictPreHook(object): 147 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 148 | def __init__(self, fn): 149 | self.fn = fn 150 | 151 | # For state_dict with version None, (assuming that it has gone through at 152 | # least one training forward), we have 153 | # 154 | # u = normalize(W_orig @ v) 155 | # W = W_orig / sigma, where sigma = u @ W_orig @ v 156 | # 157 | # To compute `v`, we solve `W_orig @ x = u`, and let 158 | # v = x / (u @ W_orig @ x) * (W / W_orig). 159 | def __call__(self, state_dict, prefix, local_metadata, strict, 160 | missing_keys, unexpected_keys, error_msgs): 161 | fn = self.fn 162 | version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None) 163 | if version is None or version < 1: 164 | weight_key = prefix + fn.name 165 | if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \ 166 | weight_key not in state_dict: 167 | # Detect if it is the updated state dict and just missing metadata. 168 | # This could happen if the users are crafting a state dict themselves, 169 | # so we just pretend that this is the newest. 170 | return 171 | has_missing_keys = False 172 | for suffix in ('_orig', '', '_u'): 173 | key = weight_key + suffix 174 | if key not in state_dict: 175 | has_missing_keys = True 176 | if strict: 177 | missing_keys.append(key) 178 | if has_missing_keys: 179 | return 180 | with torch.no_grad(): 181 | weight_orig = state_dict[weight_key + '_orig'] 182 | weight = state_dict.pop(weight_key) 183 | sigma = (weight_orig / weight).mean() 184 | weight_mat = fn.reshape_weight_to_matrix(weight_orig) 185 | u = state_dict[weight_key + '_u'] 186 | v = fn._solve_v_and_rescale(weight_mat, u, sigma) 187 | state_dict[weight_key + '_v'] = v 188 | 189 | 190 | 191 | # This is a top level class because Py2 pickle doesn't like inner class nor an 192 | # instancemethod. 193 | class SpectralNormStateDictHook(object): 194 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 195 | def __init__(self, fn): 196 | self.fn = fn 197 | 198 | def __call__(self, module, state_dict, prefix, local_metadata): 199 | if 'spectral_norm' not in local_metadata: 200 | local_metadata['spectral_norm'] = {} 201 | key = self.fn.name + '.version' 202 | if key in local_metadata['spectral_norm']: 203 | raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key)) 204 | local_metadata['spectral_norm'][key] = self.fn._version 205 | 206 | 207 | def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): 208 | r"""Applies spectral normalization to a parameter in the given module. 209 | 210 | .. math:: 211 | \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, 212 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 213 | 214 | Spectral normalization stabilizes the training of discriminators (critics) 215 | in Generative Adversarial Networks (GANs) by rescaling the weight tensor 216 | with spectral norm :math:`\sigma` of the weight matrix calculated using 217 | power iteration method. If the dimension of the weight tensor is greater 218 | than 2, it is reshaped to 2D in power iteration method to get spectral 219 | norm. This is implemented via a hook that calculates spectral norm and 220 | rescales weight before every :meth:`~Module.forward` call. 221 | 222 | See `Spectral Normalization for Generative Adversarial Networks`_ . 223 | 224 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 225 | 226 | Args: 227 | module (nn.Module): containing module 228 | name (str, optional): name of weight parameter 229 | n_power_iterations (int, optional): number of power iterations to 230 | calculate spectral norm 231 | eps (float, optional): epsilon for numerical stability in 232 | calculating norms 233 | dim (int, optional): dimension corresponding to number of outputs, 234 | the default is ``0``, except for modules that are instances of 235 | ConvTranspose{1,2,3}d, when it is ``1`` 236 | 237 | Returns: 238 | The original module with the spectral norm hook 239 | 240 | Example:: 241 | 242 | >>> m = spectral_norm(nn.Linear(20, 40)) 243 | >>> m 244 | Linear(in_features=20, out_features=40, bias=True) 245 | >>> m.weight_u.size() 246 | torch.Size([40]) 247 | 248 | """ 249 | if dim is None: 250 | if isinstance(module, (torch.nn.ConvTranspose1d, 251 | torch.nn.ConvTranspose2d, 252 | torch.nn.ConvTranspose3d)): 253 | dim = 1 254 | else: 255 | dim = 0 256 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 257 | return module 258 | 259 | 260 | def remove_spectral_norm(module, name='weight'): 261 | r"""Removes the spectral normalization reparameterization from a module. 262 | 263 | Args: 264 | module (Module): containing module 265 | name (str, optional): name of weight parameter 266 | 267 | Example: 268 | >>> m = spectral_norm(nn.Linear(40, 10)) 269 | >>> remove_spectral_norm(m) 270 | """ 271 | for k, hook in module._forward_pre_hooks.items(): 272 | if isinstance(hook, SpectralNorm) and hook.name == name: 273 | hook.remove(module) 274 | del module._forward_pre_hooks[k] 275 | return module 276 | 277 | raise ValueError("spectral_norm of '{}' not found in {}".format( 278 | name, module)) 279 | -------------------------------------------------------------------------------- /models/networks/utils.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import reduction 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | import torch 6 | import torchvision.models.vgg as models 7 | from torchvision import transforms 8 | import numpy as np 9 | from scipy.linalg import logm, norm 10 | from torch.nn import functional as F 11 | from PIL import Image 12 | from models.networks.networks import euler2mat 13 | from torchviz import make_dot 14 | import pytorch3d.transforms as T 15 | def init_variable(dim, n_init, device, mode='random', range=[0, 1], value=1): 16 | 17 | shape = (n_init, dim) 18 | var = torch.ones(shape, requires_grad=True, 19 | device=device, dtype=torch.float) 20 | if mode == 'random': 21 | var.data = torch.rand(shape, device=device) * \ 22 | (range[1]-range[0]) + range[0] 23 | elif mode == 'linspace': 24 | var.data = torch.linspace( 25 | range[0], range[1], steps=n_init, device=device).unsqueeze(-1) 26 | elif mode == 'constant': 27 | var.data = value*var.data 28 | else: 29 | raise NotImplementedError 30 | return var 31 | 32 | 33 | def grid_sample(image, grid, mode='bilinear', padding_mode='constant', padding_value=1): 34 | image_out = F.grid_sample(image, grid, mode=mode, padding_mode='border', align_corners=True) 35 | if padding_mode == 'constant': 36 | out_of_bound = grid[:, :, :, 0] > 1 37 | out_of_bound += grid[:, :, :, 0] < -1 38 | out_of_bound += grid[:, :, :, 1] > 1 39 | out_of_bound += grid[:, :, :, 1] < -1 40 | out_of_bound = out_of_bound.unsqueeze(1).expand(image_out.shape) 41 | image_out[out_of_bound] = padding_value 42 | return image_out 43 | 44 | 45 | def warping_grid(angle, transx, transy, scale, image_shape): 46 | cosz = torch.cos(angle) 47 | sinz = torch.sin(angle) 48 | affine_mat = torch.cat([cosz, -sinz, transx, 49 | sinz, cosz, transy], dim=1).view(image_shape[0], 2, 3) 50 | scale = scale.view(-1, 1, 1).expand(affine_mat.shape) 51 | return F.affine_grid(size=image_shape, theta=scale*affine_mat, align_corners=True) 52 | 53 | 54 | def set_axis(ax): 55 | ax.clear() 56 | ax.xaxis.set_visible(False) 57 | ax.spines['right'].set_visible(False) 58 | ax.spines['top'].set_visible(False) 59 | ax.grid(axis='y') 60 | 61 | 62 | def compute_angle_loss(angle, gt_angle): 63 | compute_loss = torch.nn.L1Loss() 64 | angle_loss = compute_loss(angle, gt_angle) 65 | angle_loss = angle_loss % 360 66 | if angle_loss >= 180 and angle_loss <= 360: 67 | angle_loss = 360 - angle_loss 68 | # print("angle", angle_loss) 69 | return angle_loss.detach().cpu().numpy() 70 | 71 | # def compute_pose_error(angle, gt_angle): 72 | # R_1 = euler2mat(angle) 73 | # R_2 = euler2mat(gt_angle) 74 | # # print("matriox", R_1.shape, R_2.shape) 75 | # R_12 = torch.bmm(R_1, torch.transpose(R_2, 1, 2)) 76 | # # print("R12", R_12.shape) 77 | # loss = (torch.einsum('bii->b', R_12)-1)/2 78 | # epsilon = 1e-6 79 | # loss = torch.clamp(loss, min=-1+epsilon, max=1-epsilon) 80 | # # print("loss 11", loss) 81 | # angle_loss = torch.acos(loss).mean() 82 | # # print("loss 2", angle_loss, angle_loss*180.0/np.pi) 83 | # # make_dot(angle_loss).render("angle_loss", format="png") 84 | # return angle_loss 85 | 86 | def compute_pose_error(angle, gt_angle): 87 | R_1 = T.euler_angles_to_matrix(angle, "XYZ") 88 | R_2 = T.euler_angles_to_matrix(gt_angle, "XYZ") 89 | # print("matriox", R_1.shape, R_2.shape) 90 | angle = T.so3_relative_angle(R_1, R_2, True).mean() 91 | # print("angle", angle) 92 | 93 | return angle 94 | 95 | def euler_to_quaternion(input): 96 | roll = input[:, 0:1] 97 | pitch = input[:, 1:2] 98 | yaw = input[:, 2:3] 99 | qx = torch.sin(roll/2) * torch.cos(pitch/2) * torch.cos(yaw/2) - torch.cos(roll/2) * torch.sin(pitch/2) * torch.sin(yaw/2) 100 | qy = torch.cos(roll/2) * torch.sin(pitch/2) * torch.cos(yaw/2) + torch.sin(roll/2) * torch.cos(pitch/2) * torch.sin(yaw/2) 101 | qz = torch.cos(roll/2) * torch.cos(pitch/2) * torch.sin(yaw/2) - torch.sin(roll/2) * torch.sin(pitch/2) * torch.cos(yaw/2) 102 | qw = torch.cos(roll/2) * torch.cos(pitch/2) * torch.cos(yaw/2) + torch.sin(roll/2) * torch.sin(pitch/2) * torch.sin(yaw/2) 103 | 104 | return torch.stack([qx, qy, qz, qw], dim=1) 105 | 106 | def compute_quat_loss(input, gt, symm=0, reduction='mean'): 107 | if symm==1: 108 | input[:,1] = 0 109 | gt[:,1] = 0 110 | compute_loss = torch.nn.MSELoss(reduction=reduction) 111 | return compute_loss(euler_to_quaternion(input), euler_to_quaternion(gt)) 112 | def compute_vgg_loss(image, gt_image): 113 | device = torch.device("cuda:0") 114 | vgg = models.vgg16(pretrained=True).to(device) 115 | normalization_mean = [0.485, 0.456, 0.406] 116 | normalization_std = [0.229, 0.224, 0.225] 117 | loader = transforms.Compose( 118 | [transforms.Normalize(mean=normalization_mean, std=normalization_std)]) 119 | vgg_features_gt = vgg(loader(gt_image).to(device)) 120 | vgg_features_image = vgg(loader(image).to(device)) 121 | compute_loss = torch.nn.MSELoss() 122 | return compute_loss(vgg_features_gt, vgg_features_image) 123 | 124 | 125 | def compute_l1_loss(image, gt_image, reduction='mean'): 126 | compute_loss = torch.nn.L1Loss(reduction=reduction) 127 | return compute_loss(image, gt_image) 128 | 129 | 130 | def compute_l2_loss(image, gt_image, reduction='mean'): 131 | compute_loss = torch.nn.MSELoss(reduction=reduction) 132 | return compute_loss(image, gt_image) 133 | 134 | 135 | def compute_pose_loss(R, R_gt, mode=1): 136 | if mode == 0: 137 | R, R_gt = map(np.matrix, [R, R_gt]) 138 | _logRR, errest = logm(R.transpose()*R_gt, disp=False) 139 | loss = norm(_logRR, 'fro') / np.sqrt(2) 140 | 141 | elif mode == 1: 142 | # print(R.shape, R_gt.shape) 143 | R, R_gt = map(np.matrix, [R, R_gt]) 144 | # Do clipping to [-1,1]. 145 | # For a few cases, (tr(R)-1)/2 can be a little bit less/greater than -1/1. 146 | logR_F = np.clip((np.trace(R*R_gt.transpose())-1.)/2., -1, 1) 147 | loss = np.arccos(logR_F) 148 | # print("poseloss", loss) 149 | return loss 150 | 151 | 152 | def compute_RotMats(a, e, t, degree=True): 153 | # print("a e t", a.shape, e.shape, t.shape) 154 | batch = a.shape[0] 155 | Rz = np.zeros((batch, 3, 3), dtype=np.float32) 156 | Rx = np.zeros((batch, 3, 3), dtype=np.float32) 157 | Rz2 = np.zeros((batch, 3, 3), dtype=np.float32) 158 | Rz[:, 2, 2] = 1 159 | Rx[:, 0, 0] = 1 160 | Rz2[:, 2, 2] = 1 161 | # 162 | R = np.zeros((batch, 3, 3), dtype=np.float32) 163 | if degree: 164 | a = a * np.pi / 180. 165 | e = e * np.pi / 180. 166 | t = t * np.pi / 180. 167 | a = -a 168 | e = np.pi/2.+e 169 | t = -t 170 | # 171 | sin_a, cos_a = np.sin(a), np.cos(a) 172 | sin_e, cos_e = np.sin(e), np.cos(e) 173 | sin_t, cos_t = np.sin(t), np.cos(t) 174 | 175 | # =========================== 176 | # rotation matrix 177 | # =========================== 178 | """ 179 | # [Transposed] 180 | Rz = np.matrix( [[ cos(a), sin(a), 0 ], # model rotate by a 181 | [ -sin(a), cos(a), 0 ], 182 | [ 0, 0, 1 ]] ) 183 | # [Transposed] 184 | Rx = np.matrix( [[ 1, 0, 0 ], # model rotate by e 185 | [ 0, cos(e), sin(e) ], 186 | [ 0, -sin(e), cos(e) ]] ) 187 | # [Transposed] 188 | Rz2= np.matrix( [[ cos(t), sin(t), 0 ], # camera rotate by t (in-plane rotation) 189 | [-sin(t), cos(t), 0 ], 190 | [ 0, 0, 1 ]] ) 191 | R = Rz2*Rx*Rz 192 | """ 193 | 194 | # Original matrix (None-transposed.) 195 | # No need to set back to zero? 196 | Rz[:, 0, 0], Rz[:, 0, 1] = cos_a, -sin_a 197 | Rz[:, 1, 0], Rz[:, 1, 1] = sin_a, cos_a 198 | # 199 | Rx[:, 1, 1], Rx[:, 1, 2] = cos_e, -sin_e 200 | Rx[:, 2, 1], Rx[:, 2, 2] = sin_e, cos_e 201 | # 202 | Rz2[:, 0, 0], Rz2[:, 0, 1] = cos_t, -sin_t 203 | Rz2[:, 1, 0], Rz2[:, 1, 1] = sin_t, cos_t 204 | # R = Rz2*Rx*Rz 205 | R[:] = np.einsum("nij,njk,nkl->nil", Rz2, Rx, Rz) 206 | 207 | # Return the original matrix without transpose! 208 | return R 209 | 210 | 211 | def compute_dis_loss(d_outs, target): 212 | 213 | d_outs = [d_outs] if not isinstance(d_outs, list) else d_outs 214 | loss = 0 215 | 216 | for d_out in d_outs: 217 | 218 | targets = d_out.new_full(size=d_out.size(), fill_value=target) 219 | loss += F.binary_cross_entropy_with_logits(d_out, targets) 220 | return loss / len(d_outs) 221 | -------------------------------------------------------------------------------- /models/nocs_gym.py: -------------------------------------------------------------------------------- 1 | from scipy.sparse.linalg.dsolve.linsolve import factorized 2 | from options.test_options import TestOptions 3 | import torch 4 | import numpy as np 5 | import tqdm 6 | import pickle 7 | import cv2 8 | import torch.nn.functional as F 9 | from scipy.spatial.transform import Rotation as scipy_rot 10 | import gym 11 | from gym import spaces 12 | from models.networks.utils import grid_sample, warping_grid, init_variable 13 | from models.networks.losses import PerceptualLoss 14 | from options.test_options import TestOptions 15 | from data import create_dataset 16 | from collections import OrderedDict 17 | from torchvision import transforms 18 | import PIL 19 | import random 20 | from models import create_model 21 | from models.networks.utils import grid_sample, warping_grid, init_variable 22 | import tqdm 23 | import torch.optim as optim 24 | import torch.nn.functional as F 25 | from utils import loss 26 | from utils.utils import is_between 27 | from models.networks.utils import grid_sample, warping_grid, init_variable 28 | from models.networks.losses import PerceptualLoss 29 | 30 | class nocs_gym(): 31 | def __init__(self, args, parser, criterion): 32 | super(nocs_gym, self).__init__() 33 | random.seed(args.seed) 34 | np.random.seed(args.seed) 35 | opt = TestOptions().parse(parser) # get test options 36 | opt.num_threads = 1 37 | opt.serial_batches = True 38 | opt.no_flip = True 39 | opt.n_views = 2592 40 | print('------------- Creating Dataset ----------------') 41 | self.opt = opt 42 | opt.category = args.dataset 43 | opt.exp_name = args.dataset 44 | self.category = args.dataset 45 | self.categories = ['bottle', 'bowl', 'camera', 'can', 'laptop', 'mug'] 46 | 47 | if self.category in ['bottle', 'can', 'bowl']: 48 | self.symm = True 49 | elif self.category in ['laptop', 'mug', 'camera']: 50 | self.symm = False 51 | self.dataset = create_dataset(opt).dataset 52 | self.n_samples = len(self.dataset) 53 | self.limits = { 54 | 'ax': [0, 1 / 2], 55 | 'ay': [-1 / 2, 1 / 2], 56 | 'az': [-50/180.0, 50/180.0], 57 | 's': [0.9, 1.2], 58 | 'tx': [-0.2, 0.2], 59 | 'ty': [-0.2, 0.2], 60 | 'z': [-3, 3] 61 | } 62 | print('-------------- Creating Model -----------------') 63 | self.current_step = 0 64 | self.max_step = args.max_step 65 | model = create_model(opt) 66 | model.setup(opt) 67 | model.eval() 68 | self.model = model 69 | self.criterion = criterion 70 | self.action_space = torch.zeros(22) 71 | self.use_encoder = False 72 | self.loss_dict = [ 73 | 'rot_loss', 74 | 'trans_loss', 75 | 'latent_loss' 76 | ] 77 | self.eval_dict = ['rot_distance', 78 | 'trans_distance'] 79 | self.step_dict = {"obs_image": None, 80 | "action": None} 81 | 82 | def warp_image(self, 83 | img, 84 | az, 85 | s=torch.tensor([[1]]), 86 | tx=torch.tensor([[0]]), 87 | ty=torch.tensor([[0]]), grad=False): 88 | az = az * np.pi 89 | grid = warping_grid(az, tx, ty, s, img.shape) 90 | img = grid_sample(img, grid) 91 | if grad == False: 92 | img = img.detach() 93 | return img 94 | 95 | def random_state(self, num, aug=False, bright=False): 96 | for i in range(num): 97 | choose = False 98 | while choose is False: 99 | random_index = int(np.random.random() * self.n_samples) 100 | data = self.dataset[random_index] 101 | if is_between(180, data['B_pose'][1], 360): 102 | data['B_pose'][1] -= 360 103 | data['B_pose'] = data['B_pose'] / 180.0 104 | 105 | if is_between(self.limits['ax'][0], data['B_pose'][0], self.limits['ax'][1]): 106 | if self.symm == True: 107 | choose = True 108 | if self.symm == False and is_between( 109 | self.limits['ay'][0], data['B_pose'][1], self.limits['ay'][1]): 110 | choose = True 111 | gt_image, self.gt_info.data[i, :3] = data['B'].unsqueeze( 112 | 0).to(self.model.device), torch.from_numpy(data['B_pose']).to(self.model.device) 113 | self.gt_info.data[i, 2:6] = torch.tensor([ 114 | np.random.uniform(low=self.limits['az'][0], 115 | high=self.limits['az'][1]), 116 | np.random.uniform(low=self.limits['s'][0], 117 | high=self.limits['s'][1]), 118 | np.random.uniform(low=self.limits['tx'][0], 119 | high=self.limits['tx'][1]), 120 | np.random.uniform(low=self.limits['ty'][0], 121 | high=self.limits['ty'][1]) 122 | ]) 123 | gt_image = self.warp_image(gt_image, 124 | self.gt_info[i:i+1, 2:3], 125 | s=self.gt_info[i:i+1, 3:4], 126 | tx=self.gt_info[i:i+1, 4:5], 127 | ty=self.gt_info[i:i+1, 5:6]) 128 | self.gt_images = torch.cat([self.gt_images, gt_image], dim=0) 129 | 130 | 131 | def reset(self, real=None, aug=False, batch_size=1): 132 | self.current_step = 0 133 | choose = batch_size 134 | self.batch_size = batch_size 135 | self.gt_info = torch.zeros(self.batch_size, 22) 136 | self.gt_images = torch.tensor(()) 137 | if real is not None: 138 | real = real[:, [2, 1, 0], :, :] 139 | self.gt_images = real.clone() 140 | else: 141 | self.random_state(self.batch_size) 142 | self.ax = init_variable(dim=1, 143 | n_init=batch_size, 144 | device=self.model.device, 145 | mode='constant', 146 | value=1.0/4.0) 147 | 148 | self.ay = init_variable(dim=1, 149 | n_init=batch_size, 150 | device=self.model.device, 151 | mode='constant', 152 | value=0) 153 | 154 | self.az = init_variable(dim=1, 155 | n_init=batch_size, 156 | device=self.model.device, 157 | mode='constant', 158 | value=0) 159 | 160 | self.s = init_variable(dim=1, 161 | n_init=batch_size, 162 | device=self.model.device, 163 | mode='constant', 164 | value=np.random.uniform( 165 | low=self.limits['s'][0], 166 | high=self.limits['s'][1])) 167 | self.tx = init_variable(dim=1, 168 | n_init=batch_size, 169 | device=self.model.device, 170 | mode='constant', 171 | value=np.random.uniform( 172 | low=self.limits['tx'][0], 173 | high=self.limits['tx'][1])) 174 | self.ty = init_variable(dim=1, 175 | n_init=batch_size, 176 | device=self.model.device, 177 | mode='constant', 178 | value=np.random.uniform( 179 | low=self.limits['ty'][0], 180 | high=self.limits['ty'][1])) 181 | self.z = init_variable(dim=16, 182 | n_init=batch_size, 183 | device=self.model.device, 184 | mode='constant', 185 | value=0) 186 | latent = self.model.netE( 187 | F.interpolate(self.gt_images, 188 | size=self.model.opt.crop_size, 189 | mode='nearest')) 190 | if self.model.opt.use_VAE: 191 | mu, logvar = latent[:, :self.model.opt. 192 | z_dim], latent[:, self.model.opt.z_dim:] 193 | std = logvar.mul(0.5).exp_() 194 | eps = self.model.z_dist.sample((1, )) 195 | self.gt_info.data[:, 6:] = eps.mul(std).add_(mu) 196 | if self.use_encoder: 197 | self.z.data = self.gt_info.data[:, 6:] 198 | 199 | angle = 180 * torch.cat( 200 | [self.ax, self.ay, torch.zeros_like(self.ay)], dim=1) 201 | fake_B = self.model.netG(self.z, angle) 202 | fake_B = self.warp_image(fake_B, self.az, self.s, self.tx, self.ty) 203 | state = torch.cat([self.gt_images.unsqueeze(1), 204 | fake_B.unsqueeze(1)], 205 | dim=1).detach() 206 | return state 207 | 208 | def sample_action(self, random=False): 209 | if random == True: 210 | factor = np.random.random() 211 | else: 212 | factor = 1.0 213 | action = factor * (self.gt_info - torch.cat( 214 | [self.ax, self.ay, self.az, self.s, self.tx, self.ty, self.z], 215 | dim=1)) 216 | action = torch.clamp(action, min=-1, max=1) 217 | return action 218 | 219 | def evaluate(self): 220 | targets = { 221 | 'rot_distance': self.gt_info[:, 0:3], 222 | 'trans_distance': self.gt_info[:, 3:6], 223 | } 224 | rot = torch.cat([self.ax, self.ay, self.az], dim=1) 225 | trans = torch.cat([self.s, self.tx, self.ty], dim=1) 226 | outputs = { 227 | 'rot_distance': rot, 228 | 'trans_distance': trans, 229 | } 230 | loss_dict = self.criterion(outputs, targets, self.eval_dict) 231 | return loss_dict 232 | 233 | def calc_loss(self, action, eval=False): 234 | action_gt = self.sample_action() 235 | targets = { 236 | 'rot_loss': action_gt[:, 0:3], 237 | 'trans_loss': action_gt[:, 3:6], 238 | 'latent_loss': action_gt[:, 6:22] 239 | } 240 | outputs = { 241 | 'rot_loss': action[:, 0:3], 242 | 'trans_loss': action[:, 3:6], 243 | 'latent_loss': action[:, 6:22] 244 | } 245 | if eval == False: 246 | loss_dict = self.criterion(outputs, targets, self.loss_dict) 247 | weight_dict = self.criterion.weight_dict 248 | 249 | self.losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() 250 | if k in weight_dict) 251 | self.log_losses = { 252 | "train/" + k: (loss_dict[k] * weight_dict[k]) 253 | for k in loss_dict.keys() if k in weight_dict 254 | } 255 | self.log_losses["train/loss"] = self.losses 256 | 257 | else: 258 | loss_dict = self.criterion(outputs, targets, self.loss_dict) 259 | weight_dict = self.criterion.weight_dict 260 | 261 | self.losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() 262 | if k in weight_dict) 263 | self.log_losses = { 264 | "test/" + k: (loss_dict[k] * weight_dict[k]) 265 | for k in loss_dict.keys() if k in weight_dict 266 | } 267 | self.log_losses["test/loss"] = self.losses 268 | 269 | def step(self, action, eval=False): 270 | decay_factor = 1 271 | self.current_step += 1 272 | self.calc_loss(action, eval) 273 | self.ax.data += action[:, 0:1] *1/2*decay_factor 274 | self.ay.data += action[:, 1:2] *1/2* decay_factor 275 | self.az.data += action[:, 2:3] *1/2* decay_factor 276 | self.s.data += action[:, 3:4] * decay_factor 277 | self.tx.data += action[:, 4:5] * decay_factor 278 | self.ty.data += action[:, 5:6] * decay_factor 279 | self.z.data += action[:, 6:22] * decay_factor 280 | self.ax.data = torch.clamp(self.ax.data, 281 | min=self.limits['ax'][0], 282 | max=self.limits['ax'][1]) 283 | self.ay.data = torch.clamp(self.ay.data, 284 | min=self.limits['ay'][0], 285 | max=self.limits['ay'][1]) 286 | self.az.data = torch.clamp(self.az.data, 287 | min=self.limits['az'][0], 288 | max=self.limits['az'][1]) 289 | self.s.data = torch.clamp(self.s.data, 290 | min=self.limits['s'][0], 291 | max=self.limits['s'][1]) 292 | self.tx.data = torch.clamp(self.tx.data, 293 | min=self.limits['tx'][0], 294 | max=self.limits['tx'][1]) 295 | self.ty.data = torch.clamp(self.ty.data, 296 | min=self.limits['ty'][0], 297 | max=self.limits['ty'][1]) 298 | self.z.data = torch.clamp(self.z.data, 299 | min=self.limits['z'][0], 300 | max=self.limits['z'][1]) 301 | 302 | done = False 303 | angle = 180 * \ 304 | torch.cat([self.ax, self.ay, 305 | torch.zeros_like(self.ay)], dim=1) 306 | fake_B = self.model.netG(self.z, angle) 307 | fake_B = self.warp_image(fake_B, self.az, self.s, self.tx, 308 | self.ty) 309 | state = torch.cat([self.gt_images.unsqueeze(1), 310 | fake_B.unsqueeze(1)], 311 | dim=1).detach() 312 | 313 | return state 314 | 315 | def optimize(self, iter=10): 316 | real_B = self.gt_images 317 | variable_dict = [ 318 | { 319 | 'params': self.z, 320 | 'lr': 3e-1 321 | }, 322 | { 323 | 'params': self.ax, 324 | 'lr': 1e-2 325 | }, 326 | { 327 | 'params': self.ay, 328 | 'lr': 3e-2 329 | }, 330 | { 331 | 'params': self.az, 332 | 'lr': 1e-2 333 | }, 334 | { 335 | 'params': self.tx, 336 | 'lr': 3e-2 337 | }, 338 | { 339 | 'params': self.ty, 340 | 'lr': 3e-2 341 | }, 342 | { 343 | 'params': self.s, 344 | 'lr': 3e-2 345 | }, 346 | ] 347 | 348 | optimizer = optim.Adam(variable_dict, betas=(0.5, 0.999)) 349 | 350 | losses = [('VGG', 1, PerceptualLoss(reduce=False))] 351 | 352 | reg_creterion = torch.nn.MSELoss(reduce=False) 353 | self.opt.n_iter = iter 354 | loss_history = np.zeros((self.gt_images.shape[0], self.opt.n_iter, len(losses) + 1)) 355 | state_history = np.zeros((self.gt_images.shape[0], self.opt.n_iter, 6 + self.opt.z_dim)) 356 | image_history = [] 357 | from torchviz import make_dot 358 | for iter in range(self.opt.n_iter): 359 | 360 | optimizer.zero_grad() 361 | 362 | angle = 180 * torch.cat([self.ax, self.ay, torch.zeros_like(self.az)], dim=1) 363 | 364 | fake_B = self.model.netG(self.z, angle) 365 | # g = make_dot(fake_B) 366 | # g.view() 367 | fake_B = self.warp_image(fake_B, self.az, self.s, self.tx, 368 | self.ty, grad=True) 369 | 370 | fake_B_upsampled = F.interpolate(fake_B, 371 | size=real_B.shape[-1], 372 | mode='bilinear') 373 | 374 | error_all = 0 375 | for l, (name, weight, criterion) in enumerate(losses): 376 | error = weight * \ 377 | criterion(fake_B_upsampled, real_B).view( 378 | 1, -1).mean(1) 379 | loss_history[:, iter, l] = error.data.cpu().numpy() 380 | error_all = error_all + error 381 | 382 | error = self.opt.lambda_reg * \ 383 | reg_creterion(self.z, torch.zeros_like(self.z)).view( 384 | 1, -1).mean(1) 385 | loss_history[:, iter, l + 1] = error.data.cpu().numpy() 386 | error_all = error_all + error 387 | error_all.backward() 388 | optimizer.step() 389 | image_history.append(fake_B) 390 | 391 | state_history[:, iter, :3] = 180 * \ 392 | torch.cat([-self.ay-0.5, self.ax+1, -self.az], dim=-1).data.cpu().numpy() 393 | state_history[:, iter, 3:] = torch.cat([self.tx, self.ty, self.s, self.z], 394 | dim=-1).data.cpu().numpy() 395 | return state_history, loss_history, image_history -------------------------------------------------------------------------------- /nocs/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Normalized Object Coordinate Space for Category-Level 6D Object Pose and Size Estimation 3 | Detection and evaluation 4 | 5 | Modified based on Mask R-CNN(https://github.com/matterport/Mask_RCNN) 6 | Written by He Wang 7 | """ 8 | 9 | import os 10 | import argparse 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--num_eval', type=int, default=-1) 13 | parser.add_argument('--result_path', type=str, default='./results/') 14 | parser.add_argument('--name', type=str, default='pose_estimation') 15 | parser.add_argument('--dataset', type=str, default='bowl') 16 | 17 | 18 | args = parser.parse_args() 19 | args.result_path = os.path.join(args.result_path, args.name) 20 | num_eval = args.num_eval 21 | 22 | import glob 23 | import numpy as np 24 | import utils as utils 25 | import _pickle as cPickle 26 | import matplotlib as mpl 27 | mpl.use('Agg') 28 | 29 | if __name__ == '__main__': 30 | 31 | 32 | # real classes 33 | coco_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 34 | 'bus', 'train', 'truck', 'boat', 'traffic light', 35 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 36 | 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 37 | 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 38 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 39 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 40 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 41 | 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 42 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 43 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 44 | 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 45 | 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 46 | 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 47 | 'teddy bear', 'hair drier', 'toothbrush'] 48 | 49 | 50 | synset_names = ['BG', #0 51 | 'bottle', #1 52 | 'bowl', #2 53 | 'camera', #3 54 | 'can', #4 55 | 'laptop',#5 56 | 'mug'#6 57 | ] 58 | 59 | class_map = { 60 | 'bottle': 'bottle', 61 | 'bowl':'bowl', 62 | 'cup':'mug', 63 | 'laptop': 'laptop', 64 | } 65 | 66 | 67 | coco_cls_ids = [] 68 | for coco_cls in class_map: 69 | ind = coco_names.index(coco_cls) 70 | coco_cls_ids.append(ind) 71 | 72 | result_pkl_list = glob.glob(os.path.join(args.result_path, 'results_*.pkl')) 73 | result_pkl_list = sorted(result_pkl_list)[:num_eval] 74 | assert len(result_pkl_list) 75 | 76 | final_results = [] 77 | for pkl_path in result_pkl_list: 78 | # print(pkl_path) 79 | if os.path.getsize(pkl_path) > 0: 80 | with open(pkl_path, 'rb') as f: 81 | result = cPickle.load(f) 82 | if not 'gt_handle_visibility' in result: 83 | result['gt_handle_visibility'] = np.ones_like(result['gt_class_ids']) 84 | print('can\'t find gt_handle_visibility in the pkl.') 85 | else: 86 | assert len(result['gt_handle_visibility']) == len(result['gt_class_ids']), "{} {}".format(result['gt_handle_visibility'], result['gt_class_ids']) 87 | 88 | 89 | if type(result) is list: 90 | final_results += result 91 | elif type(result) is dict: 92 | final_results.append(result) 93 | else: 94 | assert False 95 | 96 | aps = utils.compute_degree_cm_mAP(args.name, final_results, synset_names, args.result_path, args.dataset, 97 | degree_thresholds = range(0, 61, 1),#range(0, 61, 1), 98 | shift_thresholds= np.linspace(0, 1, 31)*15, #np.linspace(0, 1, 31)*15, 99 | iou_3d_thresholds=np.linspace(0, 1, 101), 100 | iou_pose_thres=0.1, 101 | use_matches_for_pose=True) 102 | 103 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import argparse 5 | import torch 6 | 7 | import data 8 | import models 9 | from utils import util 10 | from utils.visualizer.base_visualizer import BaseVisualizer as Visualizer 11 | 12 | 13 | class BaseOptions(): 14 | """This class defines options used during both training and test time. 15 | 16 | It also implements several helper functions such as parsing, printing, and saving the options. 17 | It also gathers additional options defined in functions in both dataset class and model class. 18 | """ 19 | 20 | def __init__(self): 21 | """Reset the class; indicates the class hasn't been initailized""" 22 | self.initialized = False 23 | 24 | def initialize(self, parser): 25 | """Define the common options that are used in both training and test.""" 26 | parser.add_argument('--config', type=str, default='./configs/infer.yaml') 27 | # basic parameters 28 | basic_args = parser.add_argument_group('nocs_basic') 29 | basic_args.add_argument('--project_name', type=str, default='project template',help='project name, use project folder name by default') 30 | basic_args.add_argument('--dataroot', type=str,help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 31 | basic_args.add_argument('--run_name', type=str, default='', help='id of the experiment run, specified as string format, e.g. lr={lr} or string. Using current datetime by default') 32 | basic_args.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 33 | basic_args.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 34 | # model parameters 35 | model_args = parser.add_argument_group('nocs_model') 36 | model_args.add_argument('--model', type=str, default='latent_object', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') 37 | model_args.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 38 | model_args.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 39 | model_args.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 40 | model_args.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 41 | # dataset parameters 42 | data_args = parser.add_argument_group('nocs_data') 43 | data_args.add_argument('--dataset_mode', type=str, default='nocs_hdf5', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') 44 | data_args.add_argument('--num_threads', default=0, type=int, help='# threads for loading data') 45 | # data_args.add_argument('--batch_size', type=int, default=1, help='input batch size') 46 | data_args.add_argument('--load_size', type=int, default=64, help='scale images to this size') 47 | data_args.add_argument('--crop_size', type=int, default=64, help='then crop to this size') 48 | data_args.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 49 | data_args.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 50 | data_args.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 51 | data_args.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 52 | data_args.add_argument('--keep_last', action='store_true', help='drop the last batch of the dataset to keep batch size consistent.') 53 | # additional parameters 54 | misc_args = parser.add_argument_group('nocs_misc') 55 | misc_args.add_argument('--load_suffix', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 56 | misc_args.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 57 | misc_args.add_argument('--visualizers', nargs='+', type=str, default=['terminal', 'wandb'], help='visualizers to use. local | wandb') 58 | self.initialized = True 59 | return parser 60 | 61 | def gather_options(self, parent_parser): 62 | """Initialize our parser with basic options(only once). 63 | Add additional model-specific and dataset-specific options. 64 | These options are defined in the function 65 | in model and dataset classes. 66 | """ 67 | if not self.initialized: # check if it has been initialized 68 | # parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 69 | # parser = parent_parser.add_argument_group(title="nocs image generator") 70 | # subparsers = parent_parser.add_subparsers() 71 | # parser = parent_parser.add_parser('nocs', help='nocs help') 72 | parser = self.initialize(parent_parser) 73 | 74 | 75 | # get the basic options 76 | opt, _ = parser.parse_known_args() 77 | if opt.config is not None: opt = self.load_options(opt) 78 | # modify model-related parser options 79 | model_name = opt.model 80 | model_option_setter = models.get_option_setter(model_name) 81 | parser = model_option_setter(parser, self.isTrain) 82 | opt, args = parser.parse_known_args() # parse again with new defaults 83 | if opt.config is not None: opt = self.load_options(opt) 84 | 85 | # modify dataset-related parser options 86 | dataset_name = opt.dataset_mode 87 | dataset_option_setter = data.get_option_setter(dataset_name) 88 | parser = dataset_option_setter(parser, self.isTrain) 89 | 90 | # modify visualization-related parser options 91 | parser = Visualizer.modify_commandline_options(parser) 92 | 93 | # save and return the parser 94 | self.parser = parser 95 | opt = parser.parse_args() 96 | if opt.config is not None: opt = self.load_options(opt) 97 | 98 | opt.exp_name = opt.category 99 | 100 | return opt 101 | 102 | def print_options(self, opt): 103 | """Print and save options 104 | 105 | It will print both current options and default values(if different). 106 | It will save options into a text file / [checkpoints_dir] / opt.txt 107 | """ 108 | message = '' 109 | message += '----------------- Options ---------------\n' 110 | for k, v in sorted(vars(opt).items()): 111 | comment = '' 112 | default = self.parser.get_default(k) 113 | if v != default: 114 | comment = '\t[default: %s]' % str(default) 115 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 116 | message += '----------------- End -------------------' 117 | print(message) 118 | 119 | def save_options(self,opt): 120 | output_dict = {} 121 | for group in self.parser._action_groups: 122 | if group.title in ['positional arguments', 'optional arguments']: continue 123 | output_dict[group.title] = {a.dest: getattr(opt, a.dest, None) for a in group._group_actions} 124 | 125 | import yaml 126 | 127 | if self.isTrain: 128 | output_path = os.path.join(opt.checkpoints_dir,opt.project_name,opt.exp_name, opt.run_name,'config.yaml') 129 | else: 130 | output_path = os.path.join(opt.results_dir, opt.project_name, opt.test_name, 'config.yaml') 131 | 132 | util.mkdirs(os.path.dirname(output_path)) 133 | with open(output_path, 'w') as f: 134 | yaml.dump(output_dict,f,default_flow_style=False, sort_keys=True) 135 | 136 | def load_options(self,opt): 137 | assert(opt.config is not None) 138 | from envyaml import EnvYAML 139 | 140 | args_usr = [ arg[2:] for arg in sys.argv if '--' in arg] 141 | config = EnvYAML(opt.config,include_environment=False) 142 | for name in config.keys(): 143 | # make sure yaml won't overwrite cmd input and the arg is defined 144 | basename = name.split('.')[-1] 145 | if basename not in args_usr and hasattr(opt,basename): 146 | setattr(opt, basename, config[name]) 147 | 148 | return opt 149 | 150 | def parse(self, parent_parser): 151 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 152 | opt = self.gather_options(parent_parser) 153 | 154 | opt.isTrain = self.isTrain # train or test 155 | 156 | # process opt.run_name 157 | if opt.run_name != '': 158 | opt.run_name = opt.run_name.format(**vars(opt)) 159 | else: 160 | from datetime import datetime 161 | opt.run_name = datetime.now().strftime("%d-%m-%Y %H:%M:%S") 162 | 163 | self.save_options(opt) 164 | 165 | if opt.verbose: 166 | self.print_options(opt) 167 | 168 | # set gpu ids 169 | str_ids = opt.gpu_ids.split(',') 170 | opt.gpu_ids = [] 171 | for str_id in str_ids: 172 | id = int(str_id) 173 | if id >= 0: 174 | opt.gpu_ids.append(id) 175 | if len(opt.gpu_ids) > 0: 176 | torch.cuda.set_device(opt.gpu_ids[0]) 177 | 178 | self.opt = opt 179 | return self.opt -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | It also includes shared options defined in BaseOptions. 7 | """ 8 | 9 | def initialize(self, parser): 10 | parser = BaseOptions.initialize(self, parser) # define shared options 11 | 12 | test_args = parser.add_argument_group('test') 13 | 14 | test_args.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 15 | 16 | test_args.add_argument('--target_size', type=int, default=64, help='resize the test images to this size') 17 | test_args.add_argument('--vis', action='store_true', help='visualize the fitting results') 18 | 19 | test_args.add_argument('--num_agent', type=int, default=10, help='number of evaluation agents running in parallel') 20 | test_args.add_argument('--id_agent', type=int, default=0, help='the id of current agents') 21 | 22 | test_args.add_argument('--test_name', type=str, default='fitting', help='test name') 23 | test_args.add_argument('--skip', type=int, default=1, help='evaluate every n-th sample') 24 | 25 | # rewrite devalue values 26 | test_args.set_defaults(model='test') 27 | # To avoid cropping, the load_size should be the same as crop_size 28 | test_args.set_defaults(load_size=parser.get_default('crop_size')) 29 | self.isTrain = False 30 | return parser -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visualization parameters 13 | log_args = parser.add_argument_group('log') 14 | log_args.add_argument('--display_freq', type=int, default=10, help='frequency of showing training results on screen') 15 | log_args.add_argument('--print_freq', type=int, default=1, help='frequency of showing training results on console') 16 | # network saving and loading parameters 17 | save_args = parser.add_argument_group('save') 18 | save_args.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 19 | save_args.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 20 | save_args.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 21 | save_args.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 22 | # training parameters 23 | train_args = parser.add_argument_group('train') 24 | train_args.add_argument('--niter', type=int, default=15, help='# of iter at starting learning rate') 25 | train_args.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero') 26 | train_args.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 27 | train_args.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 28 | train_args.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 29 | train_args.add_argument('--n_views', type=int, default=2592, help='number of training views per sample') 30 | log_args.add_argument('--name', type=str, default="optimize", 31 | help='hidden size (default: 256)') 32 | log_args.add_argument('--resume', type=bool, default=False, 33 | help='hidden size (default: 256)') 34 | log_args.add_argument('--eval', type=bool, default=False, 35 | help='hidden size (default: 256)') 36 | self.isTrain = True 37 | return parser 38 | -------------------------------------------------------------------------------- /prepare_datasets.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | unzip 3 | 4 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/checkpoints.zip 5 | unzip checkpoints.zip 6 | rm checkpoints.zip 7 | 8 | mkdir -p datasets/test 9 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/nocs_det.zip -P ./datasets/test/ 10 | unzip ./datasets/test/nocs_det.zip -d ./datasets/test/ 11 | rm ./datasets/test/nocs_det.zip 12 | 13 | wget http://download.cs.stanford.edu/orion/nocs/real_test.zip -P ./datasets/test/ 14 | unzip ./datasets/test/real_test.zip -d ./datasets/test/ 15 | rm ./datasets/test/real_test.zip 16 | 17 | mkdir -p datasets/train 18 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/bottle.hdf5 -P ./datasets/train/ 19 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/bowl.hdf5 -P ./datasets/train/ 20 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/camera.hdf5 -P ./datasets/train/ 21 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/can.hdf5 -P ./datasets/train/ 22 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/laptop.hdf5 -P ./datasets/train/ 23 | wget https://dataset.ait.ethz.ch/downloads/IJNQ4hZGrB/datasets/train/mug.hdf5 -P ./datasets/train/ 24 | -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torchvision.utils import save_image, make_grid 5 | import os 6 | import torch 7 | import numpy as np 8 | import tqdm 9 | import pickle 10 | from PIL import Image 11 | import cv2 12 | import torch.nn.functional as F 13 | from scipy.spatial.transform import Rotation as scipy_rot 14 | import wandb 15 | from PIL import Image 16 | from utils.utils import show_AP 17 | def train_one_epoch(args, agent, env, i_episode): 18 | 19 | state = env.reset(batch_size=args.batch_size) 20 | for i in range(args.max_step): 21 | agent.policy_optim.zero_grad() 22 | # sample action from gaussian policy 23 | action = agent.select_action(state) 24 | state = env.step(action, eval=False) 25 | loss = env.losses 26 | loss.backward() 27 | agent.policy_optim.step() 28 | 29 | if args.log == True and i_episode % args.log_interval == 0: 30 | # evaluation for the rotation error and translation error 31 | eval_loss = env.evaluate() 32 | eval_log = { 33 | "train/" + k: (eval_loss[k]).mean() 34 | for k in eval_loss.keys() 35 | } 36 | state = state.reshape(-1, state.shape[2], state.shape[3], state.shape[4]) 37 | grid_image = make_grid(state, nrow=int(np.sqrt(state.shape[0]))) 38 | ndarr = grid_image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 39 | img = Image.fromarray(ndarr) 40 | # Log to wandb 41 | img = wandb.Image(img) 42 | wandb.log(env.log_losses, step=i_episode) 43 | wandb.log(eval_log, step=i_episode) 44 | wandb.log({"train/state": img}, step=i_episode) 45 | 46 | # Evaluation on synthetic dataset 47 | def eval_one_epoch(args, agent, env, i_episode=0, episodes=10): 48 | total_loss = 0 49 | state_history = [] 50 | 51 | state = env.reset(batch_size = episodes) 52 | for i in range(args.max_step): 53 | state_history.append(state[:, 1, :, :, :].data) 54 | action = agent.select_action(state, True) 55 | state = env.step(action, eval=True).clone() 56 | 57 | total_loss += env.losses 58 | state_history.append(state[:, 1, :, :, :].data) 59 | # Start GD optimization after IL 60 | if args.gd_optimize == True: 61 | _, _, image_history = env.optimize(iter=10) 62 | state_history.append(image_history[len(image_history)-1]) 63 | state_history.append(env.gt_images) 64 | 65 | eval_loss = env.evaluate() 66 | avg_loss = total_loss / episodes 67 | state_history = torch.stack(state_history).reshape(-1, 3, 64, 64) 68 | grid_image = make_grid(state_history, nrow=episodes) 69 | ndarr = grid_image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 70 | img = Image.fromarray(ndarr) 71 | print("===========Test Results=============") 72 | print("Rotation error: {} degree, \nTranslation error: {} cm".format(eval_loss['rot_distance'].mean(), eval_loss['trans_distance'].mean())) 73 | img.save(os.path.join(args.save_path, "eval_" + str(episodes) + ".png")) 74 | show_AP(np.array(eval_loss['rot_distance'].detach().cpu()), args.name, args.save_path) 75 | if args.log == True: 76 | eval_log = { 77 | "test/" + k: (eval_loss[k]).mean() 78 | for k in eval_loss.keys() 79 | } 80 | img = wandb.Image(img) 81 | wandb.log(env.log_losses, step=i_episode) 82 | wandb.log({"test/avg loss": avg_loss}, step=i_episode) 83 | wandb.log(eval_log, step=i_episode) 84 | wandb.log({"test/state": img}, step=i_episode) 85 | return avg_loss 86 | 87 | # Evaluation on real dataset 88 | def eval_nocs(args, agent, env): 89 | intrinsics = np.array( 90 | [[591.0125, 0, 322.525], [0, 590.16775, 244.11084], [0, 0, 1]]) 91 | # Rendering parameters 92 | focal_lengh_render = 70. 93 | image_size_render = 64 94 | 95 | # Average scales from the synthetic training set CAMERA 96 | mean_scales = np.array([0.34, 0.21, 0.19, 0.15, 0.46, 0.17]) 97 | 98 | env.opt.dataroot = './datasets/test/' 99 | nocs_list = sorted(os.listdir(os.path.join( 100 | './datasets/test/', 'nocs_det')))[::env.opt.skip] 101 | 102 | interval = len(nocs_list)//(env.opt.num_agent - 103 | 1) if env.opt.num_agent > 1 else len(nocs_list) 104 | task_range = nocs_list[interval*env.opt.id_agent:min( 105 | interval*(env.opt.id_agent+1), len(nocs_list))] 106 | 107 | output_folder = args.save_path 108 | print("Starting evaluation on nocs for ", env.category) 109 | image_all = [] 110 | eval_index = 0 111 | for file_name in tqdm.tqdm(task_range): 112 | file_path = os.path.join(env.opt.dataroot, 'nocs_det', file_name) 113 | pose_file = pickle.load(open(file_path, 'rb'), encoding='utf-8') 114 | 115 | image_name = pose_file['image_path'].replace( 116 | 'data/real/test', env.opt.dataroot+'/real_test/')+'_color.png' 117 | image = cv2.imread(image_name)[:, :, ::-1] 118 | 119 | masks = pose_file['pred_mask'] 120 | bboxes = pose_file['pred_bboxes'] 121 | 122 | pose_file['pred_RTs_ours'] = np.zeros_like(pose_file['pred_RTs']) 123 | class_pred = env.categories.index(env.opt.category)+1 124 | if class_pred in pose_file['pred_class_ids']: 125 | label = [i for i, e in enumerate( 126 | pose_file['pred_class_ids']) if e == class_pred] 127 | 128 | for id in label: 129 | eval_index += 1 130 | bbox = bboxes[id] 131 | image_mask = image.copy() 132 | image_mask[masks[:, :, id] == 0, :] = 255 133 | image_mask = image_mask[bbox[0]:bbox[2], bbox[1]:bbox[3], :] 134 | # A = transforms.ToTensor()(Image.fromarray(cv2.cvtColor(image_mask,cv2.COLOR_BGR2RGB))).unsqueeze(0) 135 | 136 | A = (torch.from_numpy(image_mask.astype(np.float32)).cuda( 137 | ).unsqueeze(0).permute(0, 3, 1, 2) / 255) * 2 - 1 138 | _, c, h, w = A.shape 139 | s = max(h, w) + 30 140 | A = F.pad(A, [(s - w)//2, (s - w) - (s - w)//2, 141 | (s - h)//2, (s - h) - (s - h)//2], value=1) 142 | A = F.interpolate( 143 | A, size=env.opt.target_size, mode='bilinear') 144 | A = A.to(env.model.device) 145 | 146 | state = env.reset(real=A.clone()) 147 | 148 | for i in range(args.max_step): 149 | image_all.append(state[:, 1, :, :, :]) 150 | action = agent.select_action(state, True) 151 | state = env.step(action, eval=True) 152 | image_all.append(state[:, 1, :, :, :]) 153 | if args.gd_optimize == True: 154 | state_history, loss_history, image_history = env.optimize(iter=10) 155 | image_all.append(image_history[len(image_history)-1]) 156 | states = state_history[-1, -1, :] 157 | else: 158 | states_angle = 180 * \ 159 | torch.cat([-env.ay-0.5, env.ax+1, -env.az], dim=-1) 160 | states_other = torch.cat([env.tx, env.ty, env.s, env.z], dim=-1) 161 | states = torch.cat([states_angle, states_other], dim=-1).data.squeeze(0).cpu().numpy() 162 | 163 | 164 | image_all.append(env.gt_images) 165 | 166 | pose_file['pred_RTs_ours'][id][:3, :3] = scipy_rot.from_euler( 167 | 'yxz', states[:3], degrees=True).as_dcm()[:3, :3] 168 | angle = -states[2] / 180 * np.pi 169 | mat = np.array([[states[5]*np.cos(angle), -states[5]*np.sin(angle), states[5]*states[3]], 170 | [states[5]*np.sin(angle), states[5]*np.cos( 171 | angle), states[5]*states[4]], 172 | [0, 0, 1]]) 173 | 174 | mat_inv = np.linalg.inv(mat) 175 | u = (bbox[1] + bbox[3])/2 + mat_inv[0, 2]*s/2 176 | v = (bbox[0] + bbox[2])/2 + mat_inv[1, 2]*s/2 177 | 178 | z = image_size_render/(s/states[5]) * ( 179 | intrinsics[0, 0]+intrinsics[1, 1])/2 / focal_lengh_render * mean_scales[class_pred-1] 180 | 181 | pose_file['pred_RTs_ours'][id][2, 3] = z 182 | pose_file['pred_RTs_ours'][id][0, 3] = ( 183 | u - intrinsics[0, 2])/intrinsics[0, 0]*z 184 | pose_file['pred_RTs_ours'][id][1, 3] = ( 185 | v - intrinsics[1, 2])/intrinsics[1, 1]*z 186 | pose_file['pred_RTs_ours'][id][3, 3] = 1 187 | 188 | 189 | 190 | 191 | f = open(os.path.join(output_folder, file_name), 'wb') 192 | pickle.dump(pose_file, f, -1) 193 | 194 | image_all = torch.stack(image_all).reshape(-1, 3, 64, 64) 195 | grid_image = make_grid(image_all, nrow=eval_index) 196 | ndarr = grid_image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 197 | img = Image.fromarray(ndarr) 198 | img.save(os.path.join(args.save_path, "eval_" + env.category + ".png")) 199 | -------------------------------------------------------------------------------- /sac_2/SAC.py: -------------------------------------------------------------------------------- 1 | # Adaptive from https://github.com/pranz24/pytorch-soft-actor-critic 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.optim import Adam 6 | from utils.utils import soft_update, hard_update 7 | from sac_2.model import GaussianPolicy, QNetwork, DeterministicPolicy 8 | 9 | class SAC(object): 10 | def __init__(self, num_inputs, action_space, args, eval_path, device): 11 | 12 | self.gamma = args.gamma 13 | self.tau = args.tau 14 | self.alpha = args.alpha 15 | self.eval_path = eval_path 16 | self.policy_type = args.policy 17 | self.target_update_interval = args.target_update_interval 18 | self.automatic_entropy_tuning = args.automatic_entropy_tuning 19 | 20 | self.device = device 21 | 22 | self.critic = QNetwork(num_inputs, num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device) 23 | self.critic_optim = Adam(self.critic.parameters(), lr=args.lr) 24 | 25 | self.critic_target = QNetwork(num_inputs, num_inputs, action_space.shape[0], args.hidden_size).to(self.device) 26 | hard_update(self.critic_target, self.critic) 27 | 28 | if self.policy_type == "Gaussian": 29 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper 30 | if self.automatic_entropy_tuning is True: 31 | self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() 32 | self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) 33 | self.alpha_optim = Adam([self.log_alpha], lr=args.lr) 34 | 35 | self.policy = GaussianPolicy(num_inputs, num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) 36 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) 37 | 38 | else: 39 | self.alpha = 0 40 | self.automatic_entropy_tuning = False 41 | self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device) 42 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr) 43 | 44 | def select_action(self, state, evaluate=False): 45 | if evaluate is False: 46 | action, _, _ = self.policy.sample(state) 47 | else: 48 | _, _, action = self.policy.sample(state) 49 | return action 50 | 51 | def freeze_policy_angle(self, freeze=True): 52 | for net in self.policy.angle_net: 53 | for param in net.parameters(): 54 | param.requires_grad = not freeze 55 | 56 | def freeze_policy_trans(self, freeze=True): 57 | for net in self.policy.trans_net: 58 | for param in net.parameters(): 59 | param.requires_grad = not freeze 60 | 61 | def update_parameters(self, memory, batch_size, updates): 62 | # Sample a batch from memory 63 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size) 64 | state_batch = torch.FloatTensor(state_batch).to(self.device).squeeze(1) 65 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device).squeeze(1) 66 | action_batch = torch.FloatTensor(action_batch).to(self.device).squeeze(1) 67 | reward_batch = torch.FloatTensor(reward_batch).to(self.device) 68 | mask_batch = torch.FloatTensor(mask_batch).to(self.device) 69 | with torch.no_grad(): 70 | # print(np.shape(next_state_batch)) 71 | next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch) 72 | qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action) 73 | min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi 74 | next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target) 75 | qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step 76 | # print(reward_batch.shape, mask_batch.shape, min_qf_next_target.shape) 77 | qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] 78 | qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] 79 | qf_loss = qf1_loss + qf2_loss 80 | self.critic_optim.zero_grad() 81 | qf_loss.backward() 82 | self.critic_optim.step() 83 | 84 | pi, log_pi, _ = self.policy.sample(state_batch) 85 | 86 | qf1_pi, qf2_pi = self.critic(state_batch, pi) 87 | min_qf_pi = torch.min(qf1_pi, qf2_pi) 88 | 89 | policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))] 90 | 91 | self.policy_optim.zero_grad() 92 | policy_loss.backward() 93 | self.policy_optim.step() 94 | # g = make_dot 95 | if self.automatic_entropy_tuning: 96 | alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() 97 | 98 | self.alpha_optim.zero_grad() 99 | alpha_loss.backward() 100 | self.alpha_optim.step() 101 | 102 | self.alpha = self.log_alpha.exp() 103 | alpha_tlogs = self.alpha.clone() # For TensorboardX logs 104 | else: 105 | alpha_loss = torch.tensor(0.).to(self.device) 106 | alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs 107 | 108 | 109 | if updates % self.target_update_interval == 0: 110 | soft_update(self.critic_target, self.critic, self.tau) 111 | 112 | return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item() 113 | 114 | # Save model parameters 115 | def save_model(self, env_name, name, suffix="", actor_path=None, critic_path=None): 116 | if not os.path.exists(name + '/'): 117 | os.makedirs(name + '/') 118 | 119 | if actor_path is None: 120 | actor_path = name + '/' + "sac_actor_{}_{}".format(env_name, suffix) 121 | if critic_path is None: 122 | critic_path = name + '/' + "sac_critic_{}_{}".format(env_name, suffix) 123 | print('Saving models to {} and {}'.format(actor_path, critic_path)) 124 | torch.save(self.policy.state_dict(), actor_path) 125 | torch.save(self.critic.state_dict(), critic_path) 126 | 127 | 128 | # Load model parameters 129 | def load_model(self, actor_path, critic_path): 130 | print('Loading models from {} and {}'.format(actor_path, critic_path)) 131 | if actor_path is not None: 132 | pretrained_dict = torch.load(actor_path) 133 | model_dict = self.policy.state_dict() 134 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 135 | model_dict.update(pretrained_dict) 136 | self.policy.load_state_dict(model_dict) 137 | 138 | if critic_path is not None: 139 | pretrained_dict = torch.load(critic_path) 140 | model_dict = self.critic.state_dict() 141 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 142 | model_dict.update(pretrained_dict) 143 | self.critic.load_state_dict(model_dict) 144 | 145 | -------------------------------------------------------------------------------- /sac_2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrld/visual_navigation_pose_estimation/58d98a3592157f2558120f18af7c9ec77e795ee1/sac_2/__init__.py -------------------------------------------------------------------------------- /sac_2/model.py: -------------------------------------------------------------------------------- 1 | # Adaptive from https://github.com/pranz24/pytorch-soft-actor-critic 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions import Normal 6 | LOG_SIG_MAX = 2 7 | LOG_SIG_MIN = -20 8 | epsilon = 1e-6 9 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 10 | 11 | # Initialize Policy weights 12 | def weights_init_(m): 13 | if isinstance(m, nn.Linear): 14 | torch.nn.init.xavier_uniform_(m.weight, gain=1) 15 | torch.nn.init.constant_(m.bias, 0) 16 | 17 | 18 | class ValueNetwork(nn.Module): 19 | def __init__(self, num_inputs, hidden_dim): 20 | super(ValueNetwork, self).__init__() 21 | 22 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 23 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 24 | self.linear3 = nn.Linear(hidden_dim, 1) 25 | 26 | self.apply(weights_init_) 27 | 28 | def forward(self, state): 29 | x = F.relu(self.linear1(state)) 30 | x = F.relu(self.linear2(x)) 31 | x = self.linear3(x) 32 | return x 33 | 34 | class QNetwork(nn.Module): 35 | def __init__(self, num_inputs, state_len, num_actions, hidden_dim): 36 | super(QNetwork, self).__init__() 37 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2) 38 | self.bn1 = nn.BatchNorm2d(16) 39 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2) 40 | self.bn2 = nn.BatchNorm2d(32) 41 | self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=2) 42 | self.bn3 = nn.BatchNorm2d(32) 43 | 44 | # Number of Linear input connections depends on output of conv2d layers 45 | # and therefore the input image size, so compute it. 46 | def conv2d_size_out(size, kernel_size=3, stride=2): 47 | return (size - (kernel_size - 1) - 1) // stride + 1 48 | convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(num_inputs))) 49 | convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(num_inputs))) 50 | linear_input_size = convw * convh * 32 51 | self.head = nn.Linear(linear_input_size, state_len//2) 52 | # Q1 architecture 53 | self.linear1 = nn.Linear(num_inputs + 3, hidden_dim) 54 | self.linear1_ = nn.Linear(num_inputs + 19, hidden_dim) 55 | self.linear2_ = nn.Linear(2*hidden_dim, hidden_dim) 56 | self.linear3 = nn.Linear(hidden_dim, 1) 57 | 58 | # Q2 architecture 59 | self.linear4 = nn.Linear(num_inputs + 3, hidden_dim) 60 | self.linear4_ = nn.Linear(num_inputs + 19, hidden_dim) 61 | self.linear5_ = nn.Linear(2*hidden_dim, hidden_dim) 62 | self.linear6 = nn.Linear(hidden_dim, 1) 63 | 64 | self.apply(weights_init_) 65 | 66 | def forward(self, input, action): 67 | #gt images 68 | in_1 = input[:, 0, :, :, :] 69 | # observed images 70 | in_2 = input[:, 1, :, :, :] 71 | 72 | x_1 = F.relu(self.bn1(self.conv1(in_1))) 73 | x_1 = F.relu(self.bn2(self.conv2(x_1))) 74 | x_1 = F.relu(self.bn3(self.conv3(x_1))) 75 | x_1 = torch.flatten(x_1,1) 76 | x_1 = self.head(x_1) 77 | x_2 = F.relu(self.bn1(self.conv1(in_2))) 78 | x_2 = F.relu(self.bn2(self.conv2(x_2))) 79 | x_2 = F.relu(self.bn3(self.conv3(x_2))) 80 | x_2 = torch.flatten(x_2,1) 81 | x_2 = self.head(x_2) 82 | state = torch.cat([x_1, x_2], dim=1) 83 | 84 | rot_embed = torch.cat([state, action[:, 0:3]], 1) 85 | trans_embed = torch.cat([state, action[:, 3:22]], 1) 86 | 87 | x1 = F.relu(self.linear1(rot_embed)) 88 | x1_ = F.relu(self.linear1_(trans_embed)) 89 | x1 = torch.cat([x1, x1_], dim=1) 90 | x1 = F.relu(self.linear2_(x1)) 91 | 92 | x1 = self.linear3(x1) 93 | 94 | x2 = F.relu(self.linear4(rot_embed)) 95 | x2_ = F.relu(self.linear4_(trans_embed)) 96 | x2 = torch.cat([x2, x2_], dim=1) 97 | x2 = F.relu(self.linear5_(x2)) 98 | x2 = self.linear6(x2) 99 | 100 | return x1, x2 101 | 102 | 103 | class GaussianPolicy(nn.Module): 104 | def __init__(self, num_inputs, state_len, num_actions, hidden_dim, action_space=None): 105 | super(GaussianPolicy, self).__init__() 106 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2) 107 | self.bn1 = nn.BatchNorm2d(16) 108 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2) 109 | self.bn2 = nn.BatchNorm2d(32) 110 | self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=2) 111 | self.bn3 = nn.BatchNorm2d(32) 112 | 113 | self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=2) 114 | self.bn1_1 = nn.BatchNorm2d(16) 115 | self.conv2_1 = nn.Conv2d(16, 32, kernel_size=3, stride=2) 116 | self.bn2_1 = nn.BatchNorm2d(32) 117 | self.conv3_1 = nn.Conv2d(32, 32, kernel_size=3, stride=2) 118 | self.bn3_1 = nn.BatchNorm2d(32) 119 | 120 | # Number of Linear input connections depends on output of conv2d layers 121 | # and therefore the input image size, so compute it. 122 | def conv2d_size_out(size, kernel_size=3, stride=2): 123 | return (size - (kernel_size - 1) - 1) // stride + 1 124 | convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(num_inputs))) 125 | convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(num_inputs))) 126 | linear_input_size = convw * convh * 32 127 | self.head = nn.Linear(linear_input_size, state_len//2) 128 | self.head_1 = nn.Linear(linear_input_size, state_len//2) 129 | self.linear1 = nn.Linear(state_len, hidden_dim) 130 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 131 | 132 | self.mean_linear = nn.Linear(hidden_dim, 2) 133 | self.log_std_linear = nn.Linear(hidden_dim, 2) 134 | self.linear3 = nn.Linear(state_len, hidden_dim) 135 | self.linear4 = nn.Linear(hidden_dim, hidden_dim) 136 | self.mean_linear_1 = nn.Linear(hidden_dim, 1) 137 | self.log_std_linear_1 = nn.Linear(hidden_dim, 1) 138 | self.mean_linear_ = nn.Linear(hidden_dim, 3) 139 | self.log_std_linear_ = nn.Linear(hidden_dim, 3) 140 | self.linear5 = nn.Linear(state_len, hidden_dim) 141 | self.linear6 = nn.Linear(hidden_dim, hidden_dim) 142 | self.linear7 = nn.Linear(state_len, hidden_dim) 143 | self.linear8 = nn.Linear(hidden_dim, hidden_dim) 144 | self.mean_linear_2 = nn.Linear(hidden_dim, num_actions - 6) 145 | self.log_std_linear_2 = nn.Linear(hidden_dim, num_actions - 6) 146 | self.apply(weights_init_) 147 | self.angle_net = [ 148 | self.conv1, self.conv2, self.conv3, self.bn1, self.bn2, self.bn3, self.head, self.linear1, self.linear2, 149 | self.linear3, self.linear4, self.mean_linear,self.mean_linear_1, self.log_std_linear, self.log_std_linear_1] 150 | self.trans_net = [ 151 | self.conv1_1, self.conv2_1, self.conv3_1, self.bn1_1, self.bn2_1, self.bn3_1, self.head_1, self.linear5, self.linear6, 152 | self.linear7, self.linear8, self.mean_linear_,self.mean_linear_2, self.log_std_linear_, self.log_std_linear_2] 153 | 154 | self.action_scale = torch.tensor(1.) 155 | self.action_bias = torch.tensor(0.) 156 | 157 | def forward(self, input): 158 | # gt images 159 | in_1 = input[:, 0, :, :, :] 160 | # observed images 161 | in_2 = input[:, 1, :, :, :] 162 | 163 | x_1 = F.relu(self.bn1(self.conv1(in_1))) 164 | x_1 = F.relu(self.bn2(self.conv2(x_1))) 165 | x_1 = F.relu(self.bn3(self.conv3(x_1))) 166 | x_1 = torch.flatten(x_1,1) 167 | x_1 = self.head(x_1) 168 | x_2 = F.relu(self.bn1(self.conv1(in_2))) 169 | x_2 = F.relu(self.bn2(self.conv2(x_2))) 170 | x_2 = F.relu(self.bn3(self.conv3(x_2))) 171 | x_2 = torch.flatten(x_2,1) 172 | x_2 = self.head(x_2) 173 | state = torch.cat([x_1, x_2], dim=1) 174 | 175 | x_1 = F.relu(self.bn1_1(self.conv1_1(in_1))) 176 | x_1 = F.relu(self.bn2_1(self.conv2_1(x_1))) 177 | x_1 = F.relu(self.bn3_1(self.conv3_1(x_1))) 178 | x_1 = torch.flatten(x_1,1) 179 | x_1 = self.head_1(x_1) 180 | x_2 = F.relu(self.bn1_1(self.conv1_1(in_2))) 181 | x_2 = F.relu(self.bn2_1(self.conv2_1(x_2))) 182 | x_2 = F.relu(self.bn3_1(self.conv3_1(x_2))) 183 | x_2 = torch.flatten(x_2,1) 184 | x_2 = self.head_1(x_2) 185 | state_1 = torch.cat([x_1, x_2], dim=1) 186 | 187 | 188 | # angle 189 | x = F.relu(self.linear1(state)) 190 | x = F.relu(self.linear2(x)) 191 | mean = self.mean_linear(x) 192 | log_std = self.log_std_linear(x) 193 | log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) 194 | # az 195 | x = F.relu(self.linear3(state)) 196 | x = F.relu(self.linear4(x)) 197 | mean_1 = self.mean_linear_1(x) 198 | log_std_1 = self.log_std_linear_1(x) 199 | log_std_1 = torch.clamp(log_std_1, min=LOG_SIG_MIN, max=LOG_SIG_MAX) 200 | #sxy 201 | x = F.relu(self.linear5(state_1)) 202 | x = F.relu(self.linear6(x)) 203 | mean_ = self.mean_linear_(x) 204 | log_std_ = self.log_std_linear_(x) 205 | log_std_ = torch.clamp(log_std_, min=LOG_SIG_MIN, max=LOG_SIG_MAX) 206 | #z 207 | x = F.relu(self.linear7(state_1)) 208 | x = F.relu(self.linear8(x)) 209 | mean_2 = self.mean_linear_2(x) 210 | log_std_2 = self.log_std_linear_2(x) 211 | log_std_2 = torch.clamp(log_std_2, min=LOG_SIG_MIN, max=LOG_SIG_MAX) 212 | mean = torch.cat([mean,mean_1, mean_, mean_2], dim=1) 213 | log_std = torch.cat([log_std, log_std_1, log_std_, log_std_2], dim=1) 214 | return mean, log_std 215 | 216 | def sample(self, state): 217 | mean, log_std = self.forward(state) 218 | std = log_std.exp() 219 | normal = Normal(mean, std) 220 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 221 | y_t = torch.tanh(x_t) 222 | action = y_t * self.action_scale + self.action_bias 223 | log_prob = normal.log_prob(x_t) 224 | # Enforcing Action Bound 225 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) 226 | log_prob = log_prob.sum(1, keepdim=True) 227 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 228 | return action, log_prob, mean 229 | 230 | def to(self, device): 231 | self.action_scale = self.action_scale.to(device) 232 | self.action_bias = self.action_bias.to(device) 233 | return super(GaussianPolicy, self).to(device) 234 | 235 | class DeterministicPolicy(nn.Module): 236 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): 237 | super(DeterministicPolicy, self).__init__() 238 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 239 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 240 | 241 | self.mean = nn.Linear(hidden_dim, num_actions) 242 | self.noise = torch.Tensor(num_actions) 243 | 244 | self.apply(weights_init_) 245 | 246 | # action rescaling 247 | if action_space is None: 248 | self.action_scale = 1. 249 | self.action_bias = 0. 250 | else: 251 | self.action_scale = torch.FloatTensor( 252 | (action_space.high - action_space.low) / 2.) 253 | self.action_bias = torch.FloatTensor( 254 | (action_space.high + action_space.low) / 2.) 255 | 256 | def forward(self, state): 257 | x = F.relu(self.linear1(state)) 258 | x = F.relu(self.linear2(x)) 259 | mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias 260 | return mean 261 | 262 | def sample(self, state): 263 | mean = self.forward(state) 264 | noise = self.noise.normal_(0., std=0.1) 265 | noise = noise.clamp(-0.25, 0.25) 266 | action = mean + noise 267 | return action, torch.tensor(0.), mean 268 | 269 | def to(self, device): 270 | self.action_scale = self.action_scale.to(device) 271 | self.action_bias = self.action_bias.to(device) 272 | self.noise = self.noise.to(device) 273 | return super(DeterministicPolicy, self).to(device) 274 | 275 | 276 | class DQN(nn.Module): 277 | def __init__(self, h, w, outputs): 278 | super(DQN, self).__init__() 279 | self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2) 280 | self.bn1 = nn.BatchNorm2d(16) 281 | self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2) 282 | self.bn2 = nn.BatchNorm2d(32) 283 | self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2) 284 | self.bn3 = nn.BatchNorm2d(32) 285 | 286 | # Number of Linear input connections depends on output of conv2d layers 287 | # and therefore the input image size, so compute it. 288 | def conv2d_size_out(size, kernel_size=5, stride=2): 289 | return (size - (kernel_size - 1) - 1) // stride + 1 290 | convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w))) 291 | convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h))) 292 | linear_input_size = convw * convh * 32 293 | self.head = nn.Linear(linear_input_size, outputs) 294 | 295 | # Called with either one element to determine next action, or a batch 296 | # during optimization. Returns tensor([[left0exp,right0exp]...]). 297 | def forward(self, input): 298 | x_1 = F.relu(self.bn1(self.conv1(x_1))) 299 | x_1 = F.relu(self.bn2(self.conv2(x_1))) 300 | x_1 = F.relu(self.bn3(self.conv3(x_1))) 301 | x_1 = self.head(x_1.view(x_1.size(0), -1)) 302 | x_2 = F.relu(self.bn1(self.conv1(x_2))) 303 | x_2 = F.relu(self.bn2(self.conv2(x_2))) 304 | x_2 = F.relu(self.bn3(self.conv3(x_2))) 305 | x_2 = self.head(x_2.view(x_2.size(0), -1)) 306 | state = torch.cat([x_1, x_2], dim=1) 307 | return state 308 | -------------------------------------------------------------------------------- /sac_2/model_encoder/image_to_latent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import resnet50 3 | from PIL import Image 4 | import numpy as np 5 | class ImageToLatent(torch.nn.Module): 6 | def __init__(self, image_size=256): 7 | super().__init__() 8 | 9 | self.image_size = image_size 10 | self.activation = torch.nn.ELU() 11 | 12 | self.resnet = list(resnet50(pretrained=True).children())[:-2] 13 | self.resnet = torch.nn.Sequential(*self.resnet) 14 | self.conv2d = torch.nn.Conv2d(2048, 256, kernel_size=1) 15 | self.flatten = torch.nn.Flatten() 16 | self.dense1 = torch.nn.Linear(4096, 512) 17 | self.dense2 = torch.nn.Linear(512, 256) 18 | 19 | def forward(self, image): 20 | x = self.resnet(image) 21 | x = self.conv2d(x) 22 | x = self.flatten(x) 23 | # print(x.size()) 24 | x = self.dense1(x) 25 | x = self.dense2(x) 26 | x = x.view((-1, 1, 256)) 27 | 28 | return x 29 | 30 | class ImageLatentDataset(torch.utils.data.Dataset): 31 | def __init__(self, filenames, dlatents, image_size=256, transforms = None): 32 | self.filenames = filenames 33 | self.dlatents = dlatents 34 | self.image_size = image_size 35 | self.transforms = transforms 36 | 37 | def __len__(self): 38 | return len(self.filenames) 39 | 40 | def __getitem__(self, index): 41 | filename = self.filenames[index] 42 | dlatent = self.dlatents[index] 43 | 44 | image = self.load_image(filename) 45 | image = Image.fromarray(np.uint8(image)) 46 | 47 | if self.transforms: 48 | image = self.transforms(image) 49 | 50 | return image, dlatent 51 | 52 | def load_image(self, filename): 53 | image = np.asarray(Image.open(filename)) 54 | 55 | return image 56 | -------------------------------------------------------------------------------- /sac_2/model_encoder/latent_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torchvision.models import vgg16 3 | import torch 4 | 5 | class PostSynthesisProcessing(torch.nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | self.min_value = -1 10 | self.max_value = 1 11 | 12 | def forward(self, synthesized_image): 13 | synthesized_image = (synthesized_image - self.min_value) * torch.tensor(255).float() / (self.max_value - self.min_value) 14 | synthesized_image = torch.clamp(synthesized_image + 0.5, min=0, max=255) 15 | 16 | return synthesized_image 17 | 18 | class VGGProcessing(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | 22 | self.image_size = 256 23 | self.mean = torch.tensor([0.485, 0.456, 0.406], device="cuda").view(-1, 1, 1) 24 | self.std = torch.tensor([0.229, 0.224, 0.225], device="cuda").view(-1, 1, 1) 25 | 26 | def forward(self, image): 27 | image = image / torch.tensor(255).float() 28 | image = F.adaptive_avg_pool2d(image, self.image_size) 29 | 30 | image = (image - self.mean) / self.std 31 | 32 | return image 33 | 34 | 35 | class LatentOptimizer(torch.nn.Module): 36 | def __init__(self, synthesizer, layer=12): 37 | super().__init__() 38 | 39 | self.synthesizer = synthesizer.cuda().eval() 40 | self.post_synthesis_processing = PostSynthesisProcessing() 41 | self.vgg_processing = VGGProcessing() 42 | self.vgg16 = vgg16(pretrained=True).features[:layer].cuda().eval() 43 | 44 | 45 | def forward(self, dlatents): 46 | generated_image = self.synthesizer(dlatents) 47 | generated_image = self.post_synthesis_processing(generated_image) 48 | generated_image = self.vgg_processing(generated_image) 49 | features = self.vgg16(generated_image) 50 | 51 | return features 52 | -------------------------------------------------------------------------------- /sac_2/model_encoder/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LatentLoss(torch.nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | self.l1_loss = L1Loss() 7 | self.log_cosh_loss = LogCoshLoss() 8 | self.l2_loss = torch.nn.MSELoss() 9 | 10 | def forward(self, real_features, generated_features, average_dlatents = None, dlatents = None): 11 | # Take a look at: 12 | # https://github.com/pbaylies/stylegan-encoder/blob/master/encoder/perceptual_model.py 13 | # For additional losses and practical scaling factors. 14 | 15 | loss = 0 16 | # Possible TODO: Add more feature based loss functions to create better optimized latents. 17 | 18 | # Modify scaling factors or disable losses to get best result (Image dependent). 19 | 20 | # VGG16 Feature Loss 21 | # Absolute vs MSE Loss 22 | # loss += 1 * self.l1_loss(real_features, generated_features) 23 | loss += 1 * self.l2_loss(real_features, generated_features) 24 | 25 | # Pixel Loss 26 | # loss += 1.5 * self.log_cosh_loss(real_image, generated_image) 27 | 28 | # Dlatent Loss - Forces latents to stay near the space the model uses for faces. 29 | if average_dlatents is not None and dlatents is not None: 30 | loss += 1 * 512 * self.l1_loss(average_dlatents, dlatents) 31 | 32 | return loss 33 | 34 | class LogCoshLoss(torch.nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | 38 | def forward(self, true, pred): 39 | loss = true - pred 40 | return torch.mean(torch.log(torch.cosh(loss + 1e-12))) 41 | 42 | class L1Loss(torch.nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | 46 | def forward(self, true, pred): 47 | return torch.mean(torch.abs(true - pred)) 48 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes useful helper functions.""" 2 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torchvision.models.vgg as models 4 | from torchvision import transforms 5 | import numpy as np 6 | from scipy.linalg import logm, norm 7 | import torch.nn as nn 8 | def trans_error(input, targets): 9 | targets = targets[:, 0:1] * targets[:, 1:3] 10 | input = input[:, 0:1] * input[:, 1:3] 11 | shift = torch.norm(targets-input, dim=1) * 100 12 | return shift 13 | 14 | def compute_RotMats(rot, degree=False): 15 | # print("a e t", a.shape, e.shape, t.shape) 16 | a, e, t = rot[:, 0], rot[:, 1], rot[:, 2] 17 | batch = rot.shape[0] 18 | Rz = torch.zeros((batch, 3, 3), dtype=torch.float32) 19 | Rx = torch.zeros((batch, 3, 3), dtype=torch.float32) 20 | Rz2 = torch.zeros((batch, 3, 3), dtype=torch.float32) 21 | Rz[:, 2, 2] = 1 22 | Rx[:, 0, 0] = 1 23 | Rz2[:, 2, 2] = 1 24 | # 25 | R = torch.zeros((batch, 3, 3), dtype=torch.float32) 26 | if degree: 27 | a = a * np.pi / 180. 28 | e = e * np.pi / 180. 29 | t = t * np.pi / 180. 30 | a = -a 31 | e = np.pi/2.+e 32 | t = -t 33 | # 34 | sin_a, cos_a = torch.sin(a), torch.cos(a) 35 | sin_e, cos_e = torch.sin(e), torch.cos(e) 36 | sin_t, cos_t = torch.sin(t), torch.cos(t) 37 | 38 | # =========================== 39 | # rotation matrix 40 | # =========================== 41 | """ 42 | # [Transposed] 43 | Rz = np.matrix( [[ cos(a), sin(a), 0 ], # model rotate by a 44 | [ -sin(a), cos(a), 0 ], 45 | [ 0, 0, 1 ]] ) 46 | # [Transposed] 47 | Rx = np.matrix( [[ 1, 0, 0 ], # model rotate by e 48 | [ 0, cos(e), sin(e) ], 49 | [ 0, -sin(e), cos(e) ]] ) 50 | # [Transposed] 51 | Rz2= np.matrix( [[ cos(t), sin(t), 0 ], # camera rotate by t (in-plane rotation) 52 | [-sin(t), cos(t), 0 ], 53 | [ 0, 0, 1 ]] ) 54 | R = Rz2*Rx*Rz 55 | """ 56 | 57 | # Original matrix (None-transposed.) 58 | # No need to set back to zero? 59 | Rz[:, 0, 0], Rz[:, 0, 1] = cos_a, -sin_a 60 | Rz[:, 1, 0], Rz[:, 1, 1] = sin_a, cos_a 61 | # 62 | Rx[:, 1, 1], Rx[:, 1, 2] = cos_e, -sin_e 63 | Rx[:, 2, 1], Rx[:, 2, 2] = sin_e, cos_e 64 | # 65 | Rz2[:, 0, 0], Rz2[:, 0, 1] = cos_t, -sin_t 66 | Rz2[:, 1, 0], Rz2[:, 1, 1] = sin_t, cos_t 67 | # R = Rz2*Rx*Rz 68 | R[:] = torch.einsum("nij,njk,nkl->nil", Rz2, Rx, Rz) 69 | 70 | # Return the original matrix without transpose! 71 | return R 72 | 73 | def rot_error(pose, pose_gt, reduction=False): 74 | R_pds = compute_RotMats(pose) 75 | R_gts = compute_RotMats(pose_gt) 76 | errors = [] 77 | for i in range(R_gts.shape[0]): 78 | R = R_pds[i] @ torch.transpose(R_gts[i], 0, 1) 79 | error = (torch.trace(R) - 1)/2 80 | error = torch.clamp(error, min=-1, max=1) 81 | error = torch.arccos(error) * 180/np.pi 82 | errors.append(error) 83 | if reduction: 84 | errors = sum(errors) / pose.shape[0] 85 | else: 86 | errors = torch.tensor(errors) 87 | return errors 88 | 89 | 90 | 91 | def euler_to_quaternion(input): 92 | roll = input[:, 0:1] 93 | pitch = input[:, 1:2] 94 | yaw = input[:, 2:3] 95 | qx = torch.sin(roll/2) * torch.cos(pitch/2) * torch.cos(yaw/2) - torch.cos(roll/2) * torch.sin(pitch/2) * torch.sin(yaw/2) 96 | qy = torch.cos(roll/2) * torch.sin(pitch/2) * torch.cos(yaw/2) + torch.sin(roll/2) * torch.cos(pitch/2) * torch.sin(yaw/2) 97 | qz = torch.cos(roll/2) * torch.cos(pitch/2) * torch.sin(yaw/2) - torch.sin(roll/2) * torch.sin(pitch/2) * torch.cos(yaw/2) 98 | qw = torch.cos(roll/2) * torch.cos(pitch/2) * torch.cos(yaw/2) + torch.sin(roll/2) * torch.sin(pitch/2) * torch.sin(yaw/2) 99 | 100 | return torch.stack([qx, qy, qz, qw], dim=1) 101 | 102 | class SetCriterion(nn.Module): 103 | """ This class computes the loss for DETR. 104 | The process happens in two steps: 105 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 106 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 107 | """ 108 | 109 | def __init__(self, args): 110 | """ Create the criterion. 111 | Parameters: 112 | num_classes: number of object categories, omitting the special no-object category 113 | matcher: module able to compute a matching between targets and proposals 114 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 115 | eos_coef: relative classification weight applied to the no-object category 116 | losses: list of all the losses to be applied. See get_loss for list of available losses. 117 | """ 118 | super().__init__() 119 | self.weight_dict = { 120 | 'rot_loss': 10, 121 | 'trans_loss': 5, 122 | 'latent_loss': 1, 123 | 'rot_distance': 0, 124 | 'trans_distance': 0 125 | } 126 | self.loss_map = { 127 | 'rot_loss': self.compute_quat_loss, 128 | 'trans_loss': self.compute_l2_loss, 129 | 'latent_loss': self.compute_l2_loss, 130 | 'rot_distance': self.compute_rot_loss, 131 | 'trans_distance': self.compute_trans_loss, 132 | } 133 | self.reduction= 'mean' 134 | self.category = args.dataset 135 | self.categories = ['bottle', 'bowl', 'camera', 'can', 'laptop', 'mug'] 136 | if self.category in ['bottle', 'can', 'bowl']: 137 | self.symm = True 138 | elif self.category in ['laptop', 'mug', 'camera']: 139 | self.symm = False 140 | self.l2_loss = torch.nn.MSELoss(reduction=self.reduction) 141 | self.l1_loss = torch.nn.L1Loss(reduction=self.reduction) 142 | self.vgg = models.vgg16(pretrained=True) 143 | 144 | 145 | def forward(self, outputs, targets, update_losses): 146 | losses = {} 147 | for loss in update_losses: 148 | losses.update(self.get_loss(loss, outputs[loss], targets[loss])) 149 | 150 | return losses 151 | 152 | def get_loss(self, loss, outputs, targets): 153 | 154 | assert loss in self.loss_map, f'do you really want to compute {loss} loss?' 155 | return self.loss_map[loss](outputs, targets, name=loss) 156 | 157 | def compute_vgg_loss(self, input, targets, name): 158 | 159 | normalization_mean = [0.485, 0.456, 0.406] 160 | normalization_std = [0.229, 0.224, 0.225] 161 | loader = transforms.Compose( 162 | [transforms.Normalize(mean=normalization_mean, std=normalization_std)]) 163 | vgg_features_gt = self.vgg(loader(targets)) 164 | vgg_features_image = self.vgg(loader(input)) 165 | loss = self.l2_loss(vgg_features_gt, vgg_features_image) 166 | losses = {name: loss} 167 | 168 | return losses 169 | 170 | 171 | def compute_l1_loss(self, input, targets, name): 172 | loss = self.l1_loss(input, targets) 173 | losses = {name: loss} 174 | return losses 175 | 176 | def compute_quat_loss(self, input, gt, name): 177 | input = input * np.pi 178 | gt = gt * np.pi 179 | if self.symm == True: 180 | input[:, 1] = 0 181 | gt[:, 1] = 0 182 | loss = self.l2_loss(euler_to_quaternion(input), euler_to_quaternion(gt)) 183 | losses = {name: loss} 184 | return losses 185 | 186 | def compute_l2_loss(self, input, targets, name): 187 | loss = self.l2_loss(input, targets) 188 | losses = {name: loss} 189 | return losses 190 | 191 | 192 | def compute_rot_loss(self, input, targets, name): 193 | input = input * np.pi 194 | targets = targets * np.pi 195 | if self.symm == True: 196 | input[:, 1] = 0 197 | targets[:, 1] = 0 198 | loss = rot_error(input, targets) 199 | losses = {name: loss} 200 | return losses 201 | 202 | def compute_trans_loss(self, T, T_gt, name): 203 | loss = trans_error(T, T_gt) 204 | losses = {name: loss} 205 | return losses 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def tensor2im(input_image, imtype=np.uint8): 12 | """"Converts a Tensor array into a numpy image array. 13 | 14 | Parameters: 15 | input_image (tensor) -- the input image tensor array 16 | imtype (type) -- the desired type of the converted numpy array 17 | """ 18 | if not isinstance(input_image, np.ndarray): 19 | if isinstance(input_image, torch.Tensor): # get the data from a variable 20 | image_tensor = input_image.data 21 | else: 22 | return input_image 23 | image_numpy = image_tensor[0].clamp(-1, 1).cpu().float().numpy() # convert it into a numpy array 24 | if image_numpy.shape[0] == 1: # grayscale to RGB 25 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 26 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 27 | else: # if it is a numpy array, do nothing 28 | image_numpy = input_image 29 | return image_numpy.astype(imtype) 30 | 31 | 32 | def diagnose_network(net, name='network'): 33 | """Calculate and print the mean of average absolute(gradients) 34 | 35 | Parameters: 36 | net (torch network) -- Torch network 37 | name (str) -- the name of the network 38 | """ 39 | mean = 0.0 40 | count = 0 41 | for param in net.parameters(): 42 | if param.grad is not None: 43 | mean += torch.mean(torch.abs(param.grad.data)) 44 | count += 1 45 | if count > 0: 46 | mean = mean / count 47 | print(name) 48 | print(mean) 49 | 50 | 51 | def save_image(image_numpy, image_path): 52 | """Save a numpy image to the disk 53 | 54 | Parameters: 55 | image_numpy (numpy array) -- input numpy array 56 | image_path (str) -- the path of the image 57 | """ 58 | image_pil = Image.fromarray(image_numpy) 59 | image_pil.save(image_path) 60 | 61 | 62 | def print_numpy(x, val=True, shp=False): 63 | """Print the mean, min, max, median, std, and size of a numpy array 64 | 65 | Parameters: 66 | val (bool) -- if print the values of the numpy array 67 | shp (bool) -- if print the shape of the numpy array 68 | """ 69 | x = x.astype(np.float64) 70 | if shp: 71 | print('shape,', x.shape) 72 | if val: 73 | x = x.flatten() 74 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 75 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 76 | 77 | 78 | def mkdirs(paths): 79 | """create empty directories if they don't exist 80 | 81 | Parameters: 82 | paths (str list) -- a list of directory paths 83 | """ 84 | if isinstance(paths, list) and not isinstance(paths, str): 85 | for path in paths: 86 | mkdir(path) 87 | else: 88 | mkdir(paths) 89 | 90 | 91 | def mkdir(path): 92 | """create a single empty directory if it didn't exist 93 | 94 | Parameters: 95 | path (str) -- a single directory path 96 | """ 97 | if not os.path.exists(path): 98 | os.makedirs(path) 99 | 100 | def sort_str_by_num(l): 101 | """ Sort the given list in the way that humans expect. 102 | """ 103 | 104 | def alphanum_key(s): 105 | """ Turn a string into a list of string and number chunks. 106 | "z23a" -> ["z", 23, "a"] 107 | """ 108 | def tryint(s): 109 | try: return int(s) 110 | except: return s 111 | import re 112 | return [tryint(c) for c in re.split('([0-9]+)', s)] 113 | 114 | l.sort(key=alphanum_key) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torchvision.models.vgg as models 4 | from torchvision import transforms 5 | import numpy as np 6 | from scipy.linalg import logm, norm 7 | import matplotlib.pyplot as plt 8 | import os 9 | degree_thres_list = range(0, 61, 1) 10 | def is_between(a, x, b): 11 | return min(a, b) <= x <= max(a, b) 12 | 13 | def rotationMatrixToEulerAngles(R) : 14 | x = -math.asin(-R[2, 1]) 15 | y = -math.atan2(R[2, 0], R[2, 2]) 16 | z = -math.atan2(R[0, 1], R[1, 1]) 17 | return np.array([y, x, z]) 18 | def calc_ap(loss): 19 | 20 | acc = np.zeros_like(degree_thres_list) 21 | for j in range(len(loss)): 22 | for k in range(len(degree_thres_list)): 23 | if int(loss[j]) <= degree_thres_list[k]: 24 | acc[k] += 1.0 25 | acc = acc / len(loss) * 100.0 26 | return acc 27 | 28 | def show_AP(angle_loss, name, path): 29 | 30 | acc_il_step_1 = calc_ap(angle_loss) 31 | plt.figure(figsize=(5,5)) 32 | plt.plot(degree_thres_list, 33 | acc_il_step_1, 34 | linestyle="-", 35 | color=(0, 0, 1, 0.2), 36 | # marker="o", 37 | alpha=0.5, 38 | linewidth=3, 39 | label=name) 40 | plt.savefig(os.path.join(path, "AP_0_60_degree.png")) 41 | 42 | def load_image(path, size = 128): 43 | transform_list = [] 44 | target_size = [size, size] 45 | transform_list.append(transforms.Resize(target_size, Image.BICUBIC)) 46 | 47 | transform_list += [transforms.ToTensor()] 48 | 49 | image = Image.open(path).convert('RGB') 50 | trans = transforms.Compose(transform_list) 51 | image = trans(image) 52 | return image 53 | def create_log_gaussian(mean, log_std, t): 54 | quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2)) 55 | l = mean.shape 56 | log_z = log_std 57 | z = l[-1] * math.log(2 * math.pi) 58 | log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z 59 | return log_p 60 | 61 | 62 | def logsumexp(inputs, dim=None, keepdim=False): 63 | if dim is None: 64 | inputs = inputs.view(-1) 65 | dim = 0 66 | s, _ = torch.max(inputs, dim=dim, keepdim=True) 67 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() 68 | if not keepdim: 69 | outputs = outputs.squeeze(dim) 70 | return outputs 71 | 72 | 73 | def soft_update(target, source, tau): 74 | for target_param, param in zip(target.parameters(), source.parameters()): 75 | target_param.data.copy_( 76 | target_param.data * (1.0 - tau) + param.data * tau) 77 | 78 | 79 | def hard_update(target, source): 80 | for target_param, param in zip(target.parameters(), source.parameters()): 81 | target_param.data.copy_(param.data) 82 | 83 | 84 | -------------------------------------------------------------------------------- /utils/visualizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrld/visual_navigation_pose_estimation/58d98a3592157f2558120f18af7c9ec77e795ee1/utils/visualizer/__init__.py -------------------------------------------------------------------------------- /utils/visualizer/base_visualizer.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | import numpy as np 5 | 6 | from utils import util 7 | 8 | from cv2 import resize 9 | 10 | class BaseVisualizer(): 11 | @staticmethod 12 | def modify_commandline_options(parser): 13 | opt, _ = parser.parse_known_args() 14 | 15 | for vis_name in opt.visualizers: 16 | vis_filename = "utils.visualizer." + vis_name + "_visualizer" 17 | vislib = importlib.import_module(vis_filename) 18 | vis = None 19 | target_vis_name = vis_name + 'visualizer' 20 | for name, cls in vislib.__dict__.items(): 21 | if name.lower() == target_vis_name.lower() \ 22 | and issubclass(cls, BaseVisualizer): 23 | vis = cls 24 | 25 | if vis is None: 26 | print( 27 | "In %s.py, there should be a subclass of BaseVisualizer with class name that matches %s in lowercase." % ( 28 | vis_filename, target_vis_name)) 29 | exit(0) 30 | 31 | parser = vis.modify_commandline_options(parser) 32 | 33 | return parser 34 | 35 | def __init__(self, opt): 36 | self.visualizer_list = [] 37 | 38 | for vis_name in opt.visualizers: 39 | vis_filename = "util.visualizer." + vis_name + "_visualizer" 40 | vislib = importlib.import_module(vis_filename) 41 | vis = None 42 | target_vis_name = vis_name + 'visualizer' 43 | for name, cls in vislib.__dict__.items(): 44 | if name.lower() == target_vis_name.lower() \ 45 | and issubclass(cls, BaseVisualizer): 46 | vis = cls 47 | 48 | if vis is None: 49 | print( 50 | "In %s.py, there should be a subclass of BaseVisualizer with class name that matches %s in lowercase." % ( 51 | vis_filename, target_vis_name)) 52 | exit(0) 53 | 54 | self.visualizer_list.append(vis(opt)) 55 | 56 | def update_state(self,epochs,iters,times): 57 | for visualizer in self.visualizer_list: 58 | visualizer.update_state(epochs,iters,times) 59 | 60 | def display_current_results(self, visuals): 61 | """Display current results on visdom; save current results to an HTML file. 62 | 63 | Parameters: 64 | visuals (OrderedDict) - - dictionary of images to display or save 65 | epoch (int) - - the current epoch 66 | save_result (bool) - - if save the current results to an HTML file 67 | """ 68 | for visualizer in self.visualizer_list: 69 | visualizer.display_current_results(visuals) 70 | 71 | def display_current_videos(self, visuals): 72 | """Display current results on visdom; save current results to an HTML file. 73 | 74 | Parameters: 75 | visuals (OrderedDict) - - dictionary of images to display or save 76 | epoch (int) - - the current epoch 77 | save_result (bool) - - if save the current results to an HTML file 78 | """ 79 | for visualizer in self.visualizer_list: 80 | visualizer.display_current_videos(visuals) 81 | 82 | def plot_current_losses(self, losses): 83 | """print current losses on console; also save the losses to the disk 84 | 85 | Parameters: 86 | epoch (int) -- current epoch 87 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 88 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 89 | t_comp (float) -- computational time per data point (normalized by batch_size) 90 | t_data (float) -- data loading time per data point (normalized by batch_size) 91 | """ 92 | # losses: same format as |losses| of plot_current_losses 93 | for visualizer in self.visualizer_list: 94 | visualizer.plot_current_losses(losses) 95 | 96 | 97 | 98 | 99 | # TODO merge image saver for test and training time 100 | def save_images(webpage, visuals, name, aspect_ratio=1.0, width=256): 101 | """Save images to the disk. 102 | 103 | Parameters: 104 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 105 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 106 | image_path (str) -- the string is used to create image paths 107 | aspect_ratio (float) -- the aspect ratio of saved images 108 | width (int) -- the images will be resized to width x width 109 | 110 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 111 | """ 112 | image_dir = webpage.get_image_dir() 113 | 114 | ims, txts, links = [], [], [] 115 | 116 | for label, im_data in visuals.items(): 117 | im = util.tensor2im(im_data) 118 | image_name = '%s_%s.png' % (name, label) 119 | save_path = os.path.join(image_dir, image_name) 120 | h, w, _ = im.shape 121 | if aspect_ratio > 1.0: 122 | im = resize(im, (h, int(w * aspect_ratio)), interpolation='bicubic') 123 | if aspect_ratio < 1.0: 124 | im = resize(im, (int(h / aspect_ratio), w), interpolation='bicubic') 125 | util.save_image(im, save_path) 126 | 127 | ims.append(image_name) 128 | txts.append(label) 129 | links.append(image_name) 130 | webpage.add_images(ims, txts, links, width=width) 131 | 132 | 133 | def get_img_from_fig(fig, dpi=180): 134 | import io 135 | import cv2 136 | buf = io.BytesIO() 137 | fig.savefig(buf, format="png", dpi=180) 138 | buf.seek(0) 139 | img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) 140 | buf.close() 141 | img = cv2.imdecode(img_arr, 1) 142 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 143 | 144 | return img 145 | -------------------------------------------------------------------------------- /utils/visualizer/terminal_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from utils import util 5 | from utils.visualizer.base_visualizer import BaseVisualizer 6 | 7 | 8 | class TerminalVisualizer(BaseVisualizer): 9 | """This class stores the training results, images in HTML and losses in text file. 10 | """ 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser): 14 | return parser 15 | 16 | def __init__(self, opt): 17 | """Initialize the Visualizer class 18 | """ 19 | self.opt = opt # cache the option 20 | self.name = opt.exp_name 21 | self.win_size = opt.crop_size 22 | self.epoch = -1 23 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.project_name, opt.exp_name,opt.run_name, 'web') 24 | self.img_dir = os.path.join(self.web_dir, 'images') 25 | print('create web directory %s...' % self.web_dir) 26 | util.mkdirs([self.web_dir, self.img_dir]) 27 | 28 | # create a logging file to store training losses 29 | self.log_name = os.path.join(opt.checkpoints_dir, opt.project_name, opt.exp_name,opt.run_name, 'loss_log.txt') 30 | with open(self.log_name, "a") as log_file: 31 | now = time.strftime("%c") 32 | log_file.write('================ Training Loss (%s) ================\n' % now) 33 | 34 | def update_state(self,epochs,iters,times): 35 | self.epochs = epochs 36 | self.iters = iters 37 | self.times = times 38 | 39 | def display_current_results(self, visuals): 40 | 41 | # save images to the disk 42 | for label, image in visuals.items(): 43 | image_numpy = util.tensor2im(image) 44 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (self.epochs, label)) 45 | util.save_image(image_numpy, img_path) 46 | 47 | # # update website 48 | # webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 49 | # for n in range(self.epochs, 0, -1): 50 | # webpage.add_header('epoch [%d]' % n) 51 | # ims, txts, links = [], [], [] 52 | 53 | # for label, image_numpy in visuals.items(): 54 | # image_numpy = util.tensor2im(image) 55 | # img_path = 'epoch%.3d_%s.png' % (n, label) 56 | # ims.append(img_path) 57 | # txts.append(label) 58 | # links.append(img_path) 59 | # webpage.add_images(ims, txts, links, width=self.win_size) 60 | # webpage.save() 61 | 62 | def display_current_videos(self, visuals): 63 | import imageio 64 | from torchvision.utils import make_grid 65 | for label, visual in visuals.items(): 66 | frames = [] 67 | path = os.path.join(self.web_dir, 'epoch%.3d_%s.gif' % (self.epochs, label)) 68 | for frame in visual: 69 | image = util.tensor2im(make_grid(frame).unsqueeze(0)) 70 | frames.append(image) 71 | imageio.mimsave(path, frames) 72 | 73 | # losses: same format as |losses| of plot_current_losses 74 | def plot_current_losses(self, losses): 75 | 76 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (self.epochs, self.iters, self.times['comp'], self.times['data']) 77 | for k, v in losses.items(): 78 | message += '%s: %.3f ' % (k, v) 79 | 80 | print(message) # print the message 81 | with open(self.log_name, "a") as log_file: 82 | log_file.write('%s\n' % message) # save the message -------------------------------------------------------------------------------- /utils/visualizer/wandb_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import wandb 5 | 6 | from utils import util 7 | from utils.visualizer.base_visualizer import BaseVisualizer 8 | 9 | 10 | class WandbVisualizer(BaseVisualizer): 11 | """This class includes several functions that can display/save images and print/save logging information. 12 | 13 | It uses a Python library 'wandb' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 14 | """ 15 | 16 | @staticmethod 17 | def modify_commandline_options(parser): 18 | return parser 19 | 20 | def __init__(self, opt): 21 | self.opt = opt # cache the option 22 | config_file = os.path.join(opt.checkpoints_dir,opt.project_name, opt.exp_name, opt.run_name, 'config.yaml') 23 | import yaml 24 | with open(config_file, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) 25 | wandb.init(project=opt.project_name, name=opt.run_name, group=opt.exp_name, config=config) 26 | 27 | def update_state(self,epochs,iters,times): 28 | self.epochs = epochs 29 | self.iters = iters 30 | self.times = times 31 | 32 | def display_current_results(self, visuals): 33 | from torchvision.utils import make_grid 34 | visual_wandb = {} 35 | for key, image_tensor in visuals.items(): 36 | visual_wandb[key] = wandb.Image(util.tensor2im(make_grid(image_tensor).unsqueeze(0))) 37 | wandb.log(visual_wandb,step=self.iters) 38 | 39 | def display_current_videos(self, visuals): 40 | from torchvision.utils import make_grid 41 | video_wandb = {} 42 | for label, visual in visuals.items(): 43 | frames = [] 44 | for frame in visual: 45 | image = util.tensor2im(make_grid(frame).unsqueeze(0)) 46 | frames.append(image) 47 | gif = np.stack(frames, axis=0) 48 | gif = np.transpose(gif, (0, 3, 1, 2)) 49 | 50 | video_wandb[label] = wandb.Video(gif, fps=20) 51 | print('hello') 52 | 53 | wandb.log(video_wandb,step=self.iters) 54 | 55 | def plot_current_losses(self, losses): 56 | wandb.log(losses,step=self.iters) --------------------------------------------------------------------------------