├── .gitignore ├── LICENSE ├── README.md ├── multimodal ├── __init__.py ├── configs │ └── training_default.yaml ├── dataloaders │ ├── MultimodalManipulationDataset.py │ ├── ProcessFlow.py │ ├── ProcessForce.py │ ├── ToTensor.py │ └── __init__.py ├── dataset │ └── download_data.sh ├── logger.py ├── mini_main.py ├── models │ ├── __init__.py │ ├── base_models │ │ ├── __init__.py │ │ ├── decoders.py │ │ ├── encoders.py │ │ └── layers.py │ ├── models_utils.py │ ├── sensor_fusion.py │ └── tests │ │ ├── __init__.py │ │ └── test_layers.py ├── scripts │ └── run_all_tests.sh ├── trainers │ ├── __init__.py │ └── selfsupervised.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # directories 2 | 3 | *.DS_Store* 4 | data/ 5 | tmp/ 6 | **/*.pkl 7 | 8 | logging/* 9 | 10 | *.h5py 11 | sftp-config.json 12 | 13 | # Byte-compiled / optimized / DLL files 14 | **__pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # Vim swap files 19 | *~ 20 | *.swp 21 | *.swo 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | db.sqlite3 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Stanford Interactive Perception and Robot Learning Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Representation 2 | 3 | Code for Making Sense of Vision and Touch. 4 | https://sites.google.com/view/visionandtouch 5 | 6 | Code written by: Matthew Tan, Michelle Lee, Peter Zachares, Yuke Zhu 7 | 8 | ## requirements 9 | `pip install -r requirements.txt` 10 | 11 | ## get dataset 12 | 13 | ``` 14 | cd multimodal/dataset 15 | ./download_data.sh 16 | ``` 17 | ## run training 18 | 19 | `python mini_main.py --config configs/training_default.yaml` 20 | 21 | 22 | ## ROBOT DATASET 23 | ---- 24 | action Dataset {50, 4}\ 25 | contact Dataset {50, 50}\ 26 | depth_data Dataset {50, 128, 128, 1}\ 27 | ee_forces_continuous Dataset {50, 50, 6}\ 28 | ee_ori Dataset {50, 4}\ 29 | ee_pos Dataset {50, 3}\ 30 | ee_vel Dataset {50, 3}\ 31 | ee_vel_ori Dataset {50, 3}\ 32 | ee_yaw Dataset {50, 4}\ 33 | ee_yaw_delta Dataset {50, 4}\ 34 | image Dataset {50, 128, 128, 3}\ 35 | joint_pos Dataset {50, 7}\ 36 | joint_vel Dataset {50, 7}\ 37 | optical_flow Dataset {50, 128, 128, 2}\ 38 | proprio Dataset {50, 8}\ 39 | 40 | -------------------------------------------------------------------------------- /multimodal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-iprl-lab/multimodal_representation/e302c61a4b4fa884250a5688aae2775952276352/multimodal/__init__.py -------------------------------------------------------------------------------- /multimodal/configs/training_default.yaml: -------------------------------------------------------------------------------- 1 | training_type: "selfsupervised" 2 | log_level: 'INFO' 3 | 4 | test: False 5 | 6 | # Ablations 7 | encoder: False 8 | deterministic: False 9 | vision: 1.0 10 | depth: 1.0 11 | proprio: 1.0 12 | force: 1.0 13 | sceneflow: 1.0 14 | opticalflow: 1.0 15 | contact: 1.0 16 | pairing: 1.0 17 | eedelta: 1.0 18 | 19 | # Training parameters 20 | lr: 0.0001 21 | beta1: 0.9 22 | seed: 1234 23 | max_epoch: 50 24 | batch_size: 64 25 | ep_length: 50 26 | zdim: 128 27 | action_dim: 4 28 | 29 | # Dataset params 30 | dataset_params: 31 | force_name: "force" 32 | action_dim: 4 33 | 34 | load: '' 35 | logging_folder: logging/ 36 | 37 | 38 | # path to dataset hdf5 file' 39 | dataset: "dataset/triangle_real_data/" 40 | 41 | val_ratio: 0.20 42 | cuda: True 43 | num_workers: 8 44 | img_record_n: 500 45 | -------------------------------------------------------------------------------- /multimodal/dataloaders/MultimodalManipulationDataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import ipdb 4 | from tqdm import tqdm 5 | 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class MultimodalManipulationDataset(Dataset): 10 | """Multimodal Manipulation dataset.""" 11 | 12 | def __init__( 13 | self, 14 | filename_list, 15 | transform=None, 16 | episode_length=50, 17 | training_type="selfsupervised", 18 | n_time_steps=1, 19 | action_dim=4, 20 | pairing_tolerance=0.06 21 | ): 22 | """ 23 | Args: 24 | hdf5_file (handle): h5py handle of the hdf5 file with annotations. 25 | transform (callable, optional): Optional transform to be applied 26 | on a sample. 27 | """ 28 | self.dataset_path = filename_list 29 | self.transform = transform 30 | self.episode_length = episode_length 31 | self.training_type = training_type 32 | self.n_time_steps = n_time_steps 33 | self.dataset = {} 34 | self.action_dim = action_dim 35 | self.pairing_tolerance = pairing_tolerance 36 | 37 | self._config_checks() 38 | self._init_paired_filenames() 39 | 40 | def __len__(self): 41 | return len(self.dataset_path) * (self.episode_length - self.n_time_steps) 42 | 43 | def __getitem__(self, idx): 44 | 45 | list_index = idx // (self.episode_length - self.n_time_steps) 46 | dataset_index = idx % (self.episode_length - self.n_time_steps) 47 | filename = self.dataset_path[list_index][:-8] 48 | 49 | file_number, filename = self._parse_filename(filename) 50 | 51 | unpaired_filename, unpaired_idx = self.paired_filenames[(list_index, dataset_index)] 52 | 53 | if dataset_index >= self.episode_length - self.n_time_steps - 1: 54 | dataset_index = np.random.randint( 55 | self.episode_length - self.n_time_steps - 1 56 | ) 57 | 58 | sample = self._get_single( 59 | self.dataset_path[list_index], 60 | list_index, 61 | unpaired_filename, 62 | dataset_index, 63 | unpaired_idx, 64 | ) 65 | return sample 66 | 67 | def _get_single( 68 | self, dataset_name, list_index, unpaired_filename, dataset_index, unpaired_idx 69 | ): 70 | 71 | dataset = h5py.File(dataset_name, "r", swmr=True, libver="latest") 72 | unpaired_dataset = h5py.File(unpaired_filename, "r", swmr=True, libver="latest") 73 | 74 | if self.training_type == "selfsupervised": 75 | 76 | image = dataset["image"][dataset_index] 77 | depth = dataset["depth_data"][dataset_index] 78 | proprio = dataset["proprio"][dataset_index][:8] 79 | force = dataset["ee_forces_continuous"][dataset_index] 80 | 81 | if image.shape[0] == 3: 82 | image = np.transpose(image, (2, 1, 0)) 83 | 84 | if depth.ndim == 2: 85 | depth = depth.reshape((128, 128, 1)) 86 | 87 | flow = np.array(dataset["optical_flow"][dataset_index]) 88 | flow_mask = np.expand_dims( 89 | np.where( 90 | flow.sum(axis=2) == 0, 91 | np.zeros_like(flow.sum(axis=2)), 92 | np.ones_like(flow.sum(axis=2)), 93 | ), 94 | 2, 95 | ) 96 | 97 | unpaired_image = image 98 | unpaired_depth = depth 99 | unpaired_proprio = unpaired_dataset["proprio"][unpaired_idx][:8] 100 | unpaired_force = unpaired_dataset["ee_forces_continuous"][unpaired_idx] 101 | 102 | sample = { 103 | "image": image, 104 | "depth": depth, 105 | "flow": flow, 106 | "flow_mask": flow_mask, 107 | "action": dataset["action"][dataset_index + 1], 108 | "force": force, 109 | "proprio": proprio, 110 | "ee_yaw_next": dataset["proprio"][dataset_index + 1][:self.action_dim], 111 | "contact_next": np.array( 112 | [dataset["contact"][dataset_index + 1].sum() > 0] 113 | ).astype(np.float), 114 | "unpaired_image": unpaired_image, 115 | "unpaired_force": unpaired_force, 116 | "unpaired_proprio": unpaired_proprio, 117 | "unpaired_depth": unpaired_depth, 118 | } 119 | 120 | dataset.close() 121 | unpaired_dataset.close() 122 | 123 | if self.transform: 124 | sample = self.transform(sample) 125 | 126 | return sample 127 | 128 | def _init_paired_filenames(self): 129 | """ 130 | Precalculates the paired filenames. 131 | Imposes a distance tolerance between paired images 132 | """ 133 | tolerance = self.pairing_tolerance 134 | 135 | all_combos = set() 136 | 137 | self.paired_filenames = {} 138 | for list_index in tqdm(range(len(self.dataset_path)), desc="pairing_files"): 139 | filename = self.dataset_path[list_index] 140 | file_number, _ = self._parse_filename(filename[:-8]) 141 | 142 | dataset = h5py.File(filename, "r", swmr=True, libver="latest") 143 | 144 | for idx in range(self.episode_length - self.n_time_steps): 145 | 146 | proprio_dist = None 147 | while proprio_dist is None or proprio_dist < tolerance: 148 | # Get a random idx, file that is not the same as current 149 | unpaired_dataset_idx = np.random.randint(self.__len__()) 150 | unpaired_filename, unpaired_idx, _ = self._idx_to_filename_idx(unpaired_dataset_idx) 151 | 152 | while unpaired_filename == filename: 153 | unpaired_dataset_idx = np.random.randint(self.__len__()) 154 | unpaired_filename, unpaired_idx, _ = self._idx_to_filename_idx(unpaired_dataset_idx) 155 | 156 | with h5py.File(unpaired_filename, "r", swmr=True, libver="latest") as unpaired_dataset: 157 | proprio_dist = np.linalg.norm(dataset['proprio'][idx][:3] - unpaired_dataset['proprio'][unpaired_idx][:3]) 158 | 159 | self.paired_filenames[(list_index, idx)] = (unpaired_filename, unpaired_idx) 160 | all_combos.add((unpaired_filename, unpaired_idx)) 161 | 162 | dataset.close() 163 | 164 | def _idx_to_filename_idx(self, idx): 165 | """ 166 | Utility function for finding info about a dataset index 167 | 168 | Args: 169 | idx (int): Dataset index 170 | 171 | Returns: 172 | filename (string): Filename associated with dataset index 173 | dataset_index (int): Index of data within that file 174 | list_index (int): Index of data in filename list 175 | """ 176 | list_index = idx // (self.episode_length - self.n_time_steps) 177 | dataset_index = idx % (self.episode_length - self.n_time_steps) 178 | filename = self.dataset_path[list_index] 179 | return filename, dataset_index, list_index 180 | 181 | def _parse_filename(self, filename): 182 | """ Parses the filename to get the file number and filename""" 183 | if filename[-2] == "_": 184 | file_number = int(filename[-1]) 185 | filename = filename[:-1] 186 | else: 187 | file_number = int(filename[-2:]) 188 | filename = filename[:-2] 189 | 190 | return file_number, filename 191 | 192 | def _config_checks(self): 193 | if self.training_type != "selfsupervised": 194 | raise ValueError( 195 | "Training type not supported: {}".format(self.training_type) 196 | ) 197 | -------------------------------------------------------------------------------- /multimodal/dataloaders/ProcessFlow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ProcessFlow(object): 6 | """Process optical flow into a pyramid. 7 | Args: 8 | pyramid_scale (list): scaling factors to downsample 9 | the spatial pyramid 10 | """ 11 | 12 | def __init__(self, pyramid_scales=[2, 4, 8]): 13 | assert isinstance(pyramid_scales, list) 14 | self.pyramid_scales = pyramid_scales 15 | 16 | def __call__(self, sample): 17 | # subsampling to create small flow images 18 | for scale in self.pyramid_scales: 19 | scaled_flow = sample['flow'][::scale, ::scale] 20 | sample['flow{}'.format(scale)] = scaled_flow 21 | return sample 22 | -------------------------------------------------------------------------------- /multimodal/dataloaders/ProcessForce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ProcessForce(object): 6 | """Truncate a time series of force readings with a window size. 7 | Args: 8 | window_size (int): Length of the history window that is 9 | used to truncate the force readings 10 | """ 11 | 12 | def __init__(self, window_size, key='force', tanh=False): 13 | assert isinstance(window_size, int) 14 | self.window_size = window_size 15 | self.key = key 16 | self.tanh = tanh 17 | 18 | def __call__(self, sample): 19 | force = sample[self.key] 20 | force = force[-self.window_size:] 21 | if self.tanh: 22 | force = np.tanh(force) # remove very large force readings 23 | sample[self.key] = force.transpose() 24 | return sample 25 | -------------------------------------------------------------------------------- /multimodal/dataloaders/ToTensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ToTensor(object): 6 | """Convert ndarrays in sample to Tensors.""" 7 | 8 | def __init__(self, device=None): 9 | self.device = device 10 | 11 | def __call__(self, sample): 12 | # swap color axis because 13 | # numpy image: H x W x C 14 | # torch image: C X H X W 15 | 16 | # transpose flow into 2 x H x W 17 | for k in sample.keys(): 18 | if k.startswith('flow'): 19 | sample[k] = sample[k].transpose((2, 0, 1)) 20 | 21 | # convert numpy arrays to pytorch tensors 22 | new_dict = dict() 23 | for k, v in sample.items(): 24 | if self.device is None: 25 | # torch.tensor(v, device = self.device, dtype = torch.float32) 26 | new_dict[k] = torch.FloatTensor(v) 27 | else: 28 | new_dict[k] = torch.from_numpy(v).float() 29 | 30 | return new_dict 31 | -------------------------------------------------------------------------------- /multimodal/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .MultimodalManipulationDataset import MultimodalManipulationDataset 2 | from .ProcessForce import ProcessForce 3 | from .ProcessFlow import ProcessFlow 4 | from .ToTensor import ToTensor 5 | -------------------------------------------------------------------------------- /multimodal/dataset/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://downloads.cs.stanford.edu/juno/triangle_real_data.zip -O _tmp.zip 4 | 5 | unzip _tmp.zip 6 | rm _tmp.zip 7 | -------------------------------------------------------------------------------- /multimodal/logger.py: -------------------------------------------------------------------------------- 1 | import git 2 | from tensorboardX import SummaryWriter 3 | import datetime 4 | import time 5 | import os 6 | 7 | import logging 8 | import sys 9 | import yaml 10 | 11 | 12 | class Logger(object): 13 | """ 14 | Hooks for print statements and tensorboard logging 15 | """ 16 | 17 | def __init__(self, configs): 18 | 19 | self.configs = configs 20 | 21 | time_str = datetime.datetime.fromtimestamp(time.time()).strftime("%Y%m%d%H%M") 22 | prefix_str = time_str + "_" + configs["notes"] 23 | if configs["dev"]: 24 | prefix_str = "dev_" + prefix_str 25 | 26 | self.log_folder = os.path.join(self.configs["logging_folder"], prefix_str) 27 | self.tb_prefix = prefix_str 28 | 29 | self.setup_checks() 30 | self.create_folder_structure() 31 | self.setup_loggers() 32 | self.dump_init_info() 33 | 34 | def create_folder_structure(self): 35 | """ 36 | Creates the folder structure for logging. Subfolders can be added here 37 | """ 38 | base_dir = self.log_folder 39 | sub_folders = ["runs", "models"] 40 | 41 | if not os.path.exists(self.configs["logging_folder"]): 42 | os.mkdir(self.configs["logging_folder"]) 43 | 44 | if not os.path.exists(base_dir): 45 | os.mkdir(base_dir) 46 | 47 | for sf in sub_folders: 48 | if not os.path.exists(os.path.join(base_dir, sf)): 49 | os.mkdir(os.path.join(base_dir, sf)) 50 | 51 | def setup_loggers(self): 52 | """ 53 | Sets up a logger that logs to both file and stdout 54 | """ 55 | log_path = os.path.join(self.log_folder, "log.log") 56 | 57 | self.print_logger = logging.getLogger() 58 | self.print_logger.setLevel( 59 | getattr(logging, self.configs["log_level"].upper(), None) 60 | ) 61 | handlers = [logging.StreamHandler(sys.stdout), logging.FileHandler(log_path)] 62 | formatter = logging.Formatter( 63 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 64 | ) 65 | for h in handlers: 66 | h.setFormatter(formatter) 67 | self.print_logger.addHandler(h) 68 | 69 | # Setup Tensorboard 70 | self.tb = SummaryWriter(os.path.join(self.log_folder, "runs", self.tb_prefix)) 71 | 72 | def setup_checks(self): 73 | """ 74 | Verifies that all changes have been committed 75 | Verifies that hashes match (if continuation) 76 | """ 77 | repo = git.Repo(search_parent_directories=True) 78 | sha = repo.head.object.hexsha 79 | 80 | 81 | # Test for continuation 82 | if self.configs["continuation"]: 83 | self.log_folder = self.configs["logging_folder"] 84 | with open(os.path.join(self.log_folder, "log.log"), "r") as old_log: 85 | for line in old_log: 86 | find_str = "Git hash" 87 | if line.find(find_str) is not -1: 88 | old_sha = line[line.find(find_str) + len(find_str) + 2 : -4] 89 | assert sha == old_sha 90 | 91 | def dump_init_info(self): 92 | """ 93 | Saves important info for replicability 94 | """ 95 | if not self.configs["continuation"]: 96 | self.configs["logging_folder"] = self.log_folder 97 | else: 98 | self.print("=" * 80) 99 | self.print("Continuing log") 100 | self.print("=" * 80) 101 | 102 | repo = git.Repo(search_parent_directories=True) 103 | sha = repo.head.object.hexsha 104 | 105 | self.print("Git hash: {}".format(sha)) 106 | self.print("Dumping YAML file") 107 | self.print("Configs: ", yaml.dump(self.configs)) 108 | 109 | # Save the start of every run 110 | if "start_weights" not in self.configs: 111 | self.configs["start_weights"] = [] 112 | self.configs["start_weights"].append(self.configs["load"]) 113 | 114 | with open(os.path.join(self.log_folder, "configs.yml"), "w") as outfile: 115 | yaml.dump(self.configs, outfile) 116 | self.tb.add_text("hyperparams", str(self.configs)) 117 | 118 | def end_itr(self, weights_path): 119 | """ 120 | Perform all operations needed at end of iteration 121 | 1). Save configs with latest weights 122 | """ 123 | self.configs["latest_weights"] = weights_path 124 | with open(os.path.join(self.log_folder, "configs.yml"), "w") as outfile: 125 | yaml.dump(self.configs, outfile) 126 | 127 | def print(self, *args): 128 | """ 129 | Wrapper for print statement 130 | """ 131 | self.print_logger.info(args) 132 | 133 | -------------------------------------------------------------------------------- /multimodal/mini_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import yaml 4 | 5 | from logger import Logger 6 | from trainers.selfsupervised import selfsupervised 7 | 8 | if __name__ == "__main__": 9 | 10 | # Load the config file 11 | parser = argparse.ArgumentParser(description="Sensor fusion model") 12 | parser.add_argument("--config", help="YAML config file") 13 | parser.add_argument("--notes", default="", help="run notes") 14 | parser.add_argument("--dev", type=bool, default=False, help="run in dev mode") 15 | parser.add_argument( 16 | "--continuation", 17 | type=bool, 18 | default=False, 19 | help="continue a previous run. Will continue the log file", 20 | ) 21 | args = parser.parse_args() 22 | 23 | # Add the yaml to the config args parse 24 | with open(args.config) as f: 25 | configs = yaml.load(f) 26 | 27 | # Merge configs and args 28 | for arg in vars(args): 29 | configs[arg] = getattr(args, arg) 30 | 31 | # Initialize the loggers 32 | logger = Logger(configs) 33 | 34 | # Initialize the trainer 35 | trainer = selfsupervised(configs, logger) 36 | 37 | trainer.train() 38 | -------------------------------------------------------------------------------- /multimodal/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-iprl-lab/multimodal_representation/e302c61a4b4fa884250a5688aae2775952276352/multimodal/models/__init__.py -------------------------------------------------------------------------------- /multimodal/models/base_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-iprl-lab/multimodal_representation/e302c61a4b4fa884250a5688aae2775952276352/multimodal/models/base_models/__init__.py -------------------------------------------------------------------------------- /multimodal/models/base_models/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.models_utils import init_weights 4 | from models.base_models.layers import ( 5 | conv2d, 6 | predict_flow, 7 | deconv, 8 | crop_like 9 | ) 10 | 11 | 12 | class OpticalFlowDecoder(nn.Module): 13 | def __init__(self, z_dim, initailize_weights=True): 14 | """ 15 | Decodes the optical flow and optical flow mask. 16 | """ 17 | super().__init__() 18 | 19 | self.optical_flow_conv = conv2d(2 * z_dim, 64, kernel_size=1, stride=1) 20 | 21 | self.img_deconv6 = deconv(64, 64) 22 | self.img_deconv5 = deconv(64, 32) 23 | self.img_deconv4 = deconv(162, 32) 24 | self.img_deconv3 = deconv(98, 32) 25 | self.img_deconv2 = deconv(98, 32) 26 | 27 | self.predict_optical_flow6 = predict_flow(64) 28 | self.predict_optical_flow5 = predict_flow(162) 29 | self.predict_optical_flow4 = predict_flow(98) 30 | self.predict_optical_flow3 = predict_flow(98) 31 | self.predict_optical_flow2 = predict_flow(66) 32 | 33 | self.upsampled_optical_flow6_to_5 = nn.ConvTranspose2d( 34 | 2, 2, 4, 2, 1, bias=False 35 | ) 36 | self.upsampled_optical_flow5_to_4 = nn.ConvTranspose2d( 37 | 2, 2, 4, 2, 1, bias=False 38 | ) 39 | self.upsampled_optical_flow4_to_3 = nn.ConvTranspose2d( 40 | 2, 2, 4, 2, 1, bias=False 41 | ) 42 | self.upsampled_optical_flow3_to_2 = nn.ConvTranspose2d( 43 | 2, 2, 4, 2, 1, bias=False 44 | ) 45 | 46 | self.predict_optical_flow2_mask = nn.Conv2d( 47 | 66, 1, kernel_size=3, stride=1, padding=1, bias=False 48 | ) 49 | 50 | if initailize_weights: 51 | init_weights(self.modules()) 52 | 53 | def forward(self, tiled_feat, img_out_convs): 54 | """ 55 | Predicts the optical flow and optical flow mask. 56 | 57 | Args: 58 | tiled_feat: action conditioned z (output of fusion + action network) 59 | img_out_convs: outputs of the image encoders (skip connections) 60 | """ 61 | out_img_conv1, out_img_conv2, out_img_conv3, out_img_conv4, out_img_conv5, out_img_conv6 = ( 62 | img_out_convs 63 | ) 64 | 65 | optical_flow_in_f = torch.cat([out_img_conv6, tiled_feat], 1) 66 | optical_flow_in_f2 = self.optical_flow_conv(optical_flow_in_f) 67 | optical_flow_in_feat = self.img_deconv6(optical_flow_in_f2) 68 | 69 | # predict optical flow pyramids 70 | optical_flow6 = self.predict_optical_flow6(optical_flow_in_feat) 71 | optical_flow6_up = crop_like( 72 | self.upsampled_optical_flow6_to_5(optical_flow6), out_img_conv5 73 | ) 74 | out_img_deconv5 = crop_like( 75 | self.img_deconv5(optical_flow_in_feat), out_img_conv5 76 | ) 77 | 78 | concat5 = torch.cat((out_img_conv5, out_img_deconv5, optical_flow6_up), 1) 79 | optical_flow5 = self.predict_optical_flow5(concat5) 80 | optical_flow5_up = crop_like( 81 | self.upsampled_optical_flow5_to_4(optical_flow5), out_img_conv4 82 | ) 83 | out_img_deconv4 = crop_like(self.img_deconv4(concat5), out_img_conv4) 84 | 85 | concat4 = torch.cat((out_img_conv4, out_img_deconv4, optical_flow5_up), 1) 86 | optical_flow4 = self.predict_optical_flow4(concat4) 87 | optical_flow4_up = crop_like( 88 | self.upsampled_optical_flow4_to_3(optical_flow4), out_img_conv3 89 | ) 90 | out_img_deconv3 = crop_like(self.img_deconv3(concat4), out_img_conv3) 91 | 92 | concat3 = torch.cat((out_img_conv3, out_img_deconv3, optical_flow4_up), 1) 93 | optical_flow3 = self.predict_optical_flow3(concat3) 94 | optical_flow3_up = crop_like( 95 | self.upsampled_optical_flow3_to_2(optical_flow3), out_img_conv2 96 | ) 97 | out_img_deconv2 = crop_like(self.img_deconv2(concat3), out_img_conv2) 98 | 99 | concat2 = torch.cat((out_img_conv2, out_img_deconv2, optical_flow3_up), 1) 100 | 101 | optical_flow2_unmasked = self.predict_optical_flow2(concat2) 102 | 103 | optical_flow2_mask = self.predict_optical_flow2_mask(concat2) 104 | 105 | optical_flow2 = optical_flow2_unmasked * torch.sigmoid(optical_flow2_mask) 106 | 107 | return optical_flow2, optical_flow2_mask 108 | 109 | 110 | class EeDeltaDecoder(nn.Module): 111 | def __init__(self, z_dim, action_dim, initailize_weights=True): 112 | """ 113 | Decodes the EE Delta 114 | """ 115 | super().__init__() 116 | 117 | self.ee_delta_decoder = nn.Sequential( 118 | nn.Linear(z_dim, 128), 119 | nn.LeakyReLU(0.1, inplace=True), 120 | nn.Linear(128, 64), 121 | nn.LeakyReLU(0.1, inplace=True), 122 | nn.Linear(64, 32), 123 | nn.LeakyReLU(0.1, inplace=True), 124 | nn.Linear(32, action_dim), 125 | ) 126 | 127 | if initailize_weights: 128 | init_weights(self.modules()) 129 | 130 | def forward(self, mm_act_feat): 131 | return self.ee_delta_decoder(mm_act_feat) 132 | -------------------------------------------------------------------------------- /multimodal/models/base_models/encoders.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.models_utils import init_weights 3 | from models.base_models.layers import CausalConv1D, Flatten, conv2d 4 | 5 | 6 | class ProprioEncoder(nn.Module): 7 | def __init__(self, z_dim, initailize_weights=True): 8 | """ 9 | Image encoder taken from selfsupervised code 10 | """ 11 | super().__init__() 12 | self.z_dim = z_dim 13 | 14 | self.proprio_encoder = nn.Sequential( 15 | nn.Linear(8, 32), 16 | nn.LeakyReLU(0.1, inplace=True), 17 | nn.Linear(32, 64), 18 | nn.LeakyReLU(0.1, inplace=True), 19 | nn.Linear(64, 128), 20 | nn.LeakyReLU(0.1, inplace=True), 21 | nn.Linear(128, 2 * self.z_dim), 22 | nn.LeakyReLU(0.1, inplace=True), 23 | ) 24 | 25 | if initailize_weights: 26 | init_weights(self.modules()) 27 | 28 | def forward(self, proprio): 29 | return self.proprio_encoder(proprio).unsqueeze(2) 30 | 31 | 32 | class ForceEncoder(nn.Module): 33 | def __init__(self, z_dim, initailize_weights=True): 34 | """ 35 | Force encoder taken from selfsupervised code 36 | """ 37 | super().__init__() 38 | self.z_dim = z_dim 39 | 40 | self.frc_encoder = nn.Sequential( 41 | CausalConv1D(6, 16, kernel_size=2, stride=2), 42 | nn.LeakyReLU(0.1, inplace=True), 43 | CausalConv1D(16, 32, kernel_size=2, stride=2), 44 | nn.LeakyReLU(0.1, inplace=True), 45 | CausalConv1D(32, 64, kernel_size=2, stride=2), 46 | nn.LeakyReLU(0.1, inplace=True), 47 | CausalConv1D(64, 128, kernel_size=2, stride=2), 48 | nn.LeakyReLU(0.1, inplace=True), 49 | CausalConv1D(128, 2 * self.z_dim, kernel_size=2, stride=2), 50 | nn.LeakyReLU(0.1, inplace=True), 51 | ) 52 | 53 | if initailize_weights: 54 | init_weights(self.modules()) 55 | 56 | def forward(self, force): 57 | return self.frc_encoder(force) 58 | 59 | 60 | class ImageEncoder(nn.Module): 61 | def __init__(self, z_dim, initailize_weights=True): 62 | """ 63 | Image encoder taken from Making Sense of Vision and Touch 64 | """ 65 | super().__init__() 66 | self.z_dim = z_dim 67 | 68 | self.img_conv1 = conv2d(3, 16, kernel_size=7, stride=2) 69 | self.img_conv2 = conv2d(16, 32, kernel_size=5, stride=2) 70 | self.img_conv3 = conv2d(32, 64, kernel_size=5, stride=2) 71 | self.img_conv4 = conv2d(64, 64, stride=2) 72 | self.img_conv5 = conv2d(64, 128, stride=2) 73 | self.img_conv6 = conv2d(128, self.z_dim, stride=2) 74 | self.img_encoder = nn.Linear(4 * self.z_dim, 2 * self.z_dim) 75 | self.flatten = Flatten() 76 | 77 | if initailize_weights: 78 | init_weights(self.modules()) 79 | 80 | def forward(self, image): 81 | # image encoding layers 82 | out_img_conv1 = self.img_conv1(image) 83 | out_img_conv2 = self.img_conv2(out_img_conv1) 84 | out_img_conv3 = self.img_conv3(out_img_conv2) 85 | out_img_conv4 = self.img_conv4(out_img_conv3) 86 | out_img_conv5 = self.img_conv5(out_img_conv4) 87 | out_img_conv6 = self.img_conv6(out_img_conv5) 88 | 89 | img_out_convs = ( 90 | out_img_conv1, 91 | out_img_conv2, 92 | out_img_conv3, 93 | out_img_conv4, 94 | out_img_conv5, 95 | out_img_conv6, 96 | ) 97 | 98 | # image embedding parameters 99 | flattened = self.flatten(out_img_conv6) 100 | img_out = self.img_encoder(flattened).unsqueeze(2) 101 | 102 | return img_out, img_out_convs 103 | 104 | 105 | class DepthEncoder(nn.Module): 106 | def __init__(self, z_dim, initailize_weights=True): 107 | """ 108 | Simplified Depth Encoder taken from Making Sense of Vision and Touch 109 | """ 110 | super().__init__() 111 | self.z_dim = z_dim 112 | 113 | self.depth_conv1 = conv2d(1, 32, kernel_size=3, stride=2) 114 | self.depth_conv2 = conv2d(32, 64, kernel_size=3, stride=2) 115 | self.depth_conv3 = conv2d(64, 64, kernel_size=4, stride=2) 116 | self.depth_conv4 = conv2d(64, 64, stride=2) 117 | self.depth_conv5 = conv2d(64, 128, stride=2) 118 | self.depth_conv6 = conv2d(128, self.z_dim, stride=2) 119 | 120 | self.depth_encoder = nn.Linear(4 * self.z_dim, 2 * self.z_dim) 121 | self.flatten = Flatten() 122 | 123 | if initailize_weights: 124 | init_weights(self.modules()) 125 | 126 | def forward(self, depth): 127 | # depth encoding layers 128 | out_depth_conv1 = self.depth_conv1(depth) 129 | out_depth_conv2 = self.depth_conv2(out_depth_conv1) 130 | out_depth_conv3 = self.depth_conv3(out_depth_conv2) 131 | out_depth_conv4 = self.depth_conv4(out_depth_conv3) 132 | out_depth_conv5 = self.depth_conv5(out_depth_conv4) 133 | out_depth_conv6 = self.depth_conv6(out_depth_conv5) 134 | 135 | depth_out_convs = ( 136 | out_depth_conv1, 137 | out_depth_conv2, 138 | out_depth_conv3, 139 | out_depth_conv4, 140 | out_depth_conv5, 141 | out_depth_conv6, 142 | ) 143 | 144 | # depth embedding parameters 145 | flattened = self.flatten(out_depth_conv6) 146 | depth_out = self.depth_encoder(flattened).unsqueeze(2) 147 | 148 | return depth_out, depth_out_convs 149 | -------------------------------------------------------------------------------- /multimodal/models/base_models/layers.py: -------------------------------------------------------------------------------- 1 | """Neural network layers. 2 | """ 3 | 4 | import torch.nn as nn 5 | 6 | 7 | def crop_like(input, target): 8 | if input.size()[2:] == target.size()[2:]: 9 | return input 10 | else: 11 | return input[:, :, : target.size(2), : target.size(3)] 12 | 13 | 14 | def deconv(in_planes, out_planes): 15 | return nn.Sequential( 16 | nn.ConvTranspose2d( 17 | in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False 18 | ), 19 | nn.LeakyReLU(0.1, inplace=True), 20 | ) 21 | 22 | 23 | def predict_flow(in_planes): 24 | return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=False) 25 | 26 | 27 | def conv2d(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 28 | """`same` convolution with LeakyReLU, i.e. output shape equals input shape. 29 | Args: 30 | in_planes (int): The number of input feature maps. 31 | out_planes (int): The number of output feature maps. 32 | kernel_size (int): The filter size. 33 | dilation (int): The filter dilation factor. 34 | stride (int): The filter stride. 35 | """ 36 | # compute new filter size after dilation 37 | # and necessary padding for `same` output size 38 | dilated_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size 39 | same_padding = (dilated_kernel_size - 1) // 2 40 | 41 | return nn.Sequential( 42 | nn.Conv2d( 43 | in_channels, 44 | out_channels, 45 | kernel_size=kernel_size, 46 | stride=stride, 47 | padding=same_padding, 48 | dilation=dilation, 49 | bias=bias, 50 | ), 51 | nn.LeakyReLU(0.1, inplace=True), 52 | ) 53 | 54 | 55 | class View(nn.Module): 56 | def __init__(self, size): 57 | super(View, self).__init__() 58 | self.size = size 59 | 60 | def forward(self, tensor): 61 | return tensor.view(self.size) 62 | 63 | 64 | class Flatten(nn.Module): 65 | """Flattens convolutional feature maps for fc layers. 66 | """ 67 | 68 | def __init__(self): 69 | super().__init__() 70 | 71 | def forward(self, x): 72 | return x.reshape(x.size(0), -1) 73 | 74 | 75 | class CausalConv1D(nn.Conv1d): 76 | """A causal 1D convolution. 77 | """ 78 | 79 | def __init__( 80 | self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True 81 | ): 82 | self.__padding = (kernel_size - 1) * dilation 83 | 84 | super().__init__( 85 | in_channels, 86 | out_channels, 87 | kernel_size=kernel_size, 88 | stride=stride, 89 | padding=self.__padding, 90 | dilation=dilation, 91 | bias=bias, 92 | ) 93 | 94 | def forward(self, x): 95 | res = super().forward(x) 96 | if self.__padding != 0: 97 | return res[:, :, : -self.__padding] 98 | return res 99 | 100 | 101 | class ResidualBlock(nn.Module): 102 | """A simple residual block. 103 | """ 104 | 105 | def __init__(self, channels): 106 | super().__init__() 107 | 108 | self.conv1 = conv2d(channels, channels, bias=False) 109 | self.conv2 = conv2d(channels, channels, bias=False) 110 | self.bn1 = nn.BatchNorm2d(channels) 111 | self.bn2 = nn.BatchNorm2d(channels) 112 | self.act = nn.LeakyReLU(0.1, inplace=True) # nn.ReLU(inplace=True) 113 | 114 | def forward(self, x): 115 | out = self.act(x) 116 | out = self.act(self.bn1(self.conv1(out))) 117 | out = self.bn2(self.conv2(out)) 118 | return out + x 119 | -------------------------------------------------------------------------------- /multimodal/models/models_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torch.distributions import Normal 5 | 6 | 7 | def init_weights(modules): 8 | """ 9 | Weight initialization from original SensorFusion Code 10 | """ 11 | for m in modules: 12 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 13 | nn.init.kaiming_normal_(m.weight.data) 14 | if m.bias is not None: 15 | m.bias.data.zero_() 16 | elif isinstance(m, nn.BatchNorm2d): 17 | m.weight.data.fill_(1) 18 | m.bias.data.zero_() 19 | 20 | 21 | def sample_gaussian(m, v, device): 22 | 23 | epsilon = Normal(0, 1).sample(m.size()) 24 | z = m + torch.sqrt(v) * epsilon.to(device) 25 | 26 | return z 27 | 28 | 29 | def gaussian_parameters(h, dim=-1): 30 | 31 | m, h = torch.split(h, h.size(dim) // 2, dim=dim) 32 | v = F.softplus(h) + 1e-8 33 | return m, v 34 | 35 | 36 | def product_of_experts(m_vect, v_vect): 37 | 38 | T_vect = 1.0 / v_vect 39 | 40 | mu = (m_vect * T_vect).sum(2) * (1 / T_vect.sum(2)) 41 | var = 1 / T_vect.sum(2) 42 | 43 | return mu, var 44 | 45 | 46 | def duplicate(x, rep): 47 | 48 | return x.expand(rep, *x.shape).reshape(-1, *x.shape[1:]) 49 | 50 | 51 | def depth_deconv(in_planes, out_planes): 52 | return nn.Sequential( 53 | nn.Conv2d( 54 | in_planes, 16, kernel_size=3, stride=1, padding=(3 - 1) // 2, bias=True 55 | ), 56 | nn.LeakyReLU(0.1, inplace=True), 57 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=(3 - 1) // 2, bias=True), 58 | nn.LeakyReLU(0.1, inplace=True), 59 | nn.ConvTranspose2d( 60 | 16, out_planes, kernel_size=4, stride=2, padding=1, bias=True 61 | ), 62 | nn.LeakyReLU(0.1, inplace=True), 63 | ) 64 | 65 | 66 | def rescaleImage(image, output_size=128, scale=1 / 255.0): 67 | """Rescale the image in a sample to a given size. 68 | Args: 69 | output_size (tuple or int): Desired output size. If tuple, output is 70 | matched to output_size. If int, smaller of image edges is matched 71 | to output_size keeping aspect ratio the same. 72 | """ 73 | image_transform = image * scale 74 | return image_transform.transpose(1, 3).transpose(2, 3) 75 | 76 | 77 | def filter_depth(depth_image): 78 | depth_image = torch.where( 79 | depth_image > 1e-7, depth_image, torch.zeros_like(depth_image) 80 | ) 81 | return torch.where(depth_image < 2, depth_image, torch.zeros_like(depth_image)) 82 | -------------------------------------------------------------------------------- /multimodal/models/sensor_fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.models_utils import ( 4 | duplicate, 5 | gaussian_parameters, 6 | rescaleImage, 7 | product_of_experts, 8 | sample_gaussian, 9 | filter_depth, 10 | ) 11 | from models.base_models.encoders import ( 12 | ProprioEncoder, 13 | ForceEncoder, 14 | ImageEncoder, 15 | DepthEncoder, 16 | ) 17 | from models.base_models.decoders import ( 18 | OpticalFlowDecoder, 19 | EeDeltaDecoder 20 | ) 21 | 22 | 23 | class SensorFusion(nn.Module): 24 | """ 25 | # 26 | Regular SensorFusionNetwork Architecture 27 | Number of parameters: 28 | Inputs: 29 | image: batch_size x 3 x 128 x 128 30 | force: batch_size x 6 x 32 31 | proprio: batch_size x 8 32 | action: batch_size x action_dim 33 | """ 34 | 35 | def __init__( 36 | self, device, z_dim=128, action_dim=4, encoder=False, deterministic=False 37 | ): 38 | super().__init__() 39 | 40 | self.z_dim = z_dim 41 | self.encoder_bool = encoder 42 | self.device = device 43 | self.deterministic = deterministic 44 | 45 | # zero centered, 1 std normal distribution 46 | self.z_prior_m = torch.nn.Parameter( 47 | torch.zeros(1, self.z_dim), requires_grad=False 48 | ) 49 | self.z_prior_v = torch.nn.Parameter( 50 | torch.ones(1, self.z_dim), requires_grad=False 51 | ) 52 | self.z_prior = (self.z_prior_m, self.z_prior_v) 53 | 54 | # ----------------------- 55 | # Modality Encoders 56 | # ----------------------- 57 | self.img_encoder = ImageEncoder(self.z_dim) 58 | self.depth_encoder = DepthEncoder(self.z_dim) 59 | self.frc_encoder = ForceEncoder(self.z_dim) 60 | self.proprio_encoder = ProprioEncoder(self.z_dim) 61 | 62 | # ----------------------- 63 | # Action Encoders 64 | # ----------------------- 65 | self.action_encoder = nn.Sequential( 66 | nn.Linear(action_dim, 32), 67 | nn.LeakyReLU(0.1, inplace=True), 68 | nn.Linear(32, 32), 69 | nn.LeakyReLU(0.1, inplace=True), 70 | ) 71 | 72 | # ----------------------- 73 | # action fusion network 74 | # ----------------------- 75 | self.st_fusion_fc1 = nn.Sequential( 76 | nn.Linear(32 + self.z_dim, 128), nn.LeakyReLU(0.1, inplace=True) 77 | ) 78 | self.st_fusion_fc2 = nn.Sequential( 79 | nn.Linear(128, self.z_dim), nn.LeakyReLU(0.1, inplace=True) 80 | ) 81 | 82 | if deterministic: 83 | # ----------------------- 84 | # modality fusion network 85 | # ----------------------- 86 | # 4 Total modalities each (2 * z_dim) 87 | self.fusion_fc1 = nn.Sequential( 88 | nn.Linear(4 * 2 * self.z_dim, 128), nn.LeakyReLU(0.1, inplace=True) 89 | ) 90 | self.fusion_fc2 = nn.Sequential( 91 | nn.Linear(self.z_dim, self.z_dim), nn.LeakyReLU(0.1, inplace=True) 92 | ) 93 | 94 | # ----------------------- 95 | # weight initialization 96 | # ----------------------- 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 99 | nn.init.kaiming_normal_(m.weight.data) 100 | if m.bias is not None: 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | 106 | def forward_encoder(self, vis_in, frc_in, proprio_in, depth_in, action_in): 107 | 108 | # batch size 109 | batch_dim = vis_in.size()[0] 110 | 111 | image = rescaleImage(vis_in) 112 | depth = filter_depth(depth_in) 113 | 114 | # Get encoded outputs 115 | img_out, img_out_convs = self.img_encoder(image) 116 | depth_out, depth_out_convs = self.depth_encoder(depth) 117 | frc_out = self.frc_encoder(frc_in) 118 | proprio_out = self.proprio_encoder(proprio_in) 119 | 120 | if self.deterministic: 121 | # multimodal embedding 122 | mm_f1 = torch.cat([img_out, frc_out, proprio_out, depth_out], 1).squeeze() 123 | mm_f2 = self.fusion_fc1(mm_f1) 124 | z = self.fusion_fc2(mm_f2) 125 | 126 | else: 127 | # Encoder priors 128 | mu_prior, var_prior = self.z_prior 129 | 130 | # Duplicate prior parameters for each data point in the batch 131 | mu_prior_resized = duplicate(mu_prior, batch_dim).unsqueeze(2) 132 | var_prior_resized = duplicate(var_prior, batch_dim).unsqueeze(2) 133 | 134 | # Modality Mean and Variances 135 | mu_z_img, var_z_img = gaussian_parameters(img_out, dim=1) 136 | mu_z_frc, var_z_frc = gaussian_parameters(frc_out, dim=1) 137 | mu_z_proprio, var_z_proprio = gaussian_parameters(proprio_out, dim=1) 138 | mu_z_depth, var_z_depth = gaussian_parameters(depth_out, dim=1) 139 | 140 | # Tile distribution parameters using concatonation 141 | m_vect = torch.cat( 142 | [mu_z_img, mu_z_frc, mu_z_proprio, mu_z_depth, mu_prior_resized], dim=2 143 | ) 144 | var_vect = torch.cat( 145 | [var_z_img, var_z_frc, var_z_proprio, var_z_depth, var_prior_resized], 146 | dim=2, 147 | ) 148 | 149 | # Fuse modalities mean / variances using product of experts 150 | mu_z, var_z = product_of_experts(m_vect, var_vect) 151 | 152 | # Sample Gaussian to get latent 153 | z = sample_gaussian(mu_z, var_z, self.device) 154 | 155 | if self.encoder_bool or action_in is None: 156 | if self.deterministic: 157 | return img_out, frc_out, proprio_out, depth_out, z 158 | else: 159 | return img_out_convs, img_out, frc_out, proprio_out, depth_out, z 160 | else: 161 | # action embedding 162 | act_feat = self.action_encoder(action_in) 163 | 164 | # state-action feature 165 | mm_act_f1 = torch.cat([z, act_feat], 1) 166 | mm_act_f2 = self.st_fusion_fc1(mm_act_f1) 167 | mm_act_feat = self.st_fusion_fc2(mm_act_f2) 168 | 169 | if self.deterministic: 170 | return img_out_convs, mm_act_feat, z 171 | else: 172 | return img_out_convs, mm_act_feat, z, mu_z, var_z, mu_prior, var_prior 173 | 174 | def weight_parameters(self): 175 | return [param for name, param in self.named_parameters() if "weight" in name] 176 | 177 | def bias_parameters(self): 178 | return [param for name, param in self.named_parameters() if "bias" in name] 179 | 180 | 181 | class SensorFusionSelfSupervised(SensorFusion): 182 | """ 183 | Regular SensorFusionNetwork Architecture 184 | Inputs: 185 | image: batch_size x 3 x 128 x 128 186 | force: batch_size x 6 x 32 187 | proprio: batch_size x 8 188 | action: batch_size x action_dim 189 | """ 190 | 191 | def __init__( 192 | self, device, z_dim=128, action_dim=4, encoder=False, deterministic=False 193 | ): 194 | 195 | super().__init__(device, z_dim, action_dim, encoder, deterministic) 196 | 197 | self.deterministic = deterministic 198 | 199 | # ----------------------- 200 | # optical flow predictor 201 | # ----------------------- 202 | self.optical_flow_decoder = OpticalFlowDecoder(z_dim) 203 | 204 | # ----------------------- 205 | # ee delta decoder 206 | # ----------------------- 207 | self.ee_delta_decoder = EeDeltaDecoder(z_dim, action_dim) 208 | 209 | # ----------------------- 210 | # pairing decoder 211 | # ----------------------- 212 | self.pair_fc = nn.Sequential(nn.Linear(self.z_dim, 1)) 213 | 214 | # ----------------------- 215 | # contact decoder 216 | # ----------------------- 217 | self.contact_fc = nn.Sequential(nn.Linear(self.z_dim, 1)) 218 | 219 | # ----------------------- 220 | # weight initialization 221 | # ----------------------- 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 224 | nn.init.kaiming_normal_(m.weight.data) 225 | if m.bias is not None: 226 | m.bias.data.zero_() 227 | elif isinstance(m, nn.BatchNorm2d): 228 | m.weight.data.fill_(1) 229 | m.bias.data.zero_() 230 | 231 | def forward( 232 | self, 233 | vis_in, 234 | frc_in, 235 | proprio_in, 236 | depth_in, 237 | action_in, 238 | ): 239 | 240 | if self.encoder_bool: 241 | # returning latent space representation if model is set in encoder mode 242 | z = self.forward_encoder(vis_in, frc_in, proprio_in, depth_in, action_in) 243 | return z 244 | 245 | elif action_in is None: 246 | z = self.forward_encoder(vis_in, frc_in, proprio_in, depth_in, None) 247 | pair_out = self.pair_fc(z) 248 | return pair_out 249 | 250 | else: 251 | if self.deterministic: 252 | img_out_convs, mm_act_feat, z = self.forward_encoder( 253 | vis_in, frc_in, proprio_in, depth_in, action_in 254 | ) 255 | else: 256 | img_out_convs, mm_act_feat, z, mu_z, var_z, mu_prior, var_prior = self.forward_encoder( 257 | vis_in, 258 | frc_in, 259 | proprio_in, 260 | depth_in, 261 | action_in, 262 | ) 263 | 264 | # ---------------- Training Objectives ---------------- 265 | 266 | # tile state-action features and append to conv map 267 | batch_dim = mm_act_feat.size(0) # batch size 268 | tiled_feat = mm_act_feat.view(batch_dim, self.z_dim, 1, 1).expand(-1, -1, 2, 2) 269 | 270 | # -------------------------------------# 271 | # Pairing / Contact / EE Delta Decoder # 272 | # -------------------------------------# 273 | pair_out = self.pair_fc(z) 274 | contact_out = self.contact_fc(mm_act_feat) 275 | ee_delta_out = self.ee_delta_decoder(mm_act_feat) 276 | 277 | # -------------------------# 278 | # Optical Flow Prediction # 279 | # -------------------------# 280 | optical_flow2, optical_flow2_mask = self.optical_flow_decoder( 281 | tiled_feat, img_out_convs 282 | ) 283 | 284 | if self.deterministic: 285 | return ( 286 | pair_out, 287 | contact_out, 288 | optical_flow2, 289 | optical_flow2_mask, 290 | ee_delta_out, 291 | z, 292 | ) 293 | else: 294 | return ( 295 | pair_out, 296 | contact_out, 297 | optical_flow2, 298 | optical_flow2_mask, 299 | ee_delta_out, 300 | z, 301 | mu_z, 302 | var_z, 303 | mu_prior, 304 | var_prior, 305 | ) 306 | -------------------------------------------------------------------------------- /multimodal/models/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-iprl-lab/multimodal_representation/e302c61a4b4fa884250a5688aae2775952276352/multimodal/models/tests/__init__.py -------------------------------------------------------------------------------- /multimodal/models/tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from models.base_models.layers import conv2d, CausalConv1D 9 | 10 | 11 | class TestConv2d: 12 | def test_same_shape_no_dilation(self): 13 | x = torch.randn(1, 1, 5, 5) 14 | conv = conv2d(1, 1, 3) 15 | with torch.no_grad(): 16 | out = conv(x) 17 | assert out.shape[2:] == x.shape[2:] 18 | 19 | def test_same_shape_with_dilation(self): 20 | x = torch.randn(1, 1, 5, 5) 21 | conv = conv2d(1, 1, 3, dilation=2) 22 | with torch.no_grad(): 23 | out = conv(x) 24 | assert out.shape[2:] == x.shape[2:] 25 | 26 | 27 | class TestCausalConv1d: 28 | def test_same_shape_no_dilation(self): 29 | x = torch.randn(1, 1, 6) 30 | conv1d = CausalConv1D(1, 1, 3) 31 | with torch.no_grad(): 32 | out = conv1d(x) 33 | assert out.shape[2:] == x.shape[2:] 34 | 35 | def test_same_shape_with_dilation(self): 36 | x = torch.randn(1, 1, 6) 37 | conv1d = CausalConv1D(1, 1, 3, dilation=2) 38 | with torch.no_grad(): 39 | out = conv1d(x) 40 | assert out.shape[2:] == x.shape[2:] 41 | 42 | def test_causality_no_dilation(self): 43 | stride = 1 44 | length = 6 45 | dilation = 1 46 | kernel_size = 3 47 | x = torch.randn(1, 1, length) 48 | conv1d = CausalConv1D(1, 1, kernel_size, stride, dilation, bias=False) 49 | with torch.no_grad(): 50 | actual = conv1d(x) 51 | actual = actual.numpy().squeeze() 52 | weight = conv1d.weight.detach().clone().squeeze().numpy() 53 | padding = (int((kernel_size - 1) * dilation), 0) 54 | padded_x = F.pad(x, padding).detach().squeeze().numpy() 55 | expected = [] 56 | for i in range(length): 57 | expected.append(weight @ padded_x[i : i + 3]) 58 | expected = np.asarray(expected) 59 | assert np.allclose(actual, expected) 60 | -------------------------------------------------------------------------------- /multimodal/scripts/run_all_tests.sh: -------------------------------------------------------------------------------- 1 | python -m pytest models/tests -------------------------------------------------------------------------------- /multimodal/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-iprl-lab/multimodal_representation/e302c61a4b4fa884250a5688aae2775952276352/multimodal/trainers/__init__.py -------------------------------------------------------------------------------- /multimodal/trainers/selfsupervised.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import os 10 | from tqdm import tqdm 11 | 12 | from models.sensor_fusion import SensorFusionSelfSupervised 13 | from utils import ( 14 | kl_normal, 15 | realEPE, 16 | compute_accuracy, 17 | flow2rgb, 18 | set_seeds, 19 | augment_val, 20 | ) 21 | 22 | from dataloaders import ProcessForce, ToTensor 23 | from dataloaders import MultimodalManipulationDataset 24 | from torch.utils.data import DataLoader 25 | from torch.utils.data.sampler import SubsetRandomSampler 26 | from torchvision import transforms 27 | 28 | 29 | class selfsupervised: 30 | def __init__(self, configs, logger): 31 | 32 | # ------------------------ 33 | # Sets seed and cuda 34 | # ------------------------ 35 | use_cuda = configs["cuda"] and torch.cuda.is_available() 36 | 37 | self.configs = configs 38 | self.logger = logger 39 | self.device = torch.device("cuda" if use_cuda else "cpu") 40 | 41 | if use_cuda: 42 | logger.print("Let's use", torch.cuda.device_count(), "GPUs!") 43 | 44 | set_seeds(configs["seed"], use_cuda) 45 | 46 | # model 47 | self.model = SensorFusionSelfSupervised( 48 | device=self.device, 49 | encoder=configs["encoder"], 50 | deterministic=configs["deterministic"], 51 | z_dim=configs["zdim"], 52 | action_dim=configs["action_dim"], 53 | ).to(self.device) 54 | 55 | self.optimizer = optim.Adam( 56 | self.model.parameters(), 57 | lr=self.configs["lr"], 58 | betas=(self.configs["beta1"], 0.999), 59 | weight_decay=0.0, 60 | ) 61 | 62 | self.deterministic = configs["deterministic"] 63 | self.encoder = configs["encoder"] 64 | 65 | # losses 66 | self.loss_ee_pos = nn.MSELoss() 67 | self.loss_contact_next = nn.BCEWithLogitsLoss() 68 | self.loss_optical_flow_mask = nn.BCEWithLogitsLoss() 69 | self.loss_reward_prediction = nn.MSELoss() 70 | self.loss_is_paired = nn.BCEWithLogitsLoss() 71 | self.loss_dynamics = nn.MSELoss() 72 | 73 | # validation set variables 74 | self.val_contact_accuracy = 0.0 75 | self.val_paired_accuracy = 0.0 76 | 77 | # test set variables 78 | self.test_flow_loss = 0.0 79 | self.test_paired_accuracy = 0.0 80 | self.test_contact_accuracy = 0.0 81 | 82 | # Weights for loss 83 | self.alpha_optical_flow = 10.0 * configs["opticalflow"] 84 | self.alpha_optical_flow_mask = 1.0 85 | self.alpha_kl = 0.05 86 | self.alpha_contact = 1.0 * configs["contact"] 87 | self.alpha_pair = 0.5 * configs["pairing"] 88 | self.alpha_ee_fut = 1.0 * configs["eedelta"] 89 | 90 | # Weights for input 91 | self.alpha_vision = configs["vision"] 92 | self.alpha_depth = configs["depth"] 93 | self.alpha_proprio = configs["proprio"] 94 | self.alpha_force = configs["force"] 95 | 96 | # Global Counts For Logging 97 | self.global_cnt = {"train": 0, "val": 0} 98 | 99 | # ------------------------ 100 | # Handles Initialization 101 | # ------------------------ 102 | if configs["load"]: 103 | self.load_model(configs["load"]) 104 | 105 | self._init_dataloaders() 106 | 107 | def train(self): 108 | 109 | for i_epoch in tqdm(range(self.configs["max_epoch"])): 110 | # --------------------------- 111 | # Train Step 112 | # --------------------------- 113 | self.logger.print("Training epoch #{}...".format(i_epoch)) 114 | self.model.train() 115 | 116 | for i_iter, sample_batched in tqdm(enumerate(self.dataloaders["val"])): 117 | 118 | t_st = time.time() 119 | self.optimizer.zero_grad() 120 | 121 | loss, mm_feat, results, image_packet = self.loss_calc(sample_batched) 122 | 123 | loss.backward() 124 | self.optimizer.step() 125 | 126 | self.record_results(loss, results, self.global_cnt["train"], t_st) 127 | 128 | if self.global_cnt["train"] % self.configs["img_record_n"] == 0: 129 | self.logger.print( 130 | "processed {} mini-batches...".format(self.global_cnt["train"]) 131 | ) 132 | self._record_image(image_packet, self.global_cnt["train"]) 133 | 134 | self.global_cnt["train"] += 1 135 | 136 | if self.configs["val_ratio"] != 0: 137 | self.validate(i_epoch) 138 | 139 | # --------------------------- 140 | # Save weights 141 | # --------------------------- 142 | ckpt_path = os.path.join( 143 | self.logger.log_folder, "models", "weights_itr_{}.ckpt".format(i_epoch) 144 | ) 145 | self.logger.print("checkpoint path: ", ckpt_path) 146 | self.logger.print("Saving checkpoint after epoch #{}".format(i_epoch)) 147 | 148 | torch.save(self.model.state_dict(), ckpt_path) 149 | self.logger.end_itr(ckpt_path) 150 | 151 | def validate(self, i_epoch): 152 | self.logger.print( 153 | "calculating validation results after #{} epochs".format(i_epoch) 154 | ) 155 | 156 | self.val_contact_accuracy = 0.0 157 | self.val_paired_accuracy = 0.0 158 | 159 | for i_iter, sample_batched in enumerate(self.dataloaders["val"]): 160 | self.model.eval() 161 | 162 | loss_val, mm_feat_val, results_val, image_packet_val = self.loss_calc( 163 | sample_batched 164 | ) 165 | 166 | flow_loss, contact_loss, is_paired_loss, contact_accuracy, is_paired_accuracy, ee_delta_loss, kl = ( 167 | results_val 168 | ) 169 | 170 | self.val_contact_accuracy += contact_accuracy.item() / self.val_len_data 171 | self.val_paired_accuracy += is_paired_accuracy.item() / self.val_len_data 172 | 173 | if i_iter == 0: 174 | self._record_image( 175 | image_packet_val, self.global_cnt["val"], string="val/" 176 | ) 177 | 178 | self.logger.tb.add_scalar("val/loss/optical_flow", flow_loss.item(), self.global_cnt["val"]) 179 | self.logger.tb.add_scalar("val/loss/contact", contact_loss.item(), self.global_cnt["val"]) 180 | self.logger.tb.add_scalar("val/loss/is_paired", is_paired_loss.item(), self.global_cnt["val"]) 181 | self.logger.tb.add_scalar("val/loss/kl", kl.item(), self.global_cnt["val"]) 182 | self.logger.tb.add_scalar("val/loss/total_loss", loss_val.item(), self.global_cnt["val"]) 183 | self.logger.tb.add_scalar("val/loss/ee_delta", ee_delta_loss, self.global_cnt["val"]) 184 | self.global_cnt["val"] += 1 185 | 186 | # --------------------------- 187 | # Record Epoch Level Variables 188 | # --------------------------- 189 | self.logger.tb.add_scalar( 190 | "val/accuracy/contact", self.val_contact_accuracy, self.global_cnt["val"] 191 | ) 192 | self.logger.tb.add_scalar( 193 | "val/accuracy/is_paired", self.val_paired_accuracy, self.global_cnt["val"] 194 | ) 195 | 196 | def load_model(self, path): 197 | self.logger.print("Loading model from {}...".format(path)) 198 | ckpt = torch.load(path) 199 | self.model.load_state_dict(ckpt) 200 | self.model.eval() 201 | 202 | def loss_calc(self, sampled_batched): 203 | 204 | # input data 205 | image = self.alpha_vision * sampled_batched["image"].to(self.device) 206 | force = self.alpha_force * sampled_batched["force"].to(self.device) 207 | proprio = self.alpha_proprio * sampled_batched["proprio"].to(self.device) 208 | depth = self.alpha_depth * sampled_batched["depth"].to(self.device).transpose( 209 | 1, 3 210 | ).transpose(2, 3) 211 | 212 | action = sampled_batched["action"].to(self.device) 213 | 214 | contact_label = sampled_batched["contact_next"].to(self.device) 215 | optical_flow_label = sampled_batched["flow"].to(self.device) 216 | optical_flow_mask_label = sampled_batched["flow_mask"].to(self.device) 217 | 218 | # unpaired data for sampled point 219 | unpaired_image = self.alpha_vision * sampled_batched["unpaired_image"].to( 220 | self.device 221 | ) 222 | unpaired_force = self.alpha_force * sampled_batched["unpaired_force"].to( 223 | self.device 224 | ) 225 | unpaired_proprio = self.alpha_proprio * sampled_batched["unpaired_proprio"].to( 226 | self.device 227 | ) 228 | unpaired_depth = self.alpha_depth * sampled_batched["unpaired_depth"].to( 229 | self.device 230 | ).transpose(1, 3).transpose(2, 3) 231 | 232 | # labels to predict 233 | gt_ee_pos_delta = sampled_batched["ee_yaw_next"].to(self.device) 234 | 235 | if self.deterministic: 236 | paired_out, contact_out, flow2, optical_flow2_mask, ee_delta_out, mm_feat = self.model( 237 | image, force, proprio, depth, action 238 | ) 239 | kl = torch.tensor([0]).to(self.device).type(torch.cuda.FloatTensor) 240 | else: 241 | paired_out, contact_out, flow2, optical_flow2_mask, ee_delta_out, mm_feat, mu_z, var_z, mu_prior, var_prior = self.model( 242 | image, force, proprio, depth, action 243 | ) 244 | kl = self.alpha_kl * torch.mean( 245 | kl_normal(mu_z, var_z, mu_prior.squeeze(0), var_prior.squeeze(0)) 246 | ) 247 | 248 | flow_loss = self.alpha_optical_flow * realEPE( 249 | flow2, optical_flow_label, self.device 250 | ) 251 | 252 | # Scene flow losses 253 | 254 | b, _, h, w = optical_flow_label.size() 255 | 256 | optical_flow_mask = nn.functional.upsample( 257 | optical_flow2_mask, size=(h, w), mode="bilinear" 258 | ) 259 | 260 | flow_mask_loss = self.alpha_optical_flow_mask * self.loss_optical_flow_mask( 261 | optical_flow_mask, optical_flow_mask_label 262 | ) 263 | 264 | contact_loss = self.alpha_contact * self.loss_contact_next( 265 | contact_out, contact_label 266 | ) 267 | 268 | ee_delta_loss = self.alpha_ee_fut * self.loss_ee_pos( 269 | ee_delta_out, gt_ee_pos_delta 270 | ) 271 | 272 | paired_loss = self.alpha_pair * self.loss_is_paired( 273 | paired_out, torch.ones(paired_out.size(0), 1).to(self.device) 274 | ) 275 | 276 | unpaired_total_losses = self.model( 277 | unpaired_image, unpaired_force, unpaired_proprio, unpaired_depth, action 278 | ) 279 | unpaired_out = unpaired_total_losses[0] 280 | unpaired_loss = self.alpha_pair * self.loss_is_paired( 281 | unpaired_out, torch.zeros(unpaired_out.size(0), 1).to(self.device) 282 | ) 283 | 284 | loss = ( 285 | contact_loss 286 | + paired_loss 287 | + unpaired_loss 288 | + ee_delta_loss 289 | + kl 290 | + flow_loss 291 | + flow_mask_loss 292 | ) 293 | 294 | contact_pred = nn.Sigmoid()(contact_out).detach() 295 | contact_accuracy = compute_accuracy(contact_pred, contact_label.detach()) 296 | 297 | paired_pred = nn.Sigmoid()(paired_out).detach() 298 | paired_accuracy = compute_accuracy( 299 | paired_pred, torch.ones(paired_pred.size()[0], 1, device=self.device) 300 | ) 301 | 302 | unpaired_pred = nn.Sigmoid()(unpaired_out).detach() 303 | unpaired_accuracy = compute_accuracy( 304 | unpaired_pred, torch.zeros(unpaired_pred.size()[0], 1, device=self.device) 305 | ) 306 | 307 | is_paired_accuracy = (paired_accuracy + unpaired_accuracy) / 2.0 308 | 309 | # logging 310 | is_paired_loss = paired_loss + unpaired_loss 311 | 312 | return ( 313 | loss, 314 | mm_feat, 315 | ( 316 | flow_loss, 317 | contact_loss, 318 | is_paired_loss, 319 | contact_accuracy, 320 | is_paired_accuracy, 321 | ee_delta_loss, 322 | kl, 323 | ), 324 | (flow2, optical_flow_label, image), 325 | ) 326 | 327 | def record_results(self, total_loss, results, global_cnt, t_st): 328 | 329 | flow_loss, contact_loss, is_paired_loss, contact_accuracy, is_paired_accuracy, ee_delta_loss, kl = ( 330 | results 331 | ) 332 | 333 | self.logger.tb.add_scalar("loss/optical_flow", flow_loss.item(), global_cnt) 334 | self.logger.tb.add_scalar("loss/contact", contact_loss.item(), global_cnt) 335 | self.logger.tb.add_scalar("loss/is_paired", is_paired_loss.item(), global_cnt) 336 | self.logger.tb.add_scalar("loss/kl", kl.item(), global_cnt) 337 | self.logger.tb.add_scalar("loss/total_loss", total_loss.item(), global_cnt) 338 | self.logger.tb.add_scalar("loss/ee_delta", ee_delta_loss, global_cnt) 339 | 340 | self.logger.tb.add_scalar( 341 | "accuracy/contact", contact_accuracy.item(), global_cnt 342 | ) 343 | self.logger.tb.add_scalar( 344 | "accuracy/is_paired", is_paired_accuracy.item(), global_cnt 345 | ) 346 | 347 | self.logger.tb.add_scalar("stats/iter_time", time.time() - t_st, global_cnt) 348 | 349 | def _init_dataloaders(self): 350 | 351 | filename_list = [] 352 | for file in os.listdir(self.configs["dataset"]): 353 | if file.endswith(".h5"): 354 | filename_list.append(self.configs["dataset"] + file) 355 | 356 | self.logger.print( 357 | "Number of files in multifile dataset = {}".format(len(filename_list)) 358 | ) 359 | 360 | val_filename_list = [] 361 | 362 | val_index = np.random.randint( 363 | 0, len(filename_list), int(len(filename_list) * self.configs["val_ratio"]) 364 | ) 365 | 366 | for index in val_index: 367 | val_filename_list.append(filename_list[index]) 368 | 369 | while val_index.size > 0: 370 | filename_list.pop(val_index[0]) 371 | val_index = np.where(val_index > val_index[0], val_index - 1, val_index) 372 | val_index = val_index[1:] 373 | 374 | self.logger.print("Initial finished") 375 | 376 | val_filename_list1, filename_list1 = augment_val( 377 | val_filename_list, filename_list 378 | ) 379 | 380 | self.logger.print("Listing finished") 381 | 382 | self.dataloaders = {} 383 | self.samplers = {} 384 | self.datasets = {} 385 | 386 | self.samplers["val"] = SubsetRandomSampler( 387 | range(len(val_filename_list1) * (self.configs["ep_length"] - 1)) 388 | ) 389 | self.samplers["train"] = SubsetRandomSampler( 390 | range(len(filename_list1) * (self.configs["ep_length"] - 1)) 391 | ) 392 | 393 | self.logger.print("Sampler finished") 394 | 395 | self.datasets["train"] = MultimodalManipulationDataset( 396 | filename_list1, 397 | transform=transforms.Compose( 398 | [ 399 | ProcessForce(32, "force", tanh=True), 400 | ProcessForce(32, "unpaired_force", tanh=True), 401 | ToTensor(device=self.device), 402 | ] 403 | ), 404 | episode_length=self.configs["ep_length"], 405 | training_type=self.configs["training_type"], 406 | action_dim=self.configs["action_dim"], 407 | 408 | ) 409 | 410 | self.datasets["val"] = MultimodalManipulationDataset( 411 | val_filename_list1, 412 | transform=transforms.Compose( 413 | [ 414 | ProcessForce(32, "force", tanh=True), 415 | ProcessForce(32, "unpaired_force", tanh=True), 416 | ToTensor(device=self.device), 417 | ] 418 | ), 419 | episode_length=self.configs["ep_length"], 420 | training_type=self.configs["training_type"], 421 | action_dim=self.configs["action_dim"], 422 | 423 | ) 424 | 425 | self.logger.print("Dataset finished") 426 | 427 | self.dataloaders["val"] = DataLoader( 428 | self.datasets["val"], 429 | batch_size=self.configs["batch_size"], 430 | num_workers=self.configs["num_workers"], 431 | sampler=self.samplers["val"], 432 | pin_memory=True, 433 | drop_last=True, 434 | ) 435 | self.dataloaders["train"] = DataLoader( 436 | self.datasets["train"], 437 | batch_size=self.configs["batch_size"], 438 | num_workers=self.configs["num_workers"], 439 | sampler=self.samplers["train"], 440 | pin_memory=True, 441 | drop_last=True, 442 | ) 443 | 444 | self.len_data = len(self.dataloaders["train"]) 445 | self.val_len_data = len(self.dataloaders["val"]) 446 | 447 | self.logger.print("Finished setting up date") 448 | 449 | def _record_image(self, image_packet, global_cnt, string=None): 450 | 451 | if string is None: 452 | string = "" 453 | 454 | flow2, flow_label, image = image_packet 455 | image_index = 0 456 | 457 | b, c, h, w = flow_label.size() 458 | 459 | upsampled_flow = nn.functional.upsample(flow2, size=(h, w), mode="bilinear") 460 | upsampled_flow = upsampled_flow.cpu().detach().numpy() 461 | orig_image = image[image_index].cpu().numpy() 462 | 463 | orig_flow = flow2rgb( 464 | flow_label[image_index].cpu().detach().numpy(), max_value=None 465 | ) 466 | pred_flow = flow2rgb(upsampled_flow[image_index], max_value=None) 467 | 468 | concat_image = np.concatenate([orig_image, orig_flow, pred_flow], 1) 469 | 470 | concat_image = concat_image * 255 471 | concat_image = concat_image.astype(np.uint8) 472 | concat_image = concat_image.transpose(2, 0, 1) 473 | 474 | self.logger.tb.add_image(string + "predicted_flow", concat_image, global_cnt) 475 | -------------------------------------------------------------------------------- /multimodal/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import random 5 | import copy 6 | import math 7 | from tqdm import tqdm 8 | 9 | 10 | def detach_var(var): 11 | """Detaches a var from torch 12 | 13 | Args: 14 | var (torch var): Torch variable that requires grad 15 | 16 | Returns: 17 | TYPE: numpy array 18 | """ 19 | return var.cpu().detach().numpy() 20 | 21 | 22 | def set_seeds(seed, use_cuda): 23 | """Set Seeds 24 | 25 | Args: 26 | seed (int): Sets the seed for numpy, torch and random 27 | """ 28 | random.seed(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | 32 | if use_cuda: 33 | torch.cuda.manual_seed(seed) 34 | else: 35 | torch.manual_seed(seed) 36 | 37 | 38 | def quaternion_to_euler(x, y, z, w): 39 | 40 | # t0 = +2.0 * (w * x + y * z) 41 | # t1 = +1.0 - 2.0 * (x * x + y * y) 42 | # X = np.arctan2(t0, t1) 43 | 44 | # t2 = +2.0 * (w * y - z * x) 45 | # t2 = +1.0 if t2 > +1.0 else t2 46 | # t2 = -1.0 if t2 < -1.0 else t2 47 | # Y = np.arcsin(t2) 48 | 49 | t3 = +2.0 * (w * z + x * y) 50 | t4 = +1.0 - 2.0 * (y * y + z * z) 51 | Z = np.arctan2(t3, t4) 52 | 53 | Z = -Z - np.pi / 2 54 | 55 | return Z 56 | 57 | 58 | def compute_accuracy(pred, target): 59 | pred_1 = torch.where(pred > 0.5, torch.ones_like(pred), torch.zeros_like(pred)) 60 | target_1 = torch.where(target > 0.5, torch.ones_like(pred), torch.zeros_like(pred)) 61 | batch_size = target.size()[0] * 1.0 62 | 63 | num_correct = 1.0 * torch.where( 64 | pred_1 == target_1, torch.ones_like(pred), torch.zeros_like(pred) 65 | ).sum().float() 66 | 67 | accuracy = num_correct / batch_size 68 | return accuracy 69 | 70 | 71 | def rescaleImage(image, output_size=128, scale=1 / 255.0): 72 | """Rescale the image in a sample to a given size. 73 | Args: 74 | output_size (tuple or int): Desired output size. If tuple, output is 75 | matched to output_size. If int, smaller of image edges is matched 76 | to output_size keeping aspect ratio the same. 77 | """ 78 | image_transform = image * scale 79 | # torch.from_numpy(img.transpose((0, 3, 1, 2))).float() 80 | return image_transform.transpose(1, 3).transpose(2, 3) 81 | 82 | 83 | def log_normal(x, m, v): 84 | 85 | log_prob = -((x - m) ** 2 / (2 * v)) - 0.5 * torch.log(2 * math.pi * v) 86 | 87 | return log_prob 88 | 89 | 90 | def kl_normal(qm, qv, pm, pv): 91 | element_wise = 0.5 * ( 92 | torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1 93 | ) 94 | kl = element_wise.sum(-1) 95 | return kl 96 | 97 | 98 | def augment_val(val_filename_list, filename_list): 99 | 100 | filename_list1 = copy.deepcopy(filename_list) 101 | val_filename_list1 = [] 102 | 103 | for name in tqdm(val_filename_list): 104 | filename = name[:-8] 105 | found = True 106 | 107 | if filename[-2] == "_": 108 | file_number = int(filename[-1]) 109 | filename = filename[:-1] 110 | else: 111 | file_number = int(filename[-2:]) 112 | filename = filename[:-2] 113 | 114 | if file_number < 10: 115 | comp_number = 19 116 | filename1 = filename + str(comp_number) + "_1000.h5" 117 | while (filename1 not in filename_list1) and ( 118 | filename1 not in val_filename_list1 119 | ): 120 | comp_number += -1 121 | filename1 = filename + str(comp_number) + "_1000.h5" 122 | if comp_number < 0: 123 | found = False 124 | break 125 | else: 126 | comp_number = 0 127 | filename1 = filename + str(comp_number) + "_1000.h5" 128 | while (filename1 not in filename_list1) and ( 129 | filename1 not in val_filename_list1 130 | ): 131 | comp_number += 1 132 | filename1 = filename + str(comp_number) + "_1000.h5" 133 | if comp_number > 19: 134 | found = False 135 | break 136 | 137 | if found: 138 | if filename1 in filename_list1: 139 | filename_list1.remove(filename1) 140 | 141 | if filename1 not in val_filename_list: 142 | val_filename_list1.append(filename1) 143 | 144 | val_filename_list1 += val_filename_list 145 | 146 | return val_filename_list1, filename_list1 147 | 148 | 149 | def flow2rgb(flow_map, max_value=None): 150 | global args 151 | _, h, w = flow_map.shape 152 | # flow_map[:,(flow_map[0] == 0) & (flow_map[1] == 0)] = float('nan') 153 | rgb_map = np.ones((h, w, 3)).astype(np.float32) 154 | if max_value is not None: 155 | normalized_flow_map = flow_map / max_value 156 | else: 157 | normalized_flow_map = flow_map / (np.abs(flow_map).max()) 158 | rgb_map[:, :, 0] += normalized_flow_map[0] 159 | rgb_map[:, :, 1] -= 0.5 * (normalized_flow_map[0] + normalized_flow_map[1]) 160 | rgb_map[:, :, 2] += normalized_flow_map[1] 161 | return rgb_map.clip(0, 1) 162 | 163 | 164 | def scene_flow2rgb(flow_map): 165 | global args 166 | 167 | flow_map = np.where(flow_map > 1e-6, flow_map, np.zeros_like(flow_map)) 168 | 169 | indices1 = np.nonzero(flow_map[0, :, :]) 170 | indices2 = np.nonzero(flow_map[1, :, :]) 171 | indices3 = np.nonzero(flow_map[2, :, :]) 172 | 173 | normalized_flow_map = np.zeros_like(flow_map) 174 | 175 | divisor_1 = 0 176 | divisor_2 = 0 177 | divisor_3 = 0 178 | 179 | if np.array(indices1).size > 0: 180 | divisor_1 = ( 181 | flow_map[0, :, :][indices1].max() - flow_map[0, :, :][indices1].min() 182 | ) 183 | 184 | if np.array(indices2).size > 0: 185 | divisor_2 = ( 186 | flow_map[1, :, :][indices2].max() - flow_map[1, :, :][indices2].min() 187 | ) 188 | 189 | if np.array(indices3).size > 0: 190 | divisor_3 = ( 191 | flow_map[2, :, :][indices3].max() - flow_map[2, :, :][indices3].min() 192 | ) 193 | 194 | if divisor_1 > 0: 195 | normalized_flow_map[0, :, :][indices1] = ( 196 | flow_map[0, :, :][indices1] - flow_map[0, :, :][indices1].min() 197 | ) / divisor_1 198 | 199 | if divisor_2 > 0: 200 | normalized_flow_map[1, :, :][indices2] = ( 201 | flow_map[1, :, :][indices2] - flow_map[1, :, :][indices2].min() 202 | ) / divisor_2 203 | 204 | if divisor_3 > 0: 205 | normalized_flow_map[2, :, :][indices3] = ( 206 | flow_map[2, :, :][indices3] - flow_map[2, :, :][indices3].min() 207 | ) / divisor_3 208 | 209 | return normalized_flow_map 210 | 211 | 212 | def point_cloud2rgb(flow_map): 213 | global args 214 | 215 | flow_map = np.where(flow_map > 5e-4, flow_map, np.zeros_like(flow_map)) 216 | 217 | flow_map = np.tile( 218 | np.expand_dims(np.sqrt(np.sum(np.square(flow_map), axis=0)), axis=0), (3, 1, 1) 219 | ) 220 | return flow_map 221 | 222 | 223 | def EPE(input_flow, target_flow, device, sparse=False, mean=True): 224 | # torch.cuda.init() 225 | 226 | EPE_map = torch.norm(target_flow.cpu() - input_flow.cpu(), 2, 1) 227 | batch_size = EPE_map.size(0) 228 | if sparse: 229 | # invalid flow is defined with both flow coordinates to be exactly 0 230 | mask = (target_flow[:, 0] == 0) & (target_flow[:, 1] == 0) 231 | 232 | EPE_map = EPE_map[~mask.data] 233 | if mean: 234 | return EPE_map.mean().to(device) 235 | else: 236 | return (EPE_map.sum() / batch_size).to(device) 237 | 238 | 239 | def realEPE(output, target, device, sparse=False): 240 | b, _, h, w = target.size() 241 | 242 | upsampled_output = nn.functional.upsample(output, size=(h, w), mode="bilinear") 243 | return EPE(upsampled_output, target, device, sparse, mean=True) 244 | 245 | 246 | def realAAE(output, target, device, sparse=False): 247 | b, _, h, w = target.size() 248 | upsampled_output = nn.functional.upsample(output, size=(h, w), mode="bilinear") 249 | return AAE(upsampled_output, target, device, sparse, mean=True) 250 | 251 | 252 | def AAE(input_flow, target_flow, device, sparse=False, mean=True): 253 | b, _, h, w = target_flow.size() 254 | ones = torch.ones([b, 1, h, w]) 255 | target = torch.cat((target_flow.cpu(), ones), 1) 256 | inp = torch.cat((input_flow.cpu(), ones), 1) 257 | target = target.permute(0, 2, 3, 1).contiguous().view(b * h * w, -1) 258 | inp = inp.permute(0, 2, 3, 1).contiguous().view(b * h * w, -1) 259 | 260 | target = target.div(torch.norm(target, dim=1, keepdim=True).expand_as(target)) 261 | inp = inp.div(torch.norm(inp, dim=1, keepdim=True).expand_as(inp)) 262 | 263 | dot_prod = torch.bmm((target.view(b * h * w, 1, -1)), inp.view(b * h * w, -1, 1)) 264 | AAE_map = torch.acos(torch.clamp(dot_prod, -1, 1)) 265 | 266 | return AAE_map.mean().to(device) 267 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | astor==0.7.1 3 | attrs==19.3.0 4 | backcall==0.1.0 5 | bleach==3.1.0 6 | catkin-pkg==0.4.10 7 | certifi==2018.11.29 8 | cffi==1.13.2 9 | chardet==3.0.4 10 | cycler==0.10.0 11 | Cython==0.29.6 12 | DateTime==4.3 13 | decorator==4.3.0 14 | defusedxml==0.5.0 15 | e==1.4.5 16 | entrypoints==0.3 17 | envs==1.3 18 | future==0.17.1 19 | gast==0.2.2 20 | gitdb2==2.0.6 21 | GitPython==3.0.4 22 | glfw==1.9.1 23 | grpcio==1.20.0 24 | gym==0.10.9 25 | h5py==2.9.0 26 | hjson==3.0.1 27 | idna==2.8 28 | imageio==2.6.1 29 | importlib-metadata==1.5.0 30 | ipdb==0.12.2 31 | ipykernel==5.1.0 32 | ipython==7.2.0 33 | ipython-genutils==0.2.0 34 | ipywidgets==7.4.2 35 | jedi==0.13.2 36 | Jinja2==2.10 37 | jsonschema==2.6.0 38 | jupyter==1.0.0 39 | jupyter-client==5.2.4 40 | jupyter-console==6.0.0 41 | jupyter-core==4.4.0 42 | Keras-Applications==1.0.7 43 | Keras-Preprocessing==1.0.9 44 | kiwisolver==1.0.1 45 | lockfile==0.12.2 46 | lxml==4.4.2 47 | Markdown==3.1 48 | MarkupSafe==1.1.0 49 | matplotlib==3.0.2 50 | mistune==0.8.4 51 | mock==2.0.0 52 | more-itertools==8.2.0 53 | mpmath==1.1.0 54 | nbconvert==5.4.0 55 | nbformat==4.4.0 56 | notebook==5.7.4 57 | numpy==1.16.0 58 | packaging==20.1 59 | pandas==0.24.2 60 | pandocfilters==1.4.2 61 | parso==0.3.1 62 | pathlib2==2.3.5 63 | pbr==5.1.3 64 | pexpect==4.6.0 65 | pickleshare==0.7.5 66 | Pillow==5.4.1 67 | pluggy==0.13.1 68 | prometheus-client==0.5.0 69 | prompt-toolkit==2.0.7 70 | protobuf==3.6.1 71 | ptyprocess==0.6.0 72 | py==1.8.1 73 | pycodestyle==2.5.0 74 | pycparser==2.19 75 | pyflakes==2.1.1 76 | pyglet==1.3.2 77 | Pygments==2.3.1 78 | pyparsing==2.3.1 79 | pyquaternion==0.9.5 80 | pytest==5.3.5 81 | python-dateutil==2.7.5 82 | pytz==2018.9 83 | PyYAML==3.13 84 | pyzmq==17.1.2 85 | qtconsole==4.4.3 86 | requests==2.21.0 87 | scipy==1.2.0 88 | seaborn==0.9.0 89 | Send2Trash==1.5.0 90 | six==1.12.0 91 | smmap2==2.0.5 92 | snakeviz==2.0.1 93 | sympy==1.3 94 | tensorboard==1.12.2 95 | tensorboardX==1.6 96 | tensorflow-estimator==1.13.0 97 | termcolor==1.1.0 98 | terminado==0.8.1 99 | testpath==0.4.2 100 | torch==1.1.0 101 | torchsummary==1.5.1 102 | torchvision==0.3.0 103 | tornado==5.1.1 104 | tqdm==4.36.1 105 | traitlets==4.3.2 106 | urllib3==1.24.1 107 | wcwidth==0.1.7 108 | --------------------------------------------------------------------------------