├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md └── sscd.jpg ├── .gitignore ├── LICENSE ├── README.md ├── bin └── train_slurm_wrapper.sh ├── docs ├── Evaluation.md └── Training.md ├── requirements.txt └── sscd ├── CHANGELOG.md ├── __init__.py ├── copydays_eval.py ├── datasets ├── __init__.py ├── copydays.py ├── disc.py ├── image_folder.py └── isc │ ├── __init__.py │ ├── descriptor_matching.py │ ├── io.py │ └── metrics.py ├── disc_eval.py ├── lib ├── __init__.py ├── distributed_util.py ├── fix_paths.py ├── inference.py ├── initialize.py ├── suppress_warnings.py └── util.py ├── models ├── __init__.py ├── gem_pooling.py └── model.py ├── train.py └── transforms ├── __init__.py ├── mixup.py ├── overlay_emoji.py ├── overlay_text.py ├── repeated_augmentation.py ├── samplers.py ├── settings.py └── transforms.py /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | The main purpose of this repository is to share our method and allow others 4 | to reproduce our results. 5 | Since reproducibility is our main goal, we'd prefer to avoid unnecessary 6 | changes to this codebase. 7 | 8 | Please open an issue if you encounter a bug or would like to suggest improvements. 9 | 10 | We ask that potential contributors discuss prospective changes with us before 11 | sending a pull request. 12 | -------------------------------------------------------------------------------- /.github/sscd.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/sscd-copy-detection/95902662f2217a5f4aa45f2a3fc70a01dfd3b66a/.github/sscd.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /venv*/ 2 | /output/ 3 | /data/ 4 | __pycache__ 5 | *.swp 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Self-Supervised Descriptor for Image Copy Detection (SSCD) 2 | 3 | This is the open-source codebase for 4 | "[A Self-Supervised Descriptor for Image Copy Detection](https://arxiv.org/abs/2202.10261)", 5 | recently accepted to [CVPR 2022](https://cvpr2022.thecvf.com/). 6 | 7 | This work uses self-supervised contrastive learning with strong 8 | differential entropy regularization to create a fingerprint for 9 | image copy detection. 10 | 11 |
12 | SSCD diagram 13 |
14 | 15 | ## About this codebase 16 | 17 | This implementation is built on [Pytorch Lightning](https://pytorchlightning.ai/), 18 | with some components from [Classy Vision](https://classyvision.ai/). 19 | 20 | Our original experiments were conducted in a proprietary codebase 21 | using data files (fonts and emoji) that are not licensed for 22 | redistribution. 23 | This version uses [Noto](https://fonts.google.com/noto) fonts and 24 | [Twemoji](https://twemoji.twitter.com/) emoji, via the 25 | [AugLy](https://github.com/facebookresearch/AugLy) project. 26 | As a result, models trained in this codebase perform slightly differently 27 | than our pretrained models. 28 | 29 | ## Pretrained models 30 | 31 | We provide trained models from our original experiments to allow 32 | others to reproduce our evaluation results. 33 | 34 | For convenience, we provide equivalent model files in a few formats: 35 | * Files ending in `.classy.pt` are weight files using Classy Vision ResNe(X)t backbones, 36 | which is how these models were trained. 37 | * Files ending in `.torchvision.pt` are weight files using Torchvision ResNet backbones. 38 | These files may be easier to integrate in Torchvision-based codebases. 39 | See [model.py](sscd/models/model.py) for how we integrate GeM pooling 40 | and L2 normalization into these models. 41 | * Files ending in `.torchscript.pt` are standalone [TorchScript](https://pytorch.org/docs/stable/jit.html) 42 | models that can be used in any pytorch project without any SSCD code. 43 | 44 | We provide the following models: 45 | 46 | | name | dataset | trunk | augmentations | dimensions | classy vision | torchvision | torchscript | 47 | |------------------------|----------|-----------------|------------------|------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------| 48 | | sscd_disc_blur | DISC | ResNet50 | strong blur | 512 | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_blur.classy.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_blur.torchvision.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_blur.torchscript.pt) | 49 | | sscd_disc_advanced | DISC | ResNet50 | advanced | 512 | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_advanced.classy.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_advanced.torchvision.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_advanced.torchscript.pt) | 50 | | sscd_disc_mixup | DISC | ResNet50 | advanced + mixup | 512 | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.classy.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchvision.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchscript.pt) | 51 | | sscd_disc_large | DISC | ResNeXt101 32x4 | advanced + mixup | 1024 | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_large.classy.pt) | | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_large.torchscript.pt) | 52 | | sscd_imagenet_blur | ImageNet | ResNet50 | strong blur | 512 | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_blur.classy.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_blur.torchvision.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_blur.torchscript.pt) | 53 | | sscd_imagenet_advanced | ImageNet | ResNet50 | advanced | 512 | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_advanced.classy.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_advanced.torchvision.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_advanced.torchscript.pt) | 54 | | sscd_imagenet_mixup | ImageNet | ResNet50 | advanced + mixup | 512 | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_mixup.classy.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_mixup.torchvision.pt) | [link](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_imagenet_mixup.torchscript.pt) | 55 | 56 | We recommend `sscd_disc_mixup` (ResNet50) as a default SSCD model, 57 | especially when comparing to other standard ResNet50 models, 58 | and `sscd_disc_large` (ResNeXt101) as a higher accuracy alternative 59 | using a bit more compute. 60 | 61 | Classy Vision and Torchvision use different default cardinality settings 62 | for ResNeXt101. We do not provide a Torchvision version of the 63 | `sscd_disc_large` model for this reason. 64 | 65 | ## Installation 66 | 67 | If you only plan to use torchscript models for inference, 68 | **no installation steps are necessary**, and any environment with 69 | a recent version of pytorch installed can run our torchscript 70 | models. 71 | 72 | For all other uses, see installation steps below. 73 | 74 | The code is written for pytorch-lightning 1.5 (the latest version 75 | at time of writing), and may need changes for future Lightning 76 | versions. 77 | 78 | ### Option 1: Install dependencies using Conda 79 | 80 | Install and activate conda, then create a conda environment for SSCD as follows: 81 | 82 | ```bash 83 | # Create conda environment 84 | conda create --name sscd -c pytorch -c conda-forge \ 85 | pytorch torchvision cudatoolkit=11.3 \ 86 | "pytorch-lightning>=1.5,<1.6" lightning-bolts \ 87 | faiss python-magic pandas numpy 88 | 89 | # Activate environment 90 | conda activate sscd 91 | 92 | # Install Classy Vision and AugLy from PIP: 93 | python -m pip install classy_vision augly 94 | ``` 95 | 96 | You may need to select a `cudatoolkit` version that corresponds 97 | to the system CUDA library version you have installed. 98 | See [PyTorch documentation](https://pytorch.org/) for supported 99 | combinations of pytorch, torchvision and cudatoolkit versions. 100 | 101 | For a non-CUDA (CPU only) installation, replace `cudatoolkit=...` with `cpuonly`. 102 | 103 | ### Option 2: Install dependencies using PIP 104 | 105 | ```bash 106 | # Create environment 107 | python3 -m virtualenv ./venv 108 | 109 | # Activate environment 110 | source ./venv/bin/activate 111 | 112 | # Install dependencies in this environment 113 | python -m pip install -r ./requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 114 | ``` 115 | The `--extra-index-url` option selects a newer version of CUDA 116 | libraries, required for NVidia A100 GPUs. This can be omitted 117 | if A100 support is not needed. 118 | 119 | ## Inference using SSCD models 120 | 121 | This section describes how to use pretrained SSCD models for inference. 122 | To perform inference for DISC and Copydays evaluations, see 123 | [Evaluation](docs/Evaluation.md). 124 | 125 | ### Preprocessing 126 | 127 | We recommend preprocessing images for inference either resizing 128 | the small edge to 288 or resizing the image to a square tensor. 129 | 130 | Using fixed-sized square tensors is more efficient on GPUs, to make 131 | better use of batching. 132 | Copy detection using square tensors benefits from directly resizing 133 | to the target tensor size. This skews the image, and does not preserve 134 | aspect ratio. This differs from the common practice for 135 | classification inference. 136 | 137 | ```python 138 | from torchvision import transforms 139 | 140 | normalize = transforms.Normalize( 141 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], 142 | ) 143 | small_288 = transforms.Compose([ 144 | transforms.Resize(288), 145 | transforms.ToTensor(), 146 | normalize, 147 | ]) 148 | skew_320 = transforms.Compose([ 149 | transforms.Resize([320, 320]), 150 | transforms.ToTensor(), 151 | normalize, 152 | ]) 153 | ``` 154 | 155 | ### Inference using Torchscript 156 | 157 | Torchscript files can be loaded directly in other projects without any SSCD code or dependencies. 158 | 159 | ```python 160 | import torch 161 | from PIL import Image 162 | 163 | model = torch.jit.load("/path/to/sscd_disc_mixup.torchscript.pt") 164 | img = Image.open("/path/to/image.png").convert('RGB') 165 | batch = small_288(img).unsqueeze(0) 166 | embedding = model(batch)[0, :] 167 | ``` 168 | 169 | These Torchscript models are prepared for inference. For other uses (eg. fine-tuning), 170 | use model weight files, as described below. 171 | 172 | ### Load model weight files 173 | 174 | To load model weight files, first construct the `Model` object, 175 | then load the weights using the standard `torch.load` and `load_state_dict` 176 | methods. 177 | 178 | ```python 179 | import torch 180 | from sscd.models.model import Model 181 | 182 | model = Model("CV_RESNET50", 512, 3.0) 183 | weights = torch.load("/path/to/sscd_disc_mixup.classy.pt") 184 | model.load_state_dict(weights) 185 | model.eval() 186 | ``` 187 | 188 | Once loaded, these models can be used interchangeably with Torchscript 189 | models for inference. 190 | 191 | Model backbone strings can be found in the `Backbone` 192 | enum in [model.py](sscd/models/model.py). Classy Vision models start 193 | with the prefix `CV_` and Torchvision models start with `TV_`. 194 | 195 | ### Using SSCD descriptors 196 | 197 | SSCD models produce 512 dimension (except the "large" model, which uses 1024 dimensions) 198 | L2 normalized descriptors for each input image. 199 | The similarity of two images with descriptors `a` and `b` can be measured 200 | by descriptor cosine similarity (`a.dot(b)`; higher is more similar), 201 | or equivalently using euclidean distance (`(a-b).norm()`; lower is more similar). 202 | 203 | For the `sscd_disc_mixup` model, DISC image pairs with embedding cosine similarity greater 204 | than `0.75` are copies with 90% precision, for example. This corresponds to a euclidean 205 | distance less than `0.7`, or squared euclidean distance less than `0.5`. 206 | 207 | #### Descriptor post-processing 208 | 209 | For best results, we recommend additional descriptor processing 210 | when sample images from the target distribution are available. 211 | Centering (subtracting the mean) followed by L2 normalization, 212 | or whitening followed by L2 normalization, can improve accuracy. 213 | 214 | Score normalization can make similarity more consistent and 215 | improve global accuracy metrics (but has no effect on ranking metrics). 216 | 217 | ### Other model formats 218 | 219 | If pretrained models in another format (eg. ONYX) would be useful for you, 220 | let us know by filing a feature request. 221 | 222 | ## Reproducing evaluation results 223 | 224 | To reproduce evaluation results, see [Evaluation](docs/Evaluation.md). 225 | 226 | ## Training SSCD models 227 | 228 | For information on how to train SSCD models, see 229 | [Training](docs/Training.md). 230 | 231 | ## License 232 | 233 | The SSCD codebase uses the [MIT license](LICENSE). 234 | 235 | ## Citation 236 | 237 | If you find our codebase useful, please consider giving a star :star: and cite as: 238 | 239 | ``` 240 | @article{pizzi2022self, 241 | title={A Self-Supervised Descriptor for Image Copy Detection}, 242 | author={Pizzi, Ed and Roy, Sreya Dutta and Ravindra, Sugosh Nagavara and Goyal, Priya and Douze, Matthijs}, 243 | journal={Proc. CVPR}, 244 | year={2022} 245 | } 246 | ``` 247 | -------------------------------------------------------------------------------- /bin/train_slurm_wrapper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Wraps train.py to set pytorch distributed environment variables with 9 | # values from slurm. 10 | 11 | if [ ! -e ./sscd/train.py ]; then 12 | echo "Run from the top-level sscd directory." 13 | exit 1 14 | fi 15 | 16 | echo "Running on $(hostname)" 17 | 18 | # Choose a primary node for distributed coordination 19 | if (( SLURM_STEP_NUM_NODES == 1)); then 20 | primary="localhost" 21 | else 22 | primary="$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)" 23 | fi 24 | echo "Using $primary as primary for $SLURM_STEP_NUM_NODES node training run." 25 | 26 | MASTER_ADDR="$primary" MASTER_PORT="20285" NODE_RANK="$SLURM_NODEID" WORLD_SIZE=$(( 8 * $SLURM_STEP_NUM_NODES )) \ 27 | exec python ./sscd/train.py --nodes="$SLURM_STEP_NUM_NODES" --gpus=8 "$@" 28 | -------------------------------------------------------------------------------- /docs/Evaluation.md: -------------------------------------------------------------------------------- 1 | # Reproduce evaluation results 2 | 3 | ## DISC evaluations 4 | 5 | On an 8-GPU server, run the following to evaluate the `sscd_disc_mixup` model 6 | using our default preprocessing (resize the short edge to 288, preserving 7 | the aspect ratio). 8 | 9 | ```bash 10 | sscd/disc_eval.py --disc_path /path/to/disc2021 --gpus=8 \ 11 | --output_path=/path/to/eval/output \ 12 | --size=288 --preserve_aspect_ratio=true \ 13 | --backbone=CV_RESNET50 --dims=512 --model_state=/path/to/sscd_disc_mixup.classy.pt 14 | ``` 15 | 16 | After ~2 hours (on an 8 GPU machine), this command produces a CSV file, 17 | `disc_metrics.csv`, in the configured `--output_path`: 18 | 19 | ``` 20 | codec,score_norm,uAP,accuracy-at-1,recall-at-p90 21 | Flat,None,0.6093859526781344,0.7757,0.3766 22 | "PCAW512,L2norm,Flat",None,0.6142708346057645,0.782,0.3825 23 | Flat,"1.00[0,2]",0.7180449145415637,0.7757,0.6251 24 | "PCAW512,L2norm,Flat","1.00[0,2]",0.7242754343531326,0.782,0.6308 25 | ``` 26 | 27 | The columns of this file are: 28 | * `codec`: the descriptor postprocessing codec. Our paper results use 29 | `PCAW,L2norm,Flat`, where `` is the descriptor dimensionality. 30 | * `score_norm`: we evaluate with and without score normalization. The 31 | three numbers are: `[first, last]`, where weight is β 32 | in equation 8, `first` and `last` are zero-based indices of the first 33 | and last neighbors to include (inclusive). 34 | 35 | The two `PCAW512,L2norm,Flat` rows correspond to µAP 61.4% 36 | and µAPSN (with score normalization) of 72.4%, similar to the 37 | "SSCD DISC adv.+mixup" row in Table 2 (61.5% and 72.5%, respectively). 38 | 39 | Most of our models are 512d ResNet50 models, and can be evaluated by changing just 40 | the `--model_state` argument. 41 | To evaluate the `sscd_large` model, use: `--backbone=CV_RESNEXT101 --dims=1024` 42 | 43 | Note that our results use the DISC2021 validation set, as the test set was not 44 | yet available when the paper was written. 45 | 46 | ## Copydays + 10k distractors (CD10K) evaluations 47 | 48 | We evaluate using 10K distractors and a 20K whitening set from YFCC100M. 49 | 50 | To evaluate the `sscd_disc_mixup` model using default settings, run: 51 | 52 | ```bash 53 | sscd/copydays_eval.py --gpus=8 --copydays_path /path/to/copydays \ 54 | --distractor_path /path/to/distractors \ 55 | --codec_train_path /path/to/whitening \ 56 | --output_path=/path/to/eval/output \ 57 | --backbone=CV_RESNET50 --dims=512 \ 58 | --model_state=/path/to/sscd_disc_mixup.classy.pt \ 59 | --size=288 --preserve_aspect_ratio=true \ 60 | --codecs="PCAW512,L2norm,Flat;PCA512,L2norm,Flat" 61 | ``` 62 | 63 | The command above produces evaluation metrics with two embedding 64 | postprocessings: with whitening (`PCAW512,L2norm,Flat`, as reported 65 | in the paper), and with simple centering (`PCA512,L2norm,Flat`). 66 | 67 | After ~10 minutes (on an 8 GPU machine), this command will produce a 68 | CSV file, `copydays_metrics.csv`, within the configured `--output_path`: 69 | 70 | ``` 71 | codec,strong_mAP,overall_uAP 72 | "PCAW512,L2norm,Flat",0.865810403059819,0.9823856615515755 73 | "PCA512,L2norm,Flat",0.8843126681663799,0.9820800446423634 74 | ``` 75 | 76 | The `strong_mAP` column shows mAP on the strong subset of Copydays, 77 | and `overall_uAP` contains µAP over the full dataset. 78 | This shows that, with whitening, we get 86.58% mAP and 98.24% µAP, 79 | similar to the 86.6 mAP and 98.1 µAP values reported in the 80 | SSCD row in Table 3. 81 | 82 | Note that SSCD has better strong subset mAP metrics (although often 83 | reduced µAP) when evaluated on CD10K without whitening, 84 | using simple centering and L2 normalization (the `PCA512,L2norm,Flat` 85 | row in the CSV): 88.4% here versus 86.5%. 86 | 87 | We explore image preprocessing methods used by various baselines on Copydays, 88 | such as resizing to square tensors (`--preserve_aspect_ratio=false`) 89 | and resizing the long edge to a specified size 90 | (`--size=800 --resize_long_edge=true --preserve_aspect_ratio=false`). 91 | 92 | ## FAISS codec strings 93 | 94 | The evaluations above specify descriptor postprocessing methods 95 | (eg. whitening and L2 normalization) using 96 | [FAISS codec strings](https://github.com/facebookresearch/faiss/wiki/The-index-factory). 97 | 98 | For instance, if we start with 512d L2 normalized descriptors 99 | (as our ResNet50 models produce), and use the codec string 100 | `PCAW512,L2norm,Flat`, 101 | this means that descriptors will be whitened, then L2 normalized again, 102 | before being used for retrieval. 103 | 104 | Evaluation scripts can evaluate features using multiple codecs separated 105 | by `;`, such as `PCAW512,L2norm,Flat;PCA512,L2norm,Flat` to evaluate 106 | both whitening followed by L2 normalization and centering followed by 107 | L2 normalization. 108 | 109 | (`PCA512` centers a 512 dimensional representation by subtracting the mean 110 | before applying an orthogonal PCA projection. That projection does not change 111 | descriptor distances, so the only distance-changing effect is the centering.) 112 | -------------------------------------------------------------------------------- /docs/Training.md: -------------------------------------------------------------------------------- 1 | # Training SSCD models 2 | 3 | We use 4 8-GPU nodes to train SSCD models. 4 | We run `sscd/train.py` once on each training machine, passing 5 | [PyTorch distributed](https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization) 6 | environemnt variables to each command to coordinate workers. 7 | 8 | The train command on each worker is as follows: 9 | ``` 10 | MASTER_ADDR="" MASTER_PORT="20285" NODE_RANK="" WORLD_SIZE=32 \ 11 | ./sscd/train.py --nodes="" --gpus=8 \ 12 | --train_dataset_path=/path/to/disc/training \ 13 | --entropy_weight=30 --augmentations=ADVANCED --mixup=true \ 14 | --output_path=/path/to/train/output 15 | ``` 16 | 17 | ## Training using Slurm 18 | 19 | We orchestrate this using [Slurm](https://slurm.schedmd.com/documentation.html), 20 | and provide a wrapper script that translates Slurm environment variables to 21 | PyTorch distributed environment variables. 22 | (The next release of PyTorch Lightning should detect environment variables from Slurm and other cluster 23 | environments automatically.) 24 | 25 | ``` 26 | srun --nodes=4 --gpus-per-node=8 -mem=0 \ 27 | --cpus-per-task= --ntasks-per-node=1 \ 28 | ./bin/train_slurm_wrapper.sh --train_dataset_path=/path/to/disc/training \ 29 | --entropy_weight=30 --augmentations=ADVANCED --mixup=true \ 30 | --output_path=/path/to/train/output 31 | ``` 32 | 33 | ### Evaluating models trained using this codebase 34 | 35 | Training produces a checkpoint file within the provided `--output_path`, 36 | for instance at `/lightning_logs/version_/checkpoints/epoch=99-step=24399.ckpt`, 37 | where `` is an integer ID chosen by the Lightning framework. 38 | 39 | Our evaluation commands can load model settings and weights from 40 | these checkpoints via the `--checkpoint=` parameter. 41 | When using `--checkpoint=`, omit other model parameters 42 | (i.e. don't set `--backbone`, `--dims` or `--model_state`). 43 | 44 | ## Advice for extending SSCD 45 | 46 | To extend SSCD, for instance using different trunks, 47 | batch size, image augmentations, or optimizers, it may be necessary 48 | to reduce the entropy weight (λ, via the `--entropy_weight` 49 | argument). 50 | 51 | The setting we use in the paper, λ = 30, is a very strong 52 | weight, and is not stable for all configurations. 53 | When the entropy weight is too large, the repulsive force from 54 | entropy regularization may prevent InfoNCE from aligning matches. 55 | 56 | As an example, when training SSCD using Torchvision in this 57 | codebase, we discovered that our λ = 30 results relied 58 | on Classy Vision's default ResNet initialization, equivalent to 59 | TorchVision ResNet's `zero_init_residual=True` option, which puts all 60 | the energy into the residual connections at initialization. 61 | 62 | We recommend using a lower initial weight (eg. λ = 10) for new 63 | experiments. 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning>=1.5.7,<1.6 2 | lightning-bolts>=0.4.0,<0.6 3 | classy_vision 4 | torch 5 | torchvision 6 | torchmetrics 7 | faiss-gpu 8 | augly 9 | pandas 10 | numpy 11 | tensorboard 12 | -------------------------------------------------------------------------------- /sscd/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | July 18, 2022 2 | 3 | * Update license from CC-NC 4.0 International license to the MIT license. 4 | * Add a `--global_candidates` option to evaluate with a global candidate 5 | limit K, rather than K per query. 6 | This style of retrieval was used for the ISC descriptor track evaluations. 7 | -------------------------------------------------------------------------------- /sscd/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | 9 | from classy_vision.generic.registry_utils import import_all_modules 10 | 11 | FILE_ROOT = Path(__file__).parent 12 | 13 | 14 | def import_subdir(name): 15 | path = Path(FILE_ROOT, name) 16 | import_all_modules(path, f"sscd.{name}") 17 | 18 | 19 | # Automatically import any Python files in selected directories. 20 | import_subdir("transforms") 21 | -------------------------------------------------------------------------------- /sscd/copydays_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import argparse 9 | import logging 10 | import os 11 | 12 | import faiss 13 | import numpy as np 14 | import pandas as pd 15 | 16 | from lib import initialize # noqa 17 | from lib.inference import Inference 18 | from sscd.datasets.copydays import Copydays 19 | from sscd.datasets.image_folder import ImageFolder 20 | from sscd.lib.util import parse_bool 21 | 22 | # After initialize import to silence an error. 23 | from classy_vision.dataset.transforms import build_transforms 24 | 25 | parser = argparse.ArgumentParser() 26 | inference_parser = parser.add_argument_group("Inference") 27 | Inference.add_parser_args(inference_parser) 28 | inference_parser.add_argument( 29 | "--resize_long_edge", 30 | default=False, 31 | type=parse_bool, 32 | help=( 33 | "Preprocess images by resizing the long edge to --size. " 34 | "Has no effect if --preserve_aspect_ratio is not set." 35 | ), 36 | ) 37 | 38 | cd_parser = parser.add_argument_group("Copydays") 39 | cd_parser.add_argument("--copydays_path", required=True) 40 | cd_parser.add_argument("--distractor_path", required=True) 41 | cd_parser.add_argument("--codec_train_path") 42 | cd_parser.add_argument( 43 | "--codecs", 44 | default="Flat", 45 | help="FAISS codecs for postprocessing embeddings as ';' separated strings" 46 | "in index_factory format", 47 | ) 48 | cd_parser.add_argument("--metadata", help="Metadata column to put in the result CSV") 49 | 50 | logging.basicConfig( 51 | format="%(asctime)s %(levelname)-8s %(message)s", 52 | level=logging.WARNING, 53 | datefmt="%Y-%m-%d %H:%M:%S", 54 | ) 55 | logger = logging.getLogger("copydays_eval.py") 56 | logger.setLevel(logging.INFO) 57 | 58 | 59 | def get_transforms(size, preserve_aspect_ratio, resize_long_edge): 60 | resize_long_edge = preserve_aspect_ratio and resize_long_edge 61 | resize_name = "ResizeLongEdge" if resize_long_edge else "Resize" 62 | resize_size = size if preserve_aspect_ratio else [size, size] 63 | return build_transforms( 64 | [ 65 | {"name": resize_name, "size": resize_size}, 66 | {"name": "ToTensor"}, 67 | { 68 | "name": "Normalize", 69 | "mean": [0.485, 0.456, 0.406], 70 | "std": [0.229, 0.224, 0.225], 71 | }, 72 | ] 73 | ) 74 | 75 | 76 | def evaluate( 77 | embeddings, distractors, train, copydays: Copydays, index_type: str, k=100 78 | ): 79 | D = embeddings.shape[1] 80 | faiss_index = faiss.index_factory(D, index_type) 81 | if not faiss_index.is_trained: 82 | assert ( 83 | train is not None 84 | ), f"A training dataset must be provided for {index_type} search" 85 | faiss_index.train(train) 86 | faiss_index.add(copydays.get_block_embeddings(embeddings, "original")) 87 | faiss_index.add(distractors) 88 | distances, ids = faiss_index.search(embeddings, k) 89 | assert np.isfinite( 90 | distances 91 | ).all(), "Non-finite distances found; this often means whitening failed" 92 | metrics = copydays.eval_result(ids, distances) 93 | metrics = dict(strong_mAP=metrics["strong mAP"], overall_uAP=metrics["macro AP"]) 94 | return metrics 95 | 96 | 97 | def evaluate_all( 98 | copydays_outputs, 99 | distractor_outputs, 100 | train_outputs, 101 | copydays: Copydays, 102 | codecs: str, 103 | metadata=None, 104 | k=100, 105 | ): 106 | codecs = codecs.split(";") 107 | instance_ids = copydays_outputs["instance_id"] 108 | embeddings = copydays_outputs["embeddings"] 109 | order = np.argsort(instance_ids) 110 | instance_ids = instance_ids[order] 111 | embeddings = embeddings[order, :] 112 | N, D = embeddings.shape 113 | assert N == len(copydays) 114 | assert np.all(instance_ids == np.arange(N, dtype=np.int64)) 115 | distractors = distractor_outputs["embeddings"] 116 | train = train_outputs["embeddings"] if train_outputs else None 117 | records = [] 118 | for codec in codecs: 119 | record = {"codec": codec} 120 | metrics = evaluate(embeddings, distractors, train, copydays, codec, k=k) 121 | record.update(metrics) 122 | logger.info(f"Metrics: {record}") 123 | records.append(record) 124 | df = pd.DataFrame(records) 125 | if metadata: 126 | df["metadata"] = metadata 127 | return df 128 | 129 | 130 | def main(args): 131 | logger.info("Setting up dataset") 132 | transforms = get_transforms( 133 | args.size, args.preserve_aspect_ratio, args.resize_long_edge 134 | ) 135 | copydays = Copydays(args.copydays_path, transforms) 136 | copydays_embeddings = Inference.inference(args, copydays, "copydays") 137 | distractors = ImageFolder(args.distractor_path, img_transform=transforms) 138 | distractor_embeddings = Inference.inference(args, distractors, "distractors") 139 | if args.codec_train_path: 140 | codec_train = ImageFolder(args.codec_train_path, img_transform=transforms) 141 | train_embeddings = Inference.inference(args, codec_train, "codec_train") 142 | else: 143 | train_embeddings = None 144 | df = evaluate_all( 145 | copydays_embeddings, 146 | distractor_embeddings, 147 | train_embeddings, 148 | copydays, 149 | args.codecs, 150 | metadata=args.metadata, 151 | ) 152 | csv_filename = os.path.join(args.output_path, "copydays_metrics.csv") 153 | df.to_csv(csv_filename, index=False) 154 | with open(csv_filename, "r") as f: 155 | logger.info("Metric CSV:\n%s", f.read()) 156 | 157 | 158 | if __name__ == "__main__": 159 | args = parser.parse_args() 160 | main(args) 161 | -------------------------------------------------------------------------------- /sscd/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sscd/datasets/copydays.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import numpy as np 9 | from torchvision.datasets.folder import default_loader 10 | 11 | 12 | def score_ap_from_ranks_1(ranks, nres): 13 | """Compute the average precision of one search. 14 | ranks = ordered list of ranks of true positives 15 | nres = total number of positives in dataset 16 | """ 17 | 18 | # accumulate trapezoids in PR-plot 19 | ap = 0.0 20 | 21 | # All have an x-size of: 22 | recall_step = 1.0 / nres 23 | 24 | for ntp, rank in enumerate(ranks): 25 | 26 | # y-size on left side of trapezoid: 27 | # ntp = nb of true positives so far 28 | # rank = nb of retrieved items so far 29 | if rank == 0: 30 | precision_0 = 1.0 31 | else: 32 | precision_0 = ntp / float(rank) 33 | 34 | # y-size on right side of trapezoid: 35 | # ntp and rank are increased by one 36 | precision_1 = (ntp + 1) / float(rank + 1) 37 | 38 | ap += (precision_1 + precision_0) * recall_step / 2.0 39 | 40 | return ap 41 | 42 | 43 | # Compute the macro-MAP, precision and recall (threshold based) from 44 | # the groundtruth defined as nquery x results matrices ( 45 | # 0 for match and 1 for non-matches) 46 | def score_macro_AP(gnd, dis, totpos=None): 47 | # Default is one possible positive result per query 48 | if totpos is None: 49 | totpos = gnd.shape[0] 50 | # Interleave the results from all queries and sort them by distances 51 | gnd = np.reshape(gnd, (-1)) 52 | a = np.reshape(dis, (-1)).argsort() 53 | gnd = gnd[a] 54 | ntp = gnd.cumsum().astype("float") 55 | recall = ntp / float(totpos) 56 | precision = ntp / (np.arange(ntp.shape[0]) + 1) 57 | 58 | # Compute the macroMAP now 59 | ranks_pos = [i for i in range(gnd.shape[0]) if gnd[i] == 1] 60 | MAP = score_ap_from_ranks_1(ranks_pos, totpos) 61 | return MAP, recall, precision 62 | 63 | 64 | def blocks_from_directories(imnames): 65 | """splits a list of filenames according to their direcotry""" 66 | imnames.sort() 67 | prev_dirname = None 68 | block_names = [] 69 | per_block_images = [] 70 | for name in imnames: 71 | if name.startswith("./"): 72 | name = name[2:] 73 | if "/" in name: 74 | dirname = name[: name.rfind("/")] 75 | else: 76 | dirname = "" 77 | if dirname != prev_dirname: 78 | prev_dirname = dirname 79 | block_names.append(dirname) 80 | block_images = [] 81 | per_block_images.append(block_images) 82 | block_images.append(name) 83 | 84 | return block_names, per_block_images 85 | 86 | 87 | def cluster_pr(imnos, ids): 88 | """ 89 | The images in the list imnos are from a cluster. 90 | Return the recall @ cluster size for the the results ids. 91 | 1 = perfect result list, each false positive costs 1/len(imnos). 92 | """ 93 | prs = [] 94 | npos = len(imnos) 95 | for qno in imnos: 96 | ranks = [rank for rank, rno in enumerate(ids[qno]) if rno in imnos] 97 | # print ' ', ranks, 98 | recall = len(ranks) / float(npos) 99 | precision = len(ranks) / float(len(ids[qno])) 100 | prs.append((precision, recall)) 101 | # print 102 | return np.array(prs) 103 | 104 | 105 | class CopydaysBlock: 106 | 107 | STRONG = "strong" 108 | STANDARD_SIZE = 157 109 | STRONG_SIZE = 229 110 | 111 | def __init__(self, name, path, start_id, transforms=None): 112 | self.name = name 113 | self.path = path 114 | self.size = self.STRONG_SIZE if name == self.STRONG else self.STANDARD_SIZE 115 | self.start_id = start_id 116 | self.files = sorted([f for f in os.listdir(path) if f.endswith(".jpg")]) 117 | self.transforms = transforms 118 | 119 | def __getitem__(self, i): 120 | file = self.files[i] 121 | img = default_loader(os.path.join(self.path, file)) 122 | if self.transforms: 123 | img = self.transforms(img) 124 | return dict( 125 | input=img, block=self.name, filename=file, instance_id=i + self.start_id 126 | ) 127 | 128 | def __len__(self): 129 | return self.size 130 | 131 | 132 | class Copydays: 133 | def __init__(self, path, transforms=None): 134 | self.basedir = path 135 | self.block_names = ( 136 | ["original", "strong"] 137 | + ["jpegqual/%d" % i for i in [3, 5, 8, 10, 15, 20, 30, 50, 75]] 138 | + ["crops/%d" % i for i in [10, 15, 20, 30, 40, 50, 60, 70, 80]] 139 | ) 140 | self.blocks = [] 141 | instance_id = 0 142 | for name in self.block_names: 143 | block = CopydaysBlock( 144 | name, os.path.join(path, name), instance_id, transforms 145 | ) 146 | self.blocks.append(block) 147 | instance_id += len(block) 148 | self.size = instance_id 149 | self.nblocks = len(self.block_names) 150 | self.query_blocks = list(range(self.nblocks)) 151 | self.q_block_sizes = np.ones(self.nblocks, dtype=int) * 157 152 | self.q_block_sizes[1] = 229 153 | # search only among originals 154 | self.database_blocks = [0] 155 | 156 | def __len__(self): 157 | return self.size 158 | 159 | def __getitem__(self, i): 160 | assert i < self.size 161 | for block in self.blocks: 162 | if i < len(block): 163 | return block[i] 164 | i -= len(block) 165 | raise AssertionError("unreachable") 166 | 167 | def get_block(self, i): 168 | dirname = self.basedir + "/" + self.block_names[i] 169 | fnames = [ 170 | dirname + "/" + fname 171 | for fname in sorted(os.listdir(dirname)) 172 | if fname.endswith(".jpg") 173 | ] 174 | res = ("image_list", fnames) 175 | return res 176 | 177 | def get_block_filenames(self, subdir_name, absolute=False): 178 | dirname = self.basedir + "/" + subdir_name 179 | relative = [ 180 | fname for fname in sorted(os.listdir(dirname)) if fname.endswith(".jpg") 181 | ] 182 | if absolute: 183 | return [os.path.join(dirname, fname) for fname in relative] 184 | return relative 185 | 186 | def get_block_embeddings(self, embeddings, block_name): 187 | assert embeddings.shape[0] == len(self) 188 | assert block_name in self.block_names 189 | i = self.block_names.index(block_name) 190 | block = self.blocks[i] 191 | start = block.start_id 192 | end = start + len(block) 193 | return embeddings[start:end, :] 194 | 195 | def eval_result(self, ids, distances): 196 | metrics = {} 197 | j0 = 0 198 | for i in range(self.nblocks): 199 | j1 = j0 + self.q_block_sizes[i] 200 | block_name = self.block_names[i] 201 | I = ids[j0:j1] # block size 202 | assert I.shape[0] == self.q_block_sizes[i] # check for partial slice 203 | sum_AP = 0 204 | if block_name != "strong": 205 | # 1:1 mapping of files to names 206 | positives_per_query = [[i] for i in range(j1 - j0)] 207 | else: 208 | originals = self.get_block_filenames("original") 209 | strongs = self.get_block_filenames("strong") 210 | 211 | # check if prefixes match 212 | positives_per_query = [ 213 | [j for j, bname in enumerate(originals) if bname[:4] == qname[:4]] 214 | for qname in strongs 215 | ] 216 | 217 | for qno, Iline in enumerate(I): 218 | positives = positives_per_query[qno] 219 | ranks = [] 220 | for rank, bno in enumerate(Iline): 221 | if bno in positives: 222 | ranks.append(rank) 223 | sum_AP += score_ap_from_ranks_1(ranks, len(positives)) 224 | 225 | mAP = sum_AP / (j1 - j0) 226 | print("eval on %s mAP=%.3f" % (block_name, mAP)) 227 | metrics["%s mAP" % block_name] = mAP 228 | j0 = j1 229 | 230 | self.eval_result_alt(ids, distances, metrics) 231 | return metrics 232 | 233 | def eval_result_alt(self, ids, distances, metrics): 234 | gnd = np.zeros(ids.shape, dtype="int") 235 | j0 = 0 236 | for i in range(self.nblocks): 237 | j1 = j0 + self.q_block_sizes[i] 238 | block_name = self.block_names[i] 239 | I = ids[j0:j1] # block size 240 | if block_name != "strong": 241 | # 1:1 mapping of files to names 242 | positives_per_query = [[i] for i in range(j1 - j0)] 243 | else: 244 | originals = self.get_block_filenames("original") 245 | strongs = self.get_block_filenames("strong") 246 | 247 | # check if prefixes match 248 | positives_per_query = [ 249 | [j for j, bname in enumerate(originals) if bname[:4] == qname[:4]] 250 | for qname in strongs 251 | ] 252 | 253 | for qno, Iline in enumerate(I): 254 | positives = positives_per_query[qno] 255 | for rank, bno in enumerate(Iline): 256 | if bno in positives: 257 | gnd[j0 + qno][rank] = 1 258 | j0 = j1 259 | MAP, recall, precision = score_macro_AP(gnd, distances) 260 | print("Macro-AP = %.4f" % MAP) 261 | metrics["macro AP"] = MAP 262 | -------------------------------------------------------------------------------- /sscd/datasets/disc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os.path 8 | from typing import Callable, Dict, Optional 9 | from torchvision.datasets.folder import default_loader 10 | 11 | from sscd.datasets.image_folder import get_image_paths 12 | from sscd.datasets.isc.descriptor_matching import ( 13 | knn_match_and_make_predictions, 14 | match_and_make_predictions, 15 | ) 16 | from sscd.datasets.isc.io import read_ground_truth 17 | from sscd.datasets.isc.metrics import evaluate, Metrics 18 | 19 | 20 | class DISCEvalDataset: 21 | """DISC2021 evaluation dataset.""" 22 | 23 | SPLIT_REF = 0 24 | SPLIT_QUERY = 1 25 | SPLIT_TRAIN = 2 26 | 27 | def __init__( 28 | self, 29 | path: str, 30 | transform: Callable = None, 31 | include_train: bool = True, 32 | # Specific paths for each part of the dataset. If not set, inferred from `path`. 33 | query_path: Optional[str] = None, 34 | ref_path: Optional[str] = None, 35 | train_path: Optional[str] = None, 36 | gt_path: Optional[str] = None, 37 | ): 38 | def get_path(override_path, relative_path): 39 | return override_path if override_path else os.path.join(path, relative_path) 40 | 41 | query_path = get_path(query_path, "val_query") 42 | ref_path = get_path(ref_path, "val_ref") 43 | train_path = get_path(train_path, "training") if include_train else None 44 | gt_path = get_path(gt_path, "val_groundtruth_matches.csv") 45 | self.files, self.metadata = self.read_files(ref_path, self.SPLIT_REF) 46 | query_files, query_metadata = self.read_files(query_path, self.SPLIT_QUERY) 47 | self.files.extend(query_files) 48 | self.metadata.extend(query_metadata) 49 | if train_path: 50 | train_files, train_metadata = self.read_files(train_path, self.SPLIT_TRAIN) 51 | self.files.extend(train_files) 52 | self.metadata.extend(train_metadata) 53 | self.gt = read_ground_truth(gt_path) 54 | self.transform = transform 55 | 56 | def __getitem__(self, idx: int): 57 | filename = self.files[idx] 58 | img = default_loader(filename) 59 | if self.transform: 60 | img = self.transform(img) 61 | sample = {"input": img, "instance_id": idx} 62 | sample.update(self.metadata[idx]) 63 | return sample 64 | 65 | def __len__(self): 66 | return len(self.files) 67 | 68 | @classmethod 69 | def read_files(cls, path, split): 70 | files = get_image_paths(path) 71 | names = [os.path.splitext(os.path.basename(file))[0] for file in files] 72 | metadata = [ 73 | dict(name=name, split=split, image_num=int(name[1:]), target=-1) 74 | for name in names 75 | ] 76 | return files, metadata 77 | 78 | def retrieval_eval( 79 | self, embedding_array, targets, split, **kwargs 80 | ) -> Dict[str, float]: 81 | query_mask = split == self.SPLIT_QUERY 82 | ref_mask = split == self.SPLIT_REF 83 | query_ids = targets[query_mask] 84 | query_embeddings = embedding_array[query_mask, :] 85 | ref_ids = targets[ref_mask] 86 | ref_embeddings = embedding_array[ref_mask, :] 87 | return self.retrieval_eval_splits( 88 | query_ids, query_embeddings, ref_ids, ref_embeddings, **kwargs 89 | ) 90 | 91 | def retrieval_eval_splits( 92 | self, 93 | query_ids, 94 | query_embeddings, 95 | ref_ids, 96 | ref_embeddings, 97 | use_gpu=False, 98 | k=10, 99 | global_candidates=False, 100 | **kwargs 101 | ) -> Dict[str, float]: 102 | query_names = ["Q%05d" % i for i in query_ids] 103 | ref_names = ["R%06d" % i for i in ref_ids] 104 | if global_candidates: 105 | predictions = match_and_make_predictions( 106 | query_embeddings, 107 | query_names, 108 | ref_embeddings, 109 | ref_names, 110 | num_results=k * len(query_names), 111 | ngpu=-1 if use_gpu else 0, 112 | **kwargs, 113 | ) 114 | else: 115 | predictions = knn_match_and_make_predictions( 116 | query_embeddings, 117 | query_names, 118 | ref_embeddings, 119 | ref_names, 120 | k=k, 121 | ngpu=-1 if use_gpu else 0, 122 | **kwargs, 123 | ) 124 | results: Metrics = evaluate(self.gt, predictions) 125 | return { 126 | "uAP": results.average_precision, 127 | "accuracy-at-1": results.recall_at_rank1, 128 | "recall-at-p90": results.recall_at_p90 or 0.0, 129 | } 130 | -------------------------------------------------------------------------------- /sscd/datasets/image_folder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | import logging 9 | import os.path 10 | 11 | from torchvision.datasets.folder import is_image_file 12 | from torchvision.datasets.folder import default_loader 13 | 14 | 15 | @functools.lru_cache() 16 | def get_image_paths(path): 17 | logging.info(f"Resolving files in: {path}") 18 | filenames = [f"{path}/{file}" for file in os.listdir(path)] 19 | return sorted([fn for fn in filenames if is_image_file(fn)]) 20 | 21 | 22 | class ImageFolder: 23 | """An image folder dataset intended for self-supervised learning.""" 24 | 25 | def __init__(self, path, transform=None, img_transform=None, loader=default_loader): 26 | self.files = get_image_paths(path) 27 | self.loader = loader 28 | self.transform = transform 29 | self.img_transform = img_transform 30 | 31 | def __getitem__(self, idx: int): 32 | assert 0 <= idx < len(self) 33 | img = self.loader(self.files[idx]) 34 | record = {"input": img, "instance_id": idx} 35 | if self.img_transform: 36 | record["input"] = self.img_transform(record["input"]) 37 | if self.transform: 38 | record = self.transform(record) 39 | return record 40 | 41 | def __len__(self): 42 | return len(self.files) 43 | -------------------------------------------------------------------------------- /sscd/datasets/isc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sscd/datasets/isc/descriptor_matching.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Original source: https://github.com/facebookresearch/isc2021 8 | 9 | import numpy as np 10 | import faiss 11 | from faiss.contrib import exhaustive_search 12 | import logging 13 | 14 | from .metrics import PredictedMatch 15 | 16 | 17 | def query_iterator(xq): 18 | """produces batches of progressively increasing sizes""" 19 | nq = len(xq) 20 | bs = 32 21 | i = 0 22 | while i < nq: 23 | xqi = xq[i : i + bs] 24 | yield xqi 25 | if bs < 20000: 26 | bs *= 2 27 | i += len(xqi) 28 | 29 | 30 | ######################### 31 | # These two functions are there because current Faiss contrib 32 | # does not proporly support IP search 33 | ######################### 34 | 35 | 36 | def threshold_radius_nres_IP(nres, dis, ids, thresh): 37 | """select a set of results""" 38 | mask = dis > thresh 39 | new_nres = np.zeros_like(nres) 40 | o = 0 41 | for i, nr in enumerate(nres): 42 | nr = int(nr) # avoid issues with int64 + uint64 43 | new_nres[i] = mask[o : o + nr].sum() 44 | o += nr 45 | return new_nres, dis[mask], ids[mask] 46 | 47 | 48 | def apply_maxres_IP(res_batches, target_nres): 49 | """find radius that reduces number of results to target_nres, and 50 | applies it in-place to the result batches used in range_search_max_results""" 51 | alldis = np.hstack([dis for _, dis, _ in res_batches]) 52 | alldis.partition(len(alldis) - target_nres) 53 | radius = alldis[-target_nres] 54 | 55 | LOG = logging.getLogger(exhaustive_search.__name__) 56 | 57 | if alldis.dtype == "float32": 58 | radius = float(radius) 59 | else: 60 | radius = int(radius) 61 | LOG.debug(" setting radius to %s" % radius) 62 | totres = 0 63 | for i, (nres, dis, ids) in enumerate(res_batches): 64 | nres, dis, ids = threshold_radius_nres_IP(nres, dis, ids, radius) 65 | totres += len(dis) 66 | res_batches[i] = nres, dis, ids 67 | LOG.debug(" updated previous results, new nb results %d" % totres) 68 | return radius, totres 69 | 70 | 71 | def search_with_capped_res(xq, xb, num_results, metric=faiss.METRIC_L2): 72 | """ 73 | Searches xq into xb, with a maximum total number of results 74 | """ 75 | index = faiss.IndexFlat(xb.shape[1], metric) 76 | index.add(xb) 77 | # logging.basicConfig() 78 | # logging.getLogger(exhaustive_search.__name__).setLevel(logging.DEBUG) 79 | 80 | if metric == faiss.METRIC_INNER_PRODUCT: 81 | # this is a very ugly hack because contrib.exhaustive_search does 82 | # not support IP search correctly. Do not use in a multithreaded env. 83 | apply_maxres_saved = exhaustive_search.apply_maxres 84 | exhaustive_search.apply_maxres = apply_maxres_IP 85 | 86 | radius, lims, dis, ids = exhaustive_search.range_search_max_results( 87 | index, 88 | query_iterator(xq), 89 | 1e10 90 | if metric == faiss.METRIC_L2 91 | else -1e10, # initial radius does not filter anything 92 | max_results=2 * num_results, 93 | min_results=num_results, 94 | ngpu=-1, # use GPU if available 95 | ) 96 | 97 | if metric == faiss.METRIC_INNER_PRODUCT: 98 | exhaustive_search.apply_maxres = apply_maxres_saved 99 | 100 | n = len(dis) 101 | nq = len(xq) 102 | if n > num_results: 103 | # crop to num_results exactly 104 | if metric == faiss.METRIC_L2: 105 | o = dis.argpartition(num_results)[:num_results] 106 | else: 107 | o = dis.argpartition(len(dis) - num_results)[-num_results:] 108 | mask = np.zeros(n, bool) 109 | mask[o] = True 110 | new_dis = dis[mask] 111 | new_ids = ids[mask] 112 | nres = [0] + [mask[lims[i] : lims[i + 1]].sum() for i in range(nq)] 113 | new_lims = np.cumsum(nres) 114 | lims, dis, ids = new_lims, new_dis, new_ids 115 | 116 | return lims, dis, ids 117 | 118 | 119 | def match_and_make_predictions( 120 | xq, query_image_ids, xb, db_image_ids, num_results, ngpu=-1, metric=faiss.METRIC_L2 121 | ): 122 | lims, dis, ids = search_with_capped_res(xq, xb, num_results, metric=metric) 123 | nq = len(xq) 124 | 125 | if metric == faiss.METRIC_L2: 126 | # use negated distances as scores 127 | dis = -dis 128 | 129 | predictions = [ 130 | PredictedMatch(query_image_ids[i], db_image_ids[ids[j]], dis[j]) 131 | for i in range(nq) 132 | for j in range(lims[i], lims[i + 1]) 133 | ] 134 | return predictions 135 | 136 | 137 | def knn_match_and_make_predictions( 138 | xq, query_image_ids, xb, db_image_ids, k, ngpu=-1, metric=faiss.METRIC_L2 139 | ): 140 | if ngpu == 0 or faiss.get_num_gpus() == 0: 141 | D, I = faiss.knn(xq, xb, k, metric) 142 | else: 143 | d = xq.shape[1] 144 | index = faiss.IndexFlat(d, metric) 145 | index.add(xb) 146 | index = faiss.index_cpu_to_all_gpus(index) 147 | D, I = index.search(xq, k=k) 148 | nq = len(xq) 149 | 150 | if metric == faiss.METRIC_L2: 151 | # use negated distances as scores 152 | D = -D 153 | 154 | predictions = [ 155 | PredictedMatch(query_image_ids[i], db_image_ids[I[i, j]], D[i, j]) 156 | for i in range(nq) 157 | for j in range(k) 158 | ] 159 | return predictions 160 | 161 | 162 | def range_result_read(fname): 163 | """read the range search result file format""" 164 | f = open(fname, "rb") 165 | nq, total_res = np.fromfile(f, count=2, dtype="int32") 166 | nres = np.fromfile(f, count=nq, dtype="int32") 167 | assert nres.sum() == total_res 168 | I = np.fromfile(f, count=total_res, dtype="int32") 169 | return nres, I 170 | -------------------------------------------------------------------------------- /sscd/datasets/isc/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Original source: https://github.com/facebookresearch/isc2021 8 | 9 | from typing import List 10 | 11 | import numpy as np 12 | 13 | from .metrics import GroundTruthMatch 14 | 15 | 16 | def read_ground_truth(filename: str) -> List[GroundTruthMatch]: 17 | """ 18 | Read groundtruth csv file. 19 | Must contain query_image_id,db_image_id on each line. 20 | handles the no header version and DD's version with header 21 | """ 22 | gt_pairs = [] 23 | with open(filename, "r") as cfile: 24 | for line in cfile: 25 | line = line.strip() 26 | if line == "query_id,reference_id": 27 | continue 28 | q, db = line.split(",") 29 | if db == "": 30 | continue 31 | gt_pairs.append(GroundTruthMatch(q, db)) 32 | return gt_pairs 33 | -------------------------------------------------------------------------------- /sscd/datasets/isc/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Original source: https://github.com/facebookresearch/isc2021 8 | 9 | from dataclasses import astuple, dataclass 10 | from typing import List, Optional, Tuple 11 | from collections import defaultdict 12 | import numpy as np 13 | 14 | 15 | @dataclass 16 | class GroundTruthMatch: 17 | query: str 18 | db: str 19 | 20 | 21 | @dataclass 22 | class PredictedMatch: 23 | query: str 24 | db: str 25 | score: float 26 | 27 | 28 | @dataclass 29 | class Metrics: 30 | average_precision: float 31 | precisions: np.ndarray 32 | recalls: np.ndarray 33 | thresholds: np.ndarray 34 | recall_at_p90: float 35 | threshold_at_p90: float 36 | recall_at_rank1: float 37 | recall_at_rank10: float 38 | 39 | 40 | def argsort(seq): 41 | # from https://stackoverflow.com/a/3382369/3853462 42 | return sorted(range(len(seq)), key=seq.__getitem__) 43 | 44 | 45 | def precision_recall( 46 | y_true: np.ndarray, probas_pred: np.ndarray, num_positives: int 47 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 48 | """ 49 | Compute precisions, recalls and thresholds. 50 | 51 | Parameters 52 | ---------- 53 | y_true : np.ndarray 54 | Binary label of each prediction (0 or 1). Shape [n, k] or [n*k, ] 55 | probas_pred : np.ndarray 56 | Score of each prediction (higher score == images more similar, ie not a distance) 57 | Shape [n, k] or [n*k, ] 58 | num_positives : int 59 | Number of positives in the groundtruth. 60 | 61 | Returns 62 | ------- 63 | precisions, recalls, thresholds 64 | ordered by increasing recall. 65 | """ 66 | probas_pred = probas_pred.flatten() 67 | y_true = y_true.flatten() 68 | # to handle duplicates scores, we sort (score, NOT(jugement)) for predictions 69 | # eg,the final order will be (0.5, False), (0.5, False), (0.5, True), (0.4, False), ... 70 | # This allows to have the worst possible AP. 71 | # It prevents participants from putting the same score for all predictions to get a good AP. 72 | order = argsort(list(zip(probas_pred, ~y_true))) 73 | order = order[::-1] # sort by decreasing score 74 | probas_pred = probas_pred[order] 75 | y_true = y_true[order] 76 | 77 | ntp = np.cumsum(y_true) # number of true positives <= threshold 78 | nres = np.arange(len(y_true)) + 1 # number of results 79 | 80 | precisions = ntp / nres 81 | recalls = ntp / num_positives 82 | return precisions, recalls, probas_pred 83 | 84 | 85 | def average_precision_old(recalls: np.ndarray, precisions: np.ndarray): 86 | """ 87 | Compute the micro average-precision score (uAP). 88 | 89 | Parameters 90 | ---------- 91 | recalls : np.ndarray 92 | Recalls, can be in any order. 93 | precisions : np.ndarray 94 | Precisions for each recall value. 95 | 96 | Returns 97 | ------- 98 | uAP: float 99 | """ 100 | 101 | # Order by increasing recall 102 | order = np.argsort(recalls) 103 | recalls = recalls[order] 104 | precisions = precisions[order] 105 | return ((recalls[1:] - recalls[:-1]) * precisions[:-1]).sum() 106 | 107 | 108 | # Jay Qi's version 109 | def average_precision(recalls: np.ndarray, precisions: np.ndarray): 110 | # Order by increasing recall 111 | # order = np.argsort(recalls) 112 | # recalls = recalls[order] 113 | # precisions = precisions[order] 114 | 115 | # Check that it's ordered by increasing recall 116 | if not np.all(recalls[:-1] <= recalls[1:]): 117 | raise Exception("recalls array must be sorted before passing in") 118 | 119 | return ((recalls - np.concatenate([[0], recalls[:-1]])) * precisions).sum() 120 | 121 | 122 | def find_operating_point( 123 | x: np.ndarray, y: np.ndarray, z: np.ndarray, required_x: float 124 | ) -> Tuple[float, Optional[float], Optional[float]]: 125 | """ 126 | Find the highest y with x at least `required_x`. 127 | 128 | Returns 129 | ------- 130 | x, y, z 131 | The best operating point (highest y) with x at least `required_x`. 132 | If we can't find a point with the required x value, return 133 | x=required_x, y=None, z=None 134 | """ 135 | valid_points = x >= required_x 136 | if not np.any(valid_points): 137 | return required_x, None, None 138 | 139 | valid_x = x[valid_points] 140 | valid_y = y[valid_points] 141 | valid_z = z[valid_points] 142 | best_idx = np.argmax(valid_y) 143 | return valid_x[best_idx], valid_y[best_idx], valid_z[best_idx] 144 | 145 | 146 | def check_duplicates(predictions: List[PredictedMatch]) -> List[PredictedMatch]: 147 | """ 148 | Raise an exception if predictions contains duplicates 149 | (ie several predictions for the same (query, db) pair). 150 | """ 151 | unique_pairs = set((p.query, p.db) for p in predictions) 152 | if len(unique_pairs) != len(predictions): 153 | raise ValueError("Predictions contains duplicates.") 154 | 155 | 156 | def sanitize_predictions(predictions: List[PredictedMatch]) -> List[PredictedMatch]: 157 | # TODO(lowik) check for other possible loopholes 158 | check_duplicates(predictions) 159 | return predictions 160 | 161 | 162 | def to_arrays(gt_matches: List[GroundTruthMatch], predictions: List[PredictedMatch]): 163 | """Convert from list of matches to arrays""" 164 | predictions = sanitize_predictions(predictions) 165 | 166 | gt_set = {astuple(g) for g in gt_matches} 167 | probas_pred = np.array([p.score for p in predictions]) 168 | y_true = np.array([(p.query, p.db) in gt_set for p in predictions], dtype=bool) 169 | return y_true, probas_pred 170 | 171 | 172 | def find_tp_ranks( 173 | gt_matches: List[GroundTruthMatch], predictions: List[PredictedMatch] 174 | ): 175 | q_to_res = defaultdict(list) 176 | for p in predictions: 177 | q_to_res[p.query].append(p) 178 | ranks = [] 179 | not_found = int(1 << 35) 180 | for m in gt_matches: 181 | if m.query not in q_to_res: 182 | ranks.append(not_found) 183 | continue 184 | res = q_to_res[m.query] 185 | res = np.array([(p.score, m.db == p.db) for p in res]) 186 | (i,) = np.where(res[:, 1] == 1) 187 | if i.size == 0: 188 | ranks.append(not_found) 189 | else: 190 | i = i[0] 191 | rank = (res[:, 0] >= res[i, 0]).sum() - 1 192 | ranks.append(rank) 193 | return np.array(ranks) 194 | 195 | 196 | def evaluate( 197 | gt_matches: List[GroundTruthMatch], predictions: List[PredictedMatch] 198 | ) -> Metrics: 199 | predictions = sanitize_predictions(predictions) 200 | y_true, probas_pred = to_arrays(gt_matches, predictions) 201 | p, r, t = precision_recall(y_true, probas_pred, len(gt_matches)) 202 | ap = average_precision(r, p) 203 | pp90, rp90, tp90 = find_operating_point(p, r, t, required_x=0.9) # @Precision=90% 204 | ranks = find_tp_ranks(gt_matches, predictions) 205 | recall_at_rank1 = (ranks == 0).sum() / ranks.size 206 | recall_at_rank10 = (ranks < 10).sum() / ranks.size 207 | 208 | return Metrics( 209 | average_precision=ap, 210 | precisions=p, 211 | recalls=r, 212 | thresholds=t, 213 | recall_at_p90=rp90, 214 | threshold_at_p90=tp90, 215 | recall_at_rank1=recall_at_rank1, 216 | recall_at_rank10=recall_at_rank10, 217 | ) 218 | 219 | 220 | def print_metrics(metrics: Metrics): 221 | print(f"Average Precision: {metrics.average_precision:.5f}") 222 | if metrics.recall_at_p90 is None: 223 | print("Does not reach P90") 224 | else: 225 | print(f"Recall at P90 : {metrics.recall_at_p90:.5f}") 226 | print(f"Threshold at P90 : {metrics.threshold_at_p90:g}") 227 | print(f"Recall at rank 1: {metrics.recall_at_rank1:.5f}") 228 | print(f"Recall at rank 10: {metrics.recall_at_rank10:.5f}") 229 | -------------------------------------------------------------------------------- /sscd/disc_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import argparse 9 | import dataclasses 10 | import json 11 | import logging 12 | import os 13 | from typing import Optional 14 | 15 | import faiss 16 | import torch 17 | import numpy as np 18 | from numpy import linalg 19 | import pandas as pd 20 | 21 | from lib import initialize # noqa 22 | from lib.inference import Inference 23 | from sscd.train import DISCData 24 | from sscd.datasets.disc import DISCEvalDataset 25 | from sscd.lib.util import parse_bool 26 | 27 | parser = argparse.ArgumentParser() 28 | inference_parser = parser.add_argument_group("Inference") 29 | Inference.add_parser_args(inference_parser) 30 | 31 | disc_parser = parser.add_argument_group("DISC") 32 | disc_parser.add_argument("--disc_path", required=True) 33 | disc_parser.add_argument( 34 | "--codecs", 35 | default=None, 36 | help="FAISS codecs for postprocessing embeddings as ';' separated strings " 37 | "in index_factory format", 38 | ) 39 | disc_parser.add_argument( 40 | "--score_norm", 41 | default="1.0[0,2]", 42 | help="Score normalization settings, ';' separated, in format: " 43 | "[,]", 44 | ) 45 | disc_parser.add_argument("--k", default=10, type=int) 46 | disc_parser.add_argument( 47 | "--global_candidates", 48 | default=False, 49 | type=parse_bool, 50 | help="Use a global set of KNN candidates, instead of k per query. Uses CPU KNN.", 51 | ) 52 | disc_parser.add_argument("--metadata", help="Metadata column to put in the result CSV") 53 | 54 | logging.basicConfig( 55 | format="%(asctime)s %(levelname)-8s %(message)s", 56 | level=logging.WARNING, 57 | datefmt="%Y-%m-%d %H:%M:%S", 58 | ) 59 | logger = logging.getLogger("disc_eval.py") 60 | logger.setLevel(logging.INFO) 61 | 62 | 63 | class ProjectionError(Exception): 64 | """Projection returned non-finite values.""" 65 | 66 | 67 | def get_codecs(dims, is_l2_normalized, codecs_arg): 68 | if codecs_arg: 69 | return codecs_arg.split(";") 70 | if is_l2_normalized: 71 | return ["Flat", f"PCAW{dims},L2norm,Flat"] 72 | return ["Flat", "L2norm,Flat", f"PCAW{dims},L2norm,Flat", f"L2norm,PCAW{dims},Flat"] 73 | 74 | 75 | def is_l2_normalized(embeddings): 76 | norms = linalg.norm(embeddings, axis=1) 77 | return np.abs(norms - 1).mean() < 0.01 78 | 79 | 80 | @dataclasses.dataclass 81 | class ScoreNormalization: 82 | weight: float 83 | start_index: int 84 | end_index: int 85 | 86 | @classmethod 87 | def parse(cls, spec): 88 | weight, spec = spec.split("[", 1) 89 | assert spec.endswith("]") 90 | spec = spec[:-1] 91 | if "," in spec: 92 | start, end = spec.split(",", 1) 93 | else: 94 | start = spec 95 | end = spec 96 | return cls(weight=float(weight), start_index=int(start), end_index=int(end)) 97 | 98 | def __str__(self): 99 | return f"{self.weight:.2f}[{self.start_index},{self.end_index}]" 100 | 101 | __repr__ = __str__ 102 | 103 | 104 | @dataclasses.dataclass 105 | class Embeddings: 106 | ids: np.ndarray 107 | embeddings: np.ndarray 108 | 109 | @property 110 | def size(self): 111 | return self.embeddings.shape[0] 112 | 113 | @property 114 | def dims(self): 115 | return self.embeddings.shape[1] 116 | 117 | def project(self, codec_index, codec_str) -> "Embeddings": 118 | projected = codec_index.sa_encode(self.embeddings) 119 | projected = np.frombuffer(projected, dtype=np.float32).reshape(self.size, -1) 120 | if not np.isfinite(projected).all(): 121 | raise ProjectionError( 122 | f"Projection to {codec_str} resulted in non-finite values" 123 | ) 124 | return dataclasses.replace(self, embeddings=projected) 125 | 126 | 127 | def dataset_split(outputs, split_id) -> Embeddings: 128 | split = outputs["split"] 129 | this_split = split == split_id 130 | embeddings = outputs["embeddings"][this_split, :] 131 | image_num = outputs["image_num"][this_split] 132 | order = np.argsort(image_num) 133 | embeddings = embeddings[order, :] 134 | image_num = image_num[order] 135 | return Embeddings(ids=image_num, embeddings=embeddings) 136 | 137 | 138 | def evaluate_all(dataset, outputs, codecs_arg, score_norm_arg, **kwargs): 139 | embeddings = outputs["embeddings"] 140 | codecs = get_codecs(embeddings.shape[1], is_l2_normalized(embeddings), codecs_arg) 141 | logger.info("Using codecs: %s", codecs) 142 | score_norms = [None] 143 | if score_norm_arg: 144 | score_norms.extend( 145 | [ScoreNormalization.parse(spec) for spec in score_norm_arg.split(";")] 146 | ) 147 | logger.info("Using score_norm: %s", score_norms) 148 | queries = dataset_split(outputs, DISCEvalDataset.SPLIT_QUERY) 149 | refs = dataset_split(outputs, DISCEvalDataset.SPLIT_REF) 150 | training = dataset_split(outputs, DISCEvalDataset.SPLIT_TRAIN) 151 | logger.info( 152 | "Dataset size: %d query, %d ref, %d train", 153 | queries.size, 154 | refs.size, 155 | training.size, 156 | ) 157 | all_metrics = [] 158 | for score_norm in score_norms: 159 | for codec in codecs: 160 | record = dict(codec=codec, score_norm=str(score_norm)) 161 | metrics = evaluate( 162 | dataset, queries, refs, training, score_norm, codec, **kwargs 163 | ) 164 | if metrics: 165 | record.update(metrics) 166 | all_metrics.append(record) 167 | return all_metrics 168 | 169 | 170 | def project( 171 | codec_str: str, queries: Embeddings, refs: Embeddings, training: Embeddings 172 | ): 173 | if codec_str != "Flat": 174 | assert codec_str.endswith(",Flat") 175 | codec = faiss.index_factory(training.dims, codec_str) 176 | codec.train(training.embeddings) 177 | queries = queries.project(codec, codec_str) 178 | refs = refs.project(codec, codec_str) 179 | training = training.project(codec, codec_str) 180 | return queries, refs, training 181 | 182 | 183 | def evaluate( 184 | dataset: DISCEvalDataset, 185 | queries: Embeddings, 186 | refs: Embeddings, 187 | training: Embeddings, 188 | score_norm: Optional[ScoreNormalization], 189 | codec, 190 | **kwargs, 191 | ): 192 | try: 193 | queries, refs, training = project(codec, queries, refs, training) 194 | except ProjectionError as e: 195 | logger.error(f"DISC eval {codec}: {e}") 196 | return None 197 | eval_kwargs = dict(kwargs) 198 | use_gpu = torch.cuda.is_available() 199 | if score_norm: 200 | queries, refs = apply_score_norm( 201 | queries, refs, training, score_norm, use_gpu=use_gpu 202 | ) 203 | eval_kwargs["metric"] = faiss.METRIC_INNER_PRODUCT 204 | metrics = dataset.retrieval_eval_splits( 205 | queries.ids, 206 | queries.embeddings, 207 | refs.ids, 208 | refs.embeddings, 209 | use_gpu=use_gpu, 210 | **eval_kwargs, 211 | ) 212 | logger.info( 213 | f"DISC eval ({score_norm or 'no norm'}, {codec}): {json.dumps(metrics)}" 214 | ) 215 | return metrics 216 | 217 | 218 | def apply_score_norm( 219 | queries, refs, training, score_norm: ScoreNormalization, use_gpu=False 220 | ): 221 | index = faiss.IndexFlatIP(training.dims) 222 | index.add(training.embeddings) 223 | if use_gpu: 224 | index = faiss.index_cpu_to_all_gpus(index) 225 | D, I = index.search(queries.embeddings, score_norm.end_index + 1) 226 | adjustment = -score_norm.weight * np.mean( 227 | D[:, score_norm.start_index : score_norm.end_index + 1], 228 | axis=1, 229 | keepdims=True, 230 | ) 231 | ones = np.ones_like(refs.embeddings[:, :1]) 232 | adjusted_queries = np.concatenate([queries.embeddings, adjustment], axis=1) 233 | adjusted_refs = np.concatenate([refs.embeddings, ones], axis=1) 234 | queries = dataclasses.replace(queries, embeddings=adjusted_queries) 235 | refs = dataclasses.replace(refs, embeddings=adjusted_refs) 236 | return queries, refs 237 | 238 | 239 | def main(args): 240 | logger.info("Setting up dataset") 241 | dataset = DISCData.make_validation_dataset( 242 | args.disc_path, 243 | size=args.size, 244 | include_train=True, 245 | preserve_aspect_ratio=args.preserve_aspect_ratio, 246 | ) 247 | outputs = Inference.inference(args, dataset) 248 | logger.info("Retrieval eval") 249 | eval_options = dict(k=args.k, global_candidates=args.global_candidates) 250 | records = evaluate_all( 251 | dataset, outputs, args.codecs, args.score_norm, **eval_options 252 | ) 253 | df = pd.DataFrame(records) 254 | if args.metadata: 255 | df["metadata"] = args.metadata 256 | csv_filename = os.path.join(args.output_path, "disc_metrics.csv") 257 | df.to_csv(csv_filename, index=False) 258 | with open(csv_filename, "r") as f: 259 | logger.info("DISC metrics:\n%s", f.read()) 260 | 261 | 262 | if __name__ == "__main__": 263 | args = parser.parse_args() 264 | main(args) 265 | -------------------------------------------------------------------------------- /sscd/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sscd/lib/distributed_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import enum 8 | 9 | import torch 10 | from classy_vision.generic.distributed_util import ( 11 | all_reduce_mean, 12 | all_reduce_sum, 13 | get_world_size, 14 | get_rank, 15 | ) 16 | from torch import autograd, distributed 17 | 18 | 19 | def multi_gather_batch(*tensors): 20 | """Gather tensors across nodes / GPUs. 21 | 22 | Tensors must have the same shape on all devices. Gathering for each 23 | tensor happens in parallel. 24 | """ 25 | 26 | world_size = distributed.get_world_size() 27 | out = [] 28 | handles = [] 29 | 30 | for tensor in tensors: 31 | gathered_shape = (world_size * tensor.shape[0],) + tensor.shape[1:] 32 | 33 | gathered = torch.empty(gathered_shape, dtype=tensor.dtype, device=tensor.device) 34 | # Non-contiguous tensors seem to get scrambled. Source and dest memory layouts 35 | # may have to match. 36 | tensor = tensor.contiguous() 37 | handle = distributed.all_gather( 38 | list(torch.chunk(gathered, world_size)), tensor, async_op=True 39 | ) 40 | 41 | out.append(gathered) 42 | handles.append(handle) 43 | 44 | for handle in handles: 45 | handle.wait() 46 | 47 | return out 48 | 49 | 50 | class ReduceMethod(enum.Enum): 51 | SUM = enum.auto() 52 | MEAN = enum.auto() 53 | # No gradient aggregation (eg. where all GPUs compute the same loss) 54 | NONE = enum.auto() 55 | 56 | 57 | class _CrossGPUBatch(autograd.Function): 58 | """Aggregates embeddings and labels across GPUs. 59 | 60 | This requires that batches have the same size on each GPU. 61 | """ 62 | 63 | @staticmethod 64 | def forward(ctx, embeddings, target, reduce_method: ReduceMethod): 65 | ctx.n = embeddings.size(0) 66 | ctx.reduce_method = reduce_method 67 | ctx.world_size = get_world_size() 68 | if target is None: 69 | if ctx.world_size == 1: 70 | return embeddings 71 | else: 72 | return multi_gather_batch(embeddings)[0] 73 | 74 | assert ctx.n == target.size(0) 75 | if ctx.world_size == 1: 76 | ctx.mark_non_differentiable(target) 77 | return embeddings, target 78 | all_embeddings, all_target = multi_gather_batch(embeddings, target) 79 | ctx.mark_non_differentiable(all_target) 80 | return all_embeddings, all_target 81 | 82 | @staticmethod 83 | def backward(ctx, all_embeddings_gradient, ignored_target_grad=None): 84 | if ctx.world_size == 1: 85 | embeddings_gradient = all_embeddings_gradient 86 | else: 87 | # Aggregate gradients across nodes. 88 | if ctx.reduce_method == ReduceMethod.MEAN: 89 | all_reduce_mean(all_embeddings_gradient) 90 | elif ctx.reduce_method == ReduceMethod.SUM: 91 | all_reduce_sum(all_embeddings_gradient) 92 | else: 93 | # Do not accumulate. 94 | assert ctx.reduce_method == ReduceMethod.NONE 95 | rank = get_rank() 96 | start = ctx.n * rank 97 | end = start + ctx.n 98 | # Slice gradient for embeddings that belong to this node. 99 | embeddings_gradient = all_embeddings_gradient[start:end] 100 | return (embeddings_gradient, None, None) 101 | 102 | 103 | cross_gpu_batch = _CrossGPUBatch.apply 104 | 105 | 106 | def cross_gpu_batch(embeddings, targets, reduce_method=ReduceMethod.SUM): 107 | return _CrossGPUBatch.apply(embeddings, targets, reduce_method) 108 | -------------------------------------------------------------------------------- /sscd/lib/fix_paths.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import sys 9 | 10 | # Make sure SSCD is in PYTHONPATH. 11 | base_path = str(Path(__file__).parent.parent.parent) 12 | 13 | if base_path not in sys.path: 14 | sys.path.append(base_path) 15 | -------------------------------------------------------------------------------- /sscd/lib/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import torch 10 | import pytorch_lightning as pl 11 | from classy_vision.generic.distributed_util import get_rank, get_world_size, barrier 12 | from pytorch_lightning.callbacks import BasePredictionWriter 13 | from pytorch_lightning.plugins import DDPSpawnPlugin 14 | from torch.utils.data import DataLoader 15 | from sscd.train import SSCD 16 | from sscd.models.model import Model 17 | from sscd.lib.util import call_using_args, parse_bool 18 | 19 | 20 | logger = logging.getLogger("inference.py") 21 | logger.setLevel(logging.INFO) 22 | 23 | 24 | class InferenceModel(pl.LightningModule): 25 | """Wraps a model for inference.""" 26 | 27 | def __init__(self, model, metadata_keys): 28 | super().__init__() 29 | self.model = model 30 | self.metadata_keys = metadata_keys 31 | 32 | def forward(self, x): 33 | return self.model(x) 34 | 35 | def predict_step(self, batch, batch_idx): 36 | input = batch["input"] 37 | batch = {k: v for (k, v) in batch.items() if k in self.metadata_keys} 38 | batch["embeddings"] = self(input) 39 | 40 | # Workaround for a CUDA synchronization bug in PyTorch Lightning. 41 | # Fixed upstream: 42 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/11287 43 | batch = {k: v.cpu() for (k, v) in batch.items()} 44 | 45 | return batch 46 | 47 | 48 | class Inference: 49 | @classmethod 50 | def add_parser_args(cls, parser): 51 | parser.add_argument("--checkpoint") 52 | parser.add_argument("--features") 53 | parser.add_argument("--model_state") 54 | parser.add_argument("--output_path", required=True) 55 | parser.add_argument("--gpus", default=1, type=int) 56 | parser.add_argument("--accelerator", default="auto") 57 | parser.add_argument("--nodes", default=1, type=int) 58 | parser.add_argument("--workers", default=10, type=int) 59 | parser.add_argument( 60 | "--size", default=288, type=int, help="Image size for inference" 61 | ) 62 | parser.add_argument("--preserve_aspect_ratio", default=False, type=parse_bool) 63 | # These options are only used if --model_state is provided. 64 | Model.add_arguments(parser) 65 | 66 | @classmethod 67 | def inference(cls, args, dataset, base_name="predictions"): 68 | if args.features: 69 | logger.info("Loading features") 70 | if os.path.exists(args.features): 71 | features_fn = args.features 72 | else: 73 | features_fn = f"{args.features}/{base_name}.pt" 74 | outputs = torch.load(features_fn, map_location=torch.device("cpu")) 75 | elif args.checkpoint or args.model_state: 76 | logger.info("Loading model") 77 | if args.checkpoint: 78 | pl_model = SSCD.load_from_checkpoint( 79 | args.checkpoint, map_location=torch.device("cpu") 80 | ) 81 | else: 82 | model = call_using_args(Model, args) 83 | state = torch.load(args.model_state, map_location=torch.device("cpu")) 84 | model.load_state_dict(state) 85 | pl_model = InferenceModel(model, ["image_num", "split", "instance_id"]) 86 | logger.info("Creating dataloader") 87 | dataloader = DataLoader( 88 | dataset, 89 | batch_size=1 if args.preserve_aspect_ratio else 256, 90 | num_workers=args.workers, 91 | persistent_workers=( 92 | args.workers > 0 93 | ), # unnecessary here, but silences warning 94 | ) 95 | writer = InferenceWriter(args.output_path, base_name) 96 | trainer = pl.Trainer( 97 | devices=args.gpus, 98 | num_nodes=args.nodes, 99 | accelerator=args.accelerator, 100 | default_root_dir=args.output_path, 101 | strategy=DDPSpawnPlugin(find_unused_parameters=False), 102 | callbacks=[writer], 103 | log_every_n_steps=1, 104 | ) 105 | logger.info("Starting inference") 106 | trainer.predict(pl_model, dataloaders=dataloader) 107 | logger.info("Loading features") 108 | outputs = writer.read() 109 | else: 110 | raise ValueError("Either --checkpoint or --features is required") 111 | 112 | logger.info("Deduplication") 113 | outputs = SSCD.dedup_outputs(outputs) 114 | return outputs 115 | 116 | 117 | def coalesce_outputs(outputs): 118 | keys = outputs[0].keys() 119 | return {k: torch.cat([out[k] for out in outputs]) for k in keys} 120 | 121 | 122 | class InferenceWriter(BasePredictionWriter): 123 | def __init__(self, output_path: str, filename: str): 124 | super().__init__("epoch") 125 | self.output_path = output_path 126 | self.filename = filename 127 | self.output_file = os.path.join(self.output_path, f"{filename}.pt") 128 | 129 | def _rank_fn(self, i): 130 | return os.path.join(self.output_path, f"{self.filename}_rank_{i}.pt") 131 | 132 | def write_on_epoch_end(self, trainer, module, predictions, batch_indices): 133 | rank = get_rank() 134 | assert len(predictions) == 1 135 | predictions = predictions[0] 136 | outputs = coalesce_outputs(predictions) 137 | logger.info( 138 | "Writing %d outputs for worker %d", outputs["embeddings"].size(0), rank 139 | ) 140 | torch.save(outputs, self._rank_fn(rank)) 141 | del outputs 142 | logger.info("Rank %d done. Waiting for peers.", rank) 143 | barrier() 144 | if rank == 0: 145 | logger.info("Combining prediction outputs.") 146 | worker_output_fns = [self._rank_fn(i) for i in range(get_world_size())] 147 | worker_outputs = [torch.load(fn) for fn in worker_output_fns] 148 | outputs = coalesce_outputs(worker_outputs) 149 | del worker_outputs 150 | torch.save(outputs, self.output_file) 151 | logger.info("Save completed.") 152 | for fn in worker_output_fns: 153 | os.remove(fn) 154 | 155 | def read(self): 156 | return torch.load(self.output_file) 157 | -------------------------------------------------------------------------------- /sscd/lib/initialize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from . import fix_paths # noqa 8 | from . import suppress_warnings # noqa 9 | -------------------------------------------------------------------------------- /sscd/lib/suppress_warnings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | # This relates to Classy Vision transforms that we don't use. 10 | warnings.filterwarnings("ignore", module=".*_(functional|transforms)_video") 11 | # Upstream Classy Vision issue; fix hasn't reached released package. 12 | # https://github.com/facebookresearch/ClassyVision/pull/770 13 | warnings.filterwarnings("ignore", message=".*To copy construct from a tensor.*") 14 | # Lightning non-issue (warning false positive). 15 | warnings.filterwarnings("ignore", message=".*overridden after .* initialization.*") 16 | -------------------------------------------------------------------------------- /sscd/lib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import inspect 9 | from typing import Callable 10 | from distutils.util import strtobool 11 | 12 | 13 | def call_using_args(function: Callable, args: argparse.Namespace): 14 | """Calls the callable using arguments from an argparse container.""" 15 | signature = inspect.signature(function) 16 | arguments = {key: getattr(args, key) for key in signature.parameters} 17 | return function(**arguments) 18 | 19 | 20 | def parse_bool(bool_str): 21 | return bool(strtobool(bool_str)) 22 | -------------------------------------------------------------------------------- /sscd/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sscd/models/gem_pooling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch import nn 8 | 9 | 10 | class GlobalGeMPool2d(nn.Module): 11 | """Generalized mean pooling. 12 | 13 | Inputs should be non-negative. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | pooling_param: float, 19 | ): 20 | """ 21 | Args: 22 | pooling_param: the GeM pooling parameter 23 | """ 24 | super().__init__() 25 | self.pooling_param = pooling_param 26 | 27 | def forward(self, x): 28 | N, C, H, W = x.size() 29 | x = x.reshape(N, C, H * W) # Combine spatial dimensions 30 | mean = x.clamp(min=1e-6).pow(self.pooling_param).mean(dim=2) 31 | r = 1.0 / self.pooling_param 32 | return mean.pow(r) 33 | -------------------------------------------------------------------------------- /sscd/models/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import enum 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torchvision.models.resnet import resnet18, resnet50, resnext101_32x8d 12 | from classy_vision.models import build_model 13 | from .gem_pooling import GlobalGeMPool2d 14 | 15 | 16 | class Implementation(enum.Enum): 17 | CLASSY_VISION = enum.auto() 18 | TORCHVISION = enum.auto() 19 | 20 | 21 | class Backbone(enum.Enum): 22 | CV_RESNET18 = ("resnet18", 512, Implementation.CLASSY_VISION) 23 | CV_RESNET50 = ("resnet50", 2048, Implementation.CLASSY_VISION) 24 | CV_RESNEXT101 = ("resnext101_32x4d", 2048, Implementation.CLASSY_VISION) 25 | 26 | TV_RESNET18 = (resnet18, 512, Implementation.TORCHVISION) 27 | TV_RESNET50 = (resnet50, 2048, Implementation.TORCHVISION) 28 | TV_RESNEXT101 = (resnext101_32x8d, 2048, Implementation.TORCHVISION) 29 | 30 | def build(self, dims: int): 31 | impl = self.value[2] 32 | if impl == Implementation.CLASSY_VISION: 33 | model = build_model({"name": self.value[0]}) 34 | # Remove head exec wrapper, which we don't need, and breaks pickling 35 | # (needed for spawn dataloaders). 36 | return model.classy_model 37 | if impl == Implementation.TORCHVISION: 38 | return self.value[0](num_classes=dims, zero_init_residual=True) 39 | raise AssertionError("Model implementation not handled: %s" % (self.name,)) 40 | 41 | 42 | class L2Norm(nn.Module): 43 | def forward(self, x): 44 | return F.normalize(x) 45 | 46 | 47 | class Model(nn.Module): 48 | def __init__(self, backbone: str, dims: int, pool_param: float): 49 | super().__init__() 50 | self.backbone_type = Backbone[backbone] 51 | self.backbone = self.backbone_type.build(dims=dims) 52 | impl = self.backbone_type.value[2] 53 | if impl == Implementation.CLASSY_VISION: 54 | self.embeddings = nn.Sequential( 55 | GlobalGeMPool2d(pool_param), 56 | nn.Linear(self.backbone_type.value[1], dims), 57 | L2Norm(), 58 | ) 59 | elif impl == Implementation.TORCHVISION: 60 | if pool_param > 1: 61 | self.backbone.avgpool = GlobalGeMPool2d(pool_param) 62 | fc = self.backbone.fc 63 | nn.init.xavier_uniform_(fc.weight) 64 | nn.init.constant_(fc.bias, 0) 65 | self.embeddings = L2Norm() 66 | 67 | def forward(self, x): 68 | x = self.backbone(x) 69 | return self.embeddings(x) 70 | 71 | @classmethod 72 | def add_arguments(cls, parser: argparse.ArgumentParser): 73 | parser = parser.add_argument_group("Model") 74 | parser.add_argument( 75 | "--backbone", default="TV_RESNET50", choices=[b.name for b in Backbone] 76 | ) 77 | parser.add_argument("--dims", default=512, type=int) 78 | parser.add_argument("--pool_param", default=3, type=float) 79 | -------------------------------------------------------------------------------- /sscd/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from argparse import ArgumentParser 9 | import logging 10 | import os 11 | import numpy as np 12 | import pytorch_lightning as pl 13 | import torch 14 | 15 | # Set up our python environment (eg. PYTHONPATH). 16 | from lib import initialize # noqa 17 | 18 | from classy_vision.dataset.transforms import build_transforms 19 | from pytorch_lightning.callbacks import LearningRateMonitor 20 | from pytorch_lightning.plugins import DDPSpawnPlugin 21 | from pytorch_lightning.utilities.apply_func import move_data_to_device 22 | from pl_bolts.optimizers.lars import LARS 23 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR 24 | from torch.utils.data import DataLoader 25 | 26 | from sscd.datasets.disc import DISCEvalDataset 27 | from sscd.datasets.image_folder import ImageFolder 28 | from sscd.lib.distributed_util import cross_gpu_batch 29 | from sscd.transforms.repeated_augmentation import RepeatedAugmentationTransform 30 | from sscd.transforms.mixup import ContrastiveMixup 31 | from sscd.models.model import Model 32 | from sscd.transforms.settings import AugmentationSetting 33 | from sscd.lib.util import call_using_args, parse_bool 34 | 35 | DEBUG = False 36 | 37 | if DEBUG: 38 | logging.basicConfig( 39 | format="%(asctime)s %(levelname)-8s %(message)s", 40 | level=logging.INFO, 41 | datefmt="%Y-%m-%d %H:%M:%S", 42 | ) 43 | 44 | 45 | def add_train_args(parser: ArgumentParser, required=True): 46 | parser = parser.add_argument_group("Train") 47 | parser.add_argument("--gpus", default=1, type=int) 48 | parser.add_argument("--accelerator", default="auto") 49 | parser.add_argument("--nodes", default=1, type=int) 50 | parser.add_argument( 51 | "--batch_size", 52 | default=4096, 53 | type=int, 54 | help="The global batch size (across all nodes/GPUs, before repeated augmentation)", 55 | ) 56 | parser.add_argument("--infonce_temperature", default=0.05, type=float) 57 | parser.add_argument("--entropy_weight", default=30, type=float) 58 | parser.add_argument("--epochs", default=100, type=int) 59 | parser.add_argument( 60 | "--augmentations", 61 | default="STRONG_BLUR", 62 | choices=[x.name for x in AugmentationSetting], 63 | ) 64 | parser.add_argument("--warmup_epochs", default=5, type=int) 65 | parser.add_argument( 66 | "--base_learning_rate", 67 | default=0.3, 68 | type=float, 69 | help="Base learning rate, for a batch size of 256. Linear scaling is applied.", 70 | ) 71 | parser.add_argument( 72 | "--absolute_learning_rate", 73 | default=None, 74 | type=float, 75 | help="Absolute learning rate (overrides --base_learning_rate).", 76 | ) 77 | parser.add_argument("--weight_decay", default=1e-6, type=float) 78 | parser.add_argument("--momentum", default=0.9, type=float) 79 | parser.add_argument("--sync_bn", default=True, type=parse_bool) 80 | parser.add_argument("--mixup", default=False, type=parse_bool) 81 | parser.add_argument( 82 | "--train_image_size", default=224, type=int, help="Image size for training" 83 | ) 84 | parser.add_argument( 85 | "--val_image_size", default=288, type=int, help="Image size for validation" 86 | ) 87 | parser.add_argument( 88 | "--workers", default=10, type=int, help="Data loader workers per GPU process." 89 | ) 90 | parser.add_argument("--num_sanity_val_steps", default=-1, type=int) 91 | 92 | 93 | def add_data_args(parser, required=True): 94 | parser = parser.add_argument_group("Data") 95 | parser.add_argument("--output_path", required=required, type=str) 96 | parser.add_argument("--train_dataset_path", required=required, type=str) 97 | parser.add_argument("--val_dataset_path", required=False, type=str) 98 | 99 | 100 | parser = ArgumentParser() 101 | Model.add_arguments(parser) 102 | add_train_args(parser) 103 | add_data_args(parser) 104 | 105 | 106 | class DISCData(pl.LightningDataModule): 107 | """A data module describing datasets used during training.""" 108 | 109 | def __init__( 110 | self, 111 | *, 112 | train_dataset_path, 113 | val_dataset_path, 114 | train_batch_size, 115 | augmentations: AugmentationSetting, 116 | train_image_size=224, 117 | val_image_size=288, 118 | val_batch_size=256, 119 | workers=10, 120 | ): 121 | super().__init__() 122 | self.train_batch_size = train_batch_size 123 | self.val_batch_size = val_batch_size 124 | self.train_dataset_path = train_dataset_path 125 | self.val_dataset_path = val_dataset_path 126 | self.workers = workers 127 | transforms = augmentations.get_transformations(train_image_size) 128 | transforms = RepeatedAugmentationTransform(transforms, copies=2) 129 | self.train_dataset = ImageFolder(self.train_dataset_path, transform=transforms) 130 | if val_dataset_path: 131 | self.val_dataset = self.make_validation_dataset( 132 | self.val_dataset_path, 133 | self.val_batch_size, 134 | size=val_image_size, 135 | ) 136 | else: 137 | self.val_dataset = None 138 | 139 | @classmethod 140 | def make_validation_dataset( 141 | cls, 142 | path, 143 | include_train=False, 144 | size=288, 145 | preserve_aspect_ratio=False, 146 | ): 147 | transforms = build_transforms( 148 | [ 149 | { 150 | "name": "Resize", 151 | "size": size if preserve_aspect_ratio else [size, size], 152 | }, 153 | {"name": "ToTensor"}, 154 | { 155 | "name": "Normalize", 156 | "mean": [0.485, 0.456, 0.406], 157 | "std": [0.229, 0.224, 0.225], 158 | }, 159 | ] 160 | ) 161 | return DISCEvalDataset(path, transform=transforms, include_train=include_train) 162 | 163 | def train_dataloader(self): 164 | return DataLoader( 165 | self.train_dataset, 166 | batch_size=self.train_batch_size, 167 | num_workers=self.workers, 168 | persistent_workers=True, 169 | shuffle=True, 170 | drop_last=True, 171 | ) 172 | 173 | def val_dataloader(self): 174 | if not self.val_dataset: 175 | return None 176 | return DataLoader( 177 | self.val_dataset, 178 | batch_size=self.val_batch_size, 179 | num_workers=self.workers, 180 | persistent_workers=True, 181 | ) 182 | 183 | 184 | class SSCD(pl.LightningModule): 185 | """Training class for SSCD models.""" 186 | 187 | def __init__(self, args, train_steps: int): 188 | super().__init__() 189 | self.save_hyperparameters() 190 | self.model = call_using_args(Model, args) 191 | use_mixup = args.mixup 192 | self.mixup = ( 193 | ContrastiveMixup.from_config( 194 | { 195 | "name": "contrastive_mixup", 196 | "mix_prob": 0.05, 197 | "mixup_alpha": 2, 198 | "cutmix_alpha": 2, 199 | "switch_prob": 0.5, 200 | "repeated_augmentations": 2, 201 | "target_column": "instance_id", 202 | } 203 | ) 204 | if use_mixup 205 | else None 206 | ) 207 | self.infonce_temperature = args.infonce_temperature 208 | self.entropy_weight = args.entropy_weight 209 | self.epochs = args.epochs 210 | self.warmup_epochs = args.warmup_epochs 211 | self.lr = args.absolute_learning_rate or ( 212 | args.base_learning_rate * args.batch_size / 256 213 | ) 214 | self.weight_decay = args.weight_decay 215 | self.momentum = args.momentum 216 | self.train_steps = train_steps 217 | 218 | def forward(self, x): 219 | return self.model(x) 220 | 221 | def configure_optimizers(self): 222 | optimizer = LARS( 223 | self.parameters(), 224 | lr=self.lr, 225 | weight_decay=self.weight_decay, 226 | momentum=self.momentum, 227 | trust_coefficient=0.001, 228 | eps=1e-8, 229 | ) 230 | scheduler = { 231 | "scheduler": LinearWarmupCosineAnnealingLR( 232 | optimizer, 233 | warmup_epochs=self.warmup_epochs * self.train_steps, 234 | max_epochs=self.epochs * self.train_steps, 235 | ), 236 | "interval": "step", 237 | "frequency": 1, 238 | } 239 | return [optimizer], [scheduler] 240 | 241 | def training_step(self, batch, batch_idx): 242 | # Concatenate copies. 243 | input = torch.cat([batch["input0"], batch["input1"]]) 244 | instance_ids = torch.cat([batch["instance_id"], batch["instance_id"]]) 245 | if self.mixup: 246 | batch = self.mixup({"input": input, "instance_id": instance_ids}) 247 | input = batch["input"] 248 | instance_ids = batch["instance_id"] 249 | embeddings = self(input) 250 | return self.loss(embeddings, instance_ids) 251 | 252 | def cross_gpu_similarity(self, embeddings, instance_labels, mixup: bool): 253 | """Compute a cross-GPU embedding similarity matrix. 254 | 255 | Embeddings are gathered differentiably, via an autograd function. 256 | 257 | Returns a tuple of similarity, match_matrix, and indentity tensors, 258 | defined as follows, where N is the batch size (including copies), and 259 | W is the world size: 260 | similarity (N, W*N), float: embedding inner-product similarity between 261 | this GPU's embeddings, and all embeddings in the global 262 | (cross-GPU) batch. 263 | match_matrix (N, W*N), bool: cell [i][j] is True when batch[i] and 264 | global_batch[j] share input content (take any content from the 265 | same original image), including trivial pairs (comparing a 266 | copy with itself). 267 | identity (N, W*N), bool: cell [i][j] is True only for trivial pairs 268 | (comparing a copy with itself). This identifies the "diagonal" 269 | in the global (virtual) W*N x W*N similarity matrix. Since each 270 | GPU has only a slice of this matrix (to avoid N^2 memory use), 271 | the "diagonal" requires a bit of logic to identify. 272 | """ 273 | all_embeddings, all_instance_labels = cross_gpu_batch( 274 | embeddings, instance_labels 275 | ) 276 | N = embeddings.size(0) 277 | M = all_embeddings.size(0) 278 | R = self.global_rank 279 | similarity = embeddings.matmul(all_embeddings.transpose(0, 1)) 280 | if mixup: 281 | # In the mixup case, instance_labels are a NxN distribution 282 | # describing similarity within a per-GPU batch. We infer all inputs 283 | # from other GPUs are negatives, and use any inputs with nonzero 284 | # similarity as positives. 285 | match_matrix = torch.zeros( 286 | (N, M), dtype=torch.bool, device=embeddings.device 287 | ) 288 | match_matrix[:, R * N : (R + 1) * N] = ( 289 | instance_labels.matmul(instance_labels.transpose(0, 1)) > 0 290 | ) 291 | else: 292 | # In the non-mixup case, instance_labels are instance ID long ints. 293 | # We broadcast a `==` operation to translate this to a match matrix. 294 | match_matrix = instance_labels.unsqueeze( 295 | 1 296 | ) == all_instance_labels.unsqueeze(0) 297 | identity = torch.zeros_like(match_matrix) 298 | identity[:, R * N : (R + 1) * N] = torch.eye(N).to(identity) 299 | return similarity, match_matrix, identity 300 | 301 | def loss(self, embeddings, instance_labels): 302 | similarity, match_matrix, identity = self.cross_gpu_similarity( 303 | embeddings, instance_labels, self.mixup 304 | ) 305 | non_matches = match_matrix == 0 306 | nontrivial_matches = match_matrix * (~identity) 307 | 308 | # InfoNCE loss 309 | small_value = torch.tensor(-100.0).to( 310 | similarity 311 | ) # any value > max L2 normalized distance 312 | max_non_match_sim, _ = torch.where(non_matches, similarity, small_value).max( 313 | dim=1, keepdim=True 314 | ) 315 | logits = (similarity / self.infonce_temperature).exp() 316 | partitions = logits + ((non_matches * logits).sum(dim=1) + 1e-6).unsqueeze(1) 317 | probabilities = logits / partitions 318 | if self.mixup: 319 | infonce_loss = ( 320 | (-probabilities.log() * nontrivial_matches).sum(dim=1) 321 | / nontrivial_matches.sum(dim=1) 322 | ).mean() 323 | else: 324 | infonce_loss = ( 325 | -probabilities.log() * nontrivial_matches 326 | ).sum() / similarity.size(0) 327 | 328 | components = {"InfoNCE": infonce_loss} 329 | loss = infonce_loss 330 | if self.entropy_weight: 331 | # Differential entropy regularization loss. 332 | closest_distance = (2 - (2 * max_non_match_sim)).clamp(min=1e-6).sqrt() 333 | entropy_loss = -closest_distance.log().mean() * self.entropy_weight 334 | components["entropy"] = entropy_loss 335 | loss = infonce_loss + entropy_loss 336 | 337 | # Log stats and loss components. 338 | with torch.no_grad(): 339 | stats = { 340 | "positive_sim": (similarity * nontrivial_matches).sum() 341 | / nontrivial_matches.sum(), 342 | "negative_sim": (similarity * non_matches).sum() / non_matches.sum(), 343 | "nearest_negative_sim": max_non_match_sim.mean(), 344 | "center_l2_norm": embeddings.mean(dim=0).pow(2).sum().sqrt(), 345 | } 346 | self.log_dict(stats, on_step=False, on_epoch=True) 347 | self.log_dict(components, on_step=True, on_epoch=True, prog_bar=True) 348 | if self.logger: 349 | self.logger.experiment.add_scalars( 350 | "loss components", 351 | components, 352 | global_step=self.global_step, 353 | ) 354 | self.logger.experiment.add_scalars( 355 | "similarity stats", 356 | stats, 357 | global_step=self.global_step, 358 | ) 359 | 360 | return loss 361 | 362 | def validation_step(self, batch, batch_idx): 363 | input = batch["input"] 364 | metadata_keys = ["image_num", "split", "instance_id"] 365 | batch = {k: v for (k, v) in batch.items() if k in metadata_keys} 366 | batch["embeddings"] = self(input) 367 | return batch 368 | 369 | def validation_epoch_end(self, outputs): 370 | keys = ["embeddings", "image_num", "split", "instance_id"] 371 | outputs = {k: torch.cat([out[k] for out in outputs]) for k in keys} 372 | outputs = self._gather(outputs) 373 | outputs = self.dedup_outputs(outputs) 374 | if self.current_epoch == 0: 375 | self.print( 376 | "Eval dataset size: %d (%d queries, %d index)" 377 | % ( 378 | outputs["split"].shape[0], 379 | (outputs["split"] == DISCEvalDataset.SPLIT_QUERY).sum(), 380 | (outputs["split"] == DISCEvalDataset.SPLIT_REF).sum(), 381 | ) 382 | ) 383 | dataset: DISCEvalDataset = self.trainer.datamodule.val_dataset 384 | metrics = dataset.retrieval_eval( 385 | outputs["embeddings"], 386 | outputs["image_num"], 387 | outputs["split"], 388 | ) 389 | metrics = {k: 0.0 if v is None else v for (k, v) in metrics.items()} 390 | self.log_dict(metrics, on_epoch=True) 391 | 392 | def on_train_epoch_end(self): 393 | metrics = [] 394 | for k, v in self.trainer.logged_metrics.items(): 395 | if k.endswith("_step"): 396 | continue 397 | if k.endswith("_epoch"): 398 | k = k[: -len("_epoch")] 399 | if torch.is_tensor(v): 400 | v = v.item() 401 | metrics.append(f"{k}: {v:.3f}") 402 | metrics = ", ".join(metrics) 403 | self.print(f"Epoch {self.current_epoch}: {metrics}") 404 | 405 | def on_train_end(self): 406 | if self.global_rank != 0: 407 | return 408 | if not self.logger: 409 | return 410 | path = os.path.join(self.logger.log_dir, "model_torchscript.pt") 411 | self.save_torchscript(path) 412 | 413 | def save_torchscript(self, filename): 414 | self.eval() 415 | input = torch.randn((1, 3, 64, 64), device=self.device) 416 | script = torch.jit.trace(self.model, input) 417 | torch.jit.save(script, filename) 418 | 419 | def _gather(self, batch): 420 | batch = self.all_gather(move_data_to_device(batch, self.device)) 421 | return { 422 | k: v.reshape([-1] + list(v.size()[2:])).cpu() for (k, v) in batch.items() 423 | } 424 | 425 | @staticmethod 426 | def dedup_outputs(outputs, key="instance_id"): 427 | """Deduplicate dataset on instance_id.""" 428 | idx = np.unique(outputs[key].numpy(), return_index=True)[1] 429 | outputs = {k: v.numpy()[idx] for (k, v) in outputs.items()} 430 | assert np.unique(outputs[key]).size == outputs["instance_id"].size 431 | return outputs 432 | 433 | def predict_step(self, batch, batch_idx): 434 | batch = self.validation_step(batch, batch_idx) 435 | 436 | # Workaround for a CUDA synchronization bug in PyTorch Lightning. 437 | # Fixed upstream: 438 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/11287 439 | batch = {k: v.cpu() for (k, v) in batch.items()} 440 | 441 | return batch 442 | 443 | 444 | def main(args): 445 | world_size = args.nodes * args.gpus 446 | if args.batch_size % world_size != 0: 447 | raise ValueError( 448 | f"Global batch size ({args.batch_size}) must be a multiple of " 449 | f"the number of GPUs ({world_size})." 450 | ) 451 | data = DISCData( 452 | train_dataset_path=args.train_dataset_path, 453 | val_dataset_path=args.val_dataset_path, 454 | train_batch_size=args.batch_size // world_size, 455 | train_image_size=args.train_image_size, 456 | val_image_size=args.val_image_size, 457 | augmentations=AugmentationSetting[args.augmentations], 458 | workers=args.workers, 459 | ) 460 | model = SSCD( 461 | args, 462 | train_steps=len(data.train_dataset) // args.batch_size, 463 | ) 464 | trainer = pl.Trainer( 465 | devices=args.gpus, 466 | num_nodes=args.nodes, 467 | accelerator=args.accelerator, 468 | max_epochs=args.epochs, 469 | sync_batchnorm=args.sync_bn, 470 | default_root_dir=args.output_path, 471 | strategy=DDPSpawnPlugin(find_unused_parameters=False), 472 | check_val_every_n_epoch=1, 473 | log_every_n_steps=1, 474 | num_sanity_val_steps=args.num_sanity_val_steps, 475 | callbacks=[LearningRateMonitor()], 476 | ) 477 | trainer.fit(model, datamodule=data) 478 | 479 | 480 | if __name__ == "__main__": 481 | args = parser.parse_args() 482 | main(args) 483 | -------------------------------------------------------------------------------- /sscd/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sscd/transforms/mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from typing import Any, Dict 9 | 10 | import torch 11 | from classy_vision.dataset.transforms import ( 12 | ClassyTransform, 13 | register_transform, 14 | ) 15 | from classy_vision.dataset.transforms.mixup import MixupTransform 16 | 17 | 18 | @register_transform("contrastive_mixup") 19 | class ContrastiveMixup(ClassyTransform): 20 | """Mixup / cutmix augmentations for contrastive learning.""" 21 | 22 | MIXUP_DEFAULTS = { 23 | "mixup_alpha": 1.0, 24 | "cutmix_alpha": 0.0, 25 | "cutmix_minmax": None, 26 | "switch_prob": 0.5, 27 | "mode": "elem", 28 | "correct_lam": True, 29 | "label_smoothing": 0.0, 30 | } 31 | 32 | def __init__( 33 | self, 34 | image_column: str, 35 | target_column: str, 36 | mixup_args: Dict[str, Any], 37 | repeated_augmentations: int, 38 | ): 39 | self.image_column = image_column 40 | self.target_column = target_column 41 | self.repeated_augmentations = repeated_augmentations 42 | self.mixup_args = mixup_args 43 | 44 | def transcode_targets(self, target): 45 | R = self.repeated_augmentations 46 | M = target.size(0) 47 | assert ( 48 | M % R == 0 49 | ), "Config error: Batch size not divisible by repeated augmentations" 50 | N = M // R 51 | transcoded = torch.arange(M, dtype=target.dtype, device=target.device) % N 52 | # Sanity checking 53 | old_matches = target.unsqueeze(1) == target.unsqueeze(0) 54 | new_matches = transcoded.unsqueeze(1) == transcoded.unsqueeze(0) 55 | mismatches = new_matches ^ old_matches 56 | if mismatches.any(): 57 | num_mismatches = mismatches.sum().item() 58 | logging.warning( 59 | f"Target transcoding introduced {num_mismatches} mismatches. " 60 | f"Batch size {N} * R={R} = {M}." 61 | ) 62 | return transcoded, N 63 | 64 | def __call__(self, batch): 65 | targets, num_classes = self.transcode_targets(batch[self.target_column]) 66 | # The mixup transform appears to mutate input tensors in place. This 67 | # produces surprising results for repeated augmentations. Clone the 68 | # input tensor before calling mixup. 69 | mixup_batch = { 70 | "input": batch[self.image_column].clone(), 71 | "target": targets, 72 | } 73 | mixup = MixupTransform(**self.mixup_args, num_classes=num_classes) 74 | mixed = mixup(mixup_batch) 75 | batch = batch.copy() 76 | batch[self.image_column] = mixed["input"] 77 | batch[self.target_column] = mixed["target"] 78 | return batch 79 | 80 | @classmethod 81 | def from_config(cls, config: Dict[str, Any]): 82 | # Needed to recode targets to per-GPU-batch relative identifiers 83 | image_column = config.get("image_column", "input") 84 | target_column = config.get("target_column", "target") 85 | repeated_augmentations = config["repeated_augmentations"] 86 | mixup_args = { 87 | "mix_prob": config["mix_prob"], # required 88 | } 89 | for key, default in cls.MIXUP_DEFAULTS.items(): 90 | mixup_args[key] = config.get(key, default) 91 | 92 | return cls(image_column, target_column, mixup_args, repeated_augmentations) 93 | -------------------------------------------------------------------------------- /sscd/transforms/overlay_emoji.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import functools 9 | import random 10 | import glob 11 | from typing import Any, Dict 12 | 13 | import numpy as np 14 | from classy_vision.dataset.transforms import ClassyTransform, register_transform 15 | from PIL import Image 16 | from augly.utils import EMOJI_DIR 17 | 18 | from .samplers import Sampler 19 | 20 | 21 | class EmojiRepository: 22 | def __init__(self, path): 23 | self._emoji_fpaths = glob.glob(os.path.join(path, "*/*.png")) 24 | self._emoji_images = {} 25 | 26 | def map_path(self, path, local_path): 27 | path = path.strip() 28 | if local_path: 29 | local_mapped = os.path.join(local_path, os.path.basename(path)) 30 | if os.path.isfile(local_mapped): 31 | return local_mapped 32 | return path 33 | 34 | def random_emoji(self) -> Image.Image: 35 | emoji_fpath = random.choice(self._emoji_fpaths) 36 | return self.get_emoji(emoji_fpath) 37 | 38 | @functools.lru_cache(maxsize=None) 39 | def get_emoji(self, emoji_fpath: str) -> Image.Image: 40 | return Image.open(open(emoji_fpath, "rb")) 41 | 42 | @classmethod 43 | @functools.lru_cache(maxsize=None) 44 | def get(cls, path) -> "EmojiRepository": 45 | return cls(path) 46 | 47 | def size(self): 48 | return len(self._emoji_fpaths) 49 | 50 | 51 | @register_transform("OverlayEmoji") 52 | class OverlayEmojiTransform(ClassyTransform): 53 | """ 54 | Overlays (random) emoji on image 55 | """ 56 | 57 | def __init__( 58 | self, 59 | emoji_vault: str, 60 | emoji_size_sampler: Sampler, 61 | opacity_sampler: Sampler, 62 | fx_sampler: Sampler, 63 | fy_sampler: Sampler, 64 | ): 65 | self._emojis = EmojiRepository.get(emoji_vault) 66 | assert self._emojis.size() > 0 67 | self._emoji_size_sampler = emoji_size_sampler 68 | self._opacity_sampler = opacity_sampler 69 | self._fx_sampler = fx_sampler 70 | self._fy_sampler = fy_sampler 71 | 72 | def __call__(self, image: Image.Image): 73 | emoji: Image.Image = self._emojis.random_emoji() 74 | emoji_w, emoji_h = emoji.size 75 | image_w, image_h = image.size 76 | max_scale = min(image_w / emoji_w, image_h / emoji_h) 77 | emoji_size_frac = self._emoji_size_sampler() 78 | emoji_scale = max_scale * emoji_size_frac 79 | emoji = emoji.resize( 80 | (int(emoji_w * emoji_scale), int(emoji_h * emoji_scale)), 81 | resample=Image.BILINEAR, 82 | ) 83 | fx = self._fx_sampler() 84 | fy = self._fy_sampler() 85 | topleft_x = int(fx * (image_w - emoji.width)) 86 | topleft_y = int(fy * (image_h - emoji.height)) 87 | opacity = self._opacity_sampler() 88 | # perform overlay 89 | image_rgba = image.copy().convert("RGBA") 90 | # Get the mask of the emoji if it has one, otherwise create it 91 | try: 92 | mask = emoji.getchannel("A") 93 | mask = Image.fromarray((np.array(mask) * opacity).astype(np.uint8)) 94 | except ValueError: 95 | mask = Image.new(mode="L", size=emoji.size, color=int(opacity * 255)) 96 | image_rgba.paste(emoji, box=(topleft_x, topleft_y), mask=mask) 97 | return image_rgba.convert("RGB") 98 | 99 | @classmethod 100 | def from_config(cls, config: Dict[str, Any]) -> "OverlayEmojiTransform": 101 | emoji_vault = config.get("emoji_vault", EMOJI_DIR) 102 | emoji_size_sampler = Sampler.from_config(config["emoji_size"]) 103 | opacity_sampler = Sampler.from_config(config["opacity"]) 104 | fx_sampler = Sampler.from_config(config["fx"]) 105 | fy_sampler = Sampler.from_config(config["fy"]) 106 | transform = cls( 107 | emoji_vault, emoji_size_sampler, opacity_sampler, fx_sampler, fy_sampler 108 | ) 109 | return transform 110 | -------------------------------------------------------------------------------- /sscd/transforms/overlay_text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import dataclasses 8 | import functools 9 | import io 10 | import logging 11 | import os 12 | import pickle 13 | import random 14 | import numpy as np 15 | from typing import Any, Dict, List 16 | from PIL import Image, ImageFont, ImageDraw 17 | 18 | from classy_vision.dataset.transforms import ClassyTransform, register_transform 19 | from .samplers import Sampler 20 | from augly.utils import FONTS_DIR 21 | 22 | 23 | @dataclasses.dataclass 24 | class Font: 25 | name: str 26 | path: str 27 | ttf_bytes: bytes 28 | charset: Any # numpy array 29 | 30 | def ttf(self): 31 | return io.BytesIO(self.ttf_bytes) 32 | 33 | def image_font(self, size) -> ImageFont: 34 | return ImageFont.truetype(self.ttf(), size) 35 | 36 | @classmethod 37 | def load(cls, path) -> "Font": 38 | prefix, ext = os.path.splitext(path) 39 | assert ext in [".ttf", ".pkl"] 40 | ttf_path = f"{prefix}.ttf" 41 | name = os.path.basename(ttf_path) 42 | with open(ttf_path, "rb") as f: 43 | ttf_bytes = f.read() 44 | with open(f"{prefix}.pkl", "rb") as f: 45 | charset = np.array(pickle.load(f), dtype=np.int64) 46 | return cls(name=name, path=ttf_path, ttf_bytes=ttf_bytes, charset=charset) 47 | 48 | def sample_chars(self, length) -> List[int]: 49 | return random.choices(self.charset, k=length) 50 | 51 | def sample_string(self, length) -> str: 52 | characters = self.sample_chars(length) 53 | return "".join(chr(x) for x in characters) 54 | 55 | 56 | class FontRepository: 57 | 58 | fonts = List[Font] 59 | 60 | def __init__(self, path): 61 | filenames = [ 62 | os.path.join(path, filename) 63 | for filename in os.listdir(path) 64 | if filename.endswith(".ttf") 65 | ] 66 | logging.info("Loading %d fonts from %s.", len(filenames), path) 67 | self.fonts = [Font.load(filename) for filename in filenames] 68 | logging.info("Finished loading %d fonts.", len(filenames)) 69 | 70 | def random_font(self) -> Font: 71 | return random.choice(self.fonts) 72 | 73 | @classmethod 74 | @functools.lru_cache(maxsize=None) 75 | def get(cls, path) -> "FontRepository": 76 | return cls(path) 77 | 78 | def size(self): 79 | return len(self.fonts) 80 | 81 | 82 | @register_transform("OverlayText") 83 | class OverlayTextTransform(ClassyTransform): 84 | """ 85 | Overlays text on image 86 | """ 87 | 88 | def __init__( 89 | self, 90 | font_vault: str, 91 | font_size_sampler: Sampler, 92 | opacity_sampler: Sampler, 93 | color_sampler: Sampler, 94 | fx_sampler: Sampler, 95 | fy_sampler: Sampler, 96 | ): 97 | self._fonts = FontRepository.get(font_vault) 98 | assert self._fonts.size() > 0 99 | self._font_size_sampler = font_size_sampler 100 | self._opacity_sampler = opacity_sampler 101 | self._color_sampler = color_sampler 102 | self._fx_sampler = fx_sampler 103 | self._fy_sampler = fy_sampler 104 | 105 | def __call__(self, image: Image.Image): 106 | # instantiate font 107 | font: Font = self._fonts.random_font() 108 | font_size_frac = self._font_size_sampler() 109 | font_size = int(min(image.width, image.height) * font_size_frac) 110 | image_font = font.image_font(font_size) 111 | # sample a string of fixed length from charset 112 | _SAMPLE_STR_LEN = 100 113 | text_str = font.sample_string(_SAMPLE_STR_LEN) 114 | # compute maximum length that fits into image 115 | # TODO: binary search over a lazy list of fixed length 116 | # (tw and th are monotonically increasing) 117 | maxlen = 0 118 | for i in range(1, len(text_str)): 119 | substr = text_str[:i] 120 | try: 121 | tw, th = image_font.getsize(substr) 122 | except OSError as e: 123 | # Safeguard against invalid chars in charset 124 | # that produce "invalid composite glyph" error 125 | logging.warning(f"Error, font={font.path}, char_i={ord(substr[-1])}") 126 | logging.warning(e) 127 | # don't overlay text in case of invalid glyphs 128 | return image 129 | if (tw > image.width) or (th > image.height): 130 | maxlen = i - 1 131 | break 132 | if maxlen == 0: 133 | return image 134 | # sample text length and get definitive text size 135 | text_len = random.randint(1, maxlen) 136 | text_str = text_str[:text_len] 137 | text_width, text_height = image_font.getsize(text_str) 138 | assert (text_width <= image.width) and (text_height <= image.height), ( 139 | f"Text has size (H={text_height}, W={text_width}) which does " 140 | f"not fit into image of size (H={image.height}, W={image.width})" 141 | ) 142 | # sample text location 143 | fx = self._fx_sampler() 144 | fy = self._fy_sampler() 145 | topleft_x = fx * (image.width - text_width) 146 | topleft_y = fy * (image.height - text_height) 147 | opacity = self._opacity_sampler() 148 | alpha = int(opacity * 255 + 0.5) 149 | color = tuple(self._color_sampler()) 150 | color_w_opacity = color + (alpha,) 151 | # create output image 152 | image_base = image.convert("RGBA") 153 | image_txt = Image.new("RGBA", image_base.size, (255, 255, 255, 0)) 154 | draw = ImageDraw.Draw(image_txt) 155 | draw.text( 156 | xy=(topleft_x, topleft_y), 157 | text=text_str, 158 | fill=color_w_opacity, 159 | font=image_font, 160 | ) 161 | image_out = Image.alpha_composite(image_base, image_txt).convert("RGB") 162 | return image_out 163 | 164 | @classmethod 165 | def from_config(cls, config: Dict[str, Any]) -> "OverlayTextTransform": 166 | font_vault = config.get("font_vault", FONTS_DIR) 167 | font_size_sampler = Sampler.from_config(config["font_size"]) 168 | opacity_sampler = Sampler.from_config(config["opacity"]) 169 | color_sampler = Sampler.from_config(config["color"]) 170 | fx_sampler = Sampler.from_config(config["fx"]) 171 | fy_sampler = Sampler.from_config(config["fy"]) 172 | transform = cls( 173 | font_vault, 174 | font_size_sampler, 175 | opacity_sampler, 176 | color_sampler, 177 | fx_sampler, 178 | fy_sampler, 179 | ) 180 | return transform 181 | -------------------------------------------------------------------------------- /sscd/transforms/repeated_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | class RepeatedAugmentationTransform: 9 | """Applies a transform multiple times. 10 | 11 | Input: {"input": , ...} 12 | Output: {"input0": , "input1": , ...} 13 | """ 14 | 15 | def __init__(self, transform, copies=2, key="input"): 16 | self.transform = transform 17 | self.copies = copies 18 | self.key = key 19 | 20 | def __call__(self, record): 21 | record = record.copy() 22 | img = record.pop(self.key) 23 | for i in range(self.copies): 24 | record[f"{self.key}{i}"] = self.transform(img) 25 | return record 26 | -------------------------------------------------------------------------------- /sscd/transforms/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from collections.abc import Callable, Mapping 9 | from enum import Enum 10 | from typing import Dict, List, Any, Union, Tuple 11 | 12 | 13 | class SamplerType(str, Enum): 14 | FIXED = ("fixed",) 15 | CHOICE = ("choice",) 16 | TUPLE = "tuple" 17 | UNIFORM = "uniform" 18 | UNIFORMINT = "uniformint" 19 | 20 | 21 | class Sampler(Callable): 22 | @classmethod 23 | def from_config(cls, sampler_spec_or_value: Any) -> "Sampler": 24 | if not isinstance(sampler_spec_or_value, Mapping) or ( 25 | "sampler_type" not in sampler_spec_or_value 26 | ): 27 | return FixedValueSampler(sampler_spec_or_value) 28 | sampler_spec = sampler_spec_or_value 29 | sampler_type = SamplerType(sampler_spec["sampler_type"]) 30 | if sampler_type == SamplerType.FIXED: 31 | return FixedValueSampler.from_config(sampler_spec) 32 | elif sampler_type == SamplerType.CHOICE: 33 | return ChoiceSampler.from_config(sampler_spec) 34 | elif sampler_type == SamplerType.TUPLE: 35 | return TupleSampler.from_config(sampler_spec) 36 | elif sampler_type == SamplerType.UNIFORM: 37 | return UniformSampler.from_config(sampler_spec) 38 | elif sampler_type == SamplerType.UNIFORMINT: 39 | return UniformIntSampler.from_config(sampler_spec) 40 | else: 41 | raise ValueError(f"Unknown sampler type {sampler_type}") 42 | 43 | 44 | class FixedValueSampler(Sampler): 45 | """ 46 | Trivial sampler that returns a fixed value 47 | """ 48 | 49 | def __init__(self, value: Any): 50 | self._value = value 51 | 52 | def __call__(self): 53 | return self._value 54 | 55 | @classmethod 56 | def from_config(cls, sampler_spec: Dict[str, Any]) -> "FixedValueSampler": 57 | assert ( 58 | "value" in sampler_spec 59 | ), f"Fixed sampler value not specified: {sampler_spec}" 60 | value = sampler_spec["value"] 61 | return cls(value=value) 62 | 63 | 64 | class ChoiceSampler(Sampler): 65 | """ 66 | Produces samples from the given population of values uniformly or 67 | based on provided weights. Config for uniform sampling should 68 | look like: 69 | { 70 | "sampler_type": "choice", 71 | "values": ["a", "b", "c"] 72 | } 73 | Config for categorical distribution sampling should look like: 74 | { 75 | "sampler_type": "choice", 76 | "values": { "a": 0.5, "b": 0.2, "c": 0.3 } 77 | } 78 | """ 79 | 80 | def __init__(self, values: Union[Dict, List]): 81 | if isinstance(values, list): 82 | self._values = values 83 | self._weights = None 84 | elif isinstance(values, dict): 85 | ks, vs = zip(*values.items()) 86 | self._values = list(ks) 87 | self._weights = list(vs) 88 | else: 89 | raise ValueError( 90 | f"Choice sampler values are of type {type(values)}," 91 | f" expected either list or dict" 92 | ) 93 | 94 | def __call__(self): 95 | sample = random.choices(self._values, self._weights)[0] 96 | return sample 97 | 98 | @classmethod 99 | def from_config(cls, config: Dict[str, Any]) -> "ChoiceSampler": 100 | assert ( 101 | "values" in config 102 | ), f"Choice sampler values not specified in config: {config}" 103 | values = config["values"] 104 | return cls(values=values) 105 | 106 | 107 | class UniformSampler(Sampler): 108 | """ 109 | Samples values from a uniform distribution given by lower and upper bounds 110 | """ 111 | 112 | def __init__(self, low: float, high: float): 113 | self._low = low 114 | self._high = high 115 | 116 | def __call__(self): 117 | sample = random.uniform(self._low, self._high) 118 | return sample 119 | 120 | @classmethod 121 | def from_config(cls, config: Dict[str, Any]) -> "UniformSampler": 122 | assert ( 123 | "low" in config 124 | ), f'Uniform sampler lower bound ("low") not specified in config: {config}' 125 | assert ( 126 | "high" in config 127 | ), f'Uniform sampler upper bound ("high") not specified in config: {config}' 128 | return cls(low=config["low"], high=config["high"]) 129 | 130 | 131 | class UniformIntSampler(Sampler): 132 | """ 133 | Samples values from a uniform int distribution given by lower and upper bounds 134 | """ 135 | 136 | def __init__(self, low: int, high: int): 137 | self._low = low 138 | self._high = high 139 | 140 | def __call__(self): 141 | sample = random.randint(self._low, self._high) 142 | return sample 143 | 144 | @classmethod 145 | def from_config(cls, config: Dict[str, Any]) -> "UniformIntSampler": 146 | assert "low" in config, ( 147 | f'Uniform int sampler lower bound ("low") not specified in ' 148 | f"config: {config}" 149 | ) 150 | assert "high" in config, ( 151 | f'Uniform int sampler upper bound ("high") not specified in ' 152 | f"config: {config}" 153 | ) 154 | return cls(low=config["low"], high=config["high"]) 155 | 156 | 157 | class TupleSampler(Sampler): 158 | """ 159 | Sample a tuple of values given a collection of samplers 160 | """ 161 | 162 | def __init__(self, samplers: Tuple): 163 | self._samplers = samplers 164 | 165 | def __call__(self): 166 | samples = tuple(sampler() for sampler in self._samplers) 167 | return samples 168 | 169 | @classmethod 170 | def from_config(cls, config: Dict[str, Any]) -> "TupleSampler": 171 | assert ( 172 | "samplers" in config 173 | ), f'Tuple samplers ("samplers") not specified in config: {config}' 174 | samplers = [ 175 | Sampler.from_config(sampler_spec) for sampler_spec in config["samplers"] 176 | ] 177 | return cls(samplers=samplers) 178 | -------------------------------------------------------------------------------- /sscd/transforms/settings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import enum 8 | import json 9 | from classy_vision.dataset.transforms import build_transforms 10 | 11 | 12 | _simclr_config = """ 13 | [ 14 | {"name": "RandomResizedCrop", "size": %(train_image_size)d}, 15 | {"name": "RandomHorizontalFlip"}, 16 | { 17 | "name": "MaybeApply", 18 | "p": 0.8, 19 | "transform": { 20 | "name": "ColorJitter", 21 | "brightness": 0.8, 22 | "contrast": 0.8, 23 | "saturation": 0.8, 24 | "hue": 0.2 25 | } 26 | }, 27 | {"name": "RandomGrayscale", "p": 0.2}, 28 | { 29 | "name": "MaybeApply", 30 | "p": 0.5, 31 | "transform": { 32 | "name": "Blur", 33 | "radius": { 34 | "sampler_type": "uniform", 35 | "low": 0.1, 36 | "high": 2 37 | } 38 | } 39 | }, 40 | {"name": "ToTensor"}, 41 | {"name": "Normalize", "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} 42 | ] 43 | """ 44 | 45 | _strong_blur_config = """ 46 | [ 47 | {"name": "RandomResizedCrop", "size": %(train_image_size)d}, 48 | {"name": "RandomHorizontalFlip"}, 49 | { 50 | "name": "MaybeApply", 51 | "p": 0.8, 52 | "transform": { 53 | "name": "ColorJitter", 54 | "brightness": 0.8, 55 | "contrast": 0.8, 56 | "saturation": 0.8, 57 | "hue": 0.2 58 | } 59 | }, 60 | {"name": "RandomGrayscale", "p": 0.2}, 61 | { 62 | "name": "MaybeApply", 63 | "p": 0.5, 64 | "transform": { 65 | "name": "Blur", 66 | "radius": { 67 | "sampler_type": "uniform", 68 | "low": 1, 69 | "high": 5 70 | } 71 | } 72 | }, 73 | {"name": "ToTensor"}, 74 | {"name": "Normalize", "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} 75 | ] 76 | """ 77 | 78 | _advanced_augs_config = """ 79 | [ 80 | {"name": "RandomHorizontalFlip"}, 81 | { 82 | "name": "MaybeApply", 83 | "p": 0.05, 84 | "transform": { 85 | "name": "Rotate", 86 | "degrees_ccw": { 87 | "sampler_type": "choice", 88 | "values": [90, 180, 270] 89 | } 90 | } 91 | }, 92 | { 93 | "name": "MaybeApply", 94 | "p": 0.1, 95 | "transform": { 96 | "name": "OverlayText", 97 | "font_size": {"sampler_type": "uniform", "low": 0.1, "high": 0.3}, 98 | "opacity": {"sampler_type": "uniform", "low": 0.1, "high": 1}, 99 | "color": { 100 | "sampler_type": "tuple", 101 | "samplers": [ 102 | {"sampler_type": "uniformint", "low": 0, "high": 255}, 103 | {"sampler_type": "uniformint", "low": 0, "high": 255}, 104 | {"sampler_type": "uniformint", "low": 0, "high": 255} 105 | ] 106 | }, 107 | "fx": {"sampler_type": "uniform", "low": 0, "high": 1}, 108 | "fy": {"sampler_type": "uniform", "low": 0, "high": 1} 109 | } 110 | }, 111 | { 112 | "name": "MaybeApply", 113 | "p": 0.2, 114 | "transform": { 115 | "name": "OverlayEmoji", 116 | "emoji_size": {"sampler_type": "uniform", "low": 0.1, "high": 0.5}, 117 | "opacity": {"sampler_type": "uniform", "low": 0.7, "high": 1}, 118 | "fx": {"sampler_type": "uniform", "low": 0, "high": 1}, 119 | "fy": {"sampler_type": "uniform", "low": 0, "high": 1} 120 | } 121 | }, 122 | { 123 | "name": "MaybeApply", 124 | "p": 0.05, 125 | "transform": { 126 | "name": "Rotate", 127 | "degrees_ccw": { 128 | "sampler_type": "uniformint", 129 | "low": 0, 130 | "high": 359 131 | } 132 | } 133 | }, 134 | {"name": "RandomResizedCrop", "size": %(train_image_size)d}, 135 | { 136 | "name": "MaybeApply", 137 | "p": 0.8, 138 | "transform": { 139 | "name": "ColorJitter", 140 | "brightness": 0.8, 141 | "contrast": 0.8, 142 | "saturation": 0.8, 143 | "hue": 0.2 144 | } 145 | }, 146 | {"name": "RandomGrayscale", "p": 0.2}, 147 | { 148 | "name": "MaybeApply", 149 | "p": 0.5, 150 | "transform": { 151 | "name": "Blur", 152 | "radius": { 153 | "sampler_type": "uniform", 154 | "low": 1, 155 | "high": 5 156 | } 157 | } 158 | }, 159 | { 160 | "name": "MaybeApply", 161 | "p": 0.2, 162 | "transform": { 163 | "name": "JpegCompress", 164 | "quality": { 165 | "sampler_type": "uniformint", 166 | "low": 0, 167 | "high": 100 168 | } 169 | } 170 | }, 171 | {"name": "ToTensor"}, 172 | {"name": "Normalize", "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} 173 | ] 174 | """ 175 | 176 | 177 | class AugmentationSetting(enum.Enum): 178 | """Augmentation configs explored in the SSCD paper.""" 179 | 180 | SIMCLR = enum.auto() 181 | STRONG_BLUR = enum.auto() 182 | ADVANCED = enum.auto() 183 | 184 | def get_transformations(self, image_size): 185 | config = self._get_config(self) % {"train_image_size": image_size} 186 | config = json.loads(config) 187 | return build_transforms(config) 188 | 189 | def _get_config(self, value): 190 | return { 191 | self.SIMCLR: _simclr_config, 192 | self.STRONG_BLUR: _strong_blur_config, 193 | self.ADVANCED: _advanced_augs_config, 194 | }[value] 195 | -------------------------------------------------------------------------------- /sscd/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from typing import Any, Dict 9 | from classy_vision.dataset.transforms import ( 10 | ClassyTransform, 11 | build_transform, 12 | register_transform, 13 | ) 14 | from augly.image import encoding_quality 15 | from PIL import Image, ImageFilter 16 | from torchvision import transforms 17 | from .samplers import Sampler 18 | 19 | 20 | @register_transform("Blur") 21 | class BlurTransform(ClassyTransform): 22 | """ 23 | Applies Gaussian blur to image 24 | """ 25 | 26 | def __init__(self, radius_sampler: Sampler): 27 | self._r_sampler = radius_sampler 28 | 29 | def __call__(self, image: Image.Image): 30 | im_filter = ImageFilter.GaussianBlur(radius=self._r_sampler()) 31 | image_filtered = image.filter(im_filter) 32 | return image_filtered 33 | 34 | @classmethod 35 | def from_config(cls, config: Dict[str, Any]) -> "BlurTransform": 36 | """ 37 | Load blur transform from a config 38 | Examples: 39 | 1. Deterministic blur with radius = 3: 40 | { 41 | "name": "Blur", 42 | "radius": 3.0 43 | } 44 | 2. Blur with uniform random radius: 45 | { 46 | "name": "Blur", 47 | "radius": { 48 | "sampler_type": "uniform", 49 | "low": 0.0, 50 | "high": 5.0 51 | } 52 | } 53 | """ 54 | radius_sampler = Sampler.from_config(config["radius"]) 55 | return cls(radius_sampler=radius_sampler) 56 | 57 | 58 | @register_transform("Rotate") 59 | class RotateTransform(ClassyTransform): 60 | """ 61 | Applies rotation to image 62 | """ 63 | 64 | def __init__(self, angle_deg_sampler: Sampler): 65 | self._angle_deg_sampler = angle_deg_sampler 66 | 67 | def __call__(self, image: Image.Image): 68 | image_rotated = image.rotate( 69 | angle=self._angle_deg_sampler(), resample=Image.BILINEAR 70 | ) 71 | return image_rotated 72 | 73 | @classmethod 74 | def from_config(cls, config: Dict[str, Any]) -> "RotateTransform": 75 | """ 76 | Load rotate transform from a config 77 | Examples: 78 | 1. Deterministic rotation on 45 degrees: 79 | { 80 | "name": "Rotate", 81 | "degrees_ccw": 45.0 82 | } 83 | 2. Random uniform rotation: 84 | { 85 | "name": "Rotate", 86 | "degrees_ccw": { 87 | "sampler_type": "uniform", 88 | "low": 0.0, 89 | "high": 90.0 90 | } 91 | } 92 | """ 93 | angle_sampler = Sampler.from_config(config["degrees_ccw"]) 94 | return cls(angle_deg_sampler=angle_sampler) 95 | 96 | 97 | @register_transform("JpegCompress") 98 | class JpegCompressTransform(ClassyTransform): 99 | """ 100 | Compresses an image with lower bitrate JPEG to make compression 101 | artifacts appear on the resulting image 102 | """ 103 | 104 | def __init__(self, quality_sampler: Sampler): 105 | """ 106 | Args: 107 | quality_sampler: sampler of JPEG quality values (integers in [0, 100]) 108 | """ 109 | self._q_sampler = quality_sampler 110 | 111 | def __call__(self, image: Image.Image): 112 | image_transformed = encoding_quality(image, quality=self._q_sampler()) 113 | return image_transformed 114 | 115 | @classmethod 116 | def from_config(cls, config: Dict[str, Any]) -> "JpegCompressTransform": 117 | """ 118 | Load JPEG compression transform from a config 119 | Examples: 120 | 1. Deterministic compression with quality = 15: 121 | { 122 | "name": "JpegCompress", 123 | "quality": 15 124 | } 125 | 2. Compression with uniformly sampled quality: 126 | { 127 | "name": "JpegCompress", 128 | "quality": { 129 | "sampler_type": "uniformint", 130 | "low": 0, 131 | "high": 100 132 | } 133 | } 134 | """ 135 | quality_sampler = Sampler.from_config(config["quality"]) 136 | return cls(quality_sampler=quality_sampler) 137 | 138 | 139 | @register_transform("MaybeApply") 140 | class MaybeApplyTransform(ClassyTransform): 141 | """A Classy version of RandomApply. 142 | 143 | This is just shorthand for the `n = 1` case of BinomialWrapper, which is a 144 | common case. 145 | """ 146 | 147 | def __init__(self, p, transform): 148 | self.p = p 149 | self.transform = transform 150 | 151 | def __call__(self, image: Image.Image) -> Image.Image: 152 | if random.random() < self.p: 153 | image = self.transform(image) 154 | return image 155 | 156 | @classmethod 157 | def from_config(cls, config: Dict[str, Any]): 158 | p = config["p"] 159 | transform = build_transform(config["transform"]) 160 | return cls(p, transform) 161 | 162 | 163 | @register_transform("ResizeLongEdge") 164 | class ResizeLongEdge(ClassyTransform): 165 | """Resize the long edge of an image to a target size.""" 166 | 167 | def __init__(self, size): 168 | self.size = size 169 | 170 | def __call__(self, image: Image.Image) -> Image.Image: 171 | scale = self.size / max(image.size) 172 | h, w = image.size 173 | if h > w: 174 | h = self.size 175 | w = int(scale * w + 0.5) 176 | else: 177 | w = self.size 178 | h = int(scale * h + 0.5) 179 | return transforms.Resize((w, h))(image) 180 | 181 | @classmethod 182 | def from_config(cls, config: Dict[str, Any]): 183 | return cls(config["size"]) 184 | --------------------------------------------------------------------------------