├── ocl ├── __init__.py ├── models │ ├── __init__.py │ ├── savi.py │ └── savi_with_memory.py ├── config │ ├── memory.py │ ├── predictor.py │ ├── perceptual_groupings.py │ ├── __init__.py │ ├── neural_networks.py │ ├── optimizers.py │ ├── feature_extractors.py │ ├── datasets.py │ ├── conditioning.py │ ├── utils.py │ ├── metrics.py │ └── plugins.py ├── cli │ ├── eval_utils.py │ ├── cli_utils.py │ ├── compute_dataset_size.py │ ├── eval.py │ └── train.py ├── path_defaults.py ├── matching.py ├── hooks.py ├── consistency.py ├── visualization_types.py ├── distillation.py ├── scheduling.py ├── tree_utils.py ├── combined_model.py ├── memory_rollout.py ├── neural_networks.py ├── losses.py ├── mha.py ├── base.py ├── predictor.py ├── conditioning.py └── visualizations.py ├── NOTICE ├── configs ├── .DS_Store ├── experiment │ ├── .DS_Store │ ├── _output_path.yaml │ ├── OC-MOT │ │ ├── _cater_bbox_mot_preprocessing.yaml │ │ ├── cater_eval.yaml │ │ └── cater.yaml │ └── SAVi │ │ ├── _cater_bbox_mot_preprocessing.yaml │ │ └── cater.yaml └── dataset │ └── cater.yaml ├── srcs ├── cater_demo.gif ├── framework.png └── real_world_demo.gif ├── CODE_OF_CONDUCT.md ├── setup.cfg ├── pyproject.toml ├── requirements.txt ├── CONTRIBUTING.md ├── README.md └── LICENSE /ocl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/object-centric-multiple-object-tracking/HEAD/configs/.DS_Store -------------------------------------------------------------------------------- /srcs/cater_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/object-centric-multiple-object-tracking/HEAD/srcs/cater_demo.gif -------------------------------------------------------------------------------- /srcs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/object-centric-multiple-object-tracking/HEAD/srcs/framework.png -------------------------------------------------------------------------------- /srcs/real_world_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/object-centric-multiple-object-tracking/HEAD/srcs/real_world_demo.gif -------------------------------------------------------------------------------- /configs/experiment/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/object-centric-multiple-object-tracking/HEAD/configs/experiment/.DS_Store -------------------------------------------------------------------------------- /ocl/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Models defined in code.""" 2 | from ocl.models.savi import SAVi 3 | from ocl.models.savi_with_memory import SAVi_mem 4 | __all__ = ["SAVi", "SAVi_mem"] 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /configs/experiment/_output_path.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra 2 | 3 | run: 4 | dir: ${oc.select:experiment.root_output_folder,outputs}/${hydra:runtime.choices.experiment}/${now:%Y-%m-%d_%H-%M-%S} 5 | sweep: 6 | dir: ${oc.select:experiment.root_output_folder,multirun} 7 | subdir: ${hydra:runtime.choices.experiment}/${now:%Y-%m-%d_%H-%M-%S} 8 | output_subdir: config 9 | -------------------------------------------------------------------------------- /configs/dataset/cater.yaml: -------------------------------------------------------------------------------- 1 | # Video dataset CATER based on https://github.com/deepmind/multi_object_datasets . 2 | defaults: 3 | - webdataset 4 | train_shards: ${dataset_prefix:"cater_with_masks/train/shard-{000000..000152}.tar"} 5 | train_size: 35427 6 | val_shards: ${dataset_prefix:"cater_with_masks/val/shard-{000000..000016}.tar"} 7 | val_size: 50 #3937 8 | test_shards: ${dataset_prefix:"cater_with_masks/test/shard-{000000..000073}.tar"} 9 | test_size: 17100 10 | -------------------------------------------------------------------------------- /ocl/config/memory.py: -------------------------------------------------------------------------------- 1 | """Perceptual grouping models.""" 2 | import dataclasses 3 | 4 | from hydra_zen import builds 5 | 6 | from ocl import memory 7 | 8 | 9 | @dataclasses.dataclass 10 | class MemoryConfig: 11 | """Configuration class of Predictor.""" 12 | 13 | 14 | TransitionConfig = builds( 15 | memory.SelfSupervisedMemory, 16 | builds_bases=(MemoryConfig,), 17 | populate_full_signature=True, 18 | ) 19 | 20 | 21 | def register_configs(config_store): 22 | config_store.store(group="schemas", name="memory", node=MemoryConfig) 23 | config_store.store(group="memory", name="mem", node=TransitionConfig) 24 | -------------------------------------------------------------------------------- /ocl/cli/eval_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytorch_lightning 4 | 5 | from ocl.cli import train 6 | 7 | 8 | def build_from_train_config( 9 | config: train.TrainingConfig, checkpoint_path: Optional[str], seed: bool = True 10 | ): 11 | if seed: 12 | pytorch_lightning.seed_everything(config.seed, workers=True) 13 | 14 | pm = train.create_plugin_manager() 15 | datamodule = train.build_and_register_datamodule_from_config(config, pm.hook, pm) 16 | train.build_and_register_plugins_from_config(config, pm) 17 | model = train.build_model_from_config(config, pm.hook, checkpoint_path) 18 | 19 | return datamodule, model, pm 20 | -------------------------------------------------------------------------------- /ocl/config/predictor.py: -------------------------------------------------------------------------------- 1 | """Perceptual grouping models.""" 2 | import dataclasses 3 | 4 | from hydra_zen import builds 5 | 6 | from ocl import predictor 7 | 8 | 9 | @dataclasses.dataclass 10 | class PredictorConfig: 11 | """Configuration class of Predictor.""" 12 | 13 | 14 | TransitionConfig = builds( 15 | predictor.TransformerBlock, 16 | builds_bases=(PredictorConfig,), 17 | populate_full_signature=True, 18 | ) 19 | 20 | 21 | 22 | def register_configs(config_store): 23 | config_store.store(group="schemas", name="perceptual_grouping", node=PredictorConfig) 24 | config_store.store(group="perceptual_grouping", name="slot_attention", node=TransitionConfig) 25 | 26 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select= 3 | # F: errors from pyflake 4 | F, 5 | # W, E: warnings/errors from pycodestyle (PEP8) 6 | W, E, 7 | # I: problems with imports 8 | I, 9 | # B: bugbear warnings ("likely bugs and design problems") 10 | B, 11 | # D: docstring warnings from pydocstyle 12 | D 13 | ignore= 14 | # E203: whitespace before ':' (incompatible with black) 15 | E203, 16 | # E731: do not use a lambda expression, use a def (local def is often ugly) 17 | E731, 18 | # W503: line break before binary operator (incompatible with black) 19 | W503, 20 | # D1: docstring warnings related to missing documentation 21 | D1 22 | max-line-length = 101 23 | ban-relative-imports = true 24 | docstring-convention = google 25 | exclude = .*,__pycache__,./outputs 26 | -------------------------------------------------------------------------------- /ocl/path_defaults.py: -------------------------------------------------------------------------------- 1 | """Default paths for different types of inputs. 2 | 3 | These are only defined for convenience and can also be overwritten using the appropriate *_path 4 | constructor variables of RoutableMixin subclasses. 5 | """ 6 | INPUT = "input" 7 | VIDEO = f"{INPUT}.image" 8 | BOX = f"{INPUT}.instance_bbox" 9 | MASK = f"{INPUT}.mask" 10 | ID = f"{INPUT}.instance_id" 11 | BATCH_SIZE = f"{INPUT}.batch_size" 12 | GLOBAL_STEP = "global_step" 13 | FEATURES = "feature_extractor" 14 | CONDITIONING = "conditioning" 15 | # TODO(hornmax): Currently decoders are nested in the task and accept PerceptualGroupingOutput as 16 | # input. In the future this will change and decoders should just be regular parts of the model. 17 | OBJECTS = "perceptual_grouping.objects" 18 | FEATURE_ATTRIBUTIONS = "perceptual_grouping.feature_attributions" 19 | OBJECT_DECODER = "object_decoder" 20 | -------------------------------------------------------------------------------- /ocl/config/perceptual_groupings.py: -------------------------------------------------------------------------------- 1 | """Perceptual grouping models.""" 2 | import dataclasses 3 | 4 | from hydra_zen import builds 5 | 6 | from ocl import perceptual_grouping 7 | 8 | 9 | @dataclasses.dataclass 10 | class PerceptualGroupingConfig: 11 | """Configuration class of perceptual grouping models.""" 12 | 13 | 14 | SlotAttentionConfig = builds( 15 | perceptual_grouping.SlotAttentionGrouping, 16 | builds_bases=(PerceptualGroupingConfig,), 17 | populate_full_signature=True, 18 | ) 19 | StickBreakingGroupingConfig = builds( 20 | perceptual_grouping.StickBreakingGrouping, 21 | builds_bases=(PerceptualGroupingConfig,), 22 | populate_full_signature=True, 23 | ) 24 | 25 | 26 | def register_configs(config_store): 27 | config_store.store(group="schemas", name="perceptual_grouping", node=PerceptualGroupingConfig) 28 | config_store.store(group="perceptual_grouping", name="slot_attention", node=SlotAttentionConfig) 29 | config_store.store( 30 | group="perceptual_grouping", name="stick_breaking", node=StickBreakingGroupingConfig 31 | ) 32 | -------------------------------------------------------------------------------- /ocl/config/__init__.py: -------------------------------------------------------------------------------- 1 | from hydra.core.config_store import ConfigStore 2 | from omegaconf import OmegaConf 3 | 4 | from ocl.config import ( 5 | conditioning, 6 | datasets, 7 | feature_extractors, 8 | metrics, 9 | neural_networks, 10 | optimizers, 11 | perceptual_groupings, 12 | plugins, 13 | predictor, 14 | memory, 15 | utils, 16 | ) 17 | 18 | config_store = ConfigStore.instance() 19 | 20 | conditioning.register_configs(config_store) 21 | 22 | datasets.register_configs(config_store) 23 | datasets.register_resolvers(OmegaConf) 24 | 25 | feature_extractors.register_configs(config_store) 26 | 27 | metrics.register_configs(config_store) 28 | 29 | neural_networks.register_configs(config_store) 30 | 31 | optimizers.register_configs(config_store) 32 | 33 | perceptual_groupings.register_configs(config_store) 34 | predictor.register_configs(config_store) 35 | memory.register_configs(config_store) 36 | 37 | plugins.register_configs(config_store) 38 | plugins.register_resolvers(OmegaConf) 39 | 40 | utils.register_configs(config_store) 41 | utils.register_resolvers(OmegaConf) 42 | -------------------------------------------------------------------------------- /ocl/cli/cli_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | from hydra.core.hydra_config import HydraConfig 5 | 6 | 7 | def get_commandline_config_path(): 8 | """Get the path of a config path specified on the command line.""" 9 | hydra_cfg = HydraConfig.get() 10 | config_sources = hydra_cfg.runtime.config_sources 11 | config_path = None 12 | for source in config_sources: 13 | if source.schema == "file" and source.provider == "command-line": 14 | config_path = source.path 15 | break 16 | return config_path 17 | 18 | 19 | def find_checkpoint(path): 20 | """Find checkpoint in output path of previous run.""" 21 | checkpoints = glob.glob( 22 | os.path.join(path, "lightning_logs", "version_*", "checkpoints", "*.ckpt") 23 | ) 24 | checkpoints.sort() 25 | # Return the last checkpoint. 26 | # TODO (hornmax): If more than one checkpoint is stored this might not lead to the most recent 27 | # checkpoint being loaded. Generally, I think this is ok as we still allow people to set the 28 | # checkpoint manually. 29 | return checkpoints[-1] 30 | -------------------------------------------------------------------------------- /ocl/config/neural_networks.py: -------------------------------------------------------------------------------- 1 | """Configs for neural networks.""" 2 | import omegaconf 3 | from hydra_zen import builds 4 | 5 | from ocl import neural_networks 6 | 7 | MLPBuilderConfig = builds( 8 | neural_networks.build_mlp, 9 | features=omegaconf.MISSING, 10 | zen_partial=True, 11 | populate_full_signature=True, 12 | ) 13 | TransformerEncoderBuilderConfig = builds( 14 | neural_networks.build_transformer_encoder, 15 | n_layers=omegaconf.MISSING, 16 | n_heads=omegaconf.MISSING, 17 | zen_partial=True, 18 | populate_full_signature=True, 19 | ) 20 | TransformerDecoderBuilderConfig = builds( 21 | neural_networks.build_transformer_decoder, 22 | n_layers=omegaconf.MISSING, 23 | n_heads=omegaconf.MISSING, 24 | zen_partial=True, 25 | populate_full_signature=True, 26 | ) 27 | 28 | 29 | def register_configs(config_store): 30 | config_store.store(group="neural_networks", name="mlp", node=MLPBuilderConfig) 31 | config_store.store( 32 | group="neural_networks", name="transformer_encoder", node=TransformerEncoderBuilderConfig 33 | ) 34 | config_store.store( 35 | group="neural_networks", name="transformer_decoder", node=TransformerDecoderBuilderConfig 36 | ) 37 | -------------------------------------------------------------------------------- /ocl/config/optimizers.py: -------------------------------------------------------------------------------- 1 | """Pytorch optimizers.""" 2 | import dataclasses 3 | 4 | import torch.optim 5 | from hydra_zen import make_custom_builds_fn 6 | 7 | 8 | @dataclasses.dataclass 9 | class OptimizerConfig: 10 | pass 11 | 12 | 13 | # TODO(hornmax): We cannot automatically extract type information from the torch SGD implementation, 14 | # thus we define it manually here. 15 | @dataclasses.dataclass 16 | class SGDConfig(OptimizerConfig): 17 | learning_rate: float 18 | momentum: float = 0.0 19 | dampening: float = 0.0 20 | nestov: bool = False 21 | _target_: str = "hydra_zen.funcs.zen_processing" 22 | _zen_target: str = "torch.optim.SGD" 23 | _zen_partial: bool = True 24 | 25 | 26 | pbuilds = make_custom_builds_fn( 27 | zen_partial=True, 28 | populate_full_signature=True, 29 | ) 30 | 31 | AdamConfig = pbuilds(torch.optim.Adam, builds_bases=(OptimizerConfig,)) 32 | AdamWConfig = pbuilds(torch.optim.AdamW, builds_bases=(OptimizerConfig,)) 33 | 34 | 35 | def register_configs(config_store): 36 | config_store.store(group="optimizers", name="sgd", node=SGDConfig) 37 | config_store.store(group="optimizers", name="adam", node=AdamConfig) 38 | config_store.store(group="optimizers", name="adamw", node=AdamWConfig) 39 | -------------------------------------------------------------------------------- /ocl/matching.py: -------------------------------------------------------------------------------- 1 | """Methods for matching between sets of elements.""" 2 | from typing import Tuple, Type 3 | 4 | import numpy as np 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torchtyping import TensorType 8 | 9 | # Avoid errors due to flake: 10 | batch_size = None 11 | n_elements = None 12 | 13 | CostMatrix = Type[TensorType["batch_size", "n_elements", "n_elements"]] 14 | AssignmentMatrix = Type[TensorType["batch_size", "n_elements", "n_elements"]] 15 | CostVector = Type[TensorType["batch_size"]] 16 | 17 | 18 | class Matcher(torch.nn.Module): 19 | """Matcher base class to define consistent interface.""" 20 | 21 | def forward(self, C: CostMatrix) -> Tuple[AssignmentMatrix, CostVector]: 22 | pass 23 | 24 | 25 | class CPUHungarianMatcher(Matcher): 26 | """Implementaiton of a cpu hungarian matcher using scipy.optimize.linear_sum_assignment.""" 27 | 28 | def forward(self, C: CostMatrix) -> Tuple[AssignmentMatrix, CostVector]: 29 | X = torch.zeros_like(C) 30 | C_cpu: np.ndarray = C.detach().cpu().numpy() 31 | for i, cost_matrix in enumerate(C_cpu): 32 | row_ind, col_ind = linear_sum_assignment(cost_matrix) 33 | X[i][row_ind, col_ind] = 1.0 34 | return X, (C * X).sum(dim=(1, 2)) 35 | -------------------------------------------------------------------------------- /configs/experiment/OC-MOT/_cater_bbox_mot_preprocessing.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /plugins/data_preprocessing@plugins.03a_preprocessing 4 | - /plugins/multi_element_preprocessing@plugins.03b_preprocessing 5 | 6 | plugins: 7 | 03a_preprocessing: 8 | training_fields: 9 | - image 10 | - mask 11 | - object_positions 12 | - __key__ 13 | training_transform: 14 | _target_: torchvision.transforms.Compose 15 | transforms: 16 | - _target_: ocl.preprocessing.AddBBoxFromInstanceMasks 17 | evaluation_fields: 18 | - image 19 | - mask 20 | - object_positions 21 | - __key__ 22 | evaluation_transform: 23 | _target_: torchvision.transforms.Compose 24 | transforms: 25 | - _target_: ocl.preprocessing.AddBBoxFromInstanceMasks 26 | 27 | 03b_preprocessing: 28 | training_transforms: 29 | image: 30 | _target_: ocl.preprocessing.VideoToTensor 31 | instance_bbox: 32 | _target_: ocl.preprocessing.BBoxToTensor 33 | instance_cls: 34 | _target_: ocl.preprocessing.ClsToTensor 35 | instance_id: 36 | _target_: ocl.preprocessing.IDToTensor 37 | mask: 38 | _target_: ocl.preprocessing.MultiMaskToTensor 39 | evaluation_transforms: 40 | image: 41 | _target_: ocl.preprocessing.VideoToTensor 42 | instance_bbox: 43 | _target_: ocl.preprocessing.BBoxToTensor 44 | instance_cls: 45 | _target_: ocl.preprocessing.ClsToTensor 46 | instance_id: 47 | _target_: ocl.preprocessing.IDToTensor 48 | mask: 49 | _target_: ocl.preprocessing.MultiMaskToTensor 50 | -------------------------------------------------------------------------------- /configs/experiment/SAVi/_cater_bbox_mot_preprocessing.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /plugins/data_preprocessing@plugins.03a_preprocessing 4 | - /plugins/multi_element_preprocessing@plugins.03b_preprocessing 5 | 6 | plugins: 7 | 03a_preprocessing: 8 | training_fields: 9 | - image 10 | - mask 11 | - object_positions 12 | - __key__ 13 | training_transform: 14 | _target_: torchvision.transforms.Compose 15 | transforms: 16 | - _target_: ocl.preprocessing.AddBBoxFromInstanceMasks 17 | evaluation_fields: 18 | - image 19 | - mask 20 | - object_positions 21 | - __key__ 22 | evaluation_transform: 23 | _target_: torchvision.transforms.Compose 24 | transforms: 25 | - _target_: ocl.preprocessing.AddBBoxFromInstanceMasks 26 | 27 | 03b_preprocessing: 28 | training_transforms: 29 | image: 30 | _target_: ocl.preprocessing.VideoToTensor 31 | instance_bbox: 32 | _target_: ocl.preprocessing.BBoxToTensor 33 | instance_cls: 34 | _target_: ocl.preprocessing.ClsToTensor 35 | instance_id: 36 | _target_: ocl.preprocessing.IDToTensor 37 | mask: 38 | _target_: ocl.preprocessing.MultiMaskToTensor 39 | evaluation_transforms: 40 | image: 41 | _target_: ocl.preprocessing.VideoToTensor 42 | instance_bbox: 43 | _target_: ocl.preprocessing.BBoxToTensor 44 | instance_cls: 45 | _target_: ocl.preprocessing.ClsToTensor 46 | instance_id: 47 | _target_: ocl.preprocessing.IDToTensor 48 | mask: 49 | _target_: ocl.preprocessing.MultiMaskToTensor 50 | -------------------------------------------------------------------------------- /ocl/config/feature_extractors.py: -------------------------------------------------------------------------------- 1 | """Configurations for feature extractors.""" 2 | import dataclasses 3 | 4 | from hydra_zen import make_custom_builds_fn 5 | 6 | from ocl import feature_extractors 7 | 8 | 9 | @dataclasses.dataclass 10 | class FeatureExtractorConfig: 11 | """Base class for PyTorch Lightning DataModules. 12 | 13 | This class does not actually do anything but ensures that feature extractors give outputs of 14 | a defined structure. 15 | """ 16 | 17 | pass 18 | 19 | 20 | builds_feature_extractor = make_custom_builds_fn( 21 | populate_full_signature=True, 22 | ) 23 | 24 | TimmFeatureExtractorConfig = builds_feature_extractor( 25 | feature_extractors.TimmFeatureExtractor, 26 | builds_bases=(FeatureExtractorConfig,), 27 | ) 28 | SlotAttentionFeatureExtractorConfig = builds_feature_extractor( 29 | feature_extractors.SlotAttentionFeatureExtractor, 30 | builds_bases=(FeatureExtractorConfig,), 31 | ) 32 | SAViFeatureExtractorConfig = builds_feature_extractor( 33 | feature_extractors.SAViFeatureExtractor, 34 | builds_bases=(FeatureExtractorConfig,), 35 | ) 36 | 37 | 38 | def register_configs(config_store): 39 | config_store.store(group="schemas", name="feature_extractor", node=FeatureExtractorConfig) 40 | config_store.store( 41 | group="feature_extractor", 42 | name="timm_model", 43 | node=TimmFeatureExtractorConfig, 44 | ) 45 | config_store.store( 46 | group="feature_extractor", 47 | name="slot_attention", 48 | node=SlotAttentionFeatureExtractorConfig, 49 | ) 50 | config_store.store( 51 | group="feature_extractor", 52 | name="savi", 53 | node=SAViFeatureExtractorConfig, 54 | ) 55 | -------------------------------------------------------------------------------- /ocl/hooks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Tuple 2 | 3 | import webdataset 4 | from pluggy import HookimplMarker, HookspecMarker 5 | 6 | from ocl.combined_model import CombinedModel 7 | 8 | hook_specification = HookspecMarker("ocl") 9 | hook_implementation = HookimplMarker("ocl") 10 | 11 | 12 | class FakeHooks: 13 | """Class that mimics the behavior of the plugin manager hooks property.""" 14 | 15 | def __getattr__(self, attribute): 16 | """Return a fake hook handler for any attribute query.""" 17 | 18 | def fake_hook_handler(*args, **kwargs): 19 | return tuple() 20 | 21 | return fake_hook_handler 22 | 23 | 24 | # @transform_hooks 25 | # def input_dependencies() -> Tuple[str, ...]: 26 | # """Provide list of variables that are required for the plugin to function.""" 27 | # 28 | # 29 | # @transform_hooks 30 | # def provided_inputs() -> Tuple[str, ...]: 31 | # """Provide list of variables that are provided by the plugin.""" 32 | 33 | 34 | @hook_specification 35 | def training_transform() -> Callable[[webdataset.Processor], webdataset.Processor]: 36 | """Provide a transformation which processes a component of a webdataset pipeline.""" 37 | 38 | 39 | @hook_specification 40 | def training_fields() -> Tuple[str]: 41 | """Provide list of fields that are required to be decoded during training.""" 42 | 43 | 44 | @hook_specification 45 | def evaluation_transform() -> Callable[[webdataset.Processor], webdataset.Processor]: 46 | """Provide a transformation which processes a component of a webdataset pipeline.""" 47 | 48 | 49 | @hook_specification 50 | def evaluation_fields() -> Tuple[str]: 51 | """Provide list of fields that are required to be decoded during evaluation.""" 52 | 53 | 54 | @hook_specification 55 | def configure_optimizers(model: CombinedModel) -> Dict[str, Any]: 56 | """Return optimizers in the format of pytorch lightning.""" 57 | 58 | 59 | @hook_specification 60 | def on_train_epoch_start(model: CombinedModel) -> None: 61 | """Hook called when starting training epoch.""" 62 | -------------------------------------------------------------------------------- /ocl/consistency.py: -------------------------------------------------------------------------------- 1 | """Modules to compute the IoU matching cost and solve the corresponding LSAP.""" 2 | import numpy as np 3 | import torch 4 | from scipy.optimize import linear_sum_assignment 5 | from torch import nn 6 | 7 | 8 | class HungarianMatcher(nn.Module): 9 | """This class computes an assignment between the targets and the predictions of the network.""" 10 | 11 | @torch.no_grad() 12 | def forward(self, mask_preds, mask_targets): 13 | """Performs the matching. 14 | 15 | Params: 16 | mask_preds: Tensor of dim [batch_size, n_objects, N, N] with the predicted masks 17 | mask_targets: Tensor of dim [batch_size, n_objects, N, N] 18 | with the target masks from another augmentation 19 | 20 | Returns: 21 | A list of size batch_size, containing tuples of (index_i, index_j) where: 22 | - index_i is the indices of the selected predictions 23 | - index_j is the indices of the corresponding selected targets 24 | """ 25 | bs, n_objects, _, _ = mask_preds.shape 26 | # Compute the iou cost betwen masks 27 | cost_iou = -get_iou_matrix(mask_preds, mask_targets) 28 | cost_iou = cost_iou.reshape(bs, n_objects, bs, n_objects).cpu() 29 | self.costs = torch.stack([cost_iou[i, :, i, :][None] for i in range(bs)]) 30 | indices = [linear_sum_assignment(c[0]) for c in self.costs] 31 | return torch.as_tensor(np.array(indices)) 32 | 33 | 34 | def get_iou_matrix(preds, targets): 35 | 36 | bs, n_objects, H, W = targets.shape 37 | targets = targets.reshape(bs * n_objects, H * W).float() 38 | preds = preds.reshape(bs * n_objects, H * W).float() 39 | 40 | intersection = torch.matmul(targets, preds.t()) 41 | targets_area = targets.sum(dim=1).view(1, -1) 42 | preds_area = preds.sum(dim=1).view(1, -1) 43 | union = (targets_area.t() + preds_area) - intersection 44 | 45 | return torch.where( 46 | union == 0, 47 | torch.tensor(0.0, device=targets.device), 48 | intersection / union, 49 | ) 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ocl" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Max Horn "] 6 | 7 | [tool.poetry.scripts] 8 | ocl_train = "ocl.cli.train:train" 9 | ocl_eval = "ocl.cli.eval:evaluate" 10 | ocl_compute_dataset_size = "ocl.cli.compute_dataset_size:compute_size" 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.7.1,<3.9" 14 | webdataset = "^0.1.103" 15 | # There seems to be an issue in torch 1.12.x with masking and multi-head 16 | # attention. This prevents the usage of makes without a batch dimension. 17 | # Staying with torch 1.11.x version for now. 18 | torch = "1.12.*" 19 | pytorch-lightning = "^1.5.10" 20 | hydra-zen = "^0.7.0" 21 | torchtyping = "^0.1.4" 22 | hydra-core = "^1.2.0" 23 | pluggy = "^1.0.0" 24 | importlib-metadata = "4.2" 25 | torchvision = "0.13.*" 26 | Pillow = "9.0.1" # Newer versions of pillow seem to result in segmentation faults. 27 | torchmetrics = "^0.8.1" 28 | matplotlib = "^3.5.1" 29 | moviepy = "^1.0.3" 30 | scipy = "<=1.8" 31 | awscli = "^1.22.90" 32 | scikit-learn = "^1.0.2" 33 | pyamg = "^4.2.3" 34 | botocore = { extras = ["crt"], version = "^1.27.22" } 35 | timm = {version = "0.6.7", optional = true} 36 | hydra-submitit-launcher = { version = "^1.2.0", optional = true } 37 | decord = "0.6.0" 38 | motmetrics = "^1.2.5" 39 | clip = {git = "https://github.com/openai/CLIP.git", rev = "main", optional = true} 40 | ftfy = {version = "^6.1.1", optional = true} 41 | regex = {version = "^2022.7.9", optional = true} 42 | 43 | [tool.poetry.dev-dependencies] 44 | black = "^22.1.0" 45 | pytest = "^7.0.1" 46 | flake8 = "^4.0.1" 47 | flake8-isort = "^4.1.1" 48 | pre-commit = "^2.17.0" 49 | flake8-tidy-imports = "^4.6.0" 50 | flake8-bugbear = "^22.1.11" 51 | flake8-docstrings = "^1.6.0" 52 | 53 | [tool.poetry.extras] 54 | timm = ["timm"] 55 | clip = ["clip", "ftfy", "regex"] 56 | submitit = ["hydra-submitit-launcher"] 57 | 58 | [build-system] 59 | requires = ["poetry-core<=1.0.4"] 60 | build-backend = "poetry.core.masonry.api" 61 | 62 | [tool.black] 63 | line-length = 101 64 | target-version = ["py38"] 65 | 66 | [tool.isort] 67 | profile = "black" 68 | line_length = 101 69 | skip_gitignore = true 70 | remove_redundant_aliases = true 71 | -------------------------------------------------------------------------------- /ocl/config/datasets.py: -------------------------------------------------------------------------------- 1 | """Register all dataset related configs.""" 2 | import dataclasses 3 | import os 4 | 5 | from hydra_zen import builds 6 | 7 | from ocl import datasets 8 | 9 | 10 | def get_region(): 11 | """Determine the region this EC2 instance is running in. 12 | 13 | Returns None if not running on an EC2 instance. 14 | """ 15 | import requests 16 | 17 | try: 18 | r = requests.get( 19 | "http://169.254.169.254/latest/dynamic/instance-identity/document", timeout=0.5 20 | ) 21 | response_json = r.json() 22 | return response_json.get("region") 23 | except Exception: 24 | # Not running on an ec2 instance. 25 | return None 26 | 27 | 28 | # Detemine region name and select bucket accordingly. 29 | AWS_REGION = get_region() 30 | if AWS_REGION in ["us-east-2", "us-west-2", "eu-west-1"]: 31 | # Select bucket in same region. 32 | DEFAULT_S3_PATH = f"s3://object-centric-datasets-{AWS_REGION}" 33 | else: 34 | # Use MRAP to find closest bucket. 35 | DEFAULT_S3_PATH = "s3://arn:aws:s3::436622332146:accesspoint/m6p4hmmybeu97.mrap" 36 | 37 | 38 | @dataclasses.dataclass 39 | class DataModuleConfig: 40 | """Base class for PyTorch Lightning DataModules. 41 | 42 | This class does not actually do anything but ensures that datasets behave like pytorch lightning 43 | datamodules. 44 | """ 45 | 46 | 47 | def dataset_prefix(path): 48 | # prefix = os.environ.get("DATASET_PREFIX") 49 | prefix = '/home/ubuntu/data' 50 | if prefix: 51 | return f"{prefix}/{path}" 52 | # Use the path to the multi-region bucket if no override is specified. 53 | return f"pipe:aws s3 cp --quiet {DEFAULT_S3_PATH}/{path} -" 54 | 55 | 56 | WebdatasetDataModuleConfig = builds( 57 | datasets.WebdatasetDataModule, populate_full_signature=True, builds_bases=(DataModuleConfig,) 58 | ) 59 | DummyDataModuleConfig = builds( 60 | datasets.DummyDataModule, populate_full_signature=True, builds_bases=(DataModuleConfig,) 61 | ) 62 | 63 | 64 | def register_configs(config_store): 65 | config_store.store(group="schemas", name="dataset", node=DataModuleConfig) 66 | config_store.store(group="dataset", name="webdataset", node=WebdatasetDataModuleConfig) 67 | config_store.store(group="dataset", name="dummy_dataset", node=DummyDataModuleConfig) 68 | 69 | 70 | def register_resolvers(omegaconf): 71 | omegaconf.register_new_resolver("dataset_prefix", dataset_prefix) 72 | -------------------------------------------------------------------------------- /ocl/config/conditioning.py: -------------------------------------------------------------------------------- 1 | """Configuration of slot conditionings.""" 2 | import dataclasses 3 | 4 | from hydra_zen import builds 5 | from omegaconf import SI 6 | 7 | from ocl import conditioning 8 | 9 | 10 | @dataclasses.dataclass 11 | class ConditioningConfig: 12 | """Base class for conditioning module configuration.""" 13 | 14 | 15 | # Unfortunately, we cannot define object_dim as part of the base config class as this prevents using 16 | # required positional arguments in all subclasses. We thus instead pass them here. 17 | LearntConditioningConfig = builds( 18 | conditioning.LearntConditioning, 19 | object_dim=SI("${perceptual_grouping.object_dim}"), 20 | builds_bases=(ConditioningConfig,), 21 | populate_full_signature=True, 22 | ) 23 | 24 | RandomConditioningConfig = builds( 25 | conditioning.RandomConditioning, 26 | object_dim=SI("${perceptual_grouping.object_dim}"), 27 | builds_bases=(ConditioningConfig,), 28 | populate_full_signature=True, 29 | ) 30 | 31 | RandomConditioningWithQMCSamplingConfig = builds( 32 | conditioning.RandomConditioningWithQMCSampling, 33 | object_dim=SI("${perceptual_grouping.object_dim}"), 34 | builds_bases=(ConditioningConfig,), 35 | populate_full_signature=True, 36 | ) 37 | 38 | SlotwiseLearntConditioningConfig = builds( 39 | conditioning.SlotwiseLearntConditioning, 40 | object_dim=SI("${perceptual_grouping.object_dim}"), 41 | builds_bases=(ConditioningConfig,), 42 | populate_full_signature=True, 43 | ) 44 | CoordinateEncoderStateInitConfig = builds( 45 | conditioning.CoordinateEncoderStateInit, 46 | object_dim=SI("${perceptual_grouping.object_dim}"), 47 | builds_bases=(ConditioningConfig,), 48 | populate_full_signature=True, 49 | ) 50 | 51 | def register_configs(config_store): 52 | config_store.store(group="schemas", name="conditioning", node=ConditioningConfig) 53 | 54 | config_store.store(group="conditioning", name="learnt", node=LearntConditioningConfig) 55 | config_store.store(group="conditioning", name="random", node=RandomConditioningConfig) 56 | config_store.store( 57 | group="conditioning", 58 | name="random_with_qmc_sampling", 59 | node=RandomConditioningWithQMCSamplingConfig, 60 | ) 61 | config_store.store( 62 | group="conditioning", name="slotwise_learnt_random", node=SlotwiseLearntConditioningConfig 63 | ) 64 | config_store.store(group="conditioning", name="boxhint", node=CoordinateEncoderStateInitConfig) -------------------------------------------------------------------------------- /ocl/cli/compute_dataset_size.py: -------------------------------------------------------------------------------- 1 | """Script to compute the size of a dataset. 2 | 3 | This is useful when subsampling data using transformations in order to determine the final dataset 4 | size. The size of the dataset is typically need when running distributed training in order to 5 | ensure that all nodes and gpu training processes are presented with the same number of batches. 6 | """ 7 | import dataclasses 8 | import logging 9 | import os 10 | from typing import Dict 11 | 12 | import hydra 13 | import hydra_zen 14 | import tqdm 15 | from pluggy import PluginManager 16 | 17 | import ocl.hooks 18 | from ocl.config.datasets import DataModuleConfig 19 | from ocl.config.plugins import PluginConfig 20 | 21 | 22 | @dataclasses.dataclass 23 | class ComputeSizeConfig: 24 | """Configuration of a training run.""" 25 | 26 | dataset: DataModuleConfig 27 | plugins: Dict[str, PluginConfig] = dataclasses.field(default_factory=dict) 28 | 29 | 30 | hydra.core.config_store.ConfigStore.instance().store( 31 | name="compute_size_config", 32 | node=ComputeSizeConfig, 33 | ) 34 | 35 | 36 | @hydra.main(config_name="compute_size_config", config_path="../../configs", version_base="1.1") 37 | def compute_size(config: ComputeSizeConfig): 38 | pm = PluginManager("ocl") 39 | pm.add_hookspecs(ocl.hooks) 40 | 41 | datamodule = hydra_zen.instantiate(config.dataset, hooks=pm.hook) 42 | pm.register(datamodule) 43 | 44 | plugins = hydra_zen.instantiate(config.plugins) 45 | for plugin in plugins.values(): 46 | pm.register(plugin) 47 | 48 | # Compute dataset sizes 49 | # TODO(hornmax): This is needed for webdataset shuffling, is there a way to make this more 50 | # elegant and less specific? 51 | os.environ["WDS_EPOCH"] = str(0) 52 | train_size = sum( 53 | 1 54 | for _ in tqdm.tqdm( 55 | datamodule.train_data_iterator(), desc="Reading train split", unit="samples" 56 | ) 57 | ) 58 | logging.info("Train split size: %d", train_size) 59 | val_size = sum( 60 | 1 61 | for _ in tqdm.tqdm( 62 | datamodule.val_data_iterator(), desc="Reading validation split", unit="samples" 63 | ) 64 | ) 65 | logging.info("Validation split size: %d", val_size) 66 | test_size = sum( 67 | 1 68 | for _ in tqdm.tqdm( 69 | datamodule.test_data_iterator(), desc="Reading test split", unit="samples" 70 | ) 71 | ) 72 | logging.info("Test split size: %d", test_size) 73 | 74 | 75 | if __name__ == "__main__": 76 | compute_size() 77 | -------------------------------------------------------------------------------- /ocl/visualization_types.py: -------------------------------------------------------------------------------- 1 | """Classes for handling different types of visualizations.""" 2 | import dataclasses 3 | from typing import Any, List, Optional, Union 4 | 5 | import matplotlib.pyplot 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | from torchtyping import TensorType 9 | 10 | 11 | def dataclass_to_dict(d): 12 | return {field.name: getattr(d, field.name) for field in dataclasses.fields(d)} 13 | 14 | 15 | @dataclasses.dataclass 16 | class Visualization: 17 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int): 18 | pass 19 | 20 | 21 | @dataclasses.dataclass 22 | class Figure(Visualization): 23 | """Matplotlib figure.""" 24 | 25 | figure: matplotlib.pyplot.figure 26 | close: bool = True 27 | 28 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int): 29 | experiment.add_figure(**dataclass_to_dict(self), tag=tag, global_step=global_step) 30 | 31 | 32 | @dataclasses.dataclass 33 | class Image(Visualization): 34 | """Single image.""" 35 | 36 | img_tensor: torch.Tensor 37 | dataformats: str = "CHW" 38 | 39 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int): 40 | experiment.add_image(**dataclass_to_dict(self), tag=tag, global_step=global_step) 41 | 42 | 43 | @dataclasses.dataclass 44 | class Images(Visualization): 45 | """Batch of images.""" 46 | 47 | img_tensor: torch.Tensor 48 | dataformats: str = "NCHW" 49 | 50 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int): 51 | experiment.add_images(**dataclass_to_dict(self), tag=tag, global_step=global_step) 52 | 53 | 54 | @dataclasses.dataclass 55 | class Video(Visualization): 56 | """Batch of videos.""" 57 | 58 | vid_tensor: TensorType["batch_size", "frames", "channels", "height", "width"] # noqa: F821 59 | fps: Union[int, float] = 4 60 | 61 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int): 62 | experiment.add_video(**dataclass_to_dict(self), tag=tag, global_step=global_step) 63 | 64 | 65 | class Embedding(Visualization): 66 | """Batch of embeddings.""" 67 | 68 | mat: TensorType["batch_size", "feature_dim"] # noqa: F821 69 | metadata: Optional[List[Any]] = None 70 | label_img: Optional[TensorType["batch_size", "channels", "height", "width"]] = None # noqa: F821 71 | metadata_header: Optional[List[str]] = None 72 | 73 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int): 74 | experiment.add_embedding(**dataclass_to_dict(self), tag=tag, global_step=global_step) 75 | -------------------------------------------------------------------------------- /ocl/models/savi.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Dict 3 | import copy 4 | import torch 5 | from torch import nn 6 | 7 | from ocl.path_defaults import VIDEO, BOX 8 | from ocl.tree_utils import get_tree_element, reduce_tree 9 | 10 | 11 | class SAVi(nn.Module): 12 | def __init__( 13 | self, 14 | conditioning: nn.Module, 15 | feature_extractor: nn.Module, 16 | perceptual_grouping: nn.Module, 17 | decoder: nn.Module, 18 | transition_model: nn.Module, 19 | ): 20 | super().__init__() 21 | self.conditioning = conditioning 22 | self.feature_extractor = feature_extractor 23 | self.perceptual_grouping = perceptual_grouping 24 | self.decoder = decoder 25 | self.transition_model = transition_model 26 | self.batched_input = None 27 | 28 | def forward(self, inputs: Dict[str, Any], phase = 'train'): 29 | # if self.batched_input is None: 30 | # video = get_tree_element(inputs, VIDEO.split(".")) 31 | # # if video.shape[1] == 6: 32 | # self.batched_input = copy.deepcopy(inputs) 33 | # else: 34 | # print ('use catched') 35 | # inputs = self.batched_input 36 | 37 | output = inputs 38 | video = get_tree_element(inputs, VIDEO.split(".")) 39 | box = get_tree_element(inputs, BOX.split(".")) 40 | batch_size = video.shape[0] 41 | 42 | features = self.feature_extractor(video=video) 43 | output["feature_extractor"] = features 44 | conditioning = self.conditioning(batch_size=batch_size) 45 | # conditioning = self.conditioning(batch_size=batch_size) 46 | output["initial_conditioning"] = conditioning 47 | 48 | # Loop over time. 49 | perceptual_grouping_outputs = [] 50 | decoder_outputs = [] 51 | transition_model_outputs = [] 52 | trackers = [] 53 | for frame_features in features: 54 | perceptual_grouping_output = self.perceptual_grouping( 55 | extracted_features=frame_features, conditioning=conditioning 56 | ) 57 | slots = perceptual_grouping_output.objects 58 | decoder_output = self.decoder(object_features=slots) 59 | 60 | # remove background 61 | masks = decoder_output.masks_eval 62 | valid_idx = [0,1,2,4,5,6,7,8,9,10] 63 | masks_obj = masks[:, valid_idx] 64 | 65 | conditioning = self.transition_model(slots) 66 | # Store outputs. 67 | perceptual_grouping_outputs.append(slots) 68 | decoder_outputs.append(decoder_output) 69 | transition_model_outputs.append(conditioning) 70 | trackers.append(masks_obj) 71 | 72 | # Stack all recurrent outputs. 73 | stacking_fn = partial(torch.stack, dim=1) 74 | output["perceptual_grouping"] = reduce_tree(perceptual_grouping_outputs, stacking_fn) 75 | output["decoder"] = reduce_tree(decoder_outputs, stacking_fn) 76 | output["transition_model"] = reduce_tree(transition_model_outputs, stacking_fn) 77 | output["tracks"] = reduce_tree(trackers, stacking_fn) 78 | return output 79 | 80 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.1.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | antlr4-python3-runtime==4.9.3 5 | async-timeout==4.0.2 6 | attrs==21.4.0 7 | awscli==1.25.22 8 | awscrt==0.13.8 9 | black==22.6.0 10 | botocore==1.27.22 11 | braceexpand==0.1.7 12 | CacheControl==0.12.11 13 | cachetools==5.2.0 14 | cachy==0.3.0 15 | certifi==2022.6.15 16 | cfgv==3.3.1 17 | charset-normalizer==2.1.0 18 | cleo==0.8.1 19 | click==8.1.3 20 | clikit==0.6.2 21 | colorama==0.4.4 22 | crashtest==0.3.1 23 | cycler==0.11.0 24 | decorator==4.4.2 25 | decord==0.6.0 26 | distlib==0.3.4 27 | docutils==0.16 28 | filelock==3.7.1 29 | filterpy==1.4.5 30 | flake8==4.0.1 31 | flake8-bugbear==22.7.1 32 | flake8-docstrings==1.6.0 33 | flake8-isort==4.1.1 34 | flake8-tidy-imports==4.8.0 35 | fonttools==4.33.3 36 | frozenlist==1.3.0 37 | fsspec==2022.7.1 38 | google-auth==2.9.0 39 | google-auth-oauthlib==0.4.6 40 | grpcio==1.47.0 41 | html5lib==1.1 42 | hydra-core==1.2.0 43 | hydra-zen==0.7.1 44 | identify==2.5.1 45 | idna==3.3 46 | imageio==2.19.3 47 | imageio-ffmpeg==0.4.7 48 | importlib-metadata==4.2.0 49 | importlib-resources==5.8.0 50 | iniconfig==1.1.1 51 | isort==5.10.1 52 | jeepney==0.8.0 53 | jmespath==1.0.1 54 | joblib==1.1.0 55 | keyring==23.8.2 56 | kiwisolver==1.4.3 57 | lap==0.4.0 58 | llvmlite==0.39.1 59 | lockfile==0.12.2 60 | Markdown==3.3.5 61 | matplotlib==3.5.2 62 | mccabe==0.6.1 63 | motmetrics==1.2.5 64 | moviepy==1.0.3 65 | msgpack==1.0.4 66 | multidict==6.0.2 67 | mypy-extensions==0.4.3 68 | nodeenv==1.7.0 69 | numba==0.56.4 70 | numpy==1.21.6 71 | oauthlib==3.2.0 72 | # Editable install with no version control (ocl==0.1.0) 73 | -e /home/ubuntu/object-centric-learning-models-mainline 74 | omegaconf==2.2.2 75 | packaging==21.3 76 | pandas==1.3.5 77 | pastel==0.2.1 78 | pathspec==0.9.0 79 | pexpect==4.8.0 80 | Pillow==9.0.1 81 | pkginfo==1.8.3 82 | platformdirs==2.5.2 83 | pluggy==1.0.0 84 | poetry==1.1.14 85 | poetry-core==1.0.8 86 | pre-commit==2.19.0 87 | proglog==0.1.10 88 | protobuf==3.20.1 89 | ptyprocess==0.7.0 90 | py==1.11.0 91 | pyamg==4.2.3 92 | pyasn1==0.4.8 93 | pyasn1-modules==0.2.8 94 | pycodestyle==2.8.0 95 | pyDeprecate==0.3.2 96 | pydocstyle==6.1.1 97 | pyflakes==2.4.0 98 | pylev==1.4.0 99 | pyparsing==3.0.9 100 | pytest==7.1.2 101 | python-dateutil==2.8.2 102 | pytorch-lightning==1.6.4 103 | pytz==2022.4 104 | PyYAML==5.4.1 105 | requests==2.28.1 106 | requests-oauthlib==1.3.1 107 | requests-toolbelt==0.9.1 108 | rsa==4.7.2 109 | s3transfer==0.6.0 110 | scikit-learn==1.0.2 111 | scipy==1.7.3 112 | SecretStorage==3.3.2 113 | setuptools-scm==7.0.4 114 | shellingham==1.5.0 115 | six==1.16.0 116 | snowballstemmer==2.2.0 117 | tensorboard==2.9.0 118 | tensorboard-data-server==0.6.1 119 | tensorboard-plugin-wit==1.8.1 120 | testfixtures==6.18.5 121 | threadpoolctl==3.1.0 122 | toml==0.10.2 123 | tomli==2.0.1 124 | tomlkit==0.11.3 125 | torch==1.12.1 126 | torchmetrics==0.8.2 127 | torchtyping==0.1.4 128 | torchvision==0.13.1 129 | tqdm==4.64.0 130 | typeguard==2.13.3 131 | typing_extensions==4.3.0 132 | urllib3==1.26.9 133 | virtualenv==20.15.1 134 | webdataset==0.1.103 135 | webencodings==0.5.1 136 | Werkzeug==2.1.2 137 | xmltodict==0.13.0 138 | yarl==1.7.2 139 | zipp==3.8.0 140 | 141 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /ocl/config/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions useful for configuration.""" 2 | import ast 3 | from typing import Any, Callable 4 | 5 | from hydra_zen import builds 6 | 7 | from ocl.config.feature_extractors import FeatureExtractorConfig 8 | from ocl.config.perceptual_groupings import PerceptualGroupingConfig 9 | from ocl.config.predictor import PredictorConfig 10 | from ocl.distillation import EMASelfDistillation 11 | from ocl.utils import Combined, CreateSlotMask, Recurrent 12 | 13 | 14 | def lambda_string_to_function(function_string: str) -> Callable[..., Any]: 15 | """Convert string of the form "lambda x: x" into a callable Python function.""" 16 | # This is a bit hacky but ensures that the syntax of the input is correct and contains 17 | # a valid lambda function definition without requiring to run `eval`. 18 | parsed = ast.parse(function_string) 19 | is_lambda = isinstance(parsed.body[0], ast.Expr) and isinstance(parsed.body[0].value, ast.Lambda) 20 | if not is_lambda: 21 | raise ValueError(f"'{function_string}' is not a valid lambda definition.") 22 | 23 | return eval(function_string) 24 | 25 | 26 | class ConfigDefinedLambda: 27 | """Lambda function defined in the config. 28 | 29 | This allows lambda functions defined in the config to be pickled. 30 | """ 31 | 32 | def __init__(self, function_string: str): 33 | self.__setstate__(function_string) 34 | 35 | def __getstate__(self) -> str: 36 | return self.function_string 37 | 38 | def __setstate__(self, function_string: str): 39 | self.function_string = function_string 40 | self._fn = lambda_string_to_function(function_string) 41 | 42 | def __call__(self, *args, **kwargs): 43 | return self._fn(*args, **kwargs) 44 | 45 | 46 | def eval_lambda(function_string, *args): 47 | lambda_fn = lambda_string_to_function(function_string) 48 | return lambda_fn(*args) 49 | 50 | 51 | FunctionConfig = builds(ConfigDefinedLambda, populate_full_signature=True) 52 | 53 | # Inherit from all so it can be used in place of any module. 54 | CombinedConfig = builds( 55 | Combined, 56 | populate_full_signature=True, 57 | builds_bases=(FeatureExtractorConfig, PerceptualGroupingConfig, PredictorConfig), 58 | ) 59 | RecurrentConfig = builds( 60 | Recurrent, 61 | populate_full_signature=True, 62 | builds_bases=(FeatureExtractorConfig, PerceptualGroupingConfig, PredictorConfig), 63 | ) 64 | CreateSlotMaskConfig = builds(CreateSlotMask, populate_full_signature=True) 65 | 66 | 67 | EMASelfDistillationConfig = builds( 68 | EMASelfDistillation, 69 | populate_full_signature=True, 70 | builds_bases=(FeatureExtractorConfig, PerceptualGroupingConfig, PredictorConfig), 71 | ) 72 | 73 | 74 | def register_configs(config_store): 75 | config_store.store(group="schemas", name="lambda_fn", node=FunctionConfig) 76 | config_store.store(group="utils", name="combined", node=CombinedConfig) 77 | config_store.store(group="utils", name="selfdistillation", node=EMASelfDistillationConfig) 78 | config_store.store(group="utils", name="recurrent", node=RecurrentConfig) 79 | config_store.store(group="utils", name="create_slot_mask", node=CreateSlotMaskConfig) 80 | 81 | 82 | def register_resolvers(omegaconf): 83 | omegaconf.register_new_resolver("lambda_fn", ConfigDefinedLambda) 84 | omegaconf.register_new_resolver("eval_lambda", eval_lambda) 85 | -------------------------------------------------------------------------------- /ocl/distillation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from ocl import scheduling, tree_utils, utils 8 | 9 | 10 | class EMASelfDistillation(nn.Module): 11 | def __init__( 12 | self, 13 | student: Union[nn.Module, Dict[str, nn.Module]], 14 | schedule: scheduling.HPScheduler, 15 | student_remapping: Optional[Dict[str, str]] = None, 16 | teacher_remapping: Optional[Dict[str, str]] = None, 17 | ): 18 | super().__init__() 19 | # Do this for convenience to reduce crazy amount of nesting. 20 | if isinstance(student, dict): 21 | student = utils.Combined(student) 22 | if student_remapping is None: 23 | student_remapping = {} 24 | if teacher_remapping is None: 25 | teacher_remapping = {} 26 | 27 | self.student = student 28 | self.teacher = copy.deepcopy(student) 29 | self.schedule = schedule 30 | self.student_remapping = {key: value.split(".") for key, value in student_remapping.items()} 31 | self.teacher_remapping = {key: value.split(".") for key, value in teacher_remapping.items()} 32 | 33 | def build_input_dict(self, inputs, remapping): 34 | if not remapping: 35 | return inputs 36 | # This allows us to bing the initial input and previous_output into a similar format. 37 | output_dict = {} 38 | for output_path, input_path in remapping.items(): 39 | source = tree_utils.get_tree_element(inputs, input_path) 40 | 41 | output_path = output_path.split(".") 42 | cur_search = output_dict 43 | for path_part in output_path[:-1]: 44 | # Iterate along path and create nodes that do not exist yet. 45 | try: 46 | # Get element prior to last. 47 | cur_search = tree_utils.get_tree_element(cur_search, [path_part]) 48 | except ValueError: 49 | # Element does not yet exist. 50 | cur_search[path_part] = {} 51 | cur_search = cur_search[path_part] 52 | 53 | cur_search[output_path[-1]] = source 54 | return output_dict 55 | 56 | def forward(self, inputs: Dict[str, Any]): 57 | if self.training: 58 | with torch.no_grad(): 59 | m = self.schedule(inputs["global_step"]) # momentum parameter 60 | for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()): 61 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 62 | 63 | # prefix variable similar to combined module. 64 | prefix: List[str] 65 | if "prefix" in inputs.keys(): 66 | prefix = inputs["prefix"] 67 | else: 68 | prefix = [] 69 | inputs["prefix"] = prefix 70 | 71 | outputs = tree_utils.get_tree_element(inputs, prefix) 72 | 73 | # Forward pass student. 74 | prefix.append("student") 75 | outputs["student"] = {} 76 | student_inputs = self.build_input_dict(inputs, self.student_remapping) 77 | outputs["student"] = self.student(inputs={**inputs, **student_inputs}) 78 | # Teacher and student share the same code, thus paths also need to be the same. To ensure 79 | # that we save the student outputs and run the teacher as if it where the student. 80 | student_output = outputs["student"] 81 | 82 | # Forward pass teacher, but pretending to be student. 83 | outputs["student"] = {} 84 | teacher_inputs = self.build_input_dict(inputs, self.teacher_remapping) 85 | 86 | with torch.no_grad(): 87 | outputs["teacher"] = self.teacher(inputs={**inputs, **teacher_inputs}) 88 | prefix.pop() 89 | 90 | # Set correct outputs again. 91 | outputs["student"] = student_output 92 | 93 | return outputs 94 | -------------------------------------------------------------------------------- /configs/experiment/SAVi/cater.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /experiment/_output_path 3 | - /training_config 4 | - /dataset: cater 5 | - /plugins/optimization@plugins.optimize_parameters 6 | - /plugins/random_strided_window@plugins.02_random_strided_window # Used during training. 7 | - /plugins/multi_element_preprocessing@plugins.03_preprocessing 8 | - /optimizers/adam@plugins.optimize_parameters.optimizer 9 | - /lr_schedulers/cosine_annealing@plugins.optimize_parameters.lr_scheduler 10 | - /experiment/SAVi/_cater_bbox_mot_preprocessing 11 | # - /metrics/three_d_iou@evaluation_metrics.iou 12 | # - /metrics/mot_metric@evaluation_metrics.mot 13 | - /metrics/ari_metric@evaluation_metrics.ari 14 | - _self_ 15 | 16 | 17 | 18 | load_checkpoint: outputs/SAVi/savi/2023-02-20_23-49-54/checkpoints/epoch=18-step=1064.ckpt 19 | 20 | trainer: 21 | gpus: 8 22 | gradient_clip_val: 0.05 23 | gradient_clip_algorithm: "norm" 24 | max_epochs: null 25 | max_steps: 2000005 26 | strategy: 'ddp' 27 | callbacks: 28 | - _target_: pytorch_lightning.callbacks.LearningRateMonitor 29 | logging_interval: "step" 30 | 31 | dataset: 32 | num_workers: 4 33 | batch_size: 30 34 | 35 | models: 36 | _target_: ocl.models.SAVi 37 | conditioning: 38 | _target_: ocl.conditioning.LearntConditioning 39 | n_slots: 11 40 | object_dim: 128 41 | 42 | feature_extractor: 43 | # Use the smaller verion of the feature extractor architecture. 44 | _target_: ocl.feature_extractors.SAViFeatureExtractor 45 | larger_input_arch: False 46 | 47 | perceptual_grouping: 48 | _target_: ocl.perceptual_grouping.SlotAttentionGrouping 49 | feature_dim: 32 50 | object_dim: ${models.conditioning.object_dim} 51 | iters: 2 52 | kvq_dim: 128 53 | use_projection_bias: false 54 | positional_embedding: 55 | _target_: ocl.utils.Sequential 56 | _args_: 57 | - _target_: ocl.utils.SoftPositionEmbed 58 | n_spatial_dims: 2 59 | feature_dim: 32 60 | savi_style: true 61 | - _target_: ocl.neural_networks.build_two_layer_mlp 62 | input_dim: 32 63 | output_dim: 32 64 | hidden_dim: 64 65 | initial_layer_norm: true 66 | ff_mlp: null 67 | 68 | decoder: 69 | _target_: ocl.decoding.SlotAttentionDecoder 70 | decoder: 71 | _target_: ocl.decoding.get_savi_decoder_backbone 72 | object_dim: ${models.perceptual_grouping.object_dim} 73 | larger_input_arch: False 74 | positional_embedding: 75 | _target_: ocl.utils.SoftPositionEmbed 76 | n_spatial_dims: 2 77 | feature_dim: ${models.perceptual_grouping.object_dim} 78 | cnn_channel_order: true 79 | savi_style: true 80 | 81 | transition_model: 82 | _target_: torch.nn.Identity 83 | 84 | losses: 85 | mse: 86 | _target_: ocl.losses.ReconstructionLoss 87 | loss_type: mse_sum 88 | input_path: decoder.reconstruction 89 | target_path: input.image 90 | 91 | plugins: 92 | optimize_parameters: 93 | optimizer: 94 | lr: 0.0001 95 | lr_scheduler: 96 | T_max: 200000 97 | eta_min: 0.0 98 | warmup_steps: 0 99 | 02_random_strided_window: 100 | n_consecutive_frames: 6 101 | training_fields: 102 | - image 103 | evaluation_fields: [] 104 | 105 | visualizations: 106 | input: 107 | _target_: ocl.visualizations.Video 108 | denormalization: null 109 | video_path: input.image 110 | reconstruction: 111 | _target_: ocl.visualizations.Video 112 | denormalization: ${..input.denormalization} 113 | video_path: decoder.reconstruction 114 | objects: 115 | _target_: ocl.visualizations.VisualObject 116 | denormalization: ${..input.denormalization} 117 | object_path: decoder.object_reconstructions 118 | mask_path: decoder.masks_eval 119 | objectmot: 120 | _target_: ocl.visualizations.ObjectMOT 121 | n_clips: 5 122 | denormalization: null 123 | video_path: input.image 124 | mask_path: tracks 125 | 126 | evaluation_metrics: 127 | ari: 128 | prediction_path: decoder.masks 129 | target_path: input.mask 130 | 131 | 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Object-Centric Multiple Object Tracking (OC-MOT) 2 | This is the official implementation of the ICCV'23 paper [Object-Centric Multiple Object Tracking](https://arxiv.org/abs/2309.00233). The code was implemented by [Zixu Zhao](https://github.com/zxzhaoeric), [Jiaze Wang](https://jiazewang.com/), [Max Horn](https://github.com/ExpectationMax) and [Tianjun Xiao](http://tianjunxiao.com/). 3 | 4 | ## Introduction 5 | 6 | ![framework](srcs/framework.png) 7 | 8 | OC-MOT is a framework designed to perform multiple object tracking on object-centric representations without object ID labels. It consists of an index-merge module that adapts the object-centric slots into detection outputs and an unsupervised memory module that builds complete object prototypes to handle occlusions. Benefited from object-centric learning, we only requires sparse detection labels for object localization and feature binding. Our experiments significantly narrow the gap between the existing object-centric model and the fully supervised state-of-the-art and outperform several unsupervised trackers. 9 | 10 | 11 | ## Development Setup 12 | Installing OC-MOT requires at least python3.8. Installation can be done using [poetry](https://python-poetry.org/docs/#installation). After installing `poetry`, check out the repo and setup a development environment: 13 | 14 | ```bash 15 | git clone https://github.com/amazon-science/object-centric-learning-framework.git 16 | cd object-centric-multiple-object-tracking 17 | poetry install 18 | ``` 19 | 20 | This installs the `ocl` package and the cli scripts used for running experiments in a poetry managed virtual environment. Activate the poetry virtual environment `poetry shell` before running the experiments. 21 | 22 | ## Running experiments 23 | 24 | Experiments are defined in the folder `configs/experiment` and can be run 25 | by setting the experiment variable. For example, if we run OC-MOT on Cater dataset, we can follow: 26 | 27 | ```bash 28 | poetry run python -m ocl.cli.train +experiment=OC-MOT/cater 29 | poetry run python -m ocl.cli.eval +experiment=OC-MOT/cater_eval 30 | ``` 31 | 32 | The result is saved in a timestamped subdirectory in `outputs/`, i.e. `outputs/OC-MOT/cater/_