├── .codecov.yml ├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE │ ├── bugs.md │ └── new_method.md └── workflows │ ├── autoblack.yml │ ├── dali_tests.yml │ └── tests.yml ├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── README.md ├── bash_files ├── knn │ └── imagenet-100 │ │ └── knn.sh ├── linear │ ├── cifar │ │ ├── byol_50.sh │ │ ├── byol_ica.sh │ │ ├── sdbyol.sh │ │ ├── sdbyol_50.sh │ │ └── sdmoco.sh │ ├── imagenet-100 │ │ ├── barlow_linear.sh │ │ ├── byol_linear.sh │ │ ├── deepclusterv2_linear.sh │ │ ├── dino_linear.sh │ │ ├── general_linear.sh │ │ ├── mocov2plus_linear.sh │ │ ├── nnclr_linear.sh │ │ ├── ressl_linear.sh │ │ ├── simclr_linear.sh │ │ ├── simsiam_linear.sh │ │ ├── swav_linear.sh │ │ ├── vibcreg_linear.sh │ │ └── vicreg_linear.sh │ └── imagenet │ │ ├── barlow.sh │ │ ├── byol.sh │ │ ├── mocov2plus.sh │ │ └── sdbyol.sh ├── pretrain │ ├── cifar │ │ ├── barlow.sh │ │ ├── byol.sh │ │ ├── byol_50.sh │ │ ├── byol_large_50.sh │ │ ├── byol_rnlarge2.sh │ │ ├── deepclusterv2.sh │ │ ├── dino.sh │ │ ├── launch_comparison.sh │ │ ├── mocov2plus.sh │ │ ├── nnbyol.sh │ │ ├── nnclr.sh │ │ ├── nnsiam.sh │ │ ├── ressl.sh │ │ ├── sdbarlow.sh │ │ ├── sdbyol.sh │ │ ├── sdbyol_50.sh │ │ ├── sddino.sh │ │ ├── sdmoco.sh │ │ ├── sdsimclr.sh │ │ ├── sdswav.sh │ │ ├── sdvicreg.sh │ │ ├── simclr.sh │ │ ├── simsiam.sh │ │ ├── supcon.sh │ │ ├── swav.sh │ │ ├── vibcreg.sh │ │ ├── vicreg.sh │ │ └── wmse.sh │ ├── custom │ │ └── byol.sh │ ├── imagenet-100 │ │ ├── barlow.sh │ │ ├── byol.sh │ │ ├── deepclusterv2.sh │ │ ├── dino.sh │ │ ├── dino_vit.sh │ │ ├── mocov2plus.sh │ │ ├── multicrop │ │ │ ├── byol.sh │ │ │ ├── simclr.sh │ │ │ └── supcon.sh │ │ ├── nnclr.sh │ │ ├── ressl.sh │ │ ├── simclr.sh │ │ ├── simsiam.sh │ │ ├── supcon.sh │ │ ├── swav.sh │ │ ├── vibcreg.sh │ │ ├── vicreg.sh │ │ └── wmse.sh │ └── imagenet │ │ ├── barlow.sh │ │ ├── byol.sh │ │ ├── byol_orion.sh │ │ ├── mocov2plus.sh │ │ ├── orion_config.yaml │ │ ├── sdbyol.sh │ │ ├── sdbyol_800.sh │ │ ├── sdbyol_test.sh │ │ └── sdmoco.sh └── umap │ └── imagenet-100 │ └── umap.sh ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── conf.py │ ├── index.rst │ ├── solo │ ├── args.rst │ ├── data.rst │ ├── losses │ │ ├── barlow.rst │ │ ├── byol.rst │ │ ├── deepclusterv2.rst │ │ ├── dino.rst │ │ ├── mocov2plus.rst │ │ ├── nnclr.rst │ │ ├── ressl.rst │ │ ├── simclr.rst │ │ ├── simsiam.rst │ │ ├── swav.rst │ │ ├── vibcreg.rst │ │ ├── vicreg.rst │ │ └── wmse.rst │ ├── methods │ │ ├── barlow.rst │ │ ├── base.rst │ │ ├── byol.rst │ │ ├── dali.rst │ │ ├── deepclusterv2.rst │ │ ├── dino.rst │ │ ├── linear.rst │ │ ├── mocov2plus.rst │ │ ├── nnbyol.rst │ │ ├── nnclr.rst │ │ ├── nnsiam.rst │ │ ├── ressl.rst │ │ ├── simclr.rst │ │ ├── simsiam.rst │ │ ├── swav.rst │ │ ├── vibcreg.rst │ │ ├── vicreg.rst │ │ └── wmse.rst │ └── utils.rst │ ├── start │ ├── available.rst │ └── install.rst │ └── tutorials │ ├── add_new_method.rst │ ├── add_new_method_momentum.rst │ ├── imgs │ ├── im100_train_umap_barlow.png │ ├── im100_train_umap_byol.png │ ├── im100_train_umap_random.png │ ├── im100_val_umap_barlow.png │ ├── im100_val_umap_byol.png │ └── im100_val_umap_random.png │ ├── knn.rst │ ├── offline_linear_eval.rst │ ├── overview.rst │ └── umap.rst ├── downstream └── object_detection │ ├── README.md │ ├── configs │ ├── Base-RCNN-C4-BN.yaml │ ├── coco_R_50_C4_2x.yaml │ ├── coco_R_50_C4_2x_moco.yaml │ ├── pascal_voc_R_50_C4_24k.yaml │ └── pascal_voc_R_50_C4_24k_moco.yaml │ ├── convert_model_to_detectron2.py │ ├── run.sh │ └── train_object_detection.py ├── main_knn.py ├── main_linear.py ├── main_pretrain.py ├── main_transfer.py ├── main_umap.py ├── orion_config.yaml ├── prepare_data.sh ├── requirements.txt ├── setup.py ├── solo ├── __init__.py ├── args │ ├── __init__.py │ ├── dataset.py │ ├── setup.py │ └── utils.py ├── losses │ ├── __init__.py │ ├── barlow.py │ ├── byol.py │ ├── deepclusterv2.py │ ├── dino.py │ ├── moco.py │ ├── nnclr.py │ ├── ressl.py │ ├── simclr.py │ ├── simsiam.py │ ├── swav.py │ ├── vibcreg.py │ ├── vicreg.py │ └── wmse.py ├── methods │ ├── REINFORCEbyol.py │ ├── __init__.py │ ├── barlow_twins.py │ ├── base.py │ ├── byol.py │ ├── byol_ica.py │ ├── dali.py │ ├── deepclusterv2.py │ ├── dino.py │ ├── gstbyol.py │ ├── largebyol.py │ ├── linear.py │ ├── linear_control.py │ ├── linear_fine.py │ ├── linear_masked.py │ ├── mask_linear.py │ ├── mocov2plus.py │ ├── nnbyol.py │ ├── nnclr.py │ ├── nnsiam.py │ ├── partsdbyol.py │ ├── ressl.py │ ├── sdbt.py │ ├── sdbyol.py │ ├── sddino.py │ ├── sdmocov2plus.py │ ├── sdsimclr.py │ ├── sdswav.py │ ├── sdvicreg.py │ ├── simclr.py │ ├── simsiam.py │ ├── supcon.py │ ├── swav.py │ ├── vibcreg.py │ ├── vicreg.py │ └── wmse.py └── utils │ ├── __init__.py │ ├── auto_resumer.py │ ├── auto_umap.py │ ├── backbones.py │ ├── checkpointer.py │ ├── classification_dataloader.py │ ├── dali_dataloader.py │ ├── extract_features.py │ ├── kmeans.py │ ├── knn.py │ ├── lars.py │ ├── metrics.py │ ├── misc.py │ ├── momentum.py │ ├── pretrain_dataloader.py │ ├── sinkhorn_knopp.py │ └── whitening.py ├── start.sh ├── tests ├── __init__.py ├── args │ ├── __init__.py │ ├── test_datasets.py │ └── test_setup.py ├── dali │ ├── __init__.py │ ├── test_dali_dataloader.py │ ├── test_dali_model.py │ └── utils.py ├── losses │ ├── __init__.py │ ├── test_barlow.py │ ├── test_byol.py │ ├── test_dino.py │ ├── test_moco.py │ ├── test_nnclr.py │ ├── test_ressl.py │ ├── test_simclr.py │ ├── test_simsiam.py │ ├── test_swav.py │ ├── test_vibcreg.py │ ├── test_vicreg.py │ └── test_wmse.py ├── methods │ ├── __init__.py │ ├── test_barlow_twins.py │ ├── test_base.py │ ├── test_byol.py │ ├── test_deepclusterv2.py │ ├── test_dino.py │ ├── test_linear.py │ ├── test_mocov2plus.py │ ├── test_nnbyol.py │ ├── test_nnclr.py │ ├── test_nnsiam.py │ ├── test_ressl.py │ ├── test_simclr.py │ ├── test_simsiam.py │ ├── test_supcon.py │ ├── test_swav.py │ ├── test_vibcreg.py │ ├── test_vicreg.py │ ├── test_wmse.py │ └── utils.py └── utils │ ├── __init__.py │ ├── test_auto_resumer.py │ ├── test_auto_umap.py │ ├── test_backbones.py │ ├── test_checkpointer.py │ ├── test_classification_dataloader.py │ ├── test_filter.py │ ├── test_gather.py │ ├── test_kmeans.py │ ├── test_knn.py │ ├── test_metrics.py │ └── test_pretrain_dataloader.py └── zoo ├── cifar10.sh ├── cifar100.sh ├── imagenet.sh └── imagenet100.sh /.codecov.yml: -------------------------------------------------------------------------------- 1 | comment: 2 | layout: "flags, files" 3 | behavior: default 4 | require_changes: false 5 | require_base: no 6 | require_head: no 7 | show_carryforward_flags: true 8 | 9 | flag_management: 10 | default_rules: 11 | carryforward: true 12 | 13 | coverage: 14 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 15 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to solo-learn 2 | We want to make contributing to this project as easy and transparent as possible. Also, we encourage the community to help and improve the library. 3 | 4 | ## Our Development Process 5 | We fix issues as we discover them, either by running the code or by checking the issues tab on github. Major changes (e.g., new methods, other downstream tasks, additional features etc), are added periodically and based on how we judge its priority. 6 | 7 | ## Issues 8 | We use GitHub issues to track public bugs and questions. We will make templates for each type of issue in [issue templates](https://github.com/vturrisi/solo-learn/issues/new/choose), but for now, check if your issue has a fitting template or please be as verbose and clear as possible. Also, provide ways to reproduce the issue (e.g., bash script). 9 | 10 | 11 | ## Pull Requests 12 | We actively welcome pull requests. 13 | 14 | Before you start implementing a new method, make sure that the method has an accompaning paper (arxiv versions are fine) with competitive experimental results. 15 | 16 | When implementing the new method: 17 | 1. Fork the repo and create your branch from `main`. 18 | 2. Extend the available `BaseMethod` and `BaseMomentumMethod` classes, or even the methods that are already implemented. Your new method should reuse our base structures or provide enough justification for changing/incrementing the base structures if needed. 19 | 3. Provide a clear and simple implementation following our code style. 20 | 4. Provide unit tests. 21 | 5. Modify the documentation to add the new method and accompanying features. We use Sphinx for that. 22 | 6. Provide bash files for running the method. 23 | 7. Reproduce competitive results on at least CIFAR-10 and CIFAR-100. Imagenet-100/Imagenet results are also desirable, but we can run those eventually. 24 | 8. Ensure that your tests by running `python -m pytest --cov=solo tests`. 25 | 26 | If you are fixing a bug, please open an issue beforehand. If should also modify the library as minimally as possible. 27 | 28 | When implementing extensions for the library, detail why this is important by providing use cases. 29 | 30 | ## Coding Style 31 | 32 | Please follow our coding style and use [black](https://github.com/psf/black) with line length of 100. We will update this section with a more hands-on guide. 33 | 34 | ## License 35 | By contributing to solo-learn, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bugs.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | --- 4 | name: "🐛 Bugs" 5 | about: Report bugs in solo-learn 6 | title: Please read & provide the following 7 | 8 | --- 9 | 10 | ## Instructions To Reproduce the 🐛 Bug: 11 | 12 | 1. what changes you made (`git diff`) or what code you wrote 13 | ``` 14 | 15 | ``` 16 | 2. what exact command you run: 17 | 3. what you observed (including __full logs__): 18 | ``` 19 | 20 | ``` 21 | 4. please simplify the steps as much as possible so they do not require additional resources to 22 | run, such as a private dataset. 23 | 24 | ## Expected behavior: 25 | 26 | If there are no obvious error in "what you observed" provided above, 27 | please tell us the expected behavior. 28 | 29 | ## Environment: 30 | 31 | Provide information about your environment. 32 | 33 | ## When to expect Triage 34 | 35 | We try to answer as soon as possible, but time may vary according to other deadlines and the importance of the bug reported. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/new_method.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | --- 4 | name: "\U0001F31F New SSL method addition" 5 | about: Submit a proposal of implementation/request to implement a new SSL method in solo-learn. 6 | 7 | --- 8 | 9 | # 🌟 New SSL method addition 10 | 11 | ## Approach description 12 | 13 | Important information only describing the approach, link to the arXiv paper. 14 | 15 | ## Status of the current implement 16 | 17 | Information regarding stuff that you have already implemented for the method. 18 | As a general checklist, you should have: 19 | * [ ] the method implemented as new file in `solo/methods` and added to `solo/methods/__init__.py` 20 | * [ ] the loss implemented as new file in `solo/losses` and added to `solo/losses/__init__.py` 21 | * [ ] bash files to run experiments 22 | * [ ] tests for the method and the loss(es) in `tests/methods` and `tests/losses`. 23 | * [ ] add documentation to `docs/source/solo/methods` and `docs/source/solo/losses`, and modify `docs/source/index.rst` to reference the new files (use alphabetical order). 24 | * [ ] check if all tests pass by running `python -m pytest --cov=solo tests`. 25 | * [ ] reproduce competitive results on at least CIFAR10/CIFAR100 (possibly Imagenet-100 and Imagenet). 26 | 27 | -------------------------------------------------------------------------------- /.github/workflows/autoblack.yml: -------------------------------------------------------------------------------- 1 | name: autoblack 2 | on: 3 | pull_request: 4 | branches: [main] 5 | 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v1 11 | - name: Set up Python 3.7 12 | uses: actions/setup-python@v1 13 | with: 14 | python-version: 3.7 15 | - name: Install Black 16 | run: pip install black 17 | 18 | - name: Run black check on solo 19 | run: black --check --line-length 100 solo 20 | 21 | - name: If needed, commit black changes to the pull request 22 | if: failure() 23 | run: | 24 | black --line-length 100 solo 25 | git config --global user.name 'autoblack' 26 | git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY 27 | git checkout $GITHUB_HEAD_REF 28 | git commit -am "fixup: format solo with Black" 29 | git push 30 | 31 | - name: Run black --check on tests 32 | run: black --check --line-length 100 tests 33 | 34 | - name: If needed, commit black changes to the pull request 35 | if: failure() 36 | run: | 37 | black --line-length 100 tests 38 | git config --global user.name 'autoblack' 39 | git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY 40 | git checkout $GITHUB_HEAD_REF 41 | git commit -am "fixup: format tests with Black" 42 | git push -------------------------------------------------------------------------------- /.github/workflows/dali_tests.yml: -------------------------------------------------------------------------------- 1 | name: dali-tests 2 | 3 | # never run automatically since it needs gpus 4 | on: 5 | push: 6 | branches-ignore: 7 | - '**' 8 | 9 | jobs: 10 | test: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python: [3.7, 3.8, 3.9] 16 | os: [ubuntu-latest] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | 21 | - name: Set up Python ${{ matrix.python }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python }} 25 | 26 | - name: Cache Python dependencies 27 | uses: actions/cache@v2 28 | with: 29 | path: ~/.cache/pip 30 | key: ${{ runner.os }}-${{ matrix.python }}-pip-${{ hashFiles('**/setup.py') }} 31 | restore-keys: | 32 | ${{ runner.os }}-${{ matrix.python }}-pip- 33 | 34 | - name: Install Python dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install -e .[umap] codecov mypy pytest-cov black 38 | 39 | - name: Cache datasets 40 | uses: actions/cache@v2 41 | with: 42 | path: ~/datasets 43 | key: ${{ runner.os }} 44 | 45 | - name: pytest 46 | run: pytest --cov=solo tests/dali 47 | 48 | - name: Statistics 49 | if: success() 50 | run: | 51 | coverage report 52 | coverage xml 53 | 54 | - name: Upload coverage to Codecov 55 | uses: codecov/codecov-action@v1 56 | if: always() 57 | # see: https://github.com/actions/toolkit/issues/399 58 | continue-on-error: true 59 | with: 60 | token: ${{ secrets.CODECOV_TOKEN }} 61 | file: coverage.xml 62 | flags: dali 63 | name: DALI-coverage 64 | fail_ci_if_error: false -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | test: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python: [3.7, 3.8, 3.9] 16 | os: [ubuntu-latest, macos-latest, windows-latest] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | 21 | - name: Set up Python ${{ matrix.python }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python }} 25 | 26 | - name: Cache Python dependencies 27 | uses: actions/cache@v2 28 | with: 29 | path: ~/.cache/pip 30 | key: ${{ runner.os }}-${{ matrix.python }}-pip-${{ hashFiles('**/setup.py') }} 31 | restore-keys: | 32 | ${{ runner.os }}-${{ matrix.python }}-pip- 33 | 34 | - name: Install Python dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install -e .[umap] codecov mypy pytest-cov black 38 | 39 | - name: Cache datasets 40 | uses: actions/cache@v2 41 | with: 42 | path: ~/datasets 43 | key: ${{ runner.os }} 44 | 45 | - name: pytest 46 | run: pytest --cov=solo tests/args tests/losses tests/methods tests/utils 47 | 48 | - name: Statistics 49 | if: success() 50 | run: | 51 | coverage report 52 | coverage xml 53 | 54 | - name: Upload coverage to Codecov 55 | uses: codecov/codecov-action@v1 56 | if: always() 57 | # see: https://github.com/actions/toolkit/issues/399 58 | continue-on-error: true 59 | with: 60 | token: ${{ secrets.CODECOV_TOKEN }} 61 | file: coverage.xml 62 | flags: cpu 63 | name: Coverage 64 | fail_ci_if_error: false -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | image: latest 5 | 6 | # This part is necessary otherwise the project is not built 7 | # Optionally set the version of Python and requirements required to build your docs 8 | python: 9 | version: 3.8 10 | install: 11 | - requirements: docs/requirements.txt 12 | - method: setuptools 13 | path: . 14 | 15 | # By default readthedocs does not checkout git submodules 16 | submodules: 17 | include: all -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 solo-learn development team. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to use, 6 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | Software, and to permit persons to whom the Software is furnished to do so, 8 | subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all copies 11 | or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simplicial Embeddings for Self-Supervised Learning and Downstream Classification 2 | This repository is the companion code for the article Simplicial Embeddings for Self-supervised Learning and Downstream Classification. It is a fork of the Self-Supervised learning library `solo-learn` to which we apply the necessary modifications to run the experiments in the paper. 3 | 4 | ## SEM module 5 | The SEM module can be implemented as follows: 6 | ``` 7 | class SEM(nn.Module): 8 | def __init__(self, L, V, tau, **kwargs): 9 | super().__init__() 10 | self.L = L 11 | self.V = V 12 | self.tau = tau 13 | 14 | def forward(self, x): 15 | logits = x.view(-1, self.L, self.V) 16 | taus = self.tau 17 | return F.softmax(logits / taus, -1).view(x.shape[0], -1) 18 | ``` 19 | 20 | ## Citation 21 | To cite our article, please cite: 22 | ``` 23 | @inproceedings{ 24 | lavoie2023simplicial, 25 | title={Simplicial Embeddings in Self-Supervised Learning and Downstream Classification}, 26 | author={Samuel Lavoie and Christos Tsirigotis and Max Schwarzer and Ankit Vani and Michael Noukhovitch and Kenji Kawaguchi and Aaron Courville}, 27 | booktitle={International Conference on Learning Representations}, 28 | year={2023}, 29 | url={https://openreview.net/forum?id=RWtGreRpovS} 30 | } 31 | ``` 32 | 33 | To cite `solo-learn`, please cite their [paper](https://jmlr.org/papers/v23/21-1155.html): 34 | ``` 35 | @article{JMLR:v23:21-1155, 36 | author = {Victor Guilherme Turrisi da Costa and Enrico Fini and Moin Nabi and Nicu Sebe and Elisa Ricci}, 37 | title = {solo-learn: A Library of Self-supervised Methods for Visual Representation Learning}, 38 | journal = {Journal of Machine Learning Research}, 39 | year = {2022}, 40 | volume = {23}, 41 | number = {56}, 42 | pages = {1-6}, 43 | url = {http://jmlr.org/papers/v23/21-1155.html} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /bash_files/knn/imagenet-100/knn.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_knn.py \ 2 | --dataset imagenet100 \ 3 | --data_dir /datasets \ 4 | --train_dir imagenet-100/train \ 5 | --val_dir imagenet-100/val \ 6 | --batch_size 16 \ 7 | --num_workers 10 \ 8 | --pretrained_checkpoint_dir PATH \ 9 | --k 1 2 5 10 20 50 100 200 \ 10 | --temperature 0.01 0.02 0.05 0.07 0.1 0.2 0.5 1 \ 11 | --feature_type backbone projector \ 12 | --distance_function euclidean cosine 13 | -------------------------------------------------------------------------------- /bash_files/linear/cifar/byol_50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | source ~/env/bin/activate 5 | 6 | cp -r datasets $SLURM_TMPDIR 7 | 8 | python ../../../main_linear.py \ 9 | --linear_base True \ 10 | --dataset cifar100 \ 11 | --backbone resnet50 \ 12 | --data_dir $SLURM_TMPDIR/datasets \ 13 | --checkpoint_dir $2 \ 14 | --pretrained_feature_extractor $1 \ 15 | --max_epochs 100 \ 16 | --gpus 0 \ 17 | --accelerator gpu \ 18 | --save_checkpoint \ 19 | --precision 16 \ 20 | --method linear \ 21 | --optimizer sgd \ 22 | --scheduler step \ 23 | --lr_decay_steps 60 80 \ 24 | --weight_decay 0 \ 25 | --batch_size 256 \ 26 | --num_workers 5 \ 27 | --crop_size 32 \ 28 | --name class_sdbyol \ 29 | --class_base True \ 30 | --name byol \ 31 | ${@:3} 32 | 33 | -------------------------------------------------------------------------------- /bash_files/linear/cifar/byol_ica.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | source ~/env/bin/activate 5 | 6 | cp -r datasets $SLURM_TMPDIR 7 | 8 | python ../../../main_linear.py \ 9 | --dataset cifar100 \ 10 | --backbone resnet50 \ 11 | --data_dir $SLURM_TMPDIR/datasets \ 12 | --checkpoint_dir $2 \ 13 | --pretrained_feature_extractor $1 \ 14 | --max_epochs 100 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --save_checkpoint \ 18 | --precision 16 \ 19 | --method linear \ 20 | --ica True \ 21 | --optimizer sgd \ 22 | --scheduler step \ 23 | --lr_decay_steps 60 80 \ 24 | --weight_decay 0 \ 25 | --batch_size 256 \ 26 | --num_workers 5 \ 27 | --crop_size 32 \ 28 | --name class_sdbyol \ 29 | --class_base True \ 30 | --name byol \ 31 | --sem False \ 32 | ${@:3} 33 | 34 | -------------------------------------------------------------------------------- /bash_files/linear/cifar/sdbyol.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | source ~/class_env/bin/activate 5 | 6 | cp -r datasets $SLURM_TMPDIR 7 | 8 | python ../../../main_linear.py \ 9 | --dataset cifar100 \ 10 | --backbone resnet18 \ 11 | --data_dir $SLURM_TMPDIR/datasets \ 12 | --max_epochs 100 \ 13 | --gpus 0 \ 14 | --accelerator gpu \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --scheduler step \ 18 | --lr 0.1 \ 19 | --lr_decay_steps 60 80 \ 20 | --weight_decay 0 \ 21 | --batch_size 256 \ 22 | --num_workers 5 \ 23 | --crop_size 32 \ 24 | --name class_sdbyol \ 25 | --pretrained_feature_extractor $1 \ 26 | --taus 0.0001 0.001 0.01 0.1 1 10 \ 27 | --eval_taus 0.0001 0.001 0.01 0.1 1 10 \ 28 | --lrs 0.01 0.1 0.2 \ 29 | --wd1 0 1e-5 \ 30 | --wd2 0 \ 31 | --class_base True \ 32 | --name byol-resnet18-imagenet-linear-eval 33 | 34 | -------------------------------------------------------------------------------- /bash_files/linear/cifar/sdbyol_50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | source ~/env/bin/activate 5 | 6 | cp -r datasets $SLURM_TMPDIR 7 | 8 | python ../../../main_linear.py \ 9 | --dataset cifar100 \ 10 | --backbone resnet50 \ 11 | --data_dir $SLURM_TMPDIR/datasets \ 12 | --checkpoint_dir $2 \ 13 | --pretrained_feature_extractor $1 \ 14 | --max_epochs 100 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --save_checkpoint \ 18 | --precision 16 \ 19 | --optimizer sgd \ 20 | --scheduler step \ 21 | --lr 0.1 \ 22 | --method linear \ 23 | --lr_decay_steps 60 80 \ 24 | --weight_decay 0 \ 25 | --batch_size 256 \ 26 | --num_workers 5 \ 27 | --crop_size 32 \ 28 | --name class_sdbyol \ 29 | --lrs 0.01 \ 30 | --wd1 1e-5 \ 31 | --wd2 0 \ 32 | --class_base False \ 33 | ${@:3} 34 | 35 | -------------------------------------------------------------------------------- /bash_files/linear/cifar/sdmoco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | source ~/class_env/bin/activate 5 | 6 | cp -r datasets $SLURM_TMPDIR 7 | 8 | python ../../../main_linear.py \ 9 | --dataset cifar100 \ 10 | --backbone resnet50 \ 11 | --data_dir $SLURM_TMPDIR/datasets \ 12 | --checkpoint_dir $2 \ 13 | --pretrained_feature_extractor $1 \ 14 | --max_epochs 100 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --save_checkpoint \ 18 | --precision 16 \ 19 | --optimizer sgd \ 20 | --scheduler step \ 21 | --lr 0.1 \ 22 | --method linear \ 23 | --lr_decay_steps 60 80 \ 24 | --weight_decay 0 \ 25 | --batch_size 256 \ 26 | --num_workers 5 \ 27 | --crop_size 32 \ 28 | --name class_sdmoco \ 29 | --lrs 0.05 \ 30 | --wd1 1e-6 \ 31 | --wd2 0 \ 32 | --class_base False \ 33 | ${@:3} 34 | 35 | #--wd1 0 1e-6 1e-5 1e-4 \ 36 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/barlow_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.1 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 4 \ 17 | --dali \ 18 | --name barlow-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/byol_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.3 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 4 \ 17 | --dali \ 18 | --name byol-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/deepclusterv2_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /data/datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.15 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 5 \ 17 | --dali \ 18 | --name deepclusterv2-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH\ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/dino_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.3 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 4 \ 17 | --dali \ 18 | --name dino-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/general_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0,1 \ 9 | --strategy ddp \ 10 | --sync_batchnorm \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 0.1 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 128 \ 18 | --num_workers 10 \ 19 | --name simclr-linear-eval \ 20 | --pretrained_feature_extractor PATH \ 21 | --project contrastive_learning \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/mocov2plus_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 3.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 10 \ 17 | --dali \ 18 | --name mocov2plus-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/nnclr_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.3 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 10 \ 17 | --dali \ 18 | --name nnclr-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/ressl_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/test \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 3.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 10 \ 17 | --dali \ 18 | --name ressl-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/simclr_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /data/datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 1.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 10 \ 17 | --dali \ 18 | --name simclr-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/simsiam_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 30.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 10 \ 17 | --dali \ 18 | --name simsiam-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/swav_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /data/datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.15 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 5 \ 17 | --dali \ 18 | --name swav-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/vibcreg_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.3 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 512 \ 16 | --num_workers 5 \ 17 | --dali \ 18 | --name vibcreg-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/vicreg_linear.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.3 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 512 \ 16 | --num_workers 5 \ 17 | --dali \ 18 | --name vicreg-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb 23 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet/barlow.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet \ 3 | --backbone resnet50 \ 4 | --data_dir /data/datasets \ 5 | --train_dir imagenet/train \ 6 | --val_dir imagenet/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lr 0.3 \ 15 | --scheduler cosine \ 16 | --weight_decay 1e-5 \ 17 | --batch_size 256 \ 18 | --num_workers 10 \ 19 | --dali \ 20 | --pretrained_feature_extractor PATH \ 21 | --name barlow-resnet50-imagenet-linear-eval \ 22 | --entity unitn-mhug \ 23 | --project solo-learn \ 24 | --wandb 25 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet/byol.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_linear.py \ 2 | --dataset imagenet \ 3 | --backbone resnet50 \ 4 | --data_dir /data/datasets \ 5 | --train_dir imagenet/train \ 6 | --val_dir imagenet/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --scheduler step \ 15 | --lr 0.1 \ 16 | --lr_decay_steps 60 80 \ 17 | --weight_decay 0 \ 18 | --batch_size 256 \ 19 | --num_workers 10 \ 20 | --dali \ 21 | --pretrained_feature_extractor PATH \ 22 | --name byol-resnet50-imagenet-linear-eval \ 23 | --entity unitn-mhug \ 24 | --project solo-learn \ 25 | --wandb 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet/mocov2plus.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet \ 3 | --backbone resnet50 \ 4 | --data_dir /data/datasets \ 5 | --train_dir imagenet/train \ 6 | --val_dir imagenet/val \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 3.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 5 \ 17 | --dali \ 18 | --name mocov2plus-imagenet-linear-eval \ 19 | --pretrained_feature_extractor PATH \ 20 | --project solo-learn \ 21 | --entity unitn-mhug \ 22 | --wandb \ 23 | --save_checkpoint 24 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet/sdbyol.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT_PATH=$1 4 | 5 | ../../../prepare_data.sh TEST 6 | 7 | module load python/3.8 8 | source ~/env/bin/activate 9 | 10 | python3 ../../../main_linear.py \ 11 | --dataset imagenet \ 12 | --backbone resnet50 \ 13 | --checkpoint_dir=$ROOT_PATH \ 14 | --data_dir $SLURM_TMPDIR \ 15 | --train_dir $SLURM_TMPDIR/data/train \ 16 | --val_dir $SLURM_TMPDIR/data/val \ 17 | --max_epochs 100 \ 18 | --gpus 0,1 \ 19 | --accelerator gpu \ 20 | --strategy ddp \ 21 | --sync_batchnorm \ 22 | --dali \ 23 | --precision 16 \ 24 | --optimizer sgd \ 25 | --scheduler step \ 26 | --lr_decay_steps 60 80 \ 27 | --weight_decay 0 \ 28 | --batch_size 512 \ 29 | --num_workers 10 \ 30 | --pretrained_feature_extractor $2 \ 31 | --name byol-resnet50-imagenet-linear-eval \ 32 | --entity il_group \ 33 | --project VIL \ 34 | --save_checkpoint \ 35 | --taus 1.0 1.5 \ 36 | --lrs 0.2 0.5 \ 37 | --wd1 0 \ 38 | --wd2 1e-8 \ 39 | --wandb 40 | 41 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/barlow.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | source env/bin/activate 5 | 6 | cp -r datasets $SLURM_TMPDIR 7 | 8 | python ../../../main_pretrain.py \ 9 | --dataset cifar100 \ 10 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 11 | --backbone resnet18 \ 12 | --data_dir $SLURM_TMPDIR/datasets \ 13 | --max_epochs 1000 \ 14 | --gpus 0 \ 15 | --accelerator gpu \ 16 | --precision 16 \ 17 | --optimizer sgd \ 18 | --lars \ 19 | --grad_clip_lars \ 20 | --eta_lars 0.02 \ 21 | --exclude_bias_n_norm \ 22 | --scheduler warmup_cosine \ 23 | --lr 0.3 \ 24 | --weight_decay 1e-4 \ 25 | --batch_size 256 \ 26 | --num_workers 4 \ 27 | --brightness 0.4 \ 28 | --contrast 0.4 \ 29 | --saturation 0.2 \ 30 | --hue 0.1 \ 31 | --gaussian_prob 0.0 0.0 \ 32 | --solarization_prob 0.0 0.2 \ 33 | --crop_size 32 \ 34 | --num_crops_per_aug 1 1 \ 35 | --name baselines-barlow \ 36 | --project iclr \ 37 | --entity il_group \ 38 | --wandb \ 39 | --save_checkpoint \ 40 | --method barlow_twins \ 41 | --proj_hidden_dim 2048 \ 42 | --proj_output_dim 2048 \ 43 | --scale_loss 0.1 44 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/byol.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --exclude_bias_n_norm \ 23 | --scheduler warmup_cosine \ 24 | --lr 1.0 \ 25 | --classifier_lr 0.1 \ 26 | --batch_size 256 \ 27 | --num_workers 4 \ 28 | --brightness 0.4 \ 29 | --contrast 0.4 \ 30 | --saturation 0.2 \ 31 | --hue 0.1 \ 32 | --gaussian_prob 0.0 0.0 \ 33 | --solarization_prob 0.0 0.2 \ 34 | --crop_size 32 \ 35 | --num_crops_per_aug 1 1 \ 36 | --name baselines-byol \ 37 | --project iclr \ 38 | --entity il_group \ 39 | --wandb \ 40 | --weight_decay 1e-5 \ 41 | --save_checkpoint \ 42 | --method byol \ 43 | --proj_output_dim 256 \ 44 | --proj_hidden_dim 4096 \ 45 | --pred_hidden_dim 4096 \ 46 | --base_tau_momentum 0.99 \ 47 | --final_tau_momentum 1.0 \ 48 | --momentum_classifier 49 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/byol_50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source ~/env/bin/activate 6 | 7 | python3 ../../../main_pretrain.py \ 8 | --dataset $1 \ 9 | --backbone resnet50 \ 10 | --data_dir ./datasets \ 11 | --checkpoint_dir $SCRATCH/cifar_rn50_scale \ 12 | --max_epochs 1000 \ 13 | --gpus 0 \ 14 | --accelerator gpu \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.001 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.5 \ 23 | --classifier_lr 0.3 \ 24 | --weight_decay 1e-4 \ 25 | --batch_size 256 \ 26 | --num_workers 4 \ 27 | --brightness 0.4 \ 28 | --contrast 0.4 \ 29 | --saturation 0.2 \ 30 | --hue 0.1 \ 31 | --gaussian_prob 0.0 0.0 \ 32 | --solarization_prob 0.0 0.2 \ 33 | --crop_size 32 \ 34 | --num_crops_per_aug 1 1 \ 35 | --name byol-$1 \ 36 | --project PROJECT \ 37 | --entity ENTITY \ 38 | --wandb \ 39 | --offline \ 40 | --save_checkpoint \ 41 | --method byol \ 42 | --proj_output_dim 256 \ 43 | --proj_hidden_dim 4096 \ 44 | --pred_hidden_dim 4096 \ 45 | --base_tau_momentum 0.99 \ 46 | --final_tau_momentum 1.0 \ 47 | --momentum_classifier 48 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/byol_large_50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source ~/class_env/bin/activate 6 | 7 | python ../../../main_pretrain.py \ 8 | --dataset cifar100 \ 9 | --backbone resnet50 \ 10 | --data_dir ./datasets \ 11 | --checkpoint_dir $SCRATCH/cifar_rn50 \ 12 | --max_epochs 1000 \ 13 | --gpus 0 \ 14 | --accelerator gpu \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.001 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.5 \ 23 | --classifier_lr 0.3 \ 24 | --weight_decay 1e-4 \ 25 | --batch_size 256 \ 26 | --num_workers 4 \ 27 | --brightness 0.4 \ 28 | --contrast 0.4 \ 29 | --saturation 0.2 \ 30 | --hue 0.1 \ 31 | --gaussian_prob 0.0 0.0 \ 32 | --solarization_prob 0.0 0.2 \ 33 | --crop_size 32 \ 34 | --num_crops_per_aug 1 1 \ 35 | --name sdbyol \ 36 | --project sem_neurips \ 37 | --entity il_group \ 38 | --wandb \ 39 | --offline \ 40 | --save_checkpoint \ 41 | --method lbyol \ 42 | --proj_output_dim 256 \ 43 | --proj_hidden_dim 4096 \ 44 | --pred_hidden_dim 4096 \ 45 | --base_tau_momentum 0.99 \ 46 | --final_tau_momentum 1.0 \ 47 | --momentum_classifier \ 48 | --voc_size 13 \ 49 | --message_size $1 \ 50 | --tau_online 1 \ 51 | --tau_target 1 52 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/byol_rnlarge2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source /home/mila/l/lavoiems/env_1.12/bin/activate 6 | 7 | python ../../../main_pretrain.py \ 8 | --dataset cifar100 \ 9 | --backbone resnetlarge2 \ 10 | --data_dir ./datasets \ 11 | --checkpoint_dir $SCRATCH/cifar_rn50 \ 12 | --max_epochs 1000 \ 13 | --gpus 0 \ 14 | --accelerator gpu \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.001 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.5 \ 23 | --classifier_lr 0.3 \ 24 | --weight_decay 1e-4 \ 25 | --batch_size 256 \ 26 | --num_workers 4 \ 27 | --brightness 0.4 \ 28 | --contrast 0.4 \ 29 | --saturation 0.2 \ 30 | --hue 0.1 \ 31 | --gaussian_prob 0.0 0.0 \ 32 | --solarization_prob 0.0 0.2 \ 33 | --crop_size 32 \ 34 | --num_crops_per_aug 1 1 \ 35 | --name byol_large_embed \ 36 | --project iclr_rebuttal \ 37 | --entity il_group \ 38 | --wandb \ 39 | --save_checkpoint \ 40 | --method byol \ 41 | --proj_output_dim 256 \ 42 | --proj_hidden_dim 4096 \ 43 | --pred_hidden_dim 4096 \ 44 | --base_tau_momentum 0.99 \ 45 | --final_tau_momentum 1.0 \ 46 | --out_size $1 47 | 48 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/deepclusterv2.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --data_dir ./datasets \ 5 | --max_epochs 1000 \ 6 | --gpus 0 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --lars \ 11 | --grad_clip_lars \ 12 | --eta_lars 0.02 \ 13 | --scheduler warmup_cosine \ 14 | --lr 0.6 \ 15 | --min_lr 0.0006 \ 16 | --warmup_start_lr 0.0 \ 17 | --warmup_epochs 11 \ 18 | --classifier_lr 0.1 \ 19 | --weight_decay 1e-6 \ 20 | --batch_size 256 \ 21 | --num_workers 4 \ 22 | --brightness 0.8 \ 23 | --contrast 0.8 \ 24 | --saturation 0.8 \ 25 | --hue 0.2 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --crop_size 32 \ 28 | --num_crops_per_aug 1 1 \ 29 | --name deepclusterv2-$1 \ 30 | --project solo-learn \ 31 | --entity unitn-mhug \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --method deepclusterv2 \ 35 | --proj_hidden_dim 2048 \ 36 | --proj_output_dim 128 \ 37 | --num_prototypes 3000 3000 3000 38 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/dino.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --exclude_bias_n_norm \ 23 | --scheduler warmup_cosine \ 24 | --lr 0.3 \ 25 | --classifier_lr 0.1 \ 26 | --weight_decay 1e-6 \ 27 | --batch_size 256 \ 28 | --num_workers 4 \ 29 | --brightness 0.4 \ 30 | --contrast 0.4 \ 31 | --saturation 0.2 \ 32 | --hue 0.1 \ 33 | --gaussian_prob 0.0 0.0 \ 34 | --solarization_prob 0.0 0.2 \ 35 | --crop_size 32 \ 36 | --num_crops_per_aug 1 1 \ 37 | --name baselines-dino \ 38 | --project iclr \ 39 | --entity il_group \ 40 | --wandb \ 41 | --save_checkpoint \ 42 | --method dino \ 43 | --proj_output_dim 256 \ 44 | --proj_hidden_dim 2048 \ 45 | --num_prototypes 4096 \ 46 | --base_tau_momentum 0.9995 \ 47 | --final_tau_momentum 1.0 \ 48 | --momentum_classifier 49 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/launch_comparison.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create environment 4 | module load python/3.8 5 | virtualenv env 6 | 7 | source env/bin/activate 8 | 9 | pip install -r ../../../requirements.txt 10 | 11 | # Launch jobs 12 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./simclr.sh 13 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./mocov2plus.sh 14 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./byol.sh 15 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./barlow.sh 16 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./swav.sh 17 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./dino.sh 18 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./vicreg.sh 19 | 20 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./sdsimclr.sh 21 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./sdmoco.sh 22 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./sdbyol.sh 23 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./sdbarlow.sh 24 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./sdswav.sh 25 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./sddino.sh 26 | sbatch --gres=gpu:1 -c 4 --mem=24G --array=1-5 ./sdvicreg.sh 27 | 28 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/mocov2plus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.3 \ 21 | --classifier_lr 0.3 \ 22 | --weight_decay 1e-4 \ 23 | --batch_size 256 \ 24 | --num_workers 4 \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.4 \ 28 | --hue 0.1 \ 29 | --gaussian_prob 0.0 0.0 \ 30 | --crop_size 32 \ 31 | --num_crops_per_aug 1 1 \ 32 | --name baselines-moco \ 33 | --project iclr \ 34 | --entity il_group \ 35 | --wandb \ 36 | --save_checkpoint \ 37 | --method mocov2plus \ 38 | --proj_hidden_dim 2048 \ 39 | --queue_size 32768 \ 40 | --temperature 0.2 \ 41 | --base_tau_momentum 0.99 \ 42 | --final_tau_momentum 0.999 \ 43 | --momentum_classifier 44 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/nnbyol.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --data_dir ./datasets \ 5 | --max_epochs 1000 \ 6 | --gpus 0 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --lars \ 11 | --grad_clip_lars \ 12 | --eta_lars 0.02 \ 13 | --exclude_bias_n_norm \ 14 | --scheduler warmup_cosine \ 15 | --lr 1.0 \ 16 | --classifier_lr 0.1 \ 17 | --weight_decay 1e-5 \ 18 | --batch_size 256 \ 19 | --num_workers 4 \ 20 | --brightness 0.4 \ 21 | --contrast 0.4 \ 22 | --saturation 0.2 \ 23 | --hue 0.1 \ 24 | --gaussian_prob 0.0 0.0 \ 25 | --solarization_prob 0.0 0.2 \ 26 | --crop_size 32 \ 27 | --num_crops_per_aug 1 1 \ 28 | --name nnbyol-$1 \ 29 | --project solo-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method nnbyol \ 34 | --proj_output_dim 256 \ 35 | --proj_hidden_dim 4096 \ 36 | --pred_hidden_dim 4096 \ 37 | --base_tau_momentum 0.99 \ 38 | --final_tau_momentum 1.0 \ 39 | --momentum_classifier 40 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/nnclr.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --data_dir ./datasets \ 5 | --max_epochs 1000 \ 6 | --gpus 0 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --lars \ 11 | --grad_clip_lars \ 12 | --eta_lars 0.02 \ 13 | --exclude_bias_n_norm \ 14 | --scheduler warmup_cosine \ 15 | --lr 0.4 \ 16 | --classifier_lr 0.1 \ 17 | --weight_decay 1e-5 \ 18 | --batch_size 256 \ 19 | --num_workers 4 \ 20 | --brightness 0.4 \ 21 | --contrast 0.4 \ 22 | --saturation 0.2 \ 23 | --hue 0.1 \ 24 | --gaussian_prob 0.0 0.0 \ 25 | --solarization_prob 0.0 0.2 \ 26 | --crop_size 32 \ 27 | --num_crops_per_aug 1 1 \ 28 | --name nnclr-$1 \ 29 | --project solo-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method nnclr \ 34 | --temperature 0.2 \ 35 | --proj_hidden_dim 2048 \ 36 | --pred_hidden_dim 4096 \ 37 | --proj_output_dim 256 \ 38 | --queue_size 65536 39 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/nnsiam.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --data_dir ./datasets \ 5 | --max_epochs 1000 \ 6 | --gpus 0 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --scheduler warmup_cosine \ 11 | --lr 0.5 \ 12 | --classifier_lr 0.1 \ 13 | --weight_decay 1e-5 \ 14 | --batch_size 256 \ 15 | --num_workers 4 \ 16 | --brightness 0.4 \ 17 | --contrast 0.4 \ 18 | --saturation 0.4 \ 19 | --hue 0.1 \ 20 | --gaussian_prob 0.0 0.0 \ 21 | --crop_size 32 \ 22 | --num_crops_per_aug 1 1 \ 23 | --zero_init_residual \ 24 | --name nnsiam-$1 \ 25 | --project solo-learn \ 26 | --entity unitn-mhug \ 27 | --wandb \ 28 | --save_checkpoint \ 29 | --method nnsiam \ 30 | --proj_hidden_dim 2048 \ 31 | --pred_hidden_dim 4096 \ 32 | --proj_output_dim 2048 33 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/ressl.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --data_dir ./datasets \ 5 | --max_epochs 1000 \ 6 | --gpus 0 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --scheduler warmup_cosine \ 11 | --lr 0.05 \ 12 | --classifier_lr 0.1 \ 13 | --weight_decay 1e-4 \ 14 | --batch_size 256 \ 15 | --num_workers 4 \ 16 | --brightness 0.4 0.0 \ 17 | --contrast 0.4 0.0 \ 18 | --saturation 0.2 0.0 \ 19 | --hue 0.1 0.0 \ 20 | --gaussian_prob 0.0 0.0 \ 21 | --solarization_prob 0.0 0.0 \ 22 | --crop_size 32 \ 23 | --num_crops_per_aug 1 1 \ 24 | --name ressl-$1 \ 25 | --project solo-learn \ 26 | --entity unitn-mhug \ 27 | --wandb \ 28 | --save_checkpoint \ 29 | --method ressl \ 30 | --proj_output_dim 256 \ 31 | --proj_hidden_dim 4096 \ 32 | --base_tau_momentum 0.99 \ 33 | --final_tau_momentum 1.0 \ 34 | --momentum_classifier 35 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/sdbarlow.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | source env/bin/activate 5 | 6 | cp -r datasets $SLURM_TMPDIR 7 | 8 | python ../../../main_pretrain.py \ 9 | --dataset cifar100 \ 10 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 11 | --backbone resnet18 \ 12 | --data_dir $SLURM_TMPDIR/datasets \ 13 | --max_epochs 1000 \ 14 | --gpus 0 \ 15 | --accelerator gpu \ 16 | --precision 16 \ 17 | --optimizer sgd \ 18 | --lars \ 19 | --grad_clip_lars \ 20 | --eta_lars 0.02 \ 21 | --exclude_bias_n_norm \ 22 | --scheduler warmup_cosine \ 23 | --lr 0.3 \ 24 | --weight_decay 1e-4 \ 25 | --batch_size 256 \ 26 | --num_workers 4 \ 27 | --brightness 0.4 \ 28 | --contrast 0.4 \ 29 | --saturation 0.2 \ 30 | --hue 0.1 \ 31 | --gaussian_prob 0.0 0.0 \ 32 | --solarization_prob 0.0 0.2 \ 33 | --crop_size 32 \ 34 | --num_crops_per_aug 1 1 \ 35 | --name sdbarlow \ 36 | --project iclr \ 37 | --entity il_group \ 38 | --wandb \ 39 | --save_checkpoint \ 40 | --method sdbarlow \ 41 | --proj_hidden_dim 2048 \ 42 | --proj_output_dim 2048 \ 43 | --scale_loss 0.1 \ 44 | --voc_size 13 \ 45 | --message_size 5000 \ 46 | --tau_online 1. \ 47 | --tau_target 0.99 48 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/sdbyol.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --exclude_bias_n_norm \ 23 | --scheduler warmup_cosine \ 24 | --lr 1.0 \ 25 | --classifier_lr 0.1 \ 26 | --weight_decay 1e-5 \ 27 | --batch_size 256 \ 28 | --num_workers 4 \ 29 | --brightness 0.4 \ 30 | --contrast 0.4 \ 31 | --saturation 0.2 \ 32 | --hue 0.1 \ 33 | --gaussian_prob 0.0 0.0 \ 34 | --solarization_prob 0.0 0.2 \ 35 | --crop_size 32 \ 36 | --num_crops_per_aug 1 1 \ 37 | --name sdbyol \ 38 | --project iclr \ 39 | --entity il_group \ 40 | --wandb \ 41 | --save_checkpoint \ 42 | --method sdbyol \ 43 | --proj_output_dim 256 \ 44 | --proj_hidden_dim 4096 \ 45 | --pred_hidden_dim 4096 \ 46 | --base_tau_momentum 0.99 \ 47 | --final_tau_momentum 1.0 \ 48 | --momentum_classifier \ 49 | --voc_size 13 \ 50 | --message_size 5000 \ 51 | --tau_online 1 \ 52 | --tau_target 1 53 | 54 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/sdbyol_50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source /home/mila/l/lavoiems/env/bin/activate 6 | 7 | python ../../../main_pretrain.py \ 8 | --dataset cifar100 \ 9 | --backbone resnet50 \ 10 | --data_dir ./datasets \ 11 | --checkpoint_dir $SCRATCH/cifar_rn50_scale \ 12 | --max_epochs 1000 \ 13 | --gpus 0 \ 14 | --accelerator gpu \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.001 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.5 \ 23 | --classifier_lr 0.3 \ 24 | --weight_decay 1e-4 \ 25 | --batch_size 256 \ 26 | --num_workers 6 \ 27 | --brightness 0.4 \ 28 | --contrast 0.4 \ 29 | --saturation 0.2 \ 30 | --hue 0.1 \ 31 | --gaussian_prob 0.0 0.0 \ 32 | --solarization_prob 0.0 0.2 \ 33 | --crop_size 32 \ 34 | --num_crops_per_aug 1 1 \ 35 | --name sdbyol-z:4096 \ 36 | --project syn_ssl \ 37 | --entity lavoiems \ 38 | --wandb \ 39 | --save_checkpoint \ 40 | --method sdbyol \ 41 | --proj_output_dim 256 \ 42 | --proj_hidden_dim 4096 \ 43 | --pred_hidden_dim 4096 \ 44 | --base_tau_momentum 0.99 \ 45 | --final_tau_momentum 1.0 \ 46 | --momentum_classifier \ 47 | --voc_size $2 \ 48 | --message_size $1 \ 49 | --tau_online $3 \ 50 | --tau_target $3 \ 51 | 52 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/sddino.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --exclude_bias_n_norm \ 23 | --scheduler warmup_cosine \ 24 | --lr 0.3 \ 25 | --classifier_lr 0.1 \ 26 | --weight_decay 1e-6 \ 27 | --batch_size 256 \ 28 | --num_workers 4 \ 29 | --brightness 0.4 \ 30 | --contrast 0.4 \ 31 | --saturation 0.2 \ 32 | --hue 0.1 \ 33 | --gaussian_prob 0.0 0.0 \ 34 | --solarization_prob 0.0 0.2 \ 35 | --crop_size 32 \ 36 | --num_crops_per_aug 1 1 \ 37 | --name sddino \ 38 | --project iclr \ 39 | --entity il_group \ 40 | --wandb \ 41 | --save_checkpoint \ 42 | --method sddino \ 43 | --proj_output_dim 256 \ 44 | --proj_hidden_dim 2048 \ 45 | --num_prototypes 4096 \ 46 | --base_tau_momentum 0.9995 \ 47 | --final_tau_momentum 1.0 \ 48 | --momentum_classifier \ 49 | --voc_size 13 \ 50 | --message_size 5000 \ 51 | --tau_online 1.0 \ 52 | --tau_target 1.0 53 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/sdmoco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.3 \ 21 | --classifier_lr 0.3 \ 22 | --weight_decay 1e-4 \ 23 | --batch_size 256 \ 24 | --num_workers 4 \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.4 \ 28 | --hue 0.1 \ 29 | --gaussian_prob 0.0 0.0 \ 30 | --crop_size 32 \ 31 | --num_crops_per_aug 1 1 \ 32 | --name sdmoco \ 33 | --project iclr \ 34 | --entity il_group \ 35 | --wandb \ 36 | --save_checkpoint \ 37 | --method sdmoco \ 38 | --proj_hidden_dim 2048 \ 39 | --queue_size 32768 \ 40 | --temperature 0.2 \ 41 | --base_tau_momentum 0.99 \ 42 | --final_tau_momentum 0.999 \ 43 | --momentum_classifier \ 44 | --voc_size 13 \ 45 | --message_size 5000 \ 46 | --tau_online 0.04 \ 47 | --tau_target 0.01 48 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/sdsimclr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --exclude_bias_n_norm \ 23 | --scheduler warmup_cosine \ 24 | --lr 0.4 \ 25 | --classifier_lr 0.1 \ 26 | --weight_decay 1e-5 \ 27 | --batch_size 256 \ 28 | --num_workers 4 \ 29 | --crop_size 32 \ 30 | --brightness 0.8 \ 31 | --contrast 0.8 \ 32 | --saturation 0.8 \ 33 | --hue 0.2 \ 34 | --gaussian_prob 0.0 0.0 \ 35 | --crop_size 32 \ 36 | --num_crops_per_aug 1 1 \ 37 | --name sdsimclr \ 38 | --project iclr \ 39 | --entity il_group \ 40 | --wandb \ 41 | --save_checkpoint \ 42 | --method sdsimclr \ 43 | --temperature 0.2 \ 44 | --proj_hidden_dim 2048 \ 45 | --proj_output_dim 256 \ 46 | --voc_size 13 \ 47 | --message_size 5000 \ 48 | --tau_online 0.17 \ 49 | --tau_target 0.78 50 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/sdswav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --scheduler warmup_cosine \ 23 | --lr 0.6 \ 24 | --min_lr 0.0006 \ 25 | --classifier_lr 0.1 \ 26 | --weight_decay 1e-6 \ 27 | --batch_size 256 \ 28 | --num_workers 4 \ 29 | --crop_size 32 \ 30 | --brightness 0.8 \ 31 | --contrast 0.8 \ 32 | --saturation 0.8 \ 33 | --hue 0.2 \ 34 | --gaussian_prob 0.0 0.0 \ 35 | --num_crops_per_aug 1 1 \ 36 | --name sdswav \ 37 | --save_checkpoint \ 38 | --entity il_group \ 39 | --project iclr \ 40 | --wandb \ 41 | --method sdswav \ 42 | --proj_hidden_dim 2048 \ 43 | --queue_size 3840 \ 44 | --proj_output_dim 128 \ 45 | --num_prototypes 3000 \ 46 | --epoch_queue_starts 50 \ 47 | --freeze_prototypes_epochs 2 \ 48 | --voc_size 13 \ 49 | --message_size 5000 \ 50 | --tau_online 0.85 \ 51 | --tau_target 1.5 52 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/sdvicreg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --exclude_bias_n_norm \ 23 | --scheduler warmup_cosine \ 24 | --lr 0.3 \ 25 | --weight_decay 1e-4 \ 26 | --batch_size 256 \ 27 | --num_workers 4 \ 28 | --crop_size 32 \ 29 | --min_scale 0.2 \ 30 | --brightness 0.4 \ 31 | --contrast 0.4 \ 32 | --saturation 0.2 \ 33 | --hue 0.1 \ 34 | --solarization_prob 0.1 \ 35 | --gaussian_prob 0.0 0.0 \ 36 | --crop_size 32 \ 37 | --num_crops_per_aug 1 1 \ 38 | --name vicreg \ 39 | --project iclr \ 40 | --entity il_group \ 41 | --wandb \ 42 | --save_checkpoint \ 43 | --method sdvicreg \ 44 | --proj_hidden_dim 2048 \ 45 | --proj_output_dim 2048 \ 46 | --sim_loss_weight 25.0 \ 47 | --var_loss_weight 25.0 \ 48 | --cov_loss_weight 1.0 \ 49 | --message_size 5000 \ 50 | --voc_size 13 \ 51 | --tau_online 1 52 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/simclr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | source env/bin/activate 5 | 6 | echo "eta_lr 0.02" 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --exclude_bias_n_norm \ 23 | --scheduler warmup_cosine \ 24 | --lr 0.4 \ 25 | --classifier_lr 0.1 \ 26 | --weight_decay 1e-5 \ 27 | --batch_size 256 \ 28 | --num_workers 4 \ 29 | --crop_size 32 \ 30 | --brightness 0.8 \ 31 | --contrast 0.8 \ 32 | --saturation 0.8 \ 33 | --hue 0.2 \ 34 | --gaussian_prob 0.0 0.0 \ 35 | --crop_size 32 \ 36 | --num_crops_per_aug 1 1 \ 37 | --name baselines-simclr \ 38 | --project iclr \ 39 | --entity il_group \ 40 | --wandb \ 41 | --save_checkpoint \ 42 | --method simclr \ 43 | --temperature 0.2 \ 44 | --proj_hidden_dim 2048 \ 45 | --proj_output_dim 256 46 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/simsiam.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --data_dir ./datasets \ 5 | --max_epochs 1000 \ 6 | --gpus 0 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --scheduler warmup_cosine \ 11 | --lr 0.5 \ 12 | --classifier_lr 0.1 \ 13 | --weight_decay 1e-5 \ 14 | --batch_size 256 \ 15 | --num_workers 4 \ 16 | --crop_size 32 \ 17 | --brightness 0.4 \ 18 | --contrast 0.4 \ 19 | --saturation 0.4 \ 20 | --hue 0.1 \ 21 | --gaussian_prob 0.0 0.0 \ 22 | --crop_size 32 \ 23 | --num_crops_per_aug 1 1 \ 24 | --zero_init_residual \ 25 | --name simsiam-$1 \ 26 | --project solo-learn \ 27 | --entity unitn-mhug \ 28 | --wandb \ 29 | --save_checkpoint \ 30 | --method simsiam \ 31 | --proj_hidden_dim 2048 \ 32 | --pred_hidden_dim 512 \ 33 | --proj_output_dim 2048 34 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/supcon.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --data_dir ./datasets \ 5 | --max_epochs 1000 \ 6 | --gpus 0 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --lars \ 11 | --grad_clip_lars \ 12 | --eta_lars 0.02 \ 13 | --exclude_bias_n_norm \ 14 | --scheduler warmup_cosine \ 15 | --lr 0.4 \ 16 | --classifier_lr 0.1 \ 17 | --weight_decay 1e-5 \ 18 | --batch_size 256 \ 19 | --num_workers 4 \ 20 | --crop_size 32 \ 21 | --brightness 0.8 \ 22 | --contrast 0.8 \ 23 | --saturation 0.8 \ 24 | --hue 0.2 \ 25 | --gaussian_prob 0.0 0.0 \ 26 | --crop_size 32 \ 27 | --num_crops_per_aug 1 1 \ 28 | --name supcon-$1 \ 29 | --project solo-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method supcon \ 34 | --temperature 0.2 \ 35 | --proj_hidden_dim 2048 \ 36 | --proj_output_dim 256 37 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/swav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --scheduler warmup_cosine \ 23 | --lr 0.6 \ 24 | --min_lr 0.0006 \ 25 | --classifier_lr 0.1 \ 26 | --weight_decay 1e-6 \ 27 | --batch_size 256 \ 28 | --num_workers 4 \ 29 | --crop_size 32 \ 30 | --brightness 0.8 \ 31 | --contrast 0.8 \ 32 | --saturation 0.8 \ 33 | --hue 0.2 \ 34 | --gaussian_prob 0.0 0.0 \ 35 | --num_crops_per_aug 1 1 \ 36 | --name baselines-swav \ 37 | --save_checkpoint \ 38 | --entity il_group \ 39 | --project iclr \ 40 | --wandb \ 41 | --method swav \ 42 | --proj_hidden_dim 2048 \ 43 | --queue_size 3840 \ 44 | --proj_output_dim 128 \ 45 | --num_prototypes 3000 \ 46 | --epoch_queue_starts 50 \ 47 | --freeze_prototypes_epochs 2 48 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/vibcreg.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --data_dir ./datasets \ 5 | --max_epochs 1000 \ 6 | --gpus 0 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --lars \ 11 | --grad_clip_lars \ 12 | --eta_lars 0.02 \ 13 | --exclude_bias_n_norm \ 14 | --scheduler warmup_cosine \ 15 | --lr 0.3 \ 16 | --weight_decay 1e-4 \ 17 | --batch_size 256 \ 18 | --num_workers 4 \ 19 | --crop_size 32 \ 20 | --min_scale 0.2 \ 21 | --brightness 0.4 \ 22 | --contrast 0.4 \ 23 | --saturation 0.2 \ 24 | --hue 0.1 \ 25 | --solarization_prob 0.1 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --crop_size 32 \ 28 | --num_crops_per_aug 1 1 \ 29 | --name vibcreg-$1 \ 30 | --project solo-learn \ 31 | --entity unitn-mhug \ 32 | --save_checkpoint \ 33 | --method vibcreg \ 34 | --proj_hidden_dim 2048 \ 35 | --proj_output_dim 2048 \ 36 | --sim_loss_weight 25.0 \ 37 | --var_loss_weight 25.0 \ 38 | --cov_loss_weight 200.0 \ 39 | --iternorm \ 40 | --wandb 41 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/vicreg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load python/3.8 4 | 5 | source env/bin/activate 6 | 7 | cp -r datasets $SLURM_TMPDIR 8 | 9 | python ../../../main_pretrain.py \ 10 | --dataset cifar100 \ 11 | --checkpoint_dir /network/scratch/l/lavoiems/baselines \ 12 | --backbone resnet18 \ 13 | --data_dir $SLURM_TMPDIR/datasets \ 14 | --max_epochs 1000 \ 15 | --gpus 0 \ 16 | --accelerator gpu \ 17 | --precision 16 \ 18 | --optimizer sgd \ 19 | --lars \ 20 | --grad_clip_lars \ 21 | --eta_lars 0.02 \ 22 | --exclude_bias_n_norm \ 23 | --scheduler warmup_cosine \ 24 | --lr 0.3 \ 25 | --weight_decay 1e-4 \ 26 | --batch_size 256 \ 27 | --num_workers 4 \ 28 | --crop_size 32 \ 29 | --min_scale 0.2 \ 30 | --brightness 0.4 \ 31 | --contrast 0.4 \ 32 | --saturation 0.2 \ 33 | --hue 0.1 \ 34 | --solarization_prob 0.1 \ 35 | --gaussian_prob 0.0 0.0 \ 36 | --crop_size 32 \ 37 | --num_crops_per_aug 1 1 \ 38 | --name baselines-vicreg \ 39 | --project iclr \ 40 | --entity il_group \ 41 | --wandb \ 42 | --save_checkpoint \ 43 | --method vicreg \ 44 | --proj_hidden_dim 2048 \ 45 | --proj_output_dim 2048 \ 46 | --sim_loss_weight 25.0 \ 47 | --var_loss_weight 25.0 \ 48 | --cov_loss_weight 1.0 49 | -------------------------------------------------------------------------------- /bash_files/pretrain/cifar/wmse.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --data_dir ./datasets \ 5 | --max_epochs 1000 \ 6 | --gpus 0 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --num_workers 4 \ 10 | --optimizer adam \ 11 | --scheduler warmup_cosine \ 12 | --warmup_epochs 2 \ 13 | --lr 3e-3 \ 14 | --warmup_start_lr 0 \ 15 | --classifier_lr 3e-3 \ 16 | --weight_decay 1e-6 \ 17 | --batch_size 256 \ 18 | --crop_size 32 \ 19 | --brightness 0.4 \ 20 | --contrast 0.4 \ 21 | --saturation 0.2 \ 22 | --hue 0.1 \ 23 | --gaussian_prob 0.0 0.0 \ 24 | --crop_size 32 \ 25 | --num_crops_per_aug 1 1 \ 26 | --min_scale 0.2 \ 27 | --name wmse-$1 \ 28 | --wandb \ 29 | --save_checkpoint \ 30 | --project solo-learn \ 31 | --entity unitn-mhug \ 32 | --method wmse \ 33 | --proj_output_dim 64 \ 34 | --whitening_size 128 35 | -------------------------------------------------------------------------------- /bash_files/pretrain/custom/byol.sh: -------------------------------------------------------------------------------- 1 | # Train without labels. 2 | # To train with labels, simply remove --no_labels 3 | # --val_dir is optional and will expect a directory with subfolder (classes) 4 | # --dali flag is also supported 5 | python3 ../../../main_pretrain.py \ 6 | --dataset custom \ 7 | --backbone resnet18 \ 8 | --data_dir PATH_TO_DIR \ 9 | --train_dir PATH_TO_TRAIN_DIR \ 10 | --no_labels \ 11 | --max_epochs 400 \ 12 | --gpus 0,1 \ 13 | --accelerator gpu \ 14 | --strategy ddp \ 15 | --sync_batchnorm \ 16 | --precision 16 \ 17 | --optimizer sgd \ 18 | --lars \ 19 | --grad_clip_lars \ 20 | --eta_lars 0.02 \ 21 | --exclude_bias_n_norm \ 22 | --scheduler warmup_cosine \ 23 | --lr 1.0 \ 24 | --classifier_lr 0.1 \ 25 | --weight_decay 1e-5 \ 26 | --batch_size 128 \ 27 | --num_workers 4 \ 28 | --brightness 0.4 \ 29 | --contrast 0.4 \ 30 | --saturation 0.2 \ 31 | --hue 0.1 \ 32 | --color_jitter_prob 0.8 \ 33 | --gray_scale_prob 0.2 \ 34 | --horizontal_flip_prob 0.5 \ 35 | --gaussian_prob 1.0 0.1 \ 36 | --solarization_prob 0.0 0.2 \ 37 | --num_crops_per_aug 1 1 \ 38 | --name byol-400ep-custom \ 39 | --entity unitn-mhug \ 40 | --project solo-learn \ 41 | --wandb \ 42 | --save_checkpoint \ 43 | --method byol \ 44 | --output_dim 256 \ 45 | --proj_hidden_dim 4096 \ 46 | --pred_hidden_dim 8192 \ 47 | --base_tau_momentum 0.99 \ 48 | --final_tau_momentum 1.0 49 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/barlow.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 4 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.3 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 128 \ 23 | --dali \ 24 | --brightness 0.4 \ 25 | --contrast 0.4 \ 26 | --saturation 0.2 \ 27 | --hue 0.1 \ 28 | --gaussian_prob 1.0 0.1 \ 29 | --solarization_prob 0.0 0.2 \ 30 | --num_crops_per_aug 1 1 \ 31 | --name barlow-400ep-imagenet100 \ 32 | --entity unitn-mhug \ 33 | --project solo-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --scale_loss 0.1 \ 37 | --method barlow_twins \ 38 | --proj_hidden_dim 2048 \ 39 | --proj_output_dim 2048 40 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/byol.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.5 \ 20 | --classifier_lr 0.1 \ 21 | --weight_decay 1e-5 \ 22 | --batch_size 128 \ 23 | --num_workers 4 \ 24 | --dali \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.2 \ 28 | --hue 0.1 \ 29 | --gaussian_prob 1.0 0.1 \ 30 | --solarization_prob 0.0 0.2 \ 31 | --num_crops_per_aug 1 1 \ 32 | --name byol-400ep-imagenet100 \ 33 | --entity unitn-mhug \ 34 | --project solo-learn \ 35 | --wandb \ 36 | --save_checkpoint \ 37 | --method byol \ 38 | --proj_output_dim 256 \ 39 | --proj_hidden_dim 4096 \ 40 | --pred_hidden_dim 8192 \ 41 | --base_tau_momentum 0.99 \ 42 | --final_tau_momentum 1.0 43 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/deepclusterv2.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.6 \ 20 | --min_lr 0.0006 \ 21 | --warmup_start_lr 0.0 \ 22 | --warmup_epochs 11 \ 23 | --classifier_lr 0.1 \ 24 | --weight_decay 1e-6 \ 25 | --batch_size 128 \ 26 | --num_workers 4 \ 27 | --dali \ 28 | --encode_indexes_into_labels \ 29 | --brightness 0.8 \ 30 | --contrast 0.8 \ 31 | --saturation 0.8 \ 32 | --hue 0.2 \ 33 | --num_crops_per_aug 2 \ 34 | --name deepclusterv2-400ep-imagenet100 \ 35 | --entity unitn-mhug \ 36 | --project solo-learn \ 37 | --wandb \ 38 | --save_checkpoint \ 39 | --method deepclusterv2 \ 40 | --proj_hidden_dim 2048 \ 41 | --proj_output_dim 128 \ 42 | --num_prototypes 3000 3000 3000 43 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/dino.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.3 \ 20 | --classifier_lr 0.1 \ 21 | --weight_decay 1e-6 \ 22 | --batch_size 128 \ 23 | --num_workers 4 \ 24 | --dali \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.2 \ 28 | --hue 0.1 \ 29 | --gaussian_prob 1.0 0.1 \ 30 | --solarization_prob 0.0 0.2 \ 31 | --num_crops_per_aug 1 1 \ 32 | --name dino-400ep-imagenet100 \ 33 | --entity unitn-mhug \ 34 | --project solo-learn \ 35 | --wandb \ 36 | --save_checkpoint \ 37 | --method dino \ 38 | --proj_output_dim 256 \ 39 | --proj_hidden_dim 2048 \ 40 | --num_prototypes 4096 \ 41 | --base_tau_momentum 0.9995 \ 42 | --final_tau_momentum 1.0 43 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/dino_vit.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone vit_tiny \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/test \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 32 \ 13 | --optimizer adamw \ 14 | --scheduler warmup_cosine \ 15 | --lr 0.005 \ 16 | --warmup_start_lr 1e-6 \ 17 | --classifier_lr 3e-3 \ 18 | --weight_decay 1e-4 \ 19 | --batch_size 64 \ 20 | --num_workers 4 \ 21 | --brightness 0.4 \ 22 | --contrast 0.4 \ 23 | --saturation 0.2 \ 24 | --hue 0.1 \ 25 | --gaussian_prob 1.0 0.1 \ 26 | --solarization_prob 0.0 0.2 \ 27 | --num_crops_per_aug 1 1 \ 28 | --name vit_tiny-dino-400ep-imagenet100 \ 29 | --project solo-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method dino \ 34 | --proj_output_dim 256 \ 35 | --proj_hidden_dim 2048 \ 36 | --num_prototypes 65536 \ 37 | --norm_last_layer false \ 38 | --base_tau_momentum 0.9995 \ 39 | --final_tau_momentum 1.0 40 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/mocov2plus.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --scheduler cosine \ 15 | --lr 0.3 \ 16 | --classifier_lr 0.3 \ 17 | --weight_decay 1e-4 \ 18 | --batch_size 128 \ 19 | --num_workers 4 \ 20 | --dali \ 21 | --brightness 0.4 \ 22 | --contrast 0.4 \ 23 | --saturation 0.4 \ 24 | --hue 0.1 \ 25 | --num_crops_per_aug 2 \ 26 | --name mocov2plus-400ep \ 27 | --project solo-learn \ 28 | --entity unitn-mhug \ 29 | --wandb \ 30 | --save_checkpoint \ 31 | --method mocov2plus \ 32 | --proj_hidden_dim 2048 \ 33 | --queue_size 65536 \ 34 | --temperature 0.2 \ 35 | --base_tau_momentum 0.99 \ 36 | --final_tau_momentum 0.999 \ 37 | --momentum_classifier 38 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/multicrop/byol.sh: -------------------------------------------------------------------------------- 1 | # multicrop byol 2 | python3 ../../../main_pretrain.py \ 3 | --dataset imagenet100 \ 4 | --backbone resnet18 \ 5 | --data_dir /datasets \ 6 | --train_dir imagenet-100/train \ 7 | --val_dir imagenet-100/val \ 8 | --max_epochs 400 \ 9 | --gpus 0,1 \ 10 | --accelerator gpu \ 11 | --strategy ddp \ 12 | --sync_batchnorm \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.5 \ 21 | --classifier_lr 0.1 \ 22 | --weight_decay 1e-5 \ 23 | --batch_size 128 \ 24 | --num_workers 4 \ 25 | --dali \ 26 | --brightness 0.4 \ 27 | --contrast 0.4 \ 28 | --saturation 0.2 \ 29 | --hue 0.1 \ 30 | --gaussian_prob 1.0 0.1 0.0 \ 31 | --solarization_prob 0.0 0.2 0.0 \ 32 | --crop_size 224 224 96 \ 33 | --num_crops_per_aug 1 1 6 \ 34 | --name byol-multicro-400ep-imagenet100 \ 35 | --project debug \ 36 | --wandb \ 37 | --save_checkpoint \ 38 | --method byol \ 39 | --proj_output_dim 256 \ 40 | --proj_hidden_dim 4096 \ 41 | --pred_hidden_dim 8192 \ 42 | --base_tau_momentum 0.99 \ 43 | --final_tau_momentum 1.0 44 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/multicrop/simclr.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.3 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 128 \ 22 | --num_workers 4 \ 23 | --brightness 0.8 \ 24 | --contrast 0.8 \ 25 | --saturation 0.8 \ 26 | --hue 0.2 \ 27 | --crop_size 224 96 \ 28 | --num_crops_per_aug 2 6 \ 29 | --name multicrop-simclr-400ep-imagenet100 \ 30 | --dali \ 31 | --project solo-learn \ 32 | --entity unitn-mhug \ 33 | --wandb \ 34 | --save_checkpoint \ 35 | --method simclr \ 36 | --proj_hidden_dim 2048 \ 37 | --temperature 0.1 38 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/multicrop/supcon.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.3 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 128 \ 23 | --num_workers 4 \ 24 | --brightness 0.8 \ 25 | --contrast 0.8 \ 26 | --saturation 0.8 \ 27 | --hue 0.2 \ 28 | --crop_size 224 96 \ 29 | --num_crops_per_aug 2 6 \ 30 | --name multicrop-supcon-400ep-imagenet100 \ 31 | --dali \ 32 | --project solo-learn \ 33 | --entity unitn-mhug \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method supcon \ 37 | --temperature 0.1 \ 38 | --proj_hidden_dim 2048 \ 39 | --supervised 40 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/nnclr.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.4 \ 20 | --weight_decay 1e-5 \ 21 | --batch_size 128 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --gaussian_prob 1.0 0.1 \ 27 | --solarization_prob 0.0 0.2 \ 28 | --num_crops_per_aug 1 1 \ 29 | --num_workers 4 \ 30 | --dali \ 31 | --name nnclr-400ep-imagenet100 \ 32 | --entity unitn-mhug \ 33 | --project solo-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method nnclr \ 37 | --temperature 0.2 \ 38 | --proj_hidden_dim 2048 \ 39 | --pred_hidden_dim 4096 \ 40 | --proj_output_dim 256 \ 41 | --queue_size 65536 42 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/ressl.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler warmup_cosine \ 14 | --lr 0.3 \ 15 | --classifier_lr 0.1 \ 16 | --weight_decay 1e-4 \ 17 | --batch_size 128 \ 18 | --num_workers 4 \ 19 | --dali \ 20 | --brightness 0.4 0.0 \ 21 | --contrast 0.4 0.0 \ 22 | --saturation 0.4 0.0 \ 23 | --hue 0.1 0.0 \ 24 | --gaussian_prob 0.5 0.0 \ 25 | --solarization_prob 0.0 0.0 \ 26 | --num_crops_per_aug 1 1 \ 27 | --name ressl-400ep-imagenet100 \ 28 | --project solo-learn \ 29 | --entity unitn-mhug \ 30 | --wandb \ 31 | --save_checkpoint \ 32 | --method ressl \ 33 | --proj_output_dim 256 \ 34 | --proj_hidden_dim 4096 \ 35 | --base_tau_momentum 0.99 \ 36 | --final_tau_momentum 1.0 \ 37 | --momentum_classifier \ 38 | --temperature_q 0.1 \ 39 | --temperature_k 0.04 40 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/simclr.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.3 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 128 \ 22 | --num_workers 4 \ 23 | --brightness 0.8 \ 24 | --contrast 0.8 \ 25 | --saturation 0.8 \ 26 | --hue 0.2 \ 27 | --num_crops_per_aug 2 \ 28 | --name simclr-400ep-imagenet100 \ 29 | --dali \ 30 | --project solo-learn \ 31 | --entity unitn-mhug \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --method simclr \ 35 | --temperature 0.2 \ 36 | --proj_hidden_dim 2048 37 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/simsiam.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --scheduler warmup_cosine \ 15 | --lr 0.5 \ 16 | --classifier_lr 0.1 \ 17 | --weight_decay 1e-5 \ 18 | --batch_size 128 \ 19 | --num_workers 4 \ 20 | --brightness 0.4 \ 21 | --contrast 0.4 \ 22 | --saturation 0.4 \ 23 | --hue 0.1 \ 24 | --num_crops_per_aug 2 \ 25 | --zero_init_residual \ 26 | --name simsiam-400ep-imagenet100 \ 27 | --dali \ 28 | --entity unitn-mhug \ 29 | --project solo-learn \ 30 | --wandb \ 31 | --save_checkpoint \ 32 | --method simsiam \ 33 | --proj_hidden_dim 2048 \ 34 | --pred_hidden_dim 512 \ 35 | --proj_output_dim 2048 36 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/supcon.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.3 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 128 \ 22 | --num_workers 4 \ 23 | --brightness 0.8 \ 24 | --contrast 0.8 \ 25 | --saturation 0.8 \ 26 | --hue 0.2 \ 27 | --num_crops_per_aug 2 \ 28 | --name supcon-400ep-imagenet100 \ 29 | --dali \ 30 | --project solo-learn \ 31 | --entity unitn-mhug \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --method supcon \ 35 | --temperature 0.2 \ 36 | --proj_hidden_dim 2048 37 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/swav.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.6 \ 20 | --min_lr 0.0006 \ 21 | --classifier_lr 0.1 \ 22 | --weight_decay 1e-6 \ 23 | --batch_size 128 \ 24 | --num_workers 4 \ 25 | --dali \ 26 | --brightness 0.8 \ 27 | --contrast 0.8 \ 28 | --saturation 0.8 \ 29 | --hue 0.2 \ 30 | --num_crops_per_aug 2 \ 31 | --name swav-400ep-imagenet100 \ 32 | --entity unitn-mhug \ 33 | --project solo-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method swav \ 37 | --proj_hidden_dim 2048 \ 38 | --queue_size 3840 \ 39 | --proj_output_dim 128 \ 40 | --num_prototypes 3000 \ 41 | --epoch_queue_starts 50 \ 42 | --freeze_prototypes_epochs 2 43 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/vibcreg.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.3 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 128 \ 22 | --num_workers 4 \ 23 | --dali \ 24 | --min_scale 0.2 \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.2 \ 28 | --hue 0.1 \ 29 | --solarization_prob 0.1 \ 30 | --num_crops_per_aug 2 \ 31 | --name vibcreg-400ep-imagenet100 \ 32 | --entity unitn-mhug \ 33 | --project solo-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method vibcreg \ 37 | --proj_hidden_dim 2048 \ 38 | --proj_output_dim 2048 \ 39 | --sim_loss_weight 25.0 \ 40 | --var_loss_weight 25.0 \ 41 | --cov_loss_weight 200.0 \ 42 | --iternorm 43 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/vicreg.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --max_epochs 400 \ 8 | --gpus 0,1 \ 9 | --accelerator gpu \ 10 | --strategy ddp \ 11 | --sync_batchnorm \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --grad_clip_lars \ 16 | --eta_lars 0.02 \ 17 | --exclude_bias_n_norm \ 18 | --scheduler warmup_cosine \ 19 | --lr 0.3 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 128 \ 22 | --num_workers 4 \ 23 | --dali \ 24 | --min_scale 0.2 \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.2 \ 28 | --hue 0.1 \ 29 | --solarization_prob 0.1 \ 30 | --num_crops_per_aug 2 \ 31 | --name vicreg-400ep-imagenet100 \ 32 | --entity unitn-mhug \ 33 | --project solo-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method vicreg \ 37 | --proj_hidden_dim 2048 \ 38 | --proj_output_dim 2048 \ 39 | --sim_loss_weight 25.0 \ 40 | --var_loss_weight 25.0 \ 41 | --cov_loss_weight 1.0 42 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet-100/wmse.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --data_dir /datasets \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/test \ 7 | --max_epochs 400 \ 8 | --precision 16 \ 9 | --gpus 0,1 \ 10 | --accelerator gpu \ 11 | --strategy ddp \ 12 | --num_workers 4 \ 13 | --optimizer adam \ 14 | --scheduler warmup_cosine \ 15 | --warmup_epochs 2 \ 16 | --lr 2e-3 \ 17 | --warmup_start_lr 0 \ 18 | --classifier_lr 3e-3 \ 19 | --weight_decay 1e-6 \ 20 | --batch_size 256 \ 21 | --brightness 0.8 \ 22 | --contrast 0.8 \ 23 | --saturation 0.8 \ 24 | --hue 0.2 \ 25 | --gaussian_prob 0.2 \ 26 | --min_scale 0.08 \ 27 | --num_crops_per_aug 2 \ 28 | --dali \ 29 | --wandb \ 30 | --save_checkpoint \ 31 | --name wmse-imagenet100 \ 32 | --project solo-learn \ 33 | --entity unitn-mhug \ 34 | --method wmse \ 35 | --proj_output_dim 64 \ 36 | --whitening_size 128 37 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet/barlow.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_pretrain.py \ 2 | --dataset imagenet \ 3 | --backbone resnet50 \ 4 | --data_dir /data/datasets \ 5 | --train_dir imagenet/train \ 6 | --val_dir imagenet/val \ 7 | --max_epochs 100 \ 8 | --gpus 0,1,2,3 \ 9 | --accelerator ddp \ 10 | --sync_batchnorm \ 11 | --num_workers 4 \ 12 | --precision 16 \ 13 | --optimizer sgd \ 14 | --lars \ 15 | --eta_lars 0.001 \ 16 | --exclude_bias_n_norm \ 17 | --scheduler warmup_cosine \ 18 | --lr 0.8 \ 19 | --weight_decay 1.5e-6 \ 20 | --batch_size 64 \ 21 | --dali \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --gaussian_prob 1.0 0.1 \ 27 | --solarization_prob 0.0 0.2 \ 28 | --num_crops_per_aug 1 1 \ 29 | --name barlow-resnet50-imagenet-100ep \ 30 | --entity unitn-mhug \ 31 | --project solo-learn \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --lamb 0.0051 \ 35 | --scale_loss 0.048 \ 36 | --method barlow_twins \ 37 | --proj_hidden_dim 4096 \ 38 | --proj_output_dim 4096 \ 39 | --auto_resume 40 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet/byol.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT_PATH=$1 4 | 5 | ../../../prepare_data.sh TEST 6 | 7 | source ~/env/bin/activate 8 | 9 | python3 ../../../main_pretrain.py \ 10 | --dataset imagenet \ 11 | --backbone resnet50 \ 12 | --checkpoint_dir=$ROOT_PATH \ 13 | --data_dir $SLURM_TMPDIR/data \ 14 | --train_dir $SLURM_TMPDIR/data/train/ \ 15 | --val_dir $SLURM_TMPDIR/data/val/ \ 16 | --max_epochs 512 \ 17 | --gpus 0,1,2,3 \ 18 | --accelerator gpu \ 19 | --strategy ddp \ 20 | --sync_batchnorm \ 21 | --precision 16 \ 22 | --optimizer sgd \ 23 | --lars \ 24 | --eta_lars 0.001 \ 25 | --exclude_bias_n_norm \ 26 | --scheduler warmup_cosine \ 27 | --lr 0.45 \ 28 | --accumulate_grad_batches 16 \ 29 | --classifier_lr 0.2 \ 30 | --weight_decay 1e-6 \ 31 | --batch_size 128 \ 32 | --num_workers 4 \ 33 | --dali \ 34 | --brightness 0.4 \ 35 | --contrast 0.4 \ 36 | --saturation 0.2 \ 37 | --hue 0.1 \ 38 | --gaussian_prob 1.0 0.1 \ 39 | --solarization_prob 0.0 0.2 \ 40 | --num_crops_per_aug 1 1 \ 41 | --name byol-resnet50-imagenet-1000epochs-FULL \ 42 | --entity il_group \ 43 | --project VIL \ 44 | --wandb \ 45 | --offline \ 46 | --save_checkpoint \ 47 | --method byol \ 48 | --proj_output_dim 256 \ 49 | --proj_hidden_dim 4096 \ 50 | --pred_hidden_dim 4096 \ 51 | --base_tau_momentum 0.99 \ 52 | --final_tau_momentum 1.0 \ 53 | --momentum_classifier 54 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet/byol_orion.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_PATH=$1 3 | ROOT_PATH=$2 4 | DATASET=$3 5 | EXPNAME=byol-resnet50-$DATASET 6 | 7 | orion -vv hunt -n $EXPNAME --config=../../../orion_config.yaml \ 8 | python3 ../../../main_pretrain.py \ 9 | --wandb \ 10 | --name="${EXPNAME}-{trial.hash_params}" \ 11 | --group=${EXPNAME} \ 12 | --entity il_group \ 13 | --project VIL \ 14 | --save_checkpoint \ 15 | --checkpoint_dir="${ROOT_PATH}/{trial.hash_params}" \ 16 | --auto_resume \ 17 | --data_dir=${DATA_PATH} \ 18 | --train_dir ILSVRC2012/train \ 19 | --val_dir ILSVRC2012/val \ 20 | --max_epochs~'fidelity(low=100,high=1000,base=4)' \ 21 | --dataset imagenet \ 22 | --backbone resnet50 \ 23 | --num_workers=${SLURM_CPUS_PER_TASK} \ 24 | --gpus 0,1 \ 25 | --accelerator gpu \ 26 | --strategy ddp \ 27 | --sync_batchnorm \ 28 | --precision 16 \ 29 | --optimizer sgd \ 30 | --lars \ 31 | --eta_lars 0.001 \ 32 | --exclude_bias_n_norm \ 33 | --scheduler warmup_cosine \ 34 | --lr 0.45 \ 35 | --accumulate_grad_batches 16 \ 36 | --classifier_lr 0.2 \ 37 | --weight_decay 1e-6 \ 38 | --batch_size 128 \ 39 | --num_workers 4 \ 40 | --dali \ 41 | --brightness 0.4 \ 42 | --contrast 0.4 \ 43 | --saturation 0.2 \ 44 | --hue 0.1 \ 45 | --gaussian_prob 1.0 0.1 \ 46 | --solarization_prob 0.0 0.2 \ 47 | --num_crops_per_aug 1 1 \ 48 | --method byol \ 49 | --proj_output_dim 256 \ 50 | --proj_hidden_dim 4096 \ 51 | --pred_hidden_dim 4096 \ 52 | --base_tau_momentum 0.99 \ 53 | --final_tau_momentum 1.0 \ 54 | --momentum_classifier 55 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet/mocov2plus.sh: -------------------------------------------------------------------------------- 1 | python3 main_pretrain.py \ 2 | --dataset imagenet \ 3 | --backbone resnet50 \ 4 | --data_dir /data/datasets \ 5 | --train_dir imagenet/train \ 6 | --val_dir imagenet/val \ 7 | --max_epochs 100 \ 8 | --gpus 0,1 \ 9 | --accelerator ddp \ 10 | --sync_batchnorm \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler warmup_cosine \ 14 | --lr 0.3 \ 15 | --classifier_lr 0.4 \ 16 | --weight_decay 3e-5 \ 17 | --batch_size 128 \ 18 | --num_workers 5 \ 19 | --dali \ 20 | --brightness 0.4 \ 21 | --contrast 0.4 \ 22 | --saturation 0.2 \ 23 | --hue 0.1 \ 24 | --gaussian_prob 1.0 0.1 \ 25 | --solarization_prob 0.0 0.2 \ 26 | --num_crops_per_aug 1 1 \ 27 | --name mocov2plus-resnet50-imagenet-100ep \ 28 | --project solo-learn \ 29 | --entity unitn-mhug \ 30 | --wandb \ 31 | --save_checkpoint \ 32 | --method mocov2plus \ 33 | --proj_hidden_dim 2048 \ 34 | --queue_size 65536 \ 35 | --temperature 0.1 \ 36 | --base_tau_momentum 0.99 \ 37 | --final_tau_momentum 0.999 \ 38 | --momentum_classifier \ 39 | --auto_resume 40 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet/orion_config.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | algorithms: 3 | random: 4 | seed: 421 5 | 6 | strategy: StubParallelStrategy 7 | max_broken: 200 8 | working_dir: . 9 | 10 | evc: 11 | enable: True 12 | algorithm_change: False 13 | cli_change_type: noeffect 14 | code_change_type: noeffect 15 | config_change_type: noeffect 16 | orion_version_change: False 17 | ignore_code_changes: True 18 | manual_resolution: False 19 | non_monitored_arguments: [] 20 | 21 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet/sdbyol.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ROOT_PATH=$SCRATCH/sdbyol_search_IN 5 | 6 | ../../../prepare_data.sh TEST 7 | source ~/env/bin/activate 8 | 9 | orion hunt -n orion_sdbyol_y_mila6 --config=orion_config.yaml \ 10 | python3 ../../../main_pretrain.py \ 11 | --dataset imagenet \ 12 | --backbone resnet50 \ 13 | --checkpoint_dir="$ROOT_PATH/{trial.hash_params}" \ 14 | --data_dir $SLURM_TMPDIR/data \ 15 | --train_dir $SLURM_TMPDIR/data/train/ \ 16 | --val_dir $SLURM_TMPDIR/data/val/ \ 17 | --max_epochs 75 \ 18 | --gpus 0,1 \ 19 | --accelerator gpu \ 20 | --strategy ddp \ 21 | --sync_batchnorm \ 22 | --precision 16 \ 23 | --optimizer sgd \ 24 | --lars \ 25 | --eta_lars 0.001 \ 26 | --exclude_bias_n_norm \ 27 | --scheduler warmup_cosine \ 28 | --lr 0.45 \ 29 | --accumulate_grad_batches 16 \ 30 | --classifier_lr 0.2 \ 31 | --weight_decay 1e-6 \ 32 | --batch_size 128 \ 33 | --num_workers 4 \ 34 | --dali \ 35 | --brightness 0.4 \ 36 | --contrast 0.4 \ 37 | --saturation 0.2 \ 38 | --hue 0.1 \ 39 | --gaussian_prob 1.0 0.1 \ 40 | --solarization_prob 0.0 0.2 \ 41 | --num_crops_per_aug 1 1 \ 42 | --name "orion_sdbyol-{trial.hash_params}" \ 43 | --entity il_group \ 44 | --project sem_neurips \ 45 | --wandb \ 46 | --save_checkpoint \ 47 | --method sdbyol \ 48 | --proj_output_dim 256 \ 49 | --proj_hidden_dim 4096 \ 50 | --pred_hidden_dim 4096 \ 51 | --base_tau_momentum 0.99 \ 52 | --final_tau_momentum 1.0 \ 53 | --voc_size=24 \ 54 | --message_size=5000 \ 55 | --tau_online~"loguniform(0.1,5,precision=3)" \ 56 | --tau_target~"loguniform(0.1,5,precision=3)" \ 57 | --momentum_classifier 58 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet/sdbyol_800.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ORION 4 | export ORION_DB_TYPE=pickleddb 5 | export ORION_DB_NAME=vil 6 | export ORION_DB_ADDRESS=${1:-$PWD}/orion_db.pkl 7 | 8 | # ENVIRONMENT 9 | ROOT_PATH=$SCRATCH/sdbyol_search_IN_400 10 | module load python/3.8 11 | source ~/env_1.12/bin/activate 12 | 13 | # LOADING DATA 14 | ../../../prepare_data.sh TEST 15 | 16 | # RUNING SCRIPT 17 | orion hunt -n orion_sdbyol_800 --config=orion_config.yaml \ 18 | python3 ../../../main_pretrain.py \ 19 | --devices 2 \ 20 | --data_dir $SLURM_TMPDIR/data \ 21 | --train_dir $SLURM_TMPDIR/data/train/ \ 22 | --val_dir $SLURM_TMPDIR/data/val/ \ 23 | --checkpoint_dir=$ROOT_PATH \ 24 | --entity il_group \ 25 | --project sem_neurips \ 26 | --dataset imagenet \ 27 | --backbone resnet50 \ 28 | --accumulate_grad_batches 1 \ 29 | --name "sdbyol-ep:800" \ 30 | --max_epochs 800 \ 31 | --batch_size 32 \ 32 | --tau_online~"uniform(0.1,10,precision=2)" \ 33 | --tau_target~"uniform(0.1,10,precision=2)" \ 34 | --voc_size 21 \ 35 | --message_size 5000 \ 36 | --accelerator gpu \ 37 | --strategy ddp \ 38 | --sync_batchnorm \ 39 | --precision 16 \ 40 | --optimizer sgd \ 41 | --lars \ 42 | --eta_lars 0.001 \ 43 | --exclude_bias_n_norm \ 44 | --scheduler warmup_cosine \ 45 | --lr 0.2 \ 46 | --classifier_lr 0.2 \ 47 | --weight_decay 1.5e-6 \ 48 | --num_workers 16 \ 49 | --brightness 0.4 \ 50 | --contrast 0.4 \ 51 | --saturation 0.2 \ 52 | --hue 0.1 \ 53 | --gaussian_prob 1.0 0.1 \ 54 | --solarization_prob 0.0 0.2 \ 55 | --num_crops_per_aug 1 1 \ 56 | --wandb \ 57 | --offline \ 58 | --save_checkpoint \ 59 | --method sdbyol \ 60 | --proj_output_dim 256 \ 61 | --proj_hidden_dim 4096 \ 62 | --pred_hidden_dim 4096 \ 63 | --base_tau_momentum 0.996 \ 64 | --final_tau_momentum 1.0 \ 65 | --auto_resume \ 66 | --momentum_classifier 67 | 68 | -------------------------------------------------------------------------------- /bash_files/pretrain/imagenet/sdbyol_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT_PATH=$1 4 | 5 | ../../../prepare_data.sh TEST 6 | 7 | module load python/3.8 8 | source ~/env/bin/activate 9 | 10 | python3 ../../../main_pretrain.py \ 11 | --dataset imagenet \ 12 | --backbone resnet50 \ 13 | --checkpoint_dir=$ROOT_PATH \ 14 | --data_dir $SLURM_TMPDIR/data \ 15 | --train_dir $SLURM_TMPDIR/data/train/ \ 16 | --val_dir $SLURM_TMPDIR/data/val/ \ 17 | --max_epochs 200 \ 18 | --gpus 0,1 \ 19 | --accelerator gpu \ 20 | --strategy ddp \ 21 | --sync_batchnorm \ 22 | --precision 16 \ 23 | --optimizer sgd \ 24 | --lars \ 25 | --eta_lars 0.001 \ 26 | --exclude_bias_n_norm \ 27 | --scheduler warmup_cosine \ 28 | --lr 0.40 \ 29 | --accumulate_grad_batches 16 \ 30 | --classifier_lr 0.2 \ 31 | --weight_decay 1e-6 \ 32 | --batch_size 128 \ 33 | --num_workers 8 \ 34 | --brightness 0.4 \ 35 | --contrast 0.4 \ 36 | --saturation 0.2 \ 37 | --hue 0.1 \ 38 | --gaussian_prob 1.0 0.1 \ 39 | --solarization_prob 0.0 0.2 \ 40 | --num_crops_per_aug 1 1 \ 41 | --name "orion_sdbyol-FULL-200" \ 42 | --offline \ 43 | --entity ENTITY \ 44 | --project PROJECT \ 45 | --wandb \ 46 | --save_checkpoint \ 47 | --method sdbyol \ 48 | --proj_output_dim 256 \ 49 | --proj_hidden_dim 4096 \ 50 | --pred_hidden_dim 4096 \ 51 | --base_tau_momentum 0.99 \ 52 | --final_tau_momentum 1.0 \ 53 | --voc_size=21 \ 54 | --message_size=4500 \ 55 | --tau_online=0.872 \ 56 | --tau_target=0.159 \ 57 | --auto_resume \ 58 | --momentum_classifier 59 | -------------------------------------------------------------------------------- /bash_files/umap/imagenet-100/umap.sh: -------------------------------------------------------------------------------- 1 | python3 ../../../main_umap.py \ 2 | --dataset imagenet100 \ 3 | --data_dir /datasets \ 4 | --train_dir imagenet-100/train \ 5 | --val_dir imagenet-100/val \ 6 | --batch_size 16 \ 7 | --num_workers 10 \ 8 | --pretrained_checkpoint_dir PATH 9 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-argparse 3 | sphinx-rtd-theme 4 | sphinxcontrib-napoleon 5 | chardet==3.0.4 6 | # for umap 7 | matplotlib 8 | seaborn 9 | pandas 10 | umap-learn -------------------------------------------------------------------------------- /docs/source/solo/args.rst: -------------------------------------------------------------------------------- 1 | Args 2 | ==== 3 | 4 | dataset 5 | ------- 6 | 7 | dataset_args 8 | ~~~~~~~~~~~~ 9 | 10 | .. autofunction:: solo.args.dataset.dataset_args 11 | :noindex: 12 | 13 | 14 | augmentations_args 15 | ~~~~~~~~~~~~~~~~~~ 16 | 17 | .. autofunction:: solo.args.dataset.augmentations_args 18 | :noindex: 19 | 20 | 21 | setup 22 | ----- 23 | 24 | parse_args_pretrain 25 | ~~~~~~~~~~~~~~~~~~~ 26 | 27 | .. autofunction:: solo.args.setup.parse_args_pretrain 28 | :noindex: 29 | 30 | 31 | parse_args_linear 32 | ~~~~~~~~~~~~~~~~~ 33 | 34 | .. autofunction:: solo.args.setup.parse_args_linear 35 | :noindex: 36 | 37 | 38 | utils 39 | ----- 40 | 41 | parse_args_pretrain 42 | ~~~~~~~~~~~~~~~~~~~ 43 | 44 | .. autofunction:: solo.args.utils.additional_setup_pretrain 45 | :noindex: 46 | 47 | 48 | parse_args_linear 49 | ~~~~~~~~~~~~~~~~~ 50 | 51 | .. autofunction:: solo.args.utils.additional_setup_linear 52 | :noindex: 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /docs/source/solo/losses/barlow.rst: -------------------------------------------------------------------------------- 1 | Barlow Twins 2 | ------------ 3 | 4 | .. autofunction:: solo.losses.barlow.barlow_loss_func 5 | :noindex: 6 | -------------------------------------------------------------------------------- /docs/source/solo/losses/byol.rst: -------------------------------------------------------------------------------- 1 | BYOL 2 | ---- 3 | 4 | .. autofunction:: solo.losses.byol.byol_loss_func 5 | :noindex: 6 | -------------------------------------------------------------------------------- /docs/source/solo/losses/deepclusterv2.rst: -------------------------------------------------------------------------------- 1 | DeepClusterV2 2 | ------------- 3 | 4 | .. autofunction:: solo.losses.deepclusterv2.deepclusterv2_loss_func 5 | :noindex: 6 | -------------------------------------------------------------------------------- /docs/source/solo/losses/dino.rst: -------------------------------------------------------------------------------- 1 | DINO 2 | ---- 3 | 4 | .. automethod:: solo.losses.dino.DINOLoss.__init__ 5 | :noindex: 6 | -------------------------------------------------------------------------------- /docs/source/solo/losses/mocov2plus.rst: -------------------------------------------------------------------------------- 1 | MoCo-V2 2 | ------- 3 | 4 | .. autofunction:: solo.losses.moco.moco_loss_func 5 | :noindex: -------------------------------------------------------------------------------- /docs/source/solo/losses/nnclr.rst: -------------------------------------------------------------------------------- 1 | NNCLR 2 | ----- 3 | 4 | .. autofunction:: solo.losses.nnclr.nnclr_loss_func 5 | :noindex: -------------------------------------------------------------------------------- /docs/source/solo/losses/ressl.rst: -------------------------------------------------------------------------------- 1 | ReSSL 2 | ----- 3 | 4 | .. autofunction:: solo.losses.ressl.ressl_loss_func 5 | :noindex: -------------------------------------------------------------------------------- /docs/source/solo/losses/simclr.rst: -------------------------------------------------------------------------------- 1 | SimCLR 2 | ------ 3 | 4 | .. autofunction:: solo.losses.simclr.simclr_loss_func 5 | :noindex: 6 | 7 | .. autofunction:: solo.losses.simclr.manual_simclr_loss_func 8 | -------------------------------------------------------------------------------- /docs/source/solo/losses/simsiam.rst: -------------------------------------------------------------------------------- 1 | SimSiam 2 | ------- 3 | 4 | .. autofunction:: solo.losses.simsiam.simsiam_loss_func 5 | :noindex: 6 | -------------------------------------------------------------------------------- /docs/source/solo/losses/swav.rst: -------------------------------------------------------------------------------- 1 | SwAV 2 | ---- 3 | 4 | .. autofunction:: solo.losses.swav.swav_loss_func 5 | :noindex: 6 | -------------------------------------------------------------------------------- /docs/source/solo/losses/vibcreg.rst: -------------------------------------------------------------------------------- 1 | VIbCReg 2 | ------- 3 | 4 | .. autofunction:: solo.losses.vibcreg.vibcreg_loss_func 5 | :noindex: -------------------------------------------------------------------------------- /docs/source/solo/losses/vicreg.rst: -------------------------------------------------------------------------------- 1 | VICReg 2 | ------ 3 | 4 | .. autofunction:: solo.losses.vicreg.vicreg_loss_func 5 | :noindex: -------------------------------------------------------------------------------- /docs/source/solo/losses/wmse.rst: -------------------------------------------------------------------------------- 1 | W-MSE 2 | ----- 3 | 4 | .. autofunction:: solo.losses.wmse.wmse_loss_func 5 | :noindex: 6 | -------------------------------------------------------------------------------- /docs/source/solo/methods/barlow.rst: -------------------------------------------------------------------------------- 1 | Barlow Twins 2 | ============ 3 | 4 | 5 | .. automethod:: solo.methods.barlow_twins.BarlowTwins.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.barlow_twins.BarlowTwins.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.barlow_twins.BarlowTwins.learnable_params 16 | :noindex: 17 | 18 | forward 19 | ~~~~~~~ 20 | .. automethod:: solo.methods.barlow_twins.BarlowTwins.forward 21 | :noindex: 22 | 23 | training_step 24 | ~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.barlow_twins.BarlowTwins.training_step 26 | :noindex: 27 | -------------------------------------------------------------------------------- /docs/source/solo/methods/byol.rst: -------------------------------------------------------------------------------- 1 | BYOL 2 | ==== 3 | 4 | 5 | .. automethod:: solo.methods.byol.BYOL.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.byol.BYOL.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.byol.BYOL.learnable_params 16 | :noindex: 17 | 18 | momentum_pairs 19 | ~~~~~~~~~~~~~~ 20 | .. autoattribute:: solo.methods.byol.BYOL.momentum_pairs 21 | :noindex: 22 | 23 | forward 24 | ~~~~~~~ 25 | .. automethod:: solo.methods.byol.BYOL.forward 26 | :noindex: 27 | 28 | training_step 29 | ~~~~~~~~~~~~~ 30 | .. automethod:: solo.methods.byol.BYOL.training_step 31 | :noindex: 32 | -------------------------------------------------------------------------------- /docs/source/solo/methods/dali.rst: -------------------------------------------------------------------------------- 1 | DALI 2 | ==== 3 | 4 | DALI.PretrainABC 5 | ---------------- 6 | .. autoclass:: solo.methods.dali.PretrainABC 7 | 8 | 9 | DALI.ClassificationABC 10 | ---------------------- 11 | .. autoclass:: solo.methods.dali.ClassificationABC 12 | -------------------------------------------------------------------------------- /docs/source/solo/methods/deepclusterv2.rst: -------------------------------------------------------------------------------- 1 | DeepClusterV2 2 | ============= 3 | 4 | 5 | .. automethod:: solo.methods.deepclusterv2.DeepClusterV2.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.deepclusterv2.DeepClusterV2.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.deepclusterv2.DeepClusterV2.learnable_params 16 | :noindex: 17 | 18 | on_train_start 19 | ~~~~~~~~~~~~~~ 20 | .. automethod:: solo.methods.deepclusterv2.DeepClusterV2.on_train_start 21 | :noindex: 22 | 23 | on_train_epoch_start 24 | ~~~~~~~~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.deepclusterv2.DeepClusterV2.on_train_epoch_start 26 | :noindex: 27 | 28 | update_memory_banks 29 | ~~~~~~~~~~~~~~~~~~~ 30 | .. automethod:: solo.methods.deepclusterv2.DeepClusterV2.update_memory_banks 31 | :noindex: 32 | 33 | forward 34 | ~~~~~~~ 35 | .. automethod:: solo.methods.deepclusterv2.DeepClusterV2.forward 36 | :noindex: 37 | 38 | training_step 39 | ~~~~~~~~~~~~~ 40 | .. automethod:: solo.methods.deepclusterv2.DeepClusterV2.training_step 41 | :noindex: -------------------------------------------------------------------------------- /docs/source/solo/methods/dino.rst: -------------------------------------------------------------------------------- 1 | DINO 2 | ==== 3 | 4 | 5 | .. automethod:: solo.methods.dino.DINO.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.dino.DINO.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.dino.DINO.learnable_params 16 | :noindex: 17 | 18 | momentum_pairs 19 | ~~~~~~~~~~~~~~ 20 | .. autoattribute:: solo.methods.dino.DINO.momentum_pairs 21 | :noindex: 22 | 23 | clip_gradients 24 | ~~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.dino.DINO.clip_gradients 26 | :noindex: 27 | 28 | 29 | forward 30 | ~~~~~~~ 31 | .. automethod:: solo.methods.dino.DINO.forward 32 | :noindex: 33 | 34 | training_step 35 | ~~~~~~~~~~~~~ 36 | .. automethod:: solo.methods.dino.DINO.training_step 37 | :noindex: 38 | 39 | on_after_backward 40 | ~~~~~~~~~~~~~~~~~ 41 | .. automethod:: solo.methods.dino.DINO.on_after_backward 42 | :noindex: 43 | -------------------------------------------------------------------------------- /docs/source/solo/methods/linear.rst: -------------------------------------------------------------------------------- 1 | LinearModel 2 | =========== 3 | 4 | 5 | .. automethod:: solo.methods.linear.LinearModel.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.linear.LinearModel.add_model_specific_args 11 | :noindex: 12 | 13 | configure_optimizers 14 | ~~~~~~~~~~~~~~~~~~~~ 15 | .. automethod:: solo.methods.linear.LinearModel.configure_optimizers 16 | :noindex: 17 | 18 | forward 19 | ~~~~~~~ 20 | .. automethod:: solo.methods.linear.LinearModel.forward 21 | :noindex: 22 | 23 | shared_step 24 | ~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.linear.LinearModel.shared_step 26 | :noindex: 27 | 28 | training_step 29 | ~~~~~~~~~~~~~ 30 | .. automethod:: solo.methods.linear.LinearModel.training_step 31 | :noindex: 32 | 33 | validation_step 34 | ~~~~~~~~~~~~~~~ 35 | .. automethod:: solo.methods.linear.LinearModel.validation_step 36 | :noindex: 37 | 38 | validation_epoch_end 39 | ~~~~~~~~~~~~~~~~~~~~ 40 | .. automethod:: solo.methods.linear.LinearModel.validation_epoch_end 41 | :noindex: 42 | -------------------------------------------------------------------------------- /docs/source/solo/methods/mocov2plus.rst: -------------------------------------------------------------------------------- 1 | MoCo-V2 2 | ======= 3 | 4 | 5 | .. automethod:: solo.methods.mocov2plus.MoCoV2Plus.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.mocov2plus.MoCoV2Plus.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.mocov2plus.MoCoV2Plus.learnable_params 16 | :noindex: 17 | 18 | momentum_pairs 19 | ~~~~~~~~~~~~~~ 20 | .. autoattribute:: solo.methods.mocov2plus.MoCoV2Plus.momentum_pairs 21 | :noindex: 22 | 23 | _dequeue_and_enqueue 24 | ~~~~~~~~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.mocov2plus.MoCoV2Plus._dequeue_and_enqueue 26 | :noindex: 27 | 28 | forward 29 | ~~~~~~~ 30 | .. automethod:: solo.methods.mocov2plus.MoCoV2Plus.forward 31 | :noindex: 32 | 33 | training_step 34 | ~~~~~~~~~~~~~ 35 | .. automethod:: solo.methods.mocov2plus.MoCoV2Plus.training_step 36 | :noindex: 37 | -------------------------------------------------------------------------------- /docs/source/solo/methods/nnbyol.rst: -------------------------------------------------------------------------------- 1 | NNBYOL 2 | ====== 3 | 4 | .. automethod:: solo.methods.nnbyol.NNBYOL.__init__ 5 | :noindex: 6 | 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.nnbyol.NNBYOL.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.nnbyol.NNBYOL.learnable_params 16 | :noindex: 17 | 18 | dequeue_and_enqueue 19 | ~~~~~~~~~~~~~~~~~~~ 20 | .. automethod:: solo.methods.nnbyol.NNBYOL.dequeue_and_enqueue 21 | :noindex: 22 | 23 | find_nn 24 | ~~~~~~~~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.nnbyol.NNBYOL.find_nn 26 | :noindex: 27 | 28 | forward 29 | ~~~~~~~ 30 | .. automethod:: solo.methods.nnbyol.NNBYOL.forward 31 | :noindex: 32 | 33 | training_step 34 | ~~~~~~~~~~~~~ 35 | .. automethod:: solo.methods.nnbyol.NNBYOL.training_step 36 | :noindex: 37 | -------------------------------------------------------------------------------- /docs/source/solo/methods/nnclr.rst: -------------------------------------------------------------------------------- 1 | NNCLR 2 | ===== 3 | 4 | .. automethod:: solo.methods.nnclr.NNCLR.__init__ 5 | :noindex: 6 | 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.nnclr.NNCLR.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.nnclr.NNCLR.learnable_params 16 | :noindex: 17 | 18 | dequeue_and_enqueue 19 | ~~~~~~~~~~~~~~~~~~~ 20 | .. automethod:: solo.methods.nnclr.NNCLR.dequeue_and_enqueue 21 | :noindex: 22 | 23 | find_nn 24 | ~~~~~~~~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.nnclr.NNCLR.find_nn 26 | :noindex: 27 | 28 | forward 29 | ~~~~~~~ 30 | .. automethod:: solo.methods.nnclr.NNCLR.forward 31 | :noindex: 32 | 33 | training_step 34 | ~~~~~~~~~~~~~ 35 | .. automethod:: solo.methods.nnclr.NNCLR.training_step 36 | :noindex: 37 | -------------------------------------------------------------------------------- /docs/source/solo/methods/nnsiam.rst: -------------------------------------------------------------------------------- 1 | NNSiam 2 | ====== 3 | 4 | .. automethod:: solo.methods.nnsiam.NNSiam.__init__ 5 | :noindex: 6 | 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.nnsiam.NNSiam.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.nnsiam.NNSiam.learnable_params 16 | :noindex: 17 | 18 | dequeue_and_enqueue 19 | ~~~~~~~~~~~~~~~~~~~ 20 | .. automethod:: solo.methods.nnsiam.NNSiam.dequeue_and_enqueue 21 | :noindex: 22 | 23 | find_nn 24 | ~~~~~~~~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.nnsiam.NNSiam.find_nn 26 | :noindex: 27 | 28 | forward 29 | ~~~~~~~ 30 | .. automethod:: solo.methods.nnsiam.NNSiam.forward 31 | :noindex: 32 | 33 | training_step 34 | ~~~~~~~~~~~~~ 35 | .. automethod:: solo.methods.nnsiam.NNSiam.training_step 36 | :noindex: 37 | -------------------------------------------------------------------------------- /docs/source/solo/methods/ressl.rst: -------------------------------------------------------------------------------- 1 | ReSSL 2 | ===== 3 | 4 | .. automethod:: solo.methods.ressl.ReSSL.__init__ 5 | :noindex: 6 | 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.ressl.ReSSL.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.ressl.ReSSL.learnable_params 16 | :noindex: 17 | 18 | momentum_pairs 19 | ~~~~~~~~~~~~~~ 20 | .. autoattribute:: solo.methods.ressl.ReSSL.momentum_pairs 21 | :noindex: 22 | 23 | dequeue_and_enqueue 24 | ~~~~~~~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.ressl.ReSSL.dequeue_and_enqueue 26 | :noindex: 27 | 28 | forward 29 | ~~~~~~~ 30 | .. automethod:: solo.methods.ressl.ReSSL.forward 31 | :noindex: 32 | 33 | training_step 34 | ~~~~~~~~~~~~~ 35 | .. automethod:: solo.methods.ressl.ReSSL.training_step 36 | :noindex: 37 | -------------------------------------------------------------------------------- /docs/source/solo/methods/simclr.rst: -------------------------------------------------------------------------------- 1 | SimCLR 2 | ============ 3 | 4 | 5 | .. automethod:: solo.methods.simclr.SimCLR.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.simclr.SimCLR.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.simclr.SimCLR.learnable_params 16 | :noindex: 17 | 18 | forward 19 | ~~~~~~~ 20 | .. automethod:: solo.methods.simclr.SimCLR.forward 21 | :noindex: 22 | 23 | gen_extra_positives_gt 24 | ~~~~~~~~~~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.simclr.SimCLR.gen_extra_positives_gt 26 | :noindex: 27 | 28 | training_step 29 | ~~~~~~~~~~~~~ 30 | .. automethod:: solo.methods.simclr.SimCLR.training_step 31 | :noindex: 32 | -------------------------------------------------------------------------------- /docs/source/solo/methods/simsiam.rst: -------------------------------------------------------------------------------- 1 | SimSiam 2 | ======= 3 | 4 | 5 | .. automethod:: solo.methods.simsiam.SimSiam.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.simsiam.SimSiam.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.simsiam.SimSiam.learnable_params 16 | :noindex: 17 | 18 | forward 19 | ~~~~~~~ 20 | .. automethod:: solo.methods.simsiam.SimSiam.forward 21 | :noindex: 22 | 23 | training_step 24 | ~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.simsiam.SimSiam.training_step 26 | :noindex: 27 | -------------------------------------------------------------------------------- /docs/source/solo/methods/swav.rst: -------------------------------------------------------------------------------- 1 | SwAV 2 | ==== 3 | 4 | 5 | .. automethod:: solo.methods.swav.SwAV.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.swav.SwAV.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.swav.SwAV.learnable_params 16 | :noindex: 17 | 18 | on_train_start 19 | ~~~~~~~~~~~~~~ 20 | .. automethod:: solo.methods.swav.SwAV.on_train_start 21 | :noindex: 22 | 23 | forward 24 | ~~~~~~~ 25 | .. automethod:: solo.methods.swav.SwAV.forward 26 | :noindex: 27 | 28 | get_assignments 29 | ~~~~~~~~~~~~~~~ 30 | .. automethod:: solo.methods.swav.SwAV.get_assignments 31 | :noindex: 32 | 33 | training_step 34 | ~~~~~~~~~~~~~ 35 | .. automethod:: solo.methods.swav.SwAV.training_step 36 | :noindex: 37 | 38 | on_after_backward 39 | ~~~~~~~~~~~~~~~~~ 40 | .. automethod:: solo.methods.swav.SwAV.on_after_backward 41 | :noindex: 42 | -------------------------------------------------------------------------------- /docs/source/solo/methods/vibcreg.rst: -------------------------------------------------------------------------------- 1 | VIbCReg 2 | ======= 3 | 4 | 5 | .. automethod:: solo.methods.vibcreg.VIbCReg.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.vibcreg.VIbCReg.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.vibcreg.VIbCReg.learnable_params 16 | :noindex: 17 | 18 | forward 19 | ~~~~~~~ 20 | .. automethod:: solo.methods.vibcreg.VIbCReg.forward 21 | :noindex: 22 | 23 | training_step 24 | ~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.vibcreg.VIbCReg.training_step 26 | :noindex: 27 | -------------------------------------------------------------------------------- /docs/source/solo/methods/vicreg.rst: -------------------------------------------------------------------------------- 1 | VICReg 2 | ====== 3 | 4 | 5 | .. automethod:: solo.methods.vicreg.VICReg.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.vicreg.VICReg.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.vicreg.VICReg.learnable_params 16 | :noindex: 17 | 18 | forward 19 | ~~~~~~~ 20 | .. automethod:: solo.methods.vicreg.VICReg.forward 21 | :noindex: 22 | 23 | training_step 24 | ~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.vicreg.VICReg.training_step 26 | :noindex: 27 | -------------------------------------------------------------------------------- /docs/source/solo/methods/wmse.rst: -------------------------------------------------------------------------------- 1 | W-MSE 2 | ===== 3 | 4 | 5 | .. automethod:: solo.methods.wmse.WMSE.__init__ 6 | :noindex: 7 | 8 | add_model_specific_args 9 | ~~~~~~~~~~~~~~~~~~~~~~~ 10 | .. automethod:: solo.methods.wmse.WMSE.add_model_specific_args 11 | :noindex: 12 | 13 | learnable_params 14 | ~~~~~~~~~~~~~~~~ 15 | .. autoattribute:: solo.methods.wmse.WMSE.learnable_params 16 | :noindex: 17 | 18 | forward 19 | ~~~~~~~ 20 | .. automethod:: solo.methods.wmse.WMSE.forward 21 | :noindex: 22 | 23 | training_step 24 | ~~~~~~~~~~~~~ 25 | .. automethod:: solo.methods.wmse.WMSE.training_step 26 | :noindex: 27 | -------------------------------------------------------------------------------- /docs/source/start/available.rst: -------------------------------------------------------------------------------- 1 | ***************** 2 | Methods available 3 | ***************** 4 | 5 | * `Barlow Twins `_ 6 | * `BYOL `_ 7 | * `MoCo-V2 `_ 8 | * `NNCLR `_ 9 | * `SimCLR `_ + `Supervised Contrastive Learning `_ 10 | * `SimSiam `_ 11 | * `SwAV `_ 12 | * `VICReg `_ 13 | * `W-MSE `_ 14 | 15 | ************ 16 | Extra flavor 17 | ************ 18 | 19 | Data 20 | ==== 21 | 22 | * Increased data processing speed by up to 100% using `Nvidia Dali `_ 23 | * Asymmetric and symmetric augmentations 24 | 25 | Evaluation and logging 26 | ====================== 27 | 28 | 29 | * Online linear evaluation via stop-gradient for easier debugging and prototyping (optionally available for the momentum backbone as well) 30 | * Normal offline linear evaluation 31 | * All the perks of PyTorch Lightning (mixed precision, gradient accumulation, clipping, automatic logging and much more) 32 | * Easy-to-extend modular code structure 33 | * Custom model logging with a simpler file organization 34 | * Automatic feature space visualization with UMAP 35 | * Common metrics and more to come... 36 | 37 | 38 | Training tricks 39 | =============== 40 | 41 | * Multi-cropping dataloading following `SwAV `_: 42 | * **Note**: currently, only SimCLR supports this 43 | * Exclude batchnorm and biases from LARS 44 | * No LR scheduler for the projection head in SimSiam -------------------------------------------------------------------------------- /docs/source/start/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ************ 3 | 4 | To install the repository with Dali and/or UMAP support, use: 5 | 6 | .. code-block:: bash 7 | 8 | pip3 install .[dali,umap] 9 | 10 | If no Dali/UMAP support is needed, the repository can be installed as: 11 | 12 | .. code-block:: bash 13 | 14 | pip3 install . 15 | 16 | 17 | 18 | **NOTE:** If you want to modify the library, install it in dev mode with ``-e``. 19 | 20 | **NOTE 2:** Soon to be on pip. 21 | 22 | **Requirements:** 23 | 24 | .. code-block:: python 25 | 26 | torch 27 | tqdm 28 | einops 29 | wandb 30 | pytorch-lightning 31 | lightning-bolts 32 | 33 | # Optional 34 | nvidia-dali 35 | 36 | **NOTE:** if you are using CUDA 10.X use ``nvidia-dali-cuda100`` in ``requirements.txt``. -------------------------------------------------------------------------------- /docs/source/tutorials/imgs/im100_train_umap_barlow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lavoiems/simplicial-embeddings/d8fcecee2e16df9c24dea85c82c2916b45b9ebcf/docs/source/tutorials/imgs/im100_train_umap_barlow.png -------------------------------------------------------------------------------- /docs/source/tutorials/imgs/im100_train_umap_byol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lavoiems/simplicial-embeddings/d8fcecee2e16df9c24dea85c82c2916b45b9ebcf/docs/source/tutorials/imgs/im100_train_umap_byol.png -------------------------------------------------------------------------------- /docs/source/tutorials/imgs/im100_train_umap_random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lavoiems/simplicial-embeddings/d8fcecee2e16df9c24dea85c82c2916b45b9ebcf/docs/source/tutorials/imgs/im100_train_umap_random.png -------------------------------------------------------------------------------- /docs/source/tutorials/imgs/im100_val_umap_barlow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lavoiems/simplicial-embeddings/d8fcecee2e16df9c24dea85c82c2916b45b9ebcf/docs/source/tutorials/imgs/im100_val_umap_barlow.png -------------------------------------------------------------------------------- /docs/source/tutorials/imgs/im100_val_umap_byol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lavoiems/simplicial-embeddings/d8fcecee2e16df9c24dea85c82c2916b45b9ebcf/docs/source/tutorials/imgs/im100_val_umap_byol.png -------------------------------------------------------------------------------- /docs/source/tutorials/imgs/im100_val_umap_random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lavoiems/simplicial-embeddings/d8fcecee2e16df9c24dea85c82c2916b45b9ebcf/docs/source/tutorials/imgs/im100_val_umap_random.png -------------------------------------------------------------------------------- /downstream/object_detection/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## Transferring to Detection 5 | 6 | The `train_object_detection.py` script reproduces the object detection experiments on Pascal VOC and COCO. 7 | 8 | ### Instruction 9 | 10 | 1. Install [detectron2](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). 11 | 12 | 1. Convert a pre-trained model to detectron2's format: 13 | ``` 14 | python3 convert_model_to_detectron2.py --pretrained_feature_extractor PATH_TO_CKPT --output_detectron_model ./detectron_model.pkl 15 | ``` 16 | 17 | 1. Put dataset under "./datasets" directory, 18 | following the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets) 19 | requried by detectron2. 20 | 21 | 1. Run training: 22 | ``` 23 | python train_net.py --config-file configs/pascal_voc_R_50_C4_24k_moco.yaml \ 24 | --num-gpus 8 MODEL.WEIGHTS ./detectron_model.pkl 25 | ``` 26 | -------------------------------------------------------------------------------- /downstream/object_detection/configs/Base-RCNN-C4-BN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | RPN: 4 | PRE_NMS_TOPK_TEST: 6000 5 | POST_NMS_TOPK_TEST: 1000 6 | ROI_HEADS: 7 | NAME: "Res5ROIHeadsExtraNorm" 8 | BACKBONE: 9 | FREEZE_AT: 0 10 | RESNETS: 11 | NORM: "SyncBN" 12 | TEST: 13 | PRECISE_BN: 14 | ENABLED: True 15 | SOLVER: 16 | IMS_PER_BATCH: 16 17 | BASE_LR: 0.02 18 | -------------------------------------------------------------------------------- /downstream/object_detection/configs/coco_R_50_C4_2x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-C4-BN.yaml" 2 | MODEL: 3 | MASK_ON: True 4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 5 | INPUT: 6 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 7 | MIN_SIZE_TEST: 800 8 | DATASETS: 9 | TRAIN: ("coco_2017_train",) 10 | TEST: ("coco_2017_val",) 11 | SOLVER: 12 | STEPS: (120000, 160000) 13 | MAX_ITER: 180000 14 | -------------------------------------------------------------------------------- /downstream/object_detection/configs/coco_R_50_C4_2x_moco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "coco_R_50_C4_2x.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "See Instructions" 6 | RESNETS: 7 | STRIDE_IN_1X1: False 8 | INPUT: 9 | FORMAT: "RGB" 10 | -------------------------------------------------------------------------------- /downstream/object_detection/configs/pascal_voc_R_50_C4_24k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-C4-BN.yaml" 2 | MODEL: 3 | MASK_ON: False 4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 5 | ROI_HEADS: 6 | NUM_CLASSES: 20 7 | INPUT: 8 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 9 | MIN_SIZE_TEST: 800 10 | DATASETS: 11 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') 12 | TEST: ('voc_2007_test',) 13 | SOLVER: 14 | STEPS: (18000, 22000) 15 | MAX_ITER: 24000 16 | WARMUP_ITERS: 100 17 | -------------------------------------------------------------------------------- /downstream/object_detection/configs/pascal_voc_R_50_C4_24k_moco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "pascal_voc_R_50_C4_24k.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "See Instructions" 6 | RESNETS: 7 | STRIDE_IN_1X1: False 8 | INPUT: 9 | FORMAT: "RGB" 10 | -------------------------------------------------------------------------------- /downstream/object_detection/convert_model_to_detectron2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from argparse import ArgumentParser 21 | import pickle as pkl 22 | import torch 23 | 24 | if __name__ == "__main__": 25 | parser = ArgumentParser() 26 | parser.add_argument("--pretrained_feature_extractor", type=str, required=True) 27 | parser.add_argument("--output_detectron_model", type=str, required=True) 28 | 29 | args = parser.parse_args() 30 | 31 | checkpoint = torch.load(args.pretrained_feature_extractor, map_location="cpu") 32 | checkpoint = checkpoint["state_dict"] 33 | 34 | newmodel = {} 35 | for k, v in checkpoint.items(): 36 | if not k.startswith("backbone"): 37 | continue 38 | 39 | old_k = k 40 | k = k.replace("backbone.", "") 41 | if "layer" not in k: 42 | k = "stem." + k 43 | for t in [1, 2, 3, 4]: 44 | k = k.replace("layer{}".format(t), "res{}".format(t + 1)) 45 | for t in [1, 2, 3]: 46 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) 47 | k = k.replace("downsample.0", "shortcut") 48 | k = k.replace("downsample.1", "shortcut.norm") 49 | print(old_k, "->", k) 50 | newmodel[k] = v.numpy() 51 | 52 | res = {"model": newmodel, "__author__": "solo-learn", "matching_heuristics": True} 53 | 54 | with open(args.output_detectron_model, "wb") as f: 55 | pkl.dump(res, f) 56 | -------------------------------------------------------------------------------- /downstream/object_detection/run.sh: -------------------------------------------------------------------------------- 1 | export DETECTRON2_DATASETS=/data/datasets 2 | 3 | # good results for BYOL 4 | python3 train_object_detection.py --config-file configs/pascal_voc_R_50_C4_24k_moco.yaml \ 5 | --num-gpus 2 MODEL.WEIGHTS ./detectron_model.pkl SOLVER.IMS_PER_BATCH 16 SOLVER.BASE_LR 0.1 6 | -------------------------------------------------------------------------------- /downstream/object_detection/train_object_detection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Copied from https://github.com/facebookresearch/moco/blob/main/detection/train_net.py 3 | 4 | import os 5 | 6 | from detectron2.checkpoint import DetectionCheckpointer 7 | from detectron2.config import get_cfg 8 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch 9 | from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator 10 | from detectron2.layers import get_norm 11 | from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads 12 | 13 | 14 | @ROI_HEADS_REGISTRY.register() 15 | class Res5ROIHeadsExtraNorm(Res5ROIHeads): 16 | """ 17 | As described in the MOCO paper, there is an extra BN layer 18 | following the res5 stage. 19 | """ 20 | 21 | def _build_res5_block(self, cfg): 22 | seq, out_channels = super()._build_res5_block(cfg) 23 | norm = cfg.MODEL.RESNETS.NORM 24 | norm = get_norm(norm, out_channels) 25 | seq.add_module("norm", norm) 26 | return seq, out_channels 27 | 28 | 29 | class Trainer(DefaultTrainer): 30 | @classmethod 31 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 32 | if output_folder is None: 33 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 34 | if "coco" in dataset_name: 35 | return COCOEvaluator(dataset_name, cfg, True, output_folder) 36 | else: 37 | assert "voc" in dataset_name 38 | return PascalVOCDetectionEvaluator(dataset_name) 39 | 40 | 41 | def setup(args): 42 | cfg = get_cfg() 43 | cfg.merge_from_file(args.config_file) 44 | cfg.merge_from_list(args.opts) 45 | cfg.freeze() 46 | default_setup(cfg, args) 47 | return cfg 48 | 49 | 50 | def main(args): 51 | cfg = setup(args) 52 | 53 | if args.eval_only: 54 | model = Trainer.build_model(cfg) 55 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 56 | cfg.MODEL.WEIGHTS, resume=args.resume 57 | ) 58 | res = Trainer.test(cfg, model) 59 | return res 60 | 61 | trainer = Trainer(cfg) 62 | trainer.resume_or_load(resume=args.resume) 63 | return trainer.train() 64 | 65 | 66 | if __name__ == "__main__": 67 | args = default_argument_parser().parse_args() 68 | print("Command Line Args:", args) 69 | launch( 70 | main, 71 | args.num_gpus, 72 | num_machines=args.num_machines, 73 | machine_rank=args.machine_rank, 74 | dist_url=args.dist_url, 75 | args=(args,), 76 | ) 77 | -------------------------------------------------------------------------------- /main_umap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import json 21 | import os 22 | from pathlib import Path 23 | 24 | from solo.args.setup import parse_args_umap 25 | from solo.methods import METHODS 26 | from solo.utils.auto_umap import OfflineUMAP 27 | from solo.utils.classification_dataloader import prepare_data 28 | 29 | 30 | def main(): 31 | args = parse_args_umap() 32 | 33 | # build paths 34 | ckpt_dir = Path(args.pretrained_checkpoint_dir) 35 | args_path = ckpt_dir / "args.json" 36 | ckpt_path = [ckpt_dir / ckpt for ckpt in os.listdir(ckpt_dir) if ckpt.endswith(".ckpt")][0] 37 | 38 | # load arguments 39 | with open(args_path) as f: 40 | method_args = json.load(f) 41 | 42 | # build the model 43 | model = ( 44 | METHODS[method_args["method"]] 45 | .load_from_checkpoint(ckpt_path, strict=False, **method_args) 46 | .backbone 47 | ) 48 | model.cuda() 49 | 50 | # prepare data 51 | train_loader, val_loader = prepare_data( 52 | args.dataset, 53 | data_dir=args.data_dir, 54 | train_dir=args.train_dir, 55 | val_dir=args.val_dir, 56 | batch_size=args.batch_size, 57 | num_workers=args.num_workers, 58 | ) 59 | 60 | umap = OfflineUMAP() 61 | 62 | # move model to the gpu 63 | device = "cuda:0" 64 | model = model.to(device) 65 | 66 | umap.plot(device, model, train_loader, "im100_train_umap.pdf") 67 | umap.plot(device, model, val_loader, "im100_val_umap.pdf") 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /orion_config.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | algorithms: 3 | random: 4 | seed: 999 5 | 6 | 7 | max_broken: 4096 8 | max_trials: 4096 9 | working_dir: /tmp/solo-orion_asdfasfasgafasdfa 10 | 11 | evc: 12 | enable: False 13 | 14 | worker: 15 | pool_size: 32 16 | reservation_timeout: 360 17 | -------------------------------------------------------------------------------- /prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #### 4 | # 5 | # Usage: ./prepare_data.sh VAL for using a validation set from the training set. Otherwise use test set. 6 | # 7 | ### 8 | 9 | MODE=$1 10 | 11 | [[ $HOSTNAME == login* ]] || [[ $SLURM_CLUSTER_NAME == mila ]]; mila=$?; 12 | 13 | if [ $mila == 0 ]; 14 | then 15 | echo "In Mila - $mila" 16 | SRC_TRAIN=/network/datasets/imagenet/ILSVRC2012_img_train.tar 17 | SRC_VAL=/network/scratch/l/lavoiems/data/imagenet_val.tar 18 | else 19 | echo "In CC - $mila" 20 | SRC_TRAIN=/scratch/lavoiems/data/ILSVRC2012_img_train.tar 21 | SRC_VAL=/scratch/lavoiems/data/imagenet_val.tar 22 | fi 23 | TRG=$SLURM_TMPDIR/data 24 | TRG_TRAIN=$TRG/train 25 | TRG_VAL=$TRG/val 26 | TRG_TMP=$SLURM_TMPDIR/tmp_data 27 | mkdir -p $TRG 28 | mkdir -p $TRG_TRAIN 29 | mkdir -p $TRG_TMP 30 | 31 | echo "Extracting IMAGENET from $SRC_TRAIN into $TRG_TMP" 32 | tar -xf $SRC_TRAIN -C $TRG_TMP 33 | 34 | echo "Extracting training set from $TRG_TMP into $TRG_TRAIN" 35 | for D in `ls $TRG_TMP`; do 36 | mkdir -p "$TRG_TRAIN/${D%.*}" 37 | tar -xf $TRG_TMP/$D -C "$TRG_TRAIN/${D%.*}" & 38 | done 39 | wait 40 | 41 | if [[ $MODE == "VAL" ]]; 42 | then 43 | echo "Taking 10 samples/class as validation sample" 44 | mkdir -p $TRG_VAL 45 | for D in `ls $TRG_TRAIN`; do 46 | mkdir -p $TRG_VAL/$D 47 | ls $TRG_TRAIN/$D/ | head -10 | xargs -i mv $TRG_TRAIN/$D/{} $TRG_VAL/$D/ & 48 | done 49 | wait 50 | else 51 | echo "Extracting test set from $SRC_VAL into $TRG" 52 | tar -xf $SRC_VAL -C $TRG 53 | fi 54 | 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12 2 | torchvision==0.13 3 | einops 4 | pytorch-lightning==1.5.3 5 | torchmetrics==0.6.0 6 | lightning-bolts>=0.4.0 7 | protobuf==3.19 8 | tqdm 9 | wandb 10 | scipy 11 | timm 12 | orion 13 | -------------------------------------------------------------------------------- /solo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from solo import args, losses, methods, utils 22 | 23 | __all__ = ["args", "losses", "methods", "utils"] 24 | -------------------------------------------------------------------------------- /solo/args/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.args import dataset, setup, utils 21 | 22 | __all__ = ["dataset", "setup", "utils"] 23 | -------------------------------------------------------------------------------- /solo/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.losses.barlow import barlow_loss_func 21 | from solo.losses.byol import byol_loss_func 22 | from solo.losses.deepclusterv2 import deepclusterv2_loss_func 23 | from solo.losses.dino import DINOLoss 24 | from solo.losses.moco import moco_loss_func 25 | from solo.losses.nnclr import nnclr_loss_func 26 | from solo.losses.ressl import ressl_loss_func 27 | from solo.losses.simclr import simclr_loss_func 28 | from solo.losses.simsiam import simsiam_loss_func 29 | from solo.losses.swav import swav_loss_func 30 | from solo.losses.vibcreg import vibcreg_loss_func 31 | from solo.losses.vicreg import vicreg_loss_func 32 | from solo.losses.wmse import wmse_loss_func 33 | 34 | __all__ = [ 35 | "barlow_loss_func", 36 | "byol_loss_func", 37 | "deepclusterv2_loss_func", 38 | "DINOLoss", 39 | "moco_loss_func", 40 | "nnclr_loss_func", 41 | "ressl_loss_func", 42 | "simclr_loss_func", 43 | "simsiam_loss_func", 44 | "swav_loss_func", 45 | "vibcreg_loss_func", 46 | "vicreg_loss_func", 47 | "wmse_loss_func", 48 | ] 49 | -------------------------------------------------------------------------------- /solo/losses/barlow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | 22 | import torch.distributed as dist 23 | 24 | 25 | def barlow_loss_func( 26 | z1: torch.Tensor, z2: torch.Tensor, lamb: float = 5e-3, scale_loss: float = 0.025 27 | ) -> torch.Tensor: 28 | """Computes Barlow Twins' loss given batch of projected features z1 from view 1 and 29 | projected features z2 from view 2. 30 | 31 | Args: 32 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 33 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 34 | lamb (float, optional): off-diagonal scaling factor for the cross-covariance matrix. 35 | Defaults to 5e-3. 36 | scale_loss (float, optional): final scaling factor of the loss. Defaults to 0.025. 37 | 38 | Returns: 39 | torch.Tensor: Barlow Twins' loss. 40 | """ 41 | 42 | N, D = z1.size() 43 | 44 | # to match the original code 45 | bn = torch.nn.BatchNorm1d(D, affine=False).to(z1.device) 46 | z1 = bn(z1) 47 | z2 = bn(z2) 48 | 49 | corr = torch.einsum("bi, bj -> ij", z1, z2) / N 50 | 51 | if dist.is_available() and dist.is_initialized(): 52 | dist.all_reduce(corr) 53 | world_size = dist.get_world_size() 54 | corr /= world_size 55 | 56 | diag = torch.eye(D, device=corr.device) 57 | cdif = (corr - diag).pow(2) 58 | cdif[~diag.bool()] *= lamb 59 | loss = scale_loss * cdif.sum() 60 | return loss 61 | -------------------------------------------------------------------------------- /solo/losses/byol.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def byol_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor: 25 | """Computes BYOL's loss given batch of predicted features p and projected momentum features z. 26 | 27 | Args: 28 | p (torch.Tensor): NxD Tensor containing predicted features from view 1 29 | z (torch.Tensor): NxD Tensor containing projected momentum features from view 2 30 | simplified (bool): faster computation, but with same result. Defaults to True. 31 | 32 | Returns: 33 | torch.Tensor: BYOL's loss. 34 | """ 35 | 36 | if simplified: 37 | return 2 - 2 * F.cosine_similarity(p, z.detach(), dim=-1).mean() 38 | 39 | p = F.normalize(p, dim=-1) 40 | z = F.normalize(z, dim=-1) 41 | 42 | return 2 - 2 * (p * z.detach()).sum(dim=1).mean() 43 | -------------------------------------------------------------------------------- /solo/losses/deepclusterv2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def deepclusterv2_loss_func( 25 | outputs: torch.Tensor, assignments: torch.Tensor, temperature: float = 0.1 26 | ) -> torch.Tensor: 27 | """Computes DeepClusterV2's loss given a tensor containing logits from multiple views 28 | and a tensor containing cluster assignments from the same multiple views. 29 | 30 | Args: 31 | outputs (torch.Tensor): tensor of size PxVxNxC where P is the number of prototype 32 | layers and V is the number of views. 33 | assignments (torch.Tensor): tensor of size PxVxNxC containing the assignments 34 | generated using k-means. 35 | temperature (float, optional): softmax temperature for the loss. Defaults to 0.1. 36 | 37 | Returns: 38 | torch.Tensor: DeepClusterV2 loss. 39 | """ 40 | loss = 0 41 | for h in range(outputs.size(0)): 42 | scores = outputs[h].view(-1, outputs.size(-1)) / temperature 43 | targets = assignments[h].repeat(outputs.size(1)).to(outputs.device, non_blocking=True) 44 | loss += F.cross_entropy(scores, targets, ignore_index=-1) 45 | return loss / outputs.size(0) 46 | -------------------------------------------------------------------------------- /solo/losses/moco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def moco_loss_func( 25 | query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor, temperature=0.1 26 | ) -> torch.Tensor: 27 | """Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a 28 | queue of past elements. 29 | 30 | Args: 31 | query (torch.Tensor): NxD Tensor containing the queries from view 1. 32 | key (torch.Tensor): NxD Tensor containing the queries from view 2. 33 | queue (torch.Tensor): a queue of negative samples for the contrastive loss. 34 | temperature (float, optional): [description]. temperature of the softmax in the contrastive 35 | loss. Defaults to 0.1. 36 | 37 | Returns: 38 | torch.Tensor: MoCo loss. 39 | """ 40 | 41 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1) 42 | neg = torch.einsum("nc,ck->nk", [query, queue]) 43 | logits = torch.cat([pos, neg], dim=1) 44 | logits /= temperature 45 | targets = torch.zeros(query.size(0), device=query.device, dtype=torch.long) 46 | return F.cross_entropy(logits, targets) 47 | -------------------------------------------------------------------------------- /solo/losses/nnclr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def nnclr_loss_func(nn: torch.Tensor, p: torch.Tensor, temperature: float = 0.1) -> torch.Tensor: 25 | """Computes NNCLR's loss given batch of nearest-neighbors nn from view 1 and 26 | predicted features p from view 2. 27 | 28 | Args: 29 | nn (torch.Tensor): NxD Tensor containing nearest neighbors' features from view 1. 30 | p (torch.Tensor): NxD Tensor containing predicted features from view 2 31 | temperature (float, optional): temperature of the softmax in the contrastive loss. Defaults 32 | to 0.1. 33 | 34 | Returns: 35 | torch.Tensor: NNCLR loss. 36 | """ 37 | 38 | nn = F.normalize(nn, dim=-1) 39 | p = F.normalize(p, dim=-1) 40 | 41 | logits = nn @ p.T / temperature 42 | 43 | n = p.size(0) 44 | labels = torch.arange(n, device=p.device) 45 | 46 | loss = F.cross_entropy(logits, labels) 47 | return loss 48 | -------------------------------------------------------------------------------- /solo/losses/ressl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def ressl_loss_func( 25 | q: torch.Tensor, 26 | k: torch.Tensor, 27 | queue: torch.Tensor, 28 | temperature_q: float = 0.1, 29 | temperature_k: float = 0.04, 30 | ) -> torch.Tensor: 31 | """Computes ReSSL's loss given a batch of queries from view 1, a batch of keys from view 2 and a 32 | queue of past elements. 33 | 34 | Args: 35 | query (torch.Tensor): NxD Tensor containing the queries from view 1. 36 | key (torch.Tensor): NxD Tensor containing the queries from view 2. 37 | queue (torch.Tensor): a queue of negative samples for the contrastive loss. 38 | temperature_q (float, optional): [description]. temperature of the softmax for the query. 39 | Defaults to 0.1. 40 | temperature_k (float, optional): [description]. temperature of the softmax for the key. 41 | Defaults to 0.04. 42 | 43 | Returns: 44 | torch.Tensor: ReSSL loss. 45 | """ 46 | 47 | logits_q = torch.einsum("nc,kc->nk", [q, queue]) 48 | logits_k = torch.einsum("nc,kc->nk", [k, queue]) 49 | 50 | loss = -torch.sum( 51 | F.softmax(logits_k.detach() / temperature_k, dim=1) 52 | * F.log_softmax(logits_q / temperature_q, dim=1), 53 | dim=1, 54 | ).mean() 55 | 56 | return loss 57 | -------------------------------------------------------------------------------- /solo/losses/simclr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | from solo.utils.misc import gather, get_rank 23 | 24 | 25 | def simclr_loss_func( 26 | z: torch.Tensor, indexes: torch.Tensor, temperature: float = 0.1 27 | ) -> torch.Tensor: 28 | """Computes SimCLR's loss given batch of projected features z 29 | from different views, a positive boolean mask of all positives and 30 | a negative boolean mask of all negatives. 31 | 32 | Args: 33 | z (torch.Tensor): (N*views) x D Tensor containing projected features from the views. 34 | indexes (torch.Tensor): unique identifiers for each crop (unsupervised) 35 | or targets of each crop (supervised). 36 | 37 | Return: 38 | torch.Tensor: SimCLR loss. 39 | """ 40 | 41 | z = F.normalize(z, dim=-1) 42 | gathered_z = gather(z) 43 | 44 | sim = torch.exp(torch.einsum("if, jf -> ij", z, gathered_z) / temperature) 45 | 46 | gathered_indexes = gather(indexes) 47 | 48 | indexes = indexes.unsqueeze(0) 49 | gathered_indexes = gathered_indexes.unsqueeze(0) 50 | # positives 51 | pos_mask = indexes.t() == gathered_indexes 52 | pos_mask[:, z.size(0) * get_rank() :].fill_diagonal_(0) 53 | # negatives 54 | neg_mask = indexes.t() != gathered_indexes 55 | 56 | pos = torch.sum(sim * pos_mask, 1) 57 | neg = torch.sum(sim * neg_mask, 1) 58 | loss = -(torch.mean(torch.log(pos / (pos + neg)))) 59 | return loss 60 | -------------------------------------------------------------------------------- /solo/losses/simsiam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def simsiam_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor: 25 | """Computes SimSiam's loss given batch of predicted features p from view 1 and 26 | a batch of projected features z from view 2. 27 | 28 | Args: 29 | p (torch.Tensor): Tensor containing predicted features from view 1. 30 | z (torch.Tensor): Tensor containing projected features from view 2. 31 | simplified (bool): faster computation, but with same result. 32 | 33 | Returns: 34 | torch.Tensor: SimSiam loss. 35 | """ 36 | 37 | if simplified: 38 | return -F.cosine_similarity(p, z.detach(), dim=-1).mean() 39 | 40 | p = F.normalize(p, dim=-1) 41 | z = F.normalize(z, dim=-1) 42 | 43 | return -(p * z.detach()).sum(dim=1).mean() 44 | -------------------------------------------------------------------------------- /solo/losses/swav.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import List 21 | 22 | import numpy as np 23 | import torch 24 | 25 | 26 | def swav_loss_func( 27 | preds: List[torch.Tensor], assignments: List[torch.Tensor], temperature: float = 0.1 28 | ) -> torch.Tensor: 29 | """Computes SwAV's loss given list of batch predictions from multiple views 30 | and a list of cluster assignments from the same multiple views. 31 | 32 | Args: 33 | preds (torch.Tensor): list of NxC Tensors containing nearest neighbors' features from 34 | view 1. 35 | assignments (torch.Tensor): list of NxC Tensor containing predicted features from view 2. 36 | temperature (torch.Tensor): softmax temperature for the loss. Defaults to 0.1. 37 | 38 | Returns: 39 | torch.Tensor: SwAV loss. 40 | """ 41 | 42 | losses = [] 43 | for v1 in range(len(preds)): 44 | for v2 in np.delete(np.arange(len(preds)), v1): 45 | a = assignments[v1] 46 | p = preds[v2] / temperature 47 | loss = -torch.mean(torch.sum(a * torch.log_softmax(p, dim=1), dim=1)) 48 | losses.append(loss) 49 | return sum(losses) / len(losses) 50 | -------------------------------------------------------------------------------- /solo/losses/wmse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def wmse_loss_func(z1: torch.Tensor, z2: torch.Tensor, simplified: bool = True) -> torch.Tensor: 25 | """Computes W-MSE's loss given two batches of whitened features z1 and z2. 26 | 27 | Args: 28 | z1 (torch.Tensor): NxD Tensor containing whitened features from view 1. 29 | z2 (torch.Tensor): NxD Tensor containing whitened features from view 2. 30 | simplified (bool): faster computation, but with same result. 31 | 32 | Returns: 33 | torch.Tensor: W-MSE loss. 34 | """ 35 | 36 | if simplified: 37 | return 2 - 2 * F.cosine_similarity(z1, z2.detach(), dim=-1).mean() 38 | 39 | z1 = F.normalize(z1, dim=-1) 40 | z2 = F.normalize(z2, dim=-1) 41 | 42 | return 2 - 2 * (z1 * z2).sum(dim=-1).mean() 43 | -------------------------------------------------------------------------------- /solo/methods/byol_ica.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from solo.methods.byol import BYOL 3 | 4 | 5 | class BYOL_ICA(BYOL): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.z_ica = nn.ModuleList([nn.Linear(Z.shape[1], Z.shape[1], bias=False) for _ in range(5)]) 9 | self.feats_ica = nn.ModuleList([nn.Linear(feats.shape[1], feats.shape[1], bias=False) for _ in range(5)]) 10 | 11 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 12 | out = super().forward(X, *args, **kwargs) 13 | for i in range(5): 14 | out[f'z_ica_{i}'] = self.z_ica[i](out['z']) 15 | out[f'feats_ica_{i}'] = self.feats_ica[i](out['feats']) 16 | return out 17 | 18 | -------------------------------------------------------------------------------- /solo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.utils import ( 21 | backbones, 22 | checkpointer, 23 | classification_dataloader, 24 | knn, 25 | lars, 26 | metrics, 27 | misc, 28 | momentum, 29 | pretrain_dataloader, 30 | sinkhorn_knopp, 31 | ) 32 | 33 | __all__ = [ 34 | "backbones", 35 | "classification_dataloader", 36 | "pretrain_dataloader", 37 | "checkpointer", 38 | "knn", 39 | "misc", 40 | "lars", 41 | "metrics", 42 | "momentum", 43 | "sinkhorn_knopp", 44 | ] 45 | 46 | try: 47 | from solo.utils import dali_dataloader # noqa: F401 48 | except ImportError: 49 | pass 50 | else: 51 | __all__.append("dali_dataloader") 52 | 53 | try: 54 | from solo.utils import auto_umap # noqa: F401 55 | except ImportError: 56 | pass 57 | else: 58 | __all__.append("auto_umap") 59 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /tests/args/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /tests/args/test_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | from solo.args.dataset import dataset_args, augmentations_args 22 | 23 | 24 | def test_argparse_dataset(): 25 | parser = argparse.ArgumentParser() 26 | dataset_args(parser) 27 | actions = [vars(action)["dest"] for action in vars(parser)["_actions"]] 28 | 29 | assert "dataset" in actions 30 | assert "data_dir" in actions 31 | assert "train_dir" in actions 32 | assert "val_dir" in actions 33 | assert "dali" in actions 34 | assert "dali_device" in actions 35 | 36 | 37 | def test_argparse_augmentations(): 38 | parser = argparse.ArgumentParser() 39 | augmentations_args(parser) 40 | actions = [vars(action)["dest"] for action in vars(parser)["_actions"]] 41 | 42 | assert "num_crops_per_aug" in actions 43 | 44 | assert "brightness" in actions 45 | assert "contrast" in actions 46 | assert "saturation" in actions 47 | assert "hue" in actions 48 | assert "color_jitter_prob" in actions 49 | assert "gray_scale_prob" in actions 50 | assert "horizontal_flip_prob" in actions 51 | assert "gaussian_prob" in actions 52 | assert "solarization_prob" in actions 53 | assert "min_scale" in actions 54 | -------------------------------------------------------------------------------- /tests/dali/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /tests/dali/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import contextlib 21 | import os 22 | import random 23 | import shutil 24 | from pathlib import Path 25 | 26 | import numpy as np 27 | from PIL import Image 28 | 29 | 30 | class DummyDataset: 31 | def __init__(self, train_dir, val_dir, size, num_classes): 32 | self.train_dir = Path(train_dir) 33 | self.val_dir = Path(val_dir) 34 | self.size = size 35 | self.num_classes = num_classes 36 | 37 | def __enter__(self): 38 | for D in [self.train_dir, self.val_dir]: 39 | for y in range(self.num_classes): 40 | # make needed directories 41 | with contextlib.suppress(OSError): 42 | os.makedirs(D / str(y)) 43 | 44 | for i in range(self.size): 45 | # generate random image 46 | size = (random.randint(300, 400), random.randint(300, 400)) 47 | im = np.random.rand(*size, 3) * 255 48 | im = Image.fromarray(im.astype("uint8")).convert("RGB") 49 | im.save(D / str(y) / f"{i}.jpg") 50 | 51 | def __exit__(self, *args): 52 | with contextlib.suppress(OSError): 53 | shutil.rmtree(self.train_dir) 54 | with contextlib.suppress(OSError): 55 | shutil.rmtree(self.val_dir) 56 | -------------------------------------------------------------------------------- /tests/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /tests/losses/test_barlow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import barlow_loss_func 22 | 23 | 24 | def test_barlow_loss(): 25 | b, f = 32, 128 26 | z1 = torch.randn(b, f).requires_grad_() 27 | z2 = torch.randn(b, f).requires_grad_() 28 | 29 | loss = barlow_loss_func(z1, z2, lamb=5e-3, scale_loss=0.025) 30 | initial_loss = loss.item() 31 | assert loss != 0 32 | 33 | for _ in range(20): 34 | loss = barlow_loss_func(z1, z2, lamb=5e-3, scale_loss=0.025) 35 | loss.backward() 36 | z1.data.add_(-0.5 * z1.grad) 37 | z2.data.add_(-0.5 * z2.grad) 38 | 39 | z1.grad = z2.grad = None 40 | 41 | assert loss < initial_loss 42 | -------------------------------------------------------------------------------- /tests/losses/test_byol.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import byol_loss_func 22 | 23 | 24 | def test_byol_loss(): 25 | b, f = 32, 128 26 | p = torch.randn(b, f).requires_grad_() 27 | z = torch.randn(b, f) 28 | 29 | loss = byol_loss_func(p, z) 30 | initial_loss = loss.item() 31 | assert loss != 0 32 | 33 | for _ in range(20): 34 | loss = byol_loss_func(p, z) 35 | loss.backward() 36 | p.data.add_(-0.5 * p.grad) 37 | 38 | p.grad = None 39 | 40 | assert loss < initial_loss 41 | 42 | assert abs(byol_loss_func(p, z) - byol_loss_func(p, z, simplified=False)) < 1e-6 43 | -------------------------------------------------------------------------------- /tests/losses/test_dino.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import DINOLoss 22 | 23 | 24 | def test_dino_loss(): 25 | b, f, num_epochs = 32, 128, 20 26 | p = torch.randn(b, f).requires_grad_() 27 | p_momentum = torch.randn(b, f) 28 | 29 | dino_loss = DINOLoss( 30 | num_prototypes=f, 31 | warmup_teacher_temp=0.4, 32 | teacher_temp=0.7, 33 | warmup_teacher_temp_epochs=10, 34 | student_temp=0.1, 35 | num_epochs=num_epochs, 36 | ) 37 | 38 | loss = dino_loss(p, p_momentum) 39 | initial_loss = loss.item() 40 | assert loss != 0 41 | 42 | for i in range(20): 43 | dino_loss.epoch = i 44 | 45 | loss = dino_loss(p, p_momentum) 46 | loss.backward() 47 | p.data.add_(-0.5 * p.grad) 48 | 49 | p.grad = None 50 | 51 | assert loss < initial_loss 52 | -------------------------------------------------------------------------------- /tests/losses/test_moco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import moco_loss_func 22 | 23 | 24 | def test_moco_loss(): 25 | b, f, q = 32, 128, 15000 26 | query = torch.randn(b, f).requires_grad_() 27 | key = torch.randn(b, f).requires_grad_() 28 | queue = torch.randn(f, q) 29 | 30 | loss = moco_loss_func(query, key, queue, temperature=0.1) 31 | initial_loss = loss.item() 32 | assert loss != 0 33 | 34 | for _ in range(20): 35 | loss = moco_loss_func(query, key, queue, temperature=0.1) 36 | loss.backward() 37 | query.data.add_(-0.5 * query.grad) 38 | key.data.add_(-0.5 * key.grad) 39 | 40 | query.grad = key.grad = None 41 | 42 | assert loss < initial_loss 43 | -------------------------------------------------------------------------------- /tests/losses/test_nnclr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import nnclr_loss_func 22 | 23 | 24 | def test_nnclr_loss(): 25 | b, f = 32, 128 26 | nn = torch.randn(b, f) 27 | p = torch.randn(b, f).requires_grad_() 28 | 29 | loss = nnclr_loss_func(nn, p, temperature=0.1) 30 | initial_loss = loss.item() 31 | assert loss != 0 32 | 33 | for _ in range(20): 34 | loss = nnclr_loss_func(nn, p, temperature=0.1) 35 | loss.backward() 36 | p.data.add_(-0.5 * p.grad) 37 | 38 | p.grad = None 39 | 40 | assert loss < initial_loss 41 | -------------------------------------------------------------------------------- /tests/losses/test_ressl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import ressl_loss_func 22 | 23 | 24 | def test_moco_loss(): 25 | b, f, q = 32, 128, 15000 26 | query = torch.randn(b, f).requires_grad_() 27 | key = torch.randn(b, f) 28 | queue = torch.randn(q, f) 29 | 30 | loss = ressl_loss_func(query, key, queue, temperature_q=0.1, temperature_k=0.04) 31 | initial_loss = loss.item() 32 | assert loss != 0 33 | 34 | for _ in range(20): 35 | loss = ressl_loss_func(query, key, queue, temperature_q=0.1, temperature_k=0.04) 36 | loss.backward() 37 | query.data.add_(-0.5 * query.grad) 38 | 39 | query.grad = None 40 | 41 | assert loss < initial_loss 42 | -------------------------------------------------------------------------------- /tests/losses/test_simclr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import simclr_loss_func 22 | 23 | 24 | def test_simclr_loss(): 25 | b, f = 32, 128 26 | z1 = torch.randn(b, f).requires_grad_() 27 | z2 = torch.randn(b, f).requires_grad_() 28 | z = torch.cat((z1, z2)) 29 | indexes = torch.arange(b).repeat(2) 30 | 31 | loss = simclr_loss_func(z, indexes, temperature=0.1) 32 | initial_loss = loss.item() 33 | assert loss != 0 34 | 35 | for _ in range(20): 36 | z = torch.cat((z1, z2)) 37 | loss = simclr_loss_func(z, indexes, temperature=0.1) 38 | loss.backward() 39 | 40 | z1.data.add_(-0.5 * z1.grad) 41 | z2.data.add_(-0.5 * z2.grad) 42 | 43 | z1.grad = z2.grad = None 44 | 45 | assert loss < initial_loss 46 | -------------------------------------------------------------------------------- /tests/losses/test_simsiam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import simsiam_loss_func 22 | 23 | 24 | def test_simsiam_loss(): 25 | b, f = 32, 128 26 | p = torch.randn(b, f).requires_grad_() 27 | z = torch.randn(b, f) 28 | 29 | loss = simsiam_loss_func(p, z) 30 | initial_loss = loss.item() 31 | assert loss != 0 32 | 33 | for _ in range(20): 34 | loss = simsiam_loss_func(p, z) 35 | loss.backward() 36 | p.data.add_(-0.5 * p.grad) 37 | 38 | p.grad = None 39 | 40 | assert loss < initial_loss 41 | 42 | assert abs(simsiam_loss_func(p, z) - simsiam_loss_func(p, z, simplified=False)) < 1e-6 43 | -------------------------------------------------------------------------------- /tests/losses/test_swav.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn as nn 22 | from solo.losses import swav_loss_func 23 | from solo.utils.sinkhorn_knopp import SinkhornKnopp 24 | 25 | 26 | def get_assignments(preds): 27 | bs = preds[0].size(0) 28 | assignments = [] 29 | sk = SinkhornKnopp(10, 0.05, 1) 30 | 31 | for p in preds: 32 | # compute assignments with sinkhorn-knopp 33 | assignments.append(sk(p)[:bs]) 34 | return assignments 35 | 36 | 37 | def test_swav_loss(): 38 | b, f = 256, 128 39 | prototypes = nn.utils.weight_norm(torch.nn.Linear(f, f, bias=False)) 40 | 41 | z = torch.zeros(2, b, f).uniform_(-2, 2).requires_grad_() 42 | preds = prototypes(z) 43 | assignments = get_assignments(preds) 44 | 45 | loss = swav_loss_func(preds, assignments, temperature=0.1) 46 | initial_loss = loss.item() 47 | assert loss != 0 48 | 49 | for _ in range(20): 50 | preds = prototypes(z) 51 | assignments = get_assignments(preds) 52 | loss = swav_loss_func(preds, assignments, temperature=0.1) 53 | loss.backward() 54 | 55 | z.data.add_(-0.5 * z.grad) 56 | z.grad = None 57 | 58 | assert loss < initial_loss 59 | -------------------------------------------------------------------------------- /tests/losses/test_vibcreg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import vibcreg_loss_func 22 | 23 | 24 | def test_vibcreg_loss_func(): 25 | b, f = 32, 128 26 | z1 = torch.randn(b, f).requires_grad_() 27 | z2 = torch.randn(b, f).requires_grad_() 28 | 29 | loss = vibcreg_loss_func( 30 | z1, z2, sim_loss_weight=25.0, var_loss_weight=25.0, cov_loss_weight=200.0 31 | ) 32 | initial_loss = loss.item() 33 | assert loss != 0 34 | 35 | for _ in range(20): 36 | loss = vibcreg_loss_func( 37 | z1, z2, sim_loss_weight=25.0, var_loss_weight=25.0, cov_loss_weight=200.0 38 | ) 39 | loss.backward() 40 | z1.data.add_(-0.5 * z1.grad) 41 | z2.data.add_(-0.5 * z2.grad) 42 | 43 | z1.grad = z2.grad = None 44 | 45 | assert loss < initial_loss 46 | -------------------------------------------------------------------------------- /tests/losses/test_vicreg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import vicreg_loss_func 22 | 23 | 24 | def test_vicreg_loss(): 25 | b, f = 32, 128 26 | z1 = torch.randn(b, f).requires_grad_() 27 | z2 = torch.randn(b, f).requires_grad_() 28 | 29 | loss = vicreg_loss_func(z1, z2, sim_loss_weight=25.0, var_loss_weight=25.0, cov_loss_weight=1.0) 30 | initial_loss = loss.item() 31 | assert loss != 0 32 | 33 | for _ in range(20): 34 | loss = vicreg_loss_func( 35 | z1, z2, sim_loss_weight=25.0, var_loss_weight=25.0, cov_loss_weight=1.0 36 | ) 37 | loss.backward() 38 | z1.data.add_(-0.5 * z1.grad) 39 | z2.data.add_(-0.5 * z2.grad) 40 | 41 | z1.grad = z2.grad = None 42 | 43 | assert loss < initial_loss 44 | -------------------------------------------------------------------------------- /tests/losses/test_wmse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses import wmse_loss_func 22 | 23 | 24 | def test_simsiam_loss(): 25 | b, f = 32, 128 26 | p = torch.randn(b, f).requires_grad_() 27 | z = torch.randn(b, f) 28 | 29 | loss = wmse_loss_func(p, z) 30 | initial_loss = loss.item() 31 | assert loss != 0 32 | 33 | for _ in range(20): 34 | loss = wmse_loss_func(p, z) 35 | loss.backward() 36 | p.data.add_(-0.5 * p.grad) 37 | 38 | p.grad = None 39 | 40 | assert loss < initial_loss 41 | 42 | assert abs(wmse_loss_func(p, z) - wmse_loss_func(p, z, simplified=False)) < 1e-6 43 | -------------------------------------------------------------------------------- /tests/methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /tests/utils/test_filter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | 22 | import torch 23 | import torch.nn as nn 24 | from solo.utils.misc import FilterInfNNan, filter_inf_n_nan 25 | 26 | 27 | def test_filter(): 28 | tensor = torch.randn(100) 29 | tensor[10] = math.nan 30 | filtered, selected = filter_inf_n_nan(tensor, return_indexes=True) 31 | assert filtered.size(0) == 99 32 | assert selected.sum() == 99 33 | 34 | tensor2 = torch.randn(100) 35 | assert [t.size(0) == 99 for t in filter_inf_n_nan([tensor, tensor2])] 36 | 37 | tensor = torch.randn(100, 30) 38 | tensor[10, 0] = math.inf 39 | assert filter_inf_n_nan(tensor).size(0) == 99 40 | 41 | tensor2 = torch.randn(100, 30) 42 | assert [t.size(0) == 99 for t in filter_inf_n_nan([tensor, tensor2])] 43 | 44 | class DummyNanLayer(nn.Module): 45 | def forward(self, x): 46 | x[10] = -math.inf 47 | return x 48 | 49 | dummy_nan_layer = DummyNanLayer() 50 | filter_layer = FilterInfNNan(dummy_nan_layer) 51 | assert filter_layer(tensor).size(0) == 99 52 | -------------------------------------------------------------------------------- /tests/utils/test_gather.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.utils.misc import gather 22 | 23 | 24 | def test_gather_layer(): 25 | X = torch.randn(10, 30, requires_grad=True) 26 | X_gathered = gather(X) 27 | assert isinstance(X, torch.Tensor) 28 | 29 | dummy_loss = torch.mm(X_gathered, X_gathered.T).sum() 30 | dummy_loss.backward() 31 | assert X.grad is not None 32 | -------------------------------------------------------------------------------- /tests/utils/test_kmeans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from solo.utils.kmeans import KMeans 3 | 4 | 5 | def test_kmeans(): 6 | k = [30] 7 | kmeans = KMeans(1, 0, 1, 500, 128, k) 8 | local_memory_index = torch.arange(0, 500) 9 | local_memory_embeddings = torch.randn((1, 500, 128)) 10 | assignments, centroids_list = kmeans.cluster_memory(local_memory_index, local_memory_embeddings) 11 | assert assignments.size() == (1, 500) 12 | assert len(assignments.unique()) == 30 13 | assert centroids_list[0].size() == (30, 128) 14 | -------------------------------------------------------------------------------- /tests/utils/test_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.utils.metrics import accuracy_at_k 22 | 23 | 24 | def test_accuracy_at_k(): 25 | b, c = 32, 100 26 | output = torch.randn(b, c) 27 | target = torch.randint(low=0, high=c, size=(b,)) 28 | acc1, acc5 = accuracy_at_k(output, target) 29 | 30 | assert isinstance(acc1, torch.Tensor) 31 | assert isinstance(acc5, torch.Tensor) 32 | -------------------------------------------------------------------------------- /zoo/imagenet.sh: -------------------------------------------------------------------------------- 1 | mkdir trained_models 2 | cd trained_models 3 | mkdir imagenet 4 | cd imagenet 5 | 6 | # Barlow Twins 7 | mkdir barlow_twins 8 | cd barlow_twins 9 | gdown https://drive.google.com/uc?id=1u0NH9L-p1Y99KVFxjBmZ_wlY-l6YJFtP 10 | gdown https://drive.google.com/uc?id=1EKdbR72-gtNE782254tjXi9UR2NiwEWh 11 | cd .. 12 | 13 | # BYOL 14 | mkdir byol 15 | cd byol 16 | gdown https://drive.google.com/uc?id=1TheL_4tmDWByCxg8XHke5VEz_lcYHH64 17 | gdown https://drive.google.com/uc?id=18gG0Jo59cFVX4qNUO119jIhzHcJkAmGz 18 | cd .. 19 | 20 | # MoCo V2+ 21 | mkdir mocov2plus 22 | cd mocov2plus 23 | gdown https://drive.google.com/uc?id=1BBauwWTJV38BCf56KtOK9TJWLyjNH-mP 24 | gdown https://drive.google.com/uc?id=1JMpGSYjefFzxT5GTbEc_2d4THxOxC3Ca 25 | cd .. 26 | --------------------------------------------------------------------------------