├── ACKNOWLEDGEMENTS.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── clip ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── clip.cpython-39.pyc │ ├── model.cpython-39.pyc │ └── simple_tokenizer.cpython-39.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── datasets │ ├── caltech101.yaml │ ├── dtd.yaml │ ├── eurosat.yaml │ ├── fgvc_aircraft.yaml │ ├── food101.yaml │ ├── imagenet.yaml │ ├── imagenet_a.yaml │ ├── imagenet_r.yaml │ ├── imagenet_sketch.yaml │ ├── imagenetv2.yaml │ ├── oxford_flowers.yaml │ ├── oxford_pets.yaml │ ├── stanford_cars.yaml │ ├── sun397.yaml │ └── ucf101.yaml └── trainers │ ├── CoCoOp │ ├── vit_b16_c16_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml │ └── vit_b16_c8_ep10_batch1.yaml │ ├── CoOp │ ├── ipynb_checkpoints │ │ ├── vit_b16_ep10_ctxv1-checkpoint.yaml │ │ └── vit_b16_ep50_ctxv1-checkpoint.yaml │ ├── rn101.yaml │ ├── rn101_ep50.yaml │ ├── rn50.yaml │ ├── rn50_ctxv1.yaml │ ├── rn50_ep100.yaml │ ├── rn50_ep50.yaml │ ├── rn50_ep50_ctxv1.yaml │ ├── rn50_val.yaml │ ├── vit_b16.yaml │ ├── vit_b16_ctxv1.yaml │ ├── vit_b16_ep100.yaml │ ├── vit_b16_ep100_ctxv1.yaml │ ├── vit_b16_ep10_ctxv1.yaml │ ├── vit_b16_ep50.yaml │ ├── vit_b16_ep50_ctxv1.yaml │ ├── vit_b32.yaml │ └── vit_b32_ep50.yaml │ └── OGEN │ ├── ipynb_checkpoints │ ├── vit_b16_ep10_ctxv1-checkpoint.yaml │ └── vit_b16_ep200_ctxv1-checkpoint.yaml │ ├── vit_b16_ep10_ctxv1.yaml │ ├── vit_b16_ep200_ctxv1.yaml │ └── vit_b16_ep50_ctxv1.yaml ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── caltech101.cpython-39.pyc │ ├── dtd.cpython-39.pyc │ ├── eurosat.cpython-39.pyc │ ├── fgvc_aircraft.cpython-39.pyc │ ├── food101.cpython-39.pyc │ ├── imagenet.cpython-39.pyc │ ├── imagenet_a.cpython-39.pyc │ ├── imagenet_r.cpython-39.pyc │ ├── imagenet_sketch.cpython-39.pyc │ ├── imagenetv2.cpython-39.pyc │ ├── oxford_flowers.cpython-39.pyc │ ├── oxford_pets.cpython-39.pyc │ ├── stanford_cars.cpython-39.pyc │ ├── sun397.cpython-39.pyc │ └── ucf101.cpython-39.pyc ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc_aircraft.py ├── food101.py ├── imagenet.py ├── imagenet_a.py ├── imagenet_r.py ├── imagenet_sketch.py ├── imagenetv2.py ├── ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── caltech101-checkpoint.py │ ├── dtd-checkpoint.py │ ├── eurosat-checkpoint.py │ ├── fgvc_aircraft-checkpoint.py │ ├── food101-checkpoint.py │ ├── imagenet-checkpoint.py │ ├── oxford_flowers-checkpoint.py │ ├── oxford_pets-checkpoint.py │ ├── stanford_cars-checkpoint.py │ ├── sun397-checkpoint.py │ └── ucf101-checkpoint.py ├── oxford_flowers.py ├── oxford_pets.py ├── stanford_cars.py ├── sun397.py └── ucf101.py ├── requirements.txt ├── scripts ├── cocoop │ ├── README.md │ ├── base2new_test.sh │ ├── base2new_train.sh │ ├── ipynb_checkpoints │ │ └── base2new_test-checkpoint.sh │ ├── xd_test.sh │ └── xd_train.sh ├── coop │ ├── README.md │ ├── base2new_train_ep10.sh │ ├── eval.sh │ ├── ipynb_checkpoints │ │ ├── base2new_train_ep10-checkpoint.sh │ │ └── eval-checkpoint.sh │ ├── main.sh │ └── zeroshot.sh └── ogen │ ├── base2new_eval_ep10.sh │ ├── base2new_eval_ep200.sh │ ├── base2new_eval_ep50.sh │ ├── base2new_train_ep10.sh │ ├── base2new_train_ep200.sh │ ├── base2new_train_ep50.sh │ └── ipynb_checkpoints │ ├── base2new_eval_ep10-checkpoint.sh │ ├── base2new_eval_ep200-checkpoint.sh │ ├── base2new_eval_ep50-checkpoint.sh │ ├── base2new_train_ep10-checkpoint.sh │ ├── base2new_train_ep200-checkpoint.sh │ └── base2new_train_ep50-checkpoint.sh ├── train.py └── trainers ├── __init__.py ├── __pycache__ ├── __init__.cpython-39.pyc ├── cocoop.cpython-39.pyc ├── coop.cpython-39.pyc ├── imagenet_templates.cpython-39.pyc ├── ogen.cpython-39.pyc └── zsclip.cpython-39.pyc ├── cocoop.py ├── coop.py ├── imagenet_templates.py ├── ipynb_checkpoints ├── __init__-checkpoint.py ├── coop-checkpoint.py ├── imagenet_templates-checkpoint.py ├── ogen-checkpoint.py └── zsclip-checkpoint.py ├── ogen.py └── zsclip.py /ACKNOWLEDGEMENTS.txt: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this OGEN Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | _____________________ 6 | PyTorch (https://pytorch.org) 7 | We use PyTorch as the training framework. 8 | 9 | From PyTorch: 10 | 11 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 12 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 13 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 14 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 15 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 16 | Copyright (c) 2011-2013 NYU (Clement Farabet) 17 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 18 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 19 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 20 | 21 | From Caffe2: 22 | 23 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 24 | 25 | All contributions by Facebook: 26 | Copyright (c) 2016 Facebook Inc. 27 | 28 | All contributions by Google: 29 | Copyright (c) 2015 Google Inc. 30 | All rights reserved. 31 | 32 | All contributions by Yangqing Jia: 33 | Copyright (c) 2015 Yangqing Jia 34 | All rights reserved. 35 | 36 | All contributions from Caffe: 37 | Copyright(c) 2013, 2014, 2015, the respective contributors 38 | All rights reserved. 39 | 40 | All other contributions: 41 | Copyright(c) 2015, 2016 the respective contributors 42 | All rights reserved. 43 | 44 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 45 | copyright over their contributions to Caffe2. The project versioning records 46 | all such contribution and copyright details. If a contributor wants to further 47 | mark their specific copyright on a particular contribution, they should 48 | indicate their copyright solely in the commit message of the change when it is 49 | committed. 50 | 51 | All rights reserved. 52 | 53 | Redistribution and use in source and binary forms, with or without 54 | modification, are permitted provided that the following conditions are met: 55 | 56 | 1. Redistributions of source code must retain the above copyright 57 | notice, this list of conditions and the following disclaimer. 58 | 59 | 2. Redistributions in binary form must reproduce the above copyright 60 | notice, this list of conditions and the following disclaimer in the 61 | documentation and/or other materials provided with the distribution. 62 | 63 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 64 | and IDIAP Research Institute nor the names of its contributors may be 65 | used to endorse or promote products derived from this software without 66 | specific prior written permission. 67 | 68 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 69 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 70 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 71 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 72 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 73 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 74 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 75 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 76 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 77 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 78 | POSSIBILITY OF SUCH DAMAGE. 79 | 80 | _____________________ 81 | Soumith Chintala (TorchVision: https://github.com/pytorch/vision/) 82 | TorchVision is used for handling data loading and IO utilities, which is distributed under BSD 3-Clause License. 83 | 84 | Copyright (c) Soumith Chintala 2016, 85 | All rights reserved. 86 | 87 | Redistribution and use in source and binary forms, with or without 88 | modification, are permitted provided that the following conditions are met: 89 | 90 | * Redistributions of source code must retain the above copyright notice, this 91 | list of conditions and the following disclaimer. 92 | 93 | * Redistributions in binary form must reproduce the above copyright notice, 94 | this list of conditions and the following disclaimer in the documentation 95 | and/or other materials provided with the distribution. 96 | 97 | * Neither the name of the copyright holder nor the names of its 98 | contributors may be used to endorse or promote products derived from 99 | this software without specific prior written permission. 100 | 101 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 102 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 103 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 104 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 105 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 106 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 107 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 108 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 109 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 110 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED WITH OGEN: 43 | 44 | The OGEN software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overcoming the Pitfalls of Vision-Language Model Finetuning for OOD Generalization 2 | 3 | This OGEN repository includes sample codes that can be used to finetune the CLIP model via prompt learning on various downstream datasets, with the main focus on improving the OOD GENeralization of finetuned models. 4 | 5 | See the accompanying paper on arXiv for more details: [Overcoming the Pitfalls of Vision-Language Model Finetuning for OOD Generalization](https://arxiv.org/pdf/2401.15914.pdf) 6 | 7 | 8 | 9 | ## Getting Started 10 | 11 | **Dependencies.** We have tested on: 12 | - CUDA 11.8 13 | - torch 2.0.1 14 | - torchvision 0.15.2 15 | - dassl 0.6.3 16 | 17 | If PyTorch CUDA has been installed, please simply set up the environment with pip. 18 | 19 | ```shell 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | **Datasets.** To prepare all the downstream datasets (train/val/test splitting, etc), please refer to the DATASETS.md in [Link](https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md) and follow the instructions therein. 24 | 25 | ## Running the code 26 | 27 | * Base-to-new generalization: training for 10 epochs 28 | ```shell 29 | # bash scripts/ogen/base2new_train_ep10.sh , 30 | # dataset name: imagenet, caltech101, oxford_pets, stanford_cars, oxford_flowers, sun397, food101, fgvc_aircraft, eurosat, dtd, ucf101 31 | # random seed: 1, 2, 3 32 | bash scripts/ogen/base2new_train_ep10.sh caltech101 1 33 | ``` 34 | 35 | * Base-to-new generalization: evaluation after 10 epochs 36 | ```shell 37 | # bash scripts/ogen/base2new_eval_ep10.sh , 38 | bash scripts/ogen/base2new_eval_ep10.sh caltech101 1 39 | ``` 40 | 41 | ## Citation 42 | ``` 43 | @inproceedings{zang2023overcoming, 44 | title={Overcoming the Pitfalls of Vision-Language Model Finetuning for OOD Generalization}, 45 | author={Zang, Yuhang and Goh, Hanlin and Susskind, Josh and Huang, Chen}, 46 | booktitle={ICLR}, 47 | year={2024} 48 | } 49 | ``` 50 | 51 | ## License 52 | 53 | This sample code is released under the terms set forth in LICENSE. 54 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from .clip import * 6 | -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/clip/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/clip/__pycache__/clip.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/clip/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/clip/__pycache__/simple_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import hashlib 6 | import os 7 | import urllib 8 | import warnings 9 | from typing import Union, List 10 | 11 | import torch 12 | from PIL import Image 13 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 14 | from tqdm import tqdm 15 | 16 | from .model import build_model 17 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 18 | 19 | try: 20 | from torchvision.transforms import InterpolationMode 21 | BICUBIC = InterpolationMode.BICUBIC 22 | except ImportError: 23 | BICUBIC = Image.BICUBIC 24 | 25 | 26 | if torch.__version__.split(".") < ["1", "7", "1"]: 27 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 28 | 29 | 30 | __all__ = ["available_models", "load", "tokenize"] 31 | _tokenizer = _Tokenizer() 32 | 33 | _MODELS = { 34 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 35 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 36 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 37 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | lambda image: image.convert("RGB"), 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | Returns 105 | ------- 106 | model : torch.nn.Module 107 | The CLIP model 108 | 109 | preprocess : Callable[[PIL.Image], torch.Tensor] 110 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 111 | """ 112 | if name in _MODELS: 113 | model_path = _download(_MODELS[name]) 114 | elif os.path.isfile(name): 115 | model_path = name 116 | else: 117 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 118 | 119 | try: 120 | # loading JIT archive 121 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 122 | state_dict = None 123 | except RuntimeError: 124 | # loading saved state dict 125 | if jit: 126 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 127 | jit = False 128 | state_dict = torch.load(model_path, map_location="cpu") 129 | 130 | if not jit: 131 | model = build_model(state_dict or model.state_dict()).to(device) 132 | if str(device) == "cpu": 133 | model.float() 134 | return model, _transform(model.visual.input_resolution) 135 | 136 | # patch the device names 137 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 138 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 139 | 140 | def patch_device(module): 141 | try: 142 | graphs = [module.graph] if hasattr(module, "graph") else [] 143 | except RuntimeError: 144 | graphs = [] 145 | 146 | if hasattr(module, "forward1"): 147 | graphs.append(module.forward1.graph) 148 | 149 | for graph in graphs: 150 | for node in graph.findAllNodes("prim::Constant"): 151 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 152 | node.copyAttributes(device_node) 153 | 154 | model.apply(patch_device) 155 | patch_device(model.encode_image) 156 | patch_device(model.encode_text) 157 | 158 | # patch dtype to float32 on CPU 159 | if str(device) == "cpu": 160 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 161 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 162 | float_node = float_input.node() 163 | 164 | def patch_float(module): 165 | try: 166 | graphs = [module.graph] if hasattr(module, "graph") else [] 167 | except RuntimeError: 168 | graphs = [] 169 | 170 | if hasattr(module, "forward1"): 171 | graphs.append(module.forward1.graph) 172 | 173 | for graph in graphs: 174 | for node in graph.findAllNodes("aten::to"): 175 | inputs = list(node.inputs()) 176 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 177 | if inputs[i].node()["value"] == 5: 178 | inputs[i].node().copyAttributes(float_node) 179 | 180 | model.apply(patch_float) 181 | patch_float(model.encode_image) 182 | patch_float(model.encode_text) 183 | 184 | model.float() 185 | 186 | return model, _transform(model.input_resolution.item()) 187 | 188 | 189 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 190 | """ 191 | Returns the tokenized representation of given input string(s) 192 | 193 | Parameters 194 | ---------- 195 | texts : Union[str, List[str]] 196 | An input string or a list of input strings to tokenize 197 | 198 | context_length : int 199 | The context length to use; all CLIP models use 77 as the context length 200 | 201 | truncate: bool 202 | Whether to truncate the text in case its encoding is longer than the context length 203 | 204 | Returns 205 | ------- 206 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 207 | """ 208 | if isinstance(texts, str): 209 | texts = [texts] 210 | 211 | sot_token = _tokenizer.encoder["<|startoftext|>"] 212 | eot_token = _tokenizer.encoder["<|endoftext|>"] 213 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 214 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 215 | 216 | for i, tokens in enumerate(all_tokens): 217 | if len(tokens) > context_length: 218 | if truncate: 219 | tokens = tokens[:context_length] 220 | tokens[-1] = eot_token 221 | else: 222 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 223 | result[i, :len(tokens)] = torch.tensor(tokens) 224 | 225 | return result 226 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | 10 | import ftfy 11 | import regex as re 12 | 13 | 14 | @lru_cache() 15 | def default_bpe(): 16 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 17 | 18 | 19 | @lru_cache() 20 | def bytes_to_unicode(): 21 | """ 22 | Returns list of utf-8 byte and a corresponding list of unicode strings. 23 | The reversible bpe codes work on unicode strings. 24 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 25 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 26 | This is a signficant percentage of your normal, say, 32K bpe vocab. 27 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 28 | And avoids mapping to whitespace/control characters the bpe code barfs on. 29 | """ 30 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 31 | cs = bs[:] 32 | n = 0 33 | for b in range(2**8): 34 | if b not in bs: 35 | bs.append(b) 36 | cs.append(2**8+n) 37 | n += 1 38 | cs = [chr(n) for n in cs] 39 | return dict(zip(bs, cs)) 40 | 41 | 42 | def get_pairs(word): 43 | """Return set of symbol pairs in a word. 44 | Word is represented as tuple of symbols (symbols being variable-length strings). 45 | """ 46 | pairs = set() 47 | prev_char = word[0] 48 | for char in word[1:]: 49 | pairs.add((prev_char, char)) 50 | prev_char = char 51 | return pairs 52 | 53 | 54 | def basic_clean(text): 55 | text = ftfy.fix_text(text) 56 | text = html.unescape(html.unescape(text)) 57 | return text.strip() 58 | 59 | 60 | def whitespace_clean(text): 61 | text = re.sub(r'\s+', ' ', text) 62 | text = text.strip() 63 | return text 64 | 65 | 66 | class SimpleTokenizer(object): 67 | def __init__(self, bpe_path: str = default_bpe()): 68 | self.byte_encoder = bytes_to_unicode() 69 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 70 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 71 | merges = merges[1:49152-256-2+1] 72 | merges = [tuple(merge.split()) for merge in merges] 73 | vocab = list(bytes_to_unicode().values()) 74 | vocab = vocab + [v+'' for v in vocab] 75 | for merge in merges: 76 | vocab.append(''.join(merge)) 77 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 82 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 83 | 84 | def bpe(self, token): 85 | if token in self.cache: 86 | return self.cache[token] 87 | word = tuple(token[:-1]) + ( token[-1] + '',) 88 | pairs = get_pairs(word) 89 | 90 | if not pairs: 91 | return token+'' 92 | 93 | while True: 94 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 95 | if bigram not in self.bpe_ranks: 96 | break 97 | first, second = bigram 98 | new_word = [] 99 | i = 0 100 | while i < len(word): 101 | try: 102 | j = word.index(first, i) 103 | new_word.extend(word[i:j]) 104 | i = j 105 | except: 106 | new_word.extend(word[i:]) 107 | break 108 | 109 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 110 | new_word.append(first+second) 111 | i += 2 112 | else: 113 | new_word.append(word[i]) 114 | i += 1 115 | new_word = tuple(new_word) 116 | word = new_word 117 | if len(word) == 1: 118 | break 119 | else: 120 | pairs = get_pairs(word) 121 | word = ' '.join(word) 122 | self.cache[token] = word 123 | return word 124 | 125 | def encode(self, text): 126 | bpe_tokens = [] 127 | text = whitespace_clean(basic_clean(text)).lower() 128 | for token in re.findall(self.pat, text): 129 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 130 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 131 | return bpe_tokens 132 | 133 | def decode(self, tokens): 134 | text = ''.join([self.decoder[token] for token in tokens]) 135 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 136 | return text 137 | -------------------------------------------------------------------------------- /configs/datasets/caltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Caltech101" 3 | -------------------------------------------------------------------------------- /configs/datasets/dtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "DescribableTextures" 3 | -------------------------------------------------------------------------------- /configs/datasets/eurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "EuroSAT" 3 | -------------------------------------------------------------------------------- /configs/datasets/fgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "FGVCAircraft" 3 | -------------------------------------------------------------------------------- /configs/datasets/food101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Food101" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNet" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetA" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetR" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetSketch" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetV2" 3 | -------------------------------------------------------------------------------- /configs/datasets/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordFlowers" -------------------------------------------------------------------------------- /configs/datasets/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordPets" -------------------------------------------------------------------------------- /configs/datasets/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "StanfordCars" 3 | -------------------------------------------------------------------------------- /configs/datasets/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SUN397" 3 | -------------------------------------------------------------------------------- /configs/datasets/ucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "UCF101" 3 | -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 16 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 8 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/ipynb_checkpoints/vit_b16_ep10_ctxv1-checkpoint.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/ipynb_checkpoints/vit_b16_ep50_ctxv1-checkpoint.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn101.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn101_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 200 4 | TEST: 5 | BATCH_SIZE: 200 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | MODEL: 16 | BACKBONE: 17 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep100_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep10_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b32.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b32_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /configs/trainers/OGEN/ipynb_checkpoints/vit_b16_ep10_ctxv1-checkpoint.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | CHECKPOINT_FREQ: 0 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "ViT-B/16" 31 | 32 | TRAINER: 33 | COOP: 34 | CTX_INIT: "a photo of a" 35 | -------------------------------------------------------------------------------- /configs/trainers/OGEN/ipynb_checkpoints/vit_b16_ep200_ctxv1-checkpoint.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | CHECKPOINT_FREQ: 20 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "ViT-B/16" 31 | 32 | TRAINER: 33 | COOP: 34 | CTX_INIT: "a photo of a" 35 | 36 | TEST: 37 | FINAL_MODEL: "last_step" # last_step, best_val -------------------------------------------------------------------------------- /configs/trainers/OGEN/vit_b16_ep10_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | CHECKPOINT_FREQ: 0 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "ViT-B/16" 31 | 32 | TRAINER: 33 | COOP: 34 | CTX_INIT: "a photo of a" 35 | -------------------------------------------------------------------------------- /configs/trainers/OGEN/vit_b16_ep200_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | CHECKPOINT_FREQ: 0 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "ViT-B/16" 31 | 32 | TRAINER: 33 | COOP: 34 | CTX_INIT: "a photo of a" 35 | 36 | TEST: 37 | FINAL_MODEL: "last_step" # last_step, best_val 38 | -------------------------------------------------------------------------------- /configs/trainers/OGEN/vit_b16_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | CHECKPOINT_FREQ: 0 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "ViT-B/16" 31 | 32 | TRAINER: 33 | COOP: 34 | CTX_INIT: "a photo of a" 35 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/caltech101.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/dtd.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/eurosat.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc_aircraft.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/fgvc_aircraft.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/food101.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/food101.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/imagenet.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_a.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/imagenet_a.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_r.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/imagenet_r.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_sketch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/imagenet_sketch.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenetv2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/imagenetv2.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_flowers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/oxford_flowers.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/oxford_pets.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stanford_cars.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/stanford_cars.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/sun397.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/__pycache__/ucf101.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | from .dtd import DescribableTextures as DTD 13 | 14 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 15 | NEW_CNAMES = { 16 | "airplanes": "airplane", 17 | "Faces": "face", 18 | "Leopards": "leopard", 19 | "Motorbikes": "motorbike", 20 | } 21 | 22 | 23 | @DATASET_REGISTRY.register() 24 | class Caltech101(DatasetBase): 25 | 26 | dataset_dir = "caltech-101" 27 | 28 | def __init__(self, cfg): 29 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 30 | self.dataset_dir = os.path.join(root, self.dataset_dir) 31 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 32 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 33 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 34 | mkdir_if_missing(self.split_fewshot_dir) 35 | 36 | if os.path.exists(self.split_path): 37 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 38 | else: 39 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 40 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 41 | 42 | num_shots = cfg.DATASET.NUM_SHOTS 43 | if num_shots >= 1: 44 | seed = cfg.SEED 45 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 46 | 47 | if os.path.exists(preprocessed): 48 | print(f"Loading preprocessed few-shot data from {preprocessed}") 49 | with open(preprocessed, "rb") as file: 50 | data = pickle.load(file) 51 | train, val = data["train"], data["val"] 52 | else: 53 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 54 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 55 | data = {"train": train, "val": val} 56 | print(f"Saving preprocessed few-shot data to {preprocessed}") 57 | with open(preprocessed, "wb") as file: 58 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 59 | 60 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 61 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 62 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 63 | 64 | if subsample == 'base': 65 | train = train_base 66 | val = val_base 67 | test = test_base 68 | elif subsample == 'new': 69 | train = train_new 70 | val = val_new 71 | test = test_new 72 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 73 | -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | import random 8 | 9 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 10 | from dassl.utils import listdir_nohidden, mkdir_if_missing 11 | 12 | from .oxford_pets import OxfordPets 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class DescribableTextures(DatasetBase): 17 | 18 | dataset_dir = "dtd" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | train, val, test = self.read_and_split_data(self.image_dir) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 54 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 55 | 56 | if subsample == 'base': 57 | train = train_base 58 | val = val_base 59 | test = test_base 60 | elif subsample == 'new': 61 | train = train_new 62 | val = val_new 63 | test = test_new 64 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 65 | 66 | @staticmethod 67 | def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None): 68 | # The data are supposed to be organized into the following structure 69 | # ============= 70 | # images/ 71 | # dog/ 72 | # cat/ 73 | # horse/ 74 | # ============= 75 | categories = listdir_nohidden(image_dir) 76 | categories = [c for c in categories if c not in ignored] 77 | categories.sort() 78 | 79 | p_tst = 1 - p_trn - p_val 80 | print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test") 81 | 82 | def _collate(ims, y, c): 83 | items = [] 84 | for im in ims: 85 | item = Datum(impath=im, label=y, classname=c) # is already 0-based 86 | items.append(item) 87 | return items 88 | 89 | train, val, test = [], [], [] 90 | for label, category in enumerate(categories): 91 | category_dir = os.path.join(image_dir, category) 92 | images = listdir_nohidden(category_dir) 93 | images = [os.path.join(category_dir, im) for im in images] 94 | random.shuffle(images) 95 | n_total = len(images) 96 | n_train = round(n_total * p_trn) 97 | n_val = round(n_total * p_val) 98 | n_test = n_total - n_train - n_val 99 | assert n_train > 0 and n_val > 0 and n_test > 0 100 | 101 | if new_cnames is not None and category in new_cnames: 102 | category = new_cnames[category] 103 | 104 | train.extend(_collate(images[:n_train], label, category)) 105 | val.extend(_collate(images[n_train : n_train + n_val], label, category)) 106 | test.extend(_collate(images[n_train + n_val :], label, category)) 107 | 108 | return train, val, test 109 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | from .dtd import DescribableTextures as DTD 13 | 14 | NEW_CNAMES = { 15 | "AnnualCrop": "Annual Crop Land", 16 | "Forest": "Forest", 17 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 18 | "Highway": "Highway or Road", 19 | "Industrial": "Industrial Buildings", 20 | "Pasture": "Pasture Land", 21 | "PermanentCrop": "Permanent Crop Land", 22 | "Residential": "Residential Buildings", 23 | "River": "River", 24 | "SeaLake": "Sea or Lake", 25 | } 26 | 27 | 28 | @DATASET_REGISTRY.register() 29 | class EuroSAT(DatasetBase): 30 | 31 | dataset_dir = "eurosat" 32 | 33 | def __init__(self, cfg): 34 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 35 | self.dataset_dir = os.path.join(root, self.dataset_dir) 36 | self.image_dir = os.path.join(self.dataset_dir, "2750") 37 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 38 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 39 | mkdir_if_missing(self.split_fewshot_dir) 40 | 41 | if os.path.exists(self.split_path): 42 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 43 | else: 44 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 45 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 46 | 47 | num_shots = cfg.DATASET.NUM_SHOTS 48 | if num_shots >= 1: 49 | seed = cfg.SEED 50 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 51 | 52 | if os.path.exists(preprocessed): 53 | print(f"Loading preprocessed few-shot data from {preprocessed}") 54 | with open(preprocessed, "rb") as file: 55 | data = pickle.load(file) 56 | train, val = data["train"], data["val"] 57 | else: 58 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 59 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 60 | data = {"train": train, "val": val} 61 | print(f"Saving preprocessed few-shot data to {preprocessed}") 62 | with open(preprocessed, "wb") as file: 63 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 64 | 65 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 66 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 67 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 68 | 69 | if subsample == 'base': 70 | train = train_base 71 | val = val_base 72 | test = test_base 73 | elif subsample == 'new': 74 | train = train_new 75 | val = val_new 76 | test = test_new 77 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 78 | 79 | def update_classname(self, dataset_old): 80 | dataset_new = [] 81 | for item_old in dataset_old: 82 | cname_old = item_old.classname 83 | cname_new = NEW_CLASSNAMES[cname_old] 84 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 85 | dataset_new.append(item_new) 86 | return dataset_new 87 | -------------------------------------------------------------------------------- /datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class FGVCAircraft(DatasetBase): 16 | 17 | dataset_dir = "fgvc_aircraft" 18 | 19 | def __init__(self, cfg): 20 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 21 | self.dataset_dir = os.path.join(root, self.dataset_dir) 22 | self.image_dir = os.path.join(self.dataset_dir, "images") 23 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 24 | mkdir_if_missing(self.split_fewshot_dir) 25 | 26 | classnames = [] 27 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 28 | lines = f.readlines() 29 | for line in lines: 30 | classnames.append(line.strip()) 31 | cname2lab = {c: i for i, c in enumerate(classnames)} 32 | 33 | train = self.read_data(cname2lab, "images_variant_train.txt") 34 | val = self.read_data(cname2lab, "images_variant_val.txt") 35 | test = self.read_data(cname2lab, "images_variant_test.txt") 36 | 37 | num_shots = cfg.DATASET.NUM_SHOTS 38 | if num_shots >= 1: 39 | seed = cfg.SEED 40 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 41 | 42 | if os.path.exists(preprocessed): 43 | print(f"Loading preprocessed few-shot data from {preprocessed}") 44 | with open(preprocessed, "rb") as file: 45 | data = pickle.load(file) 46 | train, val = data["train"], data["val"] 47 | else: 48 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 49 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 50 | data = {"train": train, "val": val} 51 | print(f"Saving preprocessed few-shot data to {preprocessed}") 52 | with open(preprocessed, "wb") as file: 53 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 54 | 55 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 56 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 57 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 58 | 59 | if subsample == 'base': 60 | train = train_base 61 | val = val_base 62 | test = test_base 63 | elif subsample == 'new': 64 | train = train_new 65 | val = val_new 66 | test = test_new 67 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 68 | 69 | def read_data(self, cname2lab, split_file): 70 | filepath = os.path.join(self.dataset_dir, split_file) 71 | items = [] 72 | 73 | with open(filepath, "r") as f: 74 | lines = f.readlines() 75 | for line in lines: 76 | line = line.strip().split(" ") 77 | imname = line[0] + ".jpg" 78 | classname = " ".join(line[1:]) 79 | impath = os.path.join(self.image_dir, imname) 80 | label = cname2lab[classname] 81 | item = Datum(impath=impath, label=label, classname=classname) 82 | items.append(item) 83 | 84 | return items 85 | -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | from .dtd import DescribableTextures as DTD 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class Food101(DatasetBase): 17 | 18 | dataset_dir = "food-101" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | train, val, test = DTD.read_and_split_data(self.image_dir) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 54 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 55 | 56 | if subsample == 'base': 57 | train = train_base 58 | val = val_base 59 | test = test_base 60 | elif subsample == 'new': 61 | train = train_new 62 | val = val_new 63 | test = test_new 64 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 65 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | from collections import OrderedDict 8 | 9 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 10 | from dassl.utils import listdir_nohidden, mkdir_if_missing 11 | 12 | from .oxford_pets import OxfordPets 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class ImageNet(DatasetBase): 17 | 18 | dataset_dir = "imagenet" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.preprocessed): 29 | with open(self.preprocessed, "rb") as f: 30 | preprocessed = pickle.load(f) 31 | train = preprocessed["train"] 32 | test = preprocessed["test"] 33 | else: 34 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 35 | classnames = self.read_classnames(text_file) 36 | train = self.read_data(classnames, "train") 37 | # Follow standard practice to perform evaluation on the val set 38 | # Also used as the val set (so evaluate the last-step model) 39 | test = self.read_data(classnames, "val") 40 | 41 | preprocessed = {"train": train, "test": test} 42 | with open(self.preprocessed, "wb") as f: 43 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 44 | 45 | num_shots = cfg.DATASET.NUM_SHOTS 46 | if num_shots >= 1: 47 | seed = cfg.SEED 48 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 49 | 50 | if os.path.exists(preprocessed): 51 | print(f"Loading preprocessed few-shot data from {preprocessed}") 52 | with open(preprocessed, "rb") as file: 53 | data = pickle.load(file) 54 | train = data["train"] 55 | else: 56 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 57 | data = {"train": train} 58 | print(f"Saving preprocessed few-shot data to {preprocessed}") 59 | with open(preprocessed, "wb") as file: 60 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 61 | 62 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 63 | train_base, test_base = OxfordPets.subsample_classes(train, test, subsample='base') 64 | train_new, test_new = OxfordPets.subsample_classes(train, test, subsample='new') 65 | 66 | if subsample == 'base': 67 | train = train_base 68 | test = test_base 69 | elif subsample == 'new': 70 | train = train_new 71 | test = test_new 72 | super().__init__(train_x=train, val=test, test=test, test_new=test_new) 73 | 74 | @staticmethod 75 | def read_classnames(text_file): 76 | """Return a dictionary containing 77 | key-value pairs of : . 78 | """ 79 | classnames = OrderedDict() 80 | with open(text_file, "r") as f: 81 | lines = f.readlines() 82 | for line in lines: 83 | line = line.strip().split(" ") 84 | folder = line[0] 85 | classname = " ".join(line[1:]) 86 | classnames[folder] = classname 87 | return classnames 88 | 89 | def read_data(self, classnames, split_dir): 90 | split_dir = os.path.join(self.image_dir, split_dir) 91 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 92 | items = [] 93 | 94 | for label, folder in enumerate(folders): 95 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 96 | classname = classnames[folder] 97 | for imname in imnames: 98 | impath = os.path.join(split_dir, folder, imname) 99 | item = Datum(impath=impath, label=label, classname=classname) 100 | items.append(item) 101 | 102 | return items 103 | -------------------------------------------------------------------------------- /datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import listdir_nohidden 9 | 10 | from .imagenet import ImageNet 11 | 12 | TO_BE_IGNORED = ["README.txt"] 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class ImageNetA(DatasetBase): 17 | """ImageNet-A(dversarial). 18 | 19 | This dataset is used for testing only. 20 | """ 21 | 22 | dataset_dir = "imagenet-adversarial" 23 | 24 | def __init__(self, cfg): 25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 26 | self.dataset_dir = os.path.join(root, self.dataset_dir) 27 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 28 | 29 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 30 | classnames = ImageNet.read_classnames(text_file) 31 | 32 | data = self.read_data(classnames) 33 | 34 | super().__init__(train_x=data, test=data) 35 | 36 | def read_data(self, classnames): 37 | image_dir = self.image_dir 38 | folders = listdir_nohidden(image_dir, sort=True) 39 | folders = [f for f in folders if f not in TO_BE_IGNORED] 40 | items = [] 41 | 42 | for label, folder in enumerate(folders): 43 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 44 | classname = classnames[folder] 45 | for imname in imnames: 46 | impath = os.path.join(image_dir, folder, imname) 47 | item = Datum(impath=impath, label=label, classname=classname) 48 | items.append(item) 49 | 50 | return items 51 | -------------------------------------------------------------------------------- /datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import listdir_nohidden 9 | 10 | from .imagenet import ImageNet 11 | 12 | TO_BE_IGNORED = ["README.txt"] 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class ImageNetR(DatasetBase): 17 | """ImageNet-R(endition). 18 | 19 | This dataset is used for testing only. 20 | """ 21 | 22 | dataset_dir = "imagenet-rendition" 23 | 24 | def __init__(self, cfg): 25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 26 | self.dataset_dir = os.path.join(root, self.dataset_dir) 27 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 28 | 29 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 30 | classnames = ImageNet.read_classnames(text_file) 31 | 32 | data = self.read_data(classnames) 33 | 34 | super().__init__(train_x=data, test=data) 35 | 36 | def read_data(self, classnames): 37 | image_dir = self.image_dir 38 | folders = listdir_nohidden(image_dir, sort=True) 39 | folders = [f for f in folders if f not in TO_BE_IGNORED] 40 | items = [] 41 | 42 | for label, folder in enumerate(folders): 43 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 44 | classname = classnames[folder] 45 | for imname in imnames: 46 | impath = os.path.join(image_dir, folder, imname) 47 | item = Datum(impath=impath, label=label, classname=classname) 48 | items.append(item) 49 | 50 | return items 51 | -------------------------------------------------------------------------------- /datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import listdir_nohidden 9 | 10 | from .imagenet import ImageNet 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class ImageNetSketch(DatasetBase): 15 | """ImageNet-Sketch. 16 | 17 | This dataset is used for testing only. 18 | """ 19 | 20 | dataset_dir = "imagenet-sketch" 21 | 22 | def __init__(self, cfg): 23 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 24 | self.dataset_dir = os.path.join(root, self.dataset_dir) 25 | self.image_dir = os.path.join(self.dataset_dir, "images") 26 | 27 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 28 | classnames = ImageNet.read_classnames(text_file) 29 | 30 | data = self.read_data(classnames) 31 | 32 | super().__init__(train_x=data, test=data) 33 | 34 | def read_data(self, classnames): 35 | image_dir = self.image_dir 36 | folders = listdir_nohidden(image_dir, sort=True) 37 | items = [] 38 | 39 | for label, folder in enumerate(folders): 40 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 41 | classname = classnames[folder] 42 | for imname in imnames: 43 | impath = os.path.join(image_dir, folder, imname) 44 | item = Datum(impath=impath, label=label, classname=classname) 45 | items.append(item) 46 | 47 | return items 48 | -------------------------------------------------------------------------------- /datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import listdir_nohidden 9 | 10 | from .imagenet import ImageNet 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class ImageNetV2(DatasetBase): 15 | """ImageNetV2. 16 | 17 | This dataset is used for testing only. 18 | """ 19 | 20 | dataset_dir = "imagenetv2" 21 | 22 | def __init__(self, cfg): 23 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 24 | self.dataset_dir = os.path.join(root, self.dataset_dir) 25 | image_dir = "imagenetv2-matched-frequency-format-val" 26 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 27 | 28 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 29 | classnames = ImageNet.read_classnames(text_file) 30 | 31 | data = self.read_data(classnames) 32 | 33 | super().__init__(train_x=data, test=data) 34 | 35 | def read_data(self, classnames): 36 | image_dir = self.image_dir 37 | folders = list(classnames.keys()) 38 | items = [] 39 | 40 | for label in range(1000): 41 | class_dir = os.path.join(image_dir, str(label)) 42 | imnames = listdir_nohidden(class_dir) 43 | folder = folders[label] 44 | classname = classnames[folder] 45 | for imname in imnames: 46 | impath = os.path.join(class_dir, imname) 47 | item = Datum(impath=impath, label=label, classname=classname) 48 | items.append(item) 49 | 50 | return items 51 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/datasets/ipynb_checkpoints/__init__-checkpoint.py -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/caltech101-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | from .dtd import DescribableTextures as DTD 13 | 14 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 15 | NEW_CNAMES = { 16 | "airplanes": "airplane", 17 | "Faces": "face", 18 | "Leopards": "leopard", 19 | "Motorbikes": "motorbike", 20 | } 21 | 22 | 23 | @DATASET_REGISTRY.register() 24 | class Caltech101(DatasetBase): 25 | 26 | dataset_dir = "caltech-101" 27 | 28 | def __init__(self, cfg): 29 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 30 | self.dataset_dir = os.path.join(root, self.dataset_dir) 31 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 32 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 33 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 34 | mkdir_if_missing(self.split_fewshot_dir) 35 | 36 | if os.path.exists(self.split_path): 37 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 38 | else: 39 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 40 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 41 | 42 | num_shots = cfg.DATASET.NUM_SHOTS 43 | if num_shots >= 1: 44 | seed = cfg.SEED 45 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 46 | 47 | if os.path.exists(preprocessed): 48 | print(f"Loading preprocessed few-shot data from {preprocessed}") 49 | with open(preprocessed, "rb") as file: 50 | data = pickle.load(file) 51 | train, val = data["train"], data["val"] 52 | else: 53 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 54 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 55 | data = {"train": train, "val": val} 56 | print(f"Saving preprocessed few-shot data to {preprocessed}") 57 | with open(preprocessed, "wb") as file: 58 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 59 | 60 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 61 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 62 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 63 | 64 | if subsample == 'base': 65 | train = train_base 66 | val = val_base 67 | test = test_base 68 | elif subsample == 'new': 69 | train = train_new 70 | val = val_new 71 | test = test_new 72 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 73 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/dtd-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | import random 8 | 9 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 10 | from dassl.utils import listdir_nohidden, mkdir_if_missing 11 | 12 | from .oxford_pets import OxfordPets 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class DescribableTextures(DatasetBase): 17 | 18 | dataset_dir = "dtd" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | train, val, test = self.read_and_split_data(self.image_dir) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 54 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 55 | 56 | if subsample == 'base': 57 | train = train_base 58 | val = val_base 59 | test = test_base 60 | elif subsample == 'new': 61 | train = train_new 62 | val = val_new 63 | test = test_new 64 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 65 | 66 | @staticmethod 67 | def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None): 68 | # The data are supposed to be organized into the following structure 69 | # ============= 70 | # images/ 71 | # dog/ 72 | # cat/ 73 | # horse/ 74 | # ============= 75 | categories = listdir_nohidden(image_dir) 76 | categories = [c for c in categories if c not in ignored] 77 | categories.sort() 78 | 79 | p_tst = 1 - p_trn - p_val 80 | print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test") 81 | 82 | def _collate(ims, y, c): 83 | items = [] 84 | for im in ims: 85 | item = Datum(impath=im, label=y, classname=c) # is already 0-based 86 | items.append(item) 87 | return items 88 | 89 | train, val, test = [], [], [] 90 | for label, category in enumerate(categories): 91 | category_dir = os.path.join(image_dir, category) 92 | images = listdir_nohidden(category_dir) 93 | images = [os.path.join(category_dir, im) for im in images] 94 | random.shuffle(images) 95 | n_total = len(images) 96 | n_train = round(n_total * p_trn) 97 | n_val = round(n_total * p_val) 98 | n_test = n_total - n_train - n_val 99 | assert n_train > 0 and n_val > 0 and n_test > 0 100 | 101 | if new_cnames is not None and category in new_cnames: 102 | category = new_cnames[category] 103 | 104 | train.extend(_collate(images[:n_train], label, category)) 105 | val.extend(_collate(images[n_train : n_train + n_val], label, category)) 106 | test.extend(_collate(images[n_train + n_val :], label, category)) 107 | 108 | return train, val, test 109 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/eurosat-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | from .dtd import DescribableTextures as DTD 13 | 14 | NEW_CNAMES = { 15 | "AnnualCrop": "Annual Crop Land", 16 | "Forest": "Forest", 17 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 18 | "Highway": "Highway or Road", 19 | "Industrial": "Industrial Buildings", 20 | "Pasture": "Pasture Land", 21 | "PermanentCrop": "Permanent Crop Land", 22 | "Residential": "Residential Buildings", 23 | "River": "River", 24 | "SeaLake": "Sea or Lake", 25 | } 26 | 27 | 28 | @DATASET_REGISTRY.register() 29 | class EuroSAT(DatasetBase): 30 | 31 | dataset_dir = "eurosat" 32 | 33 | def __init__(self, cfg): 34 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 35 | self.dataset_dir = os.path.join(root, self.dataset_dir) 36 | self.image_dir = os.path.join(self.dataset_dir, "2750") 37 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 38 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 39 | mkdir_if_missing(self.split_fewshot_dir) 40 | 41 | if os.path.exists(self.split_path): 42 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 43 | else: 44 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 45 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 46 | 47 | num_shots = cfg.DATASET.NUM_SHOTS 48 | if num_shots >= 1: 49 | seed = cfg.SEED 50 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 51 | 52 | if os.path.exists(preprocessed): 53 | print(f"Loading preprocessed few-shot data from {preprocessed}") 54 | with open(preprocessed, "rb") as file: 55 | data = pickle.load(file) 56 | train, val = data["train"], data["val"] 57 | else: 58 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 59 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 60 | data = {"train": train, "val": val} 61 | print(f"Saving preprocessed few-shot data to {preprocessed}") 62 | with open(preprocessed, "wb") as file: 63 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 64 | 65 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 66 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 67 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 68 | 69 | if subsample == 'base': 70 | train = train_base 71 | val = val_base 72 | test = test_base 73 | elif subsample == 'new': 74 | train = train_new 75 | val = val_new 76 | test = test_new 77 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 78 | 79 | def update_classname(self, dataset_old): 80 | dataset_new = [] 81 | for item_old in dataset_old: 82 | cname_old = item_old.classname 83 | cname_new = NEW_CLASSNAMES[cname_old] 84 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 85 | dataset_new.append(item_new) 86 | return dataset_new 87 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/fgvc_aircraft-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class FGVCAircraft(DatasetBase): 16 | 17 | dataset_dir = "fgvc_aircraft" 18 | 19 | def __init__(self, cfg): 20 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 21 | self.dataset_dir = os.path.join(root, self.dataset_dir) 22 | self.image_dir = os.path.join(self.dataset_dir, "images") 23 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 24 | mkdir_if_missing(self.split_fewshot_dir) 25 | 26 | classnames = [] 27 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 28 | lines = f.readlines() 29 | for line in lines: 30 | classnames.append(line.strip()) 31 | cname2lab = {c: i for i, c in enumerate(classnames)} 32 | 33 | train = self.read_data(cname2lab, "images_variant_train.txt") 34 | val = self.read_data(cname2lab, "images_variant_val.txt") 35 | test = self.read_data(cname2lab, "images_variant_test.txt") 36 | 37 | num_shots = cfg.DATASET.NUM_SHOTS 38 | if num_shots >= 1: 39 | seed = cfg.SEED 40 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 41 | 42 | if os.path.exists(preprocessed): 43 | print(f"Loading preprocessed few-shot data from {preprocessed}") 44 | with open(preprocessed, "rb") as file: 45 | data = pickle.load(file) 46 | train, val = data["train"], data["val"] 47 | else: 48 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 49 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 50 | data = {"train": train, "val": val} 51 | print(f"Saving preprocessed few-shot data to {preprocessed}") 52 | with open(preprocessed, "wb") as file: 53 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 54 | 55 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 56 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 57 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 58 | 59 | if subsample == 'base': 60 | train = train_base 61 | val = val_base 62 | test = test_base 63 | elif subsample == 'new': 64 | train = train_new 65 | val = val_new 66 | test = test_new 67 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 68 | 69 | def read_data(self, cname2lab, split_file): 70 | filepath = os.path.join(self.dataset_dir, split_file) 71 | items = [] 72 | 73 | with open(filepath, "r") as f: 74 | lines = f.readlines() 75 | for line in lines: 76 | line = line.strip().split(" ") 77 | imname = line[0] + ".jpg" 78 | classname = " ".join(line[1:]) 79 | impath = os.path.join(self.image_dir, imname) 80 | label = cname2lab[classname] 81 | item = Datum(impath=impath, label=label, classname=classname) 82 | items.append(item) 83 | 84 | return items 85 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/food101-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | from .dtd import DescribableTextures as DTD 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class Food101(DatasetBase): 17 | 18 | dataset_dir = "food-101" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | train, val, test = DTD.read_and_split_data(self.image_dir) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 54 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 55 | 56 | if subsample == 'base': 57 | train = train_base 58 | val = val_base 59 | test = test_base 60 | elif subsample == 'new': 61 | train = train_new 62 | val = val_new 63 | test = test_new 64 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 65 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/imagenet-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | from collections import OrderedDict 8 | 9 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 10 | from dassl.utils import listdir_nohidden, mkdir_if_missing 11 | 12 | from .oxford_pets import OxfordPets 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class ImageNet(DatasetBase): 17 | 18 | dataset_dir = "imagenet" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.preprocessed): 29 | with open(self.preprocessed, "rb") as f: 30 | preprocessed = pickle.load(f) 31 | train = preprocessed["train"] 32 | test = preprocessed["test"] 33 | else: 34 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 35 | classnames = self.read_classnames(text_file) 36 | train = self.read_data(classnames, "train") 37 | # Follow standard practice to perform evaluation on the val set 38 | # Also used as the val set (so evaluate the last-step model) 39 | test = self.read_data(classnames, "val") 40 | 41 | preprocessed = {"train": train, "test": test} 42 | with open(self.preprocessed, "wb") as f: 43 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 44 | 45 | num_shots = cfg.DATASET.NUM_SHOTS 46 | if num_shots >= 1: 47 | seed = cfg.SEED 48 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 49 | 50 | if os.path.exists(preprocessed): 51 | print(f"Loading preprocessed few-shot data from {preprocessed}") 52 | with open(preprocessed, "rb") as file: 53 | data = pickle.load(file) 54 | train = data["train"] 55 | else: 56 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 57 | data = {"train": train} 58 | print(f"Saving preprocessed few-shot data to {preprocessed}") 59 | with open(preprocessed, "wb") as file: 60 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 61 | 62 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 63 | train_base, test_base = OxfordPets.subsample_classes(train, test, subsample='base') 64 | train_new, test_new = OxfordPets.subsample_classes(train, test, subsample='new') 65 | 66 | if subsample == 'base': 67 | train = train_base 68 | test = test_base 69 | elif subsample == 'new': 70 | train = train_new 71 | test = test_new 72 | super().__init__(train_x=train, val=test, test=test, test_new=test_new) 73 | 74 | @staticmethod 75 | def read_classnames(text_file): 76 | """Return a dictionary containing 77 | key-value pairs of : . 78 | """ 79 | classnames = OrderedDict() 80 | with open(text_file, "r") as f: 81 | lines = f.readlines() 82 | for line in lines: 83 | line = line.strip().split(" ") 84 | folder = line[0] 85 | classname = " ".join(line[1:]) 86 | classnames[folder] = classname 87 | return classnames 88 | 89 | def read_data(self, classnames, split_dir): 90 | split_dir = os.path.join(self.image_dir, split_dir) 91 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 92 | items = [] 93 | 94 | for label, folder in enumerate(folders): 95 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 96 | classname = classnames[folder] 97 | for imname in imnames: 98 | impath = os.path.join(split_dir, folder, imname) 99 | item = Datum(impath=impath, label=label, classname=classname) 100 | items.append(item) 101 | 102 | return items 103 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/oxford_flowers-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | import random 8 | from scipy.io import loadmat 9 | from collections import defaultdict 10 | 11 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 12 | from dassl.utils import read_json, mkdir_if_missing 13 | 14 | from .oxford_pets import OxfordPets 15 | 16 | 17 | @DATASET_REGISTRY.register() 18 | class OxfordFlowers(DatasetBase): 19 | 20 | dataset_dir = "oxford_flowers" 21 | 22 | def __init__(self, cfg): 23 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 24 | self.dataset_dir = os.path.join(root, self.dataset_dir) 25 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 26 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 27 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") 28 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") 29 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 30 | mkdir_if_missing(self.split_fewshot_dir) 31 | 32 | if os.path.exists(self.split_path): 33 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 34 | else: 35 | train, val, test = self.read_data() 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 58 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 59 | 60 | if subsample == 'base': 61 | train = train_base 62 | val = val_base 63 | test = test_base 64 | elif subsample == 'new': 65 | train = train_new 66 | val = val_new 67 | test = test_new 68 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 69 | 70 | def read_data(self): 71 | tracker = defaultdict(list) 72 | label_file = loadmat(self.label_file)["labels"][0] 73 | for i, label in enumerate(label_file): 74 | imname = f"image_{str(i + 1).zfill(5)}.jpg" 75 | impath = os.path.join(self.image_dir, imname) 76 | label = int(label) 77 | tracker[label].append(impath) 78 | 79 | print("Splitting data into 50% train, 20% val, and 30% test") 80 | 81 | def _collate(ims, y, c): 82 | items = [] 83 | for im in ims: 84 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label 85 | items.append(item) 86 | return items 87 | 88 | lab2cname = read_json(self.lab2cname_file) 89 | train, val, test = [], [], [] 90 | for label, impaths in tracker.items(): 91 | random.shuffle(impaths) 92 | n_total = len(impaths) 93 | n_train = round(n_total * 0.5) 94 | n_val = round(n_total * 0.2) 95 | n_test = n_total - n_train - n_val 96 | assert n_train > 0 and n_val > 0 and n_test > 0 97 | cname = lab2cname[str(label)] 98 | train.extend(_collate(impaths[:n_train], label, cname)) 99 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname)) 100 | test.extend(_collate(impaths[n_train + n_val :], label, cname)) 101 | 102 | return train, val, test 103 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/oxford_pets-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | import math 8 | import random 9 | from collections import defaultdict 10 | 11 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 12 | from dassl.utils import read_json, write_json, mkdir_if_missing 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class OxfordPets(DatasetBase): 17 | 18 | dataset_dir = "oxford_pets" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | self.anno_dir = os.path.join(self.dataset_dir, "annotations") 25 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json") 26 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 27 | mkdir_if_missing(self.split_fewshot_dir) 28 | 29 | if os.path.exists(self.split_path): 30 | train, val, test = self.read_split(self.split_path, self.image_dir) 31 | else: 32 | trainval = self.read_data(split_file="trainval.txt") 33 | test = self.read_data(split_file="test.txt") 34 | train, val = self.split_trainval(trainval) 35 | self.save_split(train, val, test, self.split_path, self.image_dir) 36 | 37 | num_shots = cfg.DATASET.NUM_SHOTS 38 | if num_shots >= 1: 39 | seed = cfg.SEED 40 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 41 | 42 | if os.path.exists(preprocessed): 43 | print(f"Loading preprocessed few-shot data from {preprocessed}") 44 | with open(preprocessed, "rb") as file: 45 | data = pickle.load(file) 46 | train, val = data["train"], data["val"] 47 | else: 48 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 49 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 50 | data = {"train": train, "val": val} 51 | print(f"Saving preprocessed few-shot data to {preprocessed}") 52 | with open(preprocessed, "wb") as file: 53 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 54 | 55 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 56 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 57 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 58 | 59 | if subsample == 'base': 60 | train = train_base 61 | val = val_base 62 | test = test_base 63 | elif subsample == 'new': 64 | train = train_new 65 | val = val_new 66 | test = test_new 67 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 68 | 69 | def read_data(self, split_file): 70 | filepath = os.path.join(self.anno_dir, split_file) 71 | items = [] 72 | 73 | with open(filepath, "r") as f: 74 | lines = f.readlines() 75 | for line in lines: 76 | line = line.strip() 77 | imname, label, species, _ = line.split(" ") 78 | breed = imname.split("_")[:-1] 79 | breed = "_".join(breed) 80 | breed = breed.lower() 81 | imname += ".jpg" 82 | impath = os.path.join(self.image_dir, imname) 83 | label = int(label) - 1 # convert to 0-based index 84 | item = Datum(impath=impath, label=label, classname=breed) 85 | items.append(item) 86 | 87 | return items 88 | 89 | @staticmethod 90 | def split_trainval(trainval, p_val=0.2): 91 | p_trn = 1 - p_val 92 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val") 93 | tracker = defaultdict(list) 94 | for idx, item in enumerate(trainval): 95 | label = item.label 96 | tracker[label].append(idx) 97 | 98 | train, val = [], [] 99 | for label, idxs in tracker.items(): 100 | n_val = round(len(idxs) * p_val) 101 | assert n_val > 0 102 | random.shuffle(idxs) 103 | for n, idx in enumerate(idxs): 104 | item = trainval[idx] 105 | if n < n_val: 106 | val.append(item) 107 | else: 108 | train.append(item) 109 | 110 | return train, val 111 | 112 | @staticmethod 113 | def save_split(train, val, test, filepath, path_prefix): 114 | def _extract(items): 115 | out = [] 116 | for item in items: 117 | impath = item.impath 118 | label = item.label 119 | classname = item.classname 120 | impath = impath.replace(path_prefix, "") 121 | if impath.startswith("/"): 122 | impath = impath[1:] 123 | out.append((impath, label, classname)) 124 | return out 125 | 126 | train = _extract(train) 127 | val = _extract(val) 128 | test = _extract(test) 129 | 130 | split = {"train": train, "val": val, "test": test} 131 | 132 | write_json(split, filepath) 133 | print(f"Saved split to {filepath}") 134 | 135 | @staticmethod 136 | def read_split(filepath, path_prefix): 137 | def _convert(items): 138 | out = [] 139 | for impath, label, classname in items: 140 | impath = os.path.join(path_prefix, impath) 141 | item = Datum(impath=impath, label=int(label), classname=classname) 142 | out.append(item) 143 | return out 144 | 145 | print(f"Reading split from {filepath}") 146 | split = read_json(filepath) 147 | train = _convert(split["train"]) 148 | val = _convert(split["val"]) 149 | test = _convert(split["test"]) 150 | 151 | return train, val, test 152 | 153 | @staticmethod 154 | def subsample_classes(*args, subsample="all"): 155 | """Divide classes into two groups. The first group 156 | represents base classes while the second group represents 157 | new classes. 158 | 159 | Args: 160 | args: a list of datasets, e.g. train, val and test. 161 | subsample (str): what classes to subsample. 162 | """ 163 | assert subsample in ["all", "base", "new"] 164 | 165 | if subsample == "all": 166 | return args 167 | 168 | dataset = args[0] 169 | labels = set() 170 | for item in dataset: 171 | labels.add(item.label) 172 | labels = list(labels) 173 | labels.sort() 174 | n = len(labels) 175 | # Divide classes into two halves 176 | m = math.ceil(n / 2) 177 | 178 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!") 179 | if subsample == "base": 180 | selected = labels[:m] # take the first half 181 | else: 182 | selected = labels[m:] # take the second half 183 | relabeler = {y: y_new for y_new, y in enumerate(selected)} 184 | 185 | output = [] 186 | for dataset in args: 187 | dataset_new = [] 188 | for item in dataset: 189 | if item.label not in selected: 190 | continue 191 | item_new = Datum( 192 | impath=item.impath, 193 | label=relabeler[item.label], 194 | classname=item.classname 195 | ) 196 | dataset_new.append(item_new) 197 | output.append(dataset_new) 198 | 199 | return output 200 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/stanford_cars-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | from scipy.io import loadmat 8 | 9 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 10 | from dassl.utils import mkdir_if_missing 11 | 12 | from .oxford_pets import OxfordPets 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class StanfordCars(DatasetBase): 17 | 18 | dataset_dir = "stanford_cars" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 24 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 25 | mkdir_if_missing(self.split_fewshot_dir) 26 | 27 | if os.path.exists(self.split_path): 28 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 29 | else: 30 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 31 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 32 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 33 | trainval = self.read_data("cars_train", trainval_file, meta_file) 34 | test = self.read_data("cars_test", test_file, meta_file) 35 | train, val = OxfordPets.split_trainval(trainval) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 58 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 59 | 60 | if subsample == 'base': 61 | train = train_base 62 | val = val_base 63 | test = test_base 64 | elif subsample == 'new': 65 | train = train_new 66 | val = val_new 67 | test = test_new 68 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 69 | 70 | def read_data(self, image_dir, anno_file, meta_file): 71 | anno_file = loadmat(anno_file)["annotations"][0] 72 | meta_file = loadmat(meta_file)["class_names"][0] 73 | items = [] 74 | 75 | for i in range(len(anno_file)): 76 | imname = anno_file[i]["fname"][0] 77 | impath = os.path.join(self.dataset_dir, image_dir, imname) 78 | label = anno_file[i]["class"][0, 0] 79 | label = int(label) - 1 # convert to 0-based index 80 | classname = meta_file[label][0] 81 | names = classname.split(" ") 82 | year = names.pop(-1) 83 | names.insert(0, year) 84 | classname = " ".join(names) 85 | item = Datum(impath=impath, label=label, classname=classname) 86 | items.append(item) 87 | 88 | return items 89 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/sun397-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class SUN397(DatasetBase): 16 | 17 | dataset_dir = "sun397" 18 | 19 | def __init__(self, cfg): 20 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 21 | self.dataset_dir = os.path.join(root, self.dataset_dir) 22 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 23 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 24 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 25 | mkdir_if_missing(self.split_fewshot_dir) 26 | 27 | if os.path.exists(self.split_path): 28 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 29 | else: 30 | classnames = [] 31 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 32 | lines = f.readlines() 33 | for line in lines: 34 | line = line.strip()[1:] # remove / 35 | classnames.append(line) 36 | cname2lab = {c: i for i, c in enumerate(classnames)} 37 | trainval = self.read_data(cname2lab, "Training_01.txt") 38 | test = self.read_data(cname2lab, "Testing_01.txt") 39 | train, val = OxfordPets.split_trainval(trainval) 40 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 41 | 42 | num_shots = cfg.DATASET.NUM_SHOTS 43 | if num_shots >= 1: 44 | seed = cfg.SEED 45 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 46 | 47 | if os.path.exists(preprocessed): 48 | print(f"Loading preprocessed few-shot data from {preprocessed}") 49 | with open(preprocessed, "rb") as file: 50 | data = pickle.load(file) 51 | train, val = data["train"], data["val"] 52 | else: 53 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 54 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 55 | data = {"train": train, "val": val} 56 | print(f"Saving preprocessed few-shot data to {preprocessed}") 57 | with open(preprocessed, "wb") as file: 58 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 59 | 60 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 61 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 62 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 63 | 64 | if subsample == 'base': 65 | train = train_base 66 | val = val_base 67 | test = test_base 68 | elif subsample == 'new': 69 | train = train_new 70 | val = val_new 71 | test = test_new 72 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 73 | 74 | def read_data(self, cname2lab, text_file): 75 | text_file = os.path.join(self.dataset_dir, text_file) 76 | items = [] 77 | 78 | with open(text_file, "r") as f: 79 | lines = f.readlines() 80 | for line in lines: 81 | imname = line.strip()[1:] # remove / 82 | classname = os.path.dirname(imname) 83 | label = cname2lab[classname] 84 | impath = os.path.join(self.image_dir, imname) 85 | 86 | names = classname.split("/")[1:] # remove 1st letter 87 | names = names[::-1] # put words like indoor/outdoor at first 88 | classname = " ".join(names) 89 | 90 | item = Datum(impath=impath, label=label, classname=classname) 91 | items.append(item) 92 | 93 | return items 94 | -------------------------------------------------------------------------------- /datasets/ipynb_checkpoints/ucf101-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | import re 8 | 9 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 10 | from dassl.utils import mkdir_if_missing 11 | 12 | from .oxford_pets import OxfordPets 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class UCF101(DatasetBase): 17 | 18 | dataset_dir = "ucf101" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | cname2lab = {} 32 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") 33 | with open(filepath, "r") as f: 34 | lines = f.readlines() 35 | for line in lines: 36 | label, classname = line.strip().split(" ") 37 | label = int(label) - 1 # conver to 0-based index 38 | cname2lab[classname] = label 39 | 40 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") 41 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 42 | train, val = OxfordPets.split_trainval(trainval) 43 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 44 | 45 | num_shots = cfg.DATASET.NUM_SHOTS 46 | if num_shots >= 1: 47 | seed = cfg.SEED 48 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 49 | 50 | if os.path.exists(preprocessed): 51 | print(f"Loading preprocessed few-shot data from {preprocessed}") 52 | with open(preprocessed, "rb") as file: 53 | data = pickle.load(file) 54 | train, val = data["train"], data["val"] 55 | else: 56 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 57 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 58 | data = {"train": train, "val": val} 59 | print(f"Saving preprocessed few-shot data to {preprocessed}") 60 | with open(preprocessed, "wb") as file: 61 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 62 | 63 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 64 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 65 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 66 | 67 | if subsample == 'base': 68 | train = train_base 69 | val = val_base 70 | test = test_base 71 | elif subsample == 'new': 72 | train = train_new 73 | val = val_new 74 | test = test_new 75 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 76 | 77 | def read_data(self, cname2lab, text_file): 78 | text_file = os.path.join(self.dataset_dir, text_file) 79 | items = [] 80 | 81 | with open(text_file, "r") as f: 82 | lines = f.readlines() 83 | for line in lines: 84 | line = line.strip().split(" ")[0] # trainlist: filename, label 85 | action, filename = line.split("/") 86 | label = cname2lab[action] 87 | 88 | elements = re.findall("[A-Z][^A-Z]*", action) 89 | renamed_action = "_".join(elements) 90 | 91 | filename = filename.replace(".avi", ".jpg") 92 | impath = os.path.join(self.image_dir, renamed_action, filename) 93 | 94 | item = Datum(impath=impath, label=label, classname=renamed_action) 95 | items.append(item) 96 | 97 | return items 98 | -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | import random 8 | from scipy.io import loadmat 9 | from collections import defaultdict 10 | 11 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 12 | from dassl.utils import read_json, mkdir_if_missing 13 | 14 | from .oxford_pets import OxfordPets 15 | 16 | 17 | @DATASET_REGISTRY.register() 18 | class OxfordFlowers(DatasetBase): 19 | 20 | dataset_dir = "oxford_flowers" 21 | 22 | def __init__(self, cfg): 23 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 24 | self.dataset_dir = os.path.join(root, self.dataset_dir) 25 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 26 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 27 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") 28 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") 29 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 30 | mkdir_if_missing(self.split_fewshot_dir) 31 | 32 | if os.path.exists(self.split_path): 33 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 34 | else: 35 | train, val, test = self.read_data() 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 58 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 59 | 60 | if subsample == 'base': 61 | train = train_base 62 | val = val_base 63 | test = test_base 64 | elif subsample == 'new': 65 | train = train_new 66 | val = val_new 67 | test = test_new 68 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 69 | 70 | def read_data(self): 71 | tracker = defaultdict(list) 72 | label_file = loadmat(self.label_file)["labels"][0] 73 | for i, label in enumerate(label_file): 74 | imname = f"image_{str(i + 1).zfill(5)}.jpg" 75 | impath = os.path.join(self.image_dir, imname) 76 | label = int(label) 77 | tracker[label].append(impath) 78 | 79 | print("Splitting data into 50% train, 20% val, and 30% test") 80 | 81 | def _collate(ims, y, c): 82 | items = [] 83 | for im in ims: 84 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label 85 | items.append(item) 86 | return items 87 | 88 | lab2cname = read_json(self.lab2cname_file) 89 | train, val, test = [], [], [] 90 | for label, impaths in tracker.items(): 91 | random.shuffle(impaths) 92 | n_total = len(impaths) 93 | n_train = round(n_total * 0.5) 94 | n_val = round(n_total * 0.2) 95 | n_test = n_total - n_train - n_val 96 | assert n_train > 0 and n_val > 0 and n_test > 0 97 | cname = lab2cname[str(label)] 98 | train.extend(_collate(impaths[:n_train], label, cname)) 99 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname)) 100 | test.extend(_collate(impaths[n_train + n_val :], label, cname)) 101 | 102 | return train, val, test 103 | -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | import math 8 | import random 9 | from collections import defaultdict 10 | 11 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 12 | from dassl.utils import read_json, write_json, mkdir_if_missing 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class OxfordPets(DatasetBase): 17 | 18 | dataset_dir = "oxford_pets" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | self.anno_dir = os.path.join(self.dataset_dir, "annotations") 25 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json") 26 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 27 | mkdir_if_missing(self.split_fewshot_dir) 28 | 29 | if os.path.exists(self.split_path): 30 | train, val, test = self.read_split(self.split_path, self.image_dir) 31 | else: 32 | trainval = self.read_data(split_file="trainval.txt") 33 | test = self.read_data(split_file="test.txt") 34 | train, val = self.split_trainval(trainval) 35 | self.save_split(train, val, test, self.split_path, self.image_dir) 36 | 37 | num_shots = cfg.DATASET.NUM_SHOTS 38 | if num_shots >= 1: 39 | seed = cfg.SEED 40 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 41 | 42 | if os.path.exists(preprocessed): 43 | print(f"Loading preprocessed few-shot data from {preprocessed}") 44 | with open(preprocessed, "rb") as file: 45 | data = pickle.load(file) 46 | train, val = data["train"], data["val"] 47 | else: 48 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 49 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 50 | data = {"train": train, "val": val} 51 | print(f"Saving preprocessed few-shot data to {preprocessed}") 52 | with open(preprocessed, "wb") as file: 53 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 54 | 55 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 56 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 57 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 58 | 59 | if subsample == 'base': 60 | train = train_base 61 | val = val_base 62 | test = test_base 63 | elif subsample == 'new': 64 | train = train_new 65 | val = val_new 66 | test = test_new 67 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 68 | 69 | def read_data(self, split_file): 70 | filepath = os.path.join(self.anno_dir, split_file) 71 | items = [] 72 | 73 | with open(filepath, "r") as f: 74 | lines = f.readlines() 75 | for line in lines: 76 | line = line.strip() 77 | imname, label, species, _ = line.split(" ") 78 | breed = imname.split("_")[:-1] 79 | breed = "_".join(breed) 80 | breed = breed.lower() 81 | imname += ".jpg" 82 | impath = os.path.join(self.image_dir, imname) 83 | label = int(label) - 1 # convert to 0-based index 84 | item = Datum(impath=impath, label=label, classname=breed) 85 | items.append(item) 86 | 87 | return items 88 | 89 | @staticmethod 90 | def split_trainval(trainval, p_val=0.2): 91 | p_trn = 1 - p_val 92 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val") 93 | tracker = defaultdict(list) 94 | for idx, item in enumerate(trainval): 95 | label = item.label 96 | tracker[label].append(idx) 97 | 98 | train, val = [], [] 99 | for label, idxs in tracker.items(): 100 | n_val = round(len(idxs) * p_val) 101 | assert n_val > 0 102 | random.shuffle(idxs) 103 | for n, idx in enumerate(idxs): 104 | item = trainval[idx] 105 | if n < n_val: 106 | val.append(item) 107 | else: 108 | train.append(item) 109 | 110 | return train, val 111 | 112 | @staticmethod 113 | def save_split(train, val, test, filepath, path_prefix): 114 | def _extract(items): 115 | out = [] 116 | for item in items: 117 | impath = item.impath 118 | label = item.label 119 | classname = item.classname 120 | impath = impath.replace(path_prefix, "") 121 | if impath.startswith("/"): 122 | impath = impath[1:] 123 | out.append((impath, label, classname)) 124 | return out 125 | 126 | train = _extract(train) 127 | val = _extract(val) 128 | test = _extract(test) 129 | 130 | split = {"train": train, "val": val, "test": test} 131 | 132 | write_json(split, filepath) 133 | print(f"Saved split to {filepath}") 134 | 135 | @staticmethod 136 | def read_split(filepath, path_prefix): 137 | def _convert(items): 138 | out = [] 139 | for impath, label, classname in items: 140 | impath = os.path.join(path_prefix, impath) 141 | item = Datum(impath=impath, label=int(label), classname=classname) 142 | out.append(item) 143 | return out 144 | 145 | print(f"Reading split from {filepath}") 146 | split = read_json(filepath) 147 | train = _convert(split["train"]) 148 | val = _convert(split["val"]) 149 | test = _convert(split["test"]) 150 | 151 | return train, val, test 152 | 153 | @staticmethod 154 | def subsample_classes(*args, subsample="all"): 155 | """Divide classes into two groups. The first group 156 | represents base classes while the second group represents 157 | new classes. 158 | 159 | Args: 160 | args: a list of datasets, e.g. train, val and test. 161 | subsample (str): what classes to subsample. 162 | """ 163 | assert subsample in ["all", "base", "new"] 164 | 165 | if subsample == "all": 166 | return args 167 | 168 | dataset = args[0] 169 | labels = set() 170 | for item in dataset: 171 | labels.add(item.label) 172 | labels = list(labels) 173 | labels.sort() 174 | n = len(labels) 175 | # Divide classes into two halves 176 | m = math.ceil(n / 2) 177 | 178 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!") 179 | if subsample == "base": 180 | selected = labels[:m] # take the first half 181 | else: 182 | selected = labels[m:] # take the second half 183 | relabeler = {y: y_new for y_new, y in enumerate(selected)} 184 | 185 | output = [] 186 | for dataset in args: 187 | dataset_new = [] 188 | for item in dataset: 189 | if item.label not in selected: 190 | continue 191 | item_new = Datum( 192 | impath=item.impath, 193 | label=relabeler[item.label], 194 | classname=item.classname 195 | ) 196 | dataset_new.append(item_new) 197 | output.append(dataset_new) 198 | 199 | return output 200 | -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | from scipy.io import loadmat 8 | 9 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 10 | from dassl.utils import mkdir_if_missing 11 | 12 | from .oxford_pets import OxfordPets 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class StanfordCars(DatasetBase): 17 | 18 | dataset_dir = "stanford_cars" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 24 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 25 | mkdir_if_missing(self.split_fewshot_dir) 26 | 27 | if os.path.exists(self.split_path): 28 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 29 | else: 30 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 31 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 32 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 33 | trainval = self.read_data("cars_train", trainval_file, meta_file) 34 | test = self.read_data("cars_test", test_file, meta_file) 35 | train, val = OxfordPets.split_trainval(trainval) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 58 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 59 | 60 | if subsample == 'base': 61 | train = train_base 62 | val = val_base 63 | test = test_base 64 | elif subsample == 'new': 65 | train = train_new 66 | val = val_new 67 | test = test_new 68 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 69 | 70 | def read_data(self, image_dir, anno_file, meta_file): 71 | anno_file = loadmat(anno_file)["annotations"][0] 72 | meta_file = loadmat(meta_file)["class_names"][0] 73 | items = [] 74 | 75 | for i in range(len(anno_file)): 76 | imname = anno_file[i]["fname"][0] 77 | impath = os.path.join(self.dataset_dir, image_dir, imname) 78 | label = anno_file[i]["class"][0, 0] 79 | label = int(label) - 1 # convert to 0-based index 80 | classname = meta_file[label][0] 81 | names = classname.split(" ") 82 | year = names.pop(-1) 83 | names.insert(0, year) 84 | classname = " ".join(names) 85 | item = Datum(impath=impath, label=label, classname=classname) 86 | items.append(item) 87 | 88 | return items 89 | -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | 8 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 9 | from dassl.utils import mkdir_if_missing 10 | 11 | from .oxford_pets import OxfordPets 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class SUN397(DatasetBase): 16 | 17 | dataset_dir = "sun397" 18 | 19 | def __init__(self, cfg): 20 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 21 | self.dataset_dir = os.path.join(root, self.dataset_dir) 22 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 23 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 24 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 25 | mkdir_if_missing(self.split_fewshot_dir) 26 | 27 | if os.path.exists(self.split_path): 28 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 29 | else: 30 | classnames = [] 31 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 32 | lines = f.readlines() 33 | for line in lines: 34 | line = line.strip()[1:] # remove / 35 | classnames.append(line) 36 | cname2lab = {c: i for i, c in enumerate(classnames)} 37 | trainval = self.read_data(cname2lab, "Training_01.txt") 38 | test = self.read_data(cname2lab, "Testing_01.txt") 39 | train, val = OxfordPets.split_trainval(trainval) 40 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 41 | 42 | num_shots = cfg.DATASET.NUM_SHOTS 43 | if num_shots >= 1: 44 | seed = cfg.SEED 45 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 46 | 47 | if os.path.exists(preprocessed): 48 | print(f"Loading preprocessed few-shot data from {preprocessed}") 49 | with open(preprocessed, "rb") as file: 50 | data = pickle.load(file) 51 | train, val = data["train"], data["val"] 52 | else: 53 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 54 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 55 | data = {"train": train, "val": val} 56 | print(f"Saving preprocessed few-shot data to {preprocessed}") 57 | with open(preprocessed, "wb") as file: 58 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 59 | 60 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 61 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 62 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 63 | 64 | if subsample == 'base': 65 | train = train_base 66 | val = val_base 67 | test = test_base 68 | elif subsample == 'new': 69 | train = train_new 70 | val = val_new 71 | test = test_new 72 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 73 | 74 | def read_data(self, cname2lab, text_file): 75 | text_file = os.path.join(self.dataset_dir, text_file) 76 | items = [] 77 | 78 | with open(text_file, "r") as f: 79 | lines = f.readlines() 80 | for line in lines: 81 | imname = line.strip()[1:] # remove / 82 | classname = os.path.dirname(imname) 83 | label = cname2lab[classname] 84 | impath = os.path.join(self.image_dir, imname) 85 | 86 | names = classname.split("/")[1:] # remove 1st letter 87 | names = names[::-1] # put words like indoor/outdoor at first 88 | classname = " ".join(names) 89 | 90 | item = Datum(impath=impath, label=label, classname=classname) 91 | items.append(item) 92 | 93 | return items 94 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pickle 7 | import re 8 | 9 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 10 | from dassl.utils import mkdir_if_missing 11 | 12 | from .oxford_pets import OxfordPets 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class UCF101(DatasetBase): 17 | 18 | dataset_dir = "ucf101" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | cname2lab = {} 32 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") 33 | with open(filepath, "r") as f: 34 | lines = f.readlines() 35 | for line in lines: 36 | label, classname = line.strip().split(" ") 37 | label = int(label) - 1 # conver to 0-based index 38 | cname2lab[classname] = label 39 | 40 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") 41 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 42 | train, val = OxfordPets.split_trainval(trainval) 43 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 44 | 45 | num_shots = cfg.DATASET.NUM_SHOTS 46 | if num_shots >= 1: 47 | seed = cfg.SEED 48 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 49 | 50 | if os.path.exists(preprocessed): 51 | print(f"Loading preprocessed few-shot data from {preprocessed}") 52 | with open(preprocessed, "rb") as file: 53 | data = pickle.load(file) 54 | train, val = data["train"], data["val"] 55 | else: 56 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 57 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 58 | data = {"train": train, "val": val} 59 | print(f"Saving preprocessed few-shot data to {preprocessed}") 60 | with open(preprocessed, "wb") as file: 61 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 62 | 63 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 64 | train_base, val_base, test_base = OxfordPets.subsample_classes(train, val, test, subsample='base') 65 | train_new, val_new, test_new = OxfordPets.subsample_classes(train, val, test, subsample='new') 66 | 67 | if subsample == 'base': 68 | train = train_base 69 | val = val_base 70 | test = test_base 71 | elif subsample == 'new': 72 | train = train_new 73 | val = val_new 74 | test = test_new 75 | super().__init__(train_x=train, val=val, test=test, test_new=test_new) 76 | 77 | def read_data(self, cname2lab, text_file): 78 | text_file = os.path.join(self.dataset_dir, text_file) 79 | items = [] 80 | 81 | with open(text_file, "r") as f: 82 | lines = f.readlines() 83 | for line in lines: 84 | line = line.strip().split(" ")[0] # trainlist: filename, label 85 | action, filename = line.split("/") 86 | label = cname2lab[action] 87 | 88 | elements = re.findall("[A-Z][^A-Z]*", action) 89 | renamed_action = "_".join(elements) 90 | 91 | filename = filename.replace(".avi", ".jpg") 92 | impath = os.path.join(self.image_dir, renamed_action, filename) 93 | 94 | item = Datum(impath=impath, label=label, classname=renamed_action) 95 | items.append(item) 96 | 97 | return items 98 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | -------------------------------------------------------------------------------- /scripts/cocoop/README.md: -------------------------------------------------------------------------------- 1 | These scripts are only for reproducing the results on the CVPR'22 paper. -------------------------------------------------------------------------------- /scripts/cocoop/base2new_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoCoOp 8 | # TRAINER=CoOp 9 | 10 | DATASET=$1 11 | SEED=$2 12 | 13 | CFG=vit_b16_c4_ep10_batch1_ctxv1 14 | # CFG=vit_b16_ctxv1 # uncomment this when TRAINER=CoOp 15 | SHOTS=16 16 | LOADEP=10 17 | SUB=new 18 | 19 | 20 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 21 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 22 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 23 | if [ -d "$DIR" ]; then 24 | echo "Oops! The results exist at ${DIR} (so skip this job)" 25 | else 26 | python train.py \ 27 | --root ${DATA} \ 28 | --seed ${SEED} \ 29 | --trainer ${TRAINER} \ 30 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 32 | --output-dir ${DIR} \ 33 | --model-dir ${MODEL_DIR} \ 34 | --load-epoch ${LOADEP} \ 35 | --eval-only \ 36 | DATASET.NUM_SHOTS ${SHOTS} \ 37 | DATASET.SUBSAMPLE_CLASSES ${SUB} 38 | fi -------------------------------------------------------------------------------- /scripts/cocoop/base2new_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoCoOp 8 | # TRAINER=CoOp 9 | 10 | DATASET=$1 11 | SEED=$2 12 | 13 | CFG=vit_b16_c4_ep10_batch1_ctxv1 14 | # CFG=vit_b16_ctxv1 # uncomment this when TRAINER=CoOp 15 | # CFG=vit_b16_ep50_ctxv1 # uncomment this when TRAINER=CoOp and DATASET=imagenet 16 | SHOTS=16 17 | 18 | 19 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 20 | if [ -d "$DIR" ]; then 21 | echo "Oops! The results exist at ${DIR} (so skip this job)" 22 | else 23 | python train.py \ 24 | --root ${DATA} \ 25 | --seed ${SEED} \ 26 | --trainer ${TRAINER} \ 27 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 28 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 29 | --output-dir ${DIR} \ 30 | DATASET.NUM_SHOTS ${SHOTS} \ 31 | DATASET.SUBSAMPLE_CLASSES base 32 | fi -------------------------------------------------------------------------------- /scripts/cocoop/ipynb_checkpoints/base2new_test-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoCoOp 8 | # TRAINER=CoOp 9 | 10 | DATASET=$1 11 | SEED=$2 12 | 13 | CFG=vit_b16_c4_ep10_batch1_ctxv1 14 | # CFG=vit_b16_ctxv1 # uncomment this when TRAINER=CoOp 15 | SHOTS=16 16 | LOADEP=10 17 | SUB=new 18 | 19 | 20 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 21 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 22 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 23 | if [ -d "$DIR" ]; then 24 | echo "Oops! The results exist at ${DIR} (so skip this job)" 25 | else 26 | python train.py \ 27 | --root ${DATA} \ 28 | --seed ${SEED} \ 29 | --trainer ${TRAINER} \ 30 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 32 | --output-dir ${DIR} \ 33 | --model-dir ${MODEL_DIR} \ 34 | --load-epoch ${LOADEP} \ 35 | --eval-only \ 36 | DATASET.NUM_SHOTS ${SHOTS} \ 37 | DATASET.SUBSAMPLE_CLASSES ${SUB} 38 | fi -------------------------------------------------------------------------------- /scripts/cocoop/xd_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoCoOp 8 | # TRAINER=CoOp 9 | 10 | DATASET=$1 11 | SEED=$2 12 | 13 | CFG=vit_b16_c4_ep10_batch1_ctxv1 14 | # CFG=vit_b16_ep50_ctxv1 # uncomment this when TRAINER=CoOp and DATASET=imagenet 15 | SHOTS=16 16 | 17 | 18 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED} 19 | if [ -d "$DIR" ]; then 20 | echo "Oops! The results exist at ${DIR} (so skip this job)" 21 | else 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \ 30 | --load-epoch 10 \ 31 | --eval-only 32 | fi -------------------------------------------------------------------------------- /scripts/cocoop/xd_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoCoOp 8 | # TRAINER=CoOp 9 | 10 | DATASET=imagenet 11 | SEED=$1 12 | 13 | CFG=vit_b16_c4_ep10_batch1_ctxv1 14 | # CFG=vit_b16_ep50_ctxv1 # uncomment this when TRAINER=CoOp and DATASET=imagenet 15 | SHOTS=16 16 | 17 | 18 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 19 | if [ -d "$DIR" ]; then 20 | echo "Oops! The results exist at ${DIR} (so skip this job)" 21 | else 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | DATASET.NUM_SHOTS ${SHOTS} 30 | fi -------------------------------------------------------------------------------- /scripts/coop/README.md: -------------------------------------------------------------------------------- 1 | These scripts are only for reproducing the results on the IJCV'22 paper. -------------------------------------------------------------------------------- /scripts/coop/base2new_train_ep10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=CoOp 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep10_ctxv1 13 | SHOTS=16 14 | LOADEP=10 15 | 16 | SUB=base 17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | CUDA_VISIBLE_DEVICES=0 python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} \ 25 | DATASET.NUM_SHOTS ${SHOTS} \ 26 | DATASET.SUBSAMPLE_CLASSES ${SUB} 27 | 28 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 29 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 30 | 31 | SUB=new 32 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 33 | CUDA_VISIBLE_DEVICES=0 python train.py \ 34 | --root ${DATA} \ 35 | --seed ${SEED} \ 36 | --trainer ${TRAINER} \ 37 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 38 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 39 | --output-dir ${DIR} \ 40 | --model-dir ${MODEL_DIR} \ 41 | --load-epoch ${LOADEP} \ 42 | --eval-only \ 43 | DATASET.NUM_SHOTS ${SHOTS} \ 44 | DATASET.SUBSAMPLE_CLASSES ${SUB} 45 | -------------------------------------------------------------------------------- /scripts/coop/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=/path/to/datasets 5 | TRAINER=CoOp 6 | SHOTS=16 7 | NCTX=16 8 | CSC=False 9 | CTP=end 10 | 11 | DATASET=$1 12 | CFG=$2 13 | 14 | for SEED in 1 2 3 15 | do 16 | python train.py \ 17 | --root ${DATA} \ 18 | --seed ${SEED} \ 19 | --trainer ${TRAINER} \ 20 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 21 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 22 | --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \ 23 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \ 24 | --load-epoch 50 \ 25 | --eval-only \ 26 | TRAINER.COOP.N_CTX ${NCTX} \ 27 | TRAINER.COOP.CSC ${CSC} \ 28 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} 29 | done -------------------------------------------------------------------------------- /scripts/coop/ipynb_checkpoints/base2new_train_ep10-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=CoOp 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep10_ctxv1 13 | SHOTS=16 14 | LOADEP=10 15 | 16 | SUB=base 17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} \ 25 | DATASET.NUM_SHOTS ${SHOTS} \ 26 | DATASET.SUBSAMPLE_CLASSES ${SUB} 27 | 28 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 29 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 30 | 31 | SUB=new 32 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 33 | python train.py \ 34 | --root ${DATA} \ 35 | --seed ${SEED} \ 36 | --trainer ${TRAINER} \ 37 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 38 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 39 | --output-dir ${DIR} \ 40 | --model-dir ${MODEL_DIR} \ 41 | --load-epoch ${LOADEP} \ 42 | --eval-only \ 43 | DATASET.NUM_SHOTS ${SHOTS} \ 44 | DATASET.SUBSAMPLE_CLASSES ${SUB} 45 | -------------------------------------------------------------------------------- /scripts/coop/ipynb_checkpoints/eval-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=/path/to/datasets 5 | TRAINER=CoOp 6 | SHOTS=16 7 | NCTX=16 8 | CSC=False 9 | CTP=end 10 | 11 | DATASET=$1 12 | CFG=$2 13 | 14 | for SEED in 1 2 3 15 | do 16 | python train.py \ 17 | --root ${DATA} \ 18 | --seed ${SEED} \ 19 | --trainer ${TRAINER} \ 20 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 21 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 22 | --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \ 23 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \ 24 | --load-epoch 50 \ 25 | --eval-only \ 26 | TRAINER.COOP.N_CTX ${NCTX} \ 27 | TRAINER.COOP.CSC ${CSC} \ 28 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} 29 | done -------------------------------------------------------------------------------- /scripts/coop/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=/path/to/datasets 5 | TRAINER=CoOp 6 | 7 | DATASET=$1 8 | CFG=$2 # config file 9 | CTP=$3 # class token position (end or middle) 10 | NCTX=$4 # number of context tokens 11 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16) 12 | CSC=$6 # class-specific context (False or True) 13 | 14 | for SEED in 1 2 3 15 | do 16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Oops! The results exist at ${DIR} (so skip this job)" 19 | else 20 | python train.py \ 21 | --root ${DATA} \ 22 | --seed ${SEED} \ 23 | --trainer ${TRAINER} \ 24 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 25 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 26 | --output-dir ${DIR} \ 27 | TRAINER.COOP.N_CTX ${NCTX} \ 28 | TRAINER.COOP.CSC ${CSC} \ 29 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 30 | DATASET.NUM_SHOTS ${SHOTS} 31 | fi 32 | done -------------------------------------------------------------------------------- /scripts/coop/zeroshot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=/path/to/datasets 5 | TRAINER=ZeroshotCLIP 6 | DATASET=$1 7 | CFG=$2 # rn50, rn101, vit_b32 or vit_b16 8 | 9 | python train.py \ 10 | --root ${DATA} \ 11 | --trainer ${TRAINER} \ 12 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 13 | --config-file configs/trainers/CoOp/${CFG}.yaml \ 14 | --output-dir output/${TRAINER}/${CFG}/${DATASET} \ 15 | --eval-only -------------------------------------------------------------------------------- /scripts/ogen/base2new_eval_ep10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep10_ctxv1 13 | SHOTS=16 14 | LOADEP=10 15 | 16 | 17 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 19 | 20 | SUB=base 21 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 22 | CUDA_VISIBLE_DEVICES=0 python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir ${MODEL_DIR} \ 30 | --load-epoch ${LOADEP} \ 31 | --eval-only \ 32 | DATASET.NUM_SHOTS ${SHOTS} \ 33 | DATASET.SUBSAMPLE_CLASSES ${SUB} 34 | 35 | SUB=new 36 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 37 | CUDA_VISIBLE_DEVICES=0 python train.py \ 38 | --root ${DATA} \ 39 | --seed ${SEED} \ 40 | --trainer ${TRAINER} \ 41 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 42 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 43 | --output-dir ${DIR} \ 44 | --model-dir ${MODEL_DIR} \ 45 | --load-epoch ${LOADEP} \ 46 | --eval-only \ 47 | DATASET.NUM_SHOTS ${SHOTS} \ 48 | DATASET.SUBSAMPLE_CLASSES ${SUB} -------------------------------------------------------------------------------- /scripts/ogen/base2new_eval_ep200.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep200_ctxv1 13 | SHOTS=16 14 | LOADEP=200 15 | 16 | 17 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 19 | 20 | SUB=base 21 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 22 | CUDA_VISIBLE_DEVICES=0 python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir ${MODEL_DIR} \ 30 | --load-epoch ${LOADEP} \ 31 | --eval-only \ 32 | DATASET.NUM_SHOTS ${SHOTS} \ 33 | DATASET.SUBSAMPLE_CLASSES ${SUB} 34 | 35 | SUB=new 36 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 37 | CUDA_VISIBLE_DEVICES=0 python train.py \ 38 | --root ${DATA} \ 39 | --seed ${SEED} \ 40 | --trainer ${TRAINER} \ 41 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 42 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 43 | --output-dir ${DIR} \ 44 | --model-dir ${MODEL_DIR} \ 45 | --load-epoch ${LOADEP} \ 46 | --eval-only \ 47 | DATASET.NUM_SHOTS ${SHOTS} \ 48 | DATASET.SUBSAMPLE_CLASSES ${SUB} -------------------------------------------------------------------------------- /scripts/ogen/base2new_eval_ep50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep50_ctxv1 13 | SHOTS=16 14 | LOADEP=50 15 | 16 | 17 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 19 | 20 | SUB=base 21 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 22 | CUDA_VISIBLE_DEVICES=0 python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir ${MODEL_DIR} \ 30 | --load-epoch ${LOADEP} \ 31 | --eval-only \ 32 | DATASET.NUM_SHOTS ${SHOTS} \ 33 | DATASET.SUBSAMPLE_CLASSES ${SUB} 34 | 35 | SUB=new 36 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 37 | CUDA_VISIBLE_DEVICES=0 python train.py \ 38 | --root ${DATA} \ 39 | --seed ${SEED} \ 40 | --trainer ${TRAINER} \ 41 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 42 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 43 | --output-dir ${DIR} \ 44 | --model-dir ${MODEL_DIR} \ 45 | --load-epoch ${LOADEP} \ 46 | --eval-only \ 47 | DATASET.NUM_SHOTS ${SHOTS} \ 48 | DATASET.SUBSAMPLE_CLASSES ${SUB} -------------------------------------------------------------------------------- /scripts/ogen/base2new_train_ep10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep10_ctxv1 13 | SHOTS=16 14 | LOADEP=10 15 | 16 | SUB=base 17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | CUDA_VISIBLE_DEVICES=0 python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} \ 25 | DATASET.NUM_SHOTS ${SHOTS} \ 26 | DATASET.SUBSAMPLE_CLASSES ${SUB} 27 | 28 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 29 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 30 | 31 | SUB=new 32 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 33 | CUDA_VISIBLE_DEVICES=0 python train.py \ 34 | --root ${DATA} \ 35 | --seed ${SEED} \ 36 | --trainer ${TRAINER} \ 37 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 38 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 39 | --output-dir ${DIR} \ 40 | --model-dir ${MODEL_DIR} \ 41 | --load-epoch ${LOADEP} \ 42 | --eval-only \ 43 | DATASET.NUM_SHOTS ${SHOTS} \ 44 | DATASET.SUBSAMPLE_CLASSES ${SUB} 45 | -------------------------------------------------------------------------------- /scripts/ogen/base2new_train_ep200.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep200_ctxv1 13 | SHOTS=16 14 | LOADEP=200 15 | 16 | SUB=base 17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | CUDA_VISIBLE_DEVICES=0 python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} \ 25 | DATASET.NUM_SHOTS ${SHOTS} \ 26 | DATASET.SUBSAMPLE_CLASSES ${SUB} 27 | -------------------------------------------------------------------------------- /scripts/ogen/base2new_train_ep50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep50_ctxv1 13 | SHOTS=16 14 | LOADEP=50 15 | 16 | SUB=base 17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | CUDA_VISIBLE_DEVICES=0 python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} \ 25 | DATASET.NUM_SHOTS ${SHOTS} \ 26 | DATASET.SUBSAMPLE_CLASSES ${SUB} 27 | -------------------------------------------------------------------------------- /scripts/ogen/ipynb_checkpoints/base2new_eval_ep10-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep10_ctxv1 13 | SHOTS=16 14 | LOADEP=10 15 | 16 | 17 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 19 | 20 | SUB=base 21 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 22 | CUDA_VISIBLE_DEVICES=0 python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir ${MODEL_DIR} \ 30 | --load-epoch ${LOADEP} \ 31 | --eval-only \ 32 | DATASET.NUM_SHOTS ${SHOTS} \ 33 | DATASET.SUBSAMPLE_CLASSES ${SUB} 34 | 35 | SUB=new 36 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 37 | CUDA_VISIBLE_DEVICES=0 python train.py \ 38 | --root ${DATA} \ 39 | --seed ${SEED} \ 40 | --trainer ${TRAINER} \ 41 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 42 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 43 | --output-dir ${DIR} \ 44 | --model-dir ${MODEL_DIR} \ 45 | --load-epoch ${LOADEP} \ 46 | --eval-only \ 47 | DATASET.NUM_SHOTS ${SHOTS} \ 48 | DATASET.SUBSAMPLE_CLASSES ${SUB} -------------------------------------------------------------------------------- /scripts/ogen/ipynb_checkpoints/base2new_eval_ep200-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep200_ctxv1 13 | SHOTS=16 14 | LOADEP=200 15 | 16 | 17 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 19 | 20 | SUB=base 21 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 22 | CUDA_VISIBLE_DEVICES=0 python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir ${MODEL_DIR} \ 30 | --load-epoch ${LOADEP} \ 31 | --eval-only \ 32 | DATASET.NUM_SHOTS ${SHOTS} \ 33 | DATASET.SUBSAMPLE_CLASSES ${SUB} 34 | 35 | SUB=new 36 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 37 | CUDA_VISIBLE_DEVICES=0 python train.py \ 38 | --root ${DATA} \ 39 | --seed ${SEED} \ 40 | --trainer ${TRAINER} \ 41 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 42 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 43 | --output-dir ${DIR} \ 44 | --model-dir ${MODEL_DIR} \ 45 | --load-epoch ${LOADEP} \ 46 | --eval-only \ 47 | DATASET.NUM_SHOTS ${SHOTS} \ 48 | DATASET.SUBSAMPLE_CLASSES ${SUB} -------------------------------------------------------------------------------- /scripts/ogen/ipynb_checkpoints/base2new_eval_ep50-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep50_ctxv1 13 | SHOTS=16 14 | LOADEP=50 15 | 16 | 17 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 19 | 20 | SUB=base 21 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 22 | CUDA_VISIBLE_DEVICES=0 python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir ${MODEL_DIR} \ 30 | --load-epoch ${LOADEP} \ 31 | --eval-only \ 32 | DATASET.NUM_SHOTS ${SHOTS} \ 33 | DATASET.SUBSAMPLE_CLASSES ${SUB} 34 | 35 | SUB=new 36 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 37 | CUDA_VISIBLE_DEVICES=0 python train.py \ 38 | --root ${DATA} \ 39 | --seed ${SEED} \ 40 | --trainer ${TRAINER} \ 41 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 42 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 43 | --output-dir ${DIR} \ 44 | --model-dir ${MODEL_DIR} \ 45 | --load-epoch ${LOADEP} \ 46 | --eval-only \ 47 | DATASET.NUM_SHOTS ${SHOTS} \ 48 | DATASET.SUBSAMPLE_CLASSES ${SUB} -------------------------------------------------------------------------------- /scripts/ogen/ipynb_checkpoints/base2new_train_ep10-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep10_ctxv1 13 | SHOTS=16 14 | LOADEP=10 15 | 16 | SUB=base 17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | CUDA_VISIBLE_DEVICES=0 python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} \ 25 | DATASET.NUM_SHOTS ${SHOTS} \ 26 | DATASET.SUBSAMPLE_CLASSES ${SUB} 27 | 28 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 29 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 30 | 31 | SUB=new 32 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 33 | CUDA_VISIBLE_DEVICES=0 python train.py \ 34 | --root ${DATA} \ 35 | --seed ${SEED} \ 36 | --trainer ${TRAINER} \ 37 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 38 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 39 | --output-dir ${DIR} \ 40 | --model-dir ${MODEL_DIR} \ 41 | --load-epoch ${LOADEP} \ 42 | --eval-only \ 43 | DATASET.NUM_SHOTS ${SHOTS} \ 44 | DATASET.SUBSAMPLE_CLASSES ${SUB} 45 | -------------------------------------------------------------------------------- /scripts/ogen/ipynb_checkpoints/base2new_train_ep200-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep200_ctxv1 13 | SHOTS=16 14 | LOADEP=200 15 | 16 | SUB=base 17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | CUDA_VISIBLE_DEVICES=0 python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} \ 25 | DATASET.NUM_SHOTS ${SHOTS} \ 26 | DATASET.SUBSAMPLE_CLASSES ${SUB} 27 | -------------------------------------------------------------------------------- /scripts/ogen/ipynb_checkpoints/base2new_train_ep50-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cd ../.. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=OGEN 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep50_ctxv1 13 | SHOTS=16 14 | LOADEP=50 15 | 16 | SUB=base 17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | CUDA_VISIBLE_DEVICES=0 python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} \ 25 | DATASET.NUM_SHOTS ${SHOTS} \ 26 | DATASET.SUBSAMPLE_CLASSES ${SUB} 27 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | import torch 7 | 8 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 9 | from dassl.config import get_cfg_default 10 | from dassl.engine import build_trainer 11 | 12 | # custom 13 | import datasets.oxford_pets 14 | import datasets.oxford_flowers 15 | import datasets.fgvc_aircraft 16 | import datasets.dtd 17 | import datasets.eurosat 18 | import datasets.stanford_cars 19 | import datasets.food101 20 | import datasets.sun397 21 | import datasets.caltech101 22 | import datasets.ucf101 23 | import datasets.imagenet 24 | 25 | import datasets.imagenet_sketch 26 | import datasets.imagenetv2 27 | import datasets.imagenet_a 28 | import datasets.imagenet_r 29 | 30 | import trainers.coop 31 | import trainers.cocoop 32 | import trainers.zsclip 33 | import trainers.ogen 34 | 35 | 36 | def print_args(args, cfg): 37 | print("***************") 38 | print("** Arguments **") 39 | print("***************") 40 | optkeys = list(args.__dict__.keys()) 41 | optkeys.sort() 42 | for key in optkeys: 43 | print("{}: {}".format(key, args.__dict__[key])) 44 | print("************") 45 | print("** Config **") 46 | print("************") 47 | print(cfg) 48 | 49 | 50 | def reset_cfg(cfg, args): 51 | if args.root: 52 | cfg.DATASET.ROOT = args.root 53 | 54 | if args.output_dir: 55 | cfg.OUTPUT_DIR = args.output_dir 56 | 57 | if args.resume: 58 | cfg.RESUME = args.resume 59 | 60 | if args.seed: 61 | cfg.SEED = args.seed 62 | 63 | if args.source_domains: 64 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 65 | 66 | if args.target_domains: 67 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 68 | 69 | if args.transforms: 70 | cfg.INPUT.TRANSFORMS = args.transforms 71 | 72 | if args.trainer: 73 | cfg.TRAINER.NAME = args.trainer 74 | 75 | if args.backbone: 76 | cfg.MODEL.BACKBONE.NAME = args.backbone 77 | 78 | if args.head: 79 | cfg.MODEL.HEAD.NAME = args.head 80 | 81 | 82 | def extend_cfg(cfg): 83 | """ 84 | Add new config variables. 85 | 86 | E.g. 87 | from yacs.config import CfgNode as CN 88 | cfg.TRAINER.MY_MODEL = CN() 89 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 90 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 91 | cfg.TRAINER.MY_MODEL.PARAM_C = False 92 | """ 93 | from yacs.config import CfgNode as CN 94 | 95 | cfg.TRAINER.COOP = CN() 96 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors 97 | cfg.TRAINER.COOP.CSC = False # class-specific context 98 | cfg.TRAINER.COOP.CTX_INIT = "" # initialization words 99 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp 100 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 101 | 102 | cfg.TRAINER.COCOOP = CN() 103 | cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors 104 | cfg.TRAINER.COCOOP.CTX_INIT = "" # initialization words 105 | cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp 106 | 107 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 108 | 109 | 110 | def setup_cfg(args): 111 | cfg = get_cfg_default() 112 | extend_cfg(cfg) 113 | 114 | # 1. From the dataset config file 115 | if args.dataset_config_file: 116 | cfg.merge_from_file(args.dataset_config_file) 117 | 118 | # 2. From the method config file 119 | if args.config_file: 120 | cfg.merge_from_file(args.config_file) 121 | 122 | # 3. From input arguments 123 | reset_cfg(cfg, args) 124 | 125 | # 4. From optional input arguments 126 | cfg.merge_from_list(args.opts) 127 | 128 | cfg.freeze() 129 | 130 | return cfg 131 | 132 | 133 | def main(args): 134 | cfg = setup_cfg(args) 135 | if cfg.SEED >= 0: 136 | print("Setting fixed seed: {}".format(cfg.SEED)) 137 | set_random_seed(cfg.SEED) 138 | setup_logger(cfg.OUTPUT_DIR) 139 | 140 | if torch.cuda.is_available() and cfg.USE_CUDA: 141 | torch.backends.cudnn.benchmark = True 142 | 143 | print_args(args, cfg) 144 | print("Collecting env info ...") 145 | print("** System info **\n{}\n".format(collect_env_info())) 146 | 147 | trainer = build_trainer(cfg) 148 | 149 | if args.eval_only: 150 | trainer.load_model(args.model_dir, epoch=args.load_epoch) 151 | trainer.test() 152 | return 153 | 154 | if not args.no_train: 155 | trainer.train() 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument("--root", type=str, default="", help="path to dataset") 161 | parser.add_argument("--output-dir", type=str, default="", help="output directory") 162 | parser.add_argument( 163 | "--resume", 164 | type=str, 165 | default="", 166 | help="checkpoint directory (from which the training resumes)", 167 | ) 168 | parser.add_argument( 169 | "--seed", type=int, default=-1, help="only positive value enables a fixed seed" 170 | ) 171 | parser.add_argument( 172 | "--source-domains", type=str, nargs="+", help="source domains for DA/DG" 173 | ) 174 | parser.add_argument( 175 | "--target-domains", type=str, nargs="+", help="target domains for DA/DG" 176 | ) 177 | parser.add_argument( 178 | "--transforms", type=str, nargs="+", help="data augmentation methods" 179 | ) 180 | parser.add_argument( 181 | "--config-file", type=str, default="", help="path to config file" 182 | ) 183 | parser.add_argument( 184 | "--dataset-config-file", 185 | type=str, 186 | default="", 187 | help="path to config file for dataset setup", 188 | ) 189 | parser.add_argument("--trainer", type=str, default="", help="name of trainer") 190 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") 191 | parser.add_argument("--head", type=str, default="", help="name of head") 192 | parser.add_argument("--eval-only", action="store_true", help="evaluation only") 193 | parser.add_argument( 194 | "--model-dir", 195 | type=str, 196 | default="", 197 | help="load model from this directory for eval-only mode", 198 | ) 199 | parser.add_argument( 200 | "--load-epoch", type=int, help="load model weights at this epoch for evaluation" 201 | ) 202 | parser.add_argument( 203 | "--no-train", action="store_true", help="do not call trainer.train()" 204 | ) 205 | parser.add_argument( 206 | "opts", 207 | default=None, 208 | nargs=argparse.REMAINDER, 209 | help="modify config options using the command-line", 210 | ) 211 | args = parser.parse_args() 212 | main(args) 213 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/trainers/__init__.py -------------------------------------------------------------------------------- /trainers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/trainers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/cocoop.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/trainers/__pycache__/cocoop.cpython-39.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/coop.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/trainers/__pycache__/coop.cpython-39.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/imagenet_templates.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/trainers/__pycache__/imagenet_templates.cpython-39.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/ogen.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/trainers/__pycache__/ogen.cpython-39.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/zsclip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/trainers/__pycache__/zsclip.cpython-39.pyc -------------------------------------------------------------------------------- /trainers/imagenet_templates.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 6 | 7 | IMAGENET_TEMPLATES = [ 8 | "a bad photo of a {}.", 9 | "a photo of many {}.", 10 | "a sculpture of a {}.", 11 | "a photo of the hard to see {}.", 12 | "a low resolution photo of the {}.", 13 | "a rendering of a {}.", 14 | "graffiti of a {}.", 15 | "a bad photo of the {}.", 16 | "a cropped photo of the {}.", 17 | "a tattoo of a {}.", 18 | "the embroidered {}.", 19 | "a photo of a hard to see {}.", 20 | "a bright photo of a {}.", 21 | "a photo of a clean {}.", 22 | "a photo of a dirty {}.", 23 | "a dark photo of the {}.", 24 | "a drawing of a {}.", 25 | "a photo of my {}.", 26 | "the plastic {}.", 27 | "a photo of the cool {}.", 28 | "a close-up photo of a {}.", 29 | "a black and white photo of the {}.", 30 | "a painting of the {}.", 31 | "a painting of a {}.", 32 | "a pixelated photo of the {}.", 33 | "a sculpture of the {}.", 34 | "a bright photo of the {}.", 35 | "a cropped photo of a {}.", 36 | "a plastic {}.", 37 | "a photo of the dirty {}.", 38 | "a jpeg corrupted photo of a {}.", 39 | "a blurry photo of the {}.", 40 | "a photo of the {}.", 41 | "a good photo of the {}.", 42 | "a rendering of the {}.", 43 | "a {} in a video game.", 44 | "a photo of one {}.", 45 | "a doodle of a {}.", 46 | "a close-up photo of the {}.", 47 | "a photo of a {}.", 48 | "the origami {}.", 49 | "the {} in a video game.", 50 | "a sketch of a {}.", 51 | "a doodle of the {}.", 52 | "a origami {}.", 53 | "a low resolution photo of a {}.", 54 | "the toy {}.", 55 | "a rendition of the {}.", 56 | "a photo of the clean {}.", 57 | "a photo of a large {}.", 58 | "a rendition of a {}.", 59 | "a photo of a nice {}.", 60 | "a photo of a weird {}.", 61 | "a blurry photo of a {}.", 62 | "a cartoon {}.", 63 | "art of a {}.", 64 | "a sketch of the {}.", 65 | "a embroidered {}.", 66 | "a pixelated photo of a {}.", 67 | "itap of the {}.", 68 | "a jpeg corrupted photo of the {}.", 69 | "a good photo of a {}.", 70 | "a plushie {}.", 71 | "a photo of the nice {}.", 72 | "a photo of the small {}.", 73 | "a photo of the weird {}.", 74 | "the cartoon {}.", 75 | "art of the {}.", 76 | "a drawing of the {}.", 77 | "a photo of the large {}.", 78 | "a black and white photo of a {}.", 79 | "the plushie {}.", 80 | "a dark photo of a {}.", 81 | "itap of a {}.", 82 | "graffiti of the {}.", 83 | "a toy {}.", 84 | "itap of my {}.", 85 | "a photo of a cool {}.", 86 | "a photo of a small {}.", 87 | "a tattoo of the {}.", 88 | ] 89 | 90 | IMAGENET_TEMPLATES_SELECT = [ 91 | "itap of a {}.", 92 | "a bad photo of the {}.", 93 | "a origami {}.", 94 | "a photo of the large {}.", 95 | "a {} in a video game.", 96 | "art of the {}.", 97 | "a photo of the small {}.", 98 | ] 99 | -------------------------------------------------------------------------------- /trainers/ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-ogen/ccfa0df57ecac7c3214f82c32291401c51a746ac/trainers/ipynb_checkpoints/__init__-checkpoint.py -------------------------------------------------------------------------------- /trainers/ipynb_checkpoints/imagenet_templates-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 6 | 7 | IMAGENET_TEMPLATES = [ 8 | "a bad photo of a {}.", 9 | "a photo of many {}.", 10 | "a sculpture of a {}.", 11 | "a photo of the hard to see {}.", 12 | "a low resolution photo of the {}.", 13 | "a rendering of a {}.", 14 | "graffiti of a {}.", 15 | "a bad photo of the {}.", 16 | "a cropped photo of the {}.", 17 | "a tattoo of a {}.", 18 | "the embroidered {}.", 19 | "a photo of a hard to see {}.", 20 | "a bright photo of a {}.", 21 | "a photo of a clean {}.", 22 | "a photo of a dirty {}.", 23 | "a dark photo of the {}.", 24 | "a drawing of a {}.", 25 | "a photo of my {}.", 26 | "the plastic {}.", 27 | "a photo of the cool {}.", 28 | "a close-up photo of a {}.", 29 | "a black and white photo of the {}.", 30 | "a painting of the {}.", 31 | "a painting of a {}.", 32 | "a pixelated photo of the {}.", 33 | "a sculpture of the {}.", 34 | "a bright photo of the {}.", 35 | "a cropped photo of a {}.", 36 | "a plastic {}.", 37 | "a photo of the dirty {}.", 38 | "a jpeg corrupted photo of a {}.", 39 | "a blurry photo of the {}.", 40 | "a photo of the {}.", 41 | "a good photo of the {}.", 42 | "a rendering of the {}.", 43 | "a {} in a video game.", 44 | "a photo of one {}.", 45 | "a doodle of a {}.", 46 | "a close-up photo of the {}.", 47 | "a photo of a {}.", 48 | "the origami {}.", 49 | "the {} in a video game.", 50 | "a sketch of a {}.", 51 | "a doodle of the {}.", 52 | "a origami {}.", 53 | "a low resolution photo of a {}.", 54 | "the toy {}.", 55 | "a rendition of the {}.", 56 | "a photo of the clean {}.", 57 | "a photo of a large {}.", 58 | "a rendition of a {}.", 59 | "a photo of a nice {}.", 60 | "a photo of a weird {}.", 61 | "a blurry photo of a {}.", 62 | "a cartoon {}.", 63 | "art of a {}.", 64 | "a sketch of the {}.", 65 | "a embroidered {}.", 66 | "a pixelated photo of a {}.", 67 | "itap of the {}.", 68 | "a jpeg corrupted photo of the {}.", 69 | "a good photo of a {}.", 70 | "a plushie {}.", 71 | "a photo of the nice {}.", 72 | "a photo of the small {}.", 73 | "a photo of the weird {}.", 74 | "the cartoon {}.", 75 | "art of the {}.", 76 | "a drawing of the {}.", 77 | "a photo of the large {}.", 78 | "a black and white photo of a {}.", 79 | "the plushie {}.", 80 | "a dark photo of a {}.", 81 | "itap of a {}.", 82 | "graffiti of the {}.", 83 | "a toy {}.", 84 | "itap of my {}.", 85 | "a photo of a cool {}.", 86 | "a photo of a small {}.", 87 | "a tattoo of the {}.", 88 | ] 89 | 90 | IMAGENET_TEMPLATES_SELECT = [ 91 | "itap of a {}.", 92 | "a bad photo of the {}.", 93 | "a origami {}.", 94 | "a photo of the large {}.", 95 | "a {} in a video game.", 96 | "art of the {}.", 97 | "a photo of the small {}.", 98 | ] 99 | -------------------------------------------------------------------------------- /trainers/ipynb_checkpoints/zsclip-checkpoint.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | 8 | from dassl.engine import TRAINER_REGISTRY, TrainerX 9 | from dassl.optim import build_optimizer, build_lr_scheduler 10 | 11 | from clip import clip 12 | from clip.model import convert_weights 13 | 14 | from .coop import load_clip_to_cpu 15 | from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT 16 | 17 | CUSTOM_TEMPLATES = { 18 | "OxfordPets": "a photo of a {}, a type of pet.", 19 | "OxfordFlowers": "a photo of a {}, a type of flower.", 20 | "FGVCAircraft": "a photo of a {}, a type of aircraft.", 21 | "DescribableTextures": "{} texture.", 22 | "EuroSAT": "a centered satellite photo of {}.", 23 | "StanfordCars": "a photo of a {}.", 24 | "Food101": "a photo of {}, a type of food.", 25 | "SUN397": "a photo of a {}.", 26 | "Caltech101": "a photo of a {}.", 27 | "UCF101": "a photo of a person doing {}.", 28 | "ImageNet": "a photo of a {}.", 29 | "ImageNetSketch": "a photo of a {}.", 30 | "ImageNetV2": "a photo of a {}.", 31 | "ImageNetA": "a photo of a {}.", 32 | "ImageNetR": "a photo of a {}.", 33 | } 34 | 35 | 36 | @TRAINER_REGISTRY.register() 37 | class ZeroshotCLIP(TrainerX): 38 | def build_model(self): 39 | cfg = self.cfg 40 | classnames = self.dm.dataset.classnames 41 | 42 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 43 | clip_model = load_clip_to_cpu(cfg) 44 | clip_model.to(self.device) 45 | 46 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 47 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 48 | print(f"Prompts: {prompts}") 49 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 50 | prompts = prompts.to(self.device) 51 | 52 | with torch.no_grad(): 53 | text_features = clip_model.encode_text(prompts) 54 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 55 | 56 | self.text_features = text_features 57 | self.clip_model = clip_model 58 | 59 | def model_inference(self, image): 60 | image_features = self.clip_model.encode_image(image) 61 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 62 | logit_scale = self.clip_model.logit_scale.exp() 63 | logits = logit_scale * image_features @ self.text_features.t() 64 | return logits 65 | 66 | 67 | @TRAINER_REGISTRY.register() 68 | class ZeroshotCLIP2(ZeroshotCLIP): 69 | """Prompt ensembling.""" 70 | 71 | # templates = IMAGENET_TEMPLATES 72 | templates = IMAGENET_TEMPLATES_SELECT 73 | 74 | def build_model(self): 75 | cfg = self.cfg 76 | classnames = self.dm.dataset.classnames 77 | 78 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 79 | clip_model = load_clip_to_cpu(cfg) 80 | clip_model.to(self.device) 81 | 82 | for params in clip_model.parameters(): 83 | params.requires_grad_(False) 84 | 85 | # add custom-made prompt 86 | if cfg.DATASET.NAME != "ImageNet": 87 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]] 88 | 89 | num_temp = len(self.templates) 90 | print(f"Prompt ensembling (n={num_temp})") 91 | 92 | mean_text_features = 0 93 | for i, temp in enumerate(self.templates): 94 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 95 | prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device) 96 | text_features = clip_model.encode_text(prompts) 97 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 98 | mean_text_features = mean_text_features + text_features 99 | mean_text_features = mean_text_features / num_temp 100 | mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True) 101 | 102 | self.text_features = mean_text_features 103 | self.clip_model = clip_model 104 | -------------------------------------------------------------------------------- /trainers/zsclip.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | 8 | from dassl.engine import TRAINER_REGISTRY, TrainerX 9 | from dassl.optim import build_optimizer, build_lr_scheduler 10 | 11 | from clip import clip 12 | from clip.model import convert_weights 13 | 14 | from .coop import load_clip_to_cpu 15 | from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT 16 | 17 | CUSTOM_TEMPLATES = { 18 | "OxfordPets": "a photo of a {}, a type of pet.", 19 | "OxfordFlowers": "a photo of a {}, a type of flower.", 20 | "FGVCAircraft": "a photo of a {}, a type of aircraft.", 21 | "DescribableTextures": "{} texture.", 22 | "EuroSAT": "a centered satellite photo of {}.", 23 | "StanfordCars": "a photo of a {}.", 24 | "Food101": "a photo of {}, a type of food.", 25 | "SUN397": "a photo of a {}.", 26 | "Caltech101": "a photo of a {}.", 27 | "UCF101": "a photo of a person doing {}.", 28 | "ImageNet": "a photo of a {}.", 29 | "ImageNetSketch": "a photo of a {}.", 30 | "ImageNetV2": "a photo of a {}.", 31 | "ImageNetA": "a photo of a {}.", 32 | "ImageNetR": "a photo of a {}.", 33 | } 34 | 35 | 36 | @TRAINER_REGISTRY.register() 37 | class ZeroshotCLIP(TrainerX): 38 | def build_model(self): 39 | cfg = self.cfg 40 | classnames = self.dm.dataset.classnames 41 | 42 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 43 | clip_model = load_clip_to_cpu(cfg) 44 | clip_model.to(self.device) 45 | 46 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 47 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 48 | print(f"Prompts: {prompts}") 49 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 50 | prompts = prompts.to(self.device) 51 | 52 | with torch.no_grad(): 53 | text_features = clip_model.encode_text(prompts) 54 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 55 | 56 | self.text_features = text_features 57 | self.clip_model = clip_model 58 | 59 | def model_inference(self, image): 60 | image_features = self.clip_model.encode_image(image) 61 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 62 | logit_scale = self.clip_model.logit_scale.exp() 63 | logits = logit_scale * image_features @ self.text_features.t() 64 | return logits 65 | 66 | 67 | @TRAINER_REGISTRY.register() 68 | class ZeroshotCLIP2(ZeroshotCLIP): 69 | """Prompt ensembling.""" 70 | 71 | # templates = IMAGENET_TEMPLATES 72 | templates = IMAGENET_TEMPLATES_SELECT 73 | 74 | def build_model(self): 75 | cfg = self.cfg 76 | classnames = self.dm.dataset.classnames 77 | 78 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 79 | clip_model = load_clip_to_cpu(cfg) 80 | clip_model.to(self.device) 81 | 82 | for params in clip_model.parameters(): 83 | params.requires_grad_(False) 84 | 85 | # add custom-made prompt 86 | if cfg.DATASET.NAME != "ImageNet": 87 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]] 88 | 89 | num_temp = len(self.templates) 90 | print(f"Prompt ensembling (n={num_temp})") 91 | 92 | mean_text_features = 0 93 | for i, temp in enumerate(self.templates): 94 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 95 | prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device) 96 | text_features = clip_model.encode_text(prompts) 97 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 98 | mean_text_features = mean_text_features + text_features 99 | mean_text_features = mean_text_features / num_temp 100 | mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True) 101 | 102 | self.text_features = mean_text_features 103 | self.clip_model = clip_model 104 | --------------------------------------------------------------------------------