├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── configs
├── __init__.py
├── dataset
│ ├── clevr.yaml
│ ├── clevr6.yaml
│ ├── clevr6_old_splits.yaml
│ ├── clevr_old_splits.yaml
│ ├── coco.yaml
│ ├── coco_nocrowd.yaml
│ ├── movi_c.yaml
│ ├── movi_c_image.yaml
│ ├── movi_e.yaml
│ ├── movi_e_image.yaml
│ ├── test.py
│ ├── voc2012.yaml
│ ├── voc2012_trainaug.yaml
│ └── voc2012_trainval.yaml
└── experiment
│ ├── _output_path.yaml
│ ├── projects
│ └── bridging
│ │ └── dinosaur
│ │ ├── _base_feature_recon.yaml
│ │ ├── _base_feature_recon_gumbel.yaml
│ │ ├── _metrics_clevr_patch.yaml
│ │ ├── _metrics_coco.yaml
│ │ ├── _preprocessing_coco_dino_feature_recon_ccrop.yaml
│ │ ├── _preprocessing_coco_dino_feature_recon_origres.yaml
│ │ ├── _preprocessing_coco_dino_feature_recon_randcrop.yaml
│ │ ├── _preprocessing_movi_dino_feature_recon.yaml
│ │ ├── _preprocessing_voc2012_segm_dino_feature_recon.yaml
│ │ ├── coco_feat_rec_dino_base16.yaml
│ │ ├── coco_feat_rec_dino_base16_adaslot.yaml
│ │ ├── coco_feat_rec_dino_base16_adaslot_eval.yaml
│ │ ├── movi_c_feat_rec_vitb16.yaml
│ │ ├── movi_c_feat_rec_vitb16_adaslot.yaml
│ │ ├── movi_c_feat_rec_vitb16_adaslot_eval.yaml
│ │ ├── movi_e_feat_rec_vitb16.yaml
│ │ ├── movi_e_feat_rec_vitb16_adaslot.yaml
│ │ └── movi_e_feat_rec_vitb16_adaslot_eval.yaml
│ └── slot_attention
│ ├── _base.yaml
│ ├── _base_gumbel.yaml
│ ├── _base_large.yaml
│ ├── _metrics_clevr.yaml
│ ├── _metrics_coco.yaml
│ ├── _preprocessing_cater.yaml
│ ├── _preprocessing_clevr.yaml
│ ├── _preprocessing_clevr_no_norm.yaml
│ ├── clevr10.yaml
│ ├── clevr10_adaslot.yaml
│ └── clevr10_adaslot_eval.yaml
├── framework.png
├── ocl
├── __init__.py
├── base.py
├── cli
│ ├── cli_utils.py
│ ├── compute_dataset_size.py
│ ├── eval.py
│ ├── eval_utils.py
│ ├── train.py
│ └── train_adaslot.py
├── combined_model.py
├── conditioning.py
├── config
│ ├── __init__.py
│ ├── conditioning.py
│ ├── datasets.py
│ ├── feature_extractors.py
│ ├── metrics.py
│ ├── neural_networks.py
│ ├── optimizers.py
│ ├── perceptual_groupings.py
│ ├── plugins.py
│ ├── predictor.py
│ └── utils.py
├── consistency.py
├── datasets.py
├── decoding.py
├── distillation.py
├── feature_extractors.py
├── hooks.py
├── losses.py
├── matching.py
├── memory.py
├── memory_rollout.py
├── metrics.py
├── mha.py
├── models
│ ├── __init__.py
│ ├── image_grouping.py
│ ├── image_grouping_adaslot.py
│ └── image_grouping_adaslot_pixel.py
├── neural_networks
│ ├── __init__.py
│ ├── convenience.py
│ ├── extensions.py
│ ├── feature_pyramid_networks.py
│ ├── positional_embedding.py
│ ├── slate.py
│ └── wrappers.py
├── path_defaults.py
├── perceptual_grouping.py
├── plugins.py
├── predictor.py
├── preprocessing.py
├── scheduling.py
├── trees.py
├── utils
│ ├── __init__.py
│ ├── annealing.py
│ ├── bboxes.py
│ ├── logging.py
│ ├── masking.py
│ ├── resizing.py
│ ├── routing.py
│ ├── trees.py
│ └── windows.py
├── visualization_types.py
└── visualizations.py
├── poetry.lock
├── pyproject.toml
└── setup.cfg
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.DS_Store
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Official PyTorch Implementation of Adaptive Slot Attention: Object Discovery with Dynamic Slot Number
2 | [](https://arxiv.org/abs/2406.09196)[](https://kfan21.github.io/AdaSlot/)
3 | > [**Adaptive Slot Attention: Object Discovery with Dynamic Slot Number**](https://arxiv.org/abs/2406.09196)
4 | > [Ke Fan](https://kfan21.github.io/), [Zechen Bai](https://www.baizechen.site/), [Tianjun Xiao](http://tianjunxiao.com/), [Tong He](https://hetong007.github.io/), [Max Horn](https://expectationmax.github.io/), [Yanwei Fu†](http://yanweifu.github.io/), [Francesco Locatello](https://www.francescolocatello.com/), [Zheng Zhang](https://scholar.google.com/citations?hl=zh-CN&user=k0KiE4wAAAAJ)
5 |
6 |
7 | This is the official implementation of the CVPR'24 paper [Adaptive Slot Attention: Object Discovery with Dynamic Slot Number]([CVPR 2024 Open Access Repository (thecvf.com)](https://openaccess.thecvf.com/content/CVPR2024/html/Fan_Adaptive_Slot_Attention_Object_Discovery_with_Dynamic_Slot_Number_CVPR_2024_paper.html)).
8 |
9 | ## Introduction
10 |
11 | 
12 |
13 | Object-centric learning (OCL) uses slots to extract object representations, enhancing flexibility and interpretability. Slot attention, a common OCL method, refines slot representations with attention mechanisms but requires predefined slot numbers, ignoring object variability. To address this, a novel complexity-aware object auto-encoder framework introduces adaptive slot attention (AdaSlot), dynamically determining the optimal slot count based on data content through a discrete slot sampling module. A masked slot decoder suppresses unselected slots during decoding. Extensive testing shows this framework matches or exceeds fixed-slot models, adapting slot numbers based on instance complexity and promising further research opportunities.
14 |
15 | ## News!
16 | - [2024.11.02] We released the pre-trained checkpoints! Please find them at this [link](https://drive.google.com/drive/folders/1SRKE9Q5XF2UeYj1XB8kyjxORDmB7c7Mz)!
17 | - [2024.08.24] We open-sourced the code!
18 |
19 | ## Development Setup
20 |
21 | Installing AdaSlot requires at least python3.8. Installation can be done using [poetry](https://python-poetry.org/docs/#installation). After installing `poetry`, check out the repo and setup a development environment:
22 |
23 | ```bash
24 | # install python3.8
25 | sudo apt update
26 | sudo apt install software-properties-common
27 | sudo add-apt-repository ppa:deadsnakes/ppa
28 | sudo apt install python3.8
29 |
30 | # install poetry with python3.8
31 | curl -sSL https://install.python-poetry.org | python3.8 - --version 1.2.0
32 | ## add poetry to environment variable
33 |
34 | # create virtual environment with poetry
35 | cd $code_path
36 | poetry install -E timm
37 | ```
38 |
39 | This installs the `ocl` package and the cli scripts used for running experiments in a poetry managed virtual environment. Activate the poetry virtual environment `poetry shell` before running the experiments.
40 |
41 | ## Running experiments
42 |
43 | Experiments are defined in the folder `configs/experiment` and can be run
44 | by setting the experiment variable. For example, if we run OC-MOT on Cater dataset, we can follow:
45 |
46 | ```bash
47 | poetry shell
48 |
49 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml
50 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
51 | python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT
52 |
53 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml
54 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
55 | python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT
56 |
57 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml
58 | python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
59 | python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT
60 |
61 | python -m ocl.cli.train +experiment=slot_attention/clevr10.yaml
62 | python -m ocl.cli.train +experiment=slot_attention/clevr10_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
63 | python -m ocl.cli.eval +experiment=slot_attention/clevr10_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT
64 | ```
65 |
66 | The result is saved in a timestamped subdirectory in `outputs/`, i.e. `outputs/OC-MOT/cater/_` in the above case. The prefix path `outputs` can be configured using the `experiment.root_output_path` variable.
67 |
68 | ## Citation
69 |
70 | Please cite our paper if you find this repo useful!
71 |
72 | ```bibtex
73 | @inproceedings{fan2024adaptive,
74 | title={Adaptive slot attention: Object discovery with dynamic slot number},
75 | author={Fan, Ke and Bai, Zechen and Xiao, Tianjun and He, Tong and Horn, Max and Fu, Yanwei and Locatello, Francesco and Zhang, Zheng},
76 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
77 | pages={23062--23071},
78 | year={2024}
79 | }
80 | ```
81 |
82 | Related projects that this paper is developed upon:
83 |
84 | ```bibtex
85 | @misc{oclf,
86 | author = {Max Horn and Maximilian Seitzer and Andrii Zadaianchuk and Zixu Zhao and Dominik Zietlow and Florian Wenzel and Tianjun Xiao},
87 | title = {Object Centric Learning Framework (version 0.1)},
88 | year = {2023},
89 | url = {https://github.com/amazon-science/object-centric-learning-framework},
90 | }
91 | ```
92 |
93 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Hydra needs this file to recognize the config folder when using hydra.main from console scripts
2 |
--------------------------------------------------------------------------------
/configs/dataset/clevr.yaml:
--------------------------------------------------------------------------------
1 | # Image dataset CLEVR based on https://github.com/deepmind/multi_object_datasets .
2 | defaults:
3 | - webdataset
4 |
5 | train_shards: "/home/ubuntu/clevr_with_masks_new_splits/train/shard-{000000..000114}.tar"
6 | train_size: 70000
7 | val_shards: "/home/ubuntu/clevr_with_masks_new_splits/val/shard-{000000..000024}.tar"
8 | val_size: 15000
9 | test_shards: "/home/ubuntu/clevr_with_masks_new_splits/test/shard-{000000..000024}.tar"
10 | test_size: 15000
--------------------------------------------------------------------------------
/configs/dataset/clevr6.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Image dataset containing instances from CLEVR with at most 6 objects in each scene.
3 | defaults:
4 | - /dataset/clevr@dataset
5 | - /plugins/subset_dataset@plugins.01_clevr6_subset
6 | - _self_
7 |
8 | dataset:
9 | # Values derived from running `bin/compute_dataset_size.py`
10 | train_size: 26240
11 | val_size: 5553
12 | test_size: 5600
13 |
14 | plugins:
15 | 01_clevr6_subset:
16 | predicate: "${lambda_fn:'lambda visibility: visibility.sum() < 7'}"
17 | fields:
18 | - visibility
19 |
--------------------------------------------------------------------------------
/configs/dataset/clevr6_old_splits.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Image dataset containing instances from CLEVR with at most 6 objects in each scene.
3 | defaults:
4 | - clevr_old_splits@dataset
5 | - /plugins/subset_dataset@plugins.01_clevr6_subset
6 | - _self_
7 |
8 | dataset:
9 | # Values derived from running `bin/compute_dataset_size.py`
10 | train_size: 29948
11 | val_size: 3674
12 | test_size: 3771
13 |
14 | plugins:
15 | 01_clevr6_subset:
16 | predicate: "${lambda_fn:'lambda visibility: visibility.sum() < 7'}"
17 | fields:
18 | - visibility
19 |
--------------------------------------------------------------------------------
/configs/dataset/clevr_old_splits.yaml:
--------------------------------------------------------------------------------
1 | # Image dataset CLEVR based on https://github.com/deepmind/multi_object_datasets .
2 | defaults:
3 | - webdataset
4 |
5 | train_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/train/shard-{000000..000131}.tar"}
6 | train_size: 80000
7 | val_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/val/shard-{000000..000016}.tar"}
8 | val_size: 10000
9 | test_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/test/shard-{000000..000016}.tar"}
10 | test_size: 10000
11 |
--------------------------------------------------------------------------------
/configs/dataset/coco.yaml:
--------------------------------------------------------------------------------
1 | # The coco2017 dataset with instance, stuff and caption annotations.
2 | defaults:
3 | - webdataset
4 |
5 | train_shards: "/home/ubuntu/coco2017/train/shard-{000000..000412}.tar"
6 | train_size: 118287
7 | val_shards: "/home/ubuntu/coco2017/val/shard-{000000..000017}.tar"
8 | val_size: 5000
9 | test_shards: "/home/ubuntu/coco2017/test/shard-{000000..000126}.tar"
10 | test_size: 40670
11 | use_autopadding: true
12 |
--------------------------------------------------------------------------------
/configs/dataset/coco_nocrowd.yaml:
--------------------------------------------------------------------------------
1 | # The coco2017 dataset with instance, stuff and caption annotations.
2 | # Validation dataset does not contain any crowd annotations.
3 | defaults:
4 | - webdataset
5 |
6 | train_shards: ${dataset_prefix:"coco2017/train/shard-{000000..000412}.tar"}
7 | train_size: 118287
8 | val_shards: ${dataset_prefix:"coco2017/val_nocrowd/shard-{000000..000017}.tar"}
9 | val_size: 5000
10 | test_shards: ${dataset_prefix:"coco2017/test/shard-{000000..000126}.tar"}
11 | test_size: 40670
12 | use_autopadding: true
13 |
--------------------------------------------------------------------------------
/configs/dataset/movi_c.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - webdataset@dataset
4 | - _self_
5 |
6 | dataset:
7 | train_shards: "/home/ubuntu/movi_c/train/shard-{000000..000298}.tar"
8 | train_size: 9737
9 | val_shards: "/home/ubuntu/movi_c/val/shard-{000000..000007}.tar"
10 | val_size: 250
11 | test_shards: "/home/ubuntu/movi_c/val/shard-{000000..000007}.tar"
12 | test_size: 250
13 | use_autopadding: true
14 |
15 | plugins:
16 | 00_1_rename_fields:
17 | _target_: ocl.plugins.RenameFields
18 | train_mapping:
19 | video: image
20 | evaluation_mapping:
21 | video: image
22 | segmentations: mask
23 | 00_2_adapt_mask_format:
24 | _target_: ocl.plugins.SingleElementPreprocessing
25 | training_transform: null
26 | evaluation_transform:
27 | _target_: ocl.preprocessing.IntegerToOneHotMask
28 | output_axis: -4
29 | max_instances: 10
30 | ignore_typical_background: false
31 | element_key: mask
32 |
--------------------------------------------------------------------------------
/configs/dataset/movi_c_image.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Image dataset containing subsampled frames from MOVI_C dataset.
3 | defaults:
4 | - /dataset/movi_c
5 | - /plugins/sample_frames_from_video@plugins.02_sample_frames
6 | - _self_
7 |
8 | dataset:
9 | # Values derived from running `bin/compute_dataset_size.py`.
10 | train_size: 87633
11 | val_size: 6000
12 | test_size: 6000
13 |
14 | plugins:
15 | 02_sample_frames:
16 | n_frames_per_video: 9
17 | n_eval_frames_per_video: -1
18 | training_fields:
19 | - image
20 | evaluation_fields:
21 | - image
22 | - mask
23 | dim: 0
24 | seed: 457834752
25 | shuffle_buffer_size: 1000
26 |
--------------------------------------------------------------------------------
/configs/dataset/movi_e.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - webdataset@dataset
4 | - _self_
5 |
6 | dataset:
7 | train_shards: "/home/ubuntu/movi_e/train/shard-{000000..000679}.tar"
8 | train_size: 9749
9 | val_shards: "/home/ubuntu/movi_e/val/shard-{000000..000017}.tar"
10 | val_size: 250
11 | test_shards: "/home/ubuntu/movi_e/val/shard-{000000..000017}.tar"
12 | test_size: 250
13 | use_autopadding: true
14 | plugins:
15 | 00_1_rename_fields:
16 | _target_: ocl.plugins.RenameFields
17 | train_mapping:
18 | video: image
19 | evaluation_mapping:
20 | video: image
21 | segmentations: mask
22 | 00_2_adapt_mask_format:
23 | _target_: ocl.plugins.SingleElementPreprocessing
24 | training_transform: null
25 | evaluation_transform:
26 | _target_: ocl.preprocessing.IntegerToOneHotMask
27 | output_axis: -4
28 | max_instances: 23
29 | ignore_typical_background: false
30 | element_key: mask
31 |
--------------------------------------------------------------------------------
/configs/dataset/movi_e_image.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Image dataset containing subsampled frames from MOVI_E dataset.
3 | defaults:
4 | - /dataset/movi_e
5 | - /plugins/sample_frames_from_video@plugins.02_sample_frames
6 | - _self_
7 |
8 | dataset:
9 | # Values derived from running `bin/compute_dataset_size.py`.
10 | train_size: 87741
11 | val_size: 6000
12 | test_size: 6000
13 |
14 | plugins:
15 | 02_sample_frames:
16 | n_frames_per_video: 9
17 | n_eval_frames_per_video: -1
18 | training_fields:
19 | - image
20 | evaluation_fields:
21 | - image
22 | - mask
23 | dim: 0
24 | seed: 457834752
25 | shuffle_buffer_size: 1000
26 |
--------------------------------------------------------------------------------
/configs/dataset/test.py:
--------------------------------------------------------------------------------
1 | lambda visibility, color, pixel_coords: True
--------------------------------------------------------------------------------
/configs/dataset/voc2012.yaml:
--------------------------------------------------------------------------------
1 | # The PASCAL VOC 2012 dataset. Does not contain segmentation annotations.
2 | defaults:
3 | - webdataset
4 |
5 | train_shards: ${dataset_prefix:"voc2012_detection/train/shard-{000000..000021}.tar"}
6 | train_size: 5717
7 | val_shards: ${dataset_prefix:"voc2012_detection/val/shard-{000000..000022}.tar"}
8 | val_size: 5823
9 | test_shards: ${dataset_prefix:"voc2012_detection/test/shard-{000000..000041}.tar"}
10 | test_size: 10991
11 | use_autopadding: true
12 |
--------------------------------------------------------------------------------
/configs/dataset/voc2012_trainaug.yaml:
--------------------------------------------------------------------------------
1 | # The PASCAL VOC 2012 dataset in the trainaug variant with instance segmentation masks.
2 | defaults:
3 | - webdataset
4 |
5 | train_shards: "/home/ubuntu/voc2012/trainaug/shard-{000000..000040}.tar"
6 | train_size: 10582
7 | val_shards: "/home/ubuntu/voc2012/val/shard-{000000..000011}.tar"
8 | val_size: 1449
9 | test_shards: null
10 | test_size: null
11 | use_autopadding: true
12 |
--------------------------------------------------------------------------------
/configs/dataset/voc2012_trainval.yaml:
--------------------------------------------------------------------------------
1 | # The PASCAL VOC 2012 dataset, using joint train+val splits for training and validation.
2 | # This setting is often used in the unsupervised case.
3 | defaults:
4 | - webdataset
5 |
6 | train_shards:
7 | - ${dataset_prefix:"voc2012_detection/train/shard-000000.tar"}
8 | - ${dataset_prefix:"voc2012_detection/train/shard-000001.tar"}
9 | - ${dataset_prefix:"voc2012_detection/train/shard-000002.tar"}
10 | - ${dataset_prefix:"voc2012_detection/train/shard-000003.tar"}
11 | - ${dataset_prefix:"voc2012_detection/train/shard-000004.tar"}
12 | - ${dataset_prefix:"voc2012_detection/train/shard-000005.tar"}
13 | - ${dataset_prefix:"voc2012_detection/train/shard-000006.tar"}
14 | - ${dataset_prefix:"voc2012_detection/train/shard-000007.tar"}
15 | - ${dataset_prefix:"voc2012_detection/train/shard-000008.tar"}
16 | - ${dataset_prefix:"voc2012_detection/train/shard-000009.tar"}
17 | - ${dataset_prefix:"voc2012_detection/train/shard-000010.tar"}
18 | - ${dataset_prefix:"voc2012_detection/train/shard-000011.tar"}
19 | - ${dataset_prefix:"voc2012_detection/train/shard-000012.tar"}
20 | - ${dataset_prefix:"voc2012_detection/train/shard-000013.tar"}
21 | - ${dataset_prefix:"voc2012_detection/train/shard-000014.tar"}
22 | - ${dataset_prefix:"voc2012_detection/train/shard-000015.tar"}
23 | - ${dataset_prefix:"voc2012_detection/train/shard-000016.tar"}
24 | - ${dataset_prefix:"voc2012_detection/train/shard-000017.tar"}
25 | - ${dataset_prefix:"voc2012_detection/train/shard-000018.tar"}
26 | - ${dataset_prefix:"voc2012_detection/train/shard-000019.tar"}
27 | - ${dataset_prefix:"voc2012_detection/train/shard-000020.tar"}
28 | - ${dataset_prefix:"voc2012_detection/train/shard-000021.tar"}
29 | - ${dataset_prefix:"voc2012_detection/val/shard-000000.tar"}
30 | - ${dataset_prefix:"voc2012_detection/val/shard-000001.tar"}
31 | - ${dataset_prefix:"voc2012_detection/val/shard-000002.tar"}
32 | - ${dataset_prefix:"voc2012_detection/val/shard-000003.tar"}
33 | - ${dataset_prefix:"voc2012_detection/val/shard-000004.tar"}
34 | - ${dataset_prefix:"voc2012_detection/val/shard-000005.tar"}
35 | - ${dataset_prefix:"voc2012_detection/val/shard-000006.tar"}
36 | - ${dataset_prefix:"voc2012_detection/val/shard-000007.tar"}
37 | - ${dataset_prefix:"voc2012_detection/val/shard-000008.tar"}
38 | - ${dataset_prefix:"voc2012_detection/val/shard-000009.tar"}
39 | - ${dataset_prefix:"voc2012_detection/val/shard-000010.tar"}
40 | - ${dataset_prefix:"voc2012_detection/val/shard-000011.tar"}
41 | - ${dataset_prefix:"voc2012_detection/val/shard-000012.tar"}
42 | - ${dataset_prefix:"voc2012_detection/val/shard-000013.tar"}
43 | - ${dataset_prefix:"voc2012_detection/val/shard-000014.tar"}
44 | - ${dataset_prefix:"voc2012_detection/val/shard-000015.tar"}
45 | - ${dataset_prefix:"voc2012_detection/val/shard-000016.tar"}
46 | - ${dataset_prefix:"voc2012_detection/val/shard-000017.tar"}
47 | - ${dataset_prefix:"voc2012_detection/val/shard-000018.tar"}
48 | - ${dataset_prefix:"voc2012_detection/val/shard-000019.tar"}
49 | - ${dataset_prefix:"voc2012_detection/val/shard-000020.tar"}
50 | - ${dataset_prefix:"voc2012_detection/val/shard-000021.tar"}
51 | - ${dataset_prefix:"voc2012_detection/val/shard-000022.tar"}
52 | train_size: 11540
53 | val_shards:
54 | - ${dataset_prefix:"voc2012_detection/train/shard-000000.tar"}
55 | - ${dataset_prefix:"voc2012_detection/train/shard-000001.tar"}
56 | - ${dataset_prefix:"voc2012_detection/train/shard-000002.tar"}
57 | - ${dataset_prefix:"voc2012_detection/train/shard-000003.tar"}
58 | - ${dataset_prefix:"voc2012_detection/train/shard-000004.tar"}
59 | - ${dataset_prefix:"voc2012_detection/train/shard-000005.tar"}
60 | - ${dataset_prefix:"voc2012_detection/train/shard-000006.tar"}
61 | - ${dataset_prefix:"voc2012_detection/train/shard-000007.tar"}
62 | - ${dataset_prefix:"voc2012_detection/train/shard-000008.tar"}
63 | - ${dataset_prefix:"voc2012_detection/train/shard-000009.tar"}
64 | - ${dataset_prefix:"voc2012_detection/train/shard-000010.tar"}
65 | - ${dataset_prefix:"voc2012_detection/train/shard-000011.tar"}
66 | - ${dataset_prefix:"voc2012_detection/train/shard-000012.tar"}
67 | - ${dataset_prefix:"voc2012_detection/train/shard-000013.tar"}
68 | - ${dataset_prefix:"voc2012_detection/train/shard-000014.tar"}
69 | - ${dataset_prefix:"voc2012_detection/train/shard-000015.tar"}
70 | - ${dataset_prefix:"voc2012_detection/train/shard-000016.tar"}
71 | - ${dataset_prefix:"voc2012_detection/train/shard-000017.tar"}
72 | - ${dataset_prefix:"voc2012_detection/train/shard-000018.tar"}
73 | - ${dataset_prefix:"voc2012_detection/train/shard-000019.tar"}
74 | - ${dataset_prefix:"voc2012_detection/train/shard-000020.tar"}
75 | - ${dataset_prefix:"voc2012_detection/train/shard-000021.tar"}
76 | - ${dataset_prefix:"voc2012_detection/val/shard-000000.tar"}
77 | - ${dataset_prefix:"voc2012_detection/val/shard-000001.tar"}
78 | - ${dataset_prefix:"voc2012_detection/val/shard-000002.tar"}
79 | - ${dataset_prefix:"voc2012_detection/val/shard-000003.tar"}
80 | - ${dataset_prefix:"voc2012_detection/val/shard-000004.tar"}
81 | - ${dataset_prefix:"voc2012_detection/val/shard-000005.tar"}
82 | - ${dataset_prefix:"voc2012_detection/val/shard-000006.tar"}
83 | - ${dataset_prefix:"voc2012_detection/val/shard-000007.tar"}
84 | - ${dataset_prefix:"voc2012_detection/val/shard-000008.tar"}
85 | - ${dataset_prefix:"voc2012_detection/val/shard-000009.tar"}
86 | - ${dataset_prefix:"voc2012_detection/val/shard-000010.tar"}
87 | - ${dataset_prefix:"voc2012_detection/val/shard-000011.tar"}
88 | - ${dataset_prefix:"voc2012_detection/val/shard-000012.tar"}
89 | - ${dataset_prefix:"voc2012_detection/val/shard-000013.tar"}
90 | - ${dataset_prefix:"voc2012_detection/val/shard-000014.tar"}
91 | - ${dataset_prefix:"voc2012_detection/val/shard-000015.tar"}
92 | - ${dataset_prefix:"voc2012_detection/val/shard-000016.tar"}
93 | - ${dataset_prefix:"voc2012_detection/val/shard-000017.tar"}
94 | - ${dataset_prefix:"voc2012_detection/val/shard-000018.tar"}
95 | - ${dataset_prefix:"voc2012_detection/val/shard-000019.tar"}
96 | - ${dataset_prefix:"voc2012_detection/val/shard-000020.tar"}
97 | - ${dataset_prefix:"voc2012_detection/val/shard-000021.tar"}
98 | - ${dataset_prefix:"voc2012_detection/val/shard-000022.tar"}
99 | val_size: 11540
100 | test_shards: ${dataset_prefix:"voc2012_detection/test/shard-{000000..000041}.tar"}
101 | test_size: 10991
102 | use_autopadding: true
103 |
--------------------------------------------------------------------------------
/configs/experiment/_output_path.yaml:
--------------------------------------------------------------------------------
1 | # @package hydra
2 |
3 | run:
4 | dir: ${oc.select:experiment.root_output_folder,outputs}/${hydra:runtime.choices.experiment}/${now:%Y-%m-%d_%H-%M-%S}
5 | sweep:
6 | dir: ${oc.select:experiment.root_output_folder,multirun}
7 | subdir: ${hydra:runtime.choices.experiment}/${now:%Y-%m-%d_%H-%M-%S}
8 | output_subdir: config
9 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/_base_feature_recon.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Default parameters for slot attention with a ViT decoder for feature reconstruction.
3 | defaults:
4 | - /experiment/_output_path
5 | - /training_config
6 | - /feature_extractor/timm_model@models.feature_extractor
7 | - /perceptual_grouping/slot_attention@models.perceptual_grouping
8 | - /plugins/optimization@plugins.optimize_parameters
9 | - /optimizers/adam@plugins.optimize_parameters.optimizer
10 | - /lr_schedulers/exponential_decay@plugins.optimize_parameters.lr_scheduler
11 | - _self_
12 |
13 | trainer:
14 | gradient_clip_val: 1.0
15 |
16 | models:
17 | feature_extractor:
18 | model_name: vit_small_patch16_224_dino
19 | pretrained: false
20 | freeze: true
21 | feature_level: 12
22 |
23 | perceptual_grouping:
24 | input_dim: 384
25 | feature_dim: ${.object_dim}
26 | object_dim: ${models.conditioning.object_dim}
27 | use_projection_bias: false
28 | positional_embedding:
29 | _target_: ocl.neural_networks.wrappers.Sequential
30 | _args_:
31 | - _target_: ocl.neural_networks.positional_embedding.DummyPositionEmbed
32 | - _target_: ocl.neural_networks.build_two_layer_mlp
33 | input_dim: ${....input_dim}
34 | output_dim: ${....feature_dim}
35 | hidden_dim: ${....input_dim}
36 | initial_layer_norm: true
37 | ff_mlp:
38 | _target_: ocl.neural_networks.build_two_layer_mlp
39 | input_dim: ${..object_dim}
40 | output_dim: ${..object_dim}
41 | hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
42 | initial_layer_norm: true
43 | residual: true
44 |
45 | object_decoder:
46 | object_dim: ${models.perceptual_grouping.object_dim}
47 | output_dim: ${models.perceptual_grouping.input_dim}
48 | num_patches: 196
49 | object_features_path: perceptual_grouping.objects
50 | target_path: feature_extractor.features
51 | image_path: input.image
52 |
53 | plugins:
54 | optimize_parameters:
55 | optimizer:
56 | lr: 0.0004
57 | lr_scheduler:
58 | decay_rate: 0.5
59 | decay_steps: 100000
60 | warmup_steps: 10000
61 |
62 | losses:
63 | mse:
64 | _target_: ocl.losses.ReconstructionLoss
65 | loss_type: mse
66 | input_path: object_decoder.reconstruction
67 | target_path: object_decoder.target # Object decoder does some resizing.
68 |
69 | visualizations:
70 | input:
71 | _target_: ocl.visualizations.Image
72 | denormalization:
73 | _target_: ocl.preprocessing.Denormalize
74 | mean: [0.485, 0.456, 0.406]
75 | std: [0.229, 0.224, 0.225]
76 | image_path: input.image
77 | masks:
78 | _target_: ocl.visualizations.Mask
79 | mask_path: object_decoder.masks_as_image
80 | pred_segmentation:
81 | _target_: ocl.visualizations.Segmentation
82 | denormalization:
83 | _target_: ocl.preprocessing.Denormalize
84 | mean: [0.485, 0.456, 0.406]
85 | std: [0.229, 0.224, 0.225]
86 | image_path: input.image
87 | mask_path: object_decoder.masks_as_image
88 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Default parameters for slot attention with a ViT decoder for feature reconstruction.
3 | defaults:
4 | - /experiment/_output_path
5 | - /training_config
6 | - /feature_extractor/timm_model@models.feature_extractor
7 | - /perceptual_grouping/slot_attention_gumbel_v1@models.perceptual_grouping
8 | - /plugins/optimization@plugins.optimize_parameters
9 | - /optimizers/adam@plugins.optimize_parameters.optimizer
10 | - /lr_schedulers/exponential_decay@plugins.optimize_parameters.lr_scheduler
11 | - _self_
12 |
13 | trainer:
14 | gradient_clip_val: 1.0
15 |
16 | models:
17 | feature_extractor:
18 | model_name: vit_small_patch16_224_dino
19 | pretrained: false
20 | freeze: true
21 | feature_level: 12
22 |
23 | perceptual_grouping:
24 | input_dim: 384
25 | feature_dim: ${.object_dim}
26 | object_dim: ${models.conditioning.object_dim}
27 | use_projection_bias: false
28 | positional_embedding:
29 | _target_: ocl.neural_networks.wrappers.Sequential
30 | _args_:
31 | - _target_: ocl.neural_networks.positional_embedding.DummyPositionEmbed
32 | - _target_: ocl.neural_networks.build_two_layer_mlp
33 | input_dim: ${....input_dim}
34 | output_dim: ${....feature_dim}
35 | hidden_dim: ${....input_dim}
36 | initial_layer_norm: true
37 | ff_mlp:
38 | _target_: ocl.neural_networks.build_two_layer_mlp
39 | input_dim: ${..object_dim}
40 | output_dim: ${..object_dim}
41 | hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
42 | initial_layer_norm: true
43 | residual: true
44 | single_gumbel_score_network:
45 | _target_: ocl.neural_networks.build_two_layer_mlp
46 | input_dim: ${..object_dim}
47 | output_dim: 2
48 | hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
49 | initial_layer_norm: true
50 | residual: false
51 | object_decoder:
52 | object_dim: ${models.perceptual_grouping.object_dim}
53 | output_dim: ${models.perceptual_grouping.input_dim}
54 | num_patches: 196
55 | object_features_path: perceptual_grouping.objects
56 | target_path: feature_extractor.features
57 | image_path: input.image
58 |
59 | plugins:
60 | optimize_parameters:
61 | optimizer:
62 | lr: 0.0004
63 | lr_scheduler:
64 | decay_rate: 0.5
65 | decay_steps: 100000
66 | warmup_steps: 10000
67 |
68 | losses:
69 | mse:
70 | _target_: ocl.losses.ReconstructionLoss
71 | loss_type: mse
72 | input_path: object_decoder.reconstruction
73 | target_path: object_decoder.target # Object decoder does some resizing.
74 |
75 |
76 | visualizations:
77 | input:
78 | _target_: ocl.visualizations.Image
79 | denormalization:
80 | _target_: ocl.preprocessing.Denormalize
81 | mean: [0.485, 0.456, 0.406]
82 | std: [0.229, 0.224, 0.225]
83 | image_path: input.image
84 | masks:
85 | _target_: ocl.visualizations.Mask
86 | mask_path: object_decoder.masks_as_image
87 | pred_segmentation:
88 | _target_: ocl.visualizations.Segmentation
89 | denormalization:
90 | _target_: ocl.preprocessing.Denormalize
91 | mean: [0.485, 0.456, 0.406]
92 | std: [0.229, 0.224, 0.225]
93 | image_path: input.image
94 | mask_path: object_decoder.masks_as_image
95 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/_metrics_clevr_patch.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /metrics/ari_metric@evaluation_metrics.ari
4 | - /metrics/average_best_overlap_metric@evaluation_metrics.abo
5 |
6 | evaluation_metrics:
7 | ari:
8 | prediction_path: masks_as_image
9 | target_path: input.mask
10 | abo:
11 | prediction_path: masks_as_image
12 | target_path: input.mask
13 | ignore_background: true
14 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/_metrics_coco.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /metrics/ari_metric@evaluation_metrics.instance_mask_ari
4 | - /metrics/unsupervised_mask_iou_metric@evaluation_metrics.instance_abo
5 | - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
6 |
7 | evaluation_metrics:
8 | instance_mask_ari:
9 | prediction_path: object_decoder.masks_as_image
10 | target_path: input.instance_mask
11 | foreground: False
12 | convert_target_one_hot: True
13 | ignore_overlaps: True
14 | instance_abo:
15 | prediction_path: object_decoder.masks_as_image
16 | target_path: input.instance_mask
17 | use_threshold: False
18 | matching: best_overlap
19 | ignore_overlaps: True
20 | instance_mask_corloc:
21 | prediction_path: object_decoder.masks_as_image
22 | target_path: input.instance_mask
23 | use_threshold: False
24 | ignore_overlaps: True
25 |
26 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /plugins/data_preprocessing@plugins.03a_preprocessing
4 | - /plugins/multi_element_preprocessing@plugins.03b_preprocessing
5 |
6 | plugins:
7 | 03a_preprocessing:
8 | evaluation_fields:
9 | - image
10 | - instance_mask
11 | - instance_category
12 | evaluation_transform:
13 | _target_: torchvision.transforms.Compose
14 | transforms:
15 | - _target_: ocl.preprocessing.InstanceMasksToDenseMasks
16 | - _target_: ocl.preprocessing.AddSegmentationMaskFromInstanceMask
17 | # Drop instance_category again as some images do not contain it
18 | - "${lambda_fn:'lambda data: {k: v for k, v in data.items() if k != \"instance_category\"}'}"
19 | - _target_: ocl.preprocessing.AddEmptyMasks
20 | mask_keys:
21 | - instance_mask
22 | - segmentation_mask
23 |
24 | 03b_preprocessing:
25 | training_transforms:
26 | image:
27 | _target_: torchvision.transforms.Compose
28 | transforms:
29 | - _target_: torchvision.transforms.ToTensor
30 | - _target_: torchvision.transforms.Resize
31 | size: 224
32 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
33 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
34 | - _target_: torchvision.transforms.CenterCrop
35 | size: 224
36 | - _target_: torchvision.transforms.RandomHorizontalFlip
37 | - _target_: torchvision.transforms.Normalize
38 | mean: [0.485, 0.456, 0.406]
39 | std: [0.229, 0.224, 0.225]
40 | evaluation_transforms:
41 | image:
42 | _target_: torchvision.transforms.Compose
43 | transforms:
44 | - _target_: torchvision.transforms.ToTensor
45 | - _target_: torchvision.transforms.Resize
46 | size: 224
47 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
48 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
49 | - _target_: torchvision.transforms.CenterCrop
50 | size: 224
51 | - _target_: torchvision.transforms.Normalize
52 | mean: [0.485, 0.456, 0.406]
53 | std: [0.229, 0.224, 0.225]
54 | instance_mask:
55 | _target_: torchvision.transforms.Compose
56 | transforms:
57 | - _target_: ocl.preprocessing.DenseMaskToTensor
58 | - _target_: ocl.preprocessing.ResizeNearestExact
59 | size: 224
60 | - _target_: torchvision.transforms.CenterCrop
61 | size: 224
62 | segmentation_mask:
63 | _target_: torchvision.transforms.Compose
64 | transforms:
65 | - _target_: ocl.preprocessing.DenseMaskToTensor
66 | - _target_: ocl.preprocessing.ResizeNearestExact
67 | size: 224
68 | - _target_: torchvision.transforms.CenterCrop
69 | size: 224
70 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_origres.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /plugins/data_preprocessing@plugins.03a_preprocessing
4 | - /plugins/multi_element_preprocessing@plugins.03b_preprocessing
5 |
6 | plugins:
7 | 03a_preprocessing:
8 | evaluation_fields:
9 | - image
10 | - instance_mask
11 | - instance_category
12 | evaluation_transform:
13 | _target_: torchvision.transforms.Compose
14 | transforms:
15 | - _target_: ocl.preprocessing.InstanceMasksToDenseMasks
16 | - _target_: ocl.preprocessing.AddSegmentationMaskFromInstanceMask
17 | # Drop instance_category again as some images do not contain it
18 | - "${lambda_fn:'lambda data: {k: v for k, v in data.items() if k != \"instance_category\"}'}"
19 | - _target_: ocl.preprocessing.AddEmptyMasks
20 | mask_keys:
21 | - instance_mask
22 | - segmentation_mask
23 |
24 | 03b_preprocessing:
25 | training_transforms:
26 | image:
27 | _target_: torchvision.transforms.Compose
28 | transforms:
29 | - _target_: torchvision.transforms.ToTensor
30 | - _target_: torchvision.transforms.Resize
31 | _convert_: partial
32 | size: [224, 224]
33 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
34 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
35 | - _target_: torchvision.transforms.RandomHorizontalFlip
36 | - _target_: torchvision.transforms.Normalize
37 | mean: [0.485, 0.456, 0.406]
38 | std: [0.229, 0.224, 0.225]
39 | evaluation_transforms:
40 | image:
41 | _target_: torchvision.transforms.Compose
42 | transforms:
43 | - _target_: torchvision.transforms.ToTensor
44 | - _target_: torchvision.transforms.Resize
45 | _convert_: partial
46 | size: [224, 224]
47 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
48 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
49 | - _target_: torchvision.transforms.Normalize
50 | mean: [0.485, 0.456, 0.406]
51 | std: [0.229, 0.224, 0.225]
52 | instance_mask:
53 | _target_: torchvision.transforms.Compose
54 | transforms:
55 | - _target_: ocl.preprocessing.DenseMaskToTensor
56 | - _target_: ocl.preprocessing.ResizeNearestExact
57 | _convert_: partial
58 | size: [224, 224]
59 | segmentation_mask:
60 | _target_: torchvision.transforms.Compose
61 | transforms:
62 | - _target_: ocl.preprocessing.DenseMaskToTensor
63 | - _target_: ocl.preprocessing.ResizeNearestExact
64 | _convert_: partial
65 | size: [224, 224]
66 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_randcrop.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /plugins/data_preprocessing@plugins.03a_preprocessing
4 | - /plugins/multi_element_preprocessing@plugins.03b_preprocessing
5 |
6 | plugins:
7 | 03a_preprocessing:
8 | evaluation_fields:
9 | - image
10 | - instance_mask
11 | - instance_category
12 | evaluation_transform:
13 | _target_: torchvision.transforms.Compose
14 | transforms:
15 | - _target_: ocl.preprocessing.InstanceMasksToDenseMasks
16 | - _target_: ocl.preprocessing.AddSegmentationMaskFromInstanceMask
17 | # Drop instance_category again as some images do not contain it
18 | - "${lambda_fn:'lambda data: {k: v for k, v in data.items() if k != \"instance_category\"}'}"
19 | - _target_: ocl.preprocessing.AddEmptyMasks
20 | mask_keys:
21 | - instance_mask
22 | - segmentation_mask
23 |
24 | 03b_preprocessing:
25 | training_transforms:
26 | image:
27 | _target_: torchvision.transforms.Compose
28 | transforms:
29 | - _target_: torchvision.transforms.ToTensor
30 | - _target_: torchvision.transforms.Resize
31 | size: 224
32 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
33 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
34 | - _target_: torchvision.transforms.RandomCrop
35 | size: 224
36 | - _target_: torchvision.transforms.RandomHorizontalFlip
37 | - _target_: torchvision.transforms.Normalize
38 | mean: [0.485, 0.456, 0.406]
39 | std: [0.229, 0.224, 0.225]
40 | evaluation_transforms:
41 | image:
42 | _target_: torchvision.transforms.Compose
43 | transforms:
44 | - _target_: torchvision.transforms.ToTensor
45 | - _target_: torchvision.transforms.Resize
46 | size: 224
47 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
48 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
49 | - _target_: torchvision.transforms.CenterCrop
50 | size: 224
51 | - _target_: torchvision.transforms.Normalize
52 | mean: [0.485, 0.456, 0.406]
53 | std: [0.229, 0.224, 0.225]
54 | instance_mask:
55 | _target_: torchvision.transforms.Compose
56 | transforms:
57 | - _target_: ocl.preprocessing.DenseMaskToTensor
58 | - _target_: ocl.preprocessing.ResizeNearestExact
59 | size: 224
60 | - _target_: torchvision.transforms.CenterCrop
61 | size: 224
62 | segmentation_mask:
63 | _target_: torchvision.transforms.Compose
64 | transforms:
65 | - _target_: ocl.preprocessing.DenseMaskToTensor
66 | - _target_: ocl.preprocessing.ResizeNearestExact
67 | size: 224
68 | - _target_: torchvision.transforms.CenterCrop
69 | size: 224
70 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /plugins/multi_element_preprocessing@plugins.03_preprocessing
4 | - _self_
5 |
6 | plugins:
7 | 03_preprocessing:
8 | training_transforms:
9 | image:
10 | _target_: torchvision.transforms.Compose
11 | transforms:
12 | - _target_: torchvision.transforms.ToTensor
13 | - _target_: torchvision.transforms.Resize
14 | size: 224
15 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
16 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
17 | - _target_: torchvision.transforms.Normalize
18 | mean: [0.485, 0.456, 0.406]
19 | std: [0.229, 0.224, 0.225]
20 | evaluation_transforms:
21 | image:
22 | _target_: torchvision.transforms.Compose
23 | transforms:
24 | - _target_: torchvision.transforms.ToTensor
25 | - _target_: torchvision.transforms.Resize
26 | size: 224
27 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
28 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
29 | - _target_: torchvision.transforms.Normalize
30 | mean: [0.485, 0.456, 0.406]
31 | std: [0.229, 0.224, 0.225]
32 | mask:
33 | _target_: torchvision.transforms.Compose
34 | transforms:
35 | - _target_: ocl.preprocessing.MultiMaskToTensor
36 | - _target_: ocl.preprocessing.ResizeNearestExact
37 | size: 128
38 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/_preprocessing_voc2012_segm_dino_feature_recon.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /plugins/multi_element_preprocessing@plugins.02a_format_consistency
4 | - /plugins/data_preprocessing@plugins.02b_format_consistency
5 | - /plugins/data_preprocessing@plugins.03a_preprocessing
6 | - /plugins/multi_element_preprocessing@plugins.03b_preprocessing
7 |
8 | plugins:
9 | # Make VOC2012 cosistent with COCO.
10 | 02a_format_consistency:
11 | evaluation_transforms:
12 | # Convert to one-hot encoding.
13 | segmentation-instance:
14 | _target_: ocl.preprocessing.IntegerToOneHotMask
15 |
16 | 02b_format_consistency:
17 | evaluation_fields:
18 | - "segmentation-instance"
19 | - "segmentation-class"
20 | - "image"
21 | evaluation_transform:
22 | _target_: torchvision.transforms.Compose
23 | transforms:
24 | # Create segmentation mask.
25 | - _target_: ocl.preprocessing.VOCInstanceMasksToDenseMasks
26 | instance_mask_key: segmentation-instance
27 | class_mask_key: segmentation-class
28 | classes_key: instance_category
29 | - _target_: ocl.preprocessing.RenameFields
30 | mapping:
31 | segmentation-instance: instance_mask
32 | 03a_preprocessing:
33 | evaluation_fields:
34 | - image
35 | - instance_mask
36 | - instance_category
37 | evaluation_transform:
38 | _target_: torchvision.transforms.Compose
39 | transforms:
40 | # This is not needed for VOC.
41 | # - _target_: ocl.preprocessing.InstanceMasksToDenseMasks
42 | - _target_: ocl.preprocessing.AddSegmentationMaskFromInstanceMask
43 | # Drop instance_category again as some images do not contain it
44 | - "${lambda_fn:'lambda data: {k: v for k, v in data.items() if k != \"instance_category\"}'}"
45 | - _target_: ocl.preprocessing.AddEmptyMasks
46 | mask_keys:
47 | - instance_mask
48 | - segmentation_mask
49 |
50 | 03b_preprocessing:
51 | training_transforms:
52 | image:
53 | _target_: torchvision.transforms.Compose
54 | transforms:
55 | - _target_: torchvision.transforms.ToTensor
56 | - _target_: torchvision.transforms.Resize
57 | size: 224
58 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
59 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
60 | - _target_: torchvision.transforms.RandomCrop
61 | size: 224
62 | - _target_: torchvision.transforms.RandomHorizontalFlip
63 | - _target_: torchvision.transforms.Normalize
64 | mean: [0.485, 0.456, 0.406]
65 | std: [0.229, 0.224, 0.225]
66 | evaluation_transforms:
67 | image:
68 | _target_: torchvision.transforms.Compose
69 | transforms:
70 | - _target_: torchvision.transforms.ToTensor
71 | - _target_: torchvision.transforms.Resize
72 | size: 224
73 | interpolation: "${torchvision_interpolation_mode:BICUBIC}"
74 | - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
75 | - _target_: torchvision.transforms.CenterCrop
76 | size: 224
77 | - _target_: torchvision.transforms.Normalize
78 | mean: [0.485, 0.456, 0.406]
79 | std: [0.229, 0.224, 0.225]
80 | instance_mask:
81 | _target_: torchvision.transforms.Compose
82 | transforms:
83 | - _target_: ocl.preprocessing.DenseMaskToTensor
84 | - _target_: ocl.preprocessing.ResizeNearestExact
85 | size: 224
86 | - _target_: torchvision.transforms.CenterCrop
87 | size: 224
88 | segmentation_mask:
89 | _target_: torchvision.transforms.Compose
90 | transforms:
91 | - _target_: ocl.preprocessing.DenseMaskToTensor
92 | - _target_: ocl.preprocessing.ResizeNearestExact
93 | size: 224
94 | - _target_: torchvision.transforms.CenterCrop
95 | size: 224
96 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /conditioning/random@models.conditioning
4 | - /experiment/projects/bridging/dinosaur/_base_feature_recon
5 | - /neural_networks/mlp@models.object_decoder.decoder
6 | - /dataset: coco
7 | - /experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop
8 | - /experiment/projects/bridging/dinosaur/_metrics_coco
9 | - _self_
10 |
11 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
12 | trainer:
13 | gpus: 8
14 | max_steps: 200000
15 | max_epochs: null
16 | strategy: ddp
17 |
18 | dataset:
19 | num_workers: 4
20 | batch_size: 8
21 |
22 | models:
23 | conditioning:
24 | n_slots: 7
25 | object_dim: 256
26 |
27 | feature_extractor:
28 | model_name: vit_base_patch16_224_dino
29 | pretrained: true
30 | freeze: true
31 |
32 | perceptual_grouping:
33 | input_dim: 768
34 |
35 | object_decoder:
36 | _target_: ocl.decoding.PatchDecoder
37 | decoder:
38 | features: [2048, 2048, 2048]
39 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /conditioning/random@models.conditioning
4 | - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
5 | - /neural_networks/mlp@models.object_decoder.decoder
6 | - /dataset: coco
7 | - /experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop
8 | - /experiment/projects/bridging/dinosaur/_metrics_coco
9 | - /metrics/tensor_statistic@training_metrics.hard_keep_decision
10 | - /metrics/tensor_statistic@training_metrics.slots_keep_prob
11 | - _self_
12 |
13 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
14 | trainer:
15 | gpus: 8
16 | max_steps: 500000
17 | max_epochs: null
18 | strategy: ddp
19 |
20 | dataset:
21 | num_workers: 4
22 | batch_size: 8
23 |
24 | models:
25 | _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
26 | conditioning:
27 | n_slots: 33
28 | object_dim: 256
29 |
30 | feature_extractor:
31 | model_name: vit_base_patch16_224_dino
32 | pretrained: true
33 | freeze: true
34 |
35 | perceptual_grouping:
36 | input_dim: 768
37 | low_bound: 0
38 |
39 | object_decoder:
40 | _target_: ocl.decoding.PatchDecoderGumbelV1
41 | decoder:
42 | features: [2048, 2048, 2048]
43 | left_mask_path: None
44 | mask_type: mask_normalized
45 |
46 | losses:
47 | sparse_penalty:
48 | _target_: ocl.losses.SparsePenalty
49 | linear_weight: 0.1
50 | quadratic_weight: 0.0
51 | quadratic_bias: 0.5
52 | input_path: hard_keep_decision
53 |
54 | # outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
55 | # outputs["slots_keep_prob"]
56 | training_metrics:
57 | hard_keep_decision:
58 | path: hard_keep_decision
59 | reduction: sum
60 |
61 | slots_keep_prob:
62 | path: slots_keep_prob
63 | reduction: mean
64 |
65 | load_model_weight: /home/ubuntu/GitLab/bags-of-tricks/object-centric-learning-models/outputs/projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml/2023-05-02_16-57-16/lightning_logs/version_0/checkpoints/epoch=95-step=177408.ckpt
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot_eval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /conditioning/random@models.conditioning
4 | - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
5 | - /neural_networks/mlp@models.object_decoder.decoder
6 | - /dataset: coco
7 | - /experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop
8 | - /experiment/projects/bridging/dinosaur/_metrics_coco
9 | - /metrics/tensor_statistic@evaluation_metrics.hard_keep_decision
10 | - /metrics/tensor_statistic@evaluation_metrics.slots_keep_prob
11 | - /metrics/ami_metric@evaluation_metrics.ami
12 | - /metrics/nmi_metric@evaluation_metrics.nmi
13 | - /metrics/purity_metric@evaluation_metrics.purity
14 | - /metrics/precision_metric@evaluation_metrics.precision
15 | - /metrics/recall_metric@evaluation_metrics.recall
16 | - /metrics/f1_metric@evaluation_metrics.f1
17 | - _self_
18 |
19 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
20 | trainer:
21 | gpus: 8
22 | max_steps: 500000
23 | max_epochs: null
24 | strategy: ddp
25 |
26 | dataset:
27 | num_workers: 4
28 | batch_size: 8
29 |
30 | models:
31 | _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
32 | conditioning:
33 | n_slots: 33
34 | object_dim: 256
35 |
36 | feature_extractor:
37 | model_name: vit_base_patch16_224_dino
38 | pretrained: true
39 | freeze: true
40 |
41 | perceptual_grouping:
42 | input_dim: 768
43 | low_bound: 0
44 |
45 | object_decoder:
46 | _target_: ocl.decoding.PatchDecoderGumbelV1
47 | decoder:
48 | features: [2048, 2048, 2048]
49 | left_mask_path: None
50 | mask_type: mask_normalized
51 |
52 | losses:
53 | sparse_penalty:
54 | _target_: ocl.losses.SparsePenalty
55 | linear_weight: 0.1
56 | quadratic_weight: 0.0
57 | quadratic_bias: 0.5
58 | input_path: hard_keep_decision
59 |
60 | evaluation_metrics:
61 | hard_keep_decision:
62 | path: hard_keep_decision
63 | reduction: sum
64 |
65 | slots_keep_prob:
66 | path: slots_keep_prob
67 | reduction: mean
68 |
69 | ami:
70 | prediction_path: object_decoder.masks_as_image
71 | target_path: input.instance_mask
72 | foreground: true
73 | convert_target_one_hot: true
74 | ignore_overlaps: true
75 | back_as_class: false
76 |
77 | nmi:
78 | prediction_path: object_decoder.masks_as_image
79 | target_path: input.instance_mask
80 | foreground: true
81 | convert_target_one_hot: true
82 | ignore_overlaps: true
83 | back_as_class: false
84 |
85 | purity:
86 | prediction_path: object_decoder.masks_as_image
87 | target_path: input.instance_mask
88 | foreground: true
89 | convert_target_one_hot: true
90 | ignore_overlaps: true
91 | back_as_class: false
92 |
93 | precision:
94 | prediction_path: object_decoder.masks_as_image
95 | target_path: input.instance_mask
96 | foreground: true
97 | convert_target_one_hot: true
98 | ignore_overlaps: true
99 | back_as_class: false
100 |
101 | recall:
102 | prediction_path: object_decoder.masks_as_image
103 | target_path: input.instance_mask
104 | foreground: true
105 | convert_target_one_hot: true
106 | ignore_overlaps: true
107 | back_as_class: false
108 |
109 | f1:
110 | prediction_path: object_decoder.masks_as_image
111 | target_path: input.instance_mask
112 | foreground: true
113 | convert_target_one_hot: true
114 | ignore_overlaps: true
115 | back_as_class: false
116 |
117 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # ViT feature reconstruction on MOVI-E.
3 | defaults:
4 | - /conditioning/random@models.conditioning
5 | - /experiment/projects/bridging/dinosaur/_base_feature_recon
6 | - /neural_networks/mlp@models.object_decoder.decoder
7 | - /dataset: movi_c_image
8 | - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
9 | - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
10 | - _self_
11 |
12 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
13 | trainer:
14 | gpus: 8
15 | max_steps: 500000
16 | max_epochs: null
17 | strategy: ddp
18 |
19 | dataset:
20 | num_workers: 4
21 | batch_size: 8
22 |
23 | models:
24 | conditioning:
25 | n_slots: 11
26 | object_dim: 128
27 |
28 | feature_extractor:
29 | model_name: vit_base_patch16_224_dino
30 | pretrained: true
31 |
32 | perceptual_grouping:
33 | input_dim: 768
34 |
35 | object_decoder:
36 | _target_: ocl.decoding.PatchDecoder
37 | num_patches: 196
38 | decoder:
39 | features: [1024, 1024, 1024]
40 |
41 | masks_as_image:
42 | _target_: ocl.utils.resizing.Resize
43 | input_path: object_decoder.masks
44 | size: 128
45 | resize_mode: bilinear
46 | patch_mode: true
47 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # ViT feature reconstruction on MOVI-E.
3 | defaults:
4 | - /conditioning/random@models.conditioning
5 | - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
6 | - /neural_networks/mlp@models.object_decoder.decoder
7 | - /dataset: movi_c_image
8 | - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
9 | - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
10 | - /metrics/tensor_statistic@training_metrics.hard_keep_decision
11 | - /metrics/tensor_statistic@training_metrics.slots_keep_prob
12 | - _self_
13 |
14 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
15 | trainer:
16 | gpus: 8
17 | max_steps: 500000
18 | max_epochs: null
19 | strategy: ddp
20 |
21 | dataset:
22 | num_workers: 4
23 | batch_size: 8
24 |
25 | models:
26 | _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
27 | conditioning:
28 | n_slots: 11
29 | object_dim: 128
30 |
31 | feature_extractor:
32 | model_name: vit_base_patch16_224_dino
33 | pretrained: true
34 |
35 | perceptual_grouping:
36 | input_dim: 768
37 | low_bound: 0
38 |
39 | object_decoder:
40 | _target_: ocl.decoding.PatchDecoderGumbelV1
41 | num_patches: 196
42 | decoder:
43 | features: [1024, 1024, 1024]
44 | left_mask_path: None
45 | mask_type: mask_normalized
46 |
47 | masks_as_image:
48 | _target_: ocl.utils.resizing.Resize
49 | input_path: object_decoder.masks
50 | size: 128
51 | resize_mode: bilinear
52 | patch_mode: true
53 |
54 | losses:
55 | sparse_penalty:
56 | _target_: ocl.losses.SparsePenalty
57 | linear_weight: 0.1
58 | quadratic_weight: 0.0
59 | quadratic_bias: 0.5
60 | input_path: hard_keep_decision
61 |
62 | # outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
63 | # outputs["slots_keep_prob"]
64 | training_metrics:
65 | hard_keep_decision:
66 | path: hard_keep_decision
67 | reduction: sum
68 |
69 | slots_keep_prob:
70 | path: slots_keep_prob
71 | reduction: mean
72 |
73 | load_model_weight: /home/ubuntu/GitLab/bags-of-tricks/object-centric-learning-models/outputs/projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml/2023-04-28_15-20-10/lightning_logs/version_0/checkpoints/epoch=328-step=450401.ckpt
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot_eval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # ViT feature reconstruction on MOVI-E.
3 | defaults:
4 | - /conditioning/random@models.conditioning
5 | - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
6 | - /neural_networks/mlp@models.object_decoder.decoder
7 | - /dataset: movi_c_image
8 | - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
9 | - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
10 | - /metrics/tensor_statistic@evaluation_metrics.hard_keep_decision
11 | - /metrics/tensor_statistic@evaluation_metrics.slots_keep_prob
12 | - /metrics/ami_metric@evaluation_metrics.ami
13 | - /metrics/nmi_metric@evaluation_metrics.nmi
14 | - /metrics/purity_metric@evaluation_metrics.purity
15 | - /metrics/precision_metric@evaluation_metrics.precision
16 | - /metrics/recall_metric@evaluation_metrics.recall
17 | - /metrics/f1_metric@evaluation_metrics.f1
18 | - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
19 | - _self_
20 |
21 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
22 | trainer:
23 | gpus: 8
24 | max_steps: 500000
25 | max_epochs: null
26 | strategy: ddp
27 |
28 | dataset:
29 | num_workers: 4
30 | batch_size: 8
31 |
32 | models:
33 | _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
34 | conditioning:
35 | n_slots: 11
36 | object_dim: 128
37 |
38 | feature_extractor:
39 | model_name: vit_base_patch16_224_dino
40 | pretrained: true
41 |
42 | perceptual_grouping:
43 | input_dim: 768
44 | low_bound: 0
45 |
46 | object_decoder:
47 | _target_: ocl.decoding.PatchDecoderGumbelV1
48 | num_patches: 196
49 | decoder:
50 | features: [1024, 1024, 1024]
51 | left_mask_path: None
52 | mask_type: mask_normalized
53 |
54 | masks_as_image:
55 | _target_: ocl.utils.resizing.Resize
56 | input_path: object_decoder.masks
57 | size: 128
58 | resize_mode: bilinear
59 | patch_mode: true
60 |
61 | losses:
62 | sparse_penalty:
63 | _target_: ocl.losses.SparsePenalty
64 | linear_weight: 0.1
65 | quadratic_weight: 0.0
66 | quadratic_bias: 0.5
67 | input_path: hard_keep_decision
68 |
69 | evaluation_metrics:
70 | hard_keep_decision:
71 | path: hard_keep_decision
72 | reduction: sum
73 |
74 | slots_keep_prob:
75 | path: slots_keep_prob
76 | reduction: mean
77 |
78 | ami:
79 | prediction_path: masks_as_image
80 | target_path: input.mask
81 | foreground: true
82 | convert_target_one_hot: false
83 | ignore_overlaps: false
84 |
85 | nmi:
86 | prediction_path: masks_as_image
87 | target_path: input.mask
88 | foreground: true
89 | convert_target_one_hot: false
90 | ignore_overlaps: false
91 |
92 | purity:
93 | prediction_path: masks_as_image
94 | target_path: input.mask
95 | foreground: true
96 | convert_target_one_hot: false
97 | ignore_overlaps: false
98 |
99 | precision:
100 | prediction_path: masks_as_image
101 | target_path: input.mask
102 | foreground: true
103 | convert_target_one_hot: false
104 | ignore_overlaps: false
105 |
106 | recall:
107 | prediction_path: masks_as_image
108 | target_path: input.mask
109 | foreground: true
110 | convert_target_one_hot: false
111 | ignore_overlaps: false
112 |
113 | f1:
114 | prediction_path: masks_as_image
115 | target_path: input.mask
116 | foreground: true
117 | convert_target_one_hot: false
118 | ignore_overlaps: false
119 |
120 | instance_mask_corloc:
121 | prediction_path: masks_as_image
122 | target_path: input.mask
123 | use_threshold: False
124 | ignore_background: True
125 | ignore_overlaps: False
126 |
127 | plugins:
128 | 02_sample_frames:
129 | n_frames_per_video: 24
130 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # ViT feature reconstruction on MOVI-E.
3 | defaults:
4 | - /conditioning/random@models.conditioning
5 | - /experiment/projects/bridging/dinosaur/_base_feature_recon
6 | - /neural_networks/mlp@models.object_decoder.decoder
7 | - /dataset: movi_e_image
8 | - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
9 | - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
10 | - _self_
11 |
12 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
13 | trainer:
14 | gpus: 8
15 | max_steps: 100000
16 | max_epochs: null
17 | strategy: ddp
18 |
19 | dataset:
20 | num_workers: 4
21 | batch_size: 8
22 |
23 | models:
24 | conditioning:
25 | n_slots: 24
26 | object_dim: 128
27 |
28 | feature_extractor:
29 | model_name: vit_base_patch16_224_dino
30 | pretrained: true
31 |
32 | perceptual_grouping:
33 | input_dim: 768
34 |
35 | object_decoder:
36 | _target_: ocl.decoding.PatchDecoder
37 | num_patches: 196
38 | decoder:
39 | features: [1024, 1024, 1024]
40 |
41 | masks_as_image:
42 | _target_: ocl.utils.resizing.Resize
43 | input_path: object_decoder.masks
44 | size: 128
45 | resize_mode: bilinear
46 | patch_mode: true
47 |
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # ViT feature reconstruction on MOVI-E.
3 | defaults:
4 | - /conditioning/random@models.conditioning
5 | - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
6 | - /neural_networks/mlp@models.object_decoder.decoder
7 | - /dataset: movi_e_image
8 | - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
9 | - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
10 | - /metrics/tensor_statistic@training_metrics.hard_keep_decision
11 | - /metrics/tensor_statistic@training_metrics.slots_keep_prob
12 | - _self_
13 |
14 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
15 | trainer:
16 | gpus: 8
17 | max_steps: 500000
18 | max_epochs: null
19 | strategy: ddp
20 |
21 | dataset:
22 | num_workers: 4
23 | batch_size: 8
24 |
25 | models:
26 | _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
27 | conditioning:
28 | n_slots: 24
29 | object_dim: 128
30 |
31 | feature_extractor:
32 | model_name: vit_base_patch16_224_dino
33 | pretrained: true
34 |
35 | perceptual_grouping:
36 | input_dim: 768
37 | low_bound: 0
38 |
39 | object_decoder:
40 | _target_: ocl.decoding.PatchDecoderGumbelV1
41 | num_patches: 196
42 | decoder:
43 | features: [1024, 1024, 1024]
44 | left_mask_path: None
45 | mask_type: mask_normalized
46 |
47 | masks_as_image:
48 | _target_: ocl.utils.resizing.Resize
49 | input_path: object_decoder.masks
50 | size: 128
51 | resize_mode: bilinear
52 | patch_mode: true
53 |
54 | losses:
55 | sparse_penalty:
56 | _target_: ocl.losses.SparsePenalty
57 | linear_weight: 0.1
58 | quadratic_weight: 0.0
59 | quadratic_bias: 0.5
60 | input_path: hard_keep_decision
61 |
62 | # outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
63 | # outputs["slots_keep_prob"]
64 | training_metrics:
65 | hard_keep_decision:
66 | path: hard_keep_decision
67 | reduction: sum
68 |
69 | slots_keep_prob:
70 | path: slots_keep_prob
71 | reduction: mean
72 |
73 | # load_model_weight: /home/ubuntu/GitLab/bags-of-tricks/object-centric-learning-models/outputs/projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml/2023-04-25_17-13-38/lightning_logs/version_0/checkpoints/epoch=86-step=119190.ckpt
--------------------------------------------------------------------------------
/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot_eval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # ViT feature reconstruction on MOVI-E.
3 | defaults:
4 | - /conditioning/random@models.conditioning
5 | - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
6 | - /neural_networks/mlp@models.object_decoder.decoder
7 | - /dataset: movi_e_image
8 | - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
9 | - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
10 | - /metrics/tensor_statistic@evaluation_metrics.hard_keep_decision
11 | - /metrics/tensor_statistic@evaluation_metrics.slots_keep_prob
12 | - /metrics/ami_metric@evaluation_metrics.ami
13 | - /metrics/nmi_metric@evaluation_metrics.nmi
14 | - /metrics/purity_metric@evaluation_metrics.purity
15 | - /metrics/precision_metric@evaluation_metrics.precision
16 | - /metrics/recall_metric@evaluation_metrics.recall
17 | - /metrics/f1_metric@evaluation_metrics.f1
18 | - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
19 | - _self_
20 |
21 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
22 | trainer:
23 | gpus: 8
24 | max_steps: 500000
25 | max_epochs: null
26 | strategy: ddp
27 |
28 | dataset:
29 | num_workers: 4
30 | batch_size: 8
31 |
32 | models:
33 | _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
34 | conditioning:
35 | n_slots: 24
36 | object_dim: 128
37 |
38 | feature_extractor:
39 | model_name: vit_base_patch16_224_dino
40 | pretrained: true
41 |
42 | perceptual_grouping:
43 | input_dim: 768
44 | low_bound: 0
45 |
46 | object_decoder:
47 | _target_: ocl.decoding.PatchDecoderGumbelV1
48 | num_patches: 196
49 | decoder:
50 | features: [1024, 1024, 1024]
51 | left_mask_path: None
52 | mask_type: mask_normalized
53 |
54 | masks_as_image:
55 | _target_: ocl.utils.resizing.Resize
56 | input_path: object_decoder.masks
57 | size: 128
58 | resize_mode: bilinear
59 | patch_mode: true
60 |
61 | losses:
62 | sparse_penalty:
63 | _target_: ocl.losses.SparsePenalty
64 | linear_weight: 0.1
65 | quadratic_weight: 0.0
66 | quadratic_bias: 0.5
67 | input_path: hard_keep_decision
68 |
69 | evaluation_metrics:
70 | hard_keep_decision:
71 | path: hard_keep_decision
72 | reduction: sum
73 |
74 | slots_keep_prob:
75 | path: slots_keep_prob
76 | reduction: mean
77 |
78 | ami:
79 | prediction_path: masks_as_image
80 | target_path: input.mask
81 | foreground: true
82 | convert_target_one_hot: false
83 | ignore_overlaps: false
84 |
85 | nmi:
86 | prediction_path: masks_as_image
87 | target_path: input.mask
88 | foreground: true
89 | convert_target_one_hot: false
90 | ignore_overlaps: false
91 |
92 | purity:
93 | prediction_path: masks_as_image
94 | target_path: input.mask
95 | foreground: true
96 | convert_target_one_hot: false
97 | ignore_overlaps: false
98 |
99 | precision:
100 | prediction_path: masks_as_image
101 | target_path: input.mask
102 | foreground: true
103 | convert_target_one_hot: false
104 | ignore_overlaps: false
105 |
106 | recall:
107 | prediction_path: masks_as_image
108 | target_path: input.mask
109 | foreground: true
110 | convert_target_one_hot: false
111 | ignore_overlaps: false
112 |
113 | f1:
114 | prediction_path: masks_as_image
115 | target_path: input.mask
116 | foreground: true
117 | convert_target_one_hot: false
118 | ignore_overlaps: false
119 |
120 | instance_mask_corloc:
121 | prediction_path: masks_as_image
122 | target_path: input.mask
123 | use_threshold: False
124 | ignore_background: True
125 | ignore_overlaps: False
126 |
127 | plugins:
128 | 02_sample_frames:
129 | n_frames_per_video: 24
130 |
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/_base.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Default parameters for slot attention.
3 | defaults:
4 | - /experiment/_output_path
5 | - /training_config
6 | - /feature_extractor/slot_attention@models.feature_extractor
7 | - /conditioning/random@models.conditioning
8 | - /perceptual_grouping/slot_attention@models.perceptual_grouping
9 | - /plugins/optimization@plugins.optimize_parameters
10 | - /optimizers/adam@plugins.optimize_parameters.optimizer
11 | - /lr_schedulers/exponential_decay@plugins.optimize_parameters.lr_scheduler
12 | - _self_
13 |
14 | models:
15 | conditioning:
16 | object_dim: 64
17 |
18 | perceptual_grouping:
19 | feature_dim: 64
20 | object_dim: ${..conditioning.object_dim}
21 | kvq_dim: 128
22 | positional_embedding:
23 | _target_: ocl.neural_networks.wrappers.Sequential
24 | _args_:
25 | - _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
26 | n_spatial_dims: 2
27 | feature_dim: 64
28 | - _target_: ocl.neural_networks.build_two_layer_mlp
29 | input_dim: 64
30 | output_dim: 64
31 | hidden_dim: 128
32 | initial_layer_norm: true
33 | residual: false
34 | ff_mlp:
35 | _target_: ocl.neural_networks.build_two_layer_mlp
36 | input_dim: 64
37 | output_dim: 64
38 | hidden_dim: 128
39 | initial_layer_norm: true
40 | residual: true
41 |
42 | object_decoder:
43 | _target_: ocl.decoding.SlotAttentionDecoder
44 | object_features_path: perceptual_grouping.objects
45 | decoder:
46 | _target_: ocl.decoding.get_slotattention_decoder_backbone
47 | object_dim: ${models.perceptual_grouping.object_dim}
48 | positional_embedding:
49 | _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
50 | n_spatial_dims: 2
51 | feature_dim: ${models.perceptual_grouping.object_dim}
52 | cnn_channel_order: true
53 |
54 | plugins:
55 | optimize_parameters:
56 | optimizer:
57 | lr: 0.0004
58 | lr_scheduler:
59 | decay_rate: 0.5
60 | decay_steps: 100000
61 | warmup_steps: 10000
62 |
63 | losses:
64 | mse:
65 | _target_: ocl.losses.ReconstructionLoss
66 | loss_type: mse_sum
67 | input_path: object_decoder.reconstruction
68 | target_path: input.image
69 |
70 | visualizations:
71 | input:
72 | _target_: ocl.visualizations.Image
73 | denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
74 | image_path: input.image
75 | reconstruction:
76 | _target_: ocl.visualizations.Image
77 | denormalization: ${..input.denormalization}
78 | image_path: object_decoder.reconstruction
79 | objects:
80 | _target_: ocl.visualizations.VisualObject
81 | denormalization: ${..input.denormalization}
82 | object_path: object_decoder.object_reconstructions
83 | mask_path: object_decoder.masks
84 | pred_segmentation:
85 | _target_: ocl.visualizations.Segmentation
86 | denormalization: ${..input.denormalization}
87 | image_path: input.image
88 | mask_path: object_decoder.masks
89 |
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/_base_gumbel.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Default parameters for slot attention.
3 | defaults:
4 | - /experiment/_output_path
5 | - /training_config
6 | - /feature_extractor/slot_attention@models.feature_extractor
7 | - /conditioning/random@models.conditioning
8 | - /perceptual_grouping/slot_attention_gumbel_v1@models.perceptual_grouping
9 | - /plugins/optimization@plugins.optimize_parameters
10 | - /optimizers/adam@plugins.optimize_parameters.optimizer
11 | - /lr_schedulers/exponential_decay@plugins.optimize_parameters.lr_scheduler
12 | - _self_
13 |
14 | models:
15 | conditioning:
16 | object_dim: 64
17 |
18 | perceptual_grouping:
19 | feature_dim: 64
20 | object_dim: ${..conditioning.object_dim}
21 | kvq_dim: 128
22 | positional_embedding:
23 | _target_: ocl.neural_networks.wrappers.Sequential
24 | _args_:
25 | - _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
26 | n_spatial_dims: 2
27 | feature_dim: 64
28 | - _target_: ocl.neural_networks.build_two_layer_mlp
29 | input_dim: 64
30 | output_dim: 64
31 | hidden_dim: 128
32 | initial_layer_norm: true
33 | residual: false
34 | ff_mlp:
35 | _target_: ocl.neural_networks.build_two_layer_mlp
36 | input_dim: 64
37 | output_dim: 64
38 | hidden_dim: 128
39 | initial_layer_norm: true
40 | residual: true
41 |
42 | single_gumbel_score_network:
43 | _target_: ocl.neural_networks.build_two_layer_mlp
44 | input_dim: ${..object_dim}
45 | output_dim: 2
46 | hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
47 | initial_layer_norm: true
48 | residual: false
49 |
50 | object_decoder:
51 | _target_: ocl.decoding.SlotAttentionDecoder
52 | object_features_path: perceptual_grouping.objects
53 | decoder:
54 | _target_: ocl.decoding.get_slotattention_decoder_backbone
55 | object_dim: ${models.perceptual_grouping.object_dim}
56 | positional_embedding:
57 | _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
58 | n_spatial_dims: 2
59 | feature_dim: ${models.perceptual_grouping.object_dim}
60 | cnn_channel_order: true
61 |
62 | plugins:
63 | optimize_parameters:
64 | optimizer:
65 | lr: 0.0004
66 | lr_scheduler:
67 | decay_rate: 0.5
68 | decay_steps: 100000
69 | warmup_steps: 10000
70 |
71 | losses:
72 | mse:
73 | _target_: ocl.losses.ReconstructionLoss
74 | loss_type: mse_sum
75 | input_path: object_decoder.reconstruction
76 | target_path: input.image
77 |
78 | visualizations:
79 | input:
80 | _target_: ocl.visualizations.Image
81 | denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
82 | image_path: input.image
83 | reconstruction:
84 | _target_: ocl.visualizations.Image
85 | denormalization: ${..input.denormalization}
86 | image_path: object_decoder.reconstruction
87 | objects:
88 | _target_: ocl.visualizations.VisualObject
89 | denormalization: ${..input.denormalization}
90 | object_path: object_decoder.object_reconstructions
91 | mask_path: object_decoder.masks
92 | pred_segmentation:
93 | _target_: ocl.visualizations.Segmentation
94 | denormalization: ${..input.denormalization}
95 | image_path: input.image
96 | mask_path: object_decoder.masks
97 |
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/_base_large.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Default parameters for slot attention on resolution 128x128 with a ResNet encoder
3 | defaults:
4 | - /experiment/_output_path
5 | - /training_config
6 | - /feature_extractor/timm_model@models.feature_extractor
7 | - /perceptual_grouping/slot_attention@models.perceptual_grouping
8 | - /plugins/optimization@plugins.optimize_parameters
9 | - /optimizers/adam@plugins.optimize_parameters.optimizer
10 | - /lr_schedulers/cosine_annealing@plugins.optimize_parameters.lr_scheduler
11 | - _self_
12 |
13 | models:
14 | feature_extractor:
15 | model_name: resnet34_savi
16 | feature_level: 4
17 | pretrained: false
18 | freeze: false
19 |
20 | perceptual_grouping:
21 | feature_dim: ${models.perceptual_grouping.object_dim}
22 | object_dim: ${models.conditioning.object_dim}
23 | kvq_dim: ${models.perceptual_grouping.object_dim}
24 | positional_embedding:
25 | _target_: ocl.neural_networks.wrappers.Sequential
26 | _args_:
27 | - _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
28 | n_spatial_dims: 2
29 | feature_dim: 512
30 | savi_style: true
31 | - _target_: ocl.neural_networks.build_two_layer_mlp
32 | input_dim: 512
33 | output_dim: ${models.perceptual_grouping.object_dim}
34 | hidden_dim: ${models.perceptual_grouping.object_dim}
35 | initial_layer_norm: true
36 | ff_mlp:
37 | _target_: ocl.neural_networks.build_two_layer_mlp
38 | input_dim: ${models.perceptual_grouping.object_dim}
39 | output_dim: ${models.perceptual_grouping.object_dim}
40 | hidden_dim: "${eval_lambda:'lambda dim: 2 * dim', ${.input_dim}}"
41 | initial_layer_norm: true
42 | residual: true
43 |
44 | object_decoder:
45 | _target_: ocl.decoding.SlotAttentionDecoder
46 | final_activation: tanh
47 | decoder:
48 | _target_: ocl.decoding.get_savi_decoder_backbone
49 | object_dim: ${models.perceptual_grouping.object_dim}
50 | larger_input_arch: true
51 | channel_multiplier: 1
52 | positional_embedding:
53 | _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
54 | n_spatial_dims: 2
55 | feature_dim: ${models.perceptual_grouping.object_dim}
56 | cnn_channel_order: true
57 | savi_style: true
58 | object_features_path: perceptual_grouping.objects
59 |
60 | plugins:
61 | optimize_parameters:
62 | optimizer:
63 | lr: 0.0002
64 | lr_scheduler:
65 | warmup_steps: 2500
66 | T_max: ${trainer.max_steps}
67 |
68 | losses:
69 | mse:
70 | _target_: ocl.losses.ReconstructionLoss
71 | loss_type: mse
72 | input_path: object_decoder.reconstruction
73 | target_path: input.image
74 |
75 | visualizations:
76 | input:
77 | _target_: ocl.visualizations.Image
78 | denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
79 | image_path: input.image
80 | reconstruction:
81 | _target_: ocl.visualizations.Image
82 | denormalization: ${..input.denormalization}
83 | image_path: object_decoder.reconstruction
84 | objects:
85 | _target_: ocl.visualizations.VisualObject
86 | denormalization: ${..input.denormalization}
87 | object_path: object_decoder.object_reconstructions
88 | mask_path: object_decoder.masks
89 | pred_segmentation:
90 | _target_: ocl.visualizations.Segmentation
91 | denormalization: ${..input.denormalization}
92 | image_path: input.image
93 | mask_path: object_decoder.masks
94 |
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/_metrics_clevr.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Metrics for CLEVR-like datasets
3 | defaults:
4 | - /metrics/ari_metric@evaluation_metrics.ari
5 | - /metrics/average_best_overlap_metric@evaluation_metrics.abo
6 | - /metrics/ari_metric@evaluation_metrics.ari_bg
7 | - /metrics/average_best_overlap_metric@evaluation_metrics.abo_bg
8 | evaluation_metrics:
9 | ari:
10 | prediction_path: object_decoder.masks
11 | target_path: input.mask
12 | abo:
13 | prediction_path: object_decoder.masks
14 | target_path: input.mask
15 | ignore_background: true
16 | ari_bg:
17 | prediction_path: object_decoder.masks
18 | target_path: input.mask
19 | foreground: false
20 | abo_bg:
21 | prediction_path: object_decoder.masks
22 | target_path: input.mask
23 | ignore_background: false
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/_metrics_coco.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # Metrics for COCO-like datasets
3 | defaults:
4 | - /metrics/ari_metric@evaluation_metrics.instance_mask_ari
5 | - /metrics/unsupervised_mask_iou_metric@evaluation_metrics.instance_mask_iou
6 | - /metrics/unsupervised_mask_iou_metric@evaluation_metrics.segmentation_mask_iou
7 | - /metrics/average_best_overlap_metric@evaluation_metrics.instance_mask_abo
8 | - /metrics/average_best_overlap_metric@evaluation_metrics.segmentation_mask_abo
9 | - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
10 |
11 | evaluation_metrics:
12 | instance_mask_ari:
13 | prediction_path: object_decoder.masks
14 | target_path: input.instance_mask
15 | foreground: false
16 | ignore_overlaps: true
17 | convert_target_one_hot: true
18 | instance_mask_iou:
19 | prediction_path: object_decoder.masks
20 | target_path: input.instance_mask
21 | ignore_overlaps: true
22 | segmentation_mask_iou:
23 | prediction_path: object_decoder.masks
24 | target_path: input.segmentation_mask
25 | instance_mask_abo:
26 | prediction_path: object_decoder.masks
27 | target_path: input.instance_mask
28 | ignore_overlaps: true
29 | segmentation_mask_abo:
30 | prediction_path: object_decoder.masks
31 | target_path: input.segmentation_mask
32 | instance_mask_corloc:
33 | prediction_path: object_decoder.masks
34 | target_path: input.instance_mask
35 | use_threshold: False
36 | ignore_overlaps: true
37 |
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/_preprocessing_cater.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /plugins/multi_element_preprocessing@plugins.03_preprocessing
4 | - _self_
5 |
6 |
7 | plugins:
8 | 03_preprocessing:
9 | training_transforms:
10 | image:
11 | _target_: torchvision.transforms.Compose
12 | transforms:
13 | - _target_: torchvision.transforms.ToTensor
14 | - _target_: torchvision.transforms.Resize
15 | size: 128
16 | - _target_: torchvision.transforms.Normalize
17 | mean: [0.5, 0.5, 0.5]
18 | std: [0.5, 0.5, 0.5]
19 | evaluation_transforms:
20 | image:
21 | _target_: torchvision.transforms.Compose
22 | transforms:
23 | - _target_: torchvision.transforms.ToTensor
24 | - _target_: torchvision.transforms.Resize
25 | size: 128
26 | - _target_: torchvision.transforms.Normalize
27 | mean: [0.5, 0.5, 0.5]
28 | std: [0.5, 0.5, 0.5]
29 | mask:
30 | _target_: torchvision.transforms.Compose
31 | transforms:
32 | - _target_: ocl.preprocessing.MultiMaskToTensor
33 | - _target_: ocl.preprocessing.ResizeNearestExact
34 | size: 128
35 |
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/_preprocessing_clevr.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /plugins/multi_element_preprocessing@plugins.03_preprocessing
4 |
5 | plugins:
6 | 03_preprocessing:
7 | training_transforms:
8 | image:
9 | _target_: torchvision.transforms.Compose
10 | transforms:
11 | - _target_: torchvision.transforms.ToTensor
12 | - _target_: torchvision.transforms.CenterCrop
13 | size: [192, 192]
14 | - _target_: torchvision.transforms.Resize
15 | size: 128
16 | - _target_: torchvision.transforms.Normalize
17 | mean: [0.5, 0.5, 0.5]
18 | std: [0.5, 0.5, 0.5]
19 | evaluation_transforms:
20 | image:
21 | _target_: torchvision.transforms.Compose
22 | transforms:
23 | - _target_: torchvision.transforms.ToTensor
24 | - _target_: torchvision.transforms.CenterCrop
25 | size: [192, 192]
26 | - _target_: torchvision.transforms.Resize
27 | size: 128
28 | - _target_: torchvision.transforms.Normalize
29 | mean: [0.5, 0.5, 0.5]
30 | std: [0.5, 0.5, 0.5]
31 | mask:
32 | _target_: torchvision.transforms.Compose
33 | transforms:
34 | - _target_: ocl.preprocessing.MaskToTensor
35 | - _target_: torchvision.transforms.CenterCrop
36 | size: [192, 192]
37 | - _target_: ocl.preprocessing.ResizeNearestExact
38 | size: 128
39 |
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/_preprocessing_clevr_no_norm.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /plugins/multi_element_preprocessing@plugins.03_preprocessing
4 |
5 | plugins:
6 | 03_preprocessing:
7 | training_transforms:
8 | image:
9 | _target_: torchvision.transforms.Compose
10 | transforms:
11 | - _target_: torchvision.transforms.ToTensor
12 | - _target_: torchvision.transforms.CenterCrop
13 | size: [192, 192]
14 | - _target_: torchvision.transforms.Resize
15 | size: 128
16 | evaluation_transforms:
17 | image:
18 | _target_: torchvision.transforms.Compose
19 | transforms:
20 | - _target_: torchvision.transforms.ToTensor
21 | - _target_: torchvision.transforms.CenterCrop
22 | size: [192, 192]
23 | - _target_: torchvision.transforms.Resize
24 | size: 128
25 | mask:
26 | _target_: torchvision.transforms.Compose
27 | transforms:
28 | - _target_: ocl.preprocessing.MaskToTensor
29 | - _target_: torchvision.transforms.CenterCrop
30 | size: [192, 192]
31 | - _target_: ocl.preprocessing.ResizeNearestExact
32 | size: 128
33 |
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/clevr10.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /experiment/slot_attention/_base
4 | - /dataset: clevr
5 | - /experiment/slot_attention/_preprocessing_clevr
6 | - /experiment/slot_attention/_metrics_clevr
7 | - _self_
8 |
9 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
10 | trainer:
11 | gpus: 8
12 | max_steps: 500000
13 | max_epochs: null
14 | strategy: ddp
15 | dataset:
16 | num_workers: 4
17 | batch_size: 8
18 |
19 | models:
20 | conditioning:
21 | n_slots: 11
22 |
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/clevr10_adaslot.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /experiment/slot_attention/_base_gumbel
4 | - /dataset: clevr
5 | - /experiment/slot_attention/_preprocessing_clevr
6 | - /experiment/slot_attention/_metrics_clevr
7 | - /metrics/tensor_statistic@training_metrics.hard_keep_decision
8 | - /metrics/tensor_statistic@training_metrics.slots_keep_prob
9 | - _self_
10 |
11 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
12 | trainer:
13 | gpus: 8
14 | max_steps: 500000
15 | max_epochs: null
16 | strategy: ddp
17 | dataset:
18 | num_workers: 4
19 | batch_size: 8
20 |
21 | models:
22 | _target_: ocl.models.image_grouping_adaslot_pixel.GroupingImgGumbel
23 | conditioning:
24 | n_slots: 11
25 |
26 | perceptual_grouping:
27 | low_bound: 0
28 |
29 | object_decoder:
30 | _target_: ocl.decoding.SlotAttentionDecoderGumbel
31 | left_mask_path: None
32 | mask_type: mask_normalized
33 |
34 | losses:
35 | sparse_penalty:
36 | _target_: ocl.losses.SparsePenalty
37 | linear_weight: 10
38 | quadratic_weight: 0.0
39 | quadratic_bias: 0.5
40 | input_path: hard_keep_decision
41 |
42 | training_metrics:
43 | hard_keep_decision:
44 | path: hard_keep_decision
45 | reduction: sum
46 |
47 | slots_keep_prob:
48 | path: slots_keep_prob
49 | reduction: mean
50 |
51 | load_model_weight: /home/ubuntu/GitLab/bags-of-tricks/object-centric-learning-models/outputs/slot_attention/clevr10.yaml/2023-05-11_11-51-55/lightning_logs/version_0/checkpoints/epoch=457-step=500000.ckpt
--------------------------------------------------------------------------------
/configs/experiment/slot_attention/clevr10_adaslot_eval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /experiment/slot_attention/_base_gumbel
4 | - /dataset: clevr
5 | - /experiment/slot_attention/_preprocessing_clevr
6 | - /experiment/slot_attention/_metrics_clevr
7 | - /metrics/tensor_statistic@evaluation_metrics.hard_keep_decision
8 | - /metrics/tensor_statistic@evaluation_metrics.slots_keep_prob
9 | - /metrics/ami_metric@evaluation_metrics.ami
10 | - /metrics/nmi_metric@evaluation_metrics.nmi
11 | - /metrics/purity_metric@evaluation_metrics.purity
12 | - /metrics/precision_metric@evaluation_metrics.precision
13 | - /metrics/recall_metric@evaluation_metrics.recall
14 | - /metrics/f1_metric@evaluation_metrics.f1
15 | - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
16 | - _self_
17 |
18 | # The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
19 | trainer:
20 | gpus: 8
21 | max_steps: 500000
22 | max_epochs: null
23 | strategy: ddp
24 | dataset:
25 | num_workers: 4
26 | batch_size: 8
27 |
28 | models:
29 | _target_: ocl.models.image_grouping_adaslot_pixel.GroupingImgGumbel
30 | conditioning:
31 | n_slots: 11
32 |
33 | perceptual_grouping:
34 | low_bound: 0
35 |
36 | object_decoder:
37 | _target_: ocl.decoding.SlotAttentionDecoderGumbel
38 | left_mask_path: None
39 | mask_type: mask_normalized
40 |
41 | losses:
42 | sparse_penalty:
43 | _target_: ocl.losses.SparsePenalty
44 | linear_weight: 10
45 | quadratic_weight: 0.0
46 | quadratic_bias: 0.5
47 | input_path: hard_keep_decision
48 |
49 | evaluation_metrics:
50 | hard_keep_decision:
51 | path: hard_keep_decision
52 | reduction: sum
53 |
54 | slots_keep_prob:
55 | path: slots_keep_prob
56 | reduction: mean
57 |
58 | ami:
59 | prediction_path: object_decoder.masks
60 | target_path: input.mask
61 | foreground: true
62 | convert_target_one_hot: false
63 | ignore_overlaps: false
64 |
65 | nmi:
66 | prediction_path: object_decoder.masks
67 | target_path: input.mask
68 | foreground: true
69 | convert_target_one_hot: false
70 | ignore_overlaps: false
71 |
72 | purity:
73 | prediction_path: object_decoder.masks
74 | target_path: input.mask
75 | foreground: true
76 | convert_target_one_hot: false
77 | ignore_overlaps: false
78 |
79 | precision:
80 | prediction_path: object_decoder.masks
81 | target_path: input.mask
82 | foreground: true
83 | convert_target_one_hot: false
84 | ignore_overlaps: false
85 |
86 | recall:
87 | prediction_path: object_decoder.masks
88 | target_path: input.mask
89 | foreground: true
90 | convert_target_one_hot: false
91 | ignore_overlaps: false
92 |
93 | f1:
94 | prediction_path: object_decoder.masks
95 | target_path: input.mask
96 | foreground: true
97 | convert_target_one_hot: false
98 | ignore_overlaps: false
99 |
100 | instance_mask_corloc:
101 | prediction_path: object_decoder.masks
102 | target_path: input.mask
103 | use_threshold: False
104 | ignore_background: True
105 | ignore_overlaps: False
106 |
--------------------------------------------------------------------------------
/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/AdaSlot/6a60387f4ee985e55b254274f41974e1aa5130e8/framework.png
--------------------------------------------------------------------------------
/ocl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/AdaSlot/6a60387f4ee985e55b254274f41974e1aa5130e8/ocl/__init__.py
--------------------------------------------------------------------------------
/ocl/cli/cli_utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | from hydra.core.hydra_config import HydraConfig
5 |
6 |
7 | def get_commandline_config_path():
8 | """Get the path of a config path specified on the command line."""
9 | hydra_cfg = HydraConfig.get()
10 | config_sources = hydra_cfg.runtime.config_sources
11 | config_path = None
12 | for source in config_sources:
13 | if source.schema == "file" and source.provider == "command-line":
14 | config_path = source.path
15 | break
16 | return config_path
17 |
18 |
19 | def find_checkpoint(path):
20 | """Find checkpoint in output path of previous run."""
21 | checkpoints = glob.glob(
22 | os.path.join(path, "lightning_logs", "version_*", "checkpoints", "*.ckpt")
23 | )
24 | checkpoints.sort()
25 | # Return the last checkpoint.
26 | # TODO (hornmax): If more than one checkpoint is stored this might not lead to the most recent
27 | # checkpoint being loaded. Generally, I think this is ok as we still allow people to set the
28 | # checkpoint manually.
29 | return checkpoints[-1]
30 |
--------------------------------------------------------------------------------
/ocl/cli/compute_dataset_size.py:
--------------------------------------------------------------------------------
1 | """Script to compute the size of a dataset.
2 |
3 | This is useful when subsampling data using transformations in order to determine the final dataset
4 | size. The size of the dataset is typically need when running distributed training in order to
5 | ensure that all nodes and gpu training processes are presented with the same number of batches.
6 | """
7 | import dataclasses
8 | import logging
9 | import os
10 | from typing import Dict
11 |
12 | import hydra
13 | import hydra_zen
14 | import tqdm
15 | from pluggy import PluginManager
16 |
17 | import ocl.hooks
18 | from ocl.config.datasets import DataModuleConfig
19 |
20 |
21 | @dataclasses.dataclass
22 | class ComputeSizeConfig:
23 | """Configuration of a training run."""
24 |
25 | dataset: DataModuleConfig
26 | plugins: Dict[str, Dict] = dataclasses.field(default_factory=dict)
27 |
28 |
29 | hydra.core.config_store.ConfigStore.instance().store(
30 | name="compute_size_config",
31 | node=ComputeSizeConfig,
32 | )
33 |
34 |
35 | @hydra.main(config_name="compute_size_config", config_path="../../configs", version_base="1.1")
36 | def compute_size(config: ComputeSizeConfig):
37 | pm = PluginManager("ocl")
38 | pm.add_hookspecs(ocl.hooks)
39 |
40 | datamodule = hydra_zen.instantiate(config.dataset, hooks=pm.hook)
41 | pm.register(datamodule)
42 |
43 | plugins = hydra_zen.instantiate(config.plugins)
44 | for plugin_name in sorted(plugins.keys())[::-1]:
45 | pm.register(plugins[plugin_name])
46 |
47 | # Compute dataset sizes
48 | # TODO(hornmax): This is needed for webdataset shuffling, is there a way to make this more
49 | # elegant and less specific?
50 | os.environ["WDS_EPOCH"] = str(0)
51 | train_size = sum(
52 | 1
53 | for _ in tqdm.tqdm(
54 | datamodule.train_data_iterator(), desc="Reading train split", unit="samples"
55 | )
56 | )
57 | logging.info("Train split size: %d", train_size)
58 | val_size = sum(
59 | 1
60 | for _ in tqdm.tqdm(
61 | datamodule.val_data_iterator(), desc="Reading validation split", unit="samples"
62 | )
63 | )
64 | logging.info("Validation split size: %d", val_size)
65 | test_size = sum(
66 | 1
67 | for _ in tqdm.tqdm(
68 | datamodule.test_data_iterator(), desc="Reading test split", unit="samples"
69 | )
70 | )
71 | logging.info("Test split size: %d", test_size)
72 |
73 |
74 | if __name__ == "__main__":
75 | compute_size()
76 |
--------------------------------------------------------------------------------
/ocl/cli/eval.py:
--------------------------------------------------------------------------------
1 | """Train a slot attention type model."""
2 | import dataclasses
3 | from typing import Any, Dict, Optional
4 |
5 | import hydra
6 | import hydra_zen
7 | import pytorch_lightning as pl
8 | from pluggy import PluginManager
9 |
10 | import ocl.hooks
11 | from ocl import base
12 | from ocl.combined_model import CombinedModel
13 | from ocl.config.datasets import DataModuleConfig
14 | from ocl.config.metrics import MetricConfig
15 | from ocl.plugins import Plugin
16 | from ocl.cli import cli_utils, eval_utils
17 | import torch
18 |
19 | TrainerConf = hydra_zen.builds(
20 | pl.Trainer, max_epochs=100, zen_partial=False, populate_full_signature=True
21 | )
22 |
23 |
24 | @dataclasses.dataclass
25 | class TrainingConfig:
26 | """Configuration of a training run."""
27 |
28 | dataset: DataModuleConfig
29 | models: Any # When provided with dict wrap in `utils.Combined`, otherwise interpret as model.
30 | losses: Dict[str, Any]
31 | visualizations: Dict[str, Any] = dataclasses.field(default_factory=dict)
32 | plugins: Dict[str, Any] = dataclasses.field(default_factory=dict)
33 | trainer: TrainerConf = TrainerConf
34 | training_vis_frequency: Optional[int] = None
35 | training_metrics: Optional[Dict[str, MetricConfig]] = None
36 | evaluation_metrics: Optional[Dict[str, MetricConfig]] = None
37 | load_checkpoint: Optional[str] = None
38 | # load_model_weight: Optional[str] = None
39 | seed: Optional[int] = None
40 | experiment: Optional[Any] = None
41 | root_output_folder: Optional[str] = None
42 |
43 |
44 | hydra.core.config_store.ConfigStore.instance().store(
45 | name="training_config",
46 | node=TrainingConfig,
47 | )
48 |
49 |
50 | def create_plugin_manager() -> PluginManager:
51 | pm = PluginManager("ocl")
52 | pm.add_hookspecs(ocl.hooks)
53 | return pm
54 |
55 |
56 | def build_and_register_datamodule_from_config(
57 | config: TrainingConfig,
58 | hooks: base.PluggyHookRelay,
59 | plugin_manager: Optional[PluginManager] = None,
60 | **datamodule_kwargs,
61 | ) -> pl.LightningDataModule:
62 | datamodule = hydra_zen.instantiate(
63 | config.dataset, hooks=hooks, _convert_="all", **datamodule_kwargs
64 | )
65 |
66 | if plugin_manager:
67 | plugin_manager.register(datamodule)
68 |
69 | return datamodule
70 |
71 |
72 | def build_and_register_plugins_from_config(
73 | config: TrainingConfig, plugin_manager: Optional[PluginManager] = None
74 | ) -> Dict[str, Plugin]:
75 | plugins = hydra_zen.instantiate(config.plugins)
76 | # Use lexicographical sorting to allow to influence registration order. This is necessary in
77 | # some cases as certain plugins might need to be called before others. Pluggy calls hooks
78 | # according to FILO (first in last out) and this is slightly unintuitive. We thus register
79 | # plugins in reverse order to their sorting position, leading to a FIFO (first in first out)
80 | # behavior with regard to the sorted position.
81 | if plugin_manager:
82 | for plugin_name in sorted(plugins.keys())[::-1]:
83 | plugin_manager.register(plugins[plugin_name])
84 |
85 | return plugins
86 |
87 |
88 | def build_model_from_config(
89 | config: TrainingConfig,
90 | hooks: base.PluggyHookRelay,
91 | checkpoint_path: Optional[str] = None,
92 | ) -> pl.LightningModule:
93 | models = hydra_zen.instantiate(config.models, _convert_="all")
94 | losses = hydra_zen.instantiate(config.losses, _convert_="all")
95 | visualizations = hydra_zen.instantiate(config.visualizations, _convert_="all")
96 |
97 | training_metrics = hydra_zen.instantiate(config.training_metrics)
98 | evaluation_metrics = hydra_zen.instantiate(config.evaluation_metrics)
99 |
100 | train_vis_freq = config.training_vis_frequency if config.training_vis_frequency else 100
101 |
102 | if checkpoint_path is None:
103 | model = CombinedModel(
104 | models=models,
105 | losses=losses,
106 | visualizations=visualizations,
107 | hooks=hooks,
108 | training_metrics=training_metrics,
109 | evaluation_metrics=evaluation_metrics,
110 | vis_log_frequency=train_vis_freq,
111 | )
112 | else:
113 | model = CombinedModel.load_from_checkpoint(
114 | checkpoint_path,
115 | strict=False,
116 | models=models,
117 | losses=losses,
118 | visualizations=visualizations,
119 | hooks=hooks,
120 | training_metrics=training_metrics,
121 | evaluation_metrics=evaluation_metrics,
122 | vis_log_frequency=train_vis_freq,
123 | )
124 | return model
125 |
126 |
127 | @hydra.main(config_name="training_config", config_path="../../configs/", version_base="1.1")
128 | def train(config: TrainingConfig):
129 | # Set all relevant random seeds. If `config.seed is None`, the function samples a random value.
130 | # The function takes care of correctly distributing the seed across nodes in multi-node training,
131 | # and assigns each dataloader worker a different random seed.
132 | # IMPORTANTLY, we need to take care not to set a custom `worker_init_fn` function on the
133 | # dataloaders (or take care of worker seeding ourselves).
134 | pl.seed_everything(config.seed, workers=True)
135 |
136 | pm = create_plugin_manager()
137 |
138 |
139 | checkpoint_path = hydra.utils.to_absolute_path(config.load_checkpoint)
140 | datamodule, model, pm = eval_utils.build_from_train_config(
141 | config, checkpoint_path
142 | )
143 |
144 | trainer: pl.Trainer = hydra_zen.instantiate(
145 | config.trainer,
146 | _convert_="all",
147 | enable_progress_bar=True,
148 | gpus=[0],
149 | )
150 |
151 | print("******start validate model******")
152 | trainer.validate(model, datamodule.val_dataloader())
153 |
154 |
155 | if __name__ == "__main__":
156 | train()
157 |
--------------------------------------------------------------------------------
/ocl/cli/eval_utils.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import pickle
3 | from collections import defaultdict
4 | from typing import Any, Callable, Dict, List, Optional
5 |
6 | import numpy
7 | import pytorch_lightning as pl
8 | import torch
9 |
10 | from ocl import path_defaults
11 | from ocl.cli import train
12 | from ocl.utils.trees import get_tree_element
13 |
14 |
15 | def build_from_train_config(
16 | config: train.TrainingConfig, checkpoint_path: Optional[str], seed: bool = True
17 | ):
18 | if seed:
19 | pl.seed_everything(config.seed, workers=True)
20 |
21 | pm = train.create_plugin_manager()
22 | datamodule = train.build_and_register_datamodule_from_config(config, pm.hook, pm)
23 | train.build_and_register_plugins_from_config(config, pm)
24 | model = train.build_model_from_config(config, pm.hook, checkpoint_path)
25 |
26 | return datamodule, model, pm
27 |
28 |
29 | class ExtractDataFromPredictions(pl.callbacks.Callback):
30 | """Callback used for extracting model outputs during validation and prediction."""
31 |
32 | def __init__(
33 | self,
34 | paths: List[str],
35 | output_paths: Optional[List[str]] = None,
36 | transform: Optional[Callable] = None,
37 | max_samples: Optional[int] = None,
38 | flatten_batches: bool = False,
39 | ):
40 | self.paths = paths
41 | self.output_paths = output_paths if output_paths is not None else paths
42 | self.transform = transform
43 | self.max_samples = max_samples
44 | self.flatten_batches = flatten_batches
45 |
46 | self.outputs = defaultdict(list)
47 | self._n_samples = 0
48 |
49 | def _start(self):
50 | self._n_samples = 0
51 | self.outputs = defaultdict(list)
52 |
53 | def _process_outputs(self, outputs, batch):
54 | if self.max_samples is not None and self._n_samples >= self.max_samples:
55 | return
56 |
57 | data = {path_defaults.INPUT: batch, **outputs}
58 | data = {path: get_tree_element(outputs, path.split(".")) for path in self.paths}
59 |
60 | if self.transform:
61 | data = self.transform(data)
62 |
63 | first_path = True
64 | for path in self.output_paths:
65 | elems = data[path].detach().cpu()
66 | if not self.flatten_batches:
67 | elems = [elems]
68 |
69 | for idx in range(len(elems)):
70 | self.outputs[path].append(elems[idx])
71 | if first_path:
72 | self._n_samples += 1
73 |
74 | first_path = False
75 |
76 | def on_validation_start(self, trainer, pl_module):
77 | self._start()
78 |
79 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
80 | assert (
81 | outputs is not None
82 | ), "Model returned no outputs. Set `model.return_outputs_on_validation=True`"
83 | self._process_outputs(outputs, batch)
84 |
85 | def on_predict_start(self, trainer, pl_module):
86 | self._start()
87 |
88 | def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
89 | self._process_outputs(outputs, batch)
90 |
91 | def get_outputs(self) -> List[Dict[str, Any]]:
92 | state = []
93 | for idx in range(self._n_samples):
94 | state.append({})
95 | for key, values in self.outputs.items():
96 | state[-1][key] = values[idx]
97 |
98 | return state
99 |
100 |
101 | def save_outputs(dir_path: str, outputs: List[Dict[str, Any]], verbose: bool = False):
102 | """Save outputs to disk in numpy or pickle format."""
103 | dir_path = pathlib.Path(dir_path)
104 | dir_path.mkdir(parents=True, exist_ok=True)
105 |
106 | def get_path(path, prefix, key, extension):
107 | return str(path / f"{prefix}.{key}.{extension}")
108 |
109 | idx_fmt = "{:0" + str(len(str(len(outputs)))) + "d}" # Get number of total digits
110 | for idx, entry in enumerate(outputs):
111 | idx_prefix = idx_fmt.format(idx)
112 | for key, value in entry.items():
113 | if isinstance(value, torch.Tensor):
114 | value = value.numpy()
115 |
116 | if isinstance(value, numpy.ndarray):
117 | path = get_path(dir_path, idx_prefix, key, "npy")
118 | if verbose:
119 | print(f"Saving numpy array to {path}.")
120 | numpy.save(path, value)
121 | else:
122 | path = get_path(dir_path, idx_prefix, key, "pkl")
123 | if verbose:
124 | print(f"Saving pickle to {path}.")
125 | with open(path, "wb") as f:
126 | pickle.dump(value, f)
127 |
--------------------------------------------------------------------------------
/ocl/cli/train.py:
--------------------------------------------------------------------------------
1 | """Train a slot attention type model."""
2 | import dataclasses
3 | from typing import Any, Dict, Optional
4 |
5 | import hydra
6 | import hydra_zen
7 | import pytorch_lightning as pl
8 | from pluggy import PluginManager
9 |
10 | import ocl.hooks
11 | from ocl import base
12 | from ocl.combined_model import CombinedModel
13 | from ocl.config.datasets import DataModuleConfig
14 | from ocl.config.metrics import MetricConfig
15 | from ocl.plugins import Plugin
16 |
17 | TrainerConf = hydra_zen.builds(
18 | pl.Trainer, max_epochs=100, zen_partial=False, populate_full_signature=True
19 | )
20 |
21 |
22 | @dataclasses.dataclass
23 | class TrainingConfig:
24 | """Configuration of a training run."""
25 |
26 | dataset: DataModuleConfig
27 | models: Any # When provided with dict wrap in `utils.Combined`, otherwise interpret as model.
28 | losses: Dict[str, Any]
29 | visualizations: Dict[str, Any] = dataclasses.field(default_factory=dict)
30 | plugins: Dict[str, Any] = dataclasses.field(default_factory=dict)
31 | trainer: TrainerConf = TrainerConf
32 | training_vis_frequency: Optional[int] = None
33 | training_metrics: Optional[Dict[str, MetricConfig]] = None
34 | evaluation_metrics: Optional[Dict[str, MetricConfig]] = None
35 | load_checkpoint: Optional[str] = None
36 | seed: Optional[int] = None
37 | experiment: Optional[Any] = None
38 | root_output_folder: Optional[str] = None
39 |
40 |
41 | hydra.core.config_store.ConfigStore.instance().store(
42 | name="training_config",
43 | node=TrainingConfig,
44 | )
45 |
46 |
47 | def create_plugin_manager() -> PluginManager:
48 | pm = PluginManager("ocl")
49 | pm.add_hookspecs(ocl.hooks)
50 | return pm
51 |
52 |
53 | def build_and_register_datamodule_from_config(
54 | config: TrainingConfig,
55 | hooks: base.PluggyHookRelay,
56 | plugin_manager: Optional[PluginManager] = None,
57 | **datamodule_kwargs,
58 | ) -> pl.LightningDataModule:
59 | datamodule = hydra_zen.instantiate(
60 | config.dataset, hooks=hooks, _convert_="all", **datamodule_kwargs
61 | )
62 |
63 | if plugin_manager:
64 | plugin_manager.register(datamodule)
65 |
66 | return datamodule
67 |
68 |
69 | def build_and_register_plugins_from_config(
70 | config: TrainingConfig, plugin_manager: Optional[PluginManager] = None
71 | ) -> Dict[str, Plugin]:
72 | plugins = hydra_zen.instantiate(config.plugins)
73 | # Use lexicographical sorting to allow to influence registration order. This is necessary in
74 | # some cases as certain plugins might need to be called before others. Pluggy calls hooks
75 | # according to FILO (first in last out) and this is slightly unintuitive. We thus register
76 | # plugins in reverse order to their sorting position, leading to a FIFO (first in first out)
77 | # behavior with regard to the sorted position.
78 | if plugin_manager:
79 | for plugin_name in sorted(plugins.keys())[::-1]:
80 | plugin_manager.register(plugins[plugin_name])
81 |
82 | return plugins
83 |
84 |
85 | def build_model_from_config(
86 | config: TrainingConfig,
87 | hooks: base.PluggyHookRelay,
88 | checkpoint_path: Optional[str] = None,
89 | ) -> pl.LightningModule:
90 | models = hydra_zen.instantiate(config.models, _convert_="all")
91 | losses = hydra_zen.instantiate(config.losses, _convert_="all")
92 | visualizations = hydra_zen.instantiate(config.visualizations, _convert_="all")
93 |
94 | training_metrics = hydra_zen.instantiate(config.training_metrics)
95 | evaluation_metrics = hydra_zen.instantiate(config.evaluation_metrics)
96 |
97 | train_vis_freq = config.training_vis_frequency if config.training_vis_frequency else 100
98 |
99 | if checkpoint_path is None:
100 | model = CombinedModel(
101 | models=models,
102 | losses=losses,
103 | visualizations=visualizations,
104 | hooks=hooks,
105 | training_metrics=training_metrics,
106 | evaluation_metrics=evaluation_metrics,
107 | vis_log_frequency=train_vis_freq,
108 | )
109 | else:
110 | model = CombinedModel.load_from_checkpoint(
111 | checkpoint_path,
112 | strict=False,
113 | models=models,
114 | losses=losses,
115 | visualizations=visualizations,
116 | hooks=hooks,
117 | training_metrics=training_metrics,
118 | evaluation_metrics=evaluation_metrics,
119 | vis_log_frequency=train_vis_freq,
120 | )
121 | return model
122 |
123 |
124 | @hydra.main(config_name="training_config", config_path="../../configs/", version_base="1.1")
125 | def train(config: TrainingConfig):
126 | # Set all relevant random seeds. If `config.seed is None`, the function samples a random value.
127 | # The function takes care of correctly distributing the seed across nodes in multi-node training,
128 | # and assigns each dataloader worker a different random seed.
129 | # IMPORTANTLY, we need to take care not to set a custom `worker_init_fn` function on the
130 | # dataloaders (or take care of worker seeding ourselves).
131 | pl.seed_everything(config.seed, workers=True)
132 |
133 | pm = create_plugin_manager()
134 |
135 | datamodule = build_and_register_datamodule_from_config(config, pm.hook, pm)
136 |
137 | build_and_register_plugins_from_config(config, pm)
138 |
139 | if config.load_checkpoint:
140 | checkpoint_path = hydra.utils.to_absolute_path(config.load_checkpoint)
141 | else:
142 | checkpoint_path = None
143 |
144 | model = build_model_from_config(config, pm.hook)
145 |
146 | callbacks = hydra_zen.instantiate(config.trainer.callbacks, _convert_="all")
147 | callbacks = callbacks if callbacks else []
148 | if config.trainer.logger is not False:
149 | lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
150 | callbacks.append(lr_monitor)
151 |
152 | trainer: pl.Trainer = hydra_zen.instantiate(config.trainer, callbacks=callbacks, _convert_="all")
153 |
154 | trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint_path)
155 |
156 |
157 | if __name__ == "__main__":
158 | train()
159 |
--------------------------------------------------------------------------------
/ocl/cli/train_adaslot.py:
--------------------------------------------------------------------------------
1 | """Train a slot attention type model."""
2 | import dataclasses
3 | from typing import Any, Dict, Optional
4 |
5 | import hydra
6 | import hydra_zen
7 | import pytorch_lightning as pl
8 | from pluggy import PluginManager
9 |
10 | import ocl.hooks
11 | from ocl import base
12 | from ocl.combined_model import CombinedModel
13 | from ocl.config.datasets import DataModuleConfig
14 | from ocl.config.metrics import MetricConfig
15 | from ocl.plugins import Plugin
16 | import torch
17 |
18 | TrainerConf = hydra_zen.builds(
19 | pl.Trainer, max_epochs=100, zen_partial=False, populate_full_signature=True
20 | )
21 |
22 |
23 | @dataclasses.dataclass
24 | class TrainingConfig:
25 | """Configuration of a training run."""
26 |
27 | dataset: DataModuleConfig
28 | models: Any # When provided with dict wrap in `utils.Combined`, otherwise interpret as model.
29 | losses: Dict[str, Any]
30 | visualizations: Dict[str, Any] = dataclasses.field(default_factory=dict)
31 | plugins: Dict[str, Any] = dataclasses.field(default_factory=dict)
32 | trainer: TrainerConf = TrainerConf
33 | training_vis_frequency: Optional[int] = None
34 | training_metrics: Optional[Dict[str, MetricConfig]] = None
35 | evaluation_metrics: Optional[Dict[str, MetricConfig]] = None
36 | load_checkpoint: Optional[str] = None
37 | load_model_weight: Optional[str] = None
38 | seed: Optional[int] = None
39 | experiment: Optional[Any] = None
40 | root_output_folder: Optional[str] = None
41 |
42 |
43 | hydra.core.config_store.ConfigStore.instance().store(
44 | name="training_config",
45 | node=TrainingConfig,
46 | )
47 |
48 |
49 | def create_plugin_manager() -> PluginManager:
50 | pm = PluginManager("ocl")
51 | pm.add_hookspecs(ocl.hooks)
52 | return pm
53 |
54 |
55 | def build_and_register_datamodule_from_config(
56 | config: TrainingConfig,
57 | hooks: base.PluggyHookRelay,
58 | plugin_manager: Optional[PluginManager] = None,
59 | **datamodule_kwargs,
60 | ) -> pl.LightningDataModule:
61 | datamodule = hydra_zen.instantiate(
62 | config.dataset, hooks=hooks, _convert_="all", **datamodule_kwargs
63 | )
64 |
65 | if plugin_manager:
66 | plugin_manager.register(datamodule)
67 |
68 | return datamodule
69 |
70 |
71 | def build_and_register_plugins_from_config(
72 | config: TrainingConfig, plugin_manager: Optional[PluginManager] = None
73 | ) -> Dict[str, Plugin]:
74 | plugins = hydra_zen.instantiate(config.plugins)
75 | # Use lexicographical sorting to allow to influence registration order. This is necessary in
76 | # some cases as certain plugins might need to be called before others. Pluggy calls hooks
77 | # according to FILO (first in last out) and this is slightly unintuitive. We thus register
78 | # plugins in reverse order to their sorting position, leading to a FIFO (first in first out)
79 | # behavior with regard to the sorted position.
80 | if plugin_manager:
81 | for plugin_name in sorted(plugins.keys())[::-1]:
82 | plugin_manager.register(plugins[plugin_name])
83 |
84 | return plugins
85 |
86 |
87 | def build_model_from_config(
88 | config: TrainingConfig,
89 | hooks: base.PluggyHookRelay,
90 | checkpoint_path: Optional[str] = None,
91 | ) -> pl.LightningModule:
92 | models = hydra_zen.instantiate(config.models, _convert_="all")
93 | losses = hydra_zen.instantiate(config.losses, _convert_="all")
94 | visualizations = hydra_zen.instantiate(config.visualizations, _convert_="all")
95 |
96 | training_metrics = hydra_zen.instantiate(config.training_metrics)
97 | evaluation_metrics = hydra_zen.instantiate(config.evaluation_metrics)
98 |
99 | train_vis_freq = config.training_vis_frequency if config.training_vis_frequency else 100
100 |
101 | if checkpoint_path is None:
102 | model = CombinedModel(
103 | models=models,
104 | losses=losses,
105 | visualizations=visualizations,
106 | hooks=hooks,
107 | training_metrics=training_metrics,
108 | evaluation_metrics=evaluation_metrics,
109 | vis_log_frequency=train_vis_freq,
110 | )
111 | else:
112 | model = CombinedModel.load_from_checkpoint(
113 | checkpoint_path,
114 | strict=False,
115 | models=models,
116 | losses=losses,
117 | visualizations=visualizations,
118 | hooks=hooks,
119 | training_metrics=training_metrics,
120 | evaluation_metrics=evaluation_metrics,
121 | vis_log_frequency=train_vis_freq,
122 | )
123 | return model
124 |
125 |
126 | @hydra.main(config_name="training_config", config_path="../../configs/", version_base="1.1")
127 | def train(config: TrainingConfig):
128 | # Set all relevant random seeds. If `config.seed is None`, the function samples a random value.
129 | # The function takes care of correctly distributing the seed across nodes in multi-node training,
130 | # and assigns each dataloader worker a different random seed.
131 | # IMPORTANTLY, we need to take care not to set a custom `worker_init_fn` function on the
132 | # dataloaders (or take care of worker seeding ourselves).
133 | pl.seed_everything(config.seed, workers=True)
134 |
135 | pm = create_plugin_manager()
136 |
137 | datamodule = build_and_register_datamodule_from_config(config, pm.hook, pm)
138 |
139 | build_and_register_plugins_from_config(config, pm)
140 |
141 | if config.load_checkpoint:
142 | checkpoint_path = hydra.utils.to_absolute_path(config.load_checkpoint)
143 | else:
144 | checkpoint_path = None
145 |
146 |
147 | model = build_model_from_config(config, pm.hook)
148 | if config.load_model_weight:
149 | model_weight_path = hydra.utils.to_absolute_path(config.load_model_weight)
150 | ckpt_weight = torch.load(model_weight_path, map_location=torch.device('cpu'))["state_dict"]
151 | model.load_state_dict(ckpt_weight, strict=False)
152 | else:
153 | model_weight_path = None
154 | callbacks = hydra_zen.instantiate(config.trainer.callbacks, _convert_="all")
155 | callbacks = callbacks if callbacks else []
156 | if config.trainer.logger is not False:
157 | lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
158 | callbacks.append(lr_monitor)
159 |
160 | trainer: pl.Trainer = hydra_zen.instantiate(config.trainer, callbacks=callbacks, _convert_="all")
161 |
162 | trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint_path)
163 |
164 |
165 | if __name__ == "__main__":
166 | train()
167 |
--------------------------------------------------------------------------------
/ocl/config/__init__.py:
--------------------------------------------------------------------------------
1 | from hydra.core.config_store import ConfigStore
2 | from omegaconf import OmegaConf
3 |
4 | from ocl.config import (
5 | conditioning,
6 | datasets,
7 | feature_extractors,
8 | metrics,
9 | neural_networks,
10 | optimizers,
11 | perceptual_groupings,
12 | plugins,
13 | predictor,
14 | utils,
15 | )
16 |
17 | config_store = ConfigStore.instance()
18 |
19 | conditioning.register_configs(config_store)
20 |
21 | datasets.register_configs(config_store)
22 | datasets.register_resolvers(OmegaConf)
23 |
24 | feature_extractors.register_configs(config_store)
25 |
26 | metrics.register_configs(config_store)
27 |
28 | neural_networks.register_configs(config_store)
29 |
30 | optimizers.register_configs(config_store)
31 |
32 | perceptual_groupings.register_configs(config_store)
33 | predictor.register_configs(config_store)
34 |
35 | plugins.register_configs(config_store)
36 | plugins.register_resolvers(OmegaConf)
37 |
38 | utils.register_configs(config_store)
39 | utils.register_resolvers(OmegaConf)
40 |
--------------------------------------------------------------------------------
/ocl/config/conditioning.py:
--------------------------------------------------------------------------------
1 | """Configuration of slot conditionings."""
2 | import dataclasses
3 |
4 | from hydra_zen import builds
5 | from omegaconf import SI
6 |
7 | from ocl import conditioning
8 |
9 |
10 | @dataclasses.dataclass
11 | class ConditioningConfig:
12 | """Base class for conditioning module configuration."""
13 |
14 |
15 | # Unfortunately, we cannot define object_dim as part of the base config class as this prevents using
16 | # required positional arguments in all subclasses. We thus instead pass them here.
17 | LearntConditioningConfig = builds(
18 | conditioning.LearntConditioning,
19 | object_dim=SI("${perceptual_grouping.object_dim}"),
20 | builds_bases=(ConditioningConfig,),
21 | populate_full_signature=True,
22 | )
23 |
24 | RandomConditioningConfig = builds(
25 | conditioning.RandomConditioning,
26 | object_dim=SI("${perceptual_grouping.object_dim}"),
27 | builds_bases=(ConditioningConfig,),
28 | populate_full_signature=True,
29 | )
30 |
31 | RandomConditioningWithQMCSamplingConfig = builds(
32 | conditioning.RandomConditioningWithQMCSampling,
33 | object_dim=SI("${perceptual_grouping.object_dim}"),
34 | builds_bases=(ConditioningConfig,),
35 | populate_full_signature=True,
36 | )
37 |
38 | SlotwiseLearntConditioningConfig = builds(
39 | conditioning.SlotwiseLearntConditioning,
40 | object_dim=SI("${perceptual_grouping.object_dim}"),
41 | builds_bases=(ConditioningConfig,),
42 | populate_full_signature=True,
43 | )
44 |
45 |
46 | def register_configs(config_store):
47 | config_store.store(group="schemas", name="conditioning", node=ConditioningConfig)
48 |
49 | config_store.store(group="conditioning", name="learnt", node=LearntConditioningConfig)
50 | config_store.store(group="conditioning", name="random", node=RandomConditioningConfig)
51 | config_store.store(
52 | group="conditioning",
53 | name="random_with_qmc_sampling",
54 | node=RandomConditioningWithQMCSamplingConfig,
55 | )
56 | config_store.store(
57 | group="conditioning", name="slotwise_learnt_random", node=SlotwiseLearntConditioningConfig
58 | )
59 |
--------------------------------------------------------------------------------
/ocl/config/datasets.py:
--------------------------------------------------------------------------------
1 | """Register all dataset related configs."""
2 | import dataclasses
3 | import os
4 |
5 | import yaml
6 | from hydra.utils import to_absolute_path
7 | from hydra_zen import builds
8 |
9 | from ocl import datasets
10 |
11 |
12 | def get_region():
13 | """Determine the region this EC2 instance is running in.
14 |
15 | Returns None if not running on an EC2 instance.
16 | """
17 | import requests
18 |
19 | try:
20 | r = requests.get(
21 | "http://169.254.169.254/latest/dynamic/instance-identity/document", timeout=0.5
22 | )
23 | response_json = r.json()
24 | return response_json.get("region")
25 | except Exception:
26 | # Not running on an ec2 instance.
27 | return None
28 |
29 |
30 | # Detemine region name and select bucket accordingly.
31 | AWS_REGION = get_region()
32 | if AWS_REGION in ["us-east-2", "us-west-2", "eu-west-1"]:
33 | # Select bucket in same region.
34 | DEFAULT_S3_PATH = f"s3://object-centric-datasets-{AWS_REGION}"
35 | # fanke aws s3 ls s3://object-centric-datasets-us-west-2/clevr_with_masks_new_splits
36 | # aws s3 cp --recursive s3://object-centric-datasets-us-west-2/clevr_with_masks_new_splits ./clevr_with_masks_new_splits
37 | # # aws s3 ls s3://object-centric-datasets-us-west-2/
38 | # aws s3 cp --recursive s3://object-centric-datasets-us-west-2/movi_e/ movi_e
39 | else:
40 | # Use MRAP to find closest bucket.
41 | DEFAULT_S3_PATH = "s3://arn:aws:s3::436622332146:accesspoint/m6p4hmmybeu97.mrap"
42 |
43 |
44 | @dataclasses.dataclass
45 | class DataModuleConfig:
46 | """Base class for PyTorch Lightning DataModules.
47 |
48 | This class does not actually do anything but ensures that datasets behave like pytorch lightning
49 | datamodules.
50 | """
51 |
52 |
53 | def dataset_prefix(path):
54 | prefix = os.environ.get("DATASET_PREFIX")
55 | if prefix:
56 | return f"{prefix}/{path}"
57 | # Use the path to the multi-region bucket if no override is specified.
58 | return f"pipe:aws s3 cp --quiet {DEFAULT_S3_PATH}/{path} -"
59 |
60 |
61 | def read_yaml(path):
62 | with open(to_absolute_path(path), "r") as f:
63 | return yaml.safe_load(f)
64 |
65 |
66 | WebdatasetDataModuleConfig = builds(
67 | datasets.WebdatasetDataModule, populate_full_signature=True, builds_bases=(DataModuleConfig,)
68 | )
69 | DummyDataModuleConfig = builds(
70 | datasets.DummyDataModule, populate_full_signature=True, builds_bases=(DataModuleConfig,)
71 | )
72 |
73 |
74 | def register_configs(config_store):
75 | config_store.store(group="schemas", name="dataset", node=DataModuleConfig)
76 | config_store.store(group="dataset", name="webdataset", node=WebdatasetDataModuleConfig)
77 | config_store.store(group="dataset", name="dummy_dataset", node=DummyDataModuleConfig)
78 |
79 |
80 | def register_resolvers(omegaconf):
81 | omegaconf.register_new_resolver("dataset_prefix", dataset_prefix)
82 | omegaconf.register_new_resolver("read_yaml", read_yaml)
83 |
--------------------------------------------------------------------------------
/ocl/config/feature_extractors.py:
--------------------------------------------------------------------------------
1 | """Configurations for feature extractors."""
2 | import dataclasses
3 |
4 | from hydra_zen import make_custom_builds_fn
5 |
6 | from ocl import feature_extractors
7 |
8 |
9 | @dataclasses.dataclass
10 | class FeatureExtractorConfig:
11 | """Base class for PyTorch Lightning DataModules.
12 |
13 | This class does not actually do anything but ensures that feature extractors give outputs of
14 | a defined structure.
15 | """
16 |
17 | pass
18 |
19 |
20 | builds_feature_extractor = make_custom_builds_fn(
21 | populate_full_signature=True,
22 | )
23 |
24 | TimmFeatureExtractorConfig = builds_feature_extractor(
25 | feature_extractors.TimmFeatureExtractor,
26 | builds_bases=(FeatureExtractorConfig,),
27 | )
28 | SlotAttentionFeatureExtractorConfig = builds_feature_extractor(
29 | feature_extractors.SlotAttentionFeatureExtractor,
30 | builds_bases=(FeatureExtractorConfig,),
31 | )
32 | DVAEFeatureExtractorConfig = builds_feature_extractor(
33 | feature_extractors.DVAEFeatureExtractor,
34 | builds_bases=(FeatureExtractorConfig,),
35 | )
36 | SAViFeatureExtractorConfig = builds_feature_extractor(
37 | feature_extractors.SAViFeatureExtractor,
38 | builds_bases=(FeatureExtractorConfig,),
39 | )
40 |
41 |
42 | def register_configs(config_store):
43 | config_store.store(group="schemas", name="feature_extractor", node=FeatureExtractorConfig)
44 | config_store.store(
45 | group="feature_extractor",
46 | name="timm_model",
47 | node=TimmFeatureExtractorConfig,
48 | )
49 | config_store.store(
50 | group="feature_extractor",
51 | name="slot_attention",
52 | node=SlotAttentionFeatureExtractorConfig,
53 | )
54 | config_store.store(
55 | group="feature_extractor",
56 | name="savi",
57 | node=SAViFeatureExtractorConfig,
58 | )
59 | config_store.store(
60 | group="feature_extractor",
61 | name="dvae",
62 | node=DVAEFeatureExtractorConfig,
63 | )
64 |
--------------------------------------------------------------------------------
/ocl/config/metrics.py:
--------------------------------------------------------------------------------
1 | """Register metric related configs."""
2 | import dataclasses
3 |
4 | from hydra_zen import builds, make_custom_builds_fn
5 |
6 | from ocl import metrics
7 | @dataclasses.dataclass
8 | class MetricConfig:
9 | """Base class for metrics."""
10 | pass
11 |
12 |
13 | builds_metric = make_custom_builds_fn(
14 | populate_full_signature=True,
15 | )
16 |
17 | TensorStatisticConfig = builds_metric(metrics.TensorStatistic, builds_bases=(MetricConfig,))
18 |
19 |
20 | TorchmetricsWrapperConfig = builds_metric(metrics.TorchmetricsWrapper, builds_bases=(MetricConfig,))
21 |
22 |
23 | PurityMetricConfig = builds_metric(
24 | metrics.MutualInfoAndPairCounting,
25 | metric_name="purity",
26 | builds_bases=(MetricConfig,),
27 | )
28 | PrecisionMetricConfig = builds_metric(
29 | metrics.MutualInfoAndPairCounting,
30 | metric_name="precision",
31 | builds_bases=(MetricConfig,),
32 | )
33 | RecallMetricConfig = builds_metric(
34 | metrics.MutualInfoAndPairCounting,
35 | metric_name="recall",
36 | builds_bases=(MetricConfig,),
37 | )
38 | F1MetricConfig = builds_metric(
39 | metrics.MutualInfoAndPairCounting,
40 | metric_name="f1",
41 | builds_bases=(MetricConfig,),
42 | )
43 | AMIMetricConfig = builds_metric(
44 | metrics.MutualInfoAndPairCounting,
45 | metric_name="ami",
46 | builds_bases=(MetricConfig,),
47 | )
48 | NMIMetricConfig = builds_metric(
49 | metrics.MutualInfoAndPairCounting,
50 | metric_name="nmi",
51 | builds_bases=(MetricConfig,),
52 | )
53 | ARISklearnMetricConfig = builds_metric(
54 | metrics.MutualInfoAndPairCounting,
55 | metric_name="ari_sklearn",
56 | builds_bases=(MetricConfig,),
57 | )
58 |
59 | ARIMetricConfig = builds_metric(metrics.ARIMetric, builds_bases=(MetricConfig,))
60 | PatchARIMetricConfig = builds_metric(
61 | metrics.PatchARIMetric,
62 | builds_bases=(MetricConfig,),
63 | )
64 | UnsupervisedMaskIoUMetricConfig = builds_metric(
65 | metrics.UnsupervisedMaskIoUMetric,
66 | builds_bases=(MetricConfig,),
67 | )
68 | MOTMetricConfig = builds_metric(
69 | metrics.MOTMetric,
70 | builds_bases=(MetricConfig,),
71 | )
72 | MaskCorLocMetricConfig = builds_metric(
73 | metrics.UnsupervisedMaskIoUMetric,
74 | matching="best_overlap",
75 | correct_localization=True,
76 | builds_bases=(MetricConfig,),
77 | )
78 | AverageBestOverlapMetricConfig = builds_metric(
79 | metrics.UnsupervisedMaskIoUMetric,
80 | matching="best_overlap",
81 | builds_bases=(MetricConfig,),
82 | )
83 | BestOverlapObjectRecoveryMetricConfig = builds_metric(
84 | metrics.UnsupervisedMaskIoUMetric,
85 | matching="best_overlap",
86 | compute_discovery_fraction=True,
87 | builds_bases=(MetricConfig,),
88 | )
89 | UnsupervisedBboxIoUMetricConfig = builds_metric(
90 | metrics.UnsupervisedBboxIoUMetric,
91 | builds_bases=(MetricConfig,),
92 | )
93 | BboxCorLocMetricConfig = builds_metric(
94 | metrics.UnsupervisedBboxIoUMetric,
95 | matching="best_overlap",
96 | correct_localization=True,
97 | builds_bases=(MetricConfig,),
98 | )
99 | BboxRecallMetricConfig = builds_metric(
100 | metrics.UnsupervisedBboxIoUMetric,
101 | matching="best_overlap",
102 | compute_discovery_fraction=True,
103 | builds_bases=(MetricConfig,),
104 | )
105 |
106 |
107 | DatasetSemanticMaskIoUMetricConfig = builds_metric(metrics.DatasetSemanticMaskIoUMetric)
108 |
109 | SklearnClusteringConfig = builds(
110 | metrics.SklearnClustering,
111 | populate_full_signature=True,
112 | )
113 |
114 |
115 | def register_configs(config_store):
116 | config_store.store(group="metrics", name="tensor_statistic", node=TensorStatisticConfig)
117 |
118 | config_store.store(group="metrics", name="torchmetric", node=TorchmetricsWrapperConfig)
119 | config_store.store(group="metrics", name="ami_metric", node=AMIMetricConfig)
120 | config_store.store(group="metrics", name="nmi_metric", node=NMIMetricConfig)
121 | config_store.store(group="metrics", name="ari_sklearn_metric", node=ARISklearnMetricConfig)
122 | config_store.store(group="metrics", name="purity_metric", node=PurityMetricConfig)
123 | config_store.store(group="metrics", name="precision_metric", node=PrecisionMetricConfig)
124 | config_store.store(group="metrics", name="recall_metric", node=RecallMetricConfig)
125 | config_store.store(group="metrics", name="f1_metric", node=F1MetricConfig)
126 |
127 | config_store.store(group="metrics", name="mot_metric", node=MOTMetricConfig)
128 | config_store.store(group="metrics", name="ari_metric", node=ARIMetricConfig)
129 | config_store.store(group="metrics", name="patch_ari_metric", node=PatchARIMetricConfig)
130 | config_store.store(
131 | group="metrics", name="unsupervised_mask_iou_metric", node=UnsupervisedMaskIoUMetricConfig
132 | )
133 | config_store.store(group="metrics", name="mask_corloc_metric", node=MaskCorLocMetricConfig)
134 | config_store.store(
135 | group="metrics", name="average_best_overlap_metric", node=AverageBestOverlapMetricConfig
136 | )
137 | config_store.store(
138 | group="metrics",
139 | name="best_overlap_object_recovery_metric",
140 | node=BestOverlapObjectRecoveryMetricConfig,
141 | )
142 | config_store.store(
143 | group="metrics", name="unsupervised_bbox_iou_metric", node=UnsupervisedBboxIoUMetricConfig
144 | )
145 | config_store.store(group="metrics", name="bbox_corloc_metric", node=BboxCorLocMetricConfig)
146 | config_store.store(group="metrics", name="bbox_recall_metric", node=BboxRecallMetricConfig)
147 |
148 | config_store.store(
149 | group="metrics",
150 | name="dataset_semantic_mask_iou",
151 | node=DatasetSemanticMaskIoUMetricConfig,
152 | )
153 | config_store.store(
154 | group="clustering",
155 | name="sklearn_clustering",
156 | node=SklearnClusteringConfig,
157 | )
158 |
--------------------------------------------------------------------------------
/ocl/config/neural_networks.py:
--------------------------------------------------------------------------------
1 | """Configs for neural networks."""
2 | import omegaconf
3 | from hydra_zen import builds
4 |
5 | from ocl import neural_networks
6 |
7 | MLPBuilderConfig = builds(
8 | neural_networks.build_mlp,
9 | features=omegaconf.MISSING,
10 | zen_partial=True,
11 | populate_full_signature=True,
12 | )
13 | TransformerEncoderBuilderConfig = builds(
14 | neural_networks.build_transformer_encoder,
15 | n_layers=omegaconf.MISSING,
16 | n_heads=omegaconf.MISSING,
17 | zen_partial=True,
18 | populate_full_signature=True,
19 | )
20 | TransformerDecoderBuilderConfig = builds(
21 | neural_networks.build_transformer_decoder,
22 | n_layers=omegaconf.MISSING,
23 | n_heads=omegaconf.MISSING,
24 | zen_partial=True,
25 | populate_full_signature=True,
26 | )
27 |
28 |
29 | def register_configs(config_store):
30 | config_store.store(group="neural_networks", name="mlp", node=MLPBuilderConfig)
31 | config_store.store(
32 | group="neural_networks", name="transformer_encoder", node=TransformerEncoderBuilderConfig
33 | )
34 | config_store.store(
35 | group="neural_networks", name="transformer_decoder", node=TransformerDecoderBuilderConfig
36 | )
37 |
--------------------------------------------------------------------------------
/ocl/config/optimizers.py:
--------------------------------------------------------------------------------
1 | """Pytorch optimizers."""
2 | import dataclasses
3 |
4 | import torch.optim
5 | from hydra_zen import make_custom_builds_fn
6 |
7 |
8 | @dataclasses.dataclass
9 | class OptimizerConfig:
10 | pass
11 |
12 |
13 | # TODO(hornmax): We cannot automatically extract type information from the torch SGD implementation,
14 | # thus we define it manually here.
15 | @dataclasses.dataclass
16 | class SGDConfig(OptimizerConfig):
17 | learning_rate: float
18 | momentum: float = 0.0
19 | dampening: float = 0.0
20 | nestov: bool = False
21 | _target_: str = "hydra_zen.funcs.zen_processing"
22 | _zen_target: str = "torch.optim.SGD"
23 | _zen_partial: bool = True
24 |
25 |
26 | pbuilds = make_custom_builds_fn(
27 | zen_partial=True,
28 | populate_full_signature=True,
29 | )
30 |
31 | AdamConfig = pbuilds(torch.optim.Adam, builds_bases=(OptimizerConfig,))
32 | AdamWConfig = pbuilds(torch.optim.AdamW, builds_bases=(OptimizerConfig,))
33 |
34 |
35 | def register_configs(config_store):
36 | config_store.store(group="optimizers", name="sgd", node=SGDConfig)
37 | config_store.store(group="optimizers", name="adam", node=AdamConfig)
38 | config_store.store(group="optimizers", name="adamw", node=AdamWConfig)
39 |
--------------------------------------------------------------------------------
/ocl/config/perceptual_groupings.py:
--------------------------------------------------------------------------------
1 | """Perceptual grouping models."""
2 | import dataclasses
3 |
4 | from hydra_zen import builds
5 |
6 | from ocl import perceptual_grouping
7 |
8 |
9 | @dataclasses.dataclass
10 | class PerceptualGroupingConfig:
11 | """Configuration class of perceptual grouping models."""
12 |
13 |
14 | SlotAttentionConfig = builds(
15 | perceptual_grouping.SlotAttentionGrouping,
16 | builds_bases=(PerceptualGroupingConfig,),
17 | populate_full_signature=True,
18 | )
19 | SlotAttentionGumbelV1Config = builds(
20 | perceptual_grouping.SlotAttentionGroupingGumbelV1,
21 | builds_bases=(PerceptualGroupingConfig,),
22 | populate_full_signature=True,
23 | )
24 |
25 |
26 | def register_configs(config_store):
27 | config_store.store(group="schemas", name="perceptual_grouping", node=PerceptualGroupingConfig)
28 | config_store.store(group="perceptual_grouping", name="slot_attention", node=SlotAttentionConfig)
29 | config_store.store(group="perceptual_grouping", name="slot_attention_gumbel_v1", node=SlotAttentionGumbelV1Config)
30 |
--------------------------------------------------------------------------------
/ocl/config/predictor.py:
--------------------------------------------------------------------------------
1 | """Perceptual grouping models."""
2 | import dataclasses
3 |
4 | from hydra_zen import builds
5 |
6 | from ocl import predictor
7 |
8 |
9 | @dataclasses.dataclass
10 | class PredictorConfig:
11 | """Configuration class of Predictor."""
12 |
13 |
14 | TransitionConfig = builds(
15 | predictor.Predictor,
16 | builds_bases=(PredictorConfig,),
17 | populate_full_signature=True,
18 | )
19 |
20 |
21 | def register_configs(config_store):
22 | config_store.store(group="schemas", name="predictor", node=PredictorConfig)
23 | config_store.store(group="predictor", name="multihead_attention", node=TransitionConfig)
24 |
--------------------------------------------------------------------------------
/ocl/config/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions useful for configuration."""
2 | import ast
3 | from typing import Any, Callable
4 |
5 | from hydra_zen import builds
6 |
7 | from ocl.config.feature_extractors import FeatureExtractorConfig
8 | from ocl.config.perceptual_groupings import PerceptualGroupingConfig
9 | from ocl.distillation import EMASelfDistillation
10 | from ocl.utils.masking import CreateSlotMask
11 | from ocl.utils.routing import Combined, Recurrent
12 | import torch
13 |
14 | def lambda_string_to_function(function_string: str) -> Callable[..., Any]:
15 | """Convert string of the form "lambda x: x" into a callable Python function."""
16 | # This is a bit hacky but ensures that the syntax of the input is correct and contains
17 | # a valid lambda function definition without requiring to run `eval`.
18 | parsed = ast.parse(function_string)
19 | is_lambda = isinstance(parsed.body[0], ast.Expr) and isinstance(parsed.body[0].value, ast.Lambda)
20 | if not is_lambda:
21 | raise ValueError(f"'{function_string}' is not a valid lambda definition.")
22 |
23 | return eval(function_string)
24 |
25 |
26 | class ConfigDefinedLambda:
27 | """Lambda function defined in the config.
28 |
29 | This allows lambda functions defined in the config to be pickled.
30 | """
31 |
32 | def __init__(self, function_string: str):
33 | self.__setstate__(function_string)
34 |
35 | def __getstate__(self) -> str:
36 | return self.function_string
37 |
38 | def __setstate__(self, function_string: str):
39 | self.function_string = function_string
40 | self._fn = lambda_string_to_function(function_string)
41 |
42 | def __call__(self, *args, **kwargs):
43 | return self._fn(*args, **kwargs)
44 |
45 |
46 | def eval_lambda(function_string, *args):
47 | lambda_fn = lambda_string_to_function(function_string)
48 | return lambda_fn(*args)
49 |
50 |
51 | FunctionConfig = builds(ConfigDefinedLambda, populate_full_signature=True)
52 |
53 | # Inherit from all so it can be used in place of any module.
54 | CombinedConfig = builds(
55 | Combined,
56 | populate_full_signature=True,
57 | builds_bases=(FeatureExtractorConfig, PerceptualGroupingConfig),
58 | )
59 | RecurrentConfig = builds(
60 | Recurrent,
61 | populate_full_signature=True,
62 | builds_bases=(FeatureExtractorConfig, PerceptualGroupingConfig),
63 | )
64 | CreateSlotMaskConfig = builds(CreateSlotMask, populate_full_signature=True)
65 |
66 |
67 | EMASelfDistillationConfig = builds(
68 | EMASelfDistillation,
69 | populate_full_signature=True,
70 | builds_bases=(FeatureExtractorConfig, PerceptualGroupingConfig),
71 | )
72 |
73 |
74 | def make_slice(expr):
75 | if isinstance(expr, int):
76 | return expr
77 |
78 | pieces = [s and int(s) or None for s in expr.split(":")]
79 | if len(pieces) == 1:
80 | return slice(pieces[0], pieces[0] + 1)
81 | else:
82 | return slice(*pieces)
83 |
84 |
85 | def slice_string(string: str, split_char: str, slice_str: str) -> str:
86 | """Split a string according to a split_char and slice.
87 |
88 | If the output contains more than one element, join these using the split char again.
89 | """
90 | sl = make_slice(slice_str)
91 | res = string.split(split_char)[sl]
92 | if isinstance(res, list):
93 | res = split_char.join(res)
94 | return res
95 |
96 |
97 | def register_configs(config_store):
98 | config_store.store(group="schemas", name="lambda_fn", node=FunctionConfig)
99 | config_store.store(group="utils", name="combined", node=CombinedConfig)
100 | config_store.store(group="utils", name="selfdistillation", node=EMASelfDistillationConfig)
101 | config_store.store(group="utils", name="recurrent", node=RecurrentConfig)
102 | config_store.store(group="utils", name="create_slot_mask", node=CreateSlotMaskConfig)
103 |
104 |
105 | def register_resolvers(omegaconf):
106 | omegaconf.register_new_resolver("lambda_fn", ConfigDefinedLambda)
107 | omegaconf.register_new_resolver("eval_lambda", eval_lambda)
108 | omegaconf.register_new_resolver("slice", slice_string)
109 |
--------------------------------------------------------------------------------
/ocl/consistency.py:
--------------------------------------------------------------------------------
1 | """Modules to compute the IoU matching cost and solve the corresponding LSAP."""
2 | import numpy as np
3 | import torch
4 | from scipy.optimize import linear_sum_assignment
5 | from torch import nn
6 |
7 |
8 | class HungarianMatcher(nn.Module):
9 | """This class computes an assignment between the targets and the predictions of the network."""
10 |
11 | @torch.no_grad()
12 | def forward(self, mask_preds, mask_targets):
13 | """Performs the matching.
14 |
15 | Params:
16 | mask_preds: Tensor of dim [batch_size, n_objects, N, N] with the predicted masks
17 | mask_targets: Tensor of dim [batch_size, n_objects, N, N]
18 | with the target masks from another augmentation
19 |
20 | Returns:
21 | A list of size batch_size, containing tuples of (index_i, index_j) where:
22 | - index_i is the indices of the selected predictions
23 | - index_j is the indices of the corresponding selected targets
24 | """
25 | bs, n_objects, _, _ = mask_preds.shape
26 | # Compute the iou cost betwen masks
27 | cost_iou = -get_iou_matrix(mask_preds, mask_targets)
28 | cost_iou = cost_iou.reshape(bs, n_objects, bs, n_objects).cpu()
29 | self.costs = torch.stack([cost_iou[i, :, i, :][None] for i in range(bs)])
30 | indices = [linear_sum_assignment(c[0]) for c in self.costs]
31 | return torch.as_tensor(np.array(indices))
32 |
33 |
34 | def get_iou_matrix(preds, targets):
35 |
36 | bs, n_objects, H, W = targets.shape
37 | targets = targets.reshape(bs * n_objects, H * W).float()
38 | preds = preds.reshape(bs * n_objects, H * W).float()
39 |
40 | intersection = torch.matmul(targets, preds.t())
41 | targets_area = targets.sum(dim=1).view(1, -1)
42 | preds_area = preds.sum(dim=1).view(1, -1)
43 | union = (targets_area.t() + preds_area) - intersection
44 |
45 | return torch.where(
46 | union == 0,
47 | torch.tensor(0.0, device=targets.device),
48 | intersection / union,
49 | )
50 |
--------------------------------------------------------------------------------
/ocl/distillation.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Any, Dict, List, Optional, Union
3 |
4 | import torch
5 | from torch import nn
6 |
7 | from ocl import scheduling
8 | from ocl.utils.routing import Combined
9 | from ocl.utils.trees import get_tree_element
10 |
11 |
12 | class EMASelfDistillation(nn.Module):
13 | def __init__(
14 | self,
15 | student: Union[nn.Module, Dict[str, nn.Module]],
16 | schedule: scheduling.HPScheduler,
17 | student_remapping: Optional[Dict[str, str]] = None,
18 | teacher_remapping: Optional[Dict[str, str]] = None,
19 | ):
20 | super().__init__()
21 | # Do this for convenience to reduce crazy amount of nesting.
22 | if isinstance(student, dict):
23 | student = Combined(student)
24 | if student_remapping is None:
25 | student_remapping = {}
26 | if teacher_remapping is None:
27 | teacher_remapping = {}
28 |
29 | self.student = student
30 | self.teacher = copy.deepcopy(student)
31 | self.schedule = schedule
32 | self.student_remapping = {key: value.split(".") for key, value in student_remapping.items()}
33 | self.teacher_remapping = {key: value.split(".") for key, value in teacher_remapping.items()}
34 |
35 | def build_input_dict(self, inputs, remapping):
36 | if not remapping:
37 | return inputs
38 | # This allows us to bing the initial input and previous_output into a similar format.
39 | output_dict = {}
40 | for output_path, input_path in remapping.items():
41 | source = get_tree_element(inputs, input_path)
42 |
43 | output_path = output_path.split(".")
44 | cur_search = output_dict
45 | for path_part in output_path[:-1]:
46 | # Iterate along path and create nodes that do not exist yet.
47 | try:
48 | # Get element prior to last.
49 | cur_search = get_tree_element(cur_search, [path_part])
50 | except ValueError:
51 | # Element does not yet exist.
52 | cur_search[path_part] = {}
53 | cur_search = cur_search[path_part]
54 |
55 | cur_search[output_path[-1]] = source
56 | return output_dict
57 |
58 | def forward(self, inputs: Dict[str, Any]):
59 | if self.training:
60 | with torch.no_grad():
61 | m = self.schedule(inputs["global_step"]) # momentum parameter
62 | for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()):
63 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
64 |
65 | # prefix variable similar to combined module.
66 | prefix: List[str]
67 | if "prefix" in inputs.keys():
68 | prefix = inputs["prefix"]
69 | else:
70 | prefix = []
71 | inputs["prefix"] = prefix
72 |
73 | outputs = get_tree_element(inputs, prefix)
74 |
75 | # Forward pass student.
76 | prefix.append("student")
77 | outputs["student"] = {}
78 | student_inputs = self.build_input_dict(inputs, self.student_remapping)
79 | outputs["student"] = self.student(inputs={**inputs, **student_inputs})
80 | # Teacher and student share the same code, thus paths also need to be the same. To ensure
81 | # that we save the student outputs and run the teacher as if it where the student.
82 | student_output = outputs["student"]
83 |
84 | # Forward pass teacher, but pretending to be student.
85 | outputs["student"] = {}
86 | teacher_inputs = self.build_input_dict(inputs, self.teacher_remapping)
87 |
88 | with torch.no_grad():
89 | outputs["teacher"] = self.teacher(inputs={**inputs, **teacher_inputs})
90 | prefix.pop()
91 |
92 | # Set correct outputs again.
93 | outputs["student"] = student_output
94 |
95 | return outputs
96 |
--------------------------------------------------------------------------------
/ocl/hooks.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, Tuple
2 |
3 | import webdataset
4 | from pluggy import HookimplMarker, HookspecMarker
5 |
6 | from ocl.combined_model import CombinedModel
7 |
8 | hook_specification = HookspecMarker("ocl")
9 | hook_implementation = HookimplMarker("ocl")
10 |
11 |
12 | class FakeHooks:
13 | """Class that mimics the behavior of the plugin manager hooks property."""
14 |
15 | def __getattr__(self, attribute):
16 | """Return a fake hook handler for any attribute query."""
17 |
18 | def fake_hook_handler(*args, **kwargs):
19 | return tuple()
20 |
21 | return fake_hook_handler
22 |
23 |
24 | # @transform_hooks
25 | # def input_dependencies() -> Tuple[str, ...]:
26 | # """Provide list of variables that are required for the plugin to function."""
27 | #
28 | #
29 | # @transform_hooks
30 | # def provided_inputs() -> Tuple[str, ...]:
31 | # """Provide list of variables that are provided by the plugin."""
32 |
33 |
34 | @hook_specification
35 | def training_transform() -> Callable[[webdataset.Processor], webdataset.Processor]:
36 | """Provide a transformation which processes a component of a webdataset pipeline."""
37 |
38 |
39 | @hook_specification
40 | def training_batch_transform() -> Callable[[webdataset.Processor], webdataset.Processor]:
41 | """Provide a transformation which processes a batched component of a webdataset pipeline."""
42 |
43 |
44 | @hook_specification
45 | def training_fields() -> Tuple[str]:
46 | """Provide list of fields that are required to be decoded during training."""
47 |
48 |
49 | @hook_specification
50 | def evaluation_transform() -> Callable[[webdataset.Processor], webdataset.Processor]:
51 | """Provide a transformation which processes a component of a webdataset pipeline."""
52 |
53 |
54 | @hook_specification
55 | def evaluation_batch_transform() -> Callable[[webdataset.Processor], webdataset.Processor]:
56 | """Provide a transformation which processes a batched component of a webdataset pipeline."""
57 |
58 |
59 | @hook_specification
60 | def evaluation_fields() -> Tuple[str]:
61 | """Provide list of fields that are required to be decoded during evaluation."""
62 |
63 |
64 | @hook_specification
65 | def configure_optimizers(model: CombinedModel) -> Dict[str, Any]:
66 | """Return optimizers in the format of pytorch lightning."""
67 |
68 |
69 | @hook_specification
70 | def on_train_start(model: CombinedModel) -> None:
71 | """Hook called when starting training."""
72 |
73 |
74 | @hook_specification
75 | def on_train_epoch_start(model: CombinedModel) -> None:
76 | """Hook called when starting training epoch."""
77 |
--------------------------------------------------------------------------------
/ocl/matching.py:
--------------------------------------------------------------------------------
1 | """Methods for matching between sets of elements."""
2 | from typing import Tuple, Type
3 |
4 | import numpy as np
5 | import torch
6 | from scipy.optimize import linear_sum_assignment
7 | from torchtyping import TensorType
8 |
9 | # Avoid errors due to flake:
10 | batch_size = None
11 | n_elements = None
12 |
13 | CostMatrix = Type[TensorType["batch_size", "n_elements", "n_elements"]]
14 | AssignmentMatrix = Type[TensorType["batch_size", "n_elements", "n_elements"]]
15 | CostVector = Type[TensorType["batch_size"]]
16 |
17 |
18 | class Matcher(torch.nn.Module):
19 | """Matcher base class to define consistent interface."""
20 |
21 | def forward(self, C: CostMatrix) -> Tuple[AssignmentMatrix, CostVector]:
22 | pass
23 |
24 |
25 | class CPUHungarianMatcher(Matcher):
26 | """Implementaiton of a cpu hungarian matcher using scipy.optimize.linear_sum_assignment."""
27 |
28 | def forward(self, C: CostMatrix) -> Tuple[AssignmentMatrix, CostVector]:
29 | X = torch.zeros_like(C)
30 | C_cpu: np.ndarray = C.detach().cpu().numpy()
31 | for i, cost_matrix in enumerate(C_cpu):
32 | row_ind, col_ind = linear_sum_assignment(cost_matrix)
33 | X[i][row_ind, col_ind] = 1.0
34 | return X, (C * X).sum(dim=(1, 2))
35 |
--------------------------------------------------------------------------------
/ocl/memory_rollout.py:
--------------------------------------------------------------------------------
1 | """Memory roll-out module, following GPT-2 architecture.
2 |
3 | References:
4 | 1) minGPT by Andrej Karpathy:
5 | https://github.com/karpathy/minGPT/tree/master/mingpt
6 | 2) the official GPT-2 TensorFlow implementation released by OpenAI:
7 | https://github.com/openai/gpt-2/blob/master/src/model.py
8 | """
9 |
10 | import math
11 |
12 | import torch
13 | from torch import nn
14 |
15 | # -----------------------------------------------------------------------------
16 |
17 |
18 | class GELU(nn.Module):
19 | def forward(self, x):
20 | return (
21 | 0.5
22 | * x
23 | * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
24 | )
25 |
26 |
27 | class Block(nn.Module):
28 | """One GPT-2 decoder block, consists of a Masked Self-Attn and a FFN."""
29 |
30 | def __init__(self, n_embd, n_heads, dropout_rate):
31 | super().__init__()
32 | self.ln_1 = nn.LayerNorm(n_embd)
33 | self.attn = nn.MultiheadAttention(n_embd, n_heads, batch_first=True)
34 | self.ln_2 = nn.LayerNorm(n_embd)
35 | self.mlp = nn.ModuleDict(
36 | dict(
37 | c_fc=nn.Linear(n_embd, 4 * n_embd),
38 | c_proj=nn.Linear(4 * n_embd, n_embd),
39 | act=GELU(),
40 | dropout=nn.Dropout(dropout_rate),
41 | )
42 | )
43 | m = self.mlp
44 | self.ffn = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))))
45 |
46 | def forward(self, x, causal_mask):
47 | att, att_weights = self.attn(
48 | query=self.ln_1(x), key=self.ln_1(x), value=self.ln_1(x), attn_mask=causal_mask
49 | )
50 |
51 | x = x + att
52 | x = x + self.ffn(self.ln_2(x))
53 | return x, att_weights
54 |
55 |
56 | class GPT(nn.Module):
57 | """Memory roll-out GPT."""
58 |
59 | def __init__(
60 | self, buffer_len, n_layer, n_head, n_embd, embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0
61 | ):
62 | super().__init__()
63 | self.buffer_len = buffer_len
64 | self.n_layer = n_layer
65 | self.n_head = n_head
66 | self.n_embd = n_embd
67 | self.embd_pdrop = embd_pdrop
68 | self.resid_pdrop = resid_pdrop
69 | self.attn_pdrop = attn_pdrop
70 |
71 | self.transformer = nn.ModuleDict(
72 | dict(
73 | wte=nn.Linear(self.n_embd, self.n_embd, bias=False),
74 | wpe=nn.Embedding(self.buffer_len, self.n_embd),
75 | drop=nn.Dropout(self.embd_pdrop),
76 | h=nn.ModuleList(
77 | [Block(self.n_embd, self.n_head, self.resid_pdrop) for _ in range(self.n_layer)]
78 | ),
79 | ln_f=nn.LayerNorm(self.n_embd),
80 | )
81 | )
82 | # roll out to the same dimension
83 | self.roll_out_head = nn.Linear(self.n_embd, self.n_embd, bias=False)
84 |
85 | # init all weights
86 | self.apply(self._init_weights)
87 | for pn, p in self.named_parameters():
88 | if pn.endswith("c_proj.weight"):
89 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layer))
90 |
91 | # report number of parameters (note we don't count the decoder parameters in lm_head)
92 | n_params = sum(p.numel() for p in self.transformer.parameters())
93 | print("number of parameters: %.2fM" % (n_params / 1e6,))
94 |
95 | def _init_weights(self, module):
96 | if isinstance(module, nn.Linear):
97 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
98 | if module.bias is not None:
99 | torch.nn.init.zeros_(module.bias)
100 | elif isinstance(module, nn.Embedding):
101 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
102 | elif isinstance(module, nn.LayerNorm):
103 | torch.nn.init.zeros_(module.bias)
104 | torch.nn.init.ones_(module.weight)
105 |
106 | def forward(self, mem, mem_table, targets=None):
107 | device = mem.device
108 | b, t, n, d = mem.shape
109 |
110 | # reshape to merge the batch and num_buffer dimensionsni
111 | mem = mem.permute(0, 2, 1, 3).reshape(b * n, t, d)
112 | mem_table = mem_table.view(b * n, -1)
113 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
114 |
115 | tok_emb = self.transformer.wte(mem) # token embeddings of shape (b, t, n_embd)
116 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
117 | x = self.transformer.drop(tok_emb + pos_emb)
118 |
119 | # create causal attention masks
120 | # need to check correctness
121 | causal_masks = []
122 | for idx in range(b * n):
123 | occupied_len = mem_table[idx].cpu().numpy().astype(int)[0]
124 | if occupied_len == 0:
125 | occupied_len = 1
126 | # causal_mask = torch.tril(torch.ones(self.buffer_len, self.buffer_len).to(device)).view(
127 | # 1, self.buffer_len, self.buffer_len
128 | # )
129 | causal_mask = (
130 | torch.zeros(self.buffer_len, self.buffer_len)
131 | .to(device)
132 | .view(1, self.buffer_len, self.buffer_len)
133 | )
134 | causal_mask[:, occupied_len:, occupied_len:] = 1
135 | causal_mask = causal_mask > 0
136 | causal_masks.append(causal_mask)
137 | causal_masks = torch.stack(causal_masks)
138 | causal_masks = causal_masks.repeat(1, self.n_head, 1, 1).view(-1, t, t)
139 |
140 | for block in self.transformer.h:
141 | x, attn_weights = block(x, causal_masks)
142 | x = self.transformer.ln_f(x)
143 | x = self.roll_out_head(x) # [b*n, t, d]
144 |
145 | out = torch.zeros((b * n, d)).to(device)
146 |
147 | for idx in range(b * n):
148 | t_pos = mem_table[idx].cpu().numpy().astype(int)[0]
149 | if t_pos > 0 and t_pos < t:
150 | out[idx] = x[idx, t_pos - 1]
151 | return out.view(b, n, d)
152 |
--------------------------------------------------------------------------------
/ocl/mha.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class ScaledDotProductAttention(nn.Module):
7 | """Scaled Dot-Product Attention."""
8 |
9 | def __init__(self, temperature, attn_dropout=0.0):
10 | super().__init__()
11 | self.temperature = temperature
12 | self.dropout = nn.Dropout(attn_dropout)
13 |
14 | def forward(self, q, k, v, mask=None):
15 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
16 |
17 | # if mask is not None:
18 | # attn = attn.masked_fill(mask == 0, -1e9)
19 | if mask is not None:
20 | bias = (1 - mask) * (-1e9)
21 | attn = attn * mask + bias
22 |
23 | attn = F.softmax(attn, dim=-1)
24 | output = torch.matmul(attn, v)
25 |
26 | return output, attn
27 |
28 |
29 | class MultiHeadAttention_for_index(nn.Module):
30 | """Multi-Head Attention module."""
31 |
32 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0):
33 | super().__init__()
34 |
35 | self.n_head = n_head
36 | self.d_k = d_k
37 | self.d_v = d_v
38 |
39 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
40 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
41 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
42 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
43 |
44 | nn.init.eye_(self.w_ks.weight)
45 | nn.init.eye_(self.w_vs.weight)
46 |
47 | self.attention = ScaledDotProductAttention(temperature=d_k**0.5) # temperature=d_k ** 0.5
48 |
49 | self.dropout = nn.Dropout(dropout)
50 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
51 |
52 | def forward(self, q, k, v, mask=None):
53 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
54 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
55 |
56 | # Pass through the pre-attention projection: b x lq x (n*dv)
57 | # Separate different heads: b x lq x n x dv
58 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
59 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
60 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
61 |
62 | # Transpose for attention dot product: b x n x lq x dv
63 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
64 |
65 | if mask is not None:
66 | mask = mask.unsqueeze(1) # For head axis broadcasting.
67 |
68 | q, attn = self.attention(q, k, v, mask=mask)
69 |
70 | # Transpose to move the head dimension back: b x lq x n x dv
71 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
72 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
73 | q = self.dropout(self.fc(q))
74 |
75 | q = self.layer_norm(q)
76 |
77 | attn = torch.mean(attn, 1)
78 | return q, attn
79 |
80 |
81 | class MultiHeadAttention(nn.Module):
82 | """Multi-Head Attention module."""
83 |
84 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0):
85 | super().__init__()
86 |
87 | self.n_head = n_head
88 | self.d_k = d_k
89 | self.d_v = d_v
90 |
91 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
92 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
93 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
94 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
95 |
96 | nn.init.eye_(self.w_qs.weight)
97 | nn.init.eye_(self.w_ks.weight)
98 | nn.init.eye_(self.w_vs.weight)
99 | nn.init.eye_(self.fc.weight)
100 | self.attention = ScaledDotProductAttention(temperature=0.5)
101 |
102 | self.dropout = nn.Dropout(dropout)
103 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
104 |
105 | def forward(self, q, k, v, mask=None):
106 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
107 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
108 |
109 | # Pass through the pre-attention projection: b x lq x (n*dv)
110 | # Separate different heads: b x lq x n x dv
111 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
112 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
113 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
114 |
115 | # Transpose for attention dot product: b x n x lq x dv
116 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
117 |
118 | if mask is not None:
119 | mask = mask.unsqueeze(1) # For head axis broadcasting.
120 |
121 | q, attn = self.attention(q, k, v, mask=mask)
122 |
123 | # Transpose to move the head dimension back: b x lq x n x dv
124 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
125 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
126 | q = self.dropout(self.fc(q))
127 |
128 | # Just return weighted sum, do not apply residule
129 | attn = torch.mean(attn, 1)
130 | return q, attn
131 |
--------------------------------------------------------------------------------
/ocl/models/__init__.py:
--------------------------------------------------------------------------------
1 | """Models defined in code."""
2 | # from ocl.models.sa_detr import SA_DETR
3 | # from ocl.models.savi import SAVi
4 | # from ocl.models.savi_with_memory import SAVi_mem
5 | # from ocl.models.savi_with_memory import SAVi_mem
6 |
7 | # __all__ = ["SAVi", "SAVi_mem", "SA_DETR"]
8 |
--------------------------------------------------------------------------------
/ocl/models/image_grouping.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from torch import nn
4 |
5 | from ocl.path_defaults import VIDEO
6 | from ocl.utils.trees import get_tree_element
7 |
8 |
9 | class GroupingImg(nn.Module):
10 | def __init__(
11 | self,
12 | conditioning: nn.Module,
13 | feature_extractor: nn.Module,
14 | perceptual_grouping: nn.Module,
15 | object_decoder: nn.Module,
16 | masks_as_image = None,
17 | decoder_mode = "MLP",
18 |
19 | ):
20 | super().__init__()
21 | self.conditioning = conditioning
22 | self.feature_extractor = feature_extractor
23 | self.perceptual_grouping = perceptual_grouping
24 | self.object_decoder = object_decoder
25 | self.masks_as_image = masks_as_image
26 | self.decoder_mode = decoder_mode
27 |
28 | def forward(self, inputs: Dict[str, Any]):
29 | outputs = inputs
30 | video = get_tree_element(inputs, VIDEO.split("."))
31 | video.shape
32 |
33 | # feature extraction
34 | features = self.feature_extractor(video=video)
35 | outputs["feature_extractor"] = features
36 |
37 | # slot initialization
38 | batch_size = video.shape[0]
39 | conditioning = self.conditioning(batch_size=batch_size)
40 | outputs["conditioning"] = conditioning
41 |
42 | # slot computation
43 | perceptual_grouping_output = self.perceptual_grouping(
44 | extracted_features=features, conditioning=conditioning
45 | )
46 | outputs["perceptual_grouping"] = perceptual_grouping_output
47 |
48 | # slot decoding
49 | object_features = get_tree_element(outputs, "perceptual_grouping.objects".split("."))
50 | masks = get_tree_element(outputs, "perceptual_grouping.feature_attributions".split("."))
51 | target = get_tree_element(outputs, "feature_extractor.features".split("."))
52 | image = get_tree_element(outputs, "input.image".split("."))
53 | empty_object = None
54 |
55 | if self.decoder_mode == "MLP":
56 | decoder_output = self.object_decoder(object_features=object_features,
57 | target=target,
58 | image = image)
59 | elif self.decoder_mode == "Transformer":
60 | decoder_output = self.object_decoder(object_features=object_features,
61 | masks=masks,
62 | target=target,
63 | image=image,
64 | empty_objects = None)
65 | else:
66 | raise RuntimeError
67 |
68 | outputs["object_decoder"] = decoder_output
69 | outputs["masks_as_image"]= self.masks_as_image(tensor = get_tree_element(outputs, "object_decoder.masks".split(".")))
70 |
71 | return outputs
72 |
--------------------------------------------------------------------------------
/ocl/models/image_grouping_adaslot.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from torch import nn
4 |
5 | from ocl.path_defaults import VIDEO
6 | from ocl.utils.trees import get_tree_element
7 | import torch
8 |
9 | class GroupingImgGumbel(nn.Module):
10 | def __init__(
11 | self,
12 | conditioning: nn.Module,
13 | feature_extractor: nn.Module,
14 | perceptual_grouping: nn.Module,
15 | object_decoder: nn.Module,
16 | masks_as_image = None,
17 | decoder_mode = "MLP",
18 | ):
19 | super().__init__()
20 | self.conditioning = conditioning
21 | self.feature_extractor = feature_extractor
22 | self.perceptual_grouping = perceptual_grouping
23 | self.object_decoder = object_decoder
24 | self.masks_as_image = masks_as_image
25 | self.decoder_mode = decoder_mode
26 | object_dim = self.conditioning.object_dim
27 |
28 | def forward(self, inputs: Dict[str, Any]):
29 | outputs = inputs
30 | video = get_tree_element(inputs, VIDEO.split("."))
31 | video.shape
32 |
33 | # feature extraction
34 | features = self.feature_extractor(video=video)
35 | outputs["feature_extractor"] = features
36 |
37 | # slot initialization
38 | batch_size = video.shape[0]
39 | conditioning = self.conditioning(batch_size=batch_size)
40 | outputs["conditioning"] = conditioning
41 |
42 | # slot computation
43 | perceptual_grouping_output = self.perceptual_grouping(
44 | extracted_features=features, conditioning=conditioning
45 | )
46 | outputs["perceptual_grouping"] = perceptual_grouping_output
47 | outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
48 | outputs["slots_keep_prob"] = perceptual_grouping_output["slots_keep_prob"]
49 |
50 | ##
51 | object_features, hard_keep_decision = perceptual_grouping_output["objects"], perceptual_grouping_output["hard_keep_decision"] # (b * t, s, d), (b * t, s, n)
52 | # slot decoding
53 | # object_features = get_tree_element(outputs, "perceptual_grouping.objects".split("."))
54 | masks = get_tree_element(outputs, "perceptual_grouping.feature_attributions".split("."))
55 | target = get_tree_element(outputs, "feature_extractor.features".split("."))
56 | image = get_tree_element(outputs, "input.image".split("."))
57 | empty_object = None
58 |
59 | if self.decoder_mode == "MLP":
60 | decoder_output = self.object_decoder(object_features=object_features,
61 | target=target,
62 | image = image,
63 | left_mask = hard_keep_decision)
64 | elif self.decoder_mode == "Transformer":
65 | decoder_output = self.object_decoder(object_features=object_features,
66 | masks=masks,
67 | target=target,
68 | image=image,
69 | empty_objects = None,
70 | left_mask = hard_keep_decision)
71 | else:
72 | raise RuntimeError
73 |
74 | outputs["object_decoder"] = decoder_output
75 | if not self.masks_as_image is None:
76 | outputs["masks_as_image"]= self.masks_as_image(tensor = get_tree_element(outputs, "object_decoder.masks".split(".")))
77 | return outputs
--------------------------------------------------------------------------------
/ocl/models/image_grouping_adaslot_pixel.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from torch import nn
4 |
5 | from ocl.path_defaults import VIDEO
6 | from ocl.utils.trees import get_tree_element
7 | import torch
8 |
9 | class GroupingImgGumbel(nn.Module):
10 | def __init__(
11 | self,
12 | conditioning: nn.Module,
13 | feature_extractor: nn.Module,
14 | perceptual_grouping: nn.Module,
15 | object_decoder: nn.Module,
16 | masks_as_image = None,
17 | ):
18 | super().__init__()
19 | self.conditioning = conditioning
20 | self.feature_extractor = feature_extractor
21 | self.perceptual_grouping = perceptual_grouping
22 | self.object_decoder = object_decoder
23 | self.masks_as_image = masks_as_image
24 | object_dim = self.conditioning.object_dim
25 |
26 | def forward(self, inputs: Dict[str, Any]):
27 | outputs = inputs
28 | video = get_tree_element(inputs, VIDEO.split("."))
29 | video.shape
30 |
31 | # feature extraction
32 | features = self.feature_extractor(video=video)
33 | outputs["feature_extractor"] = features
34 |
35 | # slot initialization
36 | batch_size = video.shape[0]
37 | conditioning = self.conditioning(batch_size=batch_size)
38 | outputs["conditioning"] = conditioning
39 |
40 | # slot computation
41 | perceptual_grouping_output = self.perceptual_grouping(
42 | extracted_features=features, conditioning=conditioning
43 | )
44 | outputs["perceptual_grouping"] = perceptual_grouping_output
45 | outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
46 | outputs["slots_keep_prob"] = perceptual_grouping_output["slots_keep_prob"]
47 |
48 | ##
49 | object_features, hard_keep_decision = perceptual_grouping_output["objects"], perceptual_grouping_output["hard_keep_decision"] # (b * t, s, d), (b * t, s, n)
50 | # slot decoding
51 | # object_features = get_tree_element(outputs, "perceptual_grouping.objects".split("."))
52 | decoder_output = self.object_decoder(object_features=object_features,
53 | left_mask = hard_keep_decision)
54 |
55 | outputs["object_decoder"] = decoder_output
56 | if not self.masks_as_image is None:
57 | outputs["masks_as_image"]= self.masks_as_image(tensor = get_tree_element(outputs, "object_decoder.masks".split(".")))
58 | return outputs
--------------------------------------------------------------------------------
/ocl/neural_networks/__init__.py:
--------------------------------------------------------------------------------
1 | from ocl.neural_networks.convenience import (
2 | build_mlp,
3 | build_transformer_decoder,
4 | build_transformer_encoder,
5 | build_two_layer_mlp,
6 | )
7 |
8 | __all__ = [
9 | "build_mlp",
10 | "build_transformer_decoder",
11 | "build_transformer_encoder",
12 | "build_two_layer_mlp",
13 | ]
14 |
--------------------------------------------------------------------------------
/ocl/neural_networks/convenience.py:
--------------------------------------------------------------------------------
1 | """Convenience functions for the construction neural networks using config."""
2 | from typing import Callable, List, Optional, Union
3 |
4 | from torch import nn
5 |
6 | from ocl.neural_networks.extensions import TransformerDecoderWithAttention
7 | from ocl.neural_networks.wrappers import Residual
8 |
9 |
10 | class ReLUSquared(nn.Module):
11 | def __init__(self, inplace=False):
12 | super().__init__()
13 | self.inplace = inplace
14 |
15 | def forward(self, x):
16 | return nn.functional.relu(x, inplace=self.inplace) ** 2
17 |
18 |
19 | def get_activation_fn(name: str, inplace: bool = True, leaky_relu_slope: Optional[float] = None):
20 | if callable(name):
21 | return name
22 |
23 | name = name.lower()
24 | if name == "relu":
25 | return nn.ReLU(inplace=inplace)
26 | elif name == "relu_squared":
27 | return ReLUSquared(inplace=inplace)
28 | elif name == "leaky_relu":
29 | if leaky_relu_slope is None:
30 | raise ValueError("Slope of leaky ReLU was not defined")
31 | return nn.LeakyReLU(leaky_relu_slope, inplace=inplace)
32 | elif name == "tanh":
33 | return nn.Tanh()
34 | elif name == "sigmoid":
35 | return nn.Sigmoid()
36 | elif name == "identity":
37 | return nn.Identity()
38 | else:
39 | raise ValueError(f"Unknown activation function {name}")
40 |
41 |
42 | def build_mlp(
43 | input_dim: int,
44 | output_dim: int,
45 | features: List[int],
46 | activation_fn: Union[str, Callable] = "relu",
47 | final_activation_fn: Optional[Union[str, Callable]] = None,
48 | initial_layer_norm: bool = False,
49 | residual: bool = False,
50 | ) -> nn.Sequential:
51 | layers = []
52 | current_dim = input_dim
53 | if initial_layer_norm:
54 | layers.append(nn.LayerNorm(current_dim))
55 |
56 | for n_features in features:
57 | layers.append(nn.Linear(current_dim, n_features))
58 | nn.init.zeros_(layers[-1].bias)
59 | layers.append(get_activation_fn(activation_fn))
60 | current_dim = n_features
61 |
62 | layers.append(nn.Linear(current_dim, output_dim))
63 | nn.init.zeros_(layers[-1].bias)
64 | if final_activation_fn is not None:
65 | layers.append(get_activation_fn(final_activation_fn))
66 |
67 | if residual:
68 | return Residual(nn.Sequential(*layers))
69 | return nn.Sequential(*layers)
70 |
71 |
72 | def build_two_layer_mlp(
73 | input_dim, output_dim, hidden_dim, initial_layer_norm: bool = False, residual: bool = False
74 | ):
75 | """Build a two layer MLP, with optional initial layer norm.
76 |
77 | Separate class as this type of construction is used very often for slot attention and
78 | transformers.
79 | """
80 | return build_mlp(
81 | input_dim, output_dim, [hidden_dim], initial_layer_norm=initial_layer_norm, residual=residual
82 | )
83 |
84 |
85 | def build_transformer_encoder(
86 | input_dim: int,
87 | output_dim: int,
88 | n_layers: int,
89 | n_heads: int,
90 | hidden_dim: Optional[int] = None,
91 | dropout: float = 0.0,
92 | activation_fn: Union[str, Callable] = "relu",
93 | layer_norm_eps: float = 1e-5,
94 | use_output_transform: bool = True,
95 | ):
96 | if hidden_dim is None:
97 | hidden_dim = 4 * input_dim
98 |
99 | layers = []
100 | for _ in range(n_layers):
101 | layers.append(
102 | nn.TransformerEncoderLayer(
103 | d_model=input_dim,
104 | nhead=n_heads,
105 | dim_feedforward=hidden_dim,
106 | dropout=dropout,
107 | activation=activation_fn,
108 | layer_norm_eps=layer_norm_eps,
109 | batch_first=True,
110 | norm_first=True,
111 | )
112 | )
113 |
114 | if use_output_transform:
115 | layers.append(nn.LayerNorm(input_dim, eps=layer_norm_eps))
116 | output_transform = nn.Linear(input_dim, output_dim, bias=True)
117 | nn.init.xavier_uniform_(output_transform.weight)
118 | nn.init.zeros_(output_transform.bias)
119 | layers.append(output_transform)
120 |
121 | return nn.Sequential(*layers)
122 |
123 |
124 | def build_transformer_decoder(
125 | input_dim: int,
126 | output_dim: int,
127 | n_layers: int,
128 | n_heads: int,
129 | hidden_dim: Optional[int] = None,
130 | dropout: float = 0.0,
131 | activation_fn: Union[str, Callable] = "relu",
132 | layer_norm_eps: float = 1e-5,
133 | return_attention_weights: bool = False,
134 | attention_weight_type: Union[int, str] = -1,
135 | ):
136 | if hidden_dim is None:
137 | hidden_dim = 4 * input_dim
138 |
139 | decoder_layer = nn.TransformerDecoderLayer(
140 | d_model=input_dim,
141 | nhead=n_heads,
142 | dim_feedforward=hidden_dim,
143 | dropout=dropout,
144 | activation=activation_fn,
145 | layer_norm_eps=layer_norm_eps,
146 | batch_first=True,
147 | norm_first=True,
148 | )
149 |
150 | if return_attention_weights:
151 | return TransformerDecoderWithAttention(
152 | decoder_layer,
153 | n_layers,
154 | return_attention_weights=True,
155 | attention_weight_type=attention_weight_type,
156 | )
157 | else:
158 | return nn.TransformerDecoder(decoder_layer, n_layers)
159 |
--------------------------------------------------------------------------------
/ocl/neural_networks/extensions.py:
--------------------------------------------------------------------------------
1 | """Extensions of existing layers to implement additional functionality."""
2 | from typing import Optional, Union
3 |
4 | import torch
5 | from torch import nn
6 |
7 |
8 | class TransformerDecoderWithAttention(nn.TransformerDecoder):
9 | """Modified nn.TransformerDecoder class that returns attention weights over memory."""
10 |
11 | def __init__(
12 | self,
13 | decoder_layer,
14 | num_layers,
15 | norm=None,
16 | return_attention_weights=False,
17 | attention_weight_type: Union[int, str] = "mean",
18 | ):
19 | super(TransformerDecoderWithAttention, self).__init__(decoder_layer, num_layers, norm)
20 |
21 | if return_attention_weights:
22 | self.attention_hooks = []
23 | for layer in self.layers:
24 | self.attention_hooks.append(self._prepare_layer(layer))
25 | else:
26 | self.attention_hooks = None
27 |
28 | if isinstance(attention_weight_type, int):
29 | if attention_weight_type >= num_layers or attention_weight_type < -num_layers:
30 | raise ValueError(
31 | f"Index {attention_weight_type} exceeds number of layers {num_layers}"
32 | )
33 | elif attention_weight_type != "mean":
34 | raise ValueError("`weights` needs to be a number or 'mean'.")
35 | self.weights = attention_weight_type
36 |
37 | def _prepare_layer(self, layer):
38 | assert isinstance(layer, nn.TransformerDecoderLayer)
39 |
40 | def _mha_block(self, x, mem, attn_mask, key_padding_mask):
41 | x = self.multihead_attn(
42 | x,
43 | mem,
44 | mem,
45 | attn_mask=attn_mask,
46 | key_padding_mask=key_padding_mask,
47 | need_weights=True,
48 | )[0]
49 | return self.dropout2(x)
50 |
51 | # Patch _mha_block method to compute attention weights
52 | layer._mha_block = _mha_block.__get__(layer, nn.TransformerDecoderLayer)
53 |
54 | class AttentionHook:
55 | def __init__(self):
56 | self._attention = None
57 |
58 | def pop(self) -> torch.Tensor:
59 | assert self._attention is not None, "Forward was not called yet!"
60 | attention = self._attention
61 | self._attention = None
62 | return attention
63 |
64 | def __call__(self, module, inp, outp):
65 | self._attention = outp[1]
66 |
67 | hook = AttentionHook()
68 | layer.multihead_attn.register_forward_hook(hook)
69 | return hook
70 |
71 | def forward(
72 | self,
73 | tgt: torch.Tensor,
74 | memory: torch.Tensor,
75 | tgt_mask: Optional[torch.Tensor] = None,
76 | memory_mask: Optional[torch.Tensor] = None,
77 | tgt_key_padding_mask: Optional[torch.Tensor] = None,
78 | memory_key_padding_mask: Optional[torch.Tensor] = None,
79 | ) -> torch.Tensor:
80 | output = tgt
81 |
82 | for mod in self.layers:
83 | output = mod(
84 | output,
85 | memory,
86 | tgt_mask=tgt_mask,
87 | memory_mask=memory_mask,
88 | tgt_key_padding_mask=tgt_key_padding_mask,
89 | memory_key_padding_mask=memory_key_padding_mask,
90 | )
91 |
92 | if self.norm is not None:
93 | output = self.norm(output)
94 |
95 | if self.attention_hooks is not None:
96 | attentions = []
97 | for hook in self.attention_hooks:
98 | attentions.append(hook.pop())
99 |
100 | if self.weights == "mean":
101 | attentions = torch.stack(attentions, dim=-1)
102 | # Take mean over all layers
103 | attention = attentions.mean(dim=-1)
104 | else:
105 | attention = attentions[self.weights]
106 |
107 | return output, attention.transpose(1, 2)
108 | else:
109 | return output
110 |
--------------------------------------------------------------------------------
/ocl/neural_networks/feature_pyramid_networks.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from ocl.utils.routing import RoutableMixin
8 |
9 |
10 | class FeaturePyramidDecoder(nn.Module, RoutableMixin):
11 | def __init__(
12 | self,
13 | slot_dim: int,
14 | feature_dim: int,
15 | mask_path: Optional[str] = None,
16 | slots_path: Optional[str] = None,
17 | features_path: Optional[str] = None,
18 | ):
19 | nn.Module.__init__(self)
20 | RoutableMixin.__init__(
21 | self,
22 | {
23 | "slots": slots_path,
24 | "mask": mask_path,
25 | "features": features_path,
26 | },
27 | )
28 |
29 | inter_dims = [slot_dim, slot_dim // 2, slot_dim // 4, slot_dim // 8, slot_dim // 16]
30 | # Depth dimension is slot dimension, no padding there and kernel size 1.
31 | self.lay1 = torch.nn.Conv3d(inter_dims[0], inter_dims[0], (1, 3, 3), padding=(0, 1, 1))
32 | self.gn1 = torch.nn.GroupNorm(8, inter_dims[0])
33 | self.lay2 = torch.nn.Conv3d(inter_dims[0], inter_dims[1], (1, 3, 3), padding=(0, 1, 1))
34 | self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
35 | self.lay3 = torch.nn.Conv3d(inter_dims[1], inter_dims[2], (1, 3, 3), padding=(0, 1, 1))
36 | self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
37 | self.lay4 = torch.nn.Conv3d(inter_dims[2], inter_dims[3], (1, 3, 3), padding=(0, 1, 1))
38 | self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
39 | self.lay5 = torch.nn.Conv3d(inter_dims[3], inter_dims[4], (1, 3, 3), padding=(0, 1, 1))
40 | self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
41 | self.out_lay = torch.nn.ConvTranspose3d(
42 | inter_dims[4],
43 | 1,
44 | stride=(1, 2, 2),
45 | kernel_size=(1, 3, 3),
46 | padding=(0, 1, 1),
47 | output_padding=(0, 1, 1),
48 | )
49 |
50 | upsampled_dim = feature_dim // 8
51 | self.upsampling = nn.ConvTranspose2d(
52 | feature_dim, upsampled_dim, kernel_size=8, stride=8
53 | ) # 112 x 112
54 | self.adapter1 = nn.Conv2d(
55 | upsampled_dim, inter_dims[0], kernel_size=5, padding=2, stride=8
56 | ) # Should downsample 112 to 14
57 | self.adapter2 = nn.Conv2d(
58 | upsampled_dim, inter_dims[1], kernel_size=5, padding=2, stride=4
59 | ) # 28x28
60 | self.adapter3 = nn.Conv2d(
61 | upsampled_dim, inter_dims[2], kernel_size=5, padding=2, stride=2
62 | ) # 56 x 56
63 | self.adapter4 = nn.Conv2d(
64 | upsampled_dim, inter_dims[3], kernel_size=5, padding=2, stride=1
65 | ) # 112 x 112
66 |
67 | for m in self.modules():
68 | if isinstance(m, nn.Conv3d):
69 | nn.init.kaiming_uniform_(m.weight, a=1)
70 | nn.init.constant_(m.bias, 0)
71 |
72 | def forward(self, slots: torch.Tensor, mask: torch.Tensor, features: torch.Tensor):
73 | # Bring features into image format with channels first
74 | features = features.unflatten(1, (14, 14)).permute(0, 3, 1, 2)
75 | mask = mask.unflatten(-1, (14, 14))
76 | # Use depth dimension for slots
77 | x = slots.transpose(1, 2)[..., None, None] * mask.unsqueeze(1)
78 | bs, n_channels, n_slots, width, height = x.shape
79 |
80 | upsampled_features = self.upsampling(features)
81 |
82 | # Add fake depth dimension for broadcasting and upsample representation.
83 | x = self.lay1(x) + self.adapter1(upsampled_features).unsqueeze(2)
84 | x = self.gn1(x)
85 | x = F.relu(x)
86 | x = self.lay2(x)
87 | x = self.gn2(x)
88 | x = F.relu(x)
89 |
90 | cur_fpn = self.adapter2(upsampled_features)
91 | # Add fake depth dimension for broadcasting and upsample representation.
92 | x = cur_fpn.unsqueeze(2) + F.interpolate(
93 | x, size=(n_slots,) + cur_fpn.shape[-2:], mode="nearest"
94 | )
95 | x = self.lay3(x)
96 | x = self.gn3(x)
97 | x = F.relu(x)
98 |
99 | cur_fpn = self.adapter3(upsampled_features)
100 | x = cur_fpn.unsqueeze(2) + F.interpolate(
101 | x, size=(n_slots,) + cur_fpn.shape[-2:], mode="nearest"
102 | )
103 | x = self.lay4(x)
104 | x = self.gn4(x)
105 | x = F.relu(x)
106 |
107 | cur_fpn = self.adapter4(upsampled_features)
108 | x = cur_fpn.unsqueeze(2) + F.interpolate(
109 | x, size=(n_slots,) + cur_fpn.shape[-2:], mode="nearest"
110 | )
111 | x = self.lay5(x)
112 | x = self.gn5(x)
113 | x = F.relu(x)
114 |
115 | # Squeeze channel dimension.
116 | x = self.out_lay(x).squeeze(1).softmax(1)
117 | return x
118 |
--------------------------------------------------------------------------------
/ocl/neural_networks/positional_embedding.py:
--------------------------------------------------------------------------------
1 | """Implementation of different positional embeddings."""
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class SoftPositionEmbed(nn.Module):
7 | """Embeding of positions using convex combination of learnable tensors.
8 |
9 | This assumes that the input positions are between 0 and 1.
10 | """
11 |
12 | def __init__(
13 | self, n_spatial_dims: int, feature_dim: int, cnn_channel_order=False, savi_style=False
14 | ):
15 | """__init__.
16 |
17 | Args:
18 | n_spatial_dims (int): Number of spatial dimensions.
19 | feature_dim (int): Dimensionality of the input features.
20 | cnn_channel_order (bool): Assume features are in CNN channel order (i.e. C x H x W).
21 | savi_style (bool): Use savi style positional encoding, where positions are normalized
22 | between -1 and 1 and a single dense layer is used for embedding.
23 | """
24 | super().__init__()
25 | self.savi_style = savi_style
26 | n_features = n_spatial_dims if savi_style else 2 * n_spatial_dims
27 | self.dense = nn.Linear(in_features=n_features, out_features=feature_dim)
28 | self.cnn_channel_order = cnn_channel_order
29 |
30 | def forward(self, inputs: torch.Tensor, positions: torch.Tensor):
31 | if self.savi_style:
32 | # Rescale positional encoding to -1 to 1
33 | positions = (positions - 0.5) * 2
34 | else:
35 | positions = torch.cat([positions, 1 - positions], axis=-1)
36 | emb_proj = self.dense(positions)
37 | if self.cnn_channel_order:
38 | emb_proj = emb_proj.permute(*range(inputs.ndim - 3), -1, -3, -2)
39 | return inputs + emb_proj
40 |
41 |
42 | class LearnedAdditivePositionalEmbed(nn.Module):
43 | """Add positional encoding as in SLATE."""
44 |
45 | def __init__(self, max_len, d_model, dropout=0.0):
46 | super().__init__()
47 | self.dropout = nn.Dropout(dropout)
48 | self.pe = nn.Parameter(torch.zeros(1, max_len, d_model), requires_grad=True)
49 | nn.init.trunc_normal_(self.pe)
50 |
51 | def forward(self, input):
52 | T = input.shape[1]
53 | return self.dropout(input + self.pe[:, :T])
54 |
55 |
56 | class DummyPositionEmbed(nn.Module):
57 | """Embedding that just passes through inputs without adding any positional embeddings."""
58 |
59 | def __init__(self):
60 | super().__init__()
61 |
62 | def forward(self, inputs: torch.Tensor, positions: torch.Tensor):
63 | return inputs
64 |
--------------------------------------------------------------------------------
/ocl/neural_networks/slate.py:
--------------------------------------------------------------------------------
1 | """Neural networks used for the implemenation of SLATE."""
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class OneHotDictionary(nn.Module):
7 | def __init__(self, vocab_size: int, emb_size: int):
8 | super().__init__()
9 | self.dictionary = nn.Embedding(vocab_size, emb_size)
10 |
11 | def forward(self, x):
12 | tokens = torch.argmax(x, dim=-1) # batch_size x N
13 | token_embs = self.dictionary(tokens) # batch_size x N x emb_size
14 | return token_embs
15 |
16 |
17 | class Conv2dBlockWithGroupNorm(nn.Module):
18 | def __init__(
19 | self,
20 | in_channels,
21 | out_channels,
22 | kernel_size,
23 | stride=1,
24 | padding=0,
25 | dilation=1,
26 | groups=1,
27 | bias=True,
28 | padding_mode="zeros",
29 | weight_init="xavier",
30 | ):
31 | super().__init__()
32 | self.conv2d = nn.Conv2d(
33 | in_channels,
34 | out_channels,
35 | kernel_size,
36 | stride,
37 | padding,
38 | dilation,
39 | groups,
40 | bias,
41 | padding_mode,
42 | )
43 |
44 | if weight_init == "kaiming":
45 | nn.init.kaiming_uniform_(self.conv2d.weight, nonlinearity="relu")
46 | else:
47 | nn.init.xavier_uniform_(self.conv2d.weight)
48 |
49 | if bias:
50 | nn.init.zeros_(self.conv2d.bias)
51 | self.group_norm = nn.GroupNorm(1, out_channels)
52 |
53 | def forward(self, x):
54 | x = self.conv2d(x)
55 | return nn.functional.relu(self.group_norm(x))
56 |
--------------------------------------------------------------------------------
/ocl/neural_networks/wrappers.py:
--------------------------------------------------------------------------------
1 | """Wrapper modules with allow the introduction or residuals or the combination of other modules."""
2 | from torch import nn
3 |
4 |
5 | class Residual(nn.Module):
6 | def __init__(self, module: nn.Module):
7 | super().__init__()
8 | self.module = module
9 |
10 | def forward(self, inputs):
11 | return inputs + self.module(inputs)
12 |
13 |
14 | class Sequential(nn.Module):
15 | """Extended sequential module that supports multiple inputs and outputs to layers.
16 |
17 | This allows a stack of layers where for example the first layer takes two inputs and only has
18 | a single output or where a layer has multiple outputs and the downstream layer takes multiple
19 | inputs.
20 | """
21 |
22 | def __init__(self, *layers):
23 | super().__init__()
24 | self.layers = nn.ModuleList(layers)
25 |
26 | def forward(self, *inputs):
27 | outputs = inputs
28 | for layer in self.layers:
29 | if isinstance(outputs, (tuple, list)):
30 | outputs = layer(*outputs)
31 | else:
32 | outputs = layer(outputs)
33 | return outputs
34 |
--------------------------------------------------------------------------------
/ocl/path_defaults.py:
--------------------------------------------------------------------------------
1 | """Default paths for different types of inputs.
2 |
3 | These are only defined for convenience and can also be overwritten using the appropriate *_path
4 | constructor variables of RoutableMixin subclasses.
5 | """
6 | MODEL = "model"
7 | INPUT = "input"
8 | VIDEO = f"{INPUT}.image"
9 | TEXT = f"{INPUT}.caption"
10 | BATCH_SIZE = f"{INPUT}.batch_size"
11 | BOX = f"{INPUT}.instance_bbox"
12 | MASK = f"{INPUT}.mask"
13 | ID = f"{INPUT}.instance_id"
14 | GLOBAL_STEP = "global_step"
15 | FEATURES = "feature_extractor"
16 | CONDITIONING = "conditioning"
17 | # TODO(hornmax): Currently decoders are nested in the task and accept PerceptualGroupingOutput as
18 | # input. In the future this will change and decoders should just be regular parts of the model.
19 | OBJECTS = "perceptual_grouping.objects"
20 | FEATURE_ATTRIBUTIONS = "perceptual_grouping.feature_attributions"
21 | OBJECT_DECODER = "object_decoder"
22 |
--------------------------------------------------------------------------------
/ocl/predictor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import nn
4 |
5 |
6 | class Predictor(nn.Module):
7 | def __init__(
8 | self,
9 | embed_dim: int = 128,
10 | num_heads: int = 4,
11 | qkv_size: int = 128,
12 | mlp_size: int = 256,
13 | pre_norm: bool = False,
14 | ):
15 | nn.Module.__init__(self)
16 |
17 | self.embed_dim = embed_dim
18 | self.qkv_size = qkv_size
19 | self.mlp_size = mlp_size
20 | self.num_heads = num_heads
21 | self.pre_norm = pre_norm
22 | self.MHA = nn.MultiheadAttention(embed_dim, num_heads)
23 |
24 | self.head_dim = qkv_size // num_heads
25 | self.mlp = torchvision.ops.MLP(embed_dim, [mlp_size, embed_dim])
26 | # layernorms
27 | self.layernorm_query = nn.LayerNorm(embed_dim, eps=1e-6)
28 | self.layernorm_mlp = nn.LayerNorm(embed_dim, eps=1e-6)
29 | # weights
30 | self.dense_q = nn.Linear(embed_dim, qkv_size)
31 | self.dense_k = nn.Linear(embed_dim, qkv_size)
32 | self.dense_v = nn.Linear(embed_dim, qkv_size)
33 | if self.num_heads > 1:
34 | self.dense_o = nn.Linear(qkv_size, embed_dim)
35 | self.multi_head = True
36 | else:
37 | self.multi_head = False
38 |
39 | def forward(
40 | self, object_features: torch.Tensor
41 | ): # TODO: add general attention for q, k, v, not just for x = qkv
42 | assert object_features.ndim == 3
43 | B, L, _ = object_features.shape
44 | head_dim = self.embed_dim // self.num_heads
45 |
46 | if self.pre_norm:
47 | # Self-attention.
48 | x = self.layernorm_query(object_features)
49 | q = self.dense_q(x).view(B, L, self.num_heads, head_dim)
50 | k = self.dense_k(x).view(B, L, self.num_heads, head_dim)
51 | v = self.dense_v(x).view(B, L, self.num_heads, head_dim)
52 | x, _ = self.MHA(q, k, v)
53 | if self.multi_head:
54 | x = self.dense_o(x.reshape(B, L, self.qkv_size)).view(B, L, self.embed_dim)
55 | else:
56 | x = x.squeeze(-2)
57 | x = x + object_features
58 |
59 | y = x
60 |
61 | # MLP
62 | z = self.layernorm_mlp(y)
63 | z = self.mlp(z)
64 | z = z + y
65 | else:
66 | # Self-attention on queries.
67 | x = object_features
68 | q = self.dense_q(x).view(B, L, self.num_heads, head_dim)
69 | k = self.dense_k(x).view(B, L, self.num_heads, head_dim)
70 | v = self.dense_v(x).view(B, L, self.num_heads, head_dim)
71 | x, _ = self.MHA(q, k, v)
72 | if self.multi_head:
73 | x = self.dense_o(x.reshape(B, L, self.qkv_size)).view(B, L, self.embed_dim)
74 | else:
75 | x = x.squeeze(-2)
76 | x = x + object_features
77 | x = self.layernorm_query(x)
78 |
79 | y = x
80 |
81 | # MLP
82 | z = self.mlp(y)
83 | z = z + y
84 | z = self.layernorm_mlp(z)
85 | return z
86 |
--------------------------------------------------------------------------------
/ocl/scheduling.py:
--------------------------------------------------------------------------------
1 | """Scheduling of learning rate and hyperparameters."""
2 | import abc
3 | import math
4 | import warnings
5 | from typing import Callable
6 |
7 | from torch.optim.lr_scheduler import _LRScheduler
8 |
9 |
10 | def warmup_fn(step: int, warmup_steps: int) -> float:
11 | """Learning rate warmup.
12 |
13 | Maps the step to a factor for rescaling the learning rate.
14 | """
15 | if warmup_steps:
16 | return min(1.0, step / warmup_steps)
17 | else:
18 | return 1.0
19 |
20 |
21 | def exp_decay_after_warmup_fn(
22 | step: int, decay_rate: float, decay_steps: int, warmup_steps: int
23 | ) -> float:
24 | """Decay function for exponential decay with learning rate warmup.
25 |
26 | Maps the step to a factor for rescaling the learning rate.
27 | """
28 | factor = warmup_fn(step, warmup_steps)
29 | if step < warmup_steps:
30 | return factor
31 | else:
32 | return factor * (decay_rate ** ((step - warmup_steps) / decay_steps))
33 |
34 |
35 | def exp_decay_with_warmup_fn(
36 | step: int, decay_rate: float, decay_steps: int, warmup_steps: int
37 | ) -> float:
38 | """Decay function for exponential decay with learning rate warmup.
39 |
40 | Maps the step to a factor for rescaling the learning rate.
41 | """
42 | factor = warmup_fn(step, warmup_steps)
43 | return factor * (decay_rate ** (step / decay_steps))
44 |
45 |
46 | class CosineAnnealingWithWarmup(_LRScheduler):
47 | """Cosine annealing with warmup."""
48 |
49 | def __init__(
50 | self,
51 | optimizer,
52 | T_max: int,
53 | warmup_steps: int = 0,
54 | eta_min: float = 0.0,
55 | last_epoch: int = -1,
56 | error_on_exceeding_steps: bool = True,
57 | verbose: bool = False,
58 | ):
59 | self.T_max = T_max
60 | self.warmup_steps = warmup_steps
61 | self.eta_min = eta_min
62 | self.error_on_exceeding_steps = error_on_exceeding_steps
63 | super().__init__(optimizer, last_epoch, verbose)
64 |
65 | def _linear_lr_warmup(self, base_lr, step_num):
66 | return base_lr * ((step_num + 0.5) / self.warmup_steps)
67 |
68 | def _cosine_annealing(self, base_lr, step_num):
69 | fraction_of_steps = (step_num - self.warmup_steps) / (self.T_max - self.warmup_steps - 1)
70 | return self.eta_min + 1 / 2 * (base_lr - self.eta_min) * (
71 | 1 + math.cos(math.pi * fraction_of_steps)
72 | )
73 |
74 | def get_lr(self):
75 | if not self._get_lr_called_within_step:
76 | warnings.warn(
77 | "To get the last learning rate computed by the scheduler, "
78 | "please use `get_last_lr()`."
79 | )
80 | step_num = self.last_epoch
81 |
82 | if step_num < self.warmup_steps:
83 | # Warmup.
84 | return [self._linear_lr_warmup(base_lr, step_num) for base_lr in self.base_lrs]
85 | elif step_num < self.T_max:
86 | # Cosine annealing.
87 | return [self._cosine_annealing(base_lr, step_num) for base_lr in self.base_lrs]
88 | else:
89 | if self.error_on_exceeding_steps:
90 | raise ValueError(
91 | "Tried to step {} times. The specified number of total steps is {}".format(
92 | step_num + 1, self.T_max
93 | )
94 | )
95 | else:
96 | return [self.eta_min for _ in self.base_lrs]
97 |
98 |
99 | HPSchedulerT = Callable[[int], float] # Type for function signatures.
100 |
101 |
102 | class HPScheduler(metaclass=abc.ABCMeta):
103 | """Base class for scheduling of scalar hyperparameters based on the number of training steps."""
104 |
105 | @abc.abstractmethod
106 | def __call__(self, step: int) -> float:
107 | """Return current value of hyperparameter based on global step."""
108 | pass
109 |
110 |
111 | class LinearHPScheduler(HPScheduler):
112 | def __init__(
113 | self, end_value: float, end_step: int, start_value: float = 0.0, start_step: int = 0
114 | ):
115 | super().__init__()
116 | if start_step > end_step:
117 | raise ValueError("`start_step` needs to be smaller equal to `end_step`.")
118 |
119 | self.start_value = start_value
120 | self.end_value = end_value
121 | self.start_step = start_step
122 | self.end_step = end_step
123 |
124 | def __call__(self, step: int) -> float:
125 | if step < self.start_step:
126 | return self.start_value
127 | elif step > self.end_step:
128 | return self.end_value
129 | else:
130 | t = step - self.start_step
131 | T = self.end_step - self.start_step
132 | return self.start_value + t * (self.end_value - self.start_value) / T
133 |
134 |
135 | class StepHPScheduler(HPScheduler):
136 | def __init__(self, end_value: float, switch_step: int, start_value: float = 0.0):
137 | super().__init__()
138 | self.start_value = start_value
139 | self.end_value = end_value
140 | self.switch_step = switch_step
141 |
142 | def __call__(self, step: int) -> float:
143 | if step < self.switch_step:
144 | return self.start_value
145 | elif step >= self.switch_step:
146 | return self.end_value
147 |
148 |
149 | class CosineAnnealingHPScheduler(HPScheduler):
150 | """Cosine annealing."""
151 |
152 | def __init__(self, start_value: float, end_value: float, start_step: int, end_step: int):
153 | super().__init__()
154 | assert start_value >= end_value
155 | assert start_step <= end_step
156 | self.start_value = start_value
157 | self.end_value = end_value
158 | self.start_step = start_step
159 | self.end_step = end_step
160 |
161 | def __call__(self, step: int) -> float:
162 |
163 | if step < self.start_step:
164 | value = self.start_value
165 | elif step >= self.end_step:
166 | value = self.end_value
167 | else:
168 | a = 0.5 * (self.start_value - self.end_value)
169 | b = 0.5 * (self.start_value + self.end_value)
170 | progress = (step - self.start_step) / (self.end_step - self.start_step)
171 | value = a * math.cos(math.pi * progress) + b
172 |
173 | return value
174 |
--------------------------------------------------------------------------------
/ocl/trees.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import dataclasses
3 | from collections import OrderedDict, abc
4 | from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
5 |
6 | import torch
7 |
8 | Tree = Union[Dict, List, Tuple]
9 |
10 | def get_tree_element(d: Tree, path: List[str]):
11 | """Get element of a tree."""
12 | next_element = d
13 |
14 | for next_element_name in path:
15 | if isinstance(next_element, abc.Mapping) and next_element_name in next_element:
16 | next_element = next_element[next_element_name]
17 | elif hasattr(next_element, next_element_name):
18 | next_element = getattr(next_element, next_element_name)
19 | elif isinstance(next_element, (list, tuple)) and next_element_name.isnumeric():
20 | next_element = next_element[int(next_element_name)]
21 | else:
22 | try:
23 | next_element = getattr(next_element, next_element_name)
24 | except AttributeError:
25 | msg = f"Trying to access path {'.'.join(path)}, "
26 | if isinstance(next_element, abc.Mapping):
27 | msg += f"but element {next_element_name} is not among keys {next_element.keys()}"
28 | elif isinstance(next_element, (list, tuple)):
29 | msg += f"but cannot index into list with {next_element_name}"
30 | else:
31 | msg += (
32 | f"but element {next_element_name} cannot be used to access attribute of "
33 | f"object of type {type(next_element)}"
34 | )
35 | raise ValueError(msg)
36 | return next_element
--------------------------------------------------------------------------------
/ocl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # We added this here to avoid issues in rebasing.
2 | # Long term the imports should be updated.
3 | from ocl.utils.windows import JoinWindows
4 |
5 | __all__ = ["JoinWindows"]
6 |
--------------------------------------------------------------------------------
/ocl/utils/annealing.py:
--------------------------------------------------------------------------------
1 | import math
2 | def cosine_anneal_factory(start_value, final_value, start_step, final_step):
3 | def cosine_anneal(step):
4 | assert start_value >= final_value
5 | assert start_step <= final_step
6 |
7 | if step < start_step:
8 | value = start_value
9 | elif step >= final_step:
10 | value = final_value
11 | else:
12 | a = 0.5 * (start_value - final_value)
13 | b = 0.5 * (start_value + final_value)
14 | progress = (step - start_step) / (final_step - start_step)
15 | value = a * math.cos(math.pi * progress) + b
16 |
17 | return value
18 | return cosine_anneal
--------------------------------------------------------------------------------
/ocl/utils/bboxes.py:
--------------------------------------------------------------------------------
1 | """Utilities for handling bboxes."""
2 | import torch
3 |
4 |
5 | def box_cxcywh_to_xyxy(x):
6 | x_c, y_c, w, h = x.unbind(-1)
7 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
8 | return torch.stack(b, dim=-1)
9 |
10 |
11 | def box_xyxy_to_cxcywh(x):
12 | x0, y0, x1, y1 = x.unbind(-1)
13 | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
14 | return torch.stack(b, dim=-1)
15 |
--------------------------------------------------------------------------------
/ocl/utils/masking.py:
--------------------------------------------------------------------------------
1 | """Utilities related to masking."""
2 | import math
3 | from typing import Optional
4 |
5 | import torch
6 | from torch import nn
7 |
8 | from ocl.utils.routing import RoutableMixin
9 |
10 |
11 | class CreateSlotMask(nn.Module, RoutableMixin):
12 | """Module intended to create a mask that marks empty slots.
13 |
14 | Module takes a tensor holding the number of slots per batch entry, and returns a binary mask of
15 | shape (batch_size, max_slots) where entries exceeding the number of slots are masked out.
16 | """
17 |
18 | def __init__(self, max_slots: int, n_slots_path: str):
19 | nn.Module.__init__(self)
20 | RoutableMixin.__init__(self, {"n_slots": n_slots_path})
21 | self.max_slots = max_slots
22 |
23 | @RoutableMixin.route
24 | def forward(self, n_slots: torch.Tensor) -> torch.Tensor:
25 | (batch_size,) = n_slots.shape
26 |
27 | # Create mask of shape B x K where the first n_slots entries per-row are false, the rest true
28 | indices = torch.arange(self.max_slots, device=n_slots.device)
29 | masks = indices.unsqueeze(0).expand(batch_size, -1) >= n_slots.unsqueeze(1)
30 |
31 | return masks
32 |
33 |
34 | class CreateRandomMaskPatterns(nn.Module, RoutableMixin):
35 | """Create random masks.
36 |
37 | Useful for showcasing behavior of metrics.
38 | """
39 |
40 | def __init__(
41 | self, pattern: str, masks_path: str, n_slots: Optional[int] = None, n_cols: int = 2
42 | ):
43 | nn.Module.__init__(self)
44 | RoutableMixin.__init__(self, {"masks": masks_path})
45 | if pattern not in ("random", "blocks"):
46 | raise ValueError(f"Unknown pattern {pattern}")
47 | self.pattern = pattern
48 | self.n_slots = n_slots
49 | self.n_cols = n_cols
50 |
51 | @RoutableMixin.route
52 | def forward(self, masks: torch.Tensor) -> torch.Tensor:
53 | if self.pattern == "random":
54 | rand_mask = torch.rand_like(masks)
55 | return rand_mask / rand_mask.sum(1, keepdim=True)
56 | elif self.pattern == "blocks":
57 | n_slots = masks.shape[1] if self.n_slots is None else self.n_slots
58 | height, width = masks.shape[-2:]
59 | new_masks = torch.zeros(
60 | len(masks), n_slots, height, width, device=masks.device, dtype=masks.dtype
61 | )
62 | blocks_per_col = int(n_slots // self.n_cols)
63 | remainder = n_slots - (blocks_per_col * self.n_cols)
64 | slot = 0
65 | for col in range(self.n_cols):
66 | rows = blocks_per_col if col < self.n_cols - 1 else blocks_per_col + remainder
67 | for row in range(rows):
68 | block_width = math.ceil(width / self.n_cols)
69 | block_height = math.ceil(height / rows)
70 | x = col * block_width
71 | y = row * block_height
72 | new_masks[:, slot, y : y + block_height, x : x + block_width] = 1
73 | slot += 1
74 | assert torch.allclose(new_masks.sum(1), torch.ones_like(masks[:, 0]))
75 | return new_masks
76 |
--------------------------------------------------------------------------------
/ocl/utils/resizing.py:
--------------------------------------------------------------------------------
1 | """Utilities related to resizing of tensors."""
2 | import math
3 | from typing import Optional, Tuple, Union
4 |
5 | import torch
6 | from torch import nn
7 |
8 | from ocl.utils.routing import RoutableMixin
9 |
10 |
11 | class Resize(nn.Module, RoutableMixin):
12 | """Module resizing tensors."""
13 |
14 | MODES = {"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"}
15 |
16 | def __init__(
17 | self,
18 | input_path: str,
19 | size: Optional[Union[int, Tuple[int, int]]] = None,
20 | take_size_from: Optional[str] = None,
21 | resize_mode: str = "bilinear",
22 | patch_mode: bool = False,
23 | channels_last: bool = False,
24 | ):
25 | nn.Module.__init__(self)
26 | RoutableMixin.__init__(self, {"tensor": input_path, "size_tensor": take_size_from})
27 |
28 | if size is not None and take_size_from is not None:
29 | raise ValueError("`size` and `take_size_from` can not be set at the same time")
30 | self.size = size
31 |
32 | if resize_mode not in Resize.MODES:
33 | raise ValueError(f"`mode` must be one of {Resize.MODES}")
34 | self.resize_mode = resize_mode
35 | self.patch_mode = patch_mode
36 | self.channels_last = channels_last
37 | self.expected_dims = 3 if patch_mode else 4
38 |
39 | @RoutableMixin.route
40 | def forward(
41 | self, tensor: torch.Tensor, size_tensor: Optional[torch.Tensor] = None
42 | ) -> torch.Tensor:
43 | """Resize tensor.
44 |
45 | Args:
46 | tensor: Tensor to resize. If `patch_mode=False`, assumed to be of shape (..., C, H, W).
47 | If `patch_mode=True`, assumed to be of shape (..., C, P), where P is the number of
48 | patches. Patches are assumed to be viewable as a perfect square image. If
49 | `channels_last=True`, channel dimension is assumed to be the last dimension instead.
50 | size_tensor: Tensor which size to resize to. If tensor has <=2 dimensions and the last
51 | dimension of this tensor has length 2, the two entries are taken as height and width.
52 | Otherwise, the size of the last two dimensions of this tensor are used as height
53 | and width.
54 |
55 | Returns: Tensor of shape (..., C, H, W), where height and width are either specified by
56 | `size` or `size_tensor`.
57 | """
58 | dims_to_flatten = tensor.ndim - self.expected_dims
59 | if dims_to_flatten > 0:
60 | flattened_dims = tensor.shape[: dims_to_flatten + 1]
61 | tensor = tensor.flatten(0, dims_to_flatten)
62 | elif dims_to_flatten < 0:
63 | raise ValueError(
64 | f"Tensor needs at least {self.expected_dims} dimensions, but only has {tensor.ndim}"
65 | )
66 |
67 | if self.patch_mode:
68 | if self.channels_last:
69 | tensor = tensor.transpose(-2, -1)
70 | n_channels, n_patches = tensor.shape[-2:]
71 | patch_size_float = math.sqrt(n_patches)
72 | patch_size = int(math.sqrt(n_patches))
73 | if patch_size_float != patch_size:
74 | raise ValueError(
75 | f"The number of patches needs to be a perfect square, but is {n_patches}."
76 | )
77 | tensor = tensor.view(-1, n_channels, patch_size, patch_size)
78 | else:
79 | if self.channels_last:
80 | tensor = tensor.permute(0, 3, 1, 2)
81 |
82 | if self.size is None:
83 | if size_tensor is None:
84 | raise ValueError("`size` is `None` but no `size_tensor` was passed.")
85 | if size_tensor.ndim <= 2 and size_tensor.shape[-1] == 2:
86 | height, width = size_tensor.unbind(-1)
87 | height = torch.atleast_1d(height)[0].squeeze().detach().cpu()
88 | width = torch.atleast_1d(width)[0].squeeze().detach().cpu()
89 | size = (int(height), int(width))
90 | else:
91 | size = size_tensor.shape[-2:]
92 | else:
93 | size = self.size
94 |
95 | tensor = torch.nn.functional.interpolate(
96 | tensor,
97 | size=size,
98 | mode=self.resize_mode,
99 | )
100 |
101 | if dims_to_flatten > 0:
102 | tensor = tensor.unflatten(0, flattened_dims)
103 |
104 | return tensor
105 |
106 |
107 | def resize_patches_to_image(
108 | patches: torch.Tensor,
109 | size: Optional[int] = None,
110 | scale_factor: Optional[float] = None,
111 | resize_mode: str = "bilinear",
112 | ) -> torch.Tensor:
113 | """Convert and resize a tensor of patches to image shape.
114 |
115 | This method requires that the patches can be converted to a square image.
116 |
117 | Args:
118 | patches: Patches to be converted of shape (..., C, P), where C is the number of channels and
119 | P the number of patches.
120 | size: Image size to resize to.
121 | scale_factor: Scale factor by which to resize the patches. Can be specified alternatively to
122 | `size`.
123 | resize_mode: Method to resize with. Valid options are "nearest", "nearest-exact", "bilinear",
124 | "bicubic".
125 |
126 | Returns: Tensor of shape (..., C, S, S) where S is the image size.
127 | """
128 | has_size = size is None
129 | has_scale = scale_factor is None
130 | if has_size == has_scale:
131 | raise ValueError("Exactly one of `size` or `scale_factor` must be specified.")
132 |
133 | n_channels = patches.shape[-2]
134 | n_patches = patches.shape[-1]
135 | patch_size_float = math.sqrt(n_patches)
136 | patch_size = int(math.sqrt(n_patches))
137 | if patch_size_float != patch_size:
138 | raise ValueError("The number of patches needs to be a perfect square.")
139 |
140 | image = torch.nn.functional.interpolate(
141 | patches.view(-1, n_channels, patch_size, patch_size),
142 | size=size,
143 | scale_factor=scale_factor,
144 | mode=resize_mode,
145 | )
146 |
147 | return image.view(*patches.shape[:-1], image.shape[-2], image.shape[-1])
148 |
--------------------------------------------------------------------------------
/ocl/utils/windows.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from ocl.utils.routing import RoutableMixin
5 |
6 |
7 | class JoinWindows(nn.Module, RoutableMixin):
8 | def __init__(self, n_windows: int, size, masks_path: str, key_path: str = "input.__key__"):
9 | nn.Module.__init__(self)
10 | RoutableMixin.__init__(self, {"masks": masks_path, "keys": key_path})
11 | self.n_windows = n_windows
12 | self.size = size
13 |
14 | @RoutableMixin.route
15 | def forward(self, masks: torch.Tensor, keys: str) -> torch.Tensor:
16 | assert len(masks) == self.n_windows
17 | keys_split = [key.split("_") for key in keys]
18 | pad_left = [int(elems[1]) for elems in keys_split]
19 | pad_top = [int(elems[2]) for elems in keys_split]
20 |
21 | target_height, target_width = self.size
22 | n_masks = masks.shape[0] * masks.shape[1]
23 | height, width = masks.shape[2], masks.shape[3]
24 | full_mask = torch.zeros(n_masks, *self.size).to(masks)
25 | x = 0
26 | y = 0
27 | for idx, mask in enumerate(masks):
28 | elems = masks.shape[1]
29 | x_start = 0 if pad_left[idx] >= 0 else -pad_left[idx]
30 | x_end = min(width, target_width - pad_left[idx])
31 | y_start = 0 if pad_top[idx] >= 0 else -pad_top[idx]
32 | y_end = min(height, target_height - pad_top[idx])
33 | cropped = mask[:, y_start:y_end, x_start:x_end]
34 | full_mask[
35 | idx * elems : (idx + 1) * elems, y : y + cropped.shape[-2], x : x + cropped.shape[-1]
36 | ] = cropped
37 | x += cropped.shape[-1]
38 | if x > target_width:
39 | y += cropped.shape[-2]
40 | x = 0
41 |
42 | assert torch.all(torch.abs(torch.sum(full_mask, axis=0) - 1) <= 1e-2)
43 |
44 | return full_mask.unsqueeze(0)
45 |
--------------------------------------------------------------------------------
/ocl/visualization_types.py:
--------------------------------------------------------------------------------
1 | """Classes for handling different types of visualizations."""
2 | import dataclasses
3 | from typing import Any, List, Optional, Union
4 |
5 | import matplotlib.pyplot
6 | import torch
7 | from torch.utils.tensorboard import SummaryWriter
8 | from torchtyping import TensorType
9 |
10 |
11 | def dataclass_to_dict(d):
12 | return {field.name: getattr(d, field.name) for field in dataclasses.fields(d)}
13 |
14 |
15 | @dataclasses.dataclass
16 | class Visualization:
17 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
18 | pass
19 |
20 |
21 | @dataclasses.dataclass
22 | class Figure(Visualization):
23 | """Matplotlib figure."""
24 |
25 | figure: matplotlib.pyplot.figure
26 | close: bool = True
27 |
28 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
29 | experiment.add_figure(**dataclass_to_dict(self), tag=tag, global_step=global_step)
30 |
31 |
32 | @dataclasses.dataclass
33 | class Image(Visualization):
34 | """Single image."""
35 |
36 | img_tensor: torch.Tensor
37 | dataformats: str = "CHW"
38 |
39 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
40 | experiment.add_image(**dataclass_to_dict(self), tag=tag, global_step=global_step)
41 |
42 |
43 | @dataclasses.dataclass
44 | class Images(Visualization):
45 | """Batch of images."""
46 |
47 | img_tensor: torch.Tensor
48 | dataformats: str = "NCHW"
49 |
50 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
51 | experiment.add_images(**dataclass_to_dict(self), tag=tag, global_step=global_step)
52 |
53 |
54 | @dataclasses.dataclass
55 | class Video(Visualization):
56 | """Batch of videos."""
57 |
58 | vid_tensor: TensorType["batch_size", "frames", "channels", "height", "width"] # noqa: F821
59 | fps: Union[int, float] = 4
60 |
61 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
62 | experiment.add_video(**dataclass_to_dict(self), tag=tag, global_step=global_step)
63 |
64 |
65 | class Embedding(Visualization):
66 | """Batch of embeddings."""
67 |
68 | mat: TensorType["batch_size", "feature_dim"] # noqa: F821
69 | metadata: Optional[List[Any]] = None
70 | label_img: Optional[TensorType["batch_size", "channels", "height", "width"]] = None # noqa: F821
71 | metadata_header: Optional[List[str]] = None
72 |
73 | def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
74 | experiment.add_embedding(**dataclass_to_dict(self), tag=tag, global_step=global_step)
75 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "ocl"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["Max Horn "]
6 |
7 | [tool.poetry.scripts]
8 | ocl_train = "ocl.cli.train:train"
9 | ocl_eval = "ocl.cli.eval:evaluate"
10 | ocl_compute_dataset_size = "ocl.cli.compute_dataset_size:compute_size"
11 |
12 | [tool.poetry.dependencies]
13 | python = ">=3.8,<3.9"
14 | webdataset = "^0.1.103"
15 | # There seems to be an issue in torch 1.12.x with masking and multi-head
16 | # attention. This prevents the usage of makes without a batch dimension.
17 | # Staying with torch 1.11.x version for now.
18 | torch = "1.12.*"
19 | pytorch-lightning = "^1.5.10"
20 | hydra-zen = "^0.7.0"
21 | torchtyping = "^0.1.4"
22 | hydra-core = "^1.2.0"
23 | pluggy = "^1.0.0"
24 | importlib-metadata = "4.2"
25 | torchvision = "0.13.*"
26 | Pillow = "9.0.1" # Newer versions of pillow seem to result in segmentation faults.
27 | torchmetrics = "^0.8.1"
28 | matplotlib = "^3.5.1"
29 | moviepy = "^1.0.3"
30 | scipy = "<=1.8"
31 | awscli = "^1.22.90"
32 | scikit-learn = "^1.0.2"
33 | pyamg = "^4.2.3"
34 | botocore = { extras = ["crt"], version = "^1.27.22" }
35 | timm = {version = "0.6.7", optional = true}
36 | hydra-submitit-launcher = { version = "^1.2.0", optional = true }
37 | decord = "0.6.0"
38 | motmetrics = "^1.2.5"
39 |
40 | ftfy = {version = "^6.1.1", optional = true}
41 | regex = {version = "^2022.7.9", optional = true}
42 | mlflow = {version = "^2.9.0", optional = true}
43 | einops = "^0.6.0"
44 | jupyter = "^1.0.0"
45 |
46 | [tool.poetry.dev-dependencies]
47 | black = "^22.1.0"
48 | pytest = "^7.0.1"
49 | flake8 = "^4.0.1"
50 | flake8-isort = "^4.1.1"
51 | pre-commit = "^2.17.0"
52 | flake8-tidy-imports = "^4.6.0"
53 | flake8-bugbear = "^22.1.11"
54 | flake8-docstrings = "^1.6.0"
55 |
56 | [tool.poetry.extras]
57 | timm = ["timm"]
58 | clip = ["clip", "ftfy", "regex"]
59 | submitit = ["hydra-submitit-launcher"]
60 | mlflow = ["mlflow"]
61 |
62 | [build-system]
63 | requires = ["poetry-core<=1.0.4"]
64 | build-backend = "poetry.core.masonry.api"
65 |
66 | [tool.black]
67 | line-length = 101
68 | target-version = ["py38"]
69 |
70 | [tool.isort]
71 | profile = "black"
72 | line_length = 101
73 | skip_gitignore = true
74 | remove_redundant_aliases = true
75 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | select=
3 | # F: errors from pyflake
4 | F,
5 | # W, E: warnings/errors from pycodestyle (PEP8)
6 | W, E,
7 | # I: problems with imports
8 | I,
9 | # B: bugbear warnings ("likely bugs and design problems")
10 | B,
11 | # D: docstring warnings from pydocstyle
12 | D
13 | ignore=
14 | # E203: whitespace before ':' (incompatible with black)
15 | E203,
16 | # E731: do not use a lambda expression, use a def (local def is often ugly)
17 | E731,
18 | # W503: line break before binary operator (incompatible with black)
19 | W503,
20 | # D1: docstring warnings related to missing documentation
21 | D1
22 | max-line-length = 101
23 | ban-relative-imports = true
24 | docstring-convention = google
25 | exclude = .*,__pycache__,./outputs
26 |
--------------------------------------------------------------------------------