├── LICENSE ├── README.md ├── analysis.py ├── assets └── example.gif ├── dataset_generation.py ├── dataset_generation_const.py ├── figures.py ├── main.py ├── model.py ├── nets.py ├── requirements.txt ├── segmentation_metrics.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Wayve Technologies Limited 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-supervised Object-Centric Segmentation (SOCS) 2 | 3 | Code repository for the paper [Linking vision and motion for self-supervised object-centric perception](http://arxiv.org/abs/2307.07147). 4 | 5 | ![gif](./assets/example.gif) 6 | 7 | ## Installation 8 | 9 | A virtual environment with Python 3.9 is recommended. 10 | 11 |
pip install -r requirements.txt
12 | 13 | ## Dataset generation 14 | We ran our experiments with the Waymo Open Perception dataset v1.4, which didn't provide utilities for loading contiguous frame sequences. Therefore, we first extract frame sequences for train and val splits using the `dataset_generation.py` script. First, download the [dataset](https://waymo.com/intl/en_us/open/download/) so that the local file structure looks like: 15 | 16 |
waymo_open_raw
17 |     train
18 |         segment-10017090168044687777_6380_000_6400_000_with_camera_labels.tfrecord
19 |         ...
20 |     val
21 |         segment-10203656353524179475_7625_000_7645_000_with_camera_labels.tfrecord
22 |         ...
23 | 24 | Then, install the Waymo Open dataset utilities from the [official repo](https://github.com/waymo-research/waymo-open-dataset). (It's recommended to do this in a separate virtual environment.) Run the data generation script for the train and val splits: 25 | 26 |
python dataset_generation.py train
27 | python dataset_generation.py val
28 | 29 | You should end up with 40,000 train sequences and 208 val sequences. Unfortunately, it's time-consuming to iterate through the dataset multiple times and generate the sequences. For new code it may be preferable to take advantage of the new dataloading tools provided with the 2.0 release of the dataset. 30 | 31 | ## Training 32 | 33 | Simply run: 34 |
python main.py --behvaioral_cloning_task
35 | 36 | In addition to the batch size, memory consumption depends on the `--downsample_factor` flag (what fraction of pixels are decoded for each sequence in the training batch). 37 | 38 | ## Inference 39 | 40 | Quantitative metrics and video clips of the qualitative results can be generated using the `analysis.py` script. By default it runs inference on a single random train and val sequence. To generate complete validation metrics, run: 41 | 42 |
python analysis.py example_logdir/version_0 --split val --num_seq_to_analyze 208 --num_seq_to_plot 10 --gpu 0
43 | 44 | The GPU memory requirements can be reduced by setting the `--parallel_pix` flag to a smaller value. Additional figures in the paper were generated with the functions in `figures.py`. 45 | 46 | ## Citation 47 | If you find our work helpful, please cite our paper: 48 | 49 |
@article{stocking2023linking,
50 |   title={Linking vision and motion for self-supervised object-centric perception},
51 |   author={Stocking, Kaylene C and Murez, Zak and Badrinarayanan, Vijay and Shotton, Jamie and Kendall, Alex and Tomlin, Claire and Burgess, Christopher P},
52 |   journal={arXiv preprint arXiv:2307.07147},
53 |   year={2023}
54 | }
55 | -------------------------------------------------------------------------------- /analysis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import imageio 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | import pickle 7 | import torch 8 | import yaml 9 | 10 | from pytorch_lightning import Trainer 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | 14 | from main import InferenceDataset 15 | from model import SOCS 16 | from util import parse_train_step, get_checkpoint_path 17 | 18 | PLOT_CHOICES = ['ground_truth_rgb', 19 | 'ground_truth_seg', 20 | 'greedy_pred_rgb', 21 | 'mixture_pred_rgb', 22 | 'pred_seg', 23 | 'pred_seg_foreground', 24 | 'pixel_score'] 25 | PLOT_CHOICES = { key: idx for (idx, key) in enumerate(PLOT_CHOICES) } 26 | 27 | def render_fig(fig): 28 | """Renders a figure into an RGB image.""" 29 | canvas = fig.canvas 30 | canvas.draw() 31 | image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') 32 | w, h = canvas.get_width_height() 33 | return image.reshape([h, w, 3]) 34 | 35 | def get_hparams(logdir): 36 | with open(os.path.join(logdir, 'hparams.yaml'), 'r') as f: 37 | hparams = yaml.safe_load(f) 38 | return hparams 39 | 40 | def plot_frame_sequence_from_single_batch(ckpt, batch, results, plot_types, fig_width=10): 41 | img_dims = tuple(batch['decode_dims']) 42 | num_rows = len(plot_types) 43 | if ckpt.hparams['cameras'] is not None: 44 | cameras = ckpt.hparams['cameras'] 45 | else: 46 | cameras = [1] 47 | num_cameras = len(cameras) 48 | seq_len = ckpt.hparams['sequence_len'] 49 | 50 | fig_imgs = [] 51 | img_seq_dims = (seq_len, num_cameras, img_dims[1], img_dims[2]) 52 | img_seq = (batch['img_seq'] * 255).astype('uint8').reshape(img_seq_dims + (3,)) 53 | obj_weights = results['per_object_weights'] 54 | fig_height_width_ratio = (img_dims[1] / img_dims[2]) * (num_rows / num_cameras) 55 | 56 | for frame in range(seq_len): 57 | (f, axes) = plt.subplots(num_rows, num_cameras, figsize=fig_width*np.array([1, fig_height_width_ratio])) 58 | # Make sure axes have 2 dims even in case where only 1 row and/or camera 59 | axes = axes.reshape(num_rows, num_cameras) 60 | for (row, plot_type) in enumerate(plot_types): 61 | for cam in range(num_cameras): 62 | frame_idx = np.ravel_multi_index((frame, cam), (seq_len, num_cameras)) 63 | if plot_type == PLOT_CHOICES['ground_truth_rgb']: 64 | im = img_seq[frame, cam] 65 | elif plot_type == PLOT_CHOICES['ground_truth_seg']: 66 | im = ckpt.show_ground_truth_seg(batch['instance_oh'], img_dims, idx=frame_idx) 67 | elif plot_type == PLOT_CHOICES['mixture_pred_rgb']: 68 | im = ckpt.reconstruct_image(results['preds'], img_dims, idx=frame_idx) 69 | elif plot_type == PLOT_CHOICES['greedy_pred_rgb']: 70 | im = ckpt.reconstruct_image(results['greedy_preds'], img_dims, idx=frame_idx) 71 | elif plot_type == PLOT_CHOICES['pred_seg']: 72 | im = ckpt.show_object_masks(obj_weights, img_dims, idx=frame_idx) 73 | elif plot_type == PLOT_CHOICES['pred_seg_foreground']: 74 | foreground_seg = batch['instance_mask'].reshape(img_seq_dims) 75 | im = ckpt.show_object_masks_foreground(obj_weights, foreground_seg, img_dims, idx=frame_idx) 76 | elif plot_type == PLOT_CHOICES['pixel_score']: 77 | im = ckpt.show_pixel_scores(obj_weights, batch['instance_oh'], img_dims, idx=frame_idx) 78 | axes[row, cam].imshow(im) 79 | axes[row, cam].set_axis_off() 80 | 81 | plt.tight_layout(pad=0, h_pad=0.5) 82 | fig_imgs.append(render_fig(f)) 83 | plt.close(f) 84 | 85 | return fig_imgs 86 | 87 | def plot_seg_sequence(ckpt, batch, results, timepts, cam=1): 88 | (TC, H, W, _) = batch['img_seq'].shape 89 | if 'cameras' in ckpt.hparams and ckpt.hparams['cameras'] is not None: 90 | cameras = ckpt.hparams['cameras'] 91 | else: 92 | cameras = [1] 93 | num_cameras = len(cameras) 94 | seq_len = ckpt.hparams['sequence_len'] 95 | img_dims = (seq_len, num_cameras, H, W) 96 | 97 | (fig, axes) = plt.subplots(2, len(timepts)) 98 | img_seq = (batch['img_seq'] * 255).astype('uint8') 99 | obj_weights = results['per_object_weights'] 100 | for t in range(timepts): 101 | frame_idx = np.ravel_multi_index((t, cam), (seq_len, num_cameras)) 102 | axes[0, t].imshow(img_seq[0, t, cam]) 103 | seg = ckpt.show_object_masks(obj_weights, img_dims, idx=frame_idx) 104 | axes[1, t].imshow(seg) 105 | 106 | plt.tight_layout(pad=0, h_pad=0.5) 107 | return fig 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('log_root', help='Path to log directory or specific checkpoint') 112 | parser.add_argument('--name', default=None) 113 | parser.add_argument('--data_root', default=None) 114 | parser.add_argument('--split', default='both', choices=['train', 'val', 'both']) 115 | parser.add_argument('--idx', type=int, default=None, nargs='+') 116 | parser.add_argument('--idx_file', default=None) 117 | parser.add_argument('--num_seq_to_analyze', type=int, default=1) 118 | parser.add_argument('--num_seq_to_plot', type=int, default=1) 119 | parser.add_argument('--num_train_seq', type=int, default=40000) 120 | parser.add_argument('--num_val_seq', type=int, default=208) 121 | parser.add_argument('--video_format', default='both', choices=['gif', 'mp4', 'both']) 122 | parser.add_argument('--gpu', type=int, default=None, nargs='+') 123 | parser.add_argument('--parallel_pix', type=int, default=10000, 124 | help='Number of pixels to decode in each pass. More takes more memory but requires less passes as a result.') 125 | parser.add_argument('--num_workers', type=int, default=4) 126 | parser.add_argument('--plot_types', 127 | default=['ground_truth_rgb', 'mixture_pred_rgb', 'pred_seg'], 128 | nargs='+', choices=PLOT_CHOICES.keys()) 129 | 130 | args = parser.parse_args() 131 | 132 | if args.log_root.endswith('.ckpt'): 133 | checkpoint_path = args.log_root 134 | checkpoint_fname = os.path.basename(checkpoint_path) 135 | log_dir = os.path.dirname(checkpoint_path) 136 | train_step = parse_train_step(checkpoint_fname) 137 | else: 138 | checkpoint_dir = os.path.join(args.log_root, 'checkpoints') 139 | (checkpoint_path, train_step) = get_checkpoint_path(checkpoint_dir) 140 | log_dir = args.log_root 141 | 142 | print(f'Loading checkpoint: {checkpoint_path}') 143 | ckpt = SOCS.load_from_checkpoint(checkpoint_path) 144 | ckpt.inference_parallel_pixels = args.parallel_pix 145 | 146 | if args.idx is not None: 147 | train_indices = args.idx 148 | val_indices = args.idx 149 | elif args.idx_file is not None: 150 | with open(args.idx_file, 'r') as f: 151 | indices = yaml.safe_load(f) 152 | train_indices = indices['train'] 153 | val_indices = indices['val'] 154 | if train_indices is None: 155 | train_indices = [] 156 | if val_indices is None: 157 | val_indices = [] 158 | if args.name is None: 159 | args.name = os.path.basename(args.idx_file).split('.')[0] 160 | else: 161 | train_indices = np.random.choice(range(args.num_train_seq), args.num_seq_to_analyze, replace=False) 162 | val_indices = np.random.choice(range(args.num_val_seq), args.num_seq_to_analyze, replace=False) 163 | 164 | if (args.split == 'both' or args.split == 'train') and len(train_indices) > 0: 165 | do_train = True 166 | else: 167 | do_train = False 168 | if (args.split == 'both' or args.split == 'val') and len(val_indices) > 0: 169 | do_val = True 170 | else: 171 | do_val = False 172 | 173 | if args.data_root is None: 174 | data_root = ckpt.hparams['dataset_root'] 175 | else: 176 | data_root = args.data_root 177 | 178 | img_dim_hw = ckpt.hparams['img_dim_hw'] 179 | 180 | if 'cameras' in ckpt.hparams and ckpt.hparams['cameras'] is not None: 181 | cameras = ckpt.hparams['cameras'] 182 | else: 183 | cameras = [1] 184 | 185 | print(ckpt.hparams) 186 | 187 | train_data_root = os.path.join(data_root, 'train') 188 | val_data_root = os.path.join(data_root, 'val') 189 | 190 | plot_types = [ PLOT_CHOICES[choice] for choice in args.plot_types ] 191 | 192 | result_categories = ['avg_centroid_dist', 193 | 'reconstruction_err', 194 | 'instance_reconstruction_err', 195 | 'ari', 196 | 'seq_ari'] 197 | 198 | def run_analysis(split, num_seq, indices, data_root): 199 | add_instance_seg = True if split == 'val' else False 200 | dataset = InferenceDataset(ckpt.hparams['sequence_len'], 201 | ckpt.hparams['spatial_patch_hw'], 202 | data_root=data_root, 203 | num_sequences=num_seq, 204 | img_dim_hw=img_dim_hw, 205 | camera_choice=cameras, 206 | decode_pixel_downsample_factor=1, 207 | add_instance_seg=add_instance_seg, 208 | no_viewpoint = not ckpt.hparams['provide_viewpoint']) 209 | dataset.set_indices(indices) 210 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.num_workers) 211 | if args.gpu is not None: 212 | trainer = Trainer(gpus=args.gpu, strategy="ddp" if len(args.gpu) > 1 else None, logger=False) 213 | else: 214 | trainer = Trainer(accelerator='cpu', logger=False) 215 | r = trainer.predict(ckpt, dataloaders=dataloader) 216 | 217 | results = {key: [] for key in result_categories} 218 | all_plots = [] 219 | for (i, batch_results) in tqdm(enumerate(r)): 220 | batch = dataset.__getitem__(i) 221 | plots = plot_frame_sequence_from_single_batch(ckpt, batch, batch_results, plot_types) 222 | for key in result_categories: 223 | if key in batch_results: 224 | results[key].append(batch_results[key]) 225 | if i < args.num_seq_to_plot: 226 | all_plots.append(plots) 227 | 228 | print(f'Results on {split} dataset:') 229 | for (key, val) in results.items(): 230 | print(f'Mean {key}: {np.nanmean(val)}, std: {np.nanstd(val)}') 231 | if args.num_seq_to_plot > 0: 232 | for (index, frames) in zip(indices, all_plots): 233 | if args.video_format == 'both' or args.video_format == 'gif': 234 | imageio.mimwrite(os.path.join(log_dir, f'{split}_{index}_{train_step}.gif'), frames, fps=2) 235 | if args.video_format == 'both' or args.video_format == 'mp4': 236 | imageio.mimwrite(os.path.join(log_dir, f'{split}_{index}_{train_step}.mp4'), frames, fps=2) 237 | 238 | return results 239 | 240 | overall_results = {'train_step': train_step} 241 | if do_train: 242 | results = run_analysis('train', args.num_train_seq, train_indices, train_data_root) 243 | overall_results['train_metrics'] = results 244 | overall_results['train_seq_indices'] = train_indices 245 | if do_val: 246 | results = run_analysis('val', args.num_val_seq, val_indices, val_data_root) 247 | overall_results['val_metrics'] = results 248 | overall_results['val_seq_indices'] = val_indices 249 | 250 | if args.name is not None: 251 | name = f'_{args.name}' 252 | else: 253 | name = '_' 254 | metrics_path = os.path.join(log_dir, f'metrics{name}_{train_step}.pkl') 255 | print(f'Saving metrics at {metrics_path}') 256 | with open(metrics_path, 'wb') as f: 257 | pickle.dump(overall_results, f) 258 | -------------------------------------------------------------------------------- /assets/example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wayveai/SOCS/4e11116c7481ec2991359ef925df6a9c2eff99c8/assets/example.gif -------------------------------------------------------------------------------- /dataset_generation.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | tf.enable_eager_execution() 3 | tf.config.set_visible_devices([], 'GPU') 4 | 5 | import argparse 6 | import os 7 | import numpy as np 8 | import pickle 9 | 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from waymo_open_dataset import dataset_pb2 as open_dataset 13 | from waymo_open_dataset.utils import camera_segmentation_utils 14 | 15 | from dataset_generation_const import * 16 | 17 | CAM_TYPE = {open_dataset.CameraName.FRONT_LEFT: 0, open_dataset.CameraName.FRONT: 1, open_dataset.CameraName.FRONT_RIGHT: 2} 18 | 19 | def save_train_seq(seq, bc_seq, seq_num, out_dir): 20 | num_cam = len(CAM_TYPE) 21 | temporal_transform = np.zeros((len(seq), 4, 4)) 22 | timestamps = np.zeros((len(seq), num_cam)) 23 | top_crop_ratio = 60/140 24 | top_offset = RESIZE_DIM[1]*top_crop_ratio 25 | crop_ltrb = (0, top_offset, RESIZE_DIM[0], RESIZE_DIM[1]) 26 | final_dim_hw = (RESIZE_DIM[1] - int(top_offset), RESIZE_DIM[0]) 27 | rgbs = np.zeros((len(seq), num_cam, final_dim_hw[0], final_dim_hw[1], 3)) 28 | extrinsics = np.zeros((len(seq), num_cam, 4, 4)) 29 | 30 | for (t, frame) in enumerate(seq): 31 | temporal_transform[t] = np.array(frame.pose.transform).reshape((4,4)) 32 | for image in frame.images: 33 | if image.name in CAM_TYPE: 34 | f = CAM_TYPE[image.name] 35 | timestamps[t, f] = frame.timestamp_micros 36 | rgb = tf.image.decode_jpeg(image.image).numpy() 37 | rgb_img = Image.fromarray(rgb).resize((RESIZE_DIM)).crop((crop_ltrb)) 38 | rgbs[t, f] = np.array(rgb_img).astype('float') / 255 39 | 40 | for calibration in frame.context.camera_calibrations: 41 | if calibration.name in CAM_TYPE: 42 | f = CAM_TYPE[calibration.name] 43 | # sensor to vehicle frame 44 | extrinsics[t, f] = np.array(calibration.extrinsic.transform).reshape((4, 4)) 45 | 46 | origin = np.expand_dims(temporal_transform[0], 0) # 1 x 4 x 4 47 | viewpoint_transform = np.expand_dims(np.linalg.solve(origin, temporal_transform), 1) @ extrinsics # T x C x 4 x 4 48 | viewpoint_transform[:, :, :3, 3] /= DISTANCE_NORMALIZATION_FACTOR 49 | viewpoint_transform = viewpoint_transform[:, :, :-1].reshape((len(seq), num_cam, -1)) 50 | 51 | timestamps -= timestamps[0] 52 | timestamps /= TIMESTAMP_NORMALIZATION_FACTOR 53 | 54 | bc_waypoints = np.zeros((len(bc_seq), 2)) 55 | bc_mask = np.zeros((len(bc_seq))) 56 | bc_origin = temporal_transform[-1] # 4 x 4 57 | for (t, frame) in enumerate(bc_seq): 58 | if frame is not None: 59 | bc_mask[t] = 1 60 | bc_transform = np.linalg.solve(bc_origin, np.array(frame.pose.transform).reshape((4,4))) 61 | bc_waypoints[t] = bc_transform[:2, 3].T / DISTANCE_NORMALIZATION_FACTOR # 1 x 2 62 | 63 | with open(os.path.join(out_dir, f'{seq_num}.npz'), 'wb') as f: 64 | np.savez_compressed(f, rgb=rgbs, 65 | viewpoint_transform=viewpoint_transform, 66 | time=timestamps, 67 | bc_waypoints=bc_waypoints, 68 | bc_mask=bc_mask) 69 | 70 | return 71 | 72 | def make_train_seqs(first_seq_num, unique_start_ids, in_dir, out_dir): 73 | seq_num = first_seq_num 74 | train_files = os.path.join(in_dir, '*.tfrecord') 75 | filenames = tf.io.matching_files(train_files) 76 | dataset = tf.data.TFRecordDataset(filenames) 77 | current_seq = [] 78 | current_BC = [None for _ in range(BC_LEN)] 79 | current_t = 0 80 | new_seq = True 81 | collect_BC = False 82 | for data in tqdm(dataset): 83 | frame = open_dataset.Frame() 84 | frame.ParseFromString(bytearray(data.numpy())) 85 | new_t = frame.timestamp_micros 86 | 87 | # Collecting image frames for autoencoding 88 | if new_seq or (not collect_BC and ((new_t - current_t > STRIDE*1e5 - 0.05e6) and (new_t - current_t < STRIDE*1e5 + 0.05e6))): 89 | if new_seq and new_t in unique_start_ids: 90 | continue 91 | 92 | current_seq.append(frame) 93 | new_seq = False 94 | current_t = new_t 95 | # Once we've collected enough image frames, switch to collecting trajectory info 96 | if len(current_seq) == SEQ_LEN: 97 | collect_BC = True 98 | seq_end_t = current_t 99 | 100 | # Collecting future trajectory information for BC 101 | elif collect_BC: 102 | elapsed_t = new_t - seq_end_t 103 | index = int(np.round(elapsed_t / (BC_STRIDE*1e5)) - 1) 104 | # Deal with timeskips 105 | if index < 0: 106 | index = np.inf 107 | if index < BC_LEN: 108 | current_BC[index] = frame 109 | if index >= BC_LEN - 1: 110 | save_train_seq(current_seq, current_BC, seq_num, out_dir) 111 | seq_start_t = current_seq[0].timestamp_micros 112 | unique_start_ids[seq_start_t] = seq_num 113 | seq_num += 1 114 | current_seq = [] 115 | current_BC = [None for _ in range(BC_LEN)] 116 | new_seq = True 117 | collect_BC = False 118 | current_t = new_t 119 | 120 | # The frame we wanted with the correct timestamp is missing, so start over with a new sequence 121 | elif (new_t - current_t < 0) or (new_t - current_t > STRIDE*1e5 + 0.05e6): 122 | if new_t not in unique_start_ids: 123 | current_seq = [frame] 124 | current_t = new_t 125 | new_seq = False 126 | else: 127 | current_seq = [] 128 | current_t = new_t 129 | new_seq = True 130 | 131 | # Note that if we don't meet either above condition, it just means we haven't reached the next frame with the 132 | # correct timestamp yet 133 | return seq_num 134 | 135 | def save_val_seq(seq, seq_num, out_dir): 136 | panoptic_label_inds = range(len(seq)) 137 | num_cam = len(CAM_TYPE) 138 | seg_protos = [0 for _ in range(num_cam*SEQ_LEN)] 139 | temporal_transform = np.zeros((len(seq), 4, 4)) 140 | timestamps = np.zeros((len(seq), num_cam)) 141 | top_crop_ratio = 60/140 142 | top_offset = RESIZE_DIM[1]*top_crop_ratio 143 | crop_ltrb = (0, top_offset, RESIZE_DIM[0], RESIZE_DIM[1]) 144 | final_dim_hw = (RESIZE_DIM[1] - int(top_offset), RESIZE_DIM[0]) 145 | rgbs = np.zeros((len(seq), num_cam, final_dim_hw[0], final_dim_hw[1], 3)) 146 | extrinsics = np.zeros((len(seq), num_cam, 4, 4)) 147 | 148 | for (t, frame) in enumerate(seq): 149 | temporal_transform[t] = np.array(frame.pose.transform).reshape((4,4)) 150 | for image in frame.images: 151 | if image.name in CAM_TYPE: 152 | f = CAM_TYPE[image.name] 153 | timestamps[t, f] = frame.timestamp_micros 154 | rgb = tf.image.decode_jpeg(image.image).numpy() 155 | rgb_img = Image.fromarray(rgb).resize((RESIZE_DIM)).crop((crop_ltrb)) 156 | rgbs[t, f] = np.array(rgb_img).astype('float') / 255 157 | if t in panoptic_label_inds: 158 | idx = np.ravel_multi_index((t, f), (SEQ_LEN, num_cam)) 159 | seg_protos[idx] = image.camera_segmentation_label 160 | 161 | for calibration in frame.context.camera_calibrations: 162 | if calibration.name in CAM_TYPE: 163 | f = CAM_TYPE[calibration.name] 164 | # sensor to vehicle frame 165 | extrinsics[t, f] = np.array(calibration.extrinsic.transform).reshape((4, 4)) 166 | 167 | origin = np.expand_dims(temporal_transform[0], 0) # 1 x 4 x 4 168 | viewpoint_transform = np.expand_dims(np.linalg.solve(origin, temporal_transform), 1) @ extrinsics # T x C x 4 x 4 169 | viewpoint_transform[:, :, :3, 3] /= DISTANCE_NORMALIZATION_FACTOR 170 | viewpoint_transform = viewpoint_transform[:, :, :-1].reshape((len(seq), num_cam, -1)) 171 | 172 | timestamps -= timestamps[0] 173 | timestamps /= TIMESTAMP_NORMALIZATION_FACTOR 174 | 175 | (panoptic_labels, _, panoptic_label_divisor) = camera_segmentation_utils.decode_multi_frame_panoptic_labels_from_protos( 176 | seg_protos, remap_values=True 177 | ) 178 | semantic_segs = np.zeros(rgbs.shape[:-1], dtype='int') 179 | instance_segs = np.zeros(rgbs.shape[:-1], dtype='int') 180 | for (i, label) in enumerate(panoptic_labels): 181 | (semantic_label_front, instance_label_front) = camera_segmentation_utils.decode_semantic_and_instance_labels_from_panoptic_label( 182 | label, 183 | panoptic_label_divisor) 184 | 185 | (t, f) = np.unravel_index(i, (SEQ_LEN, num_cam)) 186 | semantic_label_img = Image.fromarray(semantic_label_front.astype('uint8').squeeze()) 187 | semantic_label_img = semantic_label_img.resize(RESIZE_DIM, resample=Image.Resampling.NEAREST).crop(crop_ltrb) 188 | semantic_segs[t, f] = np.array(semantic_label_img).astype('int') 189 | instance_label_img = Image.fromarray(instance_label_front.astype('uint8').squeeze()) 190 | instance_label_img = instance_label_img.resize(RESIZE_DIM, resample=Image.Resampling.NEAREST).crop(crop_ltrb) 191 | instance_segs[t, f] = np.array(instance_label_img).astype('int') 192 | 193 | with open(os.path.join(out_dir, f'{seq_num}.npz'), 'wb') as f: 194 | np.savez_compressed(f, rgb=rgbs, 195 | semantic_seg=semantic_segs, 196 | instance_seg=instance_segs, 197 | viewpoint_transform=viewpoint_transform, 198 | time=timestamps) 199 | 200 | return 201 | 202 | def collate_val_seqs(in_dir): 203 | """ 204 | Find every sequence of frames with panoptic segmentation labels in the dataset. 205 | """ 206 | val_files = os.path.join(in_dir, '*.tfrecord') 207 | filenames = tf.io.matching_files(val_files) 208 | dataset = tf.data.TFRecordDataset(filenames) 209 | seq_dict = {} 210 | for data in tqdm(dataset): 211 | frame = open_dataset.Frame() 212 | frame.ParseFromString(bytearray(data.numpy())) 213 | for image in frame.images: 214 | if image.name in CAM_TYPE: 215 | break 216 | if image.camera_segmentation_label.panoptic_label: 217 | seq_id = image.camera_segmentation_label.sequence_id 218 | if seq_id in seq_dict: 219 | seq_dict[seq_id].append(frame) 220 | else: 221 | seq_dict[seq_id] = [frame] 222 | 223 | return seq_dict 224 | 225 | def make_val_seqs(out_dir, seq_dict): 226 | """ 227 | Save frame sequences where there are no missing frames. 228 | """ 229 | seq_num = 0 230 | for (_, seq) in tqdm(seq_dict.items()): 231 | seq = sorted(seq, key=lambda frame: frame.timestamp_micros) 232 | new_seq = True 233 | current_seq = [] 234 | for frame in seq: 235 | new_t = frame.timestamp_micros 236 | if new_seq or ((new_t - current_t > STRIDE*1e5 - 0.05e6) and (new_t - current_t < STRIDE*1e5 + 0.05e6)): 237 | 238 | current_seq.append(frame) 239 | new_seq = False 240 | current_t = new_t 241 | if len(current_seq) == SEQ_LEN: 242 | save_val_seq(current_seq, seq_num, out_dir) 243 | seq_num += 1 244 | current_seq = [] 245 | new_seq = True 246 | else: 247 | current_seq = [frame] 248 | current_t = new_t 249 | new_seq = False 250 | 251 | if __name__ == '__main__': 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument('split', choices=['train', 'val']) 254 | parser.add_argument('--in_dir', default='waymo_open_raw') 255 | parser.add_argument('--out_dir', default='waymo_open') 256 | parser.add_argument('--load_seq_ids', default=None, 257 | help='To resume generating training sequences, load previously generated IDs from file') 258 | parser.add_argument('--save_seq_ids', default=None, 259 | help='To resume generating training sequences later, save newly generated IDs to file') 260 | 261 | args = parser.parse_args() 262 | 263 | in_dir = os.path.join(args.in_dir, args.split) 264 | out_dir = os.path.join(args.out_dir, args.split) 265 | os.makedirs(out_dir, exist_ok=True) 266 | 267 | if args.split == 'train': 268 | 269 | if args.load_seq_ids is not None: 270 | with open(args.load_seq_ids, 'rb') as f: 271 | unique_start_ids = pickle.load(f) 272 | else: 273 | unique_start_ids = {} 274 | 275 | prev_seq_num = len(unique_start_ids) 276 | new_seq_num = make_train_seqs(prev_seq_num, unique_start_ids, in_dir, out_dir) 277 | if args.save_seq_ids is not None: 278 | with open(args.save_seq_ids, 'wb') as f: 279 | pickle.dump(unique_start_ids, f) 280 | 281 | while new_seq_num > prev_seq_num and new_seq_num < MAX_NUM_TRAIN_SEQ: 282 | prev_seq_num = new_seq_num 283 | new_seq_num = make_train_seqs(prev_seq_num, unique_start_ids, in_dir, out_dir) 284 | 285 | if args.save_seq_ids is not None: 286 | with open(args.save_seq_ids, 'wb') as f: 287 | pickle.dump(unique_start_ids, f) 288 | 289 | elif args.split == 'val': 290 | seq_dict = collate_val_seqs(in_dir) 291 | make_val_seqs(out_dir, seq_dict) -------------------------------------------------------------------------------- /dataset_generation_const.py: -------------------------------------------------------------------------------- 1 | IN_DIR = 'waymo_open/training' 2 | OUT_DIR = 'octopus/data/waymo_8_bc/train' 3 | MAX_NUM_TRAIN_SEQ = 40000 4 | SEQ_LEN = 8 5 | STRIDE = 2 6 | BC_STRIDE = 1 7 | BC_LEN = 16 8 | TIMESTAMP_NORMALIZATION_FACTOR = (SEQ_LEN-1)*STRIDE*1e5 9 | DISTANCE_NORMALIZATION_FACTOR = 48 * (1000 / 3600) * (TIMESTAMP_NORMALIZATION_FACTOR*1e-6) 10 | RESIZE_DIM = (224, 168) -------------------------------------------------------------------------------- /figures.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import pickle 6 | import torch 7 | 8 | from PIL import Image 9 | from pytorch_lightning import Trainer 10 | from torch.utils.data import DataLoader 11 | 12 | from dataset_generation_const import DISTANCE_NORMALIZATION_FACTOR 13 | from main import InferenceDataset 14 | from model import SOCS 15 | from segmentation_metrics import get_centroid_matches 16 | from util import parse_train_step, get_checkpoint_path, MASK_COLORS 17 | 18 | def overlay_masks_on_img(img_arr, mask_weights): 19 | """ 20 | Given an input image (numpy array), superimpose colored masks 21 | """ 22 | (H, W, _) = img_arr.shape 23 | img = Image.fromarray(img_arr) 24 | for (i, mask) in enumerate(mask_weights): 25 | color_arr = np.ones((H, W, 3)) * MASK_COLORS[i] 26 | color_img = Image.fromarray(color_arr.astype('uint8')) 27 | mask_img = Image.fromarray((mask * 255).astype('uint8'), mode='L') 28 | img = Image.composite(color_img, img, mask_img) 29 | return np.array(img) 30 | 31 | def plot_seg_sequence(ckpt, batch, results, timepts, cam=1): 32 | """ 33 | Show the input image sequence and predicted segmentation sequence over time. 34 | """ 35 | (TC, H, W, _) = batch['img_seq'].shape 36 | if 'cameras' in ckpt.hparams and ckpt.hparams['cameras'] is not None: 37 | cameras = ckpt.hparams['cameras'] 38 | else: 39 | cameras = [1] 40 | num_cameras = len(cameras) 41 | seq_len = ckpt.hparams['num_frame_slots'] // num_cameras 42 | img_dims = (TC, H, W) 43 | fig_width = 10 44 | fig_height_width_ratio = (img_dims[1] / img_dims[2]) * (2 / len(timepts)) 45 | 46 | (fig, axes) = plt.subplots(2, len(timepts), figsize=fig_width*np.array([1, fig_height_width_ratio])) 47 | img_seq = (batch['img_seq'] * 255).astype('uint8') 48 | obj_weights = results['per_object_weights'] 49 | for (i, t) in enumerate(timepts): 50 | frame_idx = np.ravel_multi_index((t, cam), (seq_len, num_cameras)) 51 | axes[0, i].imshow(img_seq[frame_idx]) 52 | axes[0, i].set_axis_off() 53 | seg = ckpt.show_object_masks(obj_weights, img_dims, idx=frame_idx) 54 | axes[1, i].imshow(seg) 55 | axes[1, i].set_axis_off() 56 | 57 | plt.tight_layout(pad=0, h_pad=0.5) 58 | return fig 59 | 60 | def plot_seg_overlay(ckpt, batch, results, timepts, cam=1, chosen_inds=None): 61 | """ 62 | Show the input image sequence with best-fit predicted masks superimposed on ground-truth 63 | objects. 64 | """ 65 | (TC, H, W, _) = batch['img_seq'].shape 66 | cameras = ckpt.hparams['cameras'] 67 | C = len(cameras) 68 | T = TC // C 69 | obj_weights = torch.tensor(results['per_object_weights']).reshape((-1, T, C, H, W)) 70 | gt_weights = torch.tensor(batch['instance_oh'].reshape((T, C, H, W, -1))).moveaxis(-1, 0) 71 | (_, pred_inds, gt_inds) = get_centroid_matches(obj_weights, gt_weights) 72 | chosen_pred_weights = obj_weights[pred_inds] 73 | if chosen_inds is not None: 74 | chosen_pred_weights = chosen_pred_weights[chosen_inds] 75 | 76 | fig_width = 10 77 | fig_height_width_ratio = (H / W) * (1 / len(timepts)) 78 | (fig, axes) = plt.subplots(1, len(timepts), figsize=fig_width*np.array([1, fig_height_width_ratio])) 79 | img_seq = (batch['img_seq'] * 255).astype('uint8') 80 | 81 | for (i, t) in enumerate(timepts): 82 | frame_idx = np.ravel_multi_index((t, cam), (T, C)) 83 | img = img_seq[frame_idx] 84 | img = overlay_masks_on_img(img, chosen_pred_weights[:, t, cam].numpy()) 85 | axes[i].imshow(img) 86 | axes[i].set_axis_off() 87 | 88 | plt.tight_layout(pad=0, h_pad=0.5, w_pad=0.5) 89 | return fig 90 | 91 | def plot_raw_seg_overlay(ckpt, batch, results, timepts, cam=1, chosen_inds=None): 92 | """ 93 | Show the input image sequence with (selected) predicted masks superimposed. 94 | """ 95 | (TC, H, W, _) = batch['img_seq'].shape 96 | cameras = ckpt.hparams['cameras'] 97 | C = len(cameras) 98 | T = TC // C 99 | obj_weights = torch.tensor(results['per_object_weights']).reshape((-1, T, C, H, W)) 100 | if chosen_inds is not None: 101 | obj_weights = obj_weights[chosen_inds] 102 | 103 | fig_width = 10 104 | fig_height_width_ratio = (H / W) * (1 / len(timepts)) 105 | (fig, axes) = plt.subplots(1, len(timepts), figsize=fig_width*np.array([1, fig_height_width_ratio])) 106 | img_seq = (batch['img_seq'] * 255).astype('uint8') 107 | 108 | for (i, t) in enumerate(timepts): 109 | frame_idx = np.ravel_multi_index((t, cam), (T, C)) 110 | img = img_seq[frame_idx] 111 | img = overlay_masks_on_img(img, obj_weights[:, t, cam].numpy()) 112 | axes[i].imshow(img) 113 | axes[i].set_axis_off() 114 | 115 | plt.tight_layout(pad=0, h_pad=0.5, w_pad=0.5) 116 | return fig 117 | 118 | # TODO remove OpenCV dependency (or add to requirements if it stays) 119 | def plot_waypoints(ckpt, data_path, batch, results): 120 | """ 121 | Plot the ground-truth and predicted future waypoints on the last image in the sequence. 122 | Note that this requires a special dataset containing full-res images and the ground-truth 123 | waypoints. Also requires OpenCV package. 124 | """ 125 | import cv2 126 | with open(data_path, 'rb') as f: 127 | data = np.load(f) 128 | img_seq = data['full_rgb'] / 255 129 | loaded_intrinsics = data['intrinsics'] 130 | loaded_extrinsics = data['extrinsics'] 131 | img = img_seq[-1, 1] 132 | 133 | intrinsics_matrix = np.zeros((3, 3), dtype='double') 134 | intrinsics = loaded_intrinsics[1].flatten() 135 | intrinsics_matrix[0, 0] = intrinsics[0] # f_x 136 | intrinsics_matrix[0, 2] = intrinsics[2] # c_x 137 | intrinsics_matrix[1, 1] = intrinsics[1] # f_y 138 | intrinsics_matrix[1, 2] = intrinsics[3] # c_y 139 | intrinsics_matrix[2, 2] = 1 140 | intrinsics = torch.tensor(intrinsics_matrix, dtype=torch.double) 141 | extrinsics = torch.tensor(np.array([[0, -1, 0, 0], 142 | [0, 0, -1, loaded_extrinsics[0,1,0,3]], 143 | [1, 0, 0, 0], 144 | [0, 0, 0, 1]]), dtype=torch.double) 145 | 146 | img = img.copy() 147 | 148 | expert_waypoints = torch.tensor(batch['bc_waypoints'] * DISTANCE_NORMALIZATION_FACTOR, dtype=torch.double) 149 | expert_img_points = _get_img_points(intrinsics, extrinsics, expert_waypoints) 150 | for x, y in expert_img_points: 151 | cv2.circle(img, (x, y), 10, (0,1.,0), -1) # Green 152 | 153 | pred_waypoints = torch.tensor(results['bc_waypoints'].squeeze() * DISTANCE_NORMALIZATION_FACTOR, dtype=torch.double) 154 | pred_img_points = _get_img_points(intrinsics, extrinsics, pred_waypoints) 155 | for x, y in pred_img_points: 156 | cv2.circle(img, (x, y), 10, (.9,.6,0), -1) # Orange 157 | 158 | fig = plt.figure() 159 | plt.imshow(img) 160 | plt.gca().set_axis_off() 161 | plt.tight_layout() 162 | return fig 163 | 164 | def _get_img_points(intrinsics, extrinsics, waypoints): 165 | """ 166 | Find the (x, y) coordinates in the image plane for provided waypoints in space. 167 | """ 168 | N = waypoints.shape[0] 169 | points = intrinsics.new_zeros((N, 3)) 170 | points[:, :2] = waypoints 171 | 172 | # [num_points, 4] 173 | homogeneous_points = torch.nn.functional.pad(points, (0, 1), value=1) 174 | # [3, 4] @ [4, num_points] = [3, num_points] 175 | camera_points = extrinsics[:3, :] @ homogeneous_points.T 176 | # [3, 3] @ [3, num_points] = [3, num_points] 177 | homogeneous_image_points = intrinsics @ camera_points 178 | # Filter out points behind the camera origin. 179 | homogeneous_image_points = homogeneous_image_points[:, homogeneous_image_points[2] > 0] 180 | # [2, num_points] 181 | image_points = homogeneous_image_points[:2] / homogeneous_image_points[2] 182 | # [num_points, 2] 183 | int_image_points = image_points.T.round().int().numpy() 184 | return int_image_points 185 | 186 | if __name__ == '__main__': 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument('log_root') 189 | parser.add_argument('--out_path', default='figure.png') 190 | parser.add_argument('--cache_save', default=None) 191 | parser.add_argument('--cache_load', default=None) 192 | parser.add_argument('--data_root', default=None) 193 | parser.add_argument('--camera', type=int, choices=[0,1,2], help='Which camera to plot') 194 | parser.add_argument('--idx', type=int, default=0) 195 | parser.add_argument('--fig_type', choices=['seg_seq', 'seg_overlay', 'raw_seg_overlay', 'waypoints'], default='seg_seq') 196 | parser.add_argument('--split', choices=['train', 'val'], default='train') 197 | parser.add_argument('--gpu', type=int, default=None) 198 | parser.add_argument('--parallel_pix', type=int, default=10000, 199 | help='Number of pixels to decode in each pass. More takes more memory but requires less passes as a result.') 200 | args = parser.parse_args() 201 | 202 | if args.log_root.endswith('.ckpt'): 203 | checkpoint_path = args.log_root 204 | checkpoint_fname = os.path.basename(checkpoint_path) 205 | log_dir = os.path.dirname(checkpoint_path) 206 | train_step = parse_train_step(checkpoint_fname) 207 | else: 208 | checkpoint_dir = os.path.join(args.log_root, 'checkpoints') 209 | (checkpoint_path, train_step) = get_checkpoint_path(checkpoint_dir) 210 | log_dir = args.log_root 211 | 212 | print(f'Loading checkpoint: {checkpoint_path}') 213 | ckpt = SOCS.load_from_checkpoint(checkpoint_path) 214 | 215 | if args.data_root is None: 216 | data_root = ckpt.hparams['dataset_name'] 217 | else: 218 | data_root = args.data_root 219 | 220 | if args.split == 'train': 221 | add_instance_seg = False 222 | num_seq = 40000 223 | if args.data_root is None: 224 | data_root = os.path.join(data_root, 'train') 225 | else: 226 | add_instance_seg = True 227 | num_seq = 208 228 | if args.data_root is None: 229 | data_root = os.path.join(data_root, 'val') 230 | 231 | if args.cache_load: 232 | with open(args.cache_load, 'rb') as f: 233 | data = pickle.load(f) 234 | batch = data['batch'] 235 | batch_results = data['batch_results'] 236 | else: 237 | ckpt.inference_parallel_pixels = args.parallel_pix 238 | dataset = InferenceDataset(ckpt.hparams['sequence_len'], 239 | ckpt.hparams['spatial_patch_hw'], 240 | data_root=data_root, 241 | num_sequences=num_seq, 242 | img_dim_hw=ckpt.hparams['img_dim_hw'], 243 | camera_choice=ckpt.hparams['cameras'], 244 | decode_pixel_downsample_factor=1, 245 | add_instance_seg=add_instance_seg, 246 | no_viewpoint = not ckpt.hparams['provide_viewpoint']) 247 | dataset.set_indices([args.idx]) 248 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 249 | if args.gpu is not None: 250 | trainer = Trainer(accelerator='gpu', devices=[args.gpu], logger=False) 251 | else: 252 | trainer = Trainer(accelerator='cpu', logger=False) 253 | 254 | batch = dataset.__getitem__(0) 255 | batch_results = trainer.predict(ckpt, dataloaders=dataloader)[0] 256 | 257 | if args.cache_save: 258 | full_results = dict(batch=batch, batch_results=batch_results) 259 | with open(args.cache_save, 'wb') as f: 260 | pickle.dump(full_results, f) 261 | 262 | frames = range(ckpt.hparams['sequence_len']) 263 | cam = args.camera 264 | if args.fig_type == 'seq_seq': 265 | fig = plot_seg_sequence(ckpt, batch, batch_results, frames, cam=cam) 266 | elif args.fig_type == 'seg_overlay': 267 | fig = plot_seg_overlay(ckpt, batch, batch_results, frames, cam=cam) 268 | elif args.fig_type == 'raw_seg_overlay': 269 | fig = plot_raw_seg_overlay(ckpt, batch, batch_results, frames, cam=cam) 270 | elif args.fig_type == 'waypoints': 271 | data_path = os.path.join(data_root, f'{args.idx}.npz') 272 | fig = plot_waypoints(ckpt, data_path, batch, batch_results) 273 | fig.savefig(args.out_path) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from pytorch_lightning import Trainer 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | from model import SOCS 13 | from util import fourier_embeddings 14 | 15 | class SOCSDataset(Dataset): 16 | def __init__(self, 17 | sequence_length, 18 | spatial_patch_hw, 19 | data_root, 20 | num_sequences=1, 21 | decode_pixel_downsample_factor=16, 22 | img_dim_hw=(0,0), 23 | camera_choice=[1], 24 | add_instance_seg=False, 25 | num_fourier_bands=10, 26 | fourier_sampling_rate=60, 27 | no_viewpoint=False): 28 | 29 | self.img_dim_hw = img_dim_hw 30 | self.seq_len = sequence_length 31 | self.decode_pixel_downsample_factor = decode_pixel_downsample_factor 32 | self.spatial_patch_hw = spatial_patch_hw 33 | self.data_root = data_root 34 | self.num_sequences = num_sequences 35 | self.camera_choice = camera_choice 36 | self.add_instance_seg = add_instance_seg 37 | self.num_fourier_bands = num_fourier_bands 38 | self.fourier_sampling_rate = fourier_sampling_rate 39 | self.provide_viewpoint = not no_viewpoint 40 | 41 | def __len__(self): 42 | return self.num_sequences 43 | 44 | def __getitem__(self, idx): 45 | (item, decode_mask) = self._set_pixels_to_decode(self._loaditem(idx)) 46 | if self.add_instance_seg: 47 | self._load_instance_seg(idx, item, decode_mask) 48 | return item 49 | 50 | # Overload for the specific dataset 51 | # Must return image sequence, viewpoint sequence, and optional time sequence 52 | def _loaditem(self, idx): 53 | pass 54 | 55 | def _set_pixels_to_decode(self, item): 56 | """ 57 | Given a loaded sequence, find the positional embeddings for the transformer and the queries for 58 | the output decoder. 59 | """ 60 | num_frames = self.seq_len*len(self.camera_choice) 61 | random_h_offset = np.random.randint(self.decode_pixel_downsample_factor) 62 | decode_pixel_h_inds = slice(random_h_offset, self.img_dim_hw[0], self.decode_pixel_downsample_factor) 63 | random_w_offset = np.random.randint(self.decode_pixel_downsample_factor) 64 | decode_pixel_w_inds = slice(random_w_offset, self.img_dim_hw[1], self.decode_pixel_downsample_factor) 65 | 66 | # Mask that determines which of the pixels in the input data will be decoded 67 | decode_mask = np.zeros((num_frames,) + self.img_dim_hw, dtype='bool') 68 | decode_mask[:, decode_pixel_h_inds, decode_pixel_w_inds] = True 69 | 70 | all_inds = np.array(np.meshgrid(range(num_frames), range(self.img_dim_hw[0]), range(self.img_dim_hw[1]), indexing='ij')) 71 | decode_inds = all_inds[:, decode_mask].T # /in num_p x 3 72 | 73 | img_seq = item['img_seq'] 74 | viewpoint_seq = item['viewpoint_seq'] 75 | # Remove the last row of the transform matrix if it's stored 76 | if viewpoint_seq.shape[1] > 12: 77 | viewpoint_seq = viewpoint_seq[:, :-4] 78 | 79 | if self.provide_viewpoint: 80 | viewpoint_size = viewpoint_seq.shape[1] 81 | else: 82 | viewpoint_size = 1 83 | 84 | # 3D positional information to be used as a query for the decoder 85 | # The columns are time, y, x, viewpoint_transform 86 | base_decoder_queries = np.zeros((decode_inds.shape[0], 3 + viewpoint_size)) 87 | base_decoder_queries[:,0] = item['time_seq'][decode_inds[:,0]] # time 88 | base_decoder_queries[:,1] = (2*decode_inds[:,1] - (self.img_dim_hw[0] - 1)).flatten() / self.img_dim_hw[0] # y 89 | base_decoder_queries[:,2] = (2*decode_inds[:,2] - (self.img_dim_hw[1] - 1)).flatten() / self.img_dim_hw[1] # x 90 | if self.provide_viewpoint: 91 | base_decoder_queries[:,3:] = viewpoint_seq[decode_inds[:,0]] # viewpoint transform 92 | # If not providing the camera transform matrix, recover the index of the camera to provide instead 93 | else: 94 | num_viewpoints = len(self.camera_choice) 95 | num_timepoints = self.seq_len 96 | camera_inds = np.unravel_index(decode_inds[:,0], (num_timepoints, num_viewpoints))[1] 97 | base_decoder_queries[:,3] = np.ones(decode_inds.shape[0]) * camera_inds 98 | 99 | decoder_queries = fourier_embeddings(base_decoder_queries, self.num_fourier_bands, self.fourier_sampling_rate) 100 | 101 | # Prepare the chunk labels for the first transformer 102 | base_patch_embeddings = np.zeros((num_frames, self.spatial_patch_hw[0], self.spatial_patch_hw[1], 3 + viewpoint_size)) 103 | 104 | for i in range(num_frames): # frames 105 | time_offset = item['time_seq'][i] 106 | if self.provide_viewpoint: 107 | view_offset = viewpoint_seq[i] 108 | else: 109 | view_offset = [np.unravel_index(i, (num_timepoints, num_viewpoints))[1]] 110 | for j in range(self.spatial_patch_hw[0]): # height 111 | patch_y_offset = ((2*j) / (self.spatial_patch_hw[0] - 1)) - 1 112 | for k in range(self.spatial_patch_hw[1]): # width 113 | patch_x_offset = ((2*k) / (self.spatial_patch_hw[1] - 1)) - 1 114 | base_patch_embeddings[i,j,k] = np.array([time_offset, patch_y_offset, patch_x_offset, *view_offset]) 115 | 116 | patch_positional_embeddings = fourier_embeddings(base_patch_embeddings, self.num_fourier_bands, self.fourier_sampling_rate) 117 | 118 | data = dict( 119 | img_seq = img_seq.astype('float32'), 120 | decode_dims = np.array([num_frames, 121 | self.img_dim_hw[0] // self.decode_pixel_downsample_factor, 122 | self.img_dim_hw[1] // self.decode_pixel_downsample_factor]), 123 | ground_truth_rgb = img_seq[decode_mask], 124 | patch_positional_embeddings = patch_positional_embeddings.astype('float32'), 125 | decoder_queries = decoder_queries.astype('float32'), 126 | ) 127 | 128 | if 'bc_waypoints' in item: 129 | data['bc_waypoints'] = item['bc_waypoints'] 130 | if 'bc_mask' in item: 131 | data['bc_mask'] = item['bc_mask'] 132 | 133 | return (data, decode_mask) 134 | 135 | class LocalDataset(SOCSDataset): 136 | def _loaditem(self, idx, data_root=None): 137 | data_root = data_root if data_root is not None else self.data_root 138 | num_frames = self.seq_len*len(self.camera_choice) 139 | data_path = os.path.join(data_root, f'{idx}.npz') 140 | with open(data_path, 'rb') as f: 141 | data = np.load(f) 142 | img_seq = data['rgb'][:self.seq_len, self.camera_choice] 143 | img_seq = img_seq.reshape((num_frames,) + img_seq.shape[2:]).astype('float32') 144 | viewpoint_seq = data['viewpoint_transform'][:self.seq_len, self.camera_choice] 145 | viewpoint_seq = viewpoint_seq.reshape((num_frames,) + viewpoint_seq.shape[2:]) 146 | time_seq = data['time'][:self.seq_len].flatten() 147 | 148 | loaded_data = dict(img_seq=img_seq, 149 | viewpoint_seq=viewpoint_seq, 150 | time_seq=time_seq) 151 | 152 | if 'bc_waypoints' in data: 153 | loaded_data['bc_waypoints'] = data['bc_waypoints'] 154 | if 'bc_mask' in data: 155 | loaded_data['bc_mask'] = data['bc_mask'] 156 | 157 | return loaded_data 158 | 159 | def _load_instance_seg(self, idx, item, decode_mask, data_root=None): 160 | num_frames = len(self.camera_choice)*self.seq_len 161 | data_root = data_root if data_root is not None else self.data_root 162 | data_path = os.path.join(data_root, f'{idx}.npz') 163 | with open(data_path, 'rb') as f: 164 | data = np.load(f) 165 | if 'instance_seg' in data: 166 | instance_segs = data['instance_seg'] 167 | instance_segs = instance_segs.reshape((num_frames,) + instance_segs.shape[2:])[decode_mask] 168 | else: 169 | instance_segs = np.zeros(item['img_seq'].shape[:-1]) 170 | 171 | instance_masks = np.zeros(instance_segs.shape, dtype='bool') 172 | instance_masks[np.where(instance_segs != 0)[0]] = True 173 | 174 | instances = np.unique(instance_segs[instance_masks]) 175 | num_instances = len(instances) 176 | if num_instances > 0: 177 | instance_oh = np.zeros(instance_masks.shape + (num_instances,)) 178 | for i in range(num_instances): 179 | single_mask = np.where(instance_segs == instances[i])[0] 180 | instance_oh[:, i][single_mask] = 1 181 | else: 182 | instance_oh = np.zeros(0) 183 | 184 | item['instance_oh'] = instance_oh 185 | item['instance_mask'] = instance_masks 186 | 187 | # A dataset class designed to only load a specified subset of the full dataset 188 | class InferenceDataset(LocalDataset): 189 | def set_indices(self, indices): 190 | self.indices = indices 191 | 192 | def __len__(self): 193 | return len(self.indices) 194 | 195 | def __getitem__(self, idx): 196 | (item, decode_mask) = self._set_pixels_to_decode(self._loaditem(self.indices[idx])) 197 | if self.add_instance_seg: 198 | self._load_instance_seg(self.indices[idx], item, decode_mask) 199 | return item 200 | 201 | if __name__ == '__main__': 202 | parser = argparse.ArgumentParser() 203 | 204 | # Basic training parameters 205 | parser.add_argument('--name', default='SOCS') 206 | parser.add_argument('--batch_size', type=int, default=8) 207 | parser.add_argument('--gpu', type=int, default=[0], nargs='+') 208 | parser.add_argument('--seed', type=int, default=1) 209 | parser.add_argument('--num_train_seq', type=int, default=40000) 210 | parser.add_argument('--num_epochs', type=int, default=1000) # -1 for infinite epochs 211 | parser.add_argument('--data_root', default='waymo_open') 212 | parser.add_argument('--dataset', default='waymo', choices=['waymo']) 213 | parser.add_argument('--lr', type=float, default=1e-4) 214 | 215 | # Network hyperparameters 216 | parser.add_argument('--no_viewpoint', action='store_true') 217 | parser.add_argument('--num_gaussian_heads', type=int, default=3) 218 | parser.add_argument('--behavioral_cloning_task', action='store_true') 219 | parser.add_argument('--sequence_length', type=int, default=8) 220 | parser.add_argument('--beta', type=float, default=5e-7) 221 | parser.add_argument('--bc_loss_weight') 222 | parser.add_argument('--sigma', type=float, default=0.08) 223 | parser.add_argument('--downsample_factor', type=int, default=16) 224 | parser.add_argument('--num_patches_height', type=int, default=None) 225 | parser.add_argument('--num_patches_width', type=int, default=None) 226 | parser.add_argument('--checkpoint_path', default=None) 227 | parser.add_argument('--decoder_layers', type=int, default=3) 228 | parser.add_argument('--decoder_size', type=int, default=1536) 229 | parser.add_argument('--transformer_heads', type=int, default=4) 230 | parser.add_argument('--transformer_head_size', type=int, default=128) 231 | parser.add_argument('--transformer_ff_size', type=int, default=1024) 232 | parser.add_argument('--transformer_layers', type=int, default=3) 233 | parser.add_argument('--num_object_slots', type=int, default=None) 234 | parser.add_argument('--object_latent_size', type=int, default=32) 235 | parser.add_argument('--cameras', type=int, default=[0, 1, 2], nargs='+') 236 | parser.add_argument('--num_fourier_bands', type=int, default=10) 237 | parser.add_argument('--fourier_sampling_rate', type=int, default=60) 238 | 239 | args = parser.parse_args() 240 | batch_size = args.batch_size // len(args.gpu) 241 | num_frames = len(args.cameras) * args.sequence_length 242 | 243 | torch.manual_seed(args.seed) 244 | 245 | if args.dataset == 'waymo': 246 | img_dim_hw = (96, 224) 247 | if args.no_viewpoint: 248 | viewpoint_size = 1 + 3 249 | else: 250 | viewpoint_size = 12 + 3 251 | default_patches_hw = (6, 14) 252 | default_num_object_slots = 21 253 | 254 | viewpoint_size *= (1 + 2*args.num_fourier_bands) 255 | nph = args.num_patches_height if args.num_patches_height is not None else default_patches_hw[0] 256 | npw = args.num_patches_width if args.num_patches_width is not None else default_patches_hw[1] 257 | spatial_patch_hw = (nph, npw) 258 | num_objects = args.num_object_slots if args.num_object_slots is not None else default_num_object_slots 259 | 260 | train_dataloader = DataLoader(LocalDataset(args.sequence_length, 261 | spatial_patch_hw, 262 | os.path.join(args.data_root, 'train'), 263 | num_sequences=args.num_train_seq, 264 | img_dim_hw=img_dim_hw, 265 | decode_pixel_downsample_factor=args.downsample_factor, 266 | camera_choice=args.cameras, 267 | no_viewpoint=args.no_viewpoint), 268 | batch_size=batch_size, shuffle=True, num_workers=args.batch_size) 269 | 270 | model = SOCS(img_dim_hw=img_dim_hw, 271 | embed_dim=args.object_latent_size, 272 | beta=args.beta, 273 | sigma_x=args.sigma, 274 | viewpoint_size=viewpoint_size, 275 | learning_rate=args.lr, 276 | num_transformer_layers=args.transformer_layers, 277 | num_transformer_heads=args.transformer_heads, 278 | transformer_head_dim=args.transformer_head_size, 279 | transformer_hidden_dim=args.transformer_ff_size, 280 | num_decoder_layers=args.decoder_layers, 281 | decoder_hidden_dim=args.decoder_size, 282 | num_object_slots=num_objects, 283 | spatial_patch_hw=spatial_patch_hw, 284 | pixel_downsample_factor=args.downsample_factor, 285 | num_fourier_bands=args.num_fourier_bands, 286 | fourier_sampling_rate=args.fourier_sampling_rate, 287 | cameras=args.cameras, 288 | provide_viewpoint=not args.no_viewpoint, 289 | num_gaussian_heads=args.num_gaussian_heads, 290 | bc_task=args.behavioral_cloning_task, 291 | seed=args.seed, 292 | dataset_name=args.dataset, 293 | dataset_root=args.data_root, 294 | sequence_len=args.sequence_length) 295 | 296 | 297 | logger = TensorBoardLogger(save_dir=os.getcwd(), name=os.path.join('logs', args.name)) 298 | 299 | recent_checkpoint_callback = ModelCheckpoint( 300 | filename="last_{step}", 301 | every_n_train_steps=100 302 | ) 303 | historical_checkpoint_callback = ModelCheckpoint( 304 | save_top_k=-1, 305 | every_n_train_steps=100000, 306 | filename="{step}" 307 | ) 308 | trainer = Trainer(accelerator='gpu', 309 | devices=args.gpu, 310 | strategy="ddp" if len(args.gpu) > 1 else None, 311 | check_val_every_n_epoch=1, 312 | logger=logger, 313 | max_epochs=args.num_epochs, 314 | precision=16, 315 | callbacks=[recent_checkpoint_callback, historical_checkpoint_callback]) 316 | 317 | trainer.fit(model, train_dataloader, ckpt_path=args.checkpoint_path) 318 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import matplotlib.colors as pltcolors 3 | import matplotlib.cm as pltcm 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | from pytorch_lightning import LightningModule 10 | from torch.distributions.normal import Normal 11 | 12 | from nets import fc_net, CNNEncoder, TransformerBlock, QueryDecoder 13 | from segmentation_metrics import adjusted_rand_index, closest_centroids_metric 14 | from util import MASK_COLORS 15 | 16 | # Note: only square images for now 17 | class SOCS(LightningModule): 18 | def __init__(self, 19 | img_dim_hw=(64,64), 20 | embed_dim=32, 21 | num_transformer_heads=4, 22 | transformer_head_dim=128, 23 | transformer_hidden_dim=1024, 24 | num_transformer_layers=4, 25 | decoder_hidden_dim=1536, 26 | num_decoder_layers=3, 27 | spatial_patch_hw=(6, 14), 28 | num_object_slots=21, 29 | beta=1e-6, 30 | bc_loss_weight=1e-4, 31 | sigma_x=0.08, 32 | num_gaussian_heads=3, 33 | viewpoint_size=12, 34 | learning_rate=1e-4, 35 | bc_task = True, 36 | bc_task_transformer_layers = 2, 37 | num_target_points=16, 38 | 39 | # This and below are just here so they're saved w/ other hyperparameters 40 | provide_viewpoint=True, 41 | sequence_len=8, 42 | seed=1, 43 | pixel_downsample_factor=None, 44 | num_fourier_bands=10, 45 | fourier_sampling_rate=60, 46 | dataset_name=None, 47 | dataset_root=None, 48 | cameras=None, 49 | ): 50 | 51 | super(SOCS, self).__init__() 52 | self.save_hyperparameters() 53 | self.transformer_dim = num_transformer_heads*transformer_head_dim 54 | 55 | self.encoder = CNNEncoder(img_dim_hw=img_dim_hw, output_dim=self.transformer_dim-self.hparams.viewpoint_size) 56 | 57 | transformer_1_layers = [ 58 | TransformerBlock(self.transformer_dim, num_transformer_heads, dim_feedforward=transformer_hidden_dim, dropout=0) 59 | for _ in range(num_transformer_layers)] 60 | self.transformer_1 = nn.Sequential(*transformer_1_layers) 61 | 62 | num_patches = spatial_patch_hw[0] * spatial_patch_hw[1] 63 | if num_patches == num_object_slots: 64 | self.spatial_pool = nn.Identity() 65 | elif num_patches == 4 * num_object_slots: 66 | self.spatial_pool = nn.AvgPool2d((2,2), (2,2)) 67 | else: 68 | raise ValueError(f'No pooling implementation for {num_patches} patches and {num_object_slots} objects') 69 | 70 | transformer_2_layers = [ 71 | TransformerBlock(self.transformer_dim, num_transformer_heads, dim_feedforward=transformer_hidden_dim, dropout=0) 72 | for _ in range(num_transformer_layers)] 73 | self.transformer_2 = nn.Sequential(*transformer_2_layers) 74 | 75 | # Decode transformer output into an object latent 76 | self.latent_decoder = fc_net( 77 | num_layers=1, in_size=self.transformer_dim, hidden_size=1024, out_proj_size=2*self.hparams.embed_dim) 78 | 79 | # Use object latents to render specific pixels 80 | self.query_decoder = QueryDecoder( 81 | input_size=self.hparams.embed_dim, query_size=self.hparams.viewpoint_size, 82 | hidden_size=decoder_hidden_dim, output_size=(self.hparams.num_gaussian_heads*4)+1, num_hidden_layers=num_decoder_layers) 83 | 84 | self.kl_prior = Normal(torch.tensor(0), torch.tensor(1)) 85 | 86 | if bc_task: 87 | task_output_dim = num_target_points * 2 88 | task_transformer_layers = [ 89 | TransformerBlock(self.transformer_dim, num_transformer_heads, 90 | dim_feedforward=transformer_hidden_dim, dropout=0) 91 | for _ in range(num_transformer_layers)] 92 | self.task_transformer = nn.Sequential(*task_transformer_layers) 93 | self.task_mlp = fc_net(num_layers=1, in_size=self.transformer_dim, hidden_size=1024, out_proj_size=task_output_dim) 94 | 95 | # How many pixels should be decoded at a time when aggregating reconstructions for an entire image sequence 96 | self.inference_parallel_pixels = 10000 97 | 98 | # Dimension keys: B - batch size, T - number of frames, Y - image pixel height, X - image pixel width 99 | # U - number of spatial patches (height), V - number of spatial patches (width) 100 | # E - embedding dim, N - number of pixels to decode across all frames in sequence, K - number of object slots 101 | # S - viewpoint supervision dimension 102 | def forward(self, data): 103 | slot_tokens = self.get_slot_tokens(data) 104 | return self.decode_latents(data, slot_tokens) 105 | 106 | def get_slot_tokens(self, data): 107 | x = data['img_seq'] # \in B x T x Y x X x 3 108 | positional_embeddings = data['patch_positional_embeddings'] # \in B x T x Y x X x S 109 | batch_size = x.shape[0] 110 | num_frame_slots = x.shape[1] 111 | 112 | # Encode the entire sequence of images 113 | x = self.encoder(x) # \in B x T x U x V x E-S 114 | x = torch.cat((x, positional_embeddings), -1).flatten(1, 3) # \in B x T*U*V x E 115 | 116 | # Transformers 117 | x = x.moveaxis(0, 1) # \in T*U*V x B x E 118 | x = self.transformer_1(x) 119 | # unflatten back to T x U x V x B x E 120 | x = x.reshape((num_frame_slots, self.hparams.spatial_patch_hw[0], self.hparams.spatial_patch_hw[1], batch_size, self.transformer_dim)) 121 | x = x.moveaxis(1, -1).moveaxis(1, -1) # \in T x B x E x U x V 122 | x = x.flatten(0, 1) # \in T*B x E x U x V 123 | x = self.spatial_pool(x).mul(2) # \in T*B x E x sqrt(K) x sqrt(K) 124 | x = x.reshape((num_frame_slots, batch_size) + x.shape[1:]) # \in T x B x E x sqrt(K) x sqrt(K) 125 | x = x.flatten(3, 4) # \in T x B x E x K 126 | x = x.moveaxis(-1, 1) # \in T x K x B x E 127 | x = x.flatten(0, 1) # \in T*K x B x E 128 | x = self.transformer_2(x) 129 | x = x.moveaxis(1, 0) # \in B x T*K x E 130 | x = x.reshape((batch_size, num_frame_slots, self.hparams.num_object_slots, self.transformer_dim)) # Uncollapse the patches 131 | 132 | # Aggregate along the frame latents and get the object latents 133 | slot_tokens = torch.mean(x, 1) # \in B x K x E 134 | return slot_tokens 135 | 136 | def decode_latents(self, data, slot_tokens, eval=False): 137 | output = {} 138 | decoder_queries = data['decoder_queries'] # \in B x N x S 139 | object_latent_pars = self.latent_decoder(slot_tokens) 140 | object_latent_mean = object_latent_pars[..., :self.hparams.embed_dim] # \in B x K x E 141 | 142 | if eval: 143 | object_latents = object_latent_mean 144 | else: 145 | object_latent_var = nn.functional.softplus(object_latent_pars[..., self.hparams.embed_dim:]) 146 | # Sample object latents from gaussian distribution 147 | object_latent_distribution = Normal(object_latent_mean, object_latent_var) 148 | object_latents = object_latent_distribution.rsample() 149 | 150 | # Use queries and object latents to decode the selected pixels for loss calculations 151 | queries = decoder_queries.unsqueeze(1).tile(1, self.hparams.num_object_slots, 1, 1) # \in B x K x N x S 152 | x = self.query_decoder(object_latents, queries) # \in B x K x N x (M*4)+1 153 | per_object_preds = x[..., :3*self.hparams.num_gaussian_heads].unflatten(-1, (self.hparams.num_gaussian_heads, 3)) # \in B x K x N x M x 3 154 | per_mode_log_weights = nn.functional.log_softmax(x[..., 3*self.hparams.num_gaussian_heads : -1], -1) # \in B x K x N x M 155 | per_object_log_weights = nn.functional.log_softmax(x[..., -1], 1) # \in B x K x N 156 | 157 | # Independent gaussians for M values of R, M values of G, M values of B 158 | per_object_pixel_distributions = Normal(per_object_preds, self.hparams.sigma_x) 159 | ground_truth_rgb = data['ground_truth_rgb'].unsqueeze(1).unsqueeze(3) # \in B x 1 x N x 1 x 3 160 | 161 | per_object_pixel_log_likelihoods = per_object_pixel_distributions.log_prob(ground_truth_rgb) # \in B x K x N x M x 3 162 | # Sum across RGB because we assume the probabilities of each channel are independent, so log(P(R,G,B)) = log(P(R)P(G)P(B)) = log(P(R)) + log(P(G)) + log(P(B)) 163 | per_object_pixel_log_likelihoods = per_object_pixel_log_likelihoods.sum(-1) # \in B x K x N x M 164 | # First sum the likelihoods and weights - this is equivalent to mulitplying them in regular space 165 | # Then apply logsumexp to add the weighted likelihoods in regular space and then convert back to log space 166 | weighted_mixture_log_likelihood = torch.logsumexp(per_object_pixel_log_likelihoods + per_mode_log_weights, -1) # \in B x K x N 167 | weighted_mixture_log_likelihood = torch.logsumexp(weighted_mixture_log_likelihood + per_object_log_weights, 1) # \in B x N 168 | 169 | # If using semantic segmentation to mask out the background for the reconstruction loss, do that here 170 | reconstruction_loss = -(weighted_mixture_log_likelihood).mean() 171 | output['reconstruction_loss'] = reconstruction_loss 172 | 173 | if self.hparams.bc_task and 'bc_waypoints' in data: 174 | task_tokens = self.task_transformer(slot_tokens.swapaxes(0, 1)) # \in K x B x E 175 | task_preds = self.task_mlp(task_tokens.mean(0)) # \in B x task_dim 176 | bc_mask = data['bc_mask'].unsqueeze(-1) 177 | targets = (data['bc_waypoints'] * bc_mask).flatten(1) 178 | preds = (task_preds.unflatten(-1, (-1, 2)) * bc_mask).flatten(1) 179 | bc_loss = nn.functional.smooth_l1_loss(preds, targets) 180 | output['bc_loss'] = bc_loss 181 | output['bc_waypoints'] = task_preds.unflatten(-1, (-1, 2)) 182 | 183 | if eval: 184 | per_mode_weights = torch.exp(per_mode_log_weights) # \in B x K x N x M 185 | unimodal_per_object_preds = per_object_preds.mul(per_mode_weights.unsqueeze(-1)).sum(-2) # \in B x K x N x 3 186 | output['per_object_preds'] = unimodal_per_object_preds 187 | per_object_weights = torch.exp(per_object_log_weights) # \in B x K x N 188 | output['per_object_weights'] = per_object_weights 189 | else: 190 | # Sum across the object latent dimension, and across all objects 191 | kl_loss = torch.distributions.kl_divergence(object_latent_distribution, self.kl_prior).sum((1,2)).mean() 192 | output['kl_loss'] = kl_loss 193 | 194 | return output 195 | 196 | def training_step(self, batch, batch_idx): 197 | output = self(batch) 198 | self.log('reconstruction_loss', output['reconstruction_loss']) 199 | self.log('distribution_loss', output['kl_loss']) 200 | loss = (output['reconstruction_loss'] 201 | + output['kl_loss'].mul(self.hparams.beta)) 202 | 203 | if self.hparams.bc_task: 204 | self.log('bc_loss', output['bc_loss']) 205 | loss += output['bc_loss'].mul(self.hparams.bc_loss_weight) 206 | 207 | self.log('total_loss', loss) 208 | return loss 209 | 210 | def predict_step(self, batch, batch_idx): 211 | return self.inference_and_metrics(batch) 212 | 213 | def inference_and_metrics(self, batch, batch_ind=0): 214 | num_pix = self.inference_parallel_pixels 215 | pixel_ind = 0 216 | (B, F, H, W, _) = batch['img_seq'].shape 217 | (B, total_num_pix) = batch['decoder_queries'].shape[:2] 218 | 219 | # Predictions for all pixels at once consumes too much memory 220 | # So only predict num_pix at a time and aggregate results 221 | per_object_preds = batch['img_seq'].new_zeros(B, self.hparams.num_object_slots, total_num_pix, 3) 222 | per_object_weights = batch['img_seq'].new_zeros(B, self.hparams.num_object_slots, total_num_pix) 223 | slot_tokens = self.get_slot_tokens(batch) 224 | while pixel_ind < total_num_pix: 225 | # Not really a minibatch. But a mini batch 226 | minibatch = {} 227 | minibatch['decoder_queries'] = batch['decoder_queries'][:, pixel_ind : pixel_ind + num_pix] 228 | minibatch['ground_truth_rgb'] = batch['ground_truth_rgb'][:, pixel_ind : pixel_ind + num_pix] 229 | if 'bc_waypoints' in batch: 230 | minibatch['bc_waypoints'] = batch['bc_waypoints'] 231 | minibatch['bc_mask'] = batch['bc_mask'] 232 | 233 | mini_output = self.decode_latents(minibatch, slot_tokens, eval=True) 234 | per_object_preds[:, :, pixel_ind : pixel_ind + num_pix] = mini_output['per_object_preds'].detach() 235 | per_object_weights[:, :, pixel_ind : pixel_ind + num_pix] = mini_output['per_object_weights'].detach() 236 | num_pix = min(num_pix, total_num_pix - pixel_ind) 237 | pixel_ind += num_pix 238 | 239 | preds = self.mixture_preds(per_object_preds, per_object_weights) 240 | greedy_preds = self.greedy_preds(per_object_preds, per_object_weights)[batch_ind].cpu().detach().numpy() 241 | 242 | pred_rgb = preds[batch_ind].cpu().detach().numpy() 243 | pred_weights_tensor = per_object_weights[batch_ind].cpu().detach() 244 | 245 | ground_truth_rgb = batch['ground_truth_rgb'][batch_ind].cpu().numpy() 246 | reconstruction_err = np.mean((pred_rgb - ground_truth_rgb)**2) 247 | 248 | decode_dims = tuple(batch['decode_dims'][batch_ind].cpu()) 249 | 250 | results_dict = dict( 251 | reconstruction_err = reconstruction_err, 252 | preds = pred_rgb, 253 | greedy_preds = greedy_preds, 254 | per_object_weights = pred_weights_tensor.numpy()) 255 | 256 | # Calculate segmentation metrics when ground truth instance segmentation is available 257 | if 'instance_oh' in batch: 258 | ground_truth_segmentation_tensor = batch['instance_oh'][batch_ind].cpu().unsqueeze(0) 259 | if len(ground_truth_segmentation_tensor.flatten()) > 0: 260 | pred_segmentation_tensor = pred_weights_tensor.swapaxes(0, 1).unsqueeze(0) 261 | seq_ari = adjusted_rand_index(ground_truth_segmentation_tensor, pred_segmentation_tensor).numpy().item() 262 | else: 263 | seq_ari = float('nan') 264 | results_dict['seq_ari'] = seq_ari 265 | 266 | instance_mask = batch['instance_mask'][batch_ind].cpu().numpy() 267 | if np.any(instance_mask): 268 | instance_reconstruction_err = np.mean((pred_rgb[instance_mask] - ground_truth_rgb[instance_mask])**2) 269 | else: 270 | instance_reconstruction_err = float('nan') 271 | results_dict['instance_reconstruction_err'] = instance_reconstruction_err 272 | 273 | # Per-frame ARI 274 | frame_dims = (decode_dims[0], decode_dims[1]*decode_dims[2]) 275 | pred_weights_seq_tensor = pred_weights_tensor.reshape((self.hparams.num_object_slots,) + frame_dims) 276 | instance_mask_seq = instance_mask.reshape(frame_dims) 277 | ground_truth_segmentation_seq_tensor = ground_truth_segmentation_tensor.reshape(frame_dims + (-1,)) 278 | ari_frame = [] 279 | for i in range(decode_dims[0]): 280 | instance_mask = instance_mask_seq[i] 281 | if np.any(instance_mask): 282 | frame_weights_tensor = pred_weights_seq_tensor[:, i].swapaxes(0, 1).unsqueeze(0) 283 | ground_truth_segmentation_frame_tensor = ground_truth_segmentation_seq_tensor[i].unsqueeze(0) 284 | ari_frame.append(adjusted_rand_index(ground_truth_segmentation_frame_tensor, frame_weights_tensor).numpy().item()) 285 | else: 286 | ari_frame.append(float('nan')) 287 | results_dict['ari'] = np.nanmean(ari_frame) 288 | 289 | # Centroid distance metric 290 | T = F // len(self.hparams.cameras) 291 | C = len(self.hparams.cameras) 292 | gt_weights = batch['instance_oh'][batch_ind].reshape((T, C, H, W, -1)).moveaxis(-1, 0).cpu() 293 | pred_weights = per_object_weights[batch_ind].reshape((-1, T, C, H, W)).cpu() 294 | avg_centroid_dist = closest_centroids_metric(pred_weights, gt_weights) 295 | results_dict['avg_centroid_dist'] = avg_centroid_dist 296 | 297 | if 'bc_waypoints' in mini_output: 298 | results_dict['bc_waypoints'] = mini_output['bc_waypoints'] 299 | 300 | return results_dict 301 | 302 | def configure_optimizers(self): 303 | return optim.Adam(self.parameters(), lr=self.hparams.learning_rate) 304 | 305 | def mixture_preds(self, per_object_preds, per_object_weights): 306 | """ 307 | For each pixel, return the weighted average of the predictions across objects. 308 | """ 309 | return per_object_preds.mul(per_object_weights.unsqueeze(3)).sum(1) # \in B x N x 3 310 | 311 | def greedy_preds(self, per_object_preds, per_object_weights): 312 | """ 313 | For each pixel, return the prediction of the object mask with the highest weight. 314 | """ 315 | (num_batch, _, num_p, _) = per_object_preds.size() 316 | best_obj_ids = torch.argmax(per_object_weights, 1).reshape((num_batch, 1, num_p, 1)) 317 | preds = torch.gather(per_object_preds, 1, best_obj_ids.tile(1,1,1,3)).squeeze(1) 318 | return preds 319 | 320 | def reconstruct_image(self, preds, dims, idx=0): 321 | """ 322 | Show the reconstructed RGB image. 323 | """ 324 | img_arr = preds.reshape(dims + (3,))[idx] 325 | img_arr = np.clip(img_arr, 0, 1) * 255 326 | return img_arr.astype('uint8') 327 | 328 | def show_object_masks(self, per_object_weights, dims, idx=0): 329 | """ 330 | Show the predicted segmentation masks. 331 | """ 332 | best_obj_ids = np.argmax(per_object_weights, 0).reshape(dims)[idx] 333 | color_inds = np.mod(best_obj_ids, len(MASK_COLORS)) 334 | mask_img = MASK_COLORS[color_inds.flatten()].reshape(dims[1:] + (3,)) 335 | return mask_img.astype('uint8') 336 | 337 | def show_object_masks_foreground(self, per_object_weights, foreground_seg, dims, idx=0): 338 | """ 339 | Show the predicted segmentation masks only for pixels that belong to ground-truth objects. 340 | """ 341 | mask_img = self.show_object_masks(per_object_weights, dims, idx=idx) 342 | frame_foreground_seg = foreground_seg.reshape(dims)[idx] 343 | background_inds = np.logical_not(frame_foreground_seg.flatten()) 344 | mask_arr_flat = mask_img.reshape(-1, 3) 345 | mask_arr_flat[background_inds, :] = [0,0,0] 346 | mask_img = mask_arr_flat.reshape(mask_img.shape) 347 | return mask_img 348 | 349 | def show_ground_truth_seg(self, instance_oh, dims, idx=0): 350 | """ 351 | Show the ground-truth object segmentation. 352 | """ 353 | frame_instance_oh = instance_oh.reshape(dims + (-1,))[idx] 354 | instance_seg_flat = np.zeros(dims[1:], dtype='uint8').flatten() 355 | n_total_ground_truth_obj = frame_instance_oh.shape[-1] 356 | frame_instance_oh_flat = frame_instance_oh.reshape(-1, n_total_ground_truth_obj) 357 | for i in range(n_total_ground_truth_obj): 358 | mask = np.where(frame_instance_oh_flat[:, i] == 1) 359 | instance_seg_flat[mask] = i + 1 360 | color_inds = np.mod(instance_seg_flat, len(MASK_COLORS)) 361 | colors = np.concatenate(([[0,0,0]], MASK_COLORS)) 362 | mask_img = colors[color_inds.flatten()].reshape(dims[1:] + (3,)) 363 | return mask_img 364 | 365 | def show_pixel_scores(self, per_object_weights, instance_oh, dims, idx=0): 366 | """ 367 | Assign a segmentation quality score to each pixel belonging to a ground-truth object, and 368 | plot. 369 | """ 370 | frame_pred_seg = np.argmax(per_object_weights, 0).reshape(dims)[idx] 371 | frame_instance_oh = instance_oh.reshape(dims + (-1,)).astype('uint8')[idx] 372 | frame_instance_seg_flat = np.zeros(dims[1:], dtype='uint8').flatten() 373 | n_total_ground_truth_obj = frame_instance_oh.shape[-1] 374 | frame_instance_oh_flat = frame_instance_oh.reshape(-1, n_total_ground_truth_obj) 375 | for i in range(n_total_ground_truth_obj): 376 | mask = np.where(frame_instance_oh_flat[:, i] == 1) 377 | if len(mask) > 0: 378 | frame_instance_seg_flat[mask] = i 379 | 380 | foreground_inds = np.where(frame_instance_seg_flat != 0) 381 | ground_truth_seg_foreground = frame_instance_seg_flat[foreground_inds] 382 | pred_seg_foreground = frame_pred_seg.flatten()[foreground_inds] 383 | score_img = np.zeros((dims[1]*dims[2], 3)) 384 | n_foreground_pixels = len(foreground_inds[0]) 385 | if n_foreground_pixels > 0: 386 | 387 | pixel_scores = np.zeros(n_foreground_pixels) 388 | for pixel in range(n_foreground_pixels): 389 | pred_class = pred_seg_foreground[pixel] 390 | pred_pairs = pred_seg_foreground == pred_class 391 | gt_class = ground_truth_seg_foreground[pixel] 392 | gt_pairs = ground_truth_seg_foreground == gt_class 393 | true_pos = np.sum(pred_pairs & gt_pairs) 394 | false_pos = np.sum(pred_pairs & ~gt_pairs) 395 | true_neg = np.sum(~pred_pairs & ~gt_pairs) 396 | false_neg = np.sum(~pred_pairs & gt_pairs) 397 | pixel_rand = (true_pos + true_neg) / (true_pos + true_neg + false_pos + false_neg) 398 | pixel_scores[pixel] = pixel_rand 399 | 400 | norm = pltcolors.Normalize(vmin=0., vmax=1.) 401 | cmap = pltcm.ScalarMappable(norm=norm, cmap='cool') 402 | score_rgbs = [cmap.to_rgba(score)[:-1] for score in pixel_scores] 403 | score_img[foreground_inds] = score_rgbs 404 | 405 | return score_img.reshape(dims[1:] + (3,)) -------------------------------------------------------------------------------- /nets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | def fc_block(in_features, out_features, dropout_prob=None): 6 | layers = [nn.Linear(in_features, out_features)] 7 | if dropout_prob: 8 | layers.append(nn.Dropout(p=dropout_prob)) 9 | layers.append(nn.ReLU()) 10 | return nn.Sequential(*layers) 11 | 12 | def fc_net(num_layers, in_size, hidden_size, out_proj_size=None, dropout_prob=None): 13 | layers = [ 14 | fc_block(in_size if l == 0 else hidden_size, hidden_size, dropout_prob) 15 | for l in range(num_layers)] 16 | if out_proj_size is not None: 17 | layers.append(nn.Linear(hidden_size, out_proj_size)) 18 | return nn.Sequential(*layers) 19 | 20 | class CNNEncoder(nn.Module): 21 | def __init__(self, img_dim_hw, output_dim, embed_dim=128, stride=2, kernel_size=4, num_conv_layers=4): 22 | super(CNNEncoder, self).__init__() 23 | # Should be able to get exactly to the desired output size with some number of layers 24 | # Assumes stride=2, kernel_size=4, padding=1 25 | assert img_dim_hw[0] % 2**num_conv_layers == 0 26 | assert img_dim_hw[1] % 2**num_conv_layers == 0 27 | self.layers = [] 28 | self.layers.append(nn.Conv2d(3, embed_dim, kernel_size, stride=stride, padding=1)) 29 | self.layers.append(nn.ReLU()) 30 | for _ in range(num_conv_layers - 1): 31 | self.layers.append(torch.nn.Conv2d(embed_dim, embed_dim, kernel_size, stride=stride, padding=1)) 32 | self.layers.append(nn.ReLU()) 33 | self.layers.append(nn.Linear(embed_dim, output_dim)) 34 | self.layers.append(nn.ReLU()) 35 | 36 | self.layers = nn.ModuleList(self.layers) 37 | 38 | # input \in B x T x H x W x 3 39 | # output \in B x T x U x V x E 40 | def forward(self, x): 41 | batch_size = x.shape[0] 42 | num_frames = x.shape[1] 43 | 44 | # nn.Conv2d expects B x C x H x W, so collapse frames and batches, and swap channel 45 | x = x.flatten(0, 1) # B*T x H x W x C 46 | 47 | x = x.moveaxis(-1, 1) # B*T x C x H x W 48 | for layer in self.layers[:-2]: 49 | new_x = layer(x) 50 | x = new_x 51 | 52 | # Move back from ((B*T) x C x P x P) to (B x T x P x P x C) 53 | x = x.moveaxis(1, -1) # B*T x H x W x C 54 | x = x.reshape((batch_size, num_frames) + x.shape[1:]) 55 | for layer in self.layers[-2:]: 56 | x = layer(x) 57 | 58 | return x 59 | 60 | # Input should have the format seq_len x batch_size x d_model 61 | class TransformerBlock(nn.Module): 62 | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, 63 | activation = nn.functional.relu, layer_norm_eps: float = 1e-5, norm_first: bool = False) -> None: 64 | super(TransformerBlock, self).__init__() 65 | 66 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 67 | 68 | # Implementation of Feedforward model 69 | self.linear1 = nn.Linear(d_model, dim_feedforward) 70 | self.dropout = nn.Dropout(dropout) 71 | self.linear2 = nn.Linear(dim_feedforward, d_model) 72 | 73 | self.norm_first = norm_first 74 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 75 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 76 | self.dropout1 = nn.Dropout(dropout) 77 | self.dropout2 = nn.Dropout(dropout) 78 | 79 | if isinstance(activation, str): 80 | if activation == 'relu': 81 | self.activation = nn.functional.relu 82 | elif activation == 'gelu': 83 | self.activation = nn.functional.gelu 84 | else: 85 | raise ValueError("activation should be relu/gelu, not {}".format(activation)) 86 | else: 87 | self.activation = activation 88 | 89 | def forward(self, x: torch.Tensor) -> torch.Tensor: 90 | if self.norm_first: 91 | x = x + self._sa_block(self.norm1(x)) 92 | x = x + self._ff_block(self.norm2(x)) 93 | else: 94 | x = self.norm1(x + self._sa_block(x)) 95 | x = self.norm2(x + self._ff_block(x)) 96 | return x 97 | 98 | # self-attention block 99 | def _sa_block(self, x: torch.Tensor) -> torch.Tensor: 100 | x = self.self_attn(x, x, x, need_weights=False)[0] 101 | return self.dropout1(x) 102 | 103 | # feed forward block 104 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 105 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 106 | return self.dropout2(x) 107 | 108 | class QueryDecoder(nn.Module): 109 | def __init__( 110 | self, input_size, query_size, hidden_size, output_size, num_hidden_layers, 111 | query_scale=1.0, output_scale=1.0): 112 | super(QueryDecoder, self).__init__() 113 | 114 | # we handle arbitrary output shapes, or just a scalar number of output dimensions 115 | output_size = [output_size] if np.isscalar(output_size) else output_size 116 | self.output_size = output_size 117 | num_output_dims = np.prod(output_size) 118 | 119 | # add the main fully connected layers 120 | next_input_size = input_size + query_size 121 | hidden_layers = [] 122 | for _ in range(num_hidden_layers): 123 | hidden_layers.append(fc_block(next_input_size, hidden_size)) 124 | next_input_size = hidden_size 125 | self.hiddens = nn.Sequential(*hidden_layers) 126 | 127 | # final linear projection to the target number of output dimensions 128 | self.final_project = nn.Linear(next_input_size, num_output_dims) 129 | # register buffers to store the query and output scale factors 130 | self.register_buffer('query_scale', torch.tensor(query_scale)) 131 | self.register_buffer('output_scale', torch.tensor(output_scale)) 132 | 133 | def forward(self, z, query): 134 | # z.shape (B, ..., L), query.shape (B, ..., N, Q) 135 | query = query * self.query_scale 136 | num_queries = query.shape[-2] 137 | # replicate z over all queries (B, ..., L) --> (B, ..., N, L) 138 | z_query_tiled = torch.stack([z]*num_queries, -2) 139 | out = torch.cat([z_query_tiled, query], -1) # concatenate z and query on last axis 140 | out = self.hiddens(out) # apply main hidden layers 141 | out = self.final_project(out) # apply final output projection 142 | out = out.unflatten(-1, self.output_size) # reshape last axis to target output_size 143 | return out * self.output_scale -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu116 2 | torch==1.13.1+cu116 3 | imageio[ffmpeg]==2.27.0 4 | matplotlib==3.7.1 5 | numpy==1.22.4 6 | pytorch-lightning==1.8.6 7 | scipy==1.10.1 8 | tensorflow==2.11.0 9 | torchdata==0.5.1 10 | -------------------------------------------------------------------------------- /segmentation_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | def adjusted_rand_index(gt_mask_oh, pred_mask_prob): 7 | """ 8 | Compute the adjusted Rand index (ARI). This ignores the special case where there is only a 9 | single ground-truth object and will return NaN in this case. 10 | """ 11 | num_pred_instances = pred_mask_prob.shape[-1] 12 | gt_mask_oh = gt_mask_oh.type(torch.float32) 13 | pred_instance_ids = torch.argmax(pred_mask_prob, dim=-1) 14 | pred_mask_oh = torch.nn.functional.one_hot(pred_instance_ids, num_pred_instances).type(torch.float32) 15 | num_points = gt_mask_oh.sum(dim=[1, 2]) 16 | nij = torch.einsum('bji,bjk->bki', pred_mask_oh, gt_mask_oh) 17 | a = nij.sum(dim=1) 18 | b = nij.sum(dim=2) 19 | r_idx = torch.sum(nij * (nij - 1), dim=[1, 2]) 20 | a_idx = torch.sum(a * (a - 1), dim=1) 21 | b_idx = torch.sum(b * (b - 1), dim=1) 22 | expected_r_idx = (a_idx * b_idx) / (num_points * (num_points - 1)) 23 | max_r_idx = (a_idx + b_idx) / 2 24 | ari = (r_idx - expected_r_idx) / (max_r_idx - expected_r_idx) 25 | return ari 26 | 27 | def centroid_distance(pred_trace, gt_trace): 28 | dist = 0 29 | pred_trace = pred_trace.reshape(-1, 2) 30 | gt_trace = gt_trace.reshape(-1, 2) 31 | num_frames_with_gt_obj = 0 32 | for t in range(pred_trace.shape[0]): 33 | # No penalty if object isn't in frame 34 | if torch.any(torch.isnan(gt_trace[t])): 35 | dist += 0 36 | # Maximum penalty if no predicted object in frame (although this almost never happens with soft weights) 37 | elif torch.any(torch.isnan(pred_trace[t])): 38 | num_frames_with_gt_obj += 1 39 | dist += np.linalg.norm([2, 2]) 40 | else: 41 | num_frames_with_gt_obj += 1 42 | dist += np.linalg.norm(pred_trace[t] - gt_trace[t]) 43 | return (dist, num_frames_with_gt_obj) 44 | 45 | # Assumes last 2 dimensions are H and W 46 | def get_centroids(weights): 47 | (H, W) = weights.shape[-2:] 48 | total_ob_weights = torch.sum(weights.flatten(-2), -1) 49 | xs = torch.tile(torch.arange(W) * (2 / (W-1)) - 1, (H, 1)) 50 | ys = torch.tile((torch.arange(H) * (2 / (H-1)) - 1).unsqueeze(1), (1, W)) 51 | # Weighted average of the x (y) coordinates of every pixel in the mask 52 | x_centroids = torch.sum((weights * xs).flatten(-2), -1) / total_ob_weights 53 | y_centroids = torch.sum((weights * ys).flatten(-2), -1) / total_ob_weights 54 | centroids = torch.cat((y_centroids.unsqueeze(-1), x_centroids.unsqueeze(-1)), -1) 55 | return centroids 56 | 57 | def get_centroid_matches(pred_weights, gt_weights): 58 | (_, T, C, H, W) = gt_weights.shape 59 | 60 | # Only consider large enough objects 61 | total_area = T*C*H*W 62 | area_threshold = 0.005 / C # 0.5% of one frame as in SAVi++ 63 | large_gt_obj = [] 64 | large_gt_obj_inds = [] 65 | for (i, gt_obj) in enumerate(range(gt_weights.shape[0])): 66 | total_obj_area = torch.sum(gt_weights[gt_obj].flatten(), -1) 67 | total_ratio = total_obj_area / total_area 68 | if total_ratio >= area_threshold: 69 | large_gt_obj.append(gt_obj) 70 | large_gt_obj_inds.append(i) 71 | 72 | # Only consider predicted slots that aren't "empty" 73 | argmaxed_pred_obj = np.unique(np.argmax(pred_weights, 0)) 74 | 75 | num_gt_obj = len(large_gt_obj) 76 | num_pred_obj = len(argmaxed_pred_obj) 77 | obj_dists = np.zeros((num_gt_obj, num_pred_obj)) 78 | # TODO: only calculate centroids for filtered objects/slots 79 | pred_centroids = get_centroids(pred_weights) 80 | gt_centroids = get_centroids(gt_weights) 81 | 82 | # Dist between centroids of pred and gt object pairs, summed across frames 83 | for i, gt_obj in enumerate(large_gt_obj): 84 | for j, pred_obj in enumerate(argmaxed_pred_obj): 85 | gt_trace = gt_centroids[gt_obj] 86 | pred_trace = pred_centroids[pred_obj] 87 | (dist, num_frames_with_gt_obj) = centroid_distance(pred_trace, gt_trace) 88 | max_dist = np.linalg.norm([2, 2]) * num_frames_with_gt_obj 89 | obj_dists[i, j] = dist / max_dist 90 | 91 | (row_ind, col_ind) = linear_sum_assignment(obj_dists) 92 | best_dists = obj_dists[row_ind, col_ind] 93 | gt_inds = np.array(large_gt_obj_inds) 94 | pred_inds = col_ind 95 | return (best_dists, pred_inds, gt_inds) 96 | 97 | def closest_centroids_metric(pred_weights, gt_weights): 98 | best_dists = get_centroid_matches(pred_weights, gt_weights)[0] 99 | return np.mean(best_dists) -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import matplotlib.colors as pltcolors 2 | import numpy as np 3 | import os 4 | import re 5 | import warnings 6 | 7 | mask_color_names = ['purple', 'blue', 'pink', 'red', 'orange', 'teal', 'magenta', 'olive', 'ecru', 'yellow', 'lilac', 'peach', 'pale green', 'sky blue', 'white', 'mustard', 8 | 'grey', 'cyan', 'light brown', 'bright pink', 'ice blue', 'dark green', 'mauve', 'dark red', 'red orange', 'greyish purple', 'neon purple', 'cobalt', 9 | 'medium blue', 'clay', 'avocado', 'pinky red', 'orange yellow', 'ivory', 'wheat', 'shamrock green', 'pear', 'ultramarine blue', 'greeny brown', 10 | 'very light pink', 'carnation', 'dusty red', 'petrol', 'pumpkin orange', 'saffron', 'greenish turquoise', 'light khaki', 'bluey grey', 'hazel', 11 | 'topaz', 'light pea green', 'battleship grey', 'deep brown', 'bruise', 'dark cream', 'stormy blue', 'orange pink', 'candy pink', 'bland', 'macaroni and cheese', 12 | 'cloudy blue', 'snot', 'auburn', 'strawberry'] 13 | MASK_COLORS = [np.array(pltcolors.to_rgb(f'xkcd:{color_name}')) * 255 for color_name in mask_color_names] 14 | MASK_COLORS= np.array(MASK_COLORS, dtype='uint8') 15 | 16 | def parse_train_step(ckpt_name): 17 | try: 18 | train_step = int(re.split('\D', ckpt_name.split('step=')[1], maxsplit=1)[0]) 19 | except: 20 | train_step = 0 21 | return train_step 22 | 23 | def get_checkpoint_path(checkpoint_dir): 24 | """ 25 | Given a directory containing model checkpoints, load the one with the highest number of train steps. 26 | """ 27 | checkpoint_fnames = [fname for fname in os.listdir(checkpoint_dir) if fname.endswith('.ckpt')] 28 | if not checkpoint_fnames: 29 | raise FileNotFoundError(f'No checkpoints found in {checkpoint_dir}') 30 | 31 | best_train_step = 0 32 | for fname in checkpoint_fnames: 33 | train_step = parse_train_step(fname) 34 | if train_step > best_train_step: 35 | best_train_step = train_step 36 | best_checkpoint_fname = fname 37 | 38 | if best_train_step == 0: 39 | warnings.warn('Failed to parse train step from checkpoint path(s), the most recent checkpoint may not be loaded.', 40 | RuntimeWarning) 41 | best_checkpoint_fname = checkpoint_fnames[0] 42 | 43 | checkpoint_fname = best_checkpoint_fname 44 | train_step = best_train_step 45 | 46 | checkpoint_path = os.path.join(checkpoint_dir, checkpoint_fname) 47 | 48 | return (checkpoint_path, train_step) 49 | 50 | def fourier_embeddings(data, num_freqs=10, max_sampling_rate=60): 51 | freqs = np.linspace(1, max_sampling_rate, num_freqs) * (np.pi/2) 52 | num_embeds = 2*num_freqs + 1 53 | output = np.zeros(data.shape + (num_embeds,)) 54 | output[..., 0] = data 55 | for (ind, freq) in enumerate(freqs): 56 | output[..., 2*ind + 1] = np.sin(freq*data) 57 | output[..., 2*ind + 2] = np.cos(freq*data) 58 | 59 | # Flatten the last dimension 60 | return output.reshape(data.shape[:-1] + (-1,)) --------------------------------------------------------------------------------