├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── configs ├── __init__.py ├── dataset │ ├── clevr.yaml │ ├── clevr6.yaml │ ├── clevr6_old_splits.yaml │ ├── clevr_old_splits.yaml │ ├── coco.yaml │ ├── coco_nocrowd.yaml │ ├── movi_c.yaml │ ├── movi_c_image.yaml │ ├── movi_e.yaml │ ├── movi_e_image.yaml │ ├── test.py │ ├── voc2012.yaml │ ├── voc2012_trainaug.yaml │ └── voc2012_trainval.yaml └── experiment │ ├── _output_path.yaml │ ├── projects │ └── bridging │ │ └── dinosaur │ │ ├── _base_feature_recon.yaml │ │ ├── _base_feature_recon_gumbel.yaml │ │ ├── _metrics_clevr_patch.yaml │ │ ├── _metrics_coco.yaml │ │ ├── _preprocessing_coco_dino_feature_recon_ccrop.yaml │ │ ├── _preprocessing_coco_dino_feature_recon_origres.yaml │ │ ├── _preprocessing_coco_dino_feature_recon_randcrop.yaml │ │ ├── _preprocessing_movi_dino_feature_recon.yaml │ │ ├── _preprocessing_voc2012_segm_dino_feature_recon.yaml │ │ ├── coco_feat_rec_dino_base16.yaml │ │ ├── coco_feat_rec_dino_base16_adaslot.yaml │ │ ├── coco_feat_rec_dino_base16_adaslot_eval.yaml │ │ ├── movi_c_feat_rec_vitb16.yaml │ │ ├── movi_c_feat_rec_vitb16_adaslot.yaml │ │ ├── movi_c_feat_rec_vitb16_adaslot_eval.yaml │ │ ├── movi_e_feat_rec_vitb16.yaml │ │ ├── movi_e_feat_rec_vitb16_adaslot.yaml │ │ └── movi_e_feat_rec_vitb16_adaslot_eval.yaml │ └── slot_attention │ ├── _base.yaml │ ├── _base_gumbel.yaml │ ├── _base_large.yaml │ ├── _metrics_clevr.yaml │ ├── _metrics_coco.yaml │ ├── _preprocessing_cater.yaml │ ├── _preprocessing_clevr.yaml │ ├── _preprocessing_clevr_no_norm.yaml │ ├── clevr10.yaml │ ├── clevr10_adaslot.yaml │ └── clevr10_adaslot_eval.yaml ├── framework.png ├── ocl ├── __init__.py ├── base.py ├── cli │ ├── cli_utils.py │ ├── compute_dataset_size.py │ ├── eval.py │ ├── eval_utils.py │ ├── train.py │ └── train_adaslot.py ├── combined_model.py ├── conditioning.py ├── config │ ├── __init__.py │ ├── conditioning.py │ ├── datasets.py │ ├── feature_extractors.py │ ├── metrics.py │ ├── neural_networks.py │ ├── optimizers.py │ ├── perceptual_groupings.py │ ├── plugins.py │ ├── predictor.py │ └── utils.py ├── consistency.py ├── datasets.py ├── decoding.py ├── distillation.py ├── feature_extractors.py ├── hooks.py ├── losses.py ├── matching.py ├── memory.py ├── memory_rollout.py ├── metrics.py ├── mha.py ├── models │ ├── __init__.py │ ├── image_grouping.py │ ├── image_grouping_adaslot.py │ └── image_grouping_adaslot_pixel.py ├── neural_networks │ ├── __init__.py │ ├── convenience.py │ ├── extensions.py │ ├── feature_pyramid_networks.py │ ├── positional_embedding.py │ ├── slate.py │ └── wrappers.py ├── path_defaults.py ├── perceptual_grouping.py ├── plugins.py ├── predictor.py ├── preprocessing.py ├── scheduling.py ├── trees.py ├── utils │ ├── __init__.py │ ├── annealing.py │ ├── bboxes.py │ ├── logging.py │ ├── masking.py │ ├── resizing.py │ ├── routing.py │ ├── trees.py │ └── windows.py ├── visualization_types.py └── visualizations.py ├── poetry.lock ├── pyproject.toml └── setup.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official PyTorch Implementation of Adaptive Slot Attention: Object Discovery with Dynamic Slot Number 2 | [![ArXiv](https://img.shields.io/badge/ArXiv-2406.09196-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2406.09196)[![HomePage](https://img.shields.io/badge/HomePage-Visit-blue.svg?logo=homeadvisor&logoColor=f5f5f5)](https://kfan21.github.io/AdaSlot/)![License](https://img.shields.io/badge/License-Apache%202.0-green.svg) 3 | > [**Adaptive Slot Attention: Object Discovery with Dynamic Slot Number**](https://arxiv.org/abs/2406.09196)
4 | > [Ke Fan](https://kfan21.github.io/), [Zechen Bai](https://www.baizechen.site/), [Tianjun Xiao](http://tianjunxiao.com/), [Tong He](https://hetong007.github.io/), [Max Horn](https://expectationmax.github.io/), [Yanwei Fu†](http://yanweifu.github.io/), [Francesco Locatello](https://www.francescolocatello.com/), [Zheng Zhang](https://scholar.google.com/citations?hl=zh-CN&user=k0KiE4wAAAAJ) 5 | 6 | 7 | This is the official implementation of the CVPR'24 paper [Adaptive Slot Attention: Object Discovery with Dynamic Slot Number]([CVPR 2024 Open Access Repository (thecvf.com)](https://openaccess.thecvf.com/content/CVPR2024/html/Fan_Adaptive_Slot_Attention_Object_Discovery_with_Dynamic_Slot_Number_CVPR_2024_paper.html)). 8 | 9 | ## Introduction 10 | 11 | ![framework](framework.png) 12 | 13 | Object-centric learning (OCL) uses slots to extract object representations, enhancing flexibility and interpretability. Slot attention, a common OCL method, refines slot representations with attention mechanisms but requires predefined slot numbers, ignoring object variability. To address this, a novel complexity-aware object auto-encoder framework introduces adaptive slot attention (AdaSlot), dynamically determining the optimal slot count based on data content through a discrete slot sampling module. A masked slot decoder suppresses unselected slots during decoding. Extensive testing shows this framework matches or exceeds fixed-slot models, adapting slot numbers based on instance complexity and promising further research opportunities. 14 | 15 | ## News! 16 | - [2024.11.02] We released the pre-trained checkpoints! Please find them at this [link](https://drive.google.com/drive/folders/1SRKE9Q5XF2UeYj1XB8kyjxORDmB7c7Mz)! 17 | - [2024.08.24] We open-sourced the code! 18 | 19 | ## Development Setup 20 | 21 | Installing AdaSlot requires at least python3.8. Installation can be done using [poetry](https://python-poetry.org/docs/#installation). After installing `poetry`, check out the repo and setup a development environment: 22 | 23 | ```bash 24 | # install python3.8 25 | sudo apt update 26 | sudo apt install software-properties-common 27 | sudo add-apt-repository ppa:deadsnakes/ppa 28 | sudo apt install python3.8 29 | 30 | # install poetry with python3.8 31 | curl -sSL https://install.python-poetry.org | python3.8 - --version 1.2.0 32 | ## add poetry to environment variable 33 | 34 | # create virtual environment with poetry 35 | cd $code_path 36 | poetry install -E timm 37 | ``` 38 | 39 | This installs the `ocl` package and the cli scripts used for running experiments in a poetry managed virtual environment. Activate the poetry virtual environment `poetry shell` before running the experiments. 40 | 41 | ## Running experiments 42 | 43 | Experiments are defined in the folder `configs/experiment` and can be run 44 | by setting the experiment variable. For example, if we run OC-MOT on Cater dataset, we can follow: 45 | 46 | ```bash 47 | poetry shell 48 | 49 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml 50 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT 51 | python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT 52 | 53 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml 54 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT 55 | python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT 56 | 57 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml 58 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT 59 | python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT 60 | 61 | python -m ocl.cli.train +experiment=slot_attention/clevr10.yaml 62 | python -m ocl.cli.train +experiment=slot_attention/clevr10_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT 63 | python -m ocl.cli.eval +experiment=slot_attention/clevr10_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT 64 | ``` 65 | 66 | The result is saved in a timestamped subdirectory in `outputs/`, i.e. `outputs/OC-MOT/cater/_