├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── doc └── fig1.png ├── environment.yml ├── models └── download_pretrained.sh ├── pyproject.toml ├── results ├── correspondence │ ├── attention_interp_2_img.mp4 │ ├── attention_interp_2_img_2.mp4 │ ├── attention_interp_2_img_moto.mp4 │ └── correspondence.png └── figures │ ├── cityscapes_results.jpg │ ├── cocostuff27_results.jpg │ ├── correspondence.gif │ ├── dog_man_correspondence.png │ ├── img_correspondence.jpg │ └── stego.svg ├── scripts ├── cfg │ ├── convert_checkpoint_config.yaml │ ├── demo_config.yaml │ ├── eval_config.yaml │ ├── eval_wvn_config.yaml │ ├── knn_config.yaml │ ├── plot_config.yaml │ └── train_config.yaml ├── convert_original_stego_checkpoint.py ├── data_preprocessing │ ├── create_curated_dataset.py │ ├── crop_dataset.py │ ├── generate_traversability_labels_freiburg_forest.py │ ├── preprocess_RUGD.py │ ├── preprocess_cocostuff.py │ ├── preprocess_freiburg_forest.py │ └── preprocessing_utils.py ├── demo_segmentation.py ├── download_stego_datasets.py ├── download_stego_models.py ├── eval_clusters_wvn.py ├── eval_segmentation.py ├── plot.py ├── precompute_knns.py └── train.py ├── setup.py └── stego ├── __init__.py ├── backbones ├── __init__.py ├── backbone.py └── dino │ ├── __init__.py │ ├── utils.py │ └── vision_transformer.py ├── cfg └── model_config.yaml ├── data.py ├── modules.py ├── stego.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore=E402,E501,W503,C901,E731,W605 3 | extend-ignore = E203 4 | max-line-length = 120 5 | max-complexity = 18 6 | exclude=_*,.vscode,.git -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/* 2 | saved_models/* 3 | outputs/* 4 | data/* 5 | src/.env 6 | results/predictions/* 7 | src/iarpa 8 | **.egg-info 9 | **.pyc -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # .pre-commit-config.yaml 2 | # From https://sbarnea.com/lint/black/ 3 | --- 4 | repos: 5 | - repo: https://github.com/python/black.git 6 | rev: 22.12.0 7 | hooks: 8 | - id: black 9 | language_version: python3 10 | - repo: https://github.com/pycqa/flake8 11 | rev: 3.7.9 12 | hooks: 13 | - id: flake8 14 | additional_dependencies: 15 | - flake8-black>=0.1.1 16 | language_version: python3 17 | - repo: local 18 | hooks: 19 | - id: jupyter-nb-clear-output 20 | name: jupyter-nb-clear-output 21 | files: \.ipynb$ 22 | stages: [commit] 23 | language: system 24 | entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Mark Hamilton. All rights reserved. 4 | Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. All rights reserved. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a 7 | copy of this software and associated documentation files (the 8 | "Software"), to deal in the Software without restriction, including 9 | without limitation the rights to use, copy, modify, merge, publish, 10 | distribute, sublicense, and/or sell copies of the Software, and to 11 | permit persons to whom the Software is furnished to do so, subject to 12 | the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included 15 | in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 21 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 22 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 23 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Semantic Segmentation for Wild Visual Navigation 2 | 3 | This is the implementation of the project "Semantic Understanding of Outdoor Environments for Navigation" completed at the Robotic Systems Lab at ETH Zurich in the Spring Semester 2023. 4 | 5 | The goal of the project was to investigate how the recent unsupervised semantic segmentation model STEGO ([Hamilton et al., 2022](https://arxiv.org/pdf/2203.08414.pdf)) could be used in an outdoor navigation pipeline for a ground mobile robot, with the main focus on the context of the Wild Visual Navigation system ([Frey & Mattamala et al.](https://sites.google.com/leggedrobotics.com/wild-visual-navigation)). 6 | 7 | 8 | This package is built on a refactored version of [STEGO: Unsupervised Semantic Segmentation by Distilling Feature Correspondences](https://github.com/mhamilton723/STEGO) by Hamilton et al. 9 | 10 | ![image](doc/fig1.png) 11 | _SLIC (WVN's segmentation method), standard STEGO, and STEGO with per-image feature clustering segmenting natural scenes._ 12 | 13 | ## Contents 14 | * [Setup](#setup) 15 | * [Installation](#installation) 16 | * [Dataset Download](#download-datasets) 17 | * [Data Preprocessing](#preprocess-datasets) 18 | * [Model Download](#download-and-convert-stego-models) 19 | * [KNN Preprocessing](#precompute-knns) 20 | * [Demo Segmentation](#run-demo-segmentation) 21 | * [Evaluate Segmentation](#evaluate-segmentation) 22 | * [Train Segmentation](#train-segmentation) 23 | * [Evaluate for WVN](#evaluate-segmentation-for-wvn) 24 | * [Generate Plots](#generate-plots) 25 | * [License](#license) 26 | 27 | 28 | 29 | ## Setup 30 | 31 | ### Installation 32 | 33 | Clone the repository: 34 | ``` 35 | git clone https://github.com/leggedrobotics/self_supervised_segmentation.git 36 | cd self_supervised_segmentation 37 | ``` 38 | Install the environment: 39 | ``` 40 | conda env create -f environment.yml 41 | conda activate stego 42 | pip install -e . 43 | ``` 44 | 45 | ### Download datasets 46 | 47 | Download general datasets used by Hamilton et al.: 48 | ``` 49 | python scripts/download_stego_datasets.py 50 | ``` 51 | **_NOTE:_** `wget`, which is used in the download scripts might not always work well with these large downloads. In case the download fails, try downloading the selected datasets with [azcopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10). For example, to download the cocostuff dataset: 52 | 53 | ``` 54 | azcopy copy https://marhamilresearch4.blob.core.windows.net/stego-public/pytorch_data/cocostuff.zip ./cocostuff.zip 55 | ``` 56 | 57 | In the case of the cocostuff dataset, Hamilton et al. use subsets of training and validation samples for experiments, which were also used in this project. Lists of samples can be obtained by downloading the dataset from the link above. Then, a dataset curated according to a selected list can be generated with `scripts/create_curated_dataset.py`. 58 | 59 | Download datasets with natural scenes: 60 | ``` 61 | # Download RUGD 62 | wget http://rugd.vision/data/RUGD_frames-with-annotations.zip 63 | wget http://rugd.vision/data/RUGD_annotations.zip 64 | unzip RUGD_frames-with-annotations.zip -d RUGD 65 | unzip RUGD_annotations.zip -d RUGD 66 | rm RUGD_annotations.zip RUGD_frames-with-annotations.zip 67 | 68 | # Download Freiburg Forest 69 | wget http://deepscene.cs.uni-freiburg.de/static/datasets/download_freiburg_forest_annotated.sh 70 | bash download_freiburg_forest_annotated.sh 71 | tar -xzf freiburg_forest_annotated.tar.gz 72 | rm freiburg_forest_annotated.tar.gz* 73 | ``` 74 | 75 | ### Preprocess datasets 76 | 77 | To facilitate using various datasets with the package, preprocessing scripts have been added to `scripts/data_preprocessing`. Before running, adjust paths in each preprocessing script. 78 | 79 | Cocostuff preprocessing: 80 | ``` 81 | # Preprocess full Cocostuff 82 | python scripts/data_preprocessing/preprocess_cocostuff.py 83 | 84 | # Create the curated dataset 85 | python scripts/data_preprocessing/create_curated_dataset.py 86 | 87 | # Crop the dataset (only for training) 88 | python scripts/data_preprocessing/crop_dataset.py 89 | ``` 90 | 91 | RUGD preprocessing: 92 | ``` 93 | # Preprocess RUGD 94 | python scripts/data_preprocessing/preprocess_RUGD.py 95 | 96 | # Crop the dataset (only for training) 97 | python scripts/data_preprocessing/crop_dataset.py 98 | ``` 99 | 100 | Freiburg Forest preprocessing: 101 | ``` 102 | # Preprocess Freiburg Forest 103 | python scripts/data_preprocessing/preprocess_freiburg_forest.py 104 | 105 | # Crop the dataset (only for training) 106 | python scripts/data_preprocessing/crop_dataset.py 107 | ``` 108 | 109 | To use custom data with this package preprocess it to have the following structure: 110 | ``` 111 | YOUR_DATASET 112 | |-- imgs 113 | |-- train 114 | |-- val 115 | |-- labels 116 | |-- train 117 | |-- val 118 | ``` 119 | With RGB images in the `imgs` directory, and (optionally) annotations in the `labels` directory. 120 | 121 | If the `labels` directory is provided it should contain a label for all images in `imgs`, with each label named the same as its corresponding image (excluding file extension). 122 | Annotations should be provided as single-channel masks of the same size as their corresponding images. 123 | 124 | ### Download and convert STEGO models 125 | 126 | Download STEGO model checkpoints: 127 | 128 | ``` 129 | python scripts/download_stego_models.py 130 | ``` 131 | 132 | Convert selected checkpoints to the model structure used by this package. 133 | Set input and output paths in `scripts/cfg/convert_checkpoint_config.yaml` and run: 134 | ``` 135 | python scripts/convert_original_stego_checkpoint.py 136 | ``` 137 | 138 | ### Precompute KNNs 139 | 140 | To use a preprocessed dataset with a selected model and at a selected resolution, the `precompute_knns.py` script needs to be run with the selected parameters and model. 141 | This will create the nearest neighbors file in a separate subdirectory `nns` of the selected dataset. 142 | Adjust the parameters in `scripts/cfg/knn_config.yaml` and run: 143 | ``` 144 | python scripts/precompute_knns.py 145 | ``` 146 | 147 | 148 | ## Run demo segmentation 149 | To generate segmentation predictions for a selected folder of images: 150 | - Adjust input and output paths in `scripts/cfg/demo_config.yaml` 151 | - Run: 152 | ``` 153 | python scripts/demo_segmentation.py 154 | ``` 155 | This will generate visualizations of unsupervised segmentations in `output_dir/experiment_name/cluster` and visualizations of linear probe segmentations in `output_dir/experiment_name/linear`. 156 | 157 | 158 | ## Evaluate segmentation 159 | To evaluate STEGO on a selected dataset with unsupervised metrics: 160 | - Adjust paths and parameters in `scripts/cfg/eval_config.yaml` 161 | - Run: 162 | ``` 163 | python scripts/eval_segmentation.py 164 | ``` 165 | The script will calculate and print the results of the evaluation on the given data. 166 | 167 | ## Train segmentation 168 | After performing the preprocessing steps outlined in [Setup](#setup), you can train STEGO on the selected data. 169 | 170 | Before training, select the backbone and adjust the parameters of the model and the training. 171 | 172 | ### STEGO's backbone 173 | 174 | STEGO was built based on DINO ViT, but it can be used with any Vision Transformer. 175 | All available backbones can be found in `stego/backbones/backbone.py`. 176 | To add a new backbone, add all code necessary for the backbone to the `stego/backbones` folder and modify `stego/backbone/backbones.py`: 177 | - Add an implementation of the `Backbone` class for your backbone, 178 | - Add your implementation to the `get_backbone` function with the desired name. 179 | 180 | ### Parameters 181 | 182 | The parameters of STEGO are specified in `stego/cfg/model_config.yaml`. In this file you can: 183 | - select the backbone, 184 | - specify other model parameters, 185 | - specify training parameters: learning rates and STEGO's correspondence loss parameters. 186 | 187 | Other parameters for training can be set in `scripts/cfg/train_config.yaml`. 188 | 189 | After adjusting the parameters, run the training with: 190 | ``` 191 | python scripts/train.py 192 | ``` 193 | 194 | ### Checkpoints and Logging 195 | 196 | STEGO is implemented with Pytorch Lightning, which handles saving the checkpoints during training, in a directory that can be specified in `scripts/cfg/train_config.yaml`. 197 | 198 | Logging is implemented with Weights & Biases. To use W&B over cloud, login to wandb: 199 | ``` 200 | wandb login 201 | ``` 202 | During training, apart from unsupervised metrics, loss values and other parameters, visualizations of sample segmentations and the learned feature similarity distribution plot are logged. 203 | 204 | 205 | ## Evaluate segmentation for WVN 206 | To run the experiment that compares segmentation methods in the context of an outdoor navigation pipeline: 207 | - Generate binary traversability labels for a selected dataset. Currently, only preprocessing for Freiburg Forest is available. However, you can also preprocess different datasets for this experiment with this script provided that you change `TRAVERSABLE_IDS` to IDs of traversable classes in your dataset. Run: 208 | ``` 209 | # Adjust paths in the script before running 210 | python scripts/data_preprocessing/generate_traversability_labels_freiburg_forest.py 211 | ``` 212 | - Adjust parameters in `scripts/cfg/eval_clusters_wvn.yaml` 213 | - Run: 214 | ``` 215 | python scripts/eval_clusters_wvn.py 216 | ``` 217 | The script will calculate and print the results of the evaluation on the given data, and save the selected visualizations. 218 | 219 | ## Generate plots 220 | The `scripts/plot.py` script enables generation of precision-recall curves showing the performance of features in predicting label co-occurrence. 221 | It also provides an interactive plot visualizing feature similarities in selected images. 222 | 223 | To generate the plots: 224 | - Adjust paths and parameters in `scripts/cfg/plot_config.yaml` 225 | - Run: 226 | ``` 227 | python scripts/plot.py 228 | ``` 229 | 230 | ## License 231 | ``` 232 | Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 233 | All rights reserved. Licensed under the MIT license. 234 | ``` 235 | 236 | This project is based on previous work by Mark Hamilton. 237 | ``` 238 | Copyright (c) Mark Hamilton. 239 | All rights reserved. Licensed under the MIT license. 240 | ``` 241 | 242 | Files in `stego/backbones/dino` are licensed under the Apache 2.0 license by Facebook, Inc. and its affiliates. See the file headers for details. -------------------------------------------------------------------------------- /doc/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/doc/fig1.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - anaconda 3 | - pytorch 4 | - conda-forge 5 | - menpo 6 | - nvidia 7 | dependencies: 8 | - python=3.6.9 9 | - pip>=21.0,<22 10 | - pytorch==1.7.1 11 | - torchvision>=0.8.2 12 | - torchaudio>=0.7.2 13 | - pytorch-cuda=11.6 14 | - nvidia-apex==0.1.0 15 | - scikit-learn 16 | - scikit-image 17 | - opencv 18 | - cupy 19 | - pytorch-lightning 20 | - pip: 21 | - matplotlib>=3.3,<3.4 22 | - psutil>=5.8,<5.9 23 | - tqdm>=4.59,<4.60 24 | - pandas>=1.1,<1.2 25 | - scipy>=1.5,<1.6 26 | - numpy>=1.10,<1.20 27 | - tensorboard==2.4.0 28 | - future==0.17.1 29 | - --find-links https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html 30 | - kornia==0.6.8 31 | - hydra-core 32 | - wget 33 | - seaborn 34 | - easydict 35 | - torchpq 36 | - pydensecrf 37 | - setuptools==59.5.0 38 | - pyDeprecate==0.3.1 39 | - fast_slic 40 | - pytictoc 41 | - wandb==0.14.2 42 | name: stego 43 | -------------------------------------------------------------------------------- /models/download_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) Mark Hamilton. All rights reserved. 4 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 5 | # All rights reserved. Licensed under the MIT license. 6 | # See LICENSE file in the project root for details. 7 | # 8 | # This scripts gets the pretrained weights to run STEGO 9 | # 10 | 11 | # From https://stackoverflow.com/a/246128 12 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 13 | cd $SCRIPT_DIR 14 | 15 | gdown https://drive.google.com/uc?id=1t6dS_9LlN9meN2yoA6iSZEaLlVMEJ4de 16 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 -------------------------------------------------------------------------------- /results/correspondence/attention_interp_2_img.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/results/correspondence/attention_interp_2_img.mp4 -------------------------------------------------------------------------------- /results/correspondence/attention_interp_2_img_2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/results/correspondence/attention_interp_2_img_2.mp4 -------------------------------------------------------------------------------- /results/correspondence/attention_interp_2_img_moto.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/results/correspondence/attention_interp_2_img_moto.mp4 -------------------------------------------------------------------------------- /results/correspondence/correspondence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/results/correspondence/correspondence.png -------------------------------------------------------------------------------- /results/figures/cityscapes_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/results/figures/cityscapes_results.jpg -------------------------------------------------------------------------------- /results/figures/cocostuff27_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/results/figures/cocostuff27_results.jpg -------------------------------------------------------------------------------- /results/figures/correspondence.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/results/figures/correspondence.gif -------------------------------------------------------------------------------- /results/figures/dog_man_correspondence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/results/figures/dog_man_correspondence.png -------------------------------------------------------------------------------- /results/figures/img_correspondence.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/self_supervised_segmentation/7e9c65a14dc38cb7f506dfb7f469ddde133d01f9/results/figures/img_correspondence.jpg -------------------------------------------------------------------------------- /scripts/cfg/convert_checkpoint_config.yaml: -------------------------------------------------------------------------------- 1 | model_path: "/root/catkin_ws/src/self_supervised_segmentation/scripts/models/cocostuff27_vit_base_5.ckpt" 2 | output_path: "/root/catkin_ws/src/self_supervised_segmentation/scripts/models/stego_cocostuff27_vit_base_5_cluster_linear_fine_tuning.ckpt" 3 | 4 | hydra: 5 | run: 6 | dir: "." 7 | output_subdir: ~ -------------------------------------------------------------------------------- /scripts/cfg/demo_config.yaml: -------------------------------------------------------------------------------- 1 | image_dir: "/cluster/home/plibera/images" 2 | output_root: "/cluster/home/plibera/outputs/preds" 3 | model_path: "/root/catkin_ws/src/self_supervised_segmentation/models/stego_cocostuff27_vit_base_5_cluster_linear_fine_tuning.ckpt" 4 | # model_path: "/root/catkin_ws/src/self_supervised_segmentation/models/FF_training_200.ckpt" 5 | # model_path: "/root/catkin_ws/src/self_supervised_segmentation/models/cityscapes_vit_base_1.ckpt" 6 | experiment_name: "forest_youtube_val" 7 | resolution: 320 8 | 9 | # Loader params 10 | batch_size: 8 11 | num_workers: 24 12 | 13 | run_crf: True -------------------------------------------------------------------------------- /scripts/cfg/eval_config.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "/scratch/tmp.17616576.plibera" 2 | dataset_name: "cocostuff_curated" 3 | model_path: "/cluster/scratch/plibera/checkpoints/cocostuff_training/epoch=0-step=3599.ckpt" # Path to STEGO checkpoint 4 | output_root: "/cluster/home/plibera/outputs/eval" 5 | experiment_name: "cocostuff_val" 6 | 7 | resolution: 320 8 | batch_size: 8 9 | num_workers: 24 10 | 11 | n_batches: # Optionally, specify a number of batches in the shuffled dataset to calculate metrics on 12 | run_crf: False -------------------------------------------------------------------------------- /scripts/cfg/eval_wvn_config.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "/cluster/scratch/plibera" 2 | dataset_name: "freiburg_forest_preprocessed_trav" 3 | 4 | # Paths tp stego models with given numbers of clusters. The models should have the same backbone and segmentation head weights (only differ in the clustering step) 5 | model_paths: [] 6 | # Numbers of clusters of each corresponding model given in model_paths 7 | stego_n_clusters: [] 8 | # Set to True to run per-image STEGO clustering 9 | cluster_stego_by_image: False 10 | # Numbers of segments of SLIC models 11 | slic_n_clusters: [] 12 | slic_compactness: 10 13 | 14 | output_root: "/cluster/home/plibera/outputs/wvn" 15 | experiment_name: "freiburg_forest_labels" 16 | # Save visualizations of segmentations 17 | save_vis: True 18 | # Save plots of distributions of different metrics (e.g. feature variance per segment) 19 | save_plots: False 20 | # Save plots presenting distributions of metrics of several models in a single plot 21 | save_comparison_plots: False 22 | 23 | resolution: 320 24 | num_workers: 1 25 | 26 | n_imgs: # Optionally, specify a number of batches in the shuffled dataset to calculate metrics on 27 | run_crf: True -------------------------------------------------------------------------------- /scripts/cfg/knn_config.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "/cluster/scratch/plibera" 2 | dataset_name: "cocostuff_curated_cropped" 3 | image_sets: 4 | - "train" 5 | - "val" 6 | 7 | resolution: 224 8 | batch_size: 8 9 | num_workers: 24 -------------------------------------------------------------------------------- /scripts/cfg/plot_config.yaml: -------------------------------------------------------------------------------- 1 | # Data and paths 2 | model_path: 3 | cmap: "turbo" 4 | zero_mean: False # Mean normalize cosine similarity values in the displayed heatmap 5 | zero_clamp: True # Clamp cosine similarity heatmap to 0 6 | 7 | 8 | 9 | # Interactive correspondences plot 10 | plot_correspondences_interactive: False 11 | image_a_path: "/cluster/home/plibera/forest.png" 12 | image_b_path: # Can be omitted to generate self-correspondences for the first image 13 | correspondence_output_dir: "/cluster/home/plibera" 14 | display_resolution: 512 15 | 16 | # Augmentations params for image B 17 | brightness_factor: 1.0 # Non-negative, 1.0 for no change 18 | contrast_factor: 1.0 # Non-negative, 1.0 for no change 19 | saturation_factor: 1.0 # Non-negative, 1.0 for no change 20 | hue_factor: 0.0 # [-0.5, 0.5], 0.0 for no change 21 | gaussian_kernel_size: #13 22 | gaussian_sigma: #2.0 23 | 24 | 25 | 26 | # Precision-Recall curves plot 27 | plot_pr: True 28 | plot_backbone_pr: False 29 | plot_stego_pr: False 30 | data_dir: "/scratch/tmp.17524104.plibera" 31 | dataset_name: "RUGD" 32 | # Output path for plots 33 | pr_output_dir: "/cluster/home/plibera/self_supervised_segmentation/results/pr" 34 | # Output path for PR data 35 | pr_output_data_dir: "/cluster/scratch/plibera/results/pr" 36 | # Names of pickled PR data files to additionally display in the plot 37 | additional_pr_curves: ["RUGD_preprocessed_cropped_DINOv2-B-14_224.pkl", "RUGD_preprocessed_cropped_DINO-B-8_224.pkl", "RUGD_preprocessed_cropped_DINO-B-16_224.pkl"] 38 | pr_resolution: 224 39 | batch_size: 8 40 | num_workers: 24 -------------------------------------------------------------------------------- /scripts/cfg/train_config.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "/scratch/tmp.19104891.plibera" 2 | dataset_name: "RUGD_preprocessed_cropped" 3 | 4 | # Output directory to save the trained checkpoints during training 5 | checkpoint_dir: "/cluster/scratch/plibera/checkpoints/RUGD_fine-tuning" 6 | # (Optional) Checkpoint path to initialize the training 7 | model_path: "/cluster/home/plibera/self_supervised_segmentation/models/stego_cocostuff27_vit_base_5_cluster_linear_fine_tuning.ckpt" 8 | 9 | 10 | resolution: 224 11 | num_classes: 24 12 | 13 | num_workers: 24 14 | batch_size: 32 15 | 16 | scalar_log_freq: 1 17 | max_steps: 5000 18 | val_check_interval: 100 19 | 20 | num_neighbors: 5 21 | 22 | # Reset the cluster and linear probes of the loaded model. If True, cluster and linear probes will be reset with n_classes and extra_clusters given in this file 23 | reset_clusters: True 24 | extra_clusters: 0 25 | 26 | # WandB 27 | wandb_project: "stego" 28 | wandb_name: "RUGD_fine-tuning" 29 | wandb_log_model: "all" 30 | -------------------------------------------------------------------------------- /scripts/convert_original_stego_checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | ############################################ 9 | # Conversion of original STEGO checkpoints 10 | # 11 | # This script can be used to convert the original STEGO checkpoints, trained by Hamilton et al., to checkpoints that can be used in this package. 12 | # 13 | # Original checkpoints can be downloaded with download_stego_models.py 14 | # 15 | # Before running this script, adjust paths in the config file cfg/convert_checkpoint_config.yaml 16 | # 17 | ############################################ 18 | 19 | 20 | import hydra 21 | from omegaconf import DictConfig, OmegaConf 22 | import pytorch_lightning as pl 23 | from pytorch_lightning import Trainer 24 | from torch.utils.data import Dataset 25 | import torch.multiprocessing 26 | from torch import nn 27 | import os 28 | import omegaconf 29 | import copy 30 | 31 | import stego.backbones.dino.vision_transformer as vits 32 | from stego.utils import UnsupervisedMetrics, prep_args 33 | from stego.modules import ClusterLookup, ContrastiveCorrelationLoss 34 | from stego.stego import Stego 35 | 36 | 37 | class RandomDataset(Dataset): 38 | def __init__(self, length: int, size: tuple): 39 | self.len = length 40 | self.data = torch.randn(length, *size) 41 | 42 | def __getitem__(self, index: int) -> torch.Tensor: 43 | return self.data[index] 44 | 45 | def __len__(self) -> int: 46 | return self.len 47 | 48 | 49 | class DinoFeaturizer(nn.Module): 50 | """ 51 | Class from the original STEGO package, used to load the original checkpoint. 52 | """ 53 | 54 | def __init__(self, dim, cfg): 55 | super().__init__() 56 | self.cfg = cfg 57 | self.dim = dim 58 | patch_size = self.cfg.dino_patch_size 59 | self.patch_size = patch_size 60 | self.feat_type = self.cfg.dino_feat_type 61 | arch = self.cfg.model_type 62 | self.model = vits.__dict__[arch](patch_size=patch_size, num_classes=0) 63 | for p in self.model.parameters(): 64 | p.requires_grad = False 65 | self.model.eval().cuda() 66 | self.dropout = torch.nn.Dropout2d(p=0.1) 67 | 68 | if arch == "vit_small" and patch_size == 16: 69 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 70 | elif arch == "vit_small" and patch_size == 8: 71 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" 72 | elif arch == "vit_base" and patch_size == 16: 73 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 74 | elif arch == "vit_base" and patch_size == 8: 75 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 76 | else: 77 | raise ValueError("Unknown arch and patch size") 78 | 79 | if cfg.pretrained_weights is not None: 80 | state_dict = torch.load(cfg.pretrained_weights, map_location="cpu") 81 | # state_dict = state_dict["teacher"] 82 | # remove `module.` prefix 83 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 84 | # remove `backbone.` prefix induced by multicrop wrapper 85 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 86 | 87 | # state_dict = {k.replace("projection_head", "mlp"): v for k, v in state_dict.items()} 88 | # state_dict = {k.replace("prototypes", "last_layer"): v for k, v in state_dict.items()} 89 | 90 | msg = self.model.load_state_dict(state_dict, strict=False) 91 | print("Pretrained weights found at {} and loaded with msg: {}".format(cfg.pretrained_weights, msg)) 92 | else: 93 | print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") 94 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) 95 | self.model.load_state_dict(state_dict, strict=True) 96 | 97 | if arch == "vit_small": 98 | self.n_feats = 384 99 | else: 100 | self.n_feats = 768 101 | self.cluster1 = self.make_clusterer(self.n_feats) 102 | self.proj_type = cfg.projection_type 103 | if self.proj_type == "nonlinear": 104 | self.cluster2 = self.make_nonlinear_clusterer(self.n_feats) 105 | 106 | def make_clusterer(self, in_channels): 107 | return torch.nn.Sequential(torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # , 108 | 109 | def make_nonlinear_clusterer(self, in_channels): 110 | return torch.nn.Sequential( 111 | torch.nn.Conv2d(in_channels, in_channels, (1, 1)), 112 | torch.nn.ReLU(), 113 | torch.nn.Conv2d(in_channels, self.dim, (1, 1)), 114 | ) 115 | 116 | def forward(self, img, n=1, return_class_feat=False): 117 | self.model.eval() 118 | with torch.no_grad(): 119 | assert img.shape[2] % self.patch_size == 0 120 | assert img.shape[3] % self.patch_size == 0 121 | 122 | # get selected layer activations 123 | feat, attn, qkv = self.model.get_intermediate_feat(img, n=n) 124 | feat, attn, qkv = feat[0], attn[0], qkv[0] 125 | 126 | feat_h = img.shape[2] // self.patch_size 127 | feat_w = img.shape[3] // self.patch_size 128 | 129 | if self.feat_type == "feat": 130 | image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2) 131 | elif self.feat_type == "KK": 132 | image_k = qkv[1, :, :, 1:, :].reshape(feat.shape[0], 6, feat_h, feat_w, -1) 133 | B, H, I, J, D = image_k.shape 134 | image_feat = image_k.permute(0, 1, 4, 2, 3).reshape(B, H * D, I, J) 135 | else: 136 | raise ValueError("Unknown feat type:{}".format(self.feat_type)) 137 | 138 | if return_class_feat: 139 | return feat[:, :1, :].reshape(feat.shape[0], 1, 1, -1).permute(0, 3, 1, 2) 140 | 141 | if self.proj_type is not None: 142 | code = self.cluster1(self.dropout(image_feat)) 143 | if self.proj_type == "nonlinear": 144 | code += self.cluster2(self.dropout(image_feat)) 145 | else: 146 | code = image_feat 147 | 148 | if self.cfg.dropout: 149 | return self.dropout(image_feat), code 150 | else: 151 | return image_feat, code 152 | 153 | 154 | class ContrastiveCRFLoss(nn.Module): 155 | """ 156 | Class from the original STEGO package, used to load the original checkpoint. 157 | """ 158 | 159 | def __init__(self, n_samples, alpha, beta, gamma, w1, w2, shift): 160 | super(ContrastiveCRFLoss, self).__init__() 161 | self.alpha = alpha 162 | self.beta = beta 163 | self.gamma = gamma 164 | self.w1 = w1 165 | self.w2 = w2 166 | self.n_samples = n_samples 167 | self.shift = shift 168 | 169 | def forward(self, guidance, clusters): 170 | device = clusters.device 171 | assert guidance.shape[0] == clusters.shape[0] 172 | assert guidance.shape[2:] == clusters.shape[2:] 173 | h = guidance.shape[2] 174 | w = guidance.shape[3] 175 | 176 | coords = torch.cat( 177 | [ 178 | torch.randint(0, h, size=[1, self.n_samples], device=device), 179 | torch.randint(0, w, size=[1, self.n_samples], device=device), 180 | ], 181 | 0, 182 | ) 183 | 184 | selected_guidance = guidance[:, :, coords[0, :], coords[1, :]] 185 | coord_diff = (coords.unsqueeze(-1) - coords.unsqueeze(1)).square().sum(0).unsqueeze(0) 186 | guidance_diff = (selected_guidance.unsqueeze(-1) - selected_guidance.unsqueeze(2)).square().sum(1) 187 | 188 | sim_kernel = ( 189 | self.w1 * torch.exp(-coord_diff / (2 * self.alpha) - guidance_diff / (2 * self.beta)) 190 | + self.w2 * torch.exp(-coord_diff / (2 * self.gamma)) 191 | - self.shift 192 | ) 193 | 194 | selected_clusters = clusters[:, :, coords[0, :], coords[1, :]] 195 | cluster_sims = torch.einsum("nka,nkb->nab", selected_clusters, selected_clusters) 196 | return -(cluster_sims * sim_kernel) 197 | 198 | 199 | class LitUnsupervisedSegmenter(pl.LightningModule): 200 | """ 201 | Class from the original STEGO package, used to load the original checkpoint. 202 | """ 203 | 204 | def __init__(self, n_classes, cfg): 205 | super().__init__() 206 | self.cfg = cfg 207 | self.n_classes = n_classes 208 | dim = cfg.dim 209 | self.net = DinoFeaturizer(dim, cfg) 210 | self.train_cluster_probe = ClusterLookup(dim, n_classes) 211 | self.cluster_probe = ClusterLookup(dim, n_classes + cfg.extra_clusters) 212 | self.linear_probe = nn.Conv2d(dim, n_classes, (1, 1)) 213 | self.decoder = nn.Conv2d(dim, self.net.n_feats, (1, 1)) 214 | self.cluster_metrics = UnsupervisedMetrics("test/cluster/", n_classes, cfg.extra_clusters, True) 215 | self.linear_metrics = UnsupervisedMetrics("test/linear/", n_classes, 0, False) 216 | self.test_cluster_metrics = UnsupervisedMetrics("final/cluster/", n_classes, cfg.extra_clusters, True) 217 | self.test_linear_metrics = UnsupervisedMetrics("final/linear/", n_classes, 0, False) 218 | self.linear_probe_loss_fn = torch.nn.CrossEntropyLoss() 219 | self.crf_loss_fn = ContrastiveCRFLoss( 220 | cfg.crf_samples, cfg.alpha, cfg.beta, cfg.gamma, cfg.w1, cfg.w2, cfg.shift 221 | ) 222 | self.contrastive_corr_loss_fn = ContrastiveCorrelationLoss(cfg) 223 | for p in self.contrastive_corr_loss_fn.parameters(): 224 | p.requires_grad = False 225 | self.automatic_optimization = False 226 | self.label_cmap = None 227 | self.val_steps = 0 228 | self.save_hyperparameters() 229 | 230 | def forward(self, x): 231 | # in lightning, forward defines the prediction/inference actions 232 | return self.net(x)[1] 233 | 234 | 235 | @hydra.main(config_path="cfg", config_name="convert_checkpoint_config.yaml") 236 | def my_app(cfg: DictConfig) -> None: 237 | model = LitUnsupervisedSegmenter.load_from_checkpoint(cfg.model_path) 238 | print(OmegaConf.to_yaml(model.cfg)) 239 | 240 | with open(os.path.join(os.path.dirname(__file__), "../stego/cfg/model_config.yaml"), "r") as file: 241 | model_cfg = omegaconf.OmegaConf.load(file) 242 | model_cfg.backbone = model.cfg.arch 243 | model_cfg.backbone_type = model.cfg.model_type 244 | model_cfg.patch_size = model.cfg.dino_patch_size 245 | model_cfg.dim = model.cfg.dim 246 | model_cfg.extra_clusters = model.cfg.extra_clusters 247 | n_classes = model.n_classes 248 | stego = Stego(n_classes, model_cfg) 249 | 250 | with torch.no_grad(): 251 | stego.cluster_probe = copy.deepcopy(model.cluster_probe) 252 | stego.linear_probe = copy.deepcopy(model.linear_probe) 253 | stego.segmentation_head.linear = copy.deepcopy(model.net.cluster1) 254 | stego.segmentation_head.nonlinear = copy.deepcopy(model.net.cluster2) 255 | 256 | trainer = Trainer(enable_checkpointing=False, max_steps=0) 257 | trainer.predict(stego, RandomDataset(1, (1, 3, 224, 224))) 258 | trainer.save_checkpoint(cfg.output_path) 259 | 260 | 261 | if __name__ == "__main__": 262 | prep_args() 263 | my_app() 264 | -------------------------------------------------------------------------------- /scripts/data_preprocessing/create_curated_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | ############################################ 8 | # Curated dataset generation script 9 | # 10 | # This script uses a list of samples in the input dataset to create a separate dataset containing only those samples. 11 | # The list of samples should be saved in a text file, with each sample name in a separate line. 12 | # Sample names need to correspond to names of image and label files in the dataset (name of an image, without file extensions). 13 | # The samples in the new dataset are created as links to samples in the input dataset to save memory. 14 | # Hence, the input dataset to this script should already be preprocessed. 15 | # 16 | # Expected input structure: 17 | # DATA_DIR 18 | # |-- INPUT_NAME 19 | # |-- imgs 20 | # |-- train 21 | # |-- val 22 | # |-- labels 23 | # |-- train 24 | # |-- val 25 | # 26 | ############################################ 27 | 28 | 29 | import os 30 | from os.path import join 31 | import numpy as np 32 | from tqdm import tqdm 33 | 34 | from scripts.data_preprocessing.preprocessing_utils import * 35 | 36 | 37 | DATA_DIR = "/data" 38 | INPUT_NAME = "cocostuff_preprocessed" 39 | OUTPUT_NAME = "cocostuff_curated" 40 | 41 | TRAIN_SAMPLES_FILE = "/data/cocostuff/curated/train2017/Coco164kFull_Stuff_Coarse.txt" 42 | VAL_SAMPLES_FILE = "/data/cocostuff/curated/val2017/Coco164kFull_Stuff_Coarse.txt" 43 | 44 | 45 | def preprocess_samples(input_dir, output_dir, subset, input_subset, sample_file): 46 | print("Processing subset {}".format(subset)) 47 | image_files = [] 48 | label_files = [] 49 | sample_names = [] 50 | with open(sample_file, "r") as f: 51 | img_ids = [fn.rstrip() for fn in f.readlines()] 52 | for img_id in img_ids: 53 | image_files.append(join(input_dir, "imgs", input_subset, img_id + ".jpg")) 54 | label_files.append(join(input_dir, "labels", input_subset, img_id + ".png")) 55 | sample_names.append(img_id) 56 | for i, sample_name in tqdm(enumerate(sample_names)): 57 | img_path = image_files[i] 58 | label_path = label_files[i] 59 | os.link(img_path, join(output_dir, "imgs", subset, sample_name + ".jpg")) 60 | os.link(label_path, join(output_dir, "labels", subset, sample_name + ".png")) 61 | 62 | 63 | def main(): 64 | input_dir = join(DATA_DIR, INPUT_NAME) 65 | output_dir = join(DATA_DIR, OUTPUT_NAME) 66 | create_dataset_structure(output_dir) 67 | preprocess_samples(input_dir, output_dir, "train", "train", TRAIN_SAMPLES_FILE) 68 | preprocess_samples(input_dir, output_dir, "val", "val", VAL_SAMPLES_FILE) 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /scripts/data_preprocessing/crop_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | ############################################ 8 | # Cropped dataset generation script 9 | # 10 | # This script five-crops images from the input dataset and creates a new dataset with the generated crops. 11 | # In STEGO, cropping (five-crop) should be performed before KNN generation. 12 | # Hence, the images should first be cropped with this script, then KNNs should be generated for the new dataset. 13 | # 14 | # Expected input structure: 15 | # DATA_DIR 16 | # |-- INPUT_NAME 17 | # |-- imgs 18 | # |-- train 19 | # |-- val 20 | # |-- labels 21 | # |-- train 22 | # |-- val 23 | # 24 | ############################################ 25 | 26 | 27 | import os 28 | from os.path import join 29 | from tqdm import tqdm 30 | from torchvision.transforms.functional import five_crop 31 | 32 | from scripts.data_preprocessing.preprocessing_utils import create_dataset_structure 33 | from PIL import Image 34 | 35 | 36 | DATA_DIR = "/data" 37 | INPUT_NAME = "cocostuff_curated" 38 | OUTPUT_NAME = "cocostuff_curated_cropped" 39 | 40 | # An image of size HxW will be five-cropped with target size of (CROP_RATIO*H)x(CROP_RATIO*W) 41 | CROP_RATIO = 0.5 42 | # File extension of images (in the imgs directory) 43 | IMAGE_EXT = ".jpg" 44 | 45 | 46 | def save_five_crop(input_name, output_dir, sample_name, file_ext): 47 | output_names = [join(output_dir, sample_name + "_" + str(i) + file_ext) for i in range(5)] 48 | all_exist = True 49 | for name in output_names: 50 | if not os.path.isfile(name): 51 | all_exist = False 52 | if all_exist: 53 | return 54 | image = Image.open(input_name) 55 | crops = five_crop(image, (CROP_RATIO * image.height, CROP_RATIO * image.width)) 56 | for i, crop in enumerate(crops): 57 | name = output_names[i] 58 | if not os.path.isfile(name): 59 | crop.save(name) 60 | 61 | 62 | def preprocess_samples(input_dir, output_dir, subset, input_subset): 63 | print("Processing subset {}".format(subset)) 64 | label_names = os.listdir(join(input_dir, "labels", input_subset)) 65 | for label_name in tqdm(label_names): 66 | sample_name = label_name.split(".")[0] 67 | img_path = join(input_dir, "imgs", input_subset, sample_name + IMAGE_EXT) 68 | label_path = join(input_dir, "labels", input_subset, label_name) 69 | save_five_crop(img_path, join(output_dir, "imgs", subset), sample_name, IMAGE_EXT) 70 | save_five_crop(label_path, join(output_dir, "labels", subset), sample_name, ".png") 71 | 72 | 73 | def main(): 74 | input_dir = join(DATA_DIR, INPUT_NAME) 75 | output_dir = join(DATA_DIR, OUTPUT_NAME) 76 | create_dataset_structure(output_dir) 77 | preprocess_samples(input_dir, output_dir, "train", "train") 78 | preprocess_samples(input_dir, output_dir, "val", "val") 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /scripts/data_preprocessing/generate_traversability_labels_freiburg_forest.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | ############################################ 8 | # Freiburg Forest traversability labels 9 | # 10 | # This script takes the preprocessed Freiburg Forest dataset and generates a copy with binary traversability labels 11 | # 12 | # 13 | # Expected input structure after unpacking in DATA_DIR: 14 | # DATA_DIR 15 | # |-- OUTPUT_NAME 16 | # |-- imgs 17 | # |-- train 18 | # |-- val 19 | # |-- labels 20 | # |-- train 21 | # |-- val 22 | # 23 | # Output structure after processing: 24 | # DATA_DIR 25 | # |-- OUTPUT_NAME 26 | # |-- imgs 27 | # |-- train 28 | # |-- val 29 | # |-- labels 30 | # |-- train 31 | # |-- val 32 | ############################################ 33 | 34 | 35 | import os 36 | from os.path import join 37 | import numpy as np 38 | from tqdm import tqdm 39 | from PIL import Image 40 | 41 | from scripts.data_preprocessing.preprocessing_utils import create_dataset_structure, preprocess_and_copy_image 42 | 43 | 44 | DATA_DIR = "/data" 45 | INPUT_NAME = "freiburg_forest_preprocessed" 46 | OUTPUT_NAME = "freiburg_forest_preprocessed_trav" 47 | 48 | 49 | FF_CMAP = np.array( 50 | [ 51 | (0, 0, 0), # Object 52 | (170, 170, 170), # Road 53 | (0, 255, 0), # Grass 54 | (102, 102, 51), # Vegetation 55 | (0, 120, 255), # Sky 56 | ( 57 | 0, 58 | 60, 59 | 0, 60 | ), # Tree (separate color present in the dataset, but assigned to class Vegetation in the dataset's official readme) 61 | ] 62 | ) 63 | 64 | TRAVERSABLE_IDS = [1, 2] # Road and Grass 65 | 66 | 67 | def preprocess_and_save_trav_label(input_name, output_name, traversable_ids): 68 | if os.path.isfile(output_name): 69 | return 70 | image = Image.open(input_name) 71 | img = np.array(image) 72 | label = np.zeros(img.shape) 73 | for id in traversable_ids: 74 | label = np.where(img == id, 1, label) 75 | image = Image.fromarray(label.astype(np.uint8)) 76 | image.save(output_name) 77 | 78 | 79 | def preprocess_samples(input_dir, output_dir, subset, input_subset): 80 | print("Processing subset {}".format(subset)) 81 | label_names = os.listdir(join(input_dir, "labels", input_subset)) 82 | for label_name in tqdm(label_names): 83 | sample_name = os.path.splitext(label_name)[0] 84 | img_path = join(input_dir, "imgs", input_subset, sample_name + ".jpg") 85 | label_path = join(input_dir, "labels", input_subset, label_name) 86 | preprocess_and_copy_image(img_path, join(output_dir, "imgs", subset, sample_name + ".jpg"), False) 87 | preprocess_and_save_trav_label( 88 | label_path, 89 | join(output_dir, "labels", subset, sample_name + ".png"), 90 | TRAVERSABLE_IDS, 91 | ) 92 | 93 | 94 | def main(): 95 | input_dir = join(DATA_DIR, INPUT_NAME) 96 | output_dir = join(DATA_DIR, OUTPUT_NAME) 97 | create_dataset_structure(output_dir) 98 | preprocess_samples(input_dir, output_dir, "train", "train") 99 | preprocess_samples(input_dir, output_dir, "val", "val") 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /scripts/data_preprocessing/preprocess_RUGD.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | ############################################ 8 | # RUGD preprocessing script 9 | # 10 | # RUGD: http://rugd.vision/ 11 | # Wigness, M., Eum, S., Rogers, J. G., Han, D., & Kwon, H. (2019). 12 | # A RUGD Dataset for Autonomous Navigation and Visual Perception in Unstructured Outdoor Environments. 13 | # International Conference on Intelligent Robots and Systems (IROS). 14 | # 15 | # 16 | # 17 | # Download RUGD: 18 | # wget http://rugd.vision/data/RUGD_frames-with-annotations.zip 19 | # wget http://rugd.vision/data/RUGD_annotations.zip 20 | # 21 | # unzip RUGD_frames-with-annotations.zip -d RUGD 22 | # unzip RUGD_annotations.zip -d RUGD 23 | # rm RUGD_annotations.zip RUGD_frames-with-annotations.zip 24 | # 25 | # 26 | # Expected input structure: 27 | # DATA_DIR 28 | # |-- INPUT_NAME 29 | # |-- RUGD_annotations 30 | # |-- RUGD_frames-with-annotations 31 | # 32 | # Output structure: 33 | # DATA_DIR 34 | # |-- OUTPUT_NAME 35 | # |-- imgs 36 | # |-- train 37 | # |-- val 38 | # |-- labels 39 | # |-- train 40 | # |-- val 41 | ############################################ 42 | 43 | 44 | import os 45 | from os.path import join 46 | import numpy as np 47 | from tqdm import tqdm 48 | 49 | from scripts.data_preprocessing.preprocessing_utils import * 50 | 51 | 52 | DATA_DIR = "/data" 53 | INPUT_NAME = "RUGD" 54 | OUTPUT_NAME = "RUGD_preprocessed" 55 | 56 | # Split according to the paper 57 | TRAIN_SAMPLES = [ 58 | "park-2", 59 | "trail", 60 | "trail-3", 61 | "trail-4", 62 | "trail-6", 63 | "trail-9", 64 | "trail-10", 65 | "trail-11", 66 | "trail-12", 67 | "trail-14", 68 | "trail-15", 69 | "village", 70 | ] 71 | VAL_SAMPLES = ["park-8", "trail-5"] 72 | TEST_SAMPLES = ["creek", "park-1", "trail-7", "trail-13"] 73 | 74 | 75 | RUGD_CMAP = np.array( 76 | [ 77 | (0, 0, 0), # void 78 | (108, 64, 20), # dirt 79 | (255, 229, 204), # sand 80 | (0, 102, 0), # grass 81 | (0, 255, 0), # tree 82 | (0, 153, 153), # pole 83 | (0, 128, 255), # water 84 | (0, 0, 255), # sky 85 | (255, 255, 0), # vehicle 86 | (255, 0, 127), # container/generic-object 87 | (64, 64, 64), # asphalt 88 | (255, 128, 0), # gravel 89 | (255, 0, 0), # building 90 | (153, 76, 0), # mulch 91 | (102, 102, 0), # rock-bed 92 | (102, 0, 0), # log 93 | (0, 255, 128), # bicycle 94 | (204, 153, 255), # person 95 | (102, 0, 204), # fence 96 | (255, 153, 204), # bush 97 | (0, 102, 102), # sign 98 | (153, 204, 255), # rock 99 | (102, 255, 255), # bridge 100 | (101, 101, 11), # concrete 101 | (114, 85, 47), 102 | ] 103 | ) # picnic-table 104 | 105 | 106 | def preprocess_samples(input_dir, output_dir, subset, samples): 107 | print("Processing subset {}".format(subset)) 108 | for sample in tqdm(samples): 109 | img_path = join(input_dir, "RUGD_frames-with-annotations", sample) 110 | label_path = join(input_dir, "RUGD_annotations", sample) 111 | img_files = os.listdir(img_path) 112 | label_files = os.listdir(label_path) 113 | for img_file in img_files: 114 | if img_file.endswith((".png")): 115 | preprocess_and_copy_image( 116 | join(img_path, img_file), 117 | join(output_dir, "imgs", subset, img_file), 118 | False, 119 | ) 120 | print("Processing labels of sample {}".format(sample)) 121 | for label_file in tqdm(label_files): 122 | if label_file.endswith((".png")): 123 | preprocess_and_copy_image( 124 | join(label_path, label_file), 125 | join(output_dir, "labels", subset, label_file), 126 | True, 127 | True, 128 | RUGD_CMAP, 129 | ) 130 | 131 | 132 | def main(): 133 | input_dir = join(DATA_DIR, INPUT_NAME) 134 | output_dir = join(DATA_DIR, OUTPUT_NAME) 135 | create_dataset_structure(output_dir) 136 | preprocess_samples(input_dir, output_dir, "train", TRAIN_SAMPLES + VAL_SAMPLES) 137 | preprocess_samples(input_dir, output_dir, "val", TEST_SAMPLES) 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /scripts/data_preprocessing/preprocess_cocostuff.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | ############################################ 8 | # Cocostuff preprocessing script 9 | # 10 | # https://github.com/nightrome/cocostuff 11 | # COCO-Stuff: Thing and Stuff Classes in Context 12 | # H. Caesar, J. Uijlings, V. Ferrari, 13 | # In Computer Vision and Pattern Recognition (CVPR), 2018. 14 | # 15 | # This preprocessing script converts fine labels of the Cocostuff dataset to 27 coarse labels of the dataset. 16 | # Additionally, it converts grayscale, single-channel images to 3-channel images 17 | # 18 | # Expected input structure: 19 | # DATA_DIR 20 | # |-- INPUT_NAME 21 | # |-- images 22 | # |-- train2017 23 | # |-- val2017 24 | # |-- annotations 25 | # |-- train2017 26 | # |-- val2017 27 | # 28 | # Output structure after preprocessing: 29 | # DATA_DIR 30 | # |-- OUTPUT_NAME 31 | # |-- imgs 32 | # |-- train 33 | # |-- val 34 | # |-- labels 35 | # |-- train 36 | # |-- val 37 | ############################################ 38 | 39 | 40 | import os 41 | from os.path import join 42 | import numpy as np 43 | from tqdm import tqdm 44 | from PIL import Image 45 | 46 | from scripts.data_preprocessing.preprocessing_utils import create_dataset_structure 47 | 48 | 49 | DATA_DIR = "/data" 50 | INPUT_NAME = "cocostuff" 51 | OUTPUT_NAME = "cocostuff_preprocessed" 52 | 53 | 54 | def cocostuff_to_27_classes(mask): 55 | fine_to_coarse = { 56 | 0: 9, 57 | 1: 11, 58 | 2: 11, 59 | 3: 11, 60 | 4: 11, 61 | 5: 11, 62 | 6: 11, 63 | 7: 11, 64 | 8: 11, 65 | 9: 8, 66 | 10: 8, 67 | 11: 8, 68 | 12: 8, 69 | 13: 8, 70 | 14: 8, 71 | 15: 7, 72 | 16: 7, 73 | 17: 7, 74 | 18: 7, 75 | 19: 7, 76 | 20: 7, 77 | 21: 7, 78 | 22: 7, 79 | 23: 7, 80 | 24: 7, 81 | 25: 6, 82 | 26: 6, 83 | 27: 6, 84 | 28: 6, 85 | 29: 6, 86 | 30: 6, 87 | 31: 6, 88 | 32: 6, 89 | 33: 10, 90 | 34: 10, 91 | 35: 10, 92 | 36: 10, 93 | 37: 10, 94 | 38: 10, 95 | 39: 10, 96 | 40: 10, 97 | 41: 10, 98 | 42: 10, 99 | 43: 5, 100 | 44: 5, 101 | 45: 5, 102 | 46: 5, 103 | 47: 5, 104 | 48: 5, 105 | 49: 5, 106 | 50: 5, 107 | 51: 2, 108 | 52: 2, 109 | 53: 2, 110 | 54: 2, 111 | 55: 2, 112 | 56: 2, 113 | 57: 2, 114 | 58: 2, 115 | 59: 2, 116 | 60: 2, 117 | 61: 3, 118 | 62: 3, 119 | 63: 3, 120 | 64: 3, 121 | 65: 3, 122 | 66: 3, 123 | 67: 3, 124 | 68: 3, 125 | 69: 3, 126 | 70: 3, 127 | 71: 0, 128 | 72: 0, 129 | 73: 0, 130 | 74: 0, 131 | 75: 0, 132 | 76: 0, 133 | 77: 1, 134 | 78: 1, 135 | 79: 1, 136 | 80: 1, 137 | 81: 1, 138 | 82: 1, 139 | 83: 4, 140 | 84: 4, 141 | 85: 4, 142 | 86: 4, 143 | 87: 4, 144 | 88: 4, 145 | 89: 4, 146 | 90: 4, 147 | 91: 17, 148 | 92: 17, 149 | 93: 22, 150 | 94: 20, 151 | 95: 20, 152 | 96: 22, 153 | 97: 15, 154 | 98: 25, 155 | 99: 16, 156 | 100: 13, 157 | 101: 12, 158 | 102: 12, 159 | 103: 17, 160 | 104: 17, 161 | 105: 23, 162 | 106: 15, 163 | 107: 15, 164 | 108: 17, 165 | 109: 15, 166 | 110: 21, 167 | 111: 15, 168 | 112: 25, 169 | 113: 13, 170 | 114: 13, 171 | 115: 13, 172 | 116: 13, 173 | 117: 13, 174 | 118: 22, 175 | 119: 26, 176 | 120: 14, 177 | 121: 14, 178 | 122: 15, 179 | 123: 22, 180 | 124: 21, 181 | 125: 21, 182 | 126: 24, 183 | 127: 20, 184 | 128: 22, 185 | 129: 15, 186 | 130: 17, 187 | 131: 16, 188 | 132: 15, 189 | 133: 22, 190 | 134: 24, 191 | 135: 21, 192 | 136: 17, 193 | 137: 25, 194 | 138: 16, 195 | 139: 21, 196 | 140: 17, 197 | 141: 22, 198 | 142: 16, 199 | 143: 21, 200 | 144: 21, 201 | 145: 25, 202 | 146: 21, 203 | 147: 26, 204 | 148: 21, 205 | 149: 24, 206 | 150: 20, 207 | 151: 17, 208 | 152: 14, 209 | 153: 21, 210 | 154: 26, 211 | 155: 15, 212 | 156: 23, 213 | 157: 20, 214 | 158: 21, 215 | 159: 24, 216 | 160: 15, 217 | 161: 24, 218 | 162: 22, 219 | 163: 25, 220 | 164: 15, 221 | 165: 20, 222 | 166: 17, 223 | 167: 17, 224 | 168: 22, 225 | 169: 14, 226 | 170: 18, 227 | 171: 18, 228 | 172: 18, 229 | 173: 18, 230 | 174: 18, 231 | 175: 18, 232 | 176: 18, 233 | 177: 26, 234 | 178: 26, 235 | 179: 19, 236 | 180: 19, 237 | 181: 24, 238 | } 239 | new_mask = np.zeros(mask.shape) 240 | for class_id in fine_to_coarse: 241 | new_mask = np.where(mask == class_id, fine_to_coarse[class_id], new_mask) 242 | return new_mask.astype(np.uint8) 243 | 244 | 245 | def preprocess_and_copy_label_cocostuff(input_name, output_name): 246 | if os.path.isfile(output_name): 247 | return 248 | image = Image.open(input_name) 249 | img = np.array(image) 250 | img = cocostuff_to_27_classes(img) 251 | image = Image.fromarray(img) 252 | image.save(output_name) 253 | 254 | 255 | def preprocess_and_copy_image_cocostuff(input_name, output_name): 256 | if os.path.isfile(output_name): 257 | return 258 | image = Image.open(input_name).convert("RGB") 259 | image.save(output_name) 260 | 261 | 262 | def preprocess_samples(input_dir, output_dir, subset, input_subset): 263 | print("Processing subset {}".format(subset)) 264 | label_names = os.listdir(join(input_dir, "annotations", input_subset)) 265 | for label_name in tqdm(label_names): 266 | sample_name = label_name.split(".")[0] 267 | img_path = join(input_dir, "images", input_subset, sample_name + ".jpg") 268 | label_path = join(input_dir, "annotations", input_subset, label_name) 269 | preprocess_and_copy_image_cocostuff(img_path, join(output_dir, "imgs", subset, sample_name + ".jpg")) 270 | preprocess_and_copy_label_cocostuff(label_path, join(output_dir, "labels", subset, sample_name + ".png")) 271 | 272 | 273 | def main(): 274 | input_dir = join(DATA_DIR, INPUT_NAME) 275 | output_dir = join(DATA_DIR, OUTPUT_NAME) 276 | create_dataset_structure(output_dir) 277 | preprocess_samples(input_dir, output_dir, "train", "train2017") 278 | preprocess_samples(input_dir, output_dir, "val", "val2017") 279 | 280 | 281 | if __name__ == "__main__": 282 | main() 283 | -------------------------------------------------------------------------------- /scripts/data_preprocessing/preprocess_freiburg_forest.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | ############################################ 8 | # Freiburg Forest preprocessing script 9 | # 10 | # http://deepscene.cs.uni-freiburg.de/ 11 | # Abhinav Valada, Gabriel L. Oliveira, Thomas Brox, Wolfram Burgard 12 | # Deep Multispectral Semantic Scene Understanding of Forested Environments using Multimodal Fusion 13 | # International Symposium on Experimental Robotics (ISER), Tokyo, Japan, 2016. 14 | # 15 | # 16 | # Download Friburg Forest: 17 | # wget http://deepscene.cs.uni-freiburg.de/static/datasets/download_freiburg_forest_annotated.sh 18 | # bash download_freiburg_forest_annotated.sh 19 | # tar -xzf freiburg_forest_annotated.tar.gz 20 | # rm freiburg_forest_annotated.tar.gz* 21 | # 22 | # 23 | # 24 | # Expected input structure after unpacking in DATA_DIR: 25 | # DATA_DIR 26 | # |-- INPUT_NAME 27 | # |-- train 28 | # |-- rgb 29 | # |-- GT_color 30 | # |-- test 31 | # |-- rgb 32 | # |-- GT_color 33 | # 34 | # Output structure after preprocessing: 35 | # DATA_DIR 36 | # |-- OUTPUT_NAME 37 | # |-- imgs 38 | # |-- train 39 | # |-- val 40 | # |-- labels 41 | # |-- train 42 | # |-- val 43 | ############################################ 44 | 45 | 46 | import os 47 | from os.path import join 48 | import numpy as np 49 | from tqdm import tqdm 50 | from PIL import Image 51 | 52 | from scripts.data_preprocessing.preprocessing_utils import ( 53 | create_dataset_structure, 54 | convert_rgb_label, 55 | preprocess_and_copy_image, 56 | ) 57 | 58 | 59 | DATA_DIR = "/data" 60 | INPUT_NAME = "freiburg_forest_annotated" 61 | OUTPUT_NAME = "freiburg_forest_preprocessed" 62 | 63 | 64 | FF_CMAP = np.array( 65 | [ 66 | (0, 0, 0), # Object 67 | (170, 170, 170), # Road 68 | (0, 255, 0), # Grass 69 | (102, 102, 51), # Vegetation 70 | (0, 120, 255), # Sky 71 | ( 72 | 0, 73 | 60, 74 | 0, 75 | ), # Tree (separate color present in the dataset, but assigned to class Vegetation in the dataset's official readme) 76 | ] 77 | ) 78 | 79 | 80 | def preprocess_and_copy_label_FF(input_name, output_name, cmap): 81 | if os.path.isfile(output_name): 82 | return 83 | image = Image.open(input_name).convert("RGB") 84 | img = np.array(image) 85 | img = convert_rgb_label(img, cmap) 86 | img[img == 5] = 3 # Class Tree assigned to Vegetation 87 | image = Image.fromarray(img) 88 | image.save(output_name) 89 | 90 | 91 | def preprocess_samples(input_dir, output_dir, subset, input_subset): 92 | print("Processing subset {}".format(subset)) 93 | label_names = os.listdir(join(input_dir, input_subset, "GT_color")) 94 | for label_name in tqdm(label_names): 95 | sample_name = label_name.split("_")[0] 96 | img_path = join(input_dir, input_subset, "rgb", sample_name + "_Clipped.jpg") 97 | label_path = join(input_dir, input_subset, "GT_color", label_name) 98 | preprocess_and_copy_image(img_path, join(output_dir, "imgs", subset, sample_name + ".jpg"), False) 99 | preprocess_and_copy_label_FF( 100 | label_path, 101 | join(output_dir, "labels", subset, sample_name + ".png"), 102 | FF_CMAP, 103 | ) 104 | 105 | 106 | def main(): 107 | input_dir = join(DATA_DIR, INPUT_NAME) 108 | output_dir = join(DATA_DIR, OUTPUT_NAME) 109 | create_dataset_structure(output_dir) 110 | preprocess_samples(input_dir, output_dir, "train", "train") 111 | preprocess_samples(input_dir, output_dir, "val", "test") 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /scripts/data_preprocessing/preprocessing_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | import os 8 | from os.path import join 9 | 10 | import numpy as np 11 | from PIL import Image 12 | import shutil 13 | 14 | 15 | def create_dataset_structure(dataset_dir): 16 | os.makedirs(dataset_dir, exist_ok=True) 17 | os.makedirs(join(dataset_dir, "imgs"), exist_ok=True) 18 | os.makedirs(join(dataset_dir, "imgs", "train"), exist_ok=True) 19 | os.makedirs(join(dataset_dir, "imgs", "val"), exist_ok=True) 20 | os.makedirs(join(dataset_dir, "labels"), exist_ok=True) 21 | os.makedirs(join(dataset_dir, "labels", "train"), exist_ok=True) 22 | os.makedirs(join(dataset_dir, "labels", "val"), exist_ok=True) 23 | 24 | 25 | def convert_rgb_label(label, cmap): 26 | for i in range(cmap.shape[0]): 27 | color = cmap[i] 28 | indices = np.all(label == color, axis=2) 29 | label[indices] = i 30 | return np.unique(label, axis=-1).squeeze() 31 | 32 | 33 | def preprocess_and_copy_image(input_name, output_name, is_label=False, rgb_label=False, cmap=None): 34 | if os.path.isfile(output_name): 35 | return 36 | if is_label and rgb_label: 37 | if cmap is None: 38 | raise ValueError("No colormap provided to convert the RGB label") 39 | image = Image.open(input_name).convert("RGB") 40 | img = np.array(image) 41 | img = convert_rgb_label(img, cmap) 42 | image = Image.fromarray(img) 43 | image.save(output_name) 44 | else: 45 | shutil.copyfile(input_name, output_name) 46 | -------------------------------------------------------------------------------- /scripts/demo_segmentation.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | ############################################ 9 | # Segmentation demonstration with STEGO 10 | # 11 | # This script can be used to generate segmentations of given images with the given STEGO checkpoint 12 | # 13 | # Before running the script, adjust the parameters in cfg/demo_config.yaml: 14 | # - image_dir - path to the folder with images (images from its subfolders won't be processed) 15 | # - model_path - path to the STEGO checkpoint 16 | # - output_root - path to the folder to save the segmentation in (segmentations will be saved in a subfolder named after experiment_name) 17 | # 18 | ############################################ 19 | 20 | 21 | import hydra 22 | import torch.multiprocessing 23 | from PIL import Image 24 | from omegaconf import DictConfig 25 | from torch.utils.data import DataLoader 26 | from tqdm import tqdm 27 | import os 28 | import numpy as np 29 | 30 | from stego.utils import prep_args, flexible_collate, get_transform 31 | from stego.data import UnlabeledImageFolder, create_cityscapes_colormap 32 | from stego.stego import Stego 33 | 34 | 35 | torch.multiprocessing.set_sharing_strategy("file_system") 36 | 37 | 38 | @hydra.main(config_path="cfg", config_name="demo_config.yaml") 39 | def my_app(cfg: DictConfig) -> None: 40 | result_dir = os.path.join(cfg.output_root, cfg.experiment_name) 41 | os.makedirs(result_dir, exist_ok=True) 42 | os.makedirs(os.path.join(result_dir, "cluster"), exist_ok=True) 43 | os.makedirs(os.path.join(result_dir, "linear"), exist_ok=True) 44 | 45 | model = Stego.load_from_checkpoint(cfg.model_path) 46 | 47 | dataset = UnlabeledImageFolder( 48 | root=cfg.image_dir, 49 | transform=get_transform(cfg.resolution, False, "center"), 50 | ) 51 | 52 | loader = DataLoader( 53 | dataset, 54 | cfg.batch_size * 2, 55 | shuffle=False, 56 | num_workers=cfg.num_workers, 57 | pin_memory=True, 58 | collate_fn=flexible_collate, 59 | ) 60 | 61 | model.eval().cuda() 62 | cmap = create_cityscapes_colormap() 63 | 64 | for i, (img, name) in enumerate(tqdm(loader)): 65 | with torch.no_grad(): 66 | img = img.cuda() 67 | code = model.get_code(img) 68 | cluster_crf, linear_crf = model.postprocess( 69 | code=code, 70 | img=img, 71 | use_crf_cluster=cfg.run_crf, 72 | use_crf_linear=cfg.run_crf, 73 | ) 74 | cluster_crf = cluster_crf.cpu() 75 | linear_crf = linear_crf.cpu() 76 | for j in range(img.shape[0]): 77 | new_name = ".".join(name[j].split(".")[:-1]) + ".png" 78 | Image.fromarray(cmap[linear_crf[j]].astype(np.uint8)).save(os.path.join(result_dir, "linear", new_name)) 79 | Image.fromarray(cmap[cluster_crf[j]].astype(np.uint8)).save( 80 | os.path.join(result_dir, "cluster", new_name) 81 | ) 82 | 83 | 84 | if __name__ == "__main__": 85 | prep_args() 86 | my_app() 87 | -------------------------------------------------------------------------------- /scripts/download_stego_datasets.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | ############################################ 9 | # Download datasets used by Hamilton et al. 10 | # 11 | # In case of problems, try azcopy (see README). 12 | # 13 | ############################################ 14 | 15 | 16 | import hydra 17 | from omegaconf import DictConfig 18 | import os 19 | from os.path import join 20 | import wget 21 | 22 | from stego.utils import prep_args 23 | 24 | 25 | @hydra.main(config_path="configs", config_name="eval_config.yml") 26 | def my_app(cfg: DictConfig) -> None: 27 | pytorch_data_dir = cfg.pytorch_data_dir 28 | dataset_names = ["potsdam", "cityscapes", "cocostuff", "potsdamraw"] 29 | url_base = "https://marhamilresearch4.blob.core.windows.net/stego-public/pytorch_data/" 30 | 31 | os.makedirs(pytorch_data_dir, exist_ok=True) 32 | for dataset_name in dataset_names: 33 | if (not os.path.exists(join(pytorch_data_dir, dataset_name))) or ( 34 | not os.path.exists(join(pytorch_data_dir, dataset_name + ".zip")) 35 | ): 36 | print("\n Downloading {}".format(dataset_name)) 37 | wget.download( 38 | url_base + dataset_name + ".zip", 39 | join(pytorch_data_dir, dataset_name + ".zip"), 40 | ) 41 | else: 42 | print("\n Found {}, skipping download".format(dataset_name)) 43 | 44 | 45 | if __name__ == "__main__": 46 | prep_args() 47 | my_app() 48 | -------------------------------------------------------------------------------- /scripts/download_stego_models.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | ############################################ 9 | # Download model checkpoints trained by Hamilton et al. 10 | # 11 | ############################################ 12 | 13 | 14 | from os.path import join, exists 15 | import wget 16 | import os 17 | 18 | models_dir = "models" 19 | os.makedirs(models_dir, exist_ok=True) 20 | model_url_root = "https://marhamilresearch4.blob.core.windows.net/stego-public/models/models/" 21 | model_names = [] 22 | # Optionally, uncomment to download all original models: 23 | # model_names = ["moco_v2_800ep_pretrain.pth.tar", 24 | # "model_epoch_0720_iter_085000.pth", 25 | # "picie.pkl"] 26 | 27 | saved_model_url_root = "https://marhamilresearch4.blob.core.windows.net/stego-public/saved_models/" 28 | saved_model_names = [ 29 | "cityscapes_vit_base_1.ckpt", 30 | "cocostuff27_vit_base_5.ckpt", 31 | "picie_and_probes.pth", 32 | "potsdam_test.ckpt", 33 | ] 34 | 35 | target_files = [join(models_dir, mn) for mn in model_names] + [join(models_dir, mn) for mn in saved_model_names] 36 | 37 | target_urls = [model_url_root + mn for mn in model_names] + [saved_model_url_root + mn for mn in saved_model_names] 38 | 39 | for target_file, target_url in zip(target_files, target_urls): 40 | if not exists(target_file): 41 | print("\nDownloading file from {}".format(target_url)) 42 | wget.download(target_url, target_file) 43 | else: 44 | print("\nFound {}, skipping download".format(target_file)) 45 | -------------------------------------------------------------------------------- /scripts/eval_clusters_wvn.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | ############################################ 8 | # STEGO for WVN experiment 9 | # 10 | # This script calculates the WVN metrics of SLIC and STEGO segmentation models. 11 | # 12 | # Before running the script, adjust parameters in cfg/eval_wvn_config.yaml: 13 | # - data_dir and dataset_name - the input data should be proprocessed for the WVN experiment 14 | # - model_paths and stego_n_clusters - paths to STEGO checkpoints and the number of clusters of each model 15 | # - output_root and experiment_name - outputs will be save in the given directory in the subfolder named after experiment_name 16 | # - optionally, adjust other parameters, e.g. to consider SLIC segmentations, pre-image STEGO clustering 17 | # 18 | ############################################ 19 | 20 | 21 | import os 22 | from os.path import join 23 | 24 | # from multiprocessing import Pool 25 | import hydra 26 | import torch.multiprocessing 27 | from omegaconf import DictConfig 28 | from torch.utils.data import DataLoader 29 | from tqdm import tqdm 30 | import matplotlib as plt 31 | from fast_slic import Slic 32 | import kornia 33 | from pytictoc import TicToc 34 | import time 35 | import warnings 36 | import numpy as np 37 | from PIL import Image 38 | 39 | from stego.utils import ( 40 | prep_args, 41 | plot_distributions, 42 | unnorm, 43 | WVNMetrics, 44 | flexible_collate, 45 | get_transform, 46 | ) 47 | from stego.stego import Stego 48 | from stego.data import ContrastiveSegDataset 49 | 50 | torch.multiprocessing.set_sharing_strategy("file_system") 51 | warnings.filterwarnings("ignore") 52 | 53 | 54 | @hydra.main(config_path="cfg", config_name="eval_wvn_config.yaml") 55 | def my_app(cfg: DictConfig) -> None: 56 | result_dir = os.path.join(cfg.output_root, cfg.experiment_name) 57 | 58 | plot_dir = None 59 | if cfg.save_plots: 60 | plot_dir = join(result_dir, "plots") 61 | os.makedirs(plot_dir, exist_ok=True) 62 | plt.switch_backend("agg") 63 | 64 | if cfg.save_vis: 65 | os.makedirs(join(result_dir, "img"), exist_ok=True) 66 | os.makedirs(join(result_dir, "label"), exist_ok=True) 67 | 68 | models = [] 69 | for model_path in cfg.model_paths: 70 | models.append(Stego.load_from_checkpoint(model_path)) 71 | 72 | slic_models = [] 73 | for n_clusters in cfg.slic_n_clusters: 74 | slic_models.append(Slic(num_components=n_clusters, compactness=cfg.slic_compactness)) 75 | 76 | if cfg.save_vis: 77 | for n_clusters in cfg.stego_n_clusters: 78 | os.makedirs(join(result_dir, "stego_" + str(n_clusters)), exist_ok=True) 79 | if cfg.cluster_stego_by_image: 80 | os.makedirs(join(result_dir, "stego_code_" + str(n_clusters)), exist_ok=True) 81 | for n_clusters in cfg.slic_n_clusters: 82 | os.makedirs(join(result_dir, "slic_" + str(n_clusters)), exist_ok=True) 83 | 84 | test_dataset = ContrastiveSegDataset( 85 | data_dir=cfg.data_dir, 86 | dataset_name=cfg.dataset_name, 87 | image_set="val", 88 | transform=get_transform(cfg.resolution, False, "center"), 89 | target_transform=get_transform(cfg.resolution, True, "center"), 90 | model_type="dino", 91 | resolution=cfg.resolution, 92 | ) 93 | 94 | test_loader = DataLoader( 95 | test_dataset, 96 | 1, 97 | shuffle=False, 98 | num_workers=cfg.num_workers, 99 | pin_memory=True, 100 | collate_fn=flexible_collate, 101 | ) 102 | 103 | for model in models: 104 | model.eval().cuda() 105 | 106 | model_metrics = [ 107 | WVNMetrics("Stego_" + str(i), i, save_plots=cfg.save_plots, output_dir=plot_dir) for i in cfg.stego_n_clusters 108 | ] 109 | slic_metrics = [ 110 | WVNMetrics("SLIC_" + str(i), i, save_plots=cfg.save_plots, output_dir=plot_dir) for i in cfg.slic_n_clusters 111 | ] 112 | if cfg.cluster_stego_by_image: 113 | model_cluster_metrics = [ 114 | WVNMetrics( 115 | "Stego_code_" + str(i), 116 | i, 117 | save_plots=cfg.save_plots, 118 | output_dir=plot_dir, 119 | ) 120 | for i in cfg.stego_n_clusters 121 | ] 122 | 123 | t = TicToc() 124 | feature_times = [] 125 | for i, batch in enumerate(tqdm(test_loader)): 126 | if cfg.n_imgs is not None and i >= cfg.n_imgs: 127 | break 128 | with torch.no_grad(): 129 | img = batch["img"].squeeze() 130 | label = batch["label"].squeeze() 131 | 132 | if cfg.save_vis: 133 | image = Image.fromarray((kornia.utils.tensor_to_image(unnorm(img).cpu()) * 255).astype(np.uint8)) 134 | image.save(join(result_dir, "img", str(i) + ".png")) 135 | label_img = label.cpu().detach().numpy().astype(np.uint8) 136 | image = Image.fromarray(label_img) 137 | image.save(join(result_dir, "label", str(i) + ".png")) 138 | 139 | features = None 140 | code = None 141 | 142 | for model_index, model in enumerate(models): 143 | n_clusters = cfg.stego_n_clusters[model_index] 144 | t.tic() 145 | features, code = model(batch["img"].cuda()) 146 | feature_times.append(t.tocvalue(restart=True)) 147 | clusters = model.postprocess_cluster(code=code, img=batch["img"], use_crf=cfg.run_crf) 148 | time_val = t.tocvalue() 149 | model_metrics[model_index].update(clusters, label, features, code, time_val) 150 | if cfg.save_vis: 151 | image = Image.fromarray((clusters.squeeze().cpu().numpy()).astype(np.uint8)) 152 | image.save(join(result_dir, "stego_" + str(n_clusters), str(i) + ".png")) 153 | if cfg.cluster_stego_by_image: 154 | t.tic() 155 | clusters = model.postprocess_cluster( 156 | code=code, 157 | img=batch["img"], 158 | use_crf=cfg.run_crf, 159 | image_clustering=True, 160 | ) 161 | time_val = t.tocvalue() 162 | model_cluster_metrics[model_index].update(clusters, label, features, code, time_val) 163 | if cfg.save_vis: 164 | image = Image.fromarray((clusters.squeeze().cpu().numpy()).astype(np.uint8)) 165 | image.save( 166 | join( 167 | result_dir, 168 | "stego_code_" + str(n_clusters), 169 | str(i) + ".png", 170 | ) 171 | ) 172 | 173 | for model_index, model in enumerate(slic_models): 174 | img_np = kornia.utils.tensor_to_image(unnorm(img).cpu()) 175 | t.tic() 176 | clusters = model.iterate(np.uint8(np.ascontiguousarray(img_np) * 255)) 177 | time_val = t.tocvalue() 178 | slic_metrics[model_index].update(torch.from_numpy(clusters), label.cpu(), features, code, time_val) 179 | if cfg.save_vis: 180 | n_clusters = cfg.slic_n_clusters[model_index] 181 | image = Image.fromarray((clusters).astype(np.uint8)) 182 | image.save(join(result_dir, "slic_" + str(n_clusters), str(i) + ".png")) 183 | 184 | feature_times_np = np.array(feature_times) 185 | print("Feature extraction time: Mean: {} Var: {}".format(np.mean(feature_times_np), np.var(feature_times_np))) 186 | 187 | model_values = [] 188 | for metric in model_metrics: 189 | results, values = metric.compute(print_metrics=True) 190 | model_values.append(values) 191 | print() 192 | 193 | model_cluster_values = [] 194 | if cfg.cluster_stego_by_image: 195 | for metric in model_cluster_metrics: 196 | results, values = metric.compute(print_metrics=True) 197 | model_cluster_values.append(values) 198 | print() 199 | 200 | slic_values = [] 201 | for metric in slic_metrics: 202 | results, values = metric.compute(print_metrics=True) 203 | slic_values.append(values) 204 | 205 | time_now = int(time.time()) 206 | if cfg.save_plots and cfg.save_comparison_plots: 207 | for metric in ["Avg_clusters", "Feature_var", "Code_var", "Time"]: 208 | metric_stego_values = [values[metric] for values in model_values] 209 | metric_stego_names = ["Stego_" + str(i) for i in cfg.stego_n_clusters] 210 | plot_distributions( 211 | metric_stego_values, 212 | 100, 213 | metric_stego_names, 214 | metric, 215 | os.path.join( 216 | plot_dir, 217 | "Comparison_" + metric + "_Stego_" + str(time_now) + ".png", 218 | ), 219 | ) 220 | metric_slic_values = [values[metric] for values in slic_values] 221 | metric_slic_names = ["SLIC_" + str(i) for i in cfg.slic_n_clusters] 222 | plot_distributions( 223 | metric_slic_values, 224 | 100, 225 | metric_slic_names, 226 | metric, 227 | os.path.join(plot_dir, "Comparison_" + metric + "_SLIC_" + str(time_now) + ".png"), 228 | ) 229 | 230 | metric_stego_cluster_values = [] 231 | metric_stego_cluster_names = [] 232 | if cfg.cluster_stego_by_image: 233 | metric_stego_cluster_values = [values[metric] for values in model_cluster_values] 234 | metric_stego_cluster_names = ["Stego_code_" + str(i) for i in cfg.stego_n_clusters] 235 | plot_distributions( 236 | metric_stego_cluster_values, 237 | 100, 238 | metric_stego_cluster_names, 239 | metric, 240 | os.path.join( 241 | plot_dir, 242 | "Comparison_" + metric + "_Stego_code_" + str(time_now) + ".png", 243 | ), 244 | ) 245 | 246 | plot_distributions( 247 | metric_stego_values + metric_slic_values + metric_stego_cluster_values, 248 | 100, 249 | metric_stego_names + metric_slic_names + metric_stego_cluster_names, 250 | metric, 251 | os.path.join(plot_dir, "Comparison_" + metric + "_all_" + str(time_now) + ".png"), 252 | ) 253 | 254 | 255 | if __name__ == "__main__": 256 | prep_args() 257 | my_app() 258 | -------------------------------------------------------------------------------- /scripts/eval_segmentation.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | ############################################ 9 | # Evaluation of STEGO segmentation 10 | # 11 | # This script calculates the unsupervised metrics for the given STEGO checkpoint and dataset 12 | # 13 | # Before running, adjust parameters in cfg/eval_config.yaml 14 | # 15 | ############################################ 16 | 17 | 18 | import os 19 | from os.path import join 20 | from collections import defaultdict 21 | from multiprocessing import Pool 22 | import hydra 23 | import seaborn as sns 24 | import torch.multiprocessing 25 | from omegaconf import DictConfig, OmegaConf 26 | from torch.utils.data import DataLoader 27 | from tqdm import tqdm 28 | import matplotlib as plt 29 | 30 | from stego.utils import * 31 | from stego.stego import Stego 32 | from stego.data import ContrastiveSegDataset 33 | 34 | torch.multiprocessing.set_sharing_strategy("file_system") 35 | 36 | 37 | @hydra.main(config_path="cfg", config_name="eval_config.yaml") 38 | def my_app(cfg: DictConfig) -> None: 39 | result_dir = os.path.join(cfg.output_root, cfg.experiment_name) 40 | os.makedirs(join(result_dir, "img"), exist_ok=True) 41 | os.makedirs(join(result_dir, "label"), exist_ok=True) 42 | os.makedirs(join(result_dir, "cluster"), exist_ok=True) 43 | os.makedirs(join(result_dir, "picie"), exist_ok=True) 44 | 45 | model = Stego.load_from_checkpoint(cfg.model_path) 46 | test_dataset = ContrastiveSegDataset( 47 | data_dir=cfg.data_dir, 48 | dataset_name=cfg.dataset_name, 49 | image_set="val", 50 | transform=get_transform(cfg.resolution, False, "center"), 51 | target_transform=get_transform(cfg.resolution, True, "center"), 52 | model_type=model.backbone_name, 53 | resolution=cfg.resolution, 54 | ) 55 | 56 | test_loader = DataLoader( 57 | test_dataset, 58 | cfg.batch_size, 59 | shuffle=True, 60 | num_workers=cfg.num_workers, 61 | pin_memory=True, 62 | collate_fn=flexible_collate, 63 | ) 64 | 65 | model.eval().cuda() 66 | 67 | for i, batch in enumerate(tqdm(test_loader)): 68 | if cfg.n_batches is not None and i >= cfg.n_batches: 69 | break 70 | with torch.no_grad(): 71 | img = batch["img"].cuda() 72 | label = batch["label"].cuda() 73 | 74 | code = model.get_code(img) 75 | cluster_preds, linear_preds = model.postprocess( 76 | code=code, 77 | img=img, 78 | use_crf_cluster=cfg.run_crf, 79 | use_crf_linear=cfg.run_crf, 80 | ) 81 | 82 | model.linear_metrics.update(linear_preds.cuda(), label) 83 | model.cluster_metrics.update(cluster_preds.cuda(), label) 84 | 85 | tb_metrics = { 86 | **model.linear_metrics.compute(), 87 | **model.cluster_metrics.compute(), 88 | } 89 | print(tb_metrics) 90 | 91 | 92 | if __name__ == "__main__": 93 | prep_args() 94 | my_app() 95 | -------------------------------------------------------------------------------- /scripts/plot.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | ############################################ 9 | # Plotting script for STEGO experiments 10 | # 11 | # Two types of plots are available: 12 | # - Correspondence plot - an interactive plot visualizing cosine similarities 13 | # between all features in the image and the selected query feature 14 | # - Precision-recall curves - a given STEGO checkpoint can be evaluated 15 | # on input data in predicting label co-occurrence with feature similarities 16 | # 17 | # Before running, adjust the parameters in cfg/plot_config.yaml 18 | # 19 | ############################################ 20 | 21 | 22 | from os.path import join 23 | import hydra 24 | 25 | # import matplotlib.animation as animation 26 | import matplotlib.pyplot as plt 27 | import pickle 28 | import numpy as np 29 | import torch 30 | import torch.nn.functional as F 31 | from omegaconf import DictConfig, OmegaConf 32 | from tqdm import tqdm 33 | from mpl_toolkits.axes_grid1 import make_axes_locatable 34 | from sklearn.metrics import precision_recall_curve, average_precision_score 35 | 36 | from stego.stego import Stego 37 | from stego.data import ContrastiveSegDataset 38 | from stego.utils import ( 39 | prep_args, 40 | get_transform, 41 | remove_axes, 42 | sample, 43 | tensor_correlation, 44 | prep_for_plot, 45 | load_image_to_tensor, 46 | norm 47 | ) 48 | 49 | 50 | class Plotter: 51 | """ 52 | This class collects methods used for plot generation. 53 | """ 54 | 55 | def __init__(self, cfg): 56 | self.cfg = cfg 57 | if cfg.model_path is not None: 58 | self.stego = Stego.load_from_checkpoint(cfg.model_path).cuda() 59 | else: 60 | self.stego = Stego(1).cuda() 61 | 62 | def reset_axes(self, axes): 63 | axes[0].clear() 64 | remove_axes(axes) 65 | axes[0].set_title("Image A and Query Point", fontsize=20) 66 | axes[1].set_title("Feature Cosine Similarity", fontsize=20) 67 | axes[2].set_title("Image B", fontsize=20) 68 | 69 | def get_heatmaps(self, img, img_pos, query_points, zero_mean=True, zero_clamp=True): 70 | """ 71 | Runs STEGO on the given pair of images (img, img_pos) 72 | Generates a 2D heatmap of cosine similarities between STEGO's backbone features 73 | """ 74 | 75 | feats1, _ = self.stego.forward(img.cuda()) 76 | feats2, _ = self.stego.forward(img_pos.cuda()) 77 | 78 | sfeats1 = sample(feats1, query_points) 79 | 80 | attn_intra = torch.einsum("nchw,ncij->nhwij", F.normalize(sfeats1, dim=1), F.normalize(feats1, dim=1)) 81 | if zero_mean: 82 | attn_intra -= attn_intra.mean([3, 4], keepdims=True) 83 | if zero_clamp: 84 | attn_intra = attn_intra.clamp(0).squeeze(0) 85 | else: 86 | attn_intra = attn_intra.squeeze(0) 87 | 88 | attn_inter = torch.einsum("nchw,ncij->nhwij", F.normalize(sfeats1, dim=1), F.normalize(feats2, dim=1)) 89 | if zero_mean: 90 | attn_inter -= attn_inter.mean([3, 4], keepdims=True) 91 | if zero_clamp: 92 | attn_inter = attn_inter.clamp(0).squeeze(0) 93 | else: 94 | attn_inter = attn_inter.squeeze(0) 95 | 96 | heatmap_intra = ( 97 | F.interpolate(attn_intra, img.shape[2:], mode="bilinear", align_corners=True).squeeze(0).detach().cpu() 98 | ) 99 | heatmap_inter = ( 100 | F.interpolate(attn_inter, img_pos.shape[2:], mode="bilinear", align_corners=True).squeeze(0).detach().cpu() 101 | ) 102 | 103 | return heatmap_intra, heatmap_inter 104 | 105 | def plot_figure(self, img_a, img_b, query_point, axes, fig): 106 | """ 107 | Plots a single visualization in the interactive correspondence figure. 108 | """ 109 | _, heatmap_correspondence = self.get_heatmaps( 110 | img_a, 111 | img_b, 112 | query_point, 113 | zero_mean=self.cfg.zero_mean, 114 | zero_clamp=self.cfg.zero_clamp, 115 | ) 116 | point = ((query_point[0, 0, 0] + 1) / 2 * self.cfg.display_resolution).cpu() 117 | self.reset_axes(axes) 118 | axes[0].imshow(prep_for_plot(img_a[0], rescale=False)) 119 | axes[2].imshow(prep_for_plot(img_b[0], rescale=False)) 120 | axes[0].scatter(point[0], point[1], color=(1, 0, 0), marker="x", s=500, linewidths=5) 121 | 122 | img_b_bw = prep_for_plot(img_b[0], rescale=False) * 0.8 123 | img_b_bw = np.ones_like(img_b_bw) * np.expand_dims( 124 | np.dot(np.array(img_b_bw)[..., :3], [0.2989, 0.5870, 0.1140]), -1 125 | ) 126 | axes[1].imshow(img_b_bw) 127 | im1 = None 128 | if self.cfg.zero_clamp: 129 | im1 = axes[1].imshow( 130 | heatmap_correspondence[0], 131 | alpha=0.5, 132 | cmap=self.cfg.cmap, 133 | vmin=0.0, 134 | vmax=1.0, 135 | ) 136 | else: 137 | im1 = axes[1].imshow( 138 | heatmap_correspondence[0], 139 | alpha=0.5, 140 | cmap=self.cfg.cmap, 141 | vmin=-1.0, 142 | vmax=1.0, 143 | ) 144 | 145 | divider = make_axes_locatable(axes[1]) 146 | cax = divider.append_axes("right", size="5%", pad=0.05) 147 | color_bar = fig.colorbar(im1, cax=cax, orientation="vertical") 148 | color_bar.set_alpha(1) 149 | color_bar.draw_all() 150 | plt.draw() 151 | 152 | def plot_correspondences_interactive(self): 153 | """ 154 | Plots the interactive correspondence figure and updates according to user input. 155 | """ 156 | img_a = load_image_to_tensor(self.cfg.image_a_path, self.cfg.display_resolution) 157 | image_b_path = self.cfg.image_b_path 158 | if image_b_path is None: 159 | image_b_path = self.cfg.image_a_path 160 | img_b = load_image_to_tensor( 161 | image_b_path, 162 | self.cfg.display_resolution, 163 | self.cfg.brightness_factor, 164 | self.cfg.contrast_factor, 165 | self.cfg.saturation_factor, 166 | self.cfg.hue_factor, 167 | self.cfg.gaussian_sigma, 168 | self.cfg.gaussian_kernel_size, 169 | ) 170 | 171 | fig, axes = plt.subplots(1, 3, figsize=(3 * 5, 1 * 5), dpi=100) 172 | self.reset_axes(axes) 173 | fig.tight_layout() 174 | 175 | def onclick(event): 176 | if event.xdata is not None and event.ydata is not None: 177 | x = (event.xdata - self.cfg.display_resolution / 2) / (self.cfg.display_resolution / 2) 178 | y = (event.ydata - self.cfg.display_resolution / 2) / (self.cfg.display_resolution / 2) 179 | query_point = torch.tensor([[x, y]]).float().reshape(1, 1, 1, 2).cuda() 180 | self.plot_figure(img_a, img_b, query_point, axes, fig) 181 | 182 | fig.canvas.mpl_connect("button_press_event", onclick) 183 | query_point = torch.tensor([[0.0, 0.0]]).reshape(1, 1, 1, 2).cuda() 184 | self.plot_figure(img_a, img_b, query_point, axes, fig) 185 | plt.show() 186 | 187 | def get_net_fd(self, feats1, feats2, label1, label2, coords1, coords2): 188 | with torch.no_grad(): 189 | feat_samples1 = sample(feats1, coords1) 190 | feat_samples2 = sample(feats2, coords2) 191 | label_samples1 = sample( 192 | F.one_hot(label1 + 1, self.n_classes + 1).to(torch.float).permute(0, 3, 1, 2), 193 | coords1, 194 | ) 195 | label_samples2 = sample( 196 | F.one_hot(label2 + 1, self.n_classes + 1).to(torch.float).permute(0, 3, 1, 2), 197 | coords2, 198 | ) 199 | fd = tensor_correlation(norm(feat_samples1), norm(feat_samples2)) 200 | ld = tensor_correlation(label_samples1, label_samples2) 201 | return ld, fd, label_samples1.argmax(1), label_samples2.argmax(1) 202 | 203 | def prep_fd(self, fd): 204 | fd -= fd.min() 205 | fd /= fd.max() 206 | return fd.reshape(-1) 207 | 208 | def generate_pr_plot(self, preds, targets, name): 209 | preds = preds.cpu().reshape(-1) 210 | preds -= preds.min() 211 | preds /= preds.max() 212 | targets = targets.to(torch.int64).cpu().reshape(-1) 213 | precisions, recalls, _ = precision_recall_curve(targets, preds) 214 | average_precision = average_precision_score(targets, preds) 215 | data = { 216 | "precisions": precisions, 217 | "recalls": recalls, 218 | "average_precision": average_precision, 219 | "name": name, 220 | } 221 | with open( 222 | join( 223 | self.cfg.pr_output_data_dir, 224 | self.cfg.dataset_name 225 | + "_" 226 | + self.stego.full_backbone_name 227 | + "_" 228 | + str(self.cfg.pr_resolution) 229 | + ".pkl", 230 | ), 231 | "wb", 232 | ) as handle: 233 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 234 | plt.plot( 235 | recalls, 236 | precisions, 237 | label="AP={}% {}".format(int(average_precision * 100), name), 238 | ) 239 | 240 | def plot_pr(self): 241 | self.n_classes = 256 # model.n_classes 242 | if self.cfg.plot_stego_pr or self.cfg.plot_backbone_pr: 243 | val_loader_crop = "center" 244 | val_dataset = ContrastiveSegDataset( 245 | data_dir=self.cfg.data_dir, 246 | dataset_name=self.cfg.dataset_name, 247 | image_set="val", 248 | transform=get_transform(self.cfg.pr_resolution, False, val_loader_crop), 249 | target_transform=get_transform(self.cfg.pr_resolution, True, val_loader_crop), 250 | model_type=self.stego.backbone_name, 251 | resolution=self.cfg.pr_resolution, 252 | mask=True, 253 | pos_images=True, 254 | pos_labels=True, 255 | ) 256 | print("Calculating PR curves for {} with model {}".format(self.cfg.dataset_name, self.cfg.model_path)) 257 | lds = [] 258 | backbone_fds = [] 259 | stego_fds = [] 260 | for data in tqdm(val_dataset): 261 | img = torch.unsqueeze(data["img"], dim=0).cuda() 262 | label = data["label"].cuda() 263 | feats, code = self.stego.forward(img) 264 | coord_shape = [ 265 | img.shape[0], 266 | self.stego.cfg.feature_samples, 267 | self.stego.cfg.feature_samples, 268 | 2, 269 | ] 270 | coords1 = torch.rand(coord_shape, device=img.device) * 2 - 1 271 | coords2 = torch.rand(coord_shape, device=img.device) * 2 - 1 272 | ld, stego_fd, _, _ = self.get_net_fd(code, code, label, label, coords1, coords2) 273 | ld, backbone_fd, _, _ = self.get_net_fd(feats, feats, label, label, coords1, coords2) 274 | lds.append(ld) 275 | backbone_fds.append(backbone_fd) 276 | stego_fds.append(stego_fd) 277 | ld = torch.cat(lds, dim=0) 278 | backbone_fd = torch.cat(backbone_fds, dim=0) 279 | stego_fd = torch.cat(stego_fds, dim=0) 280 | if self.cfg.plot_stego_pr: 281 | self.generate_pr_plot(self.prep_fd(stego_fd), ld, "Stego") 282 | if self.cfg.plot_backbone_pr: 283 | self.generate_pr_plot(self.prep_fd(backbone_fd), ld, self.stego.full_backbone_name) 284 | for filename in self.cfg.additional_pr_curves: 285 | with open(join(self.cfg.pr_output_data_dir, filename), "rb") as handle: 286 | data = pickle.load(handle) 287 | plt.plot( 288 | data["recalls"], 289 | data["precisions"], 290 | label="AP={}% {}".format(int(data["average_precision"] * 100), data["name"]), 291 | ) 292 | plt.xlim([0, 1]) 293 | plt.ylim([0, 1]) 294 | plt.legend(fontsize=12) 295 | plt.ylabel("Precision", fontsize=16) 296 | plt.xlabel("Recall", fontsize=16) 297 | plt.tight_layout() 298 | plt.savefig( 299 | join( 300 | self.cfg.pr_output_dir, 301 | self.cfg.dataset_name + "_" + self.stego.full_backbone_name + ".png", 302 | ) 303 | ) 304 | plt.show() 305 | 306 | def plot(self): 307 | if self.cfg.plot_correspondences_interactive: 308 | self.plot_correspondences_interactive() 309 | if self.cfg.plot_pr: 310 | plt.switch_backend("agg") 311 | self.plot_pr() 312 | 313 | 314 | @hydra.main(config_path="cfg", config_name="plot_config.yaml") 315 | def my_app(cfg: DictConfig) -> None: 316 | print(OmegaConf.to_yaml(cfg)) 317 | plotter = Plotter(cfg) 318 | plotter.plot() 319 | 320 | 321 | if __name__ == "__main__": 322 | prep_args() 323 | my_app() 324 | -------------------------------------------------------------------------------- /scripts/precompute_knns.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | ############################################ 9 | # KNN computation for datasets used for training STEGO 10 | # 11 | # This script generates the KNN file for a new dataset to be used with STEGO. 12 | # Before running the script, preprocess the dataset (including cropping). 13 | # Adjust the path to the dataset, subsets to be processed and target resolution in cfg/knn_config.yaml 14 | # 15 | ############################################ 16 | 17 | import os 18 | from os.path import join 19 | import hydra 20 | import numpy as np 21 | import torch.multiprocessing 22 | import torch.nn.functional as F 23 | from omegaconf import DictConfig, OmegaConf 24 | from pytorch_lightning import seed_everything 25 | from torch.utils.data import DataLoader 26 | from tqdm import tqdm 27 | 28 | from stego.data import ContrastiveSegDataset 29 | from stego.stego import Stego 30 | from stego.utils import prep_args, get_transform, get_nn_file_name 31 | 32 | 33 | def get_feats(model, loader): 34 | all_feats = [] 35 | for pack in tqdm(loader): 36 | img = pack["img"] 37 | feats = F.normalize(model.forward(img.cuda())[0].mean([2, 3]), dim=1) 38 | all_feats.append(feats.to("cpu", non_blocking=True)) 39 | return torch.cat(all_feats, dim=0).contiguous() 40 | 41 | 42 | @hydra.main(config_path="cfg", config_name="knn_config.yaml") 43 | def my_app(cfg: DictConfig) -> None: 44 | print(OmegaConf.to_yaml(cfg)) 45 | seed_everything(seed=0) 46 | os.makedirs(join(cfg.data_dir, cfg.dataset_name, "nns"), exist_ok=True) 47 | 48 | image_sets = cfg.image_sets 49 | 50 | res = cfg.resolution 51 | n_batches = 16 52 | model = Stego(1).cuda() 53 | 54 | for image_set in image_sets: 55 | feature_cache_file = get_nn_file_name(cfg.data_dir, cfg.dataset_name, model.backbone_name, image_set, res) 56 | if not os.path.exists(feature_cache_file): 57 | print("{} not found, computing".format(feature_cache_file)) 58 | dataset = ContrastiveSegDataset( 59 | data_dir=cfg.data_dir, 60 | dataset_name=cfg.dataset_name, 61 | image_set=image_set, 62 | transform=get_transform(res, False, "center"), 63 | target_transform=get_transform(res, True, "center"), 64 | model_type=model.backbone_name, 65 | resolution=res, 66 | ) 67 | 68 | loader = DataLoader( 69 | dataset, 70 | cfg.batch_size, 71 | shuffle=False, 72 | num_workers=cfg.num_workers, 73 | pin_memory=False, 74 | ) 75 | 76 | with torch.no_grad(): 77 | normed_feats = get_feats(model, loader) 78 | all_nns = [] 79 | step = normed_feats.shape[0] // n_batches 80 | print(normed_feats.shape) 81 | for i in tqdm(range(0, normed_feats.shape[0], step)): 82 | torch.cuda.empty_cache() 83 | batch_feats = normed_feats[i : i + step, :] 84 | pairwise_sims = torch.einsum("nf,mf->nm", batch_feats, normed_feats) 85 | all_nns.append(torch.topk(pairwise_sims, 30)[1]) 86 | del pairwise_sims 87 | nearest_neighbors = torch.cat(all_nns, dim=0) 88 | 89 | np.savez_compressed(feature_cache_file, nns=nearest_neighbors.numpy()) 90 | print("Saved NNs", model.backbone_name, cfg.dataset_name, image_set) 91 | 92 | 93 | if __name__ == "__main__": 94 | prep_args() 95 | my_app() 96 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | ############################################ 9 | # STEGO training script 10 | # 11 | # This script trains a new STEGO model from scratch of from a given checkpoint. 12 | # 13 | # Before running, adjust parameters in cfg/train_config.yaml. 14 | # 15 | # The hyperparameters of the model and the learning rates can be adjusted in stego/cfg/model_config.yaml. 16 | # 17 | ############################################ 18 | 19 | 20 | from torch.utils.data import DataLoader 21 | import hydra 22 | from omegaconf import DictConfig, OmegaConf 23 | 24 | # import pytorch_lightning as pl 25 | from pytorch_lightning import seed_everything 26 | from pytorch_lightning import Trainer 27 | 28 | # from pytorch_lightning.loggers import TensorBoardLogger 29 | 30 | # import torch.multiprocessing 31 | from pytorch_lightning.callbacks import ModelCheckpoint 32 | from pytorch_lightning.loggers import WandbLogger 33 | 34 | from stego.stego import Stego 35 | from stego.utils import prep_args, get_transform 36 | from stego.data import ContrastiveSegDataset 37 | 38 | 39 | @hydra.main(config_path="cfg", config_name="train_config.yaml") 40 | def my_app(cfg: DictConfig) -> None: 41 | OmegaConf.set_struct(cfg, False) 42 | print(OmegaConf.to_yaml(cfg)) 43 | 44 | seed_everything(seed=0) 45 | 46 | if cfg.model_path is not None: 47 | model = Stego.load_from_checkpoint(cfg.model_path).cuda() 48 | else: 49 | model = Stego(cfg.num_classes).cuda() 50 | 51 | if cfg.reset_clusters: 52 | model.reset_clusters(cfg.num_classes, cfg.extra_clusters) 53 | 54 | train_dataset = ContrastiveSegDataset( 55 | data_dir=cfg.data_dir, 56 | dataset_name=cfg.dataset_name, 57 | image_set="train", 58 | transform=get_transform(cfg.resolution, False, "center"), 59 | target_transform=get_transform(cfg.resolution, True, "center"), 60 | model_type=model.backbone_name, 61 | resolution=cfg.resolution, 62 | num_neighbors=cfg.num_neighbors, 63 | pos_images=True, 64 | pos_labels=True, 65 | ) 66 | 67 | val_dataset = ContrastiveSegDataset( 68 | data_dir=cfg.data_dir, 69 | dataset_name=cfg.dataset_name, 70 | image_set="val", 71 | transform=get_transform(cfg.resolution, False, "center"), 72 | target_transform=get_transform(cfg.resolution, True, "center"), 73 | model_type=model.backbone_name, 74 | resolution=cfg.resolution, 75 | ) 76 | 77 | train_loader = DataLoader( 78 | train_dataset, 79 | cfg.batch_size, 80 | shuffle=True, 81 | num_workers=cfg.num_workers, 82 | pin_memory=True, 83 | ) 84 | val_loader = DataLoader( 85 | val_dataset, 86 | cfg.batch_size, 87 | shuffle=False, 88 | num_workers=cfg.num_workers, 89 | pin_memory=True, 90 | ) 91 | 92 | wandb_logger = WandbLogger(project=cfg.wandb_project, name=cfg.wandb_name, log_model=cfg.wandb_log_model) 93 | 94 | trainer = Trainer( 95 | logger=wandb_logger, 96 | max_steps=cfg.max_steps, 97 | default_root_dir=cfg.checkpoint_dir, 98 | callbacks=[ 99 | ModelCheckpoint( 100 | dirpath=cfg.checkpoint_dir, 101 | every_n_train_steps=400, 102 | save_top_k=2, 103 | monitor="val/cluster/mIoU", 104 | mode="max", 105 | ) 106 | ], 107 | gpus=1, 108 | val_check_interval=cfg.val_check_interval, 109 | ) 110 | trainer.fit(model, train_loader, val_loader) 111 | 112 | 113 | if __name__ == "__main__": 114 | prep_args() 115 | my_app() 116 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | from setuptools import find_packages 8 | from distutils.core import setup 9 | 10 | INSTALL_REQUIRES = [ 11 | # generic 12 | "numpy", 13 | "tqdm", 14 | "kornia>=0.6.5", 15 | "pip", 16 | "torchvision", 17 | "torch>=1.21", 18 | "torchmetrics", 19 | "pytorch_lightning>=1.6.5", 20 | "pytest", 21 | "scipy", 22 | "scikit-image", 23 | "scikit-learn", 24 | "matplotlib", 25 | "seaborn", 26 | "pandas", 27 | "pytictac", 28 | "torch_geometric", 29 | "omegaconf", 30 | "optuna", 31 | "neptune", 32 | "fast-slic", 33 | "hydra-core", 34 | "prettytable", 35 | "termcolor", 36 | "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git", 37 | "liegroups@git+https://github.com/mmattamala/liegroups", 38 | "opencv-python", 39 | "wget", 40 | "rospkg", 41 | "wandb", 42 | "gdown" 43 | ] 44 | 45 | setup( 46 | name="stego", 47 | version="0.0.1", 48 | author="Piotr Libera, Jonas Frey, Matias Mattamala", 49 | author_email="plibera@student.ethz.ch, jonfrey@ethz.ch, matias@leggedrobotics.com", 50 | packages=find_packages(), 51 | python_requires=">=3.7", 52 | description="Self-supervised semantic segmentation package based on the STEGO model", 53 | install_requires=[INSTALL_REQUIRES], 54 | ) 55 | -------------------------------------------------------------------------------- /stego/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | import os 8 | 9 | from .data import UnlabeledImageFolder, DirectoryDataset, ContrastiveSegDataset 10 | from .stego import Stego 11 | 12 | 13 | STEGO_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 14 | """Absolute path to the stego repository.""" 15 | -------------------------------------------------------------------------------- /stego/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | # 7 | from . import dino 8 | -------------------------------------------------------------------------------- /stego/backbones/backbone.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | import torch 9 | from .dino import vision_transformer as vits 10 | from torch import nn 11 | import numpy as np 12 | from abc import ABC, abstractmethod 13 | 14 | 15 | def get_backbone(cfg): 16 | """ 17 | Returns a selected STEGO backbone. 18 | After implementing the Backbone class for your backbone, add it to be returned from this function with a desired named. 19 | The backbone can then be used by specifying its name in the STEGO configuration file. 20 | """ 21 | if not hasattr(cfg, "backbone"): 22 | raise ValueError("Could not find 'backbone' option in the config file. Please check it") 23 | 24 | if cfg.backbone == "dino": 25 | return DinoViT(cfg) 26 | elif cfg.backbone == "dinov2": 27 | return Dinov2ViT(cfg) 28 | else: 29 | raise ValueError("Backbone {} unavailable".format(cfg.backbone)) 30 | 31 | 32 | class Backbone(ABC, nn.Module): 33 | """ 34 | Base class to provide an interface for new STEGO backbones. 35 | 36 | To add a new backbone for use in STEGO, add a new implementation of this class. 37 | """ 38 | 39 | vit_name_long_to_short = { 40 | "vit_tiny": "T", 41 | "vit_small": "S", 42 | "vit_base": "B", 43 | "vit_large": "L", 44 | "vit_huge": "H", 45 | "vit_giant": "G", 46 | } 47 | 48 | # Initialize the backbone 49 | @abstractmethod 50 | def __init__(self, cfg): 51 | super().__init__() 52 | 53 | # Return the size of features generated by the backbone 54 | @abstractmethod 55 | def get_output_feat_dim(self) -> int: 56 | pass 57 | 58 | # Generate features for the given image 59 | @abstractmethod 60 | def forward(self, img): 61 | pass 62 | 63 | # Returh a name that identifies the type of the backbone 64 | @abstractmethod 65 | def get_backbone_name(self): 66 | pass 67 | 68 | 69 | class Dinov2ViT(Backbone): 70 | def __init__(self, cfg): 71 | super().__init__(cfg) 72 | self.cfg = cfg 73 | self.backbone_type = self.cfg.backbone_type 74 | self.patch_size = 14 75 | if self.backbone_type == "vit_small": 76 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14") 77 | elif self.backbone_type == "vit_base": 78 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14") 79 | elif self.backbone_type == "vit_small_reg": 80 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg") 81 | elif self.backbone_type == "vit_base_reg": 82 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg") 83 | else: 84 | raise ValueError("Model type {} unavailable".format(cfg.backbone_type)) 85 | 86 | for p in self.model.parameters(): 87 | p.requires_grad = False 88 | self.model.eval().cuda() 89 | self.dropout = torch.nn.Dropout2d(p=np.clip(self.cfg.dropout_p, 0.0, 1.0)) 90 | 91 | if self.backbone_type == "vit_small": 92 | self.n_feats = 384 93 | else: 94 | self.n_feats = 768 95 | 96 | def get_output_feat_dim(self): 97 | return self.n_feats 98 | 99 | def forward(self, img): 100 | self.model.eval() 101 | with torch.no_grad(): 102 | assert img.shape[2] % self.patch_size == 0 103 | assert img.shape[3] % self.patch_size == 0 104 | 105 | # get selected layer activations 106 | feat = self.model.get_intermediate_layers(img)[0] 107 | 108 | feat_h = img.shape[2] // self.patch_size 109 | feat_w = img.shape[3] // self.patch_size 110 | 111 | image_feat = feat[:, :, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2) 112 | 113 | if self.cfg.dropout_p > 0: 114 | return self.dropout(image_feat) 115 | else: 116 | return image_feat 117 | 118 | def get_backbone_name(self): 119 | return "DINOv2-" + Backbone.vit_name_long_to_short[self.backbone_type] + "-" + str(self.patch_size) 120 | 121 | 122 | class DinoViT(Backbone): 123 | def __init__(self, cfg): 124 | super().__init__(cfg) 125 | self.cfg = cfg 126 | self.patch_size = self.cfg.patch_size 127 | self.backbone_type = self.cfg.backbone_type 128 | self.model = vits.__dict__[self.backbone_type](patch_size=self.patch_size, num_classes=0) 129 | for p in self.model.parameters(): 130 | p.requires_grad = False 131 | self.model.eval().cuda() 132 | self.dropout = torch.nn.Dropout2d(p=np.clip(self.cfg.dropout_p, 0.0, 1.0)) 133 | 134 | if self.backbone_type == "vit_small" and self.patch_size == 16: 135 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 136 | elif self.backbone_type == "vit_small" and self.patch_size == 8: 137 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" 138 | elif self.backbone_type == "vit_base" and self.patch_size == 16: 139 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 140 | elif self.backbone_type == "vit_base" and self.patch_size == 8: 141 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 142 | else: 143 | raise ValueError("Model type {} unavailable with patch size {}".format(self.backbone_type, self.patch_size)) 144 | 145 | if cfg.pretrained_weights is not None: 146 | state_dict = torch.load(cfg.pretrained_weights, map_location="cpu") 147 | # remove `module.` prefix 148 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 149 | # remove `backbone.` prefix induced by multicrop wrapper 150 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 151 | msg = self.model.load_state_dict(state_dict, strict=False) 152 | print("Pretrained weights found at {} and loaded with msg: {}".format(cfg.pretrained_weights, msg)) 153 | else: 154 | print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") 155 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) 156 | self.model.load_state_dict(state_dict, strict=True) 157 | 158 | if self.backbone_type == "vit_small": 159 | self.n_feats = 384 160 | else: 161 | self.n_feats = 768 162 | 163 | def get_output_feat_dim(self): 164 | return self.n_feats 165 | 166 | def forward(self, img): 167 | self.model.eval() 168 | with torch.no_grad(): 169 | assert img.shape[2] % self.patch_size == 0 170 | assert img.shape[3] % self.patch_size == 0 171 | 172 | # get selected layer activations 173 | feat, attn, qkv = self.model.get_intermediate_feat(img) 174 | feat, attn, qkv = feat[0], attn[0], qkv[0] 175 | 176 | feat_h = img.shape[2] // self.patch_size 177 | feat_w = img.shape[3] // self.patch_size 178 | 179 | image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2) 180 | 181 | if self.cfg.dropout_p > 0: 182 | return self.dropout(image_feat) 183 | else: 184 | return image_feat 185 | 186 | def get_backbone_name(self): 187 | return "DINO-" + Backbone.vit_name_long_to_short[self.backbone_type] + "-" + str(self.patch_size) 188 | -------------------------------------------------------------------------------- /stego/backbones/dino/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from . import vision_transformer 3 | -------------------------------------------------------------------------------- /stego/backbones/dino/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Misc functions. 16 | 17 | Mostly copy-paste from torchvision references or other public repos like DETR: 18 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 19 | """ 20 | import os 21 | import sys 22 | import time 23 | import math 24 | import random 25 | import datetime 26 | import subprocess 27 | import warnings 28 | import argparse 29 | from collections import defaultdict, deque 30 | 31 | import numpy as np 32 | import torch 33 | from torch import Tensor 34 | from torch import nn 35 | import torch.distributed as dist 36 | from PIL import ImageFilter, ImageOps 37 | 38 | 39 | class GaussianBlur(object): 40 | """ 41 | Apply Gaussian Blur to the PIL image. 42 | """ 43 | 44 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0): 45 | self.prob = p 46 | self.radius_min = radius_min 47 | self.radius_max = radius_max 48 | 49 | def __call__(self, img): 50 | do_it = random.random() <= self.prob 51 | if not do_it: 52 | return img 53 | 54 | return img.filter(ImageFilter.GaussianBlur(radius=random.uniform(self.radius_min, self.radius_max))) 55 | 56 | 57 | class Solarization(object): 58 | """ 59 | Apply Solarization to the PIL image. 60 | """ 61 | 62 | def __init__(self, p): 63 | self.p = p 64 | 65 | def __call__(self, img): 66 | if random.random() < self.p: 67 | return ImageOps.solarize(img) 68 | else: 69 | return img 70 | 71 | 72 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size): 73 | if os.path.isfile(pretrained_weights): 74 | state_dict = torch.load(pretrained_weights, map_location="cpu") 75 | if checkpoint_key is not None and checkpoint_key in state_dict: 76 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 77 | state_dict = state_dict[checkpoint_key] 78 | # remove `module.` prefix 79 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 80 | # remove `backbone.` prefix induced by multicrop wrapper 81 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 82 | msg = model.load_state_dict(state_dict, strict=False) 83 | print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) 84 | else: 85 | print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") 86 | url = None 87 | if model_name == "vit_small" and patch_size == 16: 88 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 89 | elif model_name == "vit_small" and patch_size == 8: 90 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" 91 | elif model_name == "vit_base" and patch_size == 16: 92 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 93 | elif model_name == "vit_base" and patch_size == 8: 94 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 95 | if url is not None: 96 | print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") 97 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) 98 | model.load_state_dict(state_dict, strict=True) 99 | else: 100 | print("There is no reference weights available for this model => We use random weights.") 101 | 102 | 103 | def clip_gradients(model, clip): 104 | norms = [] 105 | for name, p in model.named_parameters(): 106 | if p.grad is not None: 107 | param_norm = p.grad.data.norm(2) 108 | norms.append(param_norm.item()) 109 | clip_coef = clip / (param_norm + 1e-6) 110 | if clip_coef < 1: 111 | p.grad.data.mul_(clip_coef) 112 | return norms 113 | 114 | 115 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): 116 | if epoch >= freeze_last_layer: 117 | return 118 | for n, p in model.named_parameters(): 119 | if "last_layer" in n: 120 | p.grad = None 121 | 122 | 123 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): 124 | """ 125 | Re-start from checkpoint 126 | """ 127 | if not os.path.isfile(ckp_path): 128 | return 129 | print("Found checkpoint at {}".format(ckp_path)) 130 | 131 | # open checkpoint file 132 | checkpoint = torch.load(ckp_path, map_location="cpu") 133 | 134 | # key is what to look for in the checkpoint file 135 | # value is the object to load 136 | # example: {'state_dict': model} 137 | for key, value in kwargs.items(): 138 | if key in checkpoint and value is not None: 139 | try: 140 | msg = value.load_state_dict(checkpoint[key], strict=False) 141 | print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) 142 | except TypeError: 143 | try: 144 | msg = value.load_state_dict(checkpoint[key]) 145 | print("=> loaded {} from checkpoint '{}'".format(key, ckp_path)) 146 | except ValueError: 147 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) 148 | else: 149 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) 150 | 151 | # re load variable important for the run 152 | if run_variables is not None: 153 | for var_name in run_variables: 154 | if var_name in checkpoint: 155 | run_variables[var_name] = checkpoint[var_name] 156 | 157 | 158 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 159 | warmup_schedule = np.array([]) 160 | warmup_iters = warmup_epochs * niter_per_ep 161 | if warmup_epochs > 0: 162 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 163 | 164 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 165 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 166 | 167 | schedule = np.concatenate((warmup_schedule, schedule)) 168 | assert len(schedule) == epochs * niter_per_ep 169 | return schedule 170 | 171 | 172 | def bool_flag(s): 173 | """ 174 | Parse boolean arguments from the command line. 175 | """ 176 | FALSY_STRINGS = {"off", "false", "0"} 177 | TRUTHY_STRINGS = {"on", "true", "1"} 178 | if s.lower() in FALSY_STRINGS: 179 | return False 180 | elif s.lower() in TRUTHY_STRINGS: 181 | return True 182 | else: 183 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 184 | 185 | 186 | def fix_random_seeds(seed=31): 187 | """ 188 | Fix random seeds. 189 | """ 190 | torch.manual_seed(seed) 191 | torch.cuda.manual_seed_all(seed) 192 | np.random.seed(seed) 193 | 194 | 195 | class SmoothedValue(object): 196 | """Track a series of values and provide access to smoothed values over a 197 | window or the global series average. 198 | """ 199 | 200 | def __init__(self, window_size=20, fmt=None): 201 | if fmt is None: 202 | fmt = "{median:.6f} ({global_avg:.6f})" 203 | self.deque = deque(maxlen=window_size) 204 | self.total = 0.0 205 | self.count = 0 206 | self.fmt = fmt 207 | 208 | def update(self, value, n=1): 209 | self.deque.append(value) 210 | self.count += n 211 | self.total += value * n 212 | 213 | def synchronize_between_processes(self): 214 | """ 215 | Warning: does not synchronize the deque! 216 | """ 217 | if not is_dist_avail_and_initialized(): 218 | return 219 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 220 | dist.barrier() 221 | dist.all_reduce(t) 222 | t = t.tolist() 223 | self.count = int(t[0]) 224 | self.total = t[1] 225 | 226 | @property 227 | def median(self): 228 | d = torch.tensor(list(self.deque)) 229 | return d.median().item() 230 | 231 | @property 232 | def avg(self): 233 | d = torch.tensor(list(self.deque), dtype=torch.float32) 234 | return d.mean().item() 235 | 236 | @property 237 | def global_avg(self): 238 | return self.total / self.count 239 | 240 | @property 241 | def max(self): 242 | return max(self.deque) 243 | 244 | @property 245 | def value(self): 246 | return self.deque[-1] 247 | 248 | def __str__(self): 249 | return self.fmt.format( 250 | median=self.median, 251 | avg=self.avg, 252 | global_avg=self.global_avg, 253 | max=self.max, 254 | value=self.value, 255 | ) 256 | 257 | 258 | def reduce_dict(input_dict, average=True): 259 | """ 260 | Args: 261 | input_dict (dict): all the values will be reduced 262 | average (bool): whether to do average or sum 263 | Reduce the values in the dictionary from all processes so that all processes 264 | have the averaged results. Returns a dict with the same fields as 265 | input_dict, after reduction. 266 | """ 267 | world_size = get_world_size() 268 | if world_size < 2: 269 | return input_dict 270 | with torch.no_grad(): 271 | names = [] 272 | values = [] 273 | # sort the keys so that they are consistent across processes 274 | for k in sorted(input_dict.keys()): 275 | names.append(k) 276 | values.append(input_dict[k]) 277 | values = torch.stack(values, dim=0) 278 | dist.all_reduce(values) 279 | if average: 280 | values /= world_size 281 | reduced_dict = {k: v for k, v in zip(names, values)} 282 | return reduced_dict 283 | 284 | 285 | class MetricLogger(object): 286 | def __init__(self, delimiter="\t"): 287 | self.meters = defaultdict(SmoothedValue) 288 | self.delimiter = delimiter 289 | 290 | def update(self, **kwargs): 291 | for k, v in kwargs.items(): 292 | if isinstance(v, torch.Tensor): 293 | v = v.item() 294 | assert isinstance(v, (float, int)) 295 | self.meters[k].update(v) 296 | 297 | def __getattr__(self, attr): 298 | if attr in self.meters: 299 | return self.meters[attr] 300 | if attr in self.__dict__: 301 | return self.__dict__[attr] 302 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) 303 | 304 | def __str__(self): 305 | loss_str = [] 306 | for name, meter in self.meters.items(): 307 | loss_str.append("{}: {}".format(name, str(meter))) 308 | return self.delimiter.join(loss_str) 309 | 310 | def synchronize_between_processes(self): 311 | for meter in self.meters.values(): 312 | meter.synchronize_between_processes() 313 | 314 | def add_meter(self, name, meter): 315 | self.meters[name] = meter 316 | 317 | def log_every(self, iterable, print_freq, header=None): 318 | i = 0 319 | if not header: 320 | header = "" 321 | start_time = time.time() 322 | end = time.time() 323 | iter_time = SmoothedValue(fmt="{avg:.6f}") 324 | data_time = SmoothedValue(fmt="{avg:.6f}") 325 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 326 | if torch.cuda.is_available(): 327 | log_msg = self.delimiter.join( 328 | [ 329 | header, 330 | "[{0" + space_fmt + "}/{1}]", 331 | "eta: {eta}", 332 | "{meters}", 333 | "time: {time}", 334 | "data: {data}", 335 | "max mem: {memory:.0f}", 336 | ] 337 | ) 338 | else: 339 | log_msg = self.delimiter.join( 340 | [ 341 | header, 342 | "[{0" + space_fmt + "}/{1}]", 343 | "eta: {eta}", 344 | "{meters}", 345 | "time: {time}", 346 | "data: {data}", 347 | ] 348 | ) 349 | MB = 1024.0 * 1024.0 350 | for obj in iterable: 351 | data_time.update(time.time() - end) 352 | yield obj 353 | iter_time.update(time.time() - end) 354 | if i % print_freq == 0 or i == len(iterable) - 1: 355 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 356 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 357 | if torch.cuda.is_available(): 358 | print( 359 | log_msg.format( 360 | i, 361 | len(iterable), 362 | eta=eta_string, 363 | meters=str(self), 364 | time=str(iter_time), 365 | data=str(data_time), 366 | memory=torch.cuda.max_memory_allocated() / MB, 367 | ) 368 | ) 369 | else: 370 | print( 371 | log_msg.format( 372 | i, 373 | len(iterable), 374 | eta=eta_string, 375 | meters=str(self), 376 | time=str(iter_time), 377 | data=str(data_time), 378 | ) 379 | ) 380 | i += 1 381 | end = time.time() 382 | total_time = time.time() - start_time 383 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 384 | print("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / len(iterable))) 385 | 386 | 387 | def get_sha(): 388 | cwd = os.path.dirname(os.path.abspath(__file__)) 389 | 390 | def _run(command): 391 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 392 | 393 | sha = "N/A" 394 | diff = "clean" 395 | branch = "N/A" 396 | try: 397 | sha = _run(["git", "rev-parse", "HEAD"]) 398 | subprocess.check_output(["git", "diff"], cwd=cwd) 399 | diff = _run(["git", "diff-index", "HEAD"]) 400 | diff = "has uncommited changes" if diff else "clean" 401 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 402 | except Exception: 403 | pass 404 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 405 | return message 406 | 407 | 408 | def is_dist_avail_and_initialized(): 409 | if not dist.is_available(): 410 | return False 411 | if not dist.is_initialized(): 412 | return False 413 | return True 414 | 415 | 416 | def get_world_size(): 417 | if not is_dist_avail_and_initialized(): 418 | return 1 419 | return dist.get_world_size() 420 | 421 | 422 | def get_rank(): 423 | if not is_dist_avail_and_initialized(): 424 | return 0 425 | return dist.get_rank() 426 | 427 | 428 | def is_main_process(): 429 | return get_rank() == 0 430 | 431 | 432 | def save_on_master(*args, **kwargs): 433 | if is_main_process(): 434 | torch.save(*args, **kwargs) 435 | 436 | 437 | def setup_for_distributed(is_master): 438 | """ 439 | This function disables printing when not in master process 440 | """ 441 | import builtins as __builtin__ 442 | 443 | builtin_print = __builtin__.print 444 | 445 | def print(*args, **kwargs): 446 | force = kwargs.pop("force", False) 447 | if is_master or force: 448 | builtin_print(*args, **kwargs) 449 | 450 | __builtin__.print = print 451 | 452 | 453 | def init_distributed_mode(args): 454 | # launched with torch.distributed.launch 455 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 456 | args.rank = int(os.environ["RANK"]) 457 | args.world_size = int(os.environ["WORLD_SIZE"]) 458 | args.gpu = int(os.environ["LOCAL_RANK"]) 459 | # launched with submitit on a slurm cluster 460 | elif "SLURM_PROCID" in os.environ: 461 | args.rank = int(os.environ["SLURM_PROCID"]) 462 | args.gpu = args.rank % torch.cuda.device_count() 463 | # launched naively with `python main_dino.py` 464 | # we manually add MASTER_ADDR and MASTER_PORT to env variables 465 | elif torch.cuda.is_available(): 466 | print("Will run the code on one GPU.") 467 | args.rank, args.gpu, args.world_size = 0, 0, 1 468 | os.environ["MASTER_ADDR"] = "127.0.0.1" 469 | os.environ["MASTER_PORT"] = "29500" 470 | else: 471 | print("Does not support training without GPU.") 472 | sys.exit(1) 473 | 474 | dist.init_process_group( 475 | backend="nccl", 476 | init_method=args.dist_url, 477 | world_size=args.world_size, 478 | rank=args.rank, 479 | ) 480 | 481 | torch.cuda.set_device(args.gpu) 482 | print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) 483 | dist.barrier() 484 | setup_for_distributed(args.rank == 0) 485 | 486 | 487 | def accuracy(output, target, topk=(1,)): 488 | """Computes the accuracy over the k top predictions for the specified values of k""" 489 | maxk = max(topk) 490 | batch_size = target.size(0) 491 | _, pred = output.topk(maxk, 1, True, True) 492 | pred = pred.t() 493 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 494 | return [correct[:k].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk] 495 | 496 | 497 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 498 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 499 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 500 | def norm_cdf(x): 501 | # Computes standard normal cumulative distribution function 502 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 503 | 504 | if (mean < a - 2 * std) or (mean > b + 2 * std): 505 | warnings.warn( 506 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 507 | "The distribution of values may be incorrect.", 508 | stacklevel=2, 509 | ) 510 | 511 | with torch.no_grad(): 512 | # Values are generated by using a truncated uniform distribution and 513 | # then using the inverse CDF for the normal distribution. 514 | # Get upper and lower cdf values 515 | l = norm_cdf((a - mean) / std) # noqa 516 | u = norm_cdf((b - mean) / std) 517 | 518 | # Uniformly fill tensor with values from [l, u], then translate to 519 | # [2l-1, 2u-1]. 520 | tensor.uniform_(2 * l - 1, 2 * u - 1) 521 | 522 | # Use inverse cdf transform for normal distribution to get truncated 523 | # standard normal 524 | tensor.erfinv_() 525 | 526 | # Transform to proper mean, std 527 | tensor.mul_(std * math.sqrt(2.0)) 528 | tensor.add_(mean) 529 | 530 | # Clamp to ensure it's in the proper range 531 | tensor.clamp_(min=a, max=b) 532 | return tensor 533 | 534 | 535 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 536 | # type: (Tensor, float, float, float, float) -> Tensor 537 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 538 | 539 | 540 | class LARS(torch.optim.Optimizer): 541 | """ 542 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py 543 | """ 544 | 545 | def __init__( 546 | self, 547 | params, 548 | lr=0, 549 | weight_decay=0, 550 | momentum=0.9, 551 | eta=0.001, 552 | weight_decay_filter=None, 553 | lars_adaptation_filter=None, 554 | ): 555 | defaults = dict( 556 | lr=lr, 557 | weight_decay=weight_decay, 558 | momentum=momentum, 559 | eta=eta, 560 | weight_decay_filter=weight_decay_filter, 561 | lars_adaptation_filter=lars_adaptation_filter, 562 | ) 563 | super().__init__(params, defaults) 564 | 565 | @torch.no_grad() 566 | def step(self): 567 | for g in self.param_groups: 568 | for p in g["params"]: 569 | dp = p.grad 570 | 571 | if dp is None: 572 | continue 573 | 574 | if p.ndim != 1: 575 | dp = dp.add(p, alpha=g["weight_decay"]) 576 | 577 | if p.ndim != 1: 578 | param_norm = torch.norm(p) 579 | update_norm = torch.norm(dp) 580 | one = torch.ones_like(param_norm) 581 | q = torch.where( 582 | param_norm > 0.0, 583 | torch.where(update_norm > 0, (g["eta"] * param_norm / update_norm), one), 584 | one, 585 | ) 586 | dp = dp.mul(q) 587 | 588 | param_state = self.state[p] 589 | if "mu" not in param_state: 590 | param_state["mu"] = torch.zeros_like(p) 591 | mu = param_state["mu"] 592 | mu.mul_(g["momentum"]).add_(dp) 593 | 594 | p.add_(mu, alpha=-g["lr"]) 595 | 596 | 597 | class MultiCropWrapper(nn.Module): 598 | """ 599 | Perform forward pass separately on each resolution input. 600 | The inputs corresponding to a single resolution are clubbed and single 601 | forward is run on the same resolution inputs. Hence we do several 602 | forward passes = number of different resolutions used. We then 603 | concatenate all the output features and run the head forward on these 604 | concatenated features. 605 | """ 606 | 607 | def __init__(self, backbone, head): 608 | super(MultiCropWrapper, self).__init__() 609 | # disable layers dedicated to ImageNet labels classification 610 | backbone.fc, backbone.head = nn.Identity(), nn.Identity() 611 | self.backbone = backbone 612 | self.head = head 613 | 614 | def forward(self, x): 615 | # convert to list 616 | if not isinstance(x, list): 617 | x = [x] 618 | idx_crops = torch.cumsum( 619 | torch.unique_consecutive( 620 | torch.tensor([inp.shape[-1] for inp in x]), 621 | return_counts=True, 622 | )[1], 623 | 0, 624 | ) 625 | start_idx = 0 626 | for end_idx in idx_crops: 627 | _out = self.backbone(torch.cat(x[start_idx:end_idx])) 628 | if start_idx == 0: 629 | output = _out 630 | else: 631 | output = torch.cat((output, _out)) 632 | start_idx = end_idx 633 | # Run the head forward on the concatenated features. 634 | return self.head(output) 635 | 636 | 637 | def get_params_groups(model): 638 | regularized = [] 639 | not_regularized = [] 640 | for name, param in model.named_parameters(): 641 | if not param.requires_grad: 642 | continue 643 | # we do not regularize biases nor Norm parameters 644 | if name.endswith(".bias") or len(param.shape) == 1: 645 | not_regularized.append(param) 646 | else: 647 | regularized.append(param) 648 | return [{"params": regularized}, {"params": not_regularized, "weight_decay": 0.0}] 649 | 650 | 651 | def has_batchnorms(model): 652 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 653 | for name, module in model.named_modules(): 654 | if isinstance(module, bn_types): 655 | return True 656 | return False 657 | -------------------------------------------------------------------------------- /stego/backbones/dino/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | from .utils import trunc_normal_ 24 | 25 | 26 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 27 | if drop_prob == 0.0 or not training: 28 | return x 29 | keep_prob = 1 - drop_prob 30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 32 | random_tensor.floor_() # binarize 33 | output = x.div(keep_prob) * random_tensor 34 | return output 35 | 36 | 37 | class DropPath(nn.Module): 38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 39 | 40 | def __init__(self, drop_prob=None): 41 | super(DropPath, self).__init__() 42 | self.drop_prob = drop_prob 43 | 44 | def forward(self, x): 45 | return drop_path(x, self.drop_prob, self.training) 46 | 47 | 48 | class Mlp(nn.Module): 49 | def __init__( 50 | self, 51 | in_features, 52 | hidden_features=None, 53 | out_features=None, 54 | act_layer=nn.GELU, 55 | drop=0.0, 56 | ): 57 | super().__init__() 58 | out_features = out_features or in_features 59 | hidden_features = hidden_features or in_features 60 | self.fc1 = nn.Linear(in_features, hidden_features) 61 | self.act = act_layer() 62 | self.fc2 = nn.Linear(hidden_features, out_features) 63 | self.drop = nn.Dropout(drop) 64 | 65 | def forward(self, x): 66 | x = self.fc1(x) 67 | x = self.act(x) 68 | x = self.drop(x) 69 | x = self.fc2(x) 70 | x = self.drop(x) 71 | return x 72 | 73 | 74 | class Attention(nn.Module): 75 | def __init__( 76 | self, 77 | dim, 78 | num_heads=8, 79 | qkv_bias=False, 80 | qk_scale=None, 81 | attn_drop=0.0, 82 | proj_drop=0.0, 83 | ): 84 | super().__init__() 85 | self.num_heads = num_heads 86 | head_dim = dim // num_heads 87 | self.scale = qk_scale or head_dim**-0.5 88 | 89 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 90 | self.attn_drop = nn.Dropout(attn_drop) 91 | self.proj = nn.Linear(dim, dim) 92 | self.proj_drop = nn.Dropout(proj_drop) 93 | 94 | def forward(self, x, return_qkv=False): 95 | B, N, C = x.shape 96 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 97 | q, k, v = qkv[0], qkv[1], qkv[2] 98 | 99 | attn = (q @ k.transpose(-2, -1)) * self.scale 100 | attn = attn.softmax(dim=-1) 101 | attn = self.attn_drop(attn) 102 | 103 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 104 | x = self.proj(x) 105 | x = self.proj_drop(x) 106 | return x, attn, qkv 107 | 108 | 109 | class Block(nn.Module): 110 | def __init__( 111 | self, 112 | dim, 113 | num_heads, 114 | mlp_ratio=4.0, 115 | qkv_bias=False, 116 | qk_scale=None, 117 | drop=0.0, 118 | attn_drop=0.0, 119 | drop_path=0.0, 120 | act_layer=nn.GELU, 121 | norm_layer=nn.LayerNorm, 122 | ): 123 | super().__init__() 124 | self.norm1 = norm_layer(dim) 125 | self.attn = Attention( 126 | dim, 127 | num_heads=num_heads, 128 | qkv_bias=qkv_bias, 129 | qk_scale=qk_scale, 130 | attn_drop=attn_drop, 131 | proj_drop=drop, 132 | ) 133 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 134 | self.norm2 = norm_layer(dim) 135 | mlp_hidden_dim = int(dim * mlp_ratio) 136 | self.mlp = Mlp( 137 | in_features=dim, 138 | hidden_features=mlp_hidden_dim, 139 | act_layer=act_layer, 140 | drop=drop, 141 | ) 142 | 143 | def forward(self, x, return_attention=False, return_qkv=False): 144 | y, attn, qkv = self.attn(self.norm1(x)) 145 | if return_attention: 146 | return attn 147 | x = x + self.drop_path(y) 148 | x = x + self.drop_path(self.mlp(self.norm2(x))) 149 | if return_qkv: 150 | return x, attn, qkv 151 | return x 152 | 153 | 154 | class PatchEmbed(nn.Module): 155 | """Image to Patch Embedding""" 156 | 157 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 158 | super().__init__() 159 | num_patches = (img_size // patch_size) * (img_size // patch_size) 160 | self.img_size = img_size 161 | self.patch_size = patch_size 162 | self.num_patches = num_patches 163 | 164 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 165 | 166 | def forward(self, x): 167 | B, C, H, W = x.shape 168 | x = self.proj(x).flatten(2).transpose(1, 2) 169 | return x 170 | 171 | 172 | class VisionTransformer(nn.Module): 173 | """Vision Transformer""" 174 | 175 | def __init__( 176 | self, 177 | img_size=[224], 178 | patch_size=16, 179 | in_chans=3, 180 | num_classes=0, 181 | embed_dim=768, 182 | depth=12, 183 | num_heads=12, 184 | mlp_ratio=4.0, 185 | qkv_bias=False, 186 | qk_scale=None, 187 | drop_rate=0.0, 188 | attn_drop_rate=0.0, 189 | drop_path_rate=0.0, 190 | norm_layer=nn.LayerNorm, 191 | **kwargs 192 | ): 193 | super().__init__() 194 | 195 | self.num_features = self.embed_dim = embed_dim 196 | 197 | self.patch_embed = PatchEmbed( 198 | img_size=img_size[0], 199 | patch_size=patch_size, 200 | in_chans=in_chans, 201 | embed_dim=embed_dim, 202 | ) 203 | num_patches = self.patch_embed.num_patches 204 | 205 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 206 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 207 | self.pos_drop = nn.Dropout(p=drop_rate) 208 | 209 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 210 | self.blocks = nn.ModuleList( 211 | [ 212 | Block( 213 | dim=embed_dim, 214 | num_heads=num_heads, 215 | mlp_ratio=mlp_ratio, 216 | qkv_bias=qkv_bias, 217 | qk_scale=qk_scale, 218 | drop=drop_rate, 219 | attn_drop=attn_drop_rate, 220 | drop_path=dpr[i], 221 | norm_layer=norm_layer, 222 | ) 223 | for i in range(depth) 224 | ] 225 | ) 226 | self.norm = norm_layer(embed_dim) 227 | 228 | # Classifier head 229 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 230 | 231 | trunc_normal_(self.pos_embed, std=0.02) 232 | trunc_normal_(self.cls_token, std=0.02) 233 | self.apply(self._init_weights) 234 | 235 | def _init_weights(self, m): 236 | if isinstance(m, nn.Linear): 237 | trunc_normal_(m.weight, std=0.02) 238 | if isinstance(m, nn.Linear) and m.bias is not None: 239 | nn.init.constant_(m.bias, 0) 240 | elif isinstance(m, nn.LayerNorm): 241 | nn.init.constant_(m.bias, 0) 242 | nn.init.constant_(m.weight, 1.0) 243 | 244 | def interpolate_pos_encoding(self, x, w, h): 245 | npatch = x.shape[1] - 1 246 | N = self.pos_embed.shape[1] - 1 247 | if npatch == N and w == h: 248 | return self.pos_embed 249 | class_pos_embed = self.pos_embed[:, 0] 250 | patch_pos_embed = self.pos_embed[:, 1:] 251 | dim = x.shape[-1] 252 | w0 = w // self.patch_embed.patch_size 253 | h0 = h // self.patch_embed.patch_size 254 | # we add a small number to avoid floating point error in the interpolation 255 | # see discussion at https://github.com/facebookresearch/dino/issues/8 256 | w0, h0 = w0 + 0.1, h0 + 0.1 257 | patch_pos_embed = nn.functional.interpolate( 258 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 259 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 260 | mode="bicubic", 261 | ) 262 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 263 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 264 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 265 | 266 | def prepare_tokens(self, x): 267 | B, nc, w, h = x.shape 268 | x = self.patch_embed(x) # patch linear embedding 269 | 270 | # add the [CLS] token to the embed patch tokens 271 | cls_tokens = self.cls_token.expand(B, -1, -1) 272 | x = torch.cat((cls_tokens, x), dim=1) 273 | 274 | # add positional encoding to each token 275 | x = x + self.interpolate_pos_encoding(x, w, h) 276 | 277 | return self.pos_drop(x) 278 | 279 | def forward(self, x): 280 | x = self.prepare_tokens(x) 281 | for blk in self.blocks: 282 | x = blk(x) 283 | x = self.norm(x) 284 | return x[:, 0] 285 | 286 | def forward_feats(self, x): 287 | x = self.prepare_tokens(x) 288 | for blk in self.blocks: 289 | x = blk(x) 290 | x = self.norm(x) 291 | return x 292 | 293 | def get_intermediate_feat(self, x, n=1): 294 | x = self.prepare_tokens(x) 295 | # we return the output tokens from the `n` last blocks 296 | feat = [] 297 | attns = [] 298 | qkvs = [] 299 | for i, blk in enumerate(self.blocks): 300 | x, attn, qkv = blk(x, return_qkv=True) 301 | if len(self.blocks) - i <= n: 302 | feat.append(self.norm(x)) 303 | qkvs.append(qkv) 304 | attns.append(attn) 305 | return feat, attns, qkvs 306 | 307 | def get_last_selfattention(self, x): 308 | x = self.prepare_tokens(x) 309 | for i, blk in enumerate(self.blocks): 310 | if i < len(self.blocks) - 1: 311 | x = blk(x) 312 | else: 313 | # return attention of the last block 314 | return blk(x, return_attention=True) 315 | 316 | def get_intermediate_layers(self, x, n=1): 317 | x = self.prepare_tokens(x) 318 | # we return the output tokens from the `n` last blocks 319 | output = [] 320 | for i, blk in enumerate(self.blocks): 321 | x = blk(x) 322 | if len(self.blocks) - i <= n: 323 | output.append(self.norm(x)) 324 | return output 325 | 326 | 327 | def vit_tiny(patch_size=16, **kwargs): 328 | model = VisionTransformer( 329 | patch_size=patch_size, 330 | embed_dim=192, 331 | depth=12, 332 | num_heads=3, 333 | mlp_ratio=4, 334 | qkv_bias=True, 335 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 336 | **kwargs 337 | ) 338 | return model 339 | 340 | 341 | def vit_small(patch_size=16, **kwargs): 342 | model = VisionTransformer( 343 | patch_size=patch_size, 344 | embed_dim=384, 345 | depth=12, 346 | num_heads=6, 347 | mlp_ratio=4, 348 | qkv_bias=True, 349 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 350 | **kwargs 351 | ) 352 | return model 353 | 354 | 355 | def vit_base(patch_size=16, **kwargs): 356 | model = VisionTransformer( 357 | patch_size=patch_size, 358 | embed_dim=768, 359 | depth=12, 360 | num_heads=12, 361 | mlp_ratio=4, 362 | qkv_bias=True, 363 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 364 | **kwargs 365 | ) 366 | return model 367 | 368 | 369 | class DINOHead(nn.Module): 370 | def __init__( 371 | self, 372 | in_dim, 373 | out_dim, 374 | use_bn=False, 375 | norm_last_layer=True, 376 | nlayers=3, 377 | hidden_dim=2048, 378 | bottleneck_dim=256, 379 | ): 380 | super().__init__() 381 | nlayers = max(nlayers, 1) 382 | if nlayers == 1: 383 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 384 | else: 385 | layers = [nn.Linear(in_dim, hidden_dim)] 386 | if use_bn: 387 | layers.append(nn.BatchNorm1d(hidden_dim)) 388 | layers.append(nn.GELU()) 389 | for _ in range(nlayers - 2): 390 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 391 | if use_bn: 392 | layers.append(nn.BatchNorm1d(hidden_dim)) 393 | layers.append(nn.GELU()) 394 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 395 | self.mlp = nn.Sequential(*layers) 396 | self.apply(self._init_weights) 397 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 398 | self.last_layer.weight_g.data.fill_(1) 399 | if norm_last_layer: 400 | self.last_layer.weight_g.requires_grad = False 401 | 402 | def _init_weights(self, m): 403 | if isinstance(m, nn.Linear): 404 | trunc_normal_(m.weight, std=0.02) 405 | if isinstance(m, nn.Linear) and m.bias is not None: 406 | nn.init.constant_(m.bias, 0) 407 | 408 | def forward(self, x): 409 | x = self.mlp(x) 410 | x = nn.functional.normalize(x, dim=-1, p=2) 411 | x = self.last_layer(x) 412 | return x 413 | -------------------------------------------------------------------------------- /stego/cfg/model_config.yaml: -------------------------------------------------------------------------------- 1 | # Backbone parameters 2 | backbone: "dino" 3 | backbone_type: "vit_base" 4 | patch_size: 8 5 | dropout_p: 0.1 # Dropout probability on backbone output, clamped to [0,1]. For training, STEGO used 0.1. 6 | pretrained_weights: 7 | 8 | # Head 9 | dim: 90 # Note: Piotr used 70, but the original STEGO model uses 90 10 | # Clustering 11 | extra_clusters: 0 12 | 13 | # CRF 14 | crf_max_iter: 10 15 | pos_w: 3 16 | pos_xy_std: 1 17 | bi_w: 4 18 | bi_xy_std: 67 19 | bi_rgb_std: 3 20 | 21 | # Training params 22 | lr: 5e-4 23 | cluster_lr: 5e-3 24 | linear_lr: 5e-3 25 | val_n_imgs: 3 26 | 27 | # Feature Contrastive params 28 | zero_clamp: True 29 | stabilize: False 30 | pointwise: True 31 | feature_samples: 11 32 | neg_samples: 5 33 | 34 | neg_inter_weight: 1.0 35 | pos_inter_weight: 0.5 36 | pos_intra_weight: 1.0 37 | neg_inter_shift: 0.3 38 | pos_inter_shift: 0.2 39 | pos_intra_shift: 0.35 40 | -------------------------------------------------------------------------------- /stego/data.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | import torch 9 | import numpy as np 10 | from PIL import Image 11 | import random 12 | from torch.utils.data import Dataset 13 | import os 14 | 15 | from stego.utils import get_nn_file_name 16 | 17 | 18 | class UnlabeledImageFolder(Dataset): 19 | """ 20 | A simple Dataset class to read images from a given folder. 21 | """ 22 | 23 | def __init__(self, root, transform): 24 | super(UnlabeledImageFolder, self).__init__() 25 | self.root = root 26 | self.transform = transform 27 | self.images = os.listdir(self.root) 28 | 29 | def __getitem__(self, index): 30 | image = Image.open(os.path.join(self.root, self.images[index])).convert("RGB") 31 | seed = np.random.randint(2147483647) 32 | random.seed(seed) 33 | torch.manual_seed(seed) 34 | image = self.transform(image) 35 | 36 | return image, self.images[index] 37 | 38 | def __len__(self): 39 | return len(self.images) 40 | 41 | 42 | def create_cityscapes_colormap(): 43 | colors = [ 44 | (128, 64, 128), 45 | (244, 35, 232), 46 | (250, 170, 160), 47 | (230, 150, 140), 48 | (70, 70, 70), 49 | (102, 102, 156), 50 | (190, 153, 153), 51 | (180, 165, 180), 52 | (150, 100, 100), 53 | (150, 120, 90), 54 | (153, 153, 153), 55 | (153, 153, 153), 56 | (250, 170, 30), 57 | (220, 220, 0), 58 | (107, 142, 35), 59 | (152, 251, 152), 60 | (70, 130, 180), 61 | (220, 20, 60), 62 | (255, 0, 0), 63 | (0, 0, 142), 64 | (0, 0, 70), 65 | (0, 60, 100), 66 | (0, 0, 90), 67 | (0, 0, 110), 68 | (0, 80, 100), 69 | (0, 0, 230), 70 | (119, 11, 32), 71 | (0, 0, 0), 72 | ] 73 | return np.array(colors) 74 | 75 | 76 | class DirectoryDataset(Dataset): 77 | """ 78 | A Dataset class that reads images and (if available) labels from the given directory. 79 | The expected structure of the directory: 80 | data_dir 81 | |-- dataset_name 82 | |-- imgs 83 | |-- image_set 84 | |-- labels 85 | |-- image_set 86 | 87 | If available, file names in labels/image_set should be the same as file names in imgs/image_set (excluding extensions). 88 | If labels are not available (there is no labels folder) this class returns zero arrays of shape corresponding to the image shape. 89 | """ 90 | 91 | def __init__(self, data_dir, dataset_name, image_set, transform, target_transform): 92 | super(DirectoryDataset, self).__init__() 93 | self.split = image_set 94 | self.dataset_name = dataset_name 95 | self.dir = os.path.join(data_dir, dataset_name) 96 | self.img_dir = os.path.join(self.dir, "imgs", self.split) 97 | self.label_dir = os.path.join(self.dir, "labels", self.split) 98 | 99 | self.transform = transform 100 | self.target_transform = target_transform 101 | 102 | self.img_files = np.array(sorted(os.listdir(self.img_dir))) 103 | assert len(self.img_files) > 0, "Could not find any images in dataset directory {}".format(self.img_dir) 104 | if os.path.exists(os.path.join(self.dir, "labels")): 105 | self.label_files = np.array(sorted(os.listdir(self.label_dir))) 106 | assert len(self.img_files) == len( 107 | self.label_files 108 | ), "The {} dataset contains a different number of images and labels: {} images and {} labels".format( 109 | self.dataset_name, len(self.img_files), len(self.label_files) 110 | ) 111 | else: 112 | self.label_files = None 113 | 114 | def __getitem__(self, index): 115 | image_name = self.img_files[index] 116 | img = Image.open(os.path.join(self.img_dir, image_name)) 117 | if self.label_files is not None: 118 | label_name = self.label_files[index] 119 | label = Image.open(os.path.join(self.label_dir, label_name)) 120 | 121 | seed = np.random.randint(2147483647) 122 | random.seed(seed) 123 | torch.manual_seed(seed) 124 | img = self.transform(img) 125 | if self.label_files is not None: 126 | random.seed(seed) 127 | torch.manual_seed(seed) 128 | label = self.target_transform(label) 129 | else: 130 | label = torch.zeros(img.shape[1], img.shape[2], dtype=torch.int64) 131 | 132 | mask = (label > 0).to(torch.float32) 133 | return img, label, mask 134 | 135 | def __len__(self): 136 | return len(self.img_files) 137 | 138 | 139 | class ContrastiveSegDataset(Dataset): 140 | """ 141 | The main Dataset class used by STEGO. 142 | Internally uses the DirectoryDataset class to load images. 143 | Additionally, this class uses the precomputed Nearest Neighbor files to extract the knn corresponding image for STEGO training. 144 | It returns a dictionary containing an image and its positive pair (one of the nearest neighbor images). 145 | """ 146 | 147 | def __init__( 148 | self, 149 | data_dir, 150 | dataset_name, 151 | image_set, 152 | transform, 153 | target_transform, 154 | model_type, 155 | resolution, 156 | aug_geometric_transform=None, 157 | aug_photometric_transform=None, 158 | num_neighbors=5, 159 | mask=False, 160 | pos_labels=False, 161 | pos_images=False, 162 | extra_transform=None, 163 | ): 164 | super(ContrastiveSegDataset).__init__() 165 | self.num_neighbors = num_neighbors 166 | self.image_set = image_set 167 | self.dataset_name = dataset_name 168 | self.mask = mask 169 | self.pos_labels = pos_labels 170 | self.pos_images = pos_images 171 | self.extra_transform = extra_transform 172 | self.aug_geometric_transform = aug_geometric_transform 173 | self.aug_photometric_transform = aug_photometric_transform 174 | 175 | self.dataset = DirectoryDataset(data_dir, dataset_name, image_set, transform, target_transform) 176 | 177 | feature_cache_file = get_nn_file_name(data_dir, dataset_name, model_type, image_set, resolution) 178 | if pos_labels or pos_images: 179 | if not os.path.exists(feature_cache_file): 180 | raise ValueError("could not find nn file {} please run precompute_knns".format(feature_cache_file)) 181 | else: 182 | loaded = np.load(feature_cache_file) 183 | self.nns = loaded["nns"] 184 | assert ( 185 | len(self.dataset) == self.nns.shape[0] 186 | ), "Found different numbers of images in dataset {} and nn file {}".format(dataset_name, feature_cache_file) 187 | 188 | def __len__(self): 189 | return len(self.dataset) 190 | 191 | def _set_seed(self, seed): 192 | random.seed(seed) # apply this seed to img tranfsorms 193 | torch.manual_seed(seed) # needed for torchvision 0.7 194 | 195 | def __getitem__(self, ind): 196 | pack = self.dataset[ind] 197 | 198 | if self.pos_images or self.pos_labels: 199 | ind_pos = self.nns[ind][torch.randint(low=1, high=self.num_neighbors + 1, size=[]).item()] 200 | pack_pos = self.dataset[ind_pos] 201 | 202 | seed = np.random.randint(2147483647) # make a seed with numpy generator 203 | 204 | self._set_seed(seed) 205 | coord_entries = torch.meshgrid( 206 | [ 207 | torch.linspace(-1, 1, pack[0].shape[1]), 208 | torch.linspace(-1, 1, pack[0].shape[2]), 209 | ] 210 | ) 211 | coord = torch.cat([t.unsqueeze(0) for t in coord_entries], 0) 212 | 213 | if self.extra_transform is not None: 214 | extra_trans = self.extra_transform 215 | else: 216 | extra_trans = lambda i, x: x 217 | 218 | ret = { 219 | "ind": ind, 220 | "img": extra_trans(ind, pack[0]), 221 | "label": extra_trans(ind, pack[1]), 222 | } 223 | 224 | if self.pos_images: 225 | ret["img_pos"] = extra_trans(ind, pack_pos[0]) 226 | ret["ind_pos"] = ind_pos 227 | 228 | if self.mask: 229 | ret["mask"] = pack[2] 230 | 231 | if self.pos_labels: 232 | ret["label_pos"] = extra_trans(ind, pack_pos[1]) 233 | ret["mask_pos"] = pack_pos[2] 234 | 235 | if self.aug_photometric_transform is not None: 236 | img_aug = self.aug_photometric_transform(self.aug_geometric_transform(pack[0])) 237 | 238 | self._set_seed(seed) 239 | coord_aug = self.aug_geometric_transform(coord) 240 | 241 | ret["img_aug"] = img_aug 242 | ret["coord_aug"] = coord_aug.permute(1, 2, 0) 243 | 244 | return ret 245 | -------------------------------------------------------------------------------- /stego/modules.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | import numpy as np 12 | import pydensecrf.densecrf as dcrf 13 | import pydensecrf.utils as utils 14 | import torchvision.transforms.functional as VF 15 | from kornia.core import Tensor 16 | from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE 17 | 18 | from stego.utils import unnorm, sample, super_perm, norm, tensor_correlation 19 | 20 | 21 | class SegmentationHead(nn.Module): 22 | """ 23 | STEGO's segmentation head module. 24 | """ 25 | 26 | def __init__(self, input_dim, dim): 27 | super().__init__() 28 | self.linear = torch.nn.Sequential(torch.nn.Conv2d(input_dim, dim, (1, 1))) 29 | self.nonlinear = torch.nn.Sequential( 30 | torch.nn.Conv2d(input_dim, input_dim, (1, 1)), 31 | torch.nn.ReLU(), 32 | torch.nn.Conv2d(input_dim, dim, (1, 1)), 33 | ) 34 | 35 | def forward(self, inputs): 36 | return self.linear(inputs) + self.nonlinear(inputs) 37 | 38 | 39 | class ClusterLookup(nn.Module): 40 | """ 41 | STEGO's clustering module. 42 | Performs cosine distance K-means on the given features. 43 | """ 44 | 45 | def __init__(self, dim: int, n_classes: int): 46 | super(ClusterLookup, self).__init__() 47 | self.n_classes = n_classes 48 | self.dim = dim 49 | self.clusters = torch.nn.Parameter(torch.randn(n_classes, dim)) 50 | 51 | def reset_parameters(self): 52 | with torch.no_grad(): 53 | self.clusters.copy_(torch.randn(self.n_classes, self.dim)) 54 | 55 | def forward(self, x, alpha, log_probs=False): 56 | normed_clusters = F.normalize(self.clusters, dim=1) 57 | normed_features = F.normalize(x, dim=1) 58 | inner_products = torch.einsum("bchw,nc->bnhw", normed_features, normed_clusters) 59 | 60 | if alpha is None: 61 | cluster_probs = ( 62 | F.one_hot(torch.argmax(inner_products, dim=1), self.clusters.shape[0]) 63 | .permute(0, 3, 1, 2) 64 | .to(torch.float32) 65 | ) 66 | else: 67 | cluster_probs = nn.functional.softmax(inner_products * alpha, dim=1) 68 | 69 | cluster_loss = -(cluster_probs * inner_products).sum(1).mean() 70 | if log_probs: 71 | return nn.functional.log_softmax(inner_products * alpha, dim=1) 72 | else: 73 | return cluster_loss, cluster_probs 74 | 75 | 76 | class ContrastiveCorrelationLoss(nn.Module): 77 | """ 78 | STEGO's correlation loss. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | cfg, 84 | ): 85 | super(ContrastiveCorrelationLoss, self).__init__() 86 | self.cfg = cfg 87 | 88 | def standard_scale(self, t): 89 | t1 = t - t.mean() 90 | t2 = t1 / t1.std() 91 | return t2 92 | 93 | def helper(self, f1, f2, c1, c2, shift): 94 | with torch.no_grad(): 95 | # Comes straight from backbone which is currently frozen. this saves mem. 96 | fd = tensor_correlation(norm(f1), norm(f2)) 97 | 98 | if self.cfg.pointwise: 99 | old_mean = fd.mean() 100 | fd -= fd.mean([3, 4], keepdim=True) 101 | fd = fd - fd.mean() + old_mean 102 | 103 | cd = tensor_correlation(norm(c1), norm(c2)) 104 | 105 | if self.cfg.zero_clamp: 106 | min_val = 0.0 107 | else: 108 | min_val = -9999.0 109 | 110 | if self.cfg.stabilize: 111 | loss = -cd.clamp(min_val, 0.8) * (fd - shift) 112 | else: 113 | loss = -cd.clamp(min_val) * (fd - shift) 114 | 115 | return loss, cd 116 | 117 | def forward( 118 | self, 119 | orig_feats: torch.Tensor, 120 | orig_feats_pos: torch.Tensor, 121 | orig_code: torch.Tensor, 122 | orig_code_pos: torch.Tensor, 123 | ): 124 | coord_shape = [ 125 | orig_feats.shape[0], 126 | self.cfg.feature_samples, 127 | self.cfg.feature_samples, 128 | 2, 129 | ] 130 | coords1 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1 131 | coords2 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1 132 | 133 | feats = sample(orig_feats, coords1) 134 | code = sample(orig_code, coords1) 135 | feats_pos = sample(orig_feats_pos, coords2) 136 | code_pos = sample(orig_code_pos, coords2) 137 | 138 | pos_intra_loss, pos_intra_cd = self.helper(feats, feats, code, code, self.cfg.pos_intra_shift) 139 | pos_inter_loss, pos_inter_cd = self.helper(feats, feats_pos, code, code_pos, self.cfg.pos_inter_shift) 140 | 141 | neg_losses = [] 142 | neg_cds = [] 143 | for i in range(self.cfg.neg_samples): 144 | perm_neg = super_perm(orig_feats.shape[0], orig_feats.device) 145 | feats_neg = sample(orig_feats[perm_neg], coords2) 146 | code_neg = sample(orig_code[perm_neg], coords2) 147 | neg_inter_loss, neg_inter_cd = self.helper(feats, feats_neg, code, code_neg, self.cfg.neg_inter_shift) 148 | neg_losses.append(neg_inter_loss) 149 | neg_cds.append(neg_inter_cd) 150 | neg_inter_loss = torch.cat(neg_losses, axis=0) 151 | neg_inter_cd = torch.cat(neg_cds, axis=0) 152 | 153 | return ( 154 | pos_intra_loss.mean(), 155 | pos_intra_cd, 156 | pos_inter_loss.mean(), 157 | pos_inter_cd, 158 | neg_inter_loss, 159 | neg_inter_cd, 160 | ) 161 | 162 | 163 | class CRF: 164 | """ 165 | Class encapsulating STEGO's CRF postprocessing step. 166 | """ 167 | 168 | def __init__(self, cfg): 169 | self.cfg = cfg 170 | 171 | def dense_crf(self, image_tensor: torch.FloatTensor, output_logits: torch.FloatTensor) -> torch.FloatTensor: 172 | image = np.array(VF.to_pil_image(unnorm(image_tensor)))[:, :, ::-1] 173 | H, W = image.shape[:2] 174 | image = np.ascontiguousarray(image) 175 | 176 | output_logits = F.interpolate( 177 | output_logits.unsqueeze(0), 178 | size=(H, W), 179 | mode="bilinear", 180 | align_corners=False, 181 | ).squeeze() 182 | output_probs = F.softmax(output_logits, dim=0).cpu().numpy() 183 | 184 | c = output_probs.shape[0] 185 | h = output_probs.shape[1] 186 | w = output_probs.shape[2] 187 | 188 | U = utils.unary_from_softmax(output_probs) 189 | U = np.ascontiguousarray(U) 190 | 191 | d = dcrf.DenseCRF2D(w, h, c) 192 | d.setUnaryEnergy(U) 193 | d.addPairwiseGaussian(sxy=self.cfg.pos_xy_std, compat=self.cfg.pos_w) 194 | d.addPairwiseBilateral( 195 | sxy=self.cfg.bi_xy_std, 196 | srgb=self.cfg.bi_rgb_std, 197 | rgbim=image, 198 | compat=self.cfg.bi_w, 199 | ) 200 | 201 | Q = d.inference(self.cfg.crf_max_iter) 202 | Q = np.array(Q).reshape((c, h, w)) 203 | return torch.from_numpy(Q) 204 | 205 | 206 | class KMeans: 207 | """Implements the kmeans clustering algorithm in PyTorch. 208 | The code of this class was based on: https://github.com/kornia/kornia/pull/2304 209 | 210 | Args: 211 | num_clusters: number of clusters the data has to be assigned to 212 | cluster_centers: tensor of starting cluster centres can be passed instead of num_clusters 213 | tolerance: float value. the algorithm terminates if the shift in centers is less than tolerance 214 | max_iterations: number of iterations to run the algorithm for 215 | distance_metric: {"euclidean", "cosine"}, type of the distance metric to use 216 | seed: number to set torch manual seed for reproducibility 217 | """ 218 | 219 | def __init__( 220 | self, 221 | num_clusters: int, 222 | cluster_centers: Tensor, 223 | tolerance: float = 10e-4, 224 | max_iterations: int = 0, 225 | distance_metric="euclidean", 226 | seed=None, 227 | ) -> None: 228 | KORNIA_CHECK(num_clusters != 0, "num_clusters can't be 0") 229 | 230 | # cluster_centers should have only 2 dimensions 231 | if cluster_centers is not None: 232 | KORNIA_CHECK_SHAPE(cluster_centers, ["C", "D"]) 233 | 234 | self.num_clusters = num_clusters 235 | self.cluster_centers = cluster_centers 236 | self.tolerance = tolerance 237 | self.max_iterations = max_iterations 238 | 239 | if distance_metric == "euclidean": 240 | self._pairwise_distance = self._pairwise_euclidean_distance 241 | elif distance_metric == "cosine": 242 | self._pairwise_distance = self._pairwise_cosine_distance 243 | else: 244 | raise ValueError("Unknown distance metric") 245 | 246 | self.final_cluster_assignments = None 247 | self.final_cluster_centers = None 248 | 249 | if seed is not None: 250 | torch.manual_seed(seed) 251 | 252 | def get_cluster_centers(self) -> Tensor: 253 | KORNIA_CHECK( 254 | self.final_cluster_centers is not None, 255 | "Model has not been fit to a dataset", 256 | ) 257 | return self.final_cluster_centers 258 | 259 | def get_cluster_assignments(self) -> Tensor: 260 | KORNIA_CHECK( 261 | self.final_cluster_assignments is not None, 262 | "Model has not been fit to a dataset", 263 | ) 264 | return self.final_cluster_assignments 265 | 266 | def _initialise_cluster_centers(self, X: Tensor, num_clusters: int) -> Tensor: 267 | """Chooses num_cluster points from X as the initial cluster centers. 268 | 269 | Args: 270 | X: 2D input tensor to be clustered 271 | num_clusters: number of desired cluster centers 272 | 273 | Returns: 274 | 2D Tensor with num_cluster rows 275 | """ 276 | num_samples = X.shape[0] 277 | perm = torch.randperm(num_samples, device=X.device) 278 | idx = perm[:num_clusters] 279 | initial_state = X[idx] 280 | return initial_state 281 | 282 | def _pairwise_euclidean_distance(self, data1: Tensor, data2: Tensor) -> Tensor: 283 | """Computes pairwise distance between 2 sets of vectors. 284 | 285 | Args: 286 | data1: 2D tensor of shape N, D 287 | data2: 2D tensor of shape C, D 288 | 289 | Returns: 290 | 2D tensor of shape N, C 291 | """ 292 | # N*1*D 293 | A = data1[:, None, ...] 294 | # 1*C*D 295 | B = data2[None, ...] 296 | distance = (A - B) ** 2.0 297 | # return N*C matrix for pairwise distance 298 | distance = distance.sum(dim=-1) 299 | return distance 300 | 301 | def _pairwise_cosine_distance(self, data1: Tensor, data2: Tensor) -> Tensor: 302 | """Computes pairwise distance between 2 sets of vectors. 303 | 304 | Args: 305 | data1: 2D tensor of shape N, D 306 | data2: 2D tensor of shape C, D 307 | 308 | Returns: 309 | 2D tensor of shape N, C 310 | """ 311 | normed_A = F.normalize(data1, dim=1) 312 | normed_B = F.normalize(data2, dim=1) 313 | distance = 1.0 - torch.einsum("nd,cd->nc", normed_A, normed_B) 314 | return distance 315 | 316 | def fit(self, X: Tensor) -> None: 317 | """Iterative KMeans clustering till a threshold for shift in cluster centers or a maximum no of iterations 318 | have reached. 319 | 320 | Args: 321 | X: 2D input tensor to be clustered 322 | """ 323 | KORNIA_CHECK_SHAPE(X, ["N", "D"]) 324 | 325 | if self.cluster_centers is None: 326 | self.cluster_centers = self._initialise_cluster_centers(X, self.num_clusters) 327 | else: 328 | # X and cluster_centers should have same number of columns 329 | KORNIA_CHECK( 330 | X.shape[1] == self.cluster_centers.shape[1], 331 | f"Dimensions at position 1 of X and cluster_centers do not match. \ 332 | {X.shape[1]} != {self.cluster_centers.shape[1]}", 333 | ) 334 | 335 | current_centers = self.cluster_centers 336 | 337 | previous_centers = None 338 | iteration: int = 0 339 | 340 | while True: 341 | # find distance between X and current_centers 342 | distance: Tensor = self._pairwise_distance(X, current_centers) 343 | 344 | cluster_assignment = torch.argmin(distance, dim=1) 345 | 346 | previous_centers = current_centers.clone() 347 | 348 | one_hot_assignments = torch.nn.functional.one_hot(cluster_assignment, self.num_clusters).float() 349 | sum_points = torch.mm(one_hot_assignments.T, X) 350 | num_points = one_hot_assignments.sum(0).unsqueeze(1) 351 | 352 | # Handle empty clusters by replacing them with a random point 353 | empty_clusters = num_points.squeeze() == 0 354 | random_points = X[torch.randint(len(X), (torch.sum(empty_clusters),))] 355 | sum_points[empty_clusters, :] = random_points 356 | num_points[empty_clusters] = 1 357 | 358 | current_centers = sum_points / num_points 359 | 360 | # sum of distance of how much the newly computed clusters have moved from their previous positions 361 | center_shift = torch.sum(torch.sqrt(torch.sum((current_centers - previous_centers) ** 2, dim=1))) 362 | 363 | iteration = iteration + 1 364 | 365 | if self.tolerance is not None and center_shift**2 < self.tolerance: 366 | break 367 | 368 | if self.max_iterations != 0 and iteration >= self.max_iterations: 369 | break 370 | 371 | self.final_cluster_assignments = cluster_assignment 372 | self.final_cluster_centers = current_centers 373 | 374 | def predict(self, x: Tensor) -> Tensor: 375 | """Find the cluster center closest to each point in x. 376 | 377 | Args: 378 | x: 2D tensor 379 | 380 | Returns: 381 | 1D tensor containing cluster id assigned to each data point in x 382 | """ 383 | 384 | # x and cluster_centers should have same number of columns 385 | KORNIA_CHECK( 386 | x.shape[1] == self.final_cluster_centers.shape[1], 387 | f"Dimensions at position 1 of x and cluster_centers do not match. \ 388 | {x.shape[1]} != {self.final_cluster_centers.shape[1]}", 389 | ) 390 | 391 | distance = self._pairwise_distance(x, self.final_cluster_centers) 392 | cluster_assignment = torch.argmin(distance, axis=1) 393 | return cluster_assignment, distance 394 | -------------------------------------------------------------------------------- /stego/stego.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | import pytorch_lightning as pl 12 | import omegaconf 13 | import os 14 | import wandb 15 | import matplotlib as plt 16 | import io 17 | 18 | 19 | from stego.backbones.backbone import get_backbone 20 | from stego.utils import UnsupervisedMetrics 21 | from stego.data import Image 22 | from stego.modules import SegmentationHead, ClusterLookup, ContrastiveCorrelationLoss, CRF, KMeans 23 | 24 | 25 | class Stego(pl.LightningModule): 26 | """ 27 | The main STEGO class. 28 | """ 29 | 30 | def __init__(self, n_classes, n_image_clusters=0, cfg=None): 31 | super().__init__() 32 | if cfg is None: 33 | with open(os.path.join(os.path.dirname(__file__), "cfg/model_config.yaml"), "r") as file: 34 | self.cfg = omegaconf.OmegaConf.load(file) 35 | cfg = self.cfg 36 | else: 37 | self.cfg = cfg 38 | self.dim = self.cfg.dim 39 | self.automatic_optimization = False 40 | self.n_classes = n_classes 41 | self.backbone = get_backbone(self.cfg) 42 | self.full_backbone_name = self.backbone.get_backbone_name() 43 | self.backbone.eval() 44 | self.backbone_dim = self.backbone.get_output_feat_dim() 45 | self.segmentation_head = SegmentationHead(self.backbone_dim, self.dim) 46 | 47 | self.cluster_probe = ClusterLookup(self.dim, self.n_classes + self.cfg.extra_clusters) 48 | self.linear_probe = nn.Conv2d(self.dim, n_classes, (1, 1)) 49 | 50 | self.cluster_metrics = UnsupervisedMetrics("test/cluster/", n_classes, self.cfg.extra_clusters, True) 51 | self.linear_metrics = UnsupervisedMetrics("test/linear/", n_classes, 0, False) 52 | 53 | self.linear_probe_loss_fn = torch.nn.CrossEntropyLoss() 54 | self.contrastive_corr_loss_fn = ContrastiveCorrelationLoss(self.cfg) 55 | for p in self.contrastive_corr_loss_fn.parameters(): 56 | p.requires_grad = False 57 | 58 | self.crf = CRF(self.cfg) 59 | 60 | self.n_image_clusters = n_image_clusters 61 | if n_image_clusters == 0: 62 | self.n_image_clusters = n_classes 63 | self.kmeans = KMeans( 64 | num_clusters=self.n_image_clusters, 65 | cluster_centers=None, 66 | distance_metric="cosine", 67 | max_iterations=100, 68 | ) 69 | 70 | self.cd_hist = torch.zeros(40) 71 | 72 | self.save_hyperparameters() 73 | 74 | def reset_clusters(self, n_classes, extra_clusters): 75 | """ 76 | Resets STEGO's cluster and linear probes, possibly with a different number of classes and extra clusters for the cluster probe. 77 | """ 78 | self.cluster_probe = ClusterLookup(self.dim, n_classes + extra_clusters) 79 | self.cluster_metrics = UnsupervisedMetrics("test/cluster/", n_classes, extra_clusters, True) 80 | self.linear_probe = nn.Conv2d(self.dim, n_classes, (1, 1)) 81 | self.linear_metrics = UnsupervisedMetrics("test/linear/", n_classes, 0, False) 82 | self.n_classes = n_classes 83 | 84 | def configure_optimizers(self): 85 | main_params = list(self.backbone.parameters()) + list(self.segmentation_head.parameters()) 86 | net_optim = torch.optim.Adam(main_params, lr=self.cfg.lr) 87 | linear_probe_optim = torch.optim.Adam(list(self.linear_probe.parameters()), lr=self.cfg.linear_lr) 88 | cluster_probe_optim = torch.optim.Adam(list(self.cluster_probe.parameters()), lr=self.cfg.cluster_lr) 89 | return net_optim, linear_probe_optim, cluster_probe_optim 90 | 91 | def forward(self, img): 92 | backbone_feats = self.backbone(img) 93 | return backbone_feats, self.segmentation_head(backbone_feats) 94 | 95 | def get_code(self, img): 96 | """ 97 | Returns segmentation features for a given image. 98 | Returned features are an average of two passes through STEGO, with the input image and its horizontal flip. 99 | """ 100 | # code1 = self.forward(img)[1] 101 | # code2 = self.forward(img.flip(dims=[3]))[1] 102 | # code = (code1 + code2.flip(dims=[3])) / 2 103 | # return code 104 | code = self.forward(img)[1] 105 | return code 106 | 107 | def postprocess_crf(self, img, probs): 108 | """ 109 | Performs the CRF postprocessing step on the given image and a set of predicted class probabilities. 110 | The class probabilities are interpolated to fit the image size inside the dense_crf function. 111 | """ 112 | pred = torch.empty(torch.Size(img.size()[:-3] + img.size()[-2:])) 113 | for j in range(img.shape[0]): 114 | single_img = img[j] 115 | x = self.crf.dense_crf(single_img, probs[j]).argmax(0) 116 | pred[j] = x 117 | return pred.int() 118 | 119 | def postprocess_cluster(self, code, img, use_crf=True, image_clustering=False): 120 | """ 121 | Cluster probe postprocessing of STEGO. 122 | For the given features, the cluster probe is run, followed by CRF (if enabled). 123 | If enabled, performs the K-means clustering only of the segmentation features in the given batch. 124 | 125 | Arguments: 126 | - code - STEGO's segmentation features. 127 | - img - input image. 128 | - use_crf - enables CRF on the image and class probabilities from the cluster probe. 129 | - image_clustering - enables per-image clustering. If True, STEGO's cluster probe is ignored and K-means is run on the given segmentation features to produce the cluster probabilities, 130 | """ 131 | orig_code = code.permute(0, 2, 3, 1) 132 | code = F.interpolate(code, img.shape[-2:], mode="bilinear", align_corners=False) 133 | if image_clustering: 134 | self.kmeans.fit(orig_code.reshape((-1, orig_code.shape[-1]))) 135 | 136 | normed_centers = F.normalize(self.kmeans.final_cluster_centers, dim=1) 137 | normed_code = F.normalize(code, dim=1).permute(0, 2, 3, 1) 138 | inner_products = torch.einsum("bhwc,nc->bnhw", normed_code, normed_centers) 139 | cluster_probs = nn.functional.softmax(inner_products * 2, dim=1) 140 | else: 141 | cluster_probs = self.cluster_probe(code, 2, log_probs=True) 142 | if use_crf: 143 | cluster_preds = self.postprocess_crf(img, cluster_probs) 144 | else: 145 | cluster_preds = cluster_probs.argmax(dim=1) 146 | return cluster_preds 147 | 148 | def postprocess_linear(self, code, img, use_crf=True): 149 | """ 150 | Linear probe postprocessing of STEGO. 151 | For the given features, the linear probe is run, followed by CRF (if enabled). 152 | 153 | Arguments: 154 | - code - STEGO's segmentation features. 155 | - img - input image. 156 | - use_crf - enables CRF on the image and class probabilities from the linear probe. 157 | """ 158 | code = F.interpolate(code, img.shape[-2:], mode="bilinear", align_corners=False) 159 | linear_probs = torch.log_softmax(self.linear_probe(code), dim=1) 160 | if use_crf: 161 | linear_preds = self.postprocess_crf(img, linear_probs) 162 | else: 163 | linear_preds = linear_probs.argmax(1) 164 | return linear_preds 165 | 166 | def postprocess( 167 | self, 168 | code, 169 | img, 170 | use_crf_cluster=True, 171 | use_crf_linear=True, 172 | image_clustering=False, 173 | ): 174 | """ 175 | Complete postprocessing of STEGO. 176 | For the given features, both the cluster and linear probes are run, followed by CRF (if enabled). 177 | If enabled, performs the K-means clustering only of the segmentation features in the given batch. 178 | 179 | Arguments: 180 | - code - STEGO's segmentation features. 181 | - img - input image. 182 | - use_crf_cluster - enables CRF on the image and class probabilities from the cluster probe. 183 | - use_crf_linear - enables CRF on the image and class probabilities from the linear probe. 184 | - image_clustering - enables per-image clustering. If True, STEGO's cluster probe is ignored and K-means is run on the given segmentation features to produce the cluster probabilities, 185 | """ 186 | cluster_preds = self.postprocess_cluster(code, img, use_crf_cluster, image_clustering) 187 | linear_preds = self.postprocess_linear(code, img, use_crf_linear) 188 | return cluster_preds, linear_preds 189 | 190 | def training_step(self, batch, batch_idx): 191 | net_optim, linear_probe_optim, cluster_probe_optim = self.optimizers() 192 | net_optim.zero_grad() 193 | linear_probe_optim.zero_grad() 194 | cluster_probe_optim.zero_grad() 195 | log_args = dict(sync_dist=False, rank_zero_only=True) 196 | 197 | with torch.no_grad(): 198 | img = batch["img"] 199 | img_pos = batch["img_pos"] 200 | label = batch["label"] 201 | 202 | feats, code = self.forward(img) 203 | feats_pos, code_pos = self.forward(img_pos) 204 | 205 | ( 206 | pos_intra_loss, 207 | pos_intra_cd, 208 | pos_inter_loss, 209 | pos_inter_cd, 210 | neg_inter_loss, 211 | neg_inter_cd, 212 | ) = self.contrastive_corr_loss_fn( 213 | feats, 214 | feats_pos, 215 | code, 216 | code_pos, 217 | ) 218 | neg_inter_loss = neg_inter_loss.mean() 219 | pos_intra_loss = pos_intra_loss.mean() 220 | pos_inter_loss = pos_inter_loss.mean() 221 | 222 | self.cd_hist = torch.add(self.cd_hist, torch.histc(pos_intra_cd.cpu(), bins=40, min=-1, max=1)) 223 | self.cd_hist = torch.add(self.cd_hist, torch.histc(pos_inter_cd.cpu(), bins=40, min=-1, max=1)) 224 | self.cd_hist = torch.add(self.cd_hist, torch.histc(neg_inter_cd.cpu(), bins=40, min=-1, max=1)) 225 | 226 | self.log("loss/pos_intra", pos_intra_loss) 227 | self.log("loss/pos_inter", pos_inter_loss) 228 | self.log("loss/neg_inter", neg_inter_loss) 229 | self.log("cd/pos_intra", pos_intra_cd.mean()) 230 | self.log("cd/pos_inter", pos_inter_cd.mean()) 231 | self.log("cd/neg_inter", neg_inter_cd.mean()) 232 | 233 | loss = ( 234 | self.cfg.pos_inter_weight * pos_inter_loss 235 | + self.cfg.pos_intra_weight * pos_intra_loss 236 | + self.cfg.neg_inter_weight * neg_inter_loss 237 | ) 238 | 239 | flat_label = label.reshape(-1) 240 | mask = (flat_label >= 0) & (flat_label < self.n_classes) 241 | 242 | detached_code = torch.clone(code.detach()) 243 | 244 | linear_logits = self.linear_probe(detached_code) 245 | linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode="bilinear", align_corners=False) 246 | linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, self.n_classes) 247 | linear_loss = self.linear_probe_loss_fn(linear_logits[mask], flat_label[mask]).mean() 248 | loss += linear_loss 249 | self.log("loss/linear", linear_loss, **log_args) 250 | 251 | cluster_loss, cluster_probs = self.cluster_probe(detached_code, None) 252 | loss += cluster_loss 253 | 254 | self.log("loss/cluster", cluster_loss, **log_args) 255 | self.log("loss/total", loss, **log_args) 256 | 257 | self.manual_backward(loss) 258 | net_optim.step() 259 | cluster_probe_optim.step() 260 | linear_probe_optim.step() 261 | 262 | return loss 263 | 264 | def validation_step(self, batch, batch_idx): 265 | img = batch["img"] 266 | label = batch["label"] 267 | 268 | with torch.no_grad(): 269 | code = self.forward(img)[1] 270 | code = F.interpolate(code, label.shape[-2:], mode="bilinear", align_corners=False) 271 | 272 | linear_preds = self.linear_probe(code) 273 | linear_preds = linear_preds.argmax(1) 274 | self.linear_metrics.update(linear_preds, label) 275 | 276 | cluster_loss, cluster_preds = self.cluster_probe(code, None) 277 | cluster_preds = cluster_preds.argmax(1) 278 | self.cluster_metrics.update(cluster_preds, label) 279 | 280 | linear_metrics = self.linear_metrics.compute() 281 | cluster_metrics = self.cluster_metrics.compute() 282 | 283 | self.log("val/linear/mIoU", linear_metrics["test/linear/mIoU"]) 284 | self.log("val/linear/Accuracy", linear_metrics["test/linear/Accuracy"]) 285 | self.log("val/cluster/mIoU", cluster_metrics["test/cluster/mIoU"]) 286 | self.log("val/cluster/Accuracy", cluster_metrics["test/cluster/Accuracy"]) 287 | 288 | return { 289 | "img": img[: self.cfg.val_n_imgs].detach().cpu(), 290 | "linear_preds": linear_preds[: self.cfg.val_n_imgs].detach().cpu(), 291 | "cluster_preds": cluster_preds[: self.cfg.val_n_imgs].detach().cpu(), 292 | "label": label[: self.cfg.val_n_imgs].detach().cpu(), 293 | } 294 | 295 | def on_validation_epoch_end(self, outputs) -> None: 296 | super().on_validation_epoch_end(outputs) 297 | with torch.no_grad(): 298 | self.linear_metrics.reset() 299 | self.cluster_metrics.reset() 300 | 301 | for i in range(self.cfg.val_n_imgs): 302 | img = outputs[0]["img"][i].cpu().numpy().transpose((1, 2, 0)) 303 | label = torch.squeeze(outputs[0]["label"][i]).cpu().numpy() 304 | cluster = torch.squeeze(outputs[0]["cluster_preds"][i]).cpu().numpy() 305 | linear = torch.squeeze(outputs[0]["linear_preds"][i]).cpu().numpy() 306 | vis = wandb.Image( 307 | img, 308 | masks={ 309 | "label": {"mask_data": label}, 310 | "cluster": {"mask_data": cluster}, 311 | "linear": {"mask_data": linear}, 312 | }, 313 | caption="Image" + str(i), 314 | ) 315 | self.logger.experiment.log({"Image" + str(i): vis}) 316 | 317 | self.cd_hist = self.cd_hist / torch.sum(self.cd_hist) 318 | x = [-1 + i * (2 / 40) + 1 / 40 for i in range(40)] 319 | plt.figure() 320 | ax = plt.axes() 321 | ax.plot(x, self.cd_hist) 322 | ax.set_xlim([-1, 1]) 323 | ax.set_ylim([0, 0.4]) 324 | 325 | img_buf = io.BytesIO() 326 | plt.savefig(img_buf, format="png") 327 | hist_img = Image.open(img_buf) 328 | hist_vis = wandb.Image(hist_img, caption="Learned Feature Similarity Distribution") 329 | self.logger.experiment.log({"Histogram": hist_vis}) 330 | img_buf.close() 331 | self.cd_hist = torch.zeros(40) 332 | -------------------------------------------------------------------------------- /stego/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | import collections 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data._utils.collate import ( 12 | np_str_obj_array_pattern, 13 | default_collate_err_msg_format, 14 | ) 15 | from torchvision import transforms as T 16 | from torchmetrics import Metric 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | from scipy.optimize import linear_sum_assignment 20 | from PIL import Image 21 | import sys 22 | import os 23 | 24 | 25 | def load_image_to_tensor( 26 | img_path, 27 | resolution=320, 28 | brightness_factor=1.0, 29 | contrast_factor=1.0, 30 | saturation_factor=1.0, 31 | hue_factor=0.0, 32 | gaussian_sigma=None, 33 | gaussian_kernel_size=None, 34 | ): 35 | img = Image.open(img_path) 36 | transforms = [] 37 | if brightness_factor != 1.0 or contrast_factor != 1.0 or saturation_factor != 1.0 or hue_factor != 0.0: 38 | transforms.append( 39 | T.ColorJitter( 40 | brightness=(brightness_factor, brightness_factor), 41 | contrast=(contrast_factor, contrast_factor), 42 | saturation=(saturation_factor, saturation_factor), 43 | hue=(hue_factor, hue_factor), 44 | ) 45 | ) 46 | if gaussian_sigma is not None and gaussian_kernel_size is not None: 47 | transforms.append(T.GaussianBlur(kernel_size=gaussian_kernel_size, sigma=gaussian_sigma)) 48 | elif gaussian_sigma is not None and gaussian_kernel_size is not None: 49 | raise ValueError( 50 | "Both sigma and kernel size for gaussian blur augmentation need to be None or specified, but exactly one was specified." 51 | ) 52 | transforms.append(get_transform(resolution, False, "center")) 53 | preprocess_transform = T.Compose(transforms) 54 | image_tensor = torch.unsqueeze(preprocess_transform(img), 0) 55 | return image_tensor 56 | 57 | 58 | def get_nn_file_name(data_dir, dataset_name, model_type, image_set, resolution): 59 | return os.path.join( 60 | data_dir, 61 | dataset_name, 62 | "nns", 63 | "nns_{}_{}_{}.npz".format(model_type, image_set, resolution), 64 | ) 65 | 66 | 67 | class UnNormalize(object): 68 | def __init__(self, mean, std): 69 | self.mean = mean 70 | self.std = std 71 | 72 | def __call__(self, image): 73 | image2 = torch.clone(image) 74 | for t, m, s in zip(image2, self.mean, self.std): 75 | t.mul_(s).add_(m) 76 | return image2 77 | 78 | 79 | class ToTargetTensor(object): 80 | def __call__(self, target): 81 | return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) 82 | 83 | 84 | def get_transform(res, is_label, crop_type, is_tensor=False, do_normalize=True): 85 | if crop_type == "center": 86 | cropper = T.CenterCrop(res) 87 | elif crop_type == "random": 88 | cropper = T.RandomCrop(res) 89 | elif crop_type is None: 90 | cropper = T.Lambda(lambda x: x) 91 | res = (res, res) 92 | else: 93 | raise ValueError("Unknown Cropper {}".format(crop_type)) 94 | transform = [T.Resize(res, T.InterpolationMode.NEAREST), cropper] 95 | 96 | if is_label: 97 | transform.append(ToTargetTensor()) 98 | else: 99 | if not is_tensor: 100 | transform.append(T.ToTensor()) 101 | 102 | if do_normalize: 103 | transform.append(normalize) 104 | 105 | return T.Compose(transform) 106 | 107 | 108 | normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 109 | unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 110 | 111 | 112 | def tensor_correlation(a, b): 113 | return torch.einsum("nchw,ncij->nhwij", a, b) 114 | 115 | 116 | def norm(t): 117 | return F.normalize(t, dim=1, eps=1e-10) 118 | 119 | 120 | def sample(t: torch.Tensor, coords: torch.Tensor): 121 | return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode="border", align_corners=True) 122 | 123 | 124 | def sample_nonzero_locations(t, target_size): 125 | nonzeros = torch.nonzero(t) 126 | coords = torch.zeros(target_size, dtype=nonzeros.dtype, device=nonzeros.device) 127 | n = target_size[1] * target_size[2] 128 | for i in range(t.shape[0]): 129 | selected_nonzeros = nonzeros[nonzeros[:, 0] == i] 130 | if selected_nonzeros.shape[0] == 0: 131 | selected_coords = torch.randint(t.shape[1], size=(n, 2), device=nonzeros.device) 132 | else: 133 | selected_coords = selected_nonzeros[torch.randint(len(selected_nonzeros), size=(n,)), 1:] 134 | coords[i, :, :, :] = selected_coords.reshape(target_size[1], target_size[2], 2) 135 | coords = coords.to(torch.float32) / t.shape[1] 136 | coords = coords * 2 - 1 137 | return torch.flip(coords, dims=[-1]) 138 | 139 | 140 | @torch.jit.script 141 | def super_perm(size: int, device: torch.device): 142 | perm = torch.randperm(size, device=device, dtype=torch.long) 143 | perm[perm == torch.arange(size, device=device)] += 1 144 | return perm % size 145 | 146 | 147 | def prep_args(): 148 | old_args = sys.argv 149 | new_args = [old_args.pop(0)] 150 | while len(old_args) > 0: 151 | arg = old_args.pop(0) 152 | if len(arg.split("=")) == 2: 153 | new_args.append(arg) 154 | elif arg.startswith("--"): 155 | new_args.append(arg[2:] + "=" + old_args.pop(0)) 156 | else: 157 | raise ValueError("Unexpected arg style {}".format(arg)) 158 | sys.argv = new_args 159 | 160 | 161 | def _remove_axes(ax): 162 | ax.xaxis.set_major_formatter(plt.NullFormatter()) 163 | ax.yaxis.set_major_formatter(plt.NullFormatter()) 164 | ax.set_xticks([]) 165 | ax.set_yticks([]) 166 | 167 | 168 | def remove_axes(axes): 169 | if len(axes.shape) == 2: 170 | for ax1 in axes: 171 | for ax in ax1: 172 | _remove_axes(ax) 173 | else: 174 | for ax in axes: 175 | _remove_axes(ax) 176 | 177 | 178 | def prep_for_plot(img, rescale=True, resize=None): 179 | if resize is not None: 180 | img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear") 181 | else: 182 | img = img.unsqueeze(0) 183 | 184 | plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0) 185 | if rescale: 186 | plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min()) 187 | return plot_img 188 | 189 | 190 | def plot_distributions(value_lists, n_bins, names, x_name, output_file): 191 | plt.clf() 192 | plt.xlabel(x_name) 193 | plt.ylabel("Frequency") 194 | plt.title("Distribution of {}".format(x_name)) 195 | for i, values in enumerate(value_lists): 196 | if len(values) == 0: 197 | continue 198 | values_np = np.array(values) 199 | hist, bin_edges = np.histogram( 200 | values_np, 201 | bins=np.linspace(np.min(values_np), np.max(values_np), num=n_bins + 1), 202 | density=True, 203 | ) 204 | x = (bin_edges[:-1] + bin_edges[1:]) / 2 205 | plt.plot(x, hist, label=names[i]) 206 | plt.legend() 207 | plt.savefig(output_file) 208 | 209 | 210 | class UnsupervisedMetrics(Metric): 211 | def __init__( 212 | self, 213 | prefix: str, 214 | n_classes: int, 215 | extra_clusters: int, 216 | compute_hungarian: bool, 217 | dist_sync_on_step=True, 218 | ): 219 | # call `self.add_state`for every internal state that is needed for the metrics computations 220 | # dist_reduce_fx indicates the function that should be used to reduce 221 | # state from multiple processes 222 | super().__init__(dist_sync_on_step=dist_sync_on_step) 223 | 224 | self.n_classes = n_classes 225 | self.extra_clusters = extra_clusters 226 | self.compute_hungarian = compute_hungarian 227 | self.prefix = prefix 228 | self.add_state( 229 | "stats", 230 | default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64), 231 | dist_reduce_fx="sum", 232 | ) 233 | 234 | def update(self, preds: torch.Tensor, target: torch.Tensor): 235 | with torch.no_grad(): 236 | actual = target.reshape(-1) 237 | preds = preds.reshape(-1) 238 | mask = ( 239 | (actual >= 0) 240 | & (actual < self.n_classes) 241 | & (preds >= 0) 242 | & (preds < self.n_classes + self.extra_clusters) 243 | ) 244 | actual = actual[mask] 245 | preds = preds[mask] 246 | self.stats += ( 247 | torch.bincount( 248 | (self.n_classes + self.extra_clusters) * actual + preds, 249 | minlength=self.n_classes * (self.n_classes + self.extra_clusters), 250 | ) 251 | .reshape(self.n_classes, self.n_classes + self.extra_clusters) 252 | .t() 253 | .to(self.stats.device) 254 | ) 255 | 256 | def map_clusters(self, clusters): 257 | if self.extra_clusters == 0: 258 | return torch.tensor(self.assignments[1])[clusters] 259 | else: 260 | missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))) 261 | cluster_to_class = self.assignments[1] 262 | for missing_entry in missing: 263 | if missing_entry == cluster_to_class.shape[0]: 264 | cluster_to_class = np.append(cluster_to_class, -1) 265 | else: 266 | cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1) 267 | cluster_to_class = torch.tensor(cluster_to_class) 268 | return cluster_to_class[clusters] 269 | 270 | def compute(self): 271 | if self.compute_hungarian: 272 | self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True) 273 | if self.extra_clusters == 0: 274 | self.histogram = self.stats[np.argsort(self.assignments[1]), :] 275 | if self.extra_clusters > 0: 276 | self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True) 277 | histogram = self.stats[self.assignments_t[1], :] 278 | missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])) 279 | new_row = self.stats[missing, :].sum(0, keepdim=True) 280 | histogram = torch.cat([histogram, new_row], axis=0) 281 | new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device) 282 | self.histogram = torch.cat([histogram, new_col], axis=1) 283 | else: 284 | self.assignments = ( 285 | torch.arange(self.n_classes).unsqueeze(1), 286 | torch.arange(self.n_classes).unsqueeze(1), 287 | ) 288 | self.histogram = self.stats 289 | 290 | tp = torch.diag(self.histogram) 291 | fp = torch.sum(self.histogram, dim=0) - tp 292 | fn = torch.sum(self.histogram, dim=1) - tp 293 | 294 | iou = tp / (tp + fp + fn) 295 | # prc = tp / (tp + fn) 296 | opc = torch.sum(tp) / torch.sum(self.histogram) 297 | 298 | metric_dict = { 299 | self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(), 300 | self.prefix + "Accuracy": opc.item(), 301 | } 302 | return {k: 100 * v for k, v in metric_dict.items()} 303 | 304 | 305 | class WVNMetrics(Metric): 306 | def __init__( 307 | self, 308 | prefix: str, 309 | n_clusters: int, 310 | dist_sync_on_step=True, 311 | save_plots=False, 312 | output_dir=None, 313 | ): 314 | super().__init__(dist_sync_on_step=dist_sync_on_step) 315 | 316 | self.prefix = prefix 317 | self.n_clusters = n_clusters 318 | self.output_dir = None 319 | if save_plots: 320 | self.output_dir = output_dir 321 | # TN, FP, FN, TP for the traversable class (1) 322 | self.add_state("stats", default=torch.zeros(4, dtype=torch.int64), dist_reduce_fx="sum") 323 | self.add_state("feature_var", default=[], dist_reduce_fx="cat") 324 | self.add_state("code_var", default=[], dist_reduce_fx="cat") 325 | self.add_state("avg_n_clusters", default=[], dist_reduce_fx="cat") 326 | self.add_state("time", default=[], dist_reduce_fx="cat") 327 | 328 | def update( 329 | self, 330 | clusters: torch.Tensor, 331 | target: torch.Tensor, 332 | features: torch.Tensor, 333 | code: torch.Tensor, 334 | time: float, 335 | ): 336 | with torch.no_grad(): 337 | actual = target.reshape(-1) 338 | pred_clusters = clusters.reshape(-1) 339 | preds = self.assign_pred_to_clusters(pred_clusters, actual) 340 | self.stats += torch.bincount(2 * actual + preds, minlength=4).to(self.stats.device) 341 | cluster_count = torch.unique(pred_clusters).size(0) 342 | self.avg_n_clusters.append(cluster_count) 343 | self.time.append(time) 344 | if features is not None: 345 | self.feature_var.extend(self.update_variance(clusters, features)) 346 | if code is not None: 347 | self.code_var.extend(self.update_variance(clusters, code)) 348 | 349 | def update_variance(self, clusters: torch.Tensor, features: torch.Tensor): 350 | upsampled_features = F.interpolate(features, clusters.shape[-2:], mode="bilinear", align_corners=False).permute( 351 | 0, 2, 3, 1 352 | ) 353 | mean_feature_vars = [] 354 | for i in range(self.n_clusters): 355 | mask = clusters == i 356 | if mask.shape[0] != 1: 357 | mask = mask.unsqueeze(0) 358 | cluster_features = upsampled_features[mask].reshape(-1, upsampled_features.shape[-1]) 359 | if cluster_features.shape[0] > 1: 360 | mean_feature_vars.append(torch.mean(torch.var(cluster_features, dim=0)).item()) 361 | return mean_feature_vars 362 | 363 | def assign_pred_to_clusters(self, clusters: torch.Tensor, target: torch.Tensor): 364 | counts = torch.zeros(2, self.n_clusters, dtype=torch.int64) 365 | for i in range(2): 366 | mask = target == i 367 | counts[i] = torch.bincount(clusters[mask], minlength=self.n_clusters) 368 | cluster_pred = torch.where(counts[0] > counts[1], 0, 1) 369 | pred = cluster_pred[clusters.long()] 370 | return pred 371 | 372 | def compute_list_metric(self, metric_name, values, metric_dict, print_metrics=False): 373 | if len(values) == 0: 374 | return 375 | value = np.mean(np.array(values)) 376 | metric_dict[self.prefix + "/" + metric_name] = value 377 | if print_metrics: 378 | print("\t{}: {}".format(metric_name, value)) 379 | if self.output_dir is not None: 380 | plot_distributions( 381 | [values], 382 | 100, 383 | [self.prefix], 384 | metric_name, 385 | os.path.join(self.output_dir, self.prefix + "_" + metric_name + ".png"), 386 | ) 387 | 388 | def compute(self, print_metrics=False): 389 | tn = self.stats[0] 390 | fp = self.stats[1] 391 | fn = self.stats[2] 392 | tp = self.stats[3] 393 | 394 | iou = tp / (tp + fp + fn) 395 | acc = (tp + tn) / (tp + tn + fp + fn) 396 | 397 | metric_dict = { 398 | self.prefix + "/IoU": iou.item(), 399 | self.prefix + "/Accuracy": acc.item(), 400 | } 401 | 402 | if print_metrics: 403 | print(self.prefix) 404 | print("\tIoU: {}".format(iou.item())) 405 | print("\tAccuracy: {}".format(acc.item())) 406 | 407 | self.compute_list_metric("Avg_clusters", self.avg_n_clusters, metric_dict, print_metrics) 408 | self.compute_list_metric("Feature_var", self.feature_var, metric_dict, print_metrics) 409 | self.compute_list_metric("Code_var", self.code_var, metric_dict, print_metrics) 410 | self.compute_list_metric("Time", self.time, metric_dict, print_metrics) 411 | 412 | values_dict = { 413 | "Avg_clusters": self.avg_n_clusters, 414 | "Feature_var": self.feature_var, 415 | "Code_var": self.code_var, 416 | "Time": self.time, 417 | } 418 | 419 | return metric_dict, values_dict 420 | 421 | 422 | def flexible_collate(batch): 423 | r"""Puts each data field into a tensor with outer dimension batch size""" 424 | 425 | elem = batch[0] 426 | elem_type = type(elem) 427 | if isinstance(elem, torch.Tensor): 428 | out = None 429 | if torch.utils.data.get_worker_info() is not None: 430 | # If we're in a background process, concatenate directly into a 431 | # shared memory tensor to avoid an extra copy 432 | numel = sum([x.numel() for x in batch]) 433 | storage = elem.storage()._new_shared(numel) 434 | out = elem.new(storage) 435 | try: 436 | return torch.stack(batch, 0, out=out) 437 | except RuntimeError: 438 | return batch 439 | elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": 440 | if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": 441 | # array of string classes and object 442 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 443 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 444 | 445 | return flexible_collate([torch.as_tensor(b) for b in batch]) 446 | elif elem.shape == (): # scalars 447 | return torch.as_tensor(batch) 448 | elif isinstance(elem, float): 449 | return torch.tensor(batch, dtype=torch.float64) 450 | elif isinstance(elem, int): 451 | return torch.tensor(batch) 452 | elif isinstance(elem, str): 453 | return batch 454 | elif isinstance(elem, collections.abc.Mapping): 455 | return {key: flexible_collate([d[key] for d in batch]) for key in elem} 456 | elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple 457 | return elem_type(*(flexible_collate(samples) for samples in zip(*batch))) 458 | elif isinstance(elem, collections.abc.Sequence): 459 | # check to make sure that the elements in batch have consistent size 460 | it = iter(batch) 461 | elem_size = len(next(it)) 462 | if not all(len(elem) == elem_size for elem in it): 463 | raise RuntimeError("each element in list of batch should be of equal size") 464 | transposed = zip(*batch) 465 | return [flexible_collate(samples) for samples in transposed] 466 | 467 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 468 | --------------------------------------------------------------------------------