├── images ├── teaser1.png ├── teaser2.png ├── officehome.png └── architecture.png ├── docs └── AD-CLIP_poster.pdf ├── clip ├── bpe_simple_vocab_16e6.txt.gz ├── simple_tokenizer.py ├── clip.py └── model.py ├── requirements.txt ├── configs ├── datasets │ ├── visda17.yaml │ ├── mini_domainnet.yaml │ └── officehome.yaml └── trainer │ ├── rn50.yaml │ ├── vitB16.yaml │ └── vitL14.yaml ├── scripts ├── main.sh └── eval.sh ├── LICENSE ├── datasets ├── visda17.py ├── mini_domainnet.py └── office_home.py ├── README.md ├── train.py └── trainers ├── adclip_vitB16.py ├── adclip_vitL14.py └── adclip_rn50.py /images/teaser1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/images/teaser1.png -------------------------------------------------------------------------------- /images/teaser2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/images/teaser2.png -------------------------------------------------------------------------------- /images/officehome.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/images/officehome.png -------------------------------------------------------------------------------- /docs/AD-CLIP_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/docs/AD-CLIP_poster.pdf -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/images/architecture.png -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | tqdm 3 | ftfy 4 | regex 5 | yacs 6 | einops 7 | h5py 8 | tb-nightly 9 | future 10 | six -------------------------------------------------------------------------------- /configs/datasets/visda17.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "center_crop", "normalize"] 4 | 5 | DATASET: 6 | NAME: "VisDA17" 7 | SOURCE_DOMAINS: ["synthetic"] 8 | TARGET_DOMAINS: ["synthetic"] 9 | 10 | MODEL: 11 | BACKBONE: 12 | NAME: "ViT-B/16" 13 | 14 | TEST: 15 | PER_CLASS_RESULT: True -------------------------------------------------------------------------------- /configs/datasets/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (96, 96) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "miniDomainNet" 7 | # SOURCE_DOMAINS: ["clipart"] 8 | # SOURCE_DOMAINS: ["painting"] 9 | # SOURCE_DOMAINS: ["real"] 10 | SOURCE_DOMAINS: ["sketch"] 11 | 12 | # TARGET_DOMAINS: ["clipart"] 13 | TARGET_DOMAINS: ["painting"] 14 | # TARGET_DOMAINS: ["real"] 15 | # TARGET_DOMAINS: ["sketch"] 16 | 17 | MODEL: 18 | BACKBONE: 19 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/datasets/officehome.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "center_crop", "normalize"] 4 | 5 | DATASET: 6 | NAME: "OfficeHome" 7 | # SOURCE_DOMAINS: ["real_world"] 8 | # SOURCE_DOMAINS: ["art"] 9 | # SOURCE_DOMAINS: ["clipart"] 10 | SOURCE_DOMAINS: ["product"] 11 | 12 | TARGET_DOMAINS: ["clipart"] 13 | # TARGET_DOMAINS: ["art"] 14 | # TARGET_DOMAINS: ["product"] 15 | # TARGET_DOMAINS: ["real_world"] 16 | # you can modify the code to explore four domains 17 | 18 | MODEL: 19 | BACKBONE: 20 | NAME: "ViT-B/16" 21 | -------------------------------------------------------------------------------- /configs/trainer/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 16 4 | TRAIN_U: 5 | BATCH_SIZE: 16 6 | TEST: 7 | BATCH_SIZE: 128 8 | NUM_WORKERS: 4 9 | 10 | INPUT: 11 | SIZE: (224, 224) 12 | INTERPOLATION: "bicubic" 13 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 14 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 15 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 16 | 17 | OPTIM: 18 | NAME: "adam" 19 | LR: 0.01 20 | MAX_EPOCH: 50 21 | LR_SCHEDULER: "cosine" 22 | WARMUP_EPOCH: 1 23 | WARMUP_TYPE: "linear" 24 | WARMUP_MIN_LR: 1e-5 25 | 26 | TRAIN: 27 | PRINT_FREQ: 100 28 | 29 | MODEL: 30 | BACKBONE: 31 | NAME: "RN50" 32 | 33 | TRAINER: 34 | ADCLIPRN50: 35 | PREC: "amp" 36 | -------------------------------------------------------------------------------- /configs/trainer/vitB16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 16 4 | TRAIN_U: 5 | BATCH_SIZE: 16 6 | TEST: 7 | BATCH_SIZE: 128 8 | NUM_WORKERS: 4 9 | 10 | INPUT: 11 | SIZE: (224, 224) 12 | INTERPOLATION: "bicubic" 13 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 14 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 15 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 16 | 17 | OPTIM: 18 | NAME: "adam" 19 | LR: 0.01 20 | MAX_EPOCH: 50 21 | LR_SCHEDULER: "cosine" 22 | WARMUP_EPOCH: 1 23 | WARMUP_TYPE: "linear" 24 | WARMUP_MIN_LR: 1e-5 25 | 26 | TRAIN: 27 | PRINT_FREQ: 100 28 | 29 | MODEL: 30 | BACKBONE: 31 | NAME: "ViT-B/16" 32 | 33 | TRAINER: 34 | ADCLIPB16: 35 | PREC: "amp" 36 | -------------------------------------------------------------------------------- /configs/trainer/vitL14.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 16 4 | TRAIN_U: 5 | BATCH_SIZE: 16 6 | TEST: 7 | BATCH_SIZE: 128 8 | NUM_WORKERS: 4 9 | 10 | INPUT: 11 | SIZE: (224, 224) 12 | INTERPOLATION: "bicubic" 13 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 14 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 15 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 16 | 17 | OPTIM: 18 | NAME: "adam" 19 | LR: 0.01 20 | MAX_EPOCH: 50 21 | LR_SCHEDULER: "cosine" 22 | WARMUP_EPOCH: 1 23 | WARMUP_TYPE: "linear" 24 | WARMUP_MIN_LR: 1e-5 25 | 26 | TRAIN: 27 | PRINT_FREQ: 100 28 | 29 | MODEL: 30 | BACKBONE: 31 | NAME: "ViT-L/14" 32 | 33 | TRAINER: 34 | ADCLIPL14: 35 | PREC: "amp" 36 | -------------------------------------------------------------------------------- /scripts/main.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA=data # change your data path here 4 | MODE=train 5 | 6 | DATASET=$1 # dataset name; officehome, visda17, mini_domainnet 7 | TRAINER=$2 # ADCLIPRN50, ADCLIPB16, ADCLIPL14 8 | CFG=$3 # config file; rn50, vitB16, vitL14 9 | #SEED=$4 10 | 11 | for SEED in 1 2 3 4 5 12 | do 13 | DIR=output/${DATASET}/${MODE}/${TRAINER}/${CFG}/seed_${SEED} 14 | if [ -d "$DIR" ]; then 15 | echo "Results are available in ${DIR}. Skip this job" 16 | else 17 | echo "Run this job and save the output to ${DIR}" 18 | python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} 25 | fi 26 | done -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA=data # change your data path here 4 | MODE=test 5 | 6 | DATASET=$1 # dataset name; officehome, visda17, mini_domainnet 7 | TRAINER=$2 # ADCLIPRN50, ADCLIPB16, ADCLIPL14 8 | CFG=$3 # config file; rn50, vitB16, vitL14 9 | # SEED=$4 10 | 11 | 12 | for SEED in 1 2 3 4 5 13 | do 14 | MODEL_DIR=output/${DATASET}/${MODE}/${TRAINER}/${CFG}/seed_${SEED} 15 | DIR=output/${DATASET}/${MODE}/${TRAINER}/${CFG}/seed_${SEED} 16 | if false; then 17 | echo "The results already exist in ${DIR}" 18 | else 19 | python train.py \ 20 | --root ${DATA} \ 21 | --seed ${SEED} \ 22 | --trainer ${TRAINER} \ 23 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 25 | --output-dir ${DIR} \ 26 | --model-dir ${MODEL_DIR} \ 27 | #--load-epoch ${LOADEP}\ 28 | --eval-only 29 | done 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mainak Singha 4 | Copyright (c) 2021 Kaiyang Zhou 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /datasets/visda17.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class VisDA17(DatasetBase): 9 | """VisDA17. 10 | 11 | Focusing on simulation-to-reality domain shift. 12 | 13 | URL: http://ai.bu.edu/visda-2017/. 14 | 15 | Reference: 16 | - Peng et al. VisDA: The Visual Domain Adaptation 17 | Challenge. ArXiv 2017. 18 | """ 19 | 20 | dataset_dir = "visda17" 21 | domains = ["synthetic", "real"] 22 | 23 | def __init__(self, cfg): 24 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | 27 | self.check_input_domains( 28 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 29 | ) 30 | 31 | train_x = self._read_data("synthetic") 32 | train_u = self._read_data("real") 33 | test = self._read_data("real") 34 | 35 | super().__init__(train_x=train_x, train_u=train_u, test=test) 36 | 37 | def _read_data(self, dname): 38 | filedir = "train" if dname == "synthetic" else "validation" 39 | image_list = osp.join(self.dataset_dir, filedir, "image_list.txt") 40 | items = [] 41 | # There is only one source domain 42 | domain = 0 43 | 44 | with open(image_list, "r") as f: 45 | lines = f.readlines() 46 | 47 | for line in lines: 48 | line = line.strip() 49 | impath, label = line.split(" ") 50 | classname = impath.split("/")[0] 51 | impath = osp.join(self.dataset_dir, filedir, impath) 52 | label = int(label) 53 | item = Datum( 54 | impath=impath, 55 | label=label, 56 | domain=domain, 57 | classname=classname 58 | ) 59 | items.append(item) 60 | 61 | return items 62 | -------------------------------------------------------------------------------- /datasets/mini_domainnet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class miniDomainNet(DatasetBase): 9 | """A subset of DomainNet. 10 | 11 | Reference: 12 | - Peng et al. Moment Matching for Multi-Source Domain 13 | Adaptation. ICCV 2019. 14 | - Zhou et al. Domain Adaptive Ensemble Learning. 15 | """ 16 | 17 | dataset_dir = "domainnet" 18 | domains = ["clipart", "painting", "real", "sketch"] 19 | 20 | def __init__(self, cfg): 21 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = osp.join(root, self.dataset_dir) 23 | self.split_dir = osp.join(self.dataset_dir, "splits_mini") 24 | 25 | self.check_input_domains( 26 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 27 | ) 28 | 29 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 30 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 31 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 32 | 33 | super().__init__(train_x=train_x, train_u=train_u, test=test) 34 | 35 | def _read_data(self, input_domains, split="train"): 36 | items = [] 37 | 38 | for domain, dname in enumerate(input_domains): 39 | filename = dname + "_" + split + ".txt" 40 | split_file = osp.join(self.split_dir, filename) 41 | 42 | with open(split_file, "r") as f: 43 | lines = f.readlines() 44 | for line in lines: 45 | line = line.strip() 46 | impath, label = line.split(" ") 47 | classname = impath.split("/")[1] 48 | impath = osp.join(self.dataset_dir, impath) 49 | label = int(label) 50 | item = Datum( 51 | impath=impath, 52 | label=label, 53 | domain=domain, 54 | classname=classname 55 | ) 56 | items.append(item) 57 | 58 | return items 59 | -------------------------------------------------------------------------------- /datasets/office_home.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class OfficeHome(DatasetBase): 11 | """Office-Home. 12 | 13 | Statistics: 14 | - Around 15,500 images. 15 | - 65 classes related to office and home objects. 16 | - 4 domains: Art, Clipart, Product, Real World. 17 | - URL: http://hemanthdv.org/OfficeHome-Dataset/. 18 | 19 | Reference: 20 | - Venkateswara et al. Deep Hashing Network for Unsupervised 21 | Domain Adaptation. CVPR 2017. 22 | """ 23 | 24 | dataset_dir = "office_home" 25 | domains = ["art", "clipart", "product", "real_world"] 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | self.check_input_domains( 32 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 33 | ) 34 | 35 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS) 36 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS) 37 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS) 38 | 39 | super().__init__(train_x=train_x, train_u=train_u, test=test) 40 | 41 | def _read_data(self, input_domains): 42 | items = [] 43 | 44 | for domain, dname in enumerate(input_domains): 45 | domain_dir = osp.join(self.dataset_dir, dname) 46 | class_names = listdir_nohidden(domain_dir) 47 | class_names.sort() 48 | 49 | for label, class_name in enumerate(class_names): 50 | class_path = osp.join(domain_dir, class_name) 51 | imnames = listdir_nohidden(class_path) 52 | 53 | for imname in imnames: 54 | impath = osp.join(class_path, imname) 55 | item = Datum( 56 | impath=impath, 57 | label=label, 58 | domain=domain, 59 | classname=class_name.lower(), 60 | ) 61 | items.append(item) 62 | 63 | return items 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AD-CLIP: Adapting Domains in Prompt Space Using CLIP 2 | Official repository of AD-CLIP, which is focused on domain adaptation using *prompt learning* by adapting pre-trained vision-language models (VLM) like CLIP. 3 | 4 | ## **ICCVw 2023** 5 | 6 | [![paper](https://img.shields.io/badge/Conference-Paper-blue)](https://openaccess.thecvf.com/content/ICCV2023W/OODCV/papers/Singha_AD-CLIP_Adapting_Domains_in_Prompt_Space_Using_CLIP_ICCVW_2023_paper.pdf) 7 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-brightgreen)](https://arxiv.org/pdf/2308.05659.pdf) 8 | [![poster](https://img.shields.io/badge/Poster-yellow)](https://github.com/mainaksingha01/AD-CLIP/blob/master/docs/AD-CLIP_poster.pdf) 9 | 10 | ## Abstract 11 | 12 | 13 | Although deep learning models have shown impressive performance on supervised learning tasks, they often struggle to generalize well when the training (source) and test (target) domains differ. Unsupervised domain adaptation (DA) has emerged as a popular solution to this problem. However, current DA techniques rely on visual backbones, which may lack semantic richness. Despite the potential of large-scale vision-language foundation models like CLIP, their effectiveness for DA has yet to be fully explored. To address this gap, we introduce AD-CLIP, a domain-agnostic prompt learning strategy for CLIP that aims to solve the DA problem in the prompt space. We leverage the frozen vision backbone of CLIP to extract both image style (domain) and content information, which we apply to learn prompt tokens. Our prompts are designed to be domain-invariant and class-generalizable, by conditioning prompt learning on image style and content features simultaneously. We use standard supervised contrastive learning in the source domain, while proposing an entropy minimization strategy to align domains in the embedding space given the target domain data. We also consider a scenario where only target domain samples are available during testing, without any source domain data, and propose a cross-domain style mapping network to hallucinate domain-agnostic tokens. Our extensive experiments on three benchmark DA datasets demonstrate the effectiveness of AD-CLIP compared to existing literature. 14 | 15 | ## Architecture 16 | 17 | 18 | 19 | ## How to install 20 | 21 | ### Create your environment: 22 | 23 | ```bash 24 | $ conda create -n adclip python=3.8 25 | $ conda activate adclip 26 | $ conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=10.2 -c pytorch 27 | $ pip install -r requirements.txt 28 | ``` 29 | 30 | ## Code 31 | 32 | - `datasets` folder contains the dataloader files of each datasets. 33 | - `trainers` folder contains the code of our model in three variants ResNet50, ViT-B/16 and ViT-L/14. 34 | - Clone the awesome toolbox of [dassl](https://github.com/KaiyangZhou/Dassl.pytorch/tree/master/dassl) inside this repo. 35 | - In line 464 of `dassl.engine.trainer` file, replace the output by the returns of the `CustomCLIP` class of the trainers (e.g. adclip_vitB16) file for evaluation. 36 | - `scripts` folder holds the scripts of for training and testing. 37 | - Put data path in `main.sh` and `eval.sh`. 38 | - Choose the source and target domains from `configs.datasets` files. 39 | 40 | ```shell (for example) 41 | $ cd scripts 42 | $ bash main.sh officehome ADCLIPB16 vitB16 43 | $ bash eval.sh officehome ADCLIPB16 vitB16 44 | ``` 45 | 46 | ## Bibtex 47 | 48 | Please cite the paper if you use our work . Thanks. 49 | 50 | ``` 51 | @inproceedings{singha2023ad, 52 | title={Ad-clip: Adapting domains in prompt space using clip}, 53 | author={Singha, Mainak and Pal, Harsh and Jha, Ankit and Banerjee, Biplab}, 54 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 55 | pages={4355--4364}, 56 | year={2023} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | 5 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 6 | from dassl.config import get_cfg_default 7 | from dassl.engine import build_trainer 8 | 9 | from dassl.data.datasets import VisDA17 10 | from dassl.data.datasets import OfficeHome 11 | from dassl.data.datasets import miniDomainNet 12 | 13 | import trainers.adclip_rn50 14 | import trainers.adclip_vitB16 15 | import trainers.adclip_vitL14 16 | 17 | 18 | def print_args(args, cfg): 19 | print("***************") 20 | print("** Arguments **") 21 | print("***************") 22 | optkeys = list(args.__dict__.keys()) 23 | optkeys.sort() 24 | for key in optkeys: 25 | print("{}: {}".format(key, args.__dict__[key])) 26 | print("************") 27 | print("** Config **") 28 | print("************") 29 | print(cfg) 30 | 31 | 32 | def reset_cfg(cfg, args): 33 | if args.root: 34 | cfg.DATASET.ROOT = args.root 35 | 36 | if args.output_dir: 37 | cfg.OUTPUT_DIR = args.output_dir 38 | 39 | if args.resume: 40 | cfg.RESUME = args.resume 41 | 42 | if args.seed: 43 | cfg.SEED = args.seed 44 | 45 | if args.source_domains: 46 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 47 | 48 | if args.target_domains: 49 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 50 | 51 | if args.transforms: 52 | cfg.INPUT.TRANSFORMS = args.transforms 53 | 54 | if args.trainer: 55 | cfg.TRAINER.NAME = args.trainer 56 | 57 | if args.backbone: 58 | cfg.MODEL.BACKBONE.NAME = args.backbone 59 | 60 | if args.head: 61 | cfg.MODEL.HEAD.NAME = args.head 62 | 63 | 64 | def extend_cfg(cfg): 65 | """ 66 | Add new config variables for DAPL. 67 | 68 | E.g. 69 | from yacs.config import CfgNode as CN 70 | cfg.TRAINER.MY_MODEL = CN() 71 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 72 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 73 | cfg.TRAINER.MY_MODEL.PARAM_C = False 74 | """ 75 | from yacs.config import CfgNode as CN 76 | 77 | cfg.MODEL.BACKBONE.PATH = "./assets" 78 | 79 | 80 | cfg.TRAINER.ADCLIPRN50 = CN() 81 | cfg.TRAINER.ADCLIPRN50.PREC = "amp" # fp16, fp32, amp 82 | 83 | cfg.TRAINER.ADCLIPB16 = CN() 84 | cfg.TRAINER.ADCLIPB16.PREC = "amp" # fp16, fp32, amp 85 | 86 | cfg.TRAINER.ADCLIPL14 = CN() 87 | cfg.TRAINER.ADCLIPL14.PREC = "amp" # fp16, fp32, amp 88 | 89 | 90 | def setup_cfg(args): 91 | cfg = get_cfg_default() 92 | extend_cfg(cfg) 93 | print(cfg) 94 | 95 | # 1. From the dataset config file 96 | if args.dataset_config_file: 97 | cfg.merge_from_file(args.dataset_config_file) 98 | 99 | # 2. From the method config file 100 | if args.config_file: 101 | cfg.merge_from_file(args.config_file) 102 | 103 | # 3. From input arguments 104 | reset_cfg(cfg, args) 105 | 106 | # 4. From optional input arguments 107 | cfg.merge_from_list(args.opts) 108 | 109 | cfg.freeze() 110 | 111 | return cfg 112 | 113 | 114 | def main(args): 115 | cfg = setup_cfg(args) 116 | if cfg.SEED >= 0: 117 | print("Setting fixed seed: {}".format(cfg.SEED)) 118 | set_random_seed(cfg.SEED) 119 | setup_logger(cfg.OUTPUT_DIR) 120 | 121 | if torch.cuda.is_available() and cfg.USE_CUDA: 122 | torch.backends.cudnn.benchmark = True 123 | 124 | print_args(args, cfg) 125 | print("Collecting env info ...") 126 | print("** System info **\n{}\n".format(collect_env_info())) 127 | 128 | trainer = build_trainer(cfg) 129 | 130 | if args.eval_only: 131 | # if True: 132 | print("Yess testing") 133 | trainer.load_model(args.model_dir, epoch=args.load_epoch) 134 | trainer.test() 135 | return 136 | 137 | if not args.no_train: 138 | print("No! Training") 139 | trainer.train() 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--root", type=str, default="", help="path to dataset") 145 | parser.add_argument("--output-dir", 146 | type=str, 147 | default="", 148 | help="output directory") 149 | parser.add_argument( 150 | "--resume", 151 | type=str, 152 | default="", 153 | help="checkpoint directory (from which the training resumes)", 154 | ) 155 | parser.add_argument("--seed", 156 | type=int, 157 | default=-1, 158 | help="only positive value enables a fixed seed") 159 | parser.add_argument("--source-domains", 160 | type=str, 161 | nargs="+", 162 | help="source domains for DA/DG") 163 | parser.add_argument("--target-domains", 164 | type=str, 165 | nargs="+", 166 | help="target domains for DA/DG") 167 | parser.add_argument("--transforms", 168 | type=str, 169 | nargs="+", 170 | help="data augmentation methods") 171 | parser.add_argument("--config-file", 172 | type=str, 173 | default="", 174 | help="path to config file") 175 | parser.add_argument( 176 | "--dataset-config-file", 177 | type=str, 178 | default="", 179 | help="path to config file for dataset setup", 180 | ) 181 | parser.add_argument("--trainer", 182 | type=str, 183 | default="", 184 | help="name of trainer") 185 | parser.add_argument("--backbone", 186 | type=str, 187 | default="", 188 | help="name of CNN backbone") 189 | parser.add_argument("--head", type=str, default="", help="name of head") 190 | parser.add_argument("--eval-only", 191 | action="store_true", 192 | help="evaluation only") 193 | parser.add_argument( 194 | "--model-dir", 195 | type=str, 196 | default="", 197 | help="load model from this directory for eval-only mode", 198 | ) 199 | parser.add_argument("--load-epoch", 200 | type=int, 201 | help="load model weights at this epoch for evaluation") 202 | parser.add_argument("--no-train", 203 | action="store_true", 204 | help="do not call trainer.train()") 205 | parser.add_argument( 206 | "opts", 207 | default=None, 208 | nargs=argparse.REMAINDER, 209 | help="modify config options using the command-line", 210 | ) 211 | args = parser.parse_args() 212 | main(args) 213 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 37 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 38 | } 39 | 40 | 41 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 42 | os.makedirs(root, exist_ok=True) 43 | filename = os.path.basename(url) 44 | 45 | expected_sha256 = url.split("/")[-2] 46 | download_target = os.path.join(root, filename) 47 | 48 | if os.path.exists(download_target) and not os.path.isfile(download_target): 49 | raise RuntimeError(f"{download_target} exists and is not a regular file") 50 | 51 | if os.path.isfile(download_target): 52 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 53 | return download_target 54 | else: 55 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 56 | 57 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 58 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 59 | while True: 60 | buffer = source.read(8192) 61 | if not buffer: 62 | break 63 | 64 | output.write(buffer) 65 | loop.update(len(buffer)) 66 | 67 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 68 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 69 | 70 | return download_target 71 | 72 | 73 | def _transform(n_px): 74 | return Compose([ 75 | Resize(n_px, interpolation=BICUBIC), 76 | CenterCrop(n_px), 77 | lambda image: image.convert("RGB"), 78 | ToTensor(), 79 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 80 | ]) 81 | 82 | 83 | def available_models() -> List[str]: 84 | """Returns the names of available CLIP models""" 85 | return list(_MODELS.keys()) 86 | 87 | 88 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 89 | """Load a CLIP model 90 | 91 | Parameters 92 | ---------- 93 | name : str 94 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 95 | 96 | device : Union[str, torch.device] 97 | The device to put the loaded model 98 | 99 | jit : bool 100 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 101 | 102 | Returns 103 | ------- 104 | model : torch.nn.Module 105 | The CLIP model 106 | 107 | preprocess : Callable[[PIL.Image], torch.Tensor] 108 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 109 | """ 110 | if name in _MODELS: 111 | model_path = _download(_MODELS[name]) 112 | elif os.path.isfile(name): 113 | model_path = name 114 | else: 115 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 116 | 117 | try: 118 | # loading JIT archive 119 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 120 | state_dict = None 121 | except RuntimeError: 122 | # loading saved state dict 123 | if jit: 124 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 125 | jit = False 126 | state_dict = torch.load(model_path, map_location="cpu") 127 | 128 | if not jit: 129 | model = build_model(state_dict or model.state_dict()).to(device) 130 | if str(device) == "cpu": 131 | model.float() 132 | return model, _transform(model.visual.input_resolution) 133 | 134 | # patch the device names 135 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 136 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 137 | 138 | def patch_device(module): 139 | try: 140 | graphs = [module.graph] if hasattr(module, "graph") else [] 141 | except RuntimeError: 142 | graphs = [] 143 | 144 | if hasattr(module, "forward1"): 145 | graphs.append(module.forward1.graph) 146 | 147 | for graph in graphs: 148 | for node in graph.findAllNodes("prim::Constant"): 149 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 150 | node.copyAttributes(device_node) 151 | 152 | model.apply(patch_device) 153 | patch_device(model.encode_image) 154 | patch_device(model.encode_text) 155 | 156 | # patch dtype to float32 on CPU 157 | if str(device) == "cpu": 158 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 159 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 160 | float_node = float_input.node() 161 | 162 | def patch_float(module): 163 | try: 164 | graphs = [module.graph] if hasattr(module, "graph") else [] 165 | except RuntimeError: 166 | graphs = [] 167 | 168 | if hasattr(module, "forward1"): 169 | graphs.append(module.forward1.graph) 170 | 171 | for graph in graphs: 172 | for node in graph.findAllNodes("aten::to"): 173 | inputs = list(node.inputs()) 174 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 175 | if inputs[i].node()["value"] == 5: 176 | inputs[i].node().copyAttributes(float_node) 177 | 178 | model.apply(patch_float) 179 | patch_float(model.encode_image) 180 | patch_float(model.encode_text) 181 | 182 | model.float() 183 | 184 | return model, _transform(model.input_resolution.item()) 185 | 186 | 187 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 188 | """ 189 | Returns the tokenized representation of given input string(s) 190 | 191 | Parameters 192 | ---------- 193 | texts : Union[str, List[str]] 194 | An input string or a list of input strings to tokenize 195 | 196 | context_length : int 197 | The context length to use; all CLIP models use 77 as the context length 198 | 199 | truncate: bool 200 | Whether to truncate the text in case its encoding is longer than the context length 201 | 202 | Returns 203 | ------- 204 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 205 | """ 206 | if isinstance(texts, str): 207 | texts = [texts] 208 | 209 | sot_token = _tokenizer.encoder["<|startoftext|>"] 210 | eot_token = _tokenizer.encoder["<|endoftext|>"] 211 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 212 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 213 | 214 | for i, tokens in enumerate(all_tokens): 215 | if len(tokens) > context_length: 216 | if truncate: 217 | tokens = tokens[:context_length] 218 | tokens[-1] = eot_token 219 | else: 220 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 221 | result[i, :len(tokens)] = torch.tensor(tokens) 222 | 223 | return result 224 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = conv(x) 139 | x = bn(x) 140 | x = self.relu(x) 141 | # x = self.relu(bn(conv(x))) 142 | x = self.avgpool(x) 143 | return x 144 | 145 | x = x.type(self.conv1.weight.dtype) 146 | x = stem(x) 147 | data = [] 148 | x1 = self.layer1(x) 149 | data.append(x1) 150 | x2 = self.layer2(x1) 151 | data.append(x2) 152 | x3 = self.layer3(x2) 153 | data.append(x3) 154 | x4 = self.layer4(x3) 155 | data.append(x4) 156 | feat = self.attnpool(x4) 157 | return feat, data 158 | 159 | 160 | class LayerNorm(nn.LayerNorm): 161 | """Subclass torch's LayerNorm to handle fp16.""" 162 | 163 | def forward(self, x: torch.Tensor): 164 | orig_type = x.dtype 165 | ret = super().forward(x.type(torch.float32)) 166 | return ret.type(orig_type) 167 | 168 | 169 | class QuickGELU(nn.Module): 170 | def forward(self, x: torch.Tensor): 171 | return x * torch.sigmoid(1.702 * x) 172 | 173 | 174 | class ResidualAttentionBlock(nn.Module): 175 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 176 | super().__init__() 177 | 178 | self.attn = nn.MultiheadAttention(d_model, n_head) 179 | self.ln_1 = LayerNorm(d_model) 180 | self.mlp = nn.Sequential(OrderedDict([ 181 | ("c_fc", nn.Linear(d_model, d_model * 4)), 182 | ("gelu", QuickGELU()), 183 | ("c_proj", nn.Linear(d_model * 4, d_model)) 184 | ])) 185 | self.ln_2 = LayerNorm(d_model) 186 | self.attn_mask = attn_mask 187 | 188 | def attention(self, x: torch.Tensor): 189 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 190 | y = self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 191 | return y 192 | 193 | def forward(self, x: torch.Tensor): 194 | x = x + self.attention(self.ln_1(x)) 195 | x = x + self.mlp(self.ln_2(x)) 196 | return x 197 | 198 | 199 | class Transformer(nn.Module): 200 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 201 | super().__init__() 202 | self.width = width 203 | self.layers = layers 204 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 205 | 206 | def forward(self, x: torch.Tensor): 207 | data=[] 208 | for layer in self.resblocks: 209 | x = layer(x) 210 | data.append(x.detach().permute(1,0,2)) 211 | data = torch.stack(data) 212 | return x, data 213 | 214 | 215 | class VisionTransformer(nn.Module): 216 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 217 | super().__init__() 218 | self.input_resolution = input_resolution 219 | self.output_dim = output_dim 220 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 221 | 222 | scale = width ** -0.5 223 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 224 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 225 | self.ln_pre = LayerNorm(width) 226 | 227 | self.transformer = Transformer(width, layers, heads) 228 | 229 | self.ln_post = LayerNorm(width) 230 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 231 | 232 | def forward(self, x: torch.Tensor): 233 | x = self.conv1(x) # shape = [*, width, grid, grid] 234 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 235 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 236 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 237 | x = x + self.positional_embedding.to(x.dtype) 238 | x = self.ln_pre(x) 239 | 240 | x = x.permute(1, 0, 2) # NLD -> LND 241 | x, data = self.transformer(x) 242 | x = x.permute(1, 0, 2) # LND -> NLD 243 | 244 | x = self.ln_post(x[:, 0, :]) 245 | 246 | if self.proj is not None: 247 | x = x @ self.proj 248 | return x, data 249 | 250 | 251 | class CLIP(nn.Module): 252 | def __init__(self, 253 | embed_dim: int, 254 | # vision 255 | image_resolution: int, 256 | vision_layers: Union[Tuple[int, int, int, int], int], 257 | vision_width: int, 258 | vision_patch_size: int, 259 | # text 260 | context_length: int, 261 | vocab_size: int, 262 | transformer_width: int, 263 | transformer_heads: int, 264 | transformer_layers: int 265 | ): 266 | super().__init__() 267 | 268 | self.context_length = context_length 269 | 270 | if isinstance(vision_layers, (tuple, list)): 271 | vision_heads = vision_width * 32 // 64 272 | self.visual = ModifiedResNet( 273 | layers=vision_layers, 274 | output_dim=embed_dim, 275 | heads=vision_heads, 276 | input_resolution=image_resolution, 277 | width=vision_width 278 | ) 279 | else: 280 | vision_heads = vision_width // 64 281 | self.visual = VisionTransformer( 282 | input_resolution=image_resolution, 283 | patch_size=vision_patch_size, 284 | width=vision_width, 285 | layers=vision_layers, 286 | heads=vision_heads, 287 | output_dim=embed_dim 288 | ) 289 | 290 | self.transformer = Transformer( 291 | width=transformer_width, 292 | layers=transformer_layers, 293 | heads=transformer_heads, 294 | attn_mask=self.build_attention_mask() 295 | ) 296 | 297 | self.vocab_size = vocab_size 298 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 299 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 300 | self.ln_final = LayerNorm(transformer_width) 301 | 302 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 303 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 304 | 305 | self.initialize_parameters() 306 | 307 | def initialize_parameters(self): 308 | nn.init.normal_(self.token_embedding.weight, std=0.02) 309 | nn.init.normal_(self.positional_embedding, std=0.01) 310 | 311 | if isinstance(self.visual, ModifiedResNet): 312 | if self.visual.attnpool is not None: 313 | std = self.visual.attnpool.c_proj.in_features ** -0.5 314 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 315 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 316 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 317 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 318 | 319 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 320 | for name, param in resnet_block.named_parameters(): 321 | if name.endswith("bn3.weight"): 322 | nn.init.zeros_(param) 323 | 324 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 325 | attn_std = self.transformer.width ** -0.5 326 | fc_std = (2 * self.transformer.width) ** -0.5 327 | for block in self.transformer.resblocks: 328 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 329 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 330 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 331 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 332 | 333 | if self.text_projection is not None: 334 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 335 | 336 | def build_attention_mask(self): 337 | # lazily create causal attention mask, with full attention between the vision tokens 338 | # pytorch uses additive attention mask; fill with -inf 339 | mask = torch.empty(self.context_length, self.context_length) 340 | mask.fill_(float("-inf")) 341 | mask.triu_(1) # zero out the lower diagonal 342 | return mask 343 | 344 | @property 345 | def dtype(self): 346 | return self.visual.conv1.weight.dtype 347 | 348 | def encode_image(self, image): 349 | return self.visual(image.type(self.dtype)) 350 | 351 | def encode_text(self, text): 352 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 353 | 354 | x = x + self.positional_embedding.type(self.dtype) 355 | x = x.permute(1, 0, 2) # NLD -> LND 356 | x, temp = self.transformer(x) 357 | x = x.permute(1, 0, 2) # LND -> NLD 358 | x = self.ln_final(x).type(self.dtype) 359 | 360 | # x.shape = [batch_size, n_ctx, transformer.width] 361 | # take features from the eot embedding (eot_token is the highest number in each sequence) 362 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 363 | 364 | return x 365 | 366 | def forward(self, image, text): 367 | image_features = self.encode_image(image) 368 | text_features = self.encode_text(text) 369 | 370 | # normalized features 371 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 372 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 373 | 374 | # cosine similarity as logits 375 | logit_scale = self.logit_scale.exp() 376 | logits_per_image = logit_scale * image_features @ text_features.t() 377 | logits_per_text = logit_scale * text_features @ image_features.t() 378 | 379 | # shape = [global_batch_size, global_batch_size] 380 | return logits_per_image, logits_per_text 381 | 382 | 383 | def convert_weights(model: nn.Module): 384 | """Convert applicable model parameters to fp16""" 385 | 386 | def _convert_weights_to_fp16(l): 387 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 388 | l.weight.data = l.weight.data.half() 389 | if l.bias is not None: 390 | l.bias.data = l.bias.data.half() 391 | 392 | if isinstance(l, nn.MultiheadAttention): 393 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 394 | tensor = getattr(l, attr) 395 | if tensor is not None: 396 | tensor.data = tensor.data.half() 397 | 398 | for name in ["text_projection", "proj"]: 399 | if hasattr(l, name): 400 | attr = getattr(l, name) 401 | if attr is not None: 402 | attr.data = attr.data.half() 403 | 404 | model.apply(_convert_weights_to_fp16) 405 | 406 | 407 | def build_model(state_dict: dict): 408 | vit = "visual.proj" in state_dict 409 | 410 | if vit: 411 | vision_width = state_dict["visual.conv1.weight"].shape[0] 412 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 413 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 414 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 415 | image_resolution = vision_patch_size * grid_size 416 | else: 417 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 418 | vision_layers = tuple(counts) 419 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 420 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 421 | vision_patch_size = None 422 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 423 | image_resolution = output_width * 32 424 | 425 | embed_dim = state_dict["text_projection"].shape[1] 426 | context_length = state_dict["positional_embedding"].shape[0] 427 | vocab_size = state_dict["token_embedding.weight"].shape[0] 428 | transformer_width = state_dict["ln_final.weight"].shape[0] 429 | transformer_heads = transformer_width // 64 430 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 431 | 432 | model = CLIP( 433 | embed_dim, 434 | image_resolution, vision_layers, vision_width, vision_patch_size, 435 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 436 | ) 437 | 438 | for key in ["input_resolution", "context_length", "vocab_size"]: 439 | if key in state_dict: 440 | del state_dict[key] 441 | 442 | convert_weights(model) 443 | model.load_state_dict(state_dict) 444 | return model.eval() 445 | -------------------------------------------------------------------------------- /trainers/adclip_vitB16.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | import datetime 4 | import time 5 | from collections import OrderedDict 6 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from torch.cuda.amp import GradScaler, autocast 12 | from tqdm import tqdm 13 | 14 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 15 | from dassl.metrics import compute_accuracy 16 | from dassl.utils import MetricMeter, AverageMeter, load_pretrained_weights, load_checkpoint, save_checkpoint 17 | from dassl.optim import build_optimizer, build_lr_scheduler 18 | 19 | from clip import clip 20 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 21 | 22 | _tokenizer = _Tokenizer() 23 | 24 | 25 | def load_clip_to_cpu(cfg): 26 | backbone_name = cfg.MODEL.BACKBONE.NAME 27 | url = clip._MODELS[backbone_name] 28 | model_path = clip._download(url, cfg.MODEL.BACKBONE.PATH) 29 | 30 | 31 | try: 32 | model = torch.jit.load(model_path, map_location="cpu").eval() 33 | state_dict = None 34 | 35 | except RuntimeError: 36 | state_dict = torch.load(model_path, map_location="cpu") 37 | 38 | model = clip.build_model(state_dict or model.state_dict()) 39 | 40 | return model 41 | 42 | class AdaIN(nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | 46 | def mu(self, x): 47 | return torch.sum(x,(1))/(x.shape[1]) 48 | 49 | def sigma(self, x): 50 | return torch.sqrt((torch.sum((x.permute([1,0,2])-self.mu(x)).permute([1,0,2])**2,(1))+0.000000023)/(x.shape[1])) 51 | 52 | 53 | class domain_projector(nn.Module): 54 | def __init__(self): 55 | super().__init__() 56 | self.linear1 = nn.ModuleList(nn.Linear(768,256) for _ in range (12)) 57 | self.linear2 = nn.ModuleList(nn.Linear(256,512) for _ in range (12)) 58 | self.adain=AdaIN() 59 | self.gap=nn.AdaptiveAvgPool2d((1,768)) 60 | def forward(self, data): 61 | data_prompt=[] 62 | for i in range(len(data)): 63 | x_mu=self.adain.mu(data[i]).unsqueeze(1).to(torch.float32) 64 | x_sigma=self.adain.sigma(data[i]).unsqueeze(1).to(torch.float32) 65 | x_cat = torch.cat((x_mu, x_sigma),1) 66 | x_cat = self.gap(x_cat).squeeze(1) 67 | x_out = self.linear1[i](x_cat) 68 | x_final = self.linear2[i](x_out) 69 | data_prompt.append(x_final) 70 | output = torch.stack(data_prompt, dim=1) 71 | return output 72 | 73 | class image_projector(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | self.linear = nn.ModuleList(nn.Linear(768,512) for _ in range (12)) 77 | self.adain=AdaIN() 78 | self.lin = nn.Linear(12,1) 79 | self.gap=nn.AdaptiveAvgPool2d((1,768)) 80 | 81 | def forward(self, data, n_imgctx): 82 | data_prompt=[] 83 | for i in range(len(data)): 84 | x_gap = self.gap(data[i]).squeeze(1) 85 | x_lin=self.linear[i](x_gap) 86 | data_prompt.append(x_lin) 87 | feat = torch.stack(data_prompt, dim=1) 88 | output = [] 89 | for i in range(n_imgctx): # L decoders 90 | x = self.lin(feat.permute(0,2,1)) 91 | x = x.permute(0,2,1) 92 | output.append(x) 93 | feat_tokens = torch.stack(output, dim=1).squeeze(2) 94 | return feat_tokens 95 | 96 | class style_mapping_projector(nn.Module): 97 | def __init__(self): 98 | super().__init__() 99 | self.linear1 = nn.ModuleList(nn.Linear(768,384) for _ in range (12)) 100 | self.linear2 = nn.ModuleList(nn.Linear(384,512) for _ in range (12)) 101 | self.adain=AdaIN() 102 | self.relu = nn.ReLU() 103 | self.gap=nn.AdaptiveAvgPool1d((768)) 104 | def forward(self, data): 105 | data_prompt=[] 106 | for i in range(len(data)): 107 | x_mu=self.adain.mu(data[i]).to(torch.float32) 108 | x_sigma=self.adain.sigma(data[i]).to(torch.float32) 109 | x_cat = torch.cat((x_mu, x_sigma),1) 110 | x_gap = self.gap(x_cat) 111 | x_out = self.linear1[i](x_gap) 112 | x_relu = self.relu(x_out) 113 | x_final = self.linear2[i](x_relu) 114 | data_prompt.append(x_final) 115 | output = torch.stack(data_prompt, dim=1) 116 | return output 117 | 118 | class TextEncoder(nn.Module): 119 | def __init__(self, clip_model): 120 | super().__init__() 121 | self.transformer = clip_model.transformer 122 | self.positional_embedding = clip_model.positional_embedding 123 | self.ln_final = clip_model.ln_final 124 | self.text_projection = clip_model.text_projection 125 | self.dtype = clip_model.dtype 126 | 127 | @autocast() 128 | def forward(self, prompts, tokenized_prompts): 129 | x = prompts + self.positional_embedding.type(self.dtype) 130 | x = x.permute(1, 0, 2) 131 | x = self.transformer(x) 132 | 133 | x = x[0].permute(1, 0, 2) 134 | x = self.ln_final(x).type(self.dtype) 135 | x = x[torch.arange(x.shape[0]), 136 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection 137 | 138 | return x 139 | 140 | 141 | class PromptLearner(nn.Module): 142 | def __init__(self, cfg, classnames, clip_model): 143 | super().__init__() 144 | n_cls = len(classnames) 145 | n_imgctx = 4 146 | n_ctx = 24 + n_imgctx 147 | 148 | dtype = clip_model.dtype 149 | clip_imsize = clip_model.visual.input_resolution 150 | cfg_imsize = cfg.INPUT.SIZE[0] 151 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 152 | 153 | self.domain_tokens = domain_projector() 154 | self.image_tokens = image_projector() 155 | self.style_mapping_tokens = style_mapping_projector() 156 | 157 | prompt_prefix = " ".join(["X"] * n_ctx) 158 | classnames = [name.replace("_", " ") for name in classnames] 159 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 160 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 161 | 162 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 163 | with torch.no_grad(): 164 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 165 | 166 | # These token vectors will be saved when in save_model(), 167 | # but they should be ignored in load_model() as we want to use 168 | # those computed using the current class names 169 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 170 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 171 | 172 | self.n_cls = n_cls 173 | self.n_ctx = n_ctx 174 | self.n_imgctx = n_imgctx 175 | self.tokenized_prompts = tokenized_prompts 176 | self.name_lens = name_lens 177 | 178 | def construct_prompts(self, ctx, prefix, suffix, label=None): 179 | # dim0 is either batch_size (during training) or n_cls (during testing) 180 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 181 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 182 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 183 | 184 | if label is not None: 185 | prefix = prefix[label] 186 | suffix = suffix[label] 187 | 188 | 189 | prompts = torch.cat( 190 | [ 191 | prefix, 192 | ctx, 193 | suffix, 194 | ], 195 | dim=1, 196 | ) 197 | 198 | return prompts 199 | @autocast() 200 | def forward(self, source_data, target_data): 201 | prefix = self.token_prefix 202 | suffix = self.token_suffix 203 | n_imgctx = self.n_imgctx 204 | 205 | source_domaintokens = self.domain_tokens(source_data) 206 | source_imagetokens = self.image_tokens(source_data, n_imgctx) 207 | source_style_mappingtokens = self.style_mapping_tokens(source_data) 208 | 209 | target_domaintokens = self.domain_tokens(target_data) 210 | target_imagetokens = self.image_tokens(target_data, n_imgctx) 211 | 212 | source_tokens = torch.cat((source_domaintokens, target_domaintokens, source_imagetokens), dim=1) 213 | target_tokens = torch.cat((source_domaintokens, target_domaintokens, target_imagetokens), dim=1) 214 | 215 | source_prompts = [] 216 | for tokens_i in source_tokens: 217 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1) 218 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) 219 | source_prompts.append(pts_i) 220 | source_prompts = torch.stack(source_prompts) 221 | 222 | target_prompts = [] 223 | for tokens_i in target_tokens: 224 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1) 225 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) 226 | target_prompts.append(pts_i) 227 | target_prompts = torch.stack(target_prompts) 228 | 229 | return source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens 230 | 231 | 232 | class CustomCLIP(nn.Module): 233 | def __init__(self, cfg, classnames, clip_model): 234 | super().__init__() 235 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 236 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 237 | self.image_encoder = clip_model.visual 238 | self.text_encoder = TextEncoder(clip_model) 239 | self.logit_scale = clip_model.logit_scale 240 | self.dtype = clip_model.dtype 241 | 242 | @autocast() 243 | def forward(self, s_image, t_image): 244 | source_image_features, source_data = self.image_encoder(s_image.type(self.dtype)) 245 | target_image_features, target_data = self.image_encoder(t_image.type(self.dtype)) 246 | 247 | source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens = self.prompt_learner(source_data, target_data) 248 | tokenized_prompts = self.tokenized_prompts 249 | 250 | source_image_features = source_image_features / source_image_features.norm(dim=-1, 251 | keepdim=True) 252 | target_image_features = target_image_features / target_image_features.norm(dim=-1, 253 | keepdim=True) 254 | logit_scale = self.logit_scale.exp() 255 | 256 | source_text_features = [] 257 | for pts_i in source_prompts: 258 | tf = self.text_encoder(pts_i, tokenized_prompts) 259 | source_text_features.append(tf) 260 | source_text_features=torch.stack(source_text_features) 261 | source_text_features = source_text_features / source_text_features.norm(dim=-1, keepdim=True) 262 | 263 | target_text_features = [] 264 | for pts_i in target_prompts: 265 | tf = self.text_encoder(pts_i, tokenized_prompts) 266 | target_text_features.append(tf) 267 | target_text_features=torch.stack(target_text_features) 268 | target_text_features = target_text_features / target_text_features.norm(dim=-1, keepdim=True) 269 | 270 | 271 | source_logits = [] 272 | 273 | for txt, im in zip(source_text_features, source_image_features): 274 | l_i = logit_scale * im @ txt.t() 275 | source_logits.append(l_i) 276 | source_logits = torch.stack(source_logits) 277 | 278 | target_logits = [] 279 | 280 | for txt, im in zip(target_text_features, target_image_features): 281 | l_i = logit_scale * im @ txt.t() 282 | target_logits.append(l_i) 283 | target_logits = torch.stack(target_logits) 284 | 285 | target_probs = torch.nn.functional.softmax(target_logits, dim=1) 286 | 287 | return source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features 288 | 289 | 290 | class entropy_loss(nn.Module): 291 | def __init__(self): 292 | super(entropy_loss, self).__init__() 293 | 294 | def forward(self, target_prob): 295 | full_enp = torch.zeros(target_prob.shape[0]) 296 | target_prob = nn.functional.normalize(target_prob, dim=0) 297 | 298 | for i in range(len(target_prob)): 299 | total_en = 0 300 | for j in range(target_prob.shape[1]): 301 | total_en = total_en - target_prob[i][j] * torch.log(target_prob[i][j] + 1e-8) 302 | full_enp[i] = total_en 303 | avg_full_enp = torch.mean(full_enp) 304 | return avg_full_enp 305 | 306 | 307 | @TRAINER_REGISTRY.register() 308 | class ADCLIPB16(TrainerXU): 309 | def check_cfg(self, cfg): 310 | assert cfg.TRAINER.ADCLIPB16.PREC in ["fp16", "fp32", "amp"] 311 | 312 | def build_model(self): 313 | cfg = self.cfg 314 | classnames = self.dm.dataset.classnames 315 | 316 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 317 | clip_model = load_clip_to_cpu(cfg) 318 | 319 | if cfg.TRAINER.ADCLIPB16.PREC == "fp32" or cfg.TRAINER.ADCLIPB16.PREC == "amp": 320 | # CLIP's default precision is fp16 321 | clip_model.float() 322 | 323 | print("Building custom CLIP") 324 | self.model = CustomCLIP(cfg, classnames, clip_model) 325 | 326 | self.n_cls = self.model.prompt_learner.n_cls 327 | 328 | name_to_update = "prompt_learner" 329 | 330 | for name, param in self.model.named_parameters(): 331 | if name_to_update not in name: 332 | param.requires_grad_(False) 333 | 334 | # Double check 335 | enabled = set() 336 | for name, param in self.model.named_parameters(): 337 | if param.requires_grad: 338 | enabled.add(name) 339 | print(f"Parameters to be updated: {enabled}") 340 | 341 | 342 | if cfg.MODEL.INIT_WEIGHTS: 343 | load_pretrained_weights(self.model.prompt_learner, 344 | cfg.MODEL.INIT_WEIGHTS) 345 | 346 | self.model.to(self.device) 347 | 348 | # transform the epoch to step schedule 349 | len_train_loader_x = len(self.train_loader_x) 350 | len_train_loader_u = len(self.train_loader_u) 351 | if self.cfg.TRAIN.COUNT_ITER == "train_x": 352 | self.num_batches = len_train_loader_x 353 | elif self.cfg.TRAIN.COUNT_ITER == "train_u": 354 | self.num_batches = len_train_loader_u 355 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one": 356 | self.num_batches = min(len_train_loader_x, len_train_loader_u) 357 | else: 358 | raise ValueError 359 | 360 | # NOTE: only give prompt_learner to the optimizer 361 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 362 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 363 | ''' 364 | register model could be updated. When new module needs to be updated 365 | register the module before use 366 | ''' 367 | self.register_model("prompt_learner", self.model.prompt_learner, 368 | self.optim, self.sched) 369 | 370 | self.scaler = GradScaler() if cfg.TRAINER.ADCLIPB16.PREC == "amp" else None 371 | 372 | device_count = torch.cuda.device_count() 373 | if device_count > 1: 374 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 375 | self.model = nn.DataParallel(self.model) 376 | 377 | def save_model(self, epoch, directory, is_best=False, model_name=""): 378 | names = self.get_model_names() 379 | 380 | for name in names: 381 | model_dict = self._models[name].state_dict() 382 | 383 | optim_dict = None 384 | if self._optims[name] is not None: 385 | optim_dict = self._optims[name].state_dict() 386 | 387 | sched_dict = None 388 | if self._scheds[name] is not None: 389 | sched_dict = self._scheds[name].state_dict() 390 | 391 | save_checkpoint( 392 | { 393 | "state_dict": model_dict, 394 | "epoch": epoch + 1, 395 | "optimizer": optim_dict, 396 | "scheduler": sched_dict, 397 | }, 398 | osp.join(directory, name), 399 | is_best=is_best, 400 | model_name=model_name, 401 | ) 402 | 403 | def train(self): 404 | """Generic training loops.""" 405 | 406 | self.before_train() 407 | for self.epoch in range(self.start_epoch, self.max_epoch): 408 | self.before_epoch() 409 | self.run_epoch() 410 | self.after_epoch() 411 | self.after_train() 412 | 413 | def run_epoch(self): 414 | self.set_model_mode("train") 415 | losses = MetricMeter() 416 | batch_time = AverageMeter() 417 | data_time = AverageMeter() 418 | 419 | # Decide to iterate over labeled or unlabeled dataset 420 | len_train_loader_x = len(self.train_loader_x) 421 | len_train_loader_u = len(self.train_loader_u) 422 | if self.cfg.TRAIN.COUNT_ITER == "train_x": 423 | self.num_batches = len_train_loader_x 424 | elif self.cfg.TRAIN.COUNT_ITER == "train_u": 425 | self.num_batches = len_train_loader_u 426 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one": 427 | self.num_batches = min(len_train_loader_x, len_train_loader_u) 428 | else: 429 | raise ValueError 430 | 431 | train_loader_x_iter = iter(self.train_loader_x) 432 | train_loader_u_iter = iter(self.train_loader_u) 433 | 434 | 435 | end = time.time() 436 | for self.batch_idx in range(self.num_batches): 437 | try: 438 | batch_x = next(train_loader_x_iter) 439 | except StopIteration: 440 | train_loader_x_iter = iter(self.train_loader_x) 441 | batch_x = next(train_loader_x_iter) 442 | 443 | try: 444 | batch_u = next(train_loader_u_iter) 445 | except StopIteration: 446 | train_loader_u_iter = iter(self.train_loader_u) 447 | batch_u = next(train_loader_u_iter) 448 | 449 | data_time.update(time.time() - end) 450 | loss_summary = self.forward_backward(batch_x, batch_u) 451 | batch_time.update(time.time() - end) 452 | losses.update(loss_summary) 453 | 454 | if ( 455 | self.batch_idx + 1 456 | ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ: 457 | nb_remain = 0 458 | nb_remain += self.num_batches - self.batch_idx - 1 459 | nb_remain += (self.max_epoch - self.epoch - 460 | 1) * self.num_batches 461 | eta_seconds = batch_time.avg * nb_remain 462 | eta = str(datetime.timedelta(seconds=int(eta_seconds))) 463 | print("epoch [{0}/{1}][{2}/{3}]\t" 464 | "time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 465 | "data {data_time.val:.3f} ({data_time.avg:.3f})\t" 466 | "eta {eta}\t" 467 | "{losses}\t" 468 | "lr {lr:.6e}".format( 469 | self.epoch + 1, 470 | self.max_epoch, 471 | self.batch_idx + 1, 472 | self.num_batches, 473 | batch_time=batch_time, 474 | data_time=data_time, 475 | eta=eta, 476 | losses=losses, 477 | lr=self.get_current_lr(), 478 | )) 479 | 480 | n_iter = self.epoch * self.num_batches + self.batch_idx 481 | for name, meter in losses.meters.items(): 482 | self.write_scalar("train/" + name, meter.avg, n_iter) 483 | self.write_scalar("train/lr", self.get_current_lr(), n_iter) 484 | 485 | end = time.time() 486 | 487 | def forward_backward(self, batch_x, batch_u): 488 | self.entropy = entropy_loss() 489 | kl_loss = nn.KLDivLoss(reduction="batchmean") 490 | image_x, label, image_u = self.parse_batch_train(batch_x, batch_u) 491 | prec = self.cfg.TRAINER.ADCLIPB16.PREC 492 | if prec == "amp": 493 | with autocast(): 494 | source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model(image_x, image_u) 495 | 496 | loss_ce = F.cross_entropy(source_logits, label) 497 | source_textfeat = F.log_softmax(source_text_features, dim=1) 498 | target_textfeat = F.softmax(target_text_features, dim=1) 499 | loss_kl = kl_loss(source_textfeat, target_textfeat) 500 | loss_smn = F.mse_loss(source_domaintokens, source_style_mappingtokens) 501 | loss_entropy = self.entropy(target_probs) 502 | 503 | loss = loss_ce + 0.1*loss_smn + 0.01*loss_entropy + loss_kl 504 | 505 | self.optim.zero_grad() 506 | self.scaler.scale(loss).backward() 507 | self.scaler.step(self.optim) 508 | self.scaler.update() 509 | 510 | 511 | loss_summary = { 512 | "loss": 513 | loss.item(), 514 | "loss_ce": 515 | loss_ce.item(), 516 | "loss_smn": 517 | loss_smn.item(), 518 | "loss_entropy": 519 | loss_entropy.item(), 520 | "loss_kl": 521 | loss_kl.item(), 522 | "acc_x": 523 | compute_accuracy(source_logits[:, :self.n_cls], label)[0].item(), 524 | } 525 | 526 | self.update_lr() 527 | 528 | return loss_summary 529 | 530 | def after_epoch(self): 531 | last_epoch = (self.epoch + 1) == self.max_epoch 532 | do_test = not self.cfg.TEST.NO_TEST 533 | meet_checkpoint_freq = ((self.epoch + 1) % 534 | self.cfg.TRAIN.CHECKPOINT_FREQ == 0 if 535 | self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False) 536 | 537 | if do_test: 538 | curr_result = self.test() 539 | is_best = curr_result > self.best_result 540 | if is_best: 541 | self.best_result = curr_result 542 | self.save_model(self.epoch, 543 | self.output_dir, 544 | model_name="model-best.pth.tar") 545 | 546 | self.set_model_mode("train") 547 | 548 | if meet_checkpoint_freq or last_epoch: 549 | self.save_model(self.epoch, self.output_dir) 550 | 551 | def parse_batch_train(self, batch_x, batch_u): 552 | input = batch_x["img"] 553 | label = batch_x["label"] 554 | input_u = batch_u["img"] 555 | input = input.to(self.device) 556 | label = label.to(self.device) 557 | input_u = input_u.to(self.device) 558 | return input, label, input_u 559 | 560 | def load_model(self, directory, epoch=None): 561 | if not directory: 562 | print( 563 | "Note that load_model() is skipped as no pretrained model is given" 564 | ) 565 | return 566 | 567 | names = self.get_model_names() 568 | 569 | # By default, the best model is loaded 570 | model_file = "model-best.pth.tar" 571 | 572 | if epoch is not None: 573 | model_file = "model.pth.tar-" + str(epoch) 574 | 575 | for name in names: 576 | model_path = osp.join(directory, name, model_file) 577 | 578 | if not osp.exists(model_path): 579 | raise FileNotFoundError( 580 | 'Model not found at "{}"'.format(model_path)) 581 | 582 | checkpoint = load_checkpoint(model_path) 583 | state_dict = checkpoint["state_dict"] 584 | epoch = checkpoint["epoch"] 585 | 586 | # Ignore fixed token vectors 587 | if "token_prefix" in state_dict: 588 | del state_dict["token_prefix"] 589 | 590 | if "token_suffix" in state_dict: 591 | del state_dict["token_suffix"] 592 | 593 | print("Loading weights to {} " 594 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 595 | # set strict=False 596 | self._models[name].load_state_dict(state_dict, strict=False) 597 | 598 | @torch.no_grad() 599 | def test(self, split=None): 600 | """A generic testing pipeline.""" 601 | self.set_model_mode("eval") 602 | self.evaluator.reset() 603 | 604 | if split is None: 605 | split = self.cfg.TEST.SPLIT 606 | 607 | split = "test" 608 | data_loader = self.test_loader 609 | print(f"Evaluate on the *{split}* set") 610 | 611 | 612 | for batch_idx, batch in enumerate(tqdm(data_loader)): 613 | input, label = self.parse_batch_test(batch) 614 | output, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model_inference(input) 615 | self.evaluator.process(output, label) 616 | 617 | results = self.evaluator.evaluate() 618 | 619 | for k, v in results.items(): 620 | tag = f"{split}/{k}" 621 | self.write_scalar(tag, v, self.epoch) 622 | 623 | return list(results.values())[0] 624 | -------------------------------------------------------------------------------- /trainers/adclip_vitL14.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | import datetime 4 | import time 5 | from collections import OrderedDict 6 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from torch.cuda.amp import GradScaler, autocast 12 | from tqdm import tqdm 13 | 14 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 15 | from dassl.metrics import compute_accuracy 16 | from dassl.utils import MetricMeter, AverageMeter, load_pretrained_weights, load_checkpoint, save_checkpoint 17 | from dassl.optim import build_optimizer, build_lr_scheduler 18 | 19 | from clip import clip 20 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 21 | 22 | _tokenizer = _Tokenizer() 23 | 24 | 25 | def load_clip_to_cpu(cfg): 26 | backbone_name = cfg.MODEL.BACKBONE.NAME 27 | url = clip._MODELS[backbone_name] 28 | model_path = clip._download(url, cfg.MODEL.BACKBONE.PATH) 29 | 30 | 31 | try: 32 | model = torch.jit.load(model_path, map_location="cpu").eval() 33 | state_dict = None 34 | 35 | except RuntimeError: 36 | state_dict = torch.load(model_path, map_location="cpu") 37 | 38 | model = clip.build_model(state_dict or model.state_dict()) 39 | 40 | return model 41 | 42 | class AdaIN(nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | 46 | def mu(self, x): 47 | return torch.sum(x,(1))/(x.shape[1]) 48 | 49 | def sigma(self, x): 50 | return torch.sqrt((torch.sum((x.permute([1,0,2])-self.mu(x)).permute([1,0,2])**2,(1))+0.000000023)/(x.shape[1])) 51 | 52 | 53 | class domain_projector(nn.Module): 54 | def __init__(self): 55 | super().__init__() 56 | self.linear1 = nn.ModuleList(nn.Linear(1024,512) for _ in range (24)) 57 | self.linear2 = nn.ModuleList(nn.Linear(512,768) for _ in range (24)) 58 | self.adain=AdaIN() 59 | self.gap=nn.AdaptiveAvgPool2d((1,1024)) 60 | def forward(self, data): 61 | data_prompt=[] 62 | for i in range(len(data)): 63 | x_mu=self.adain.mu(data[i]).unsqueeze(1).to(torch.float32) 64 | x_sigma=self.adain.sigma(data[i]).unsqueeze(1).to(torch.float32) 65 | x_cat = torch.cat((x_mu, x_sigma),1) 66 | x_cat = self.gap(x_cat).squeeze(1) 67 | x_out = self.linear1[i](x_cat) 68 | x_final = self.linear2[i](x_out) 69 | data_prompt.append(x_final) 70 | output = torch.stack(data_prompt, dim=1) 71 | return output 72 | 73 | class image_projector(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | self.linear = nn.ModuleList(nn.Linear(1024,768) for _ in range (24)) 77 | self.adain=AdaIN() 78 | self.lin = nn.Linear(24,1) 79 | self.gap=nn.AdaptiveAvgPool2d((1,1024)) 80 | 81 | def forward(self, data, n_imgctx): 82 | data_prompt=[] 83 | for i in range(len(data)): 84 | x_gap = self.gap(data[i]).squeeze(1) 85 | x_lin=self.linear[i](x_gap) 86 | data_prompt.append(x_lin) 87 | feat = torch.stack(data_prompt, dim=1) 88 | output = [] 89 | for i in range(n_imgctx): # L decoders 90 | x = self.lin(feat.permute(0,2,1)) 91 | x = x.permute(0,2,1) 92 | output.append(x) 93 | feat_tokens = torch.stack(output, dim=1).squeeze(2) 94 | return feat_tokens 95 | 96 | class style_mapping_projector(nn.Module): 97 | def __init__(self): 98 | super().__init__() 99 | self.linear1 = nn.ModuleList(nn.Linear(1024,640) for _ in range (24)) 100 | self.linear2 = nn.ModuleList(nn.Linear(640,768) for _ in range (24)) 101 | self.adain=AdaIN() 102 | self.relu = nn.ReLU() 103 | self.gap=nn.AdaptiveAvgPool1d((1024)) 104 | def forward(self, data): 105 | data_prompt=[] 106 | for i in range(len(data)): 107 | x_mu=self.adain.mu(data[i]).to(torch.float32) 108 | x_sigma=self.adain.sigma(data[i]).to(torch.float32) 109 | x_cat = torch.cat((x_mu, x_sigma),1) 110 | x_gap = self.gap(x_cat) 111 | x_out = self.linear1[i](x_gap) 112 | x_relu = self.relu(x_out) 113 | x_final = self.linear2[i](x_relu) 114 | data_prompt.append(x_final) 115 | output = torch.stack(data_prompt, dim=1) 116 | return output 117 | 118 | class TextEncoder(nn.Module): 119 | def __init__(self, clip_model): 120 | super().__init__() 121 | self.transformer = clip_model.transformer 122 | self.positional_embedding = clip_model.positional_embedding 123 | self.ln_final = clip_model.ln_final 124 | self.text_projection = clip_model.text_projection 125 | self.dtype = clip_model.dtype 126 | 127 | @autocast() 128 | def forward(self, prompts, tokenized_prompts): 129 | x = prompts + self.positional_embedding.type(self.dtype) 130 | x = x.permute(1, 0, 2) # NLD -> LND 131 | x = self.transformer(x) 132 | 133 | x = x[0].permute(1, 0, 2) # LND -> NLD 134 | x = self.ln_final(x).type(self.dtype) 135 | x = x[torch.arange(x.shape[0]), 136 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection 137 | 138 | return x 139 | 140 | 141 | class PromptLearner(nn.Module): 142 | def __init__(self, cfg, classnames, clip_model): 143 | super().__init__() 144 | n_cls = len(classnames) 145 | n_imgctx = 4 146 | n_ctx = 48 + n_imgctx 147 | 148 | dtype = clip_model.dtype 149 | ctx_dim = clip_model.ln_final.weight.shape[0] 150 | vis_dim = clip_model.visual.output_dim 151 | clip_imsize = clip_model.visual.input_resolution 152 | cfg_imsize = cfg.INPUT.SIZE[0] 153 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 154 | 155 | self.domain_tokens = domain_projector() 156 | self.image_tokens = image_projector() 157 | self.style_mapping_tokens = style_mapping_projector() 158 | 159 | prompt_prefix = " ".join(["X"] * n_ctx) 160 | classnames = [name.replace("_", " ") for name in classnames] 161 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 162 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 163 | 164 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 165 | with torch.no_grad(): 166 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 167 | 168 | # These token vectors will be saved when in save_model(), 169 | # but they should be ignored in load_model() as we want to use 170 | # those computed using the current class names 171 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 172 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 173 | 174 | self.n_cls = n_cls 175 | self.n_ctx = n_ctx 176 | self.n_imgctx = n_imgctx 177 | self.tokenized_prompts = tokenized_prompts 178 | self.name_lens = name_lens 179 | 180 | def construct_prompts(self, ctx, prefix, suffix, label=None): 181 | # dim0 is either batch_size (during training) or n_cls (during testing) 182 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 183 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 184 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 185 | 186 | if label is not None: 187 | prefix = prefix[label] 188 | suffix = suffix[label] 189 | 190 | 191 | prompts = torch.cat( 192 | [ 193 | prefix, 194 | ctx, 195 | suffix, 196 | ], 197 | dim=1, 198 | ) 199 | 200 | return prompts 201 | @autocast() 202 | def forward(self, source_data, target_data): 203 | prefix = self.token_prefix 204 | suffix = self.token_suffix 205 | n_imgctx = self.n_imgctx 206 | 207 | source_domaintokens = self.domain_tokens(source_data) 208 | source_imagetokens = self.image_tokens(source_data, n_imgctx) 209 | source_style_mappingtokens = self.style_mapping_tokens(source_data) 210 | 211 | target_domaintokens = self.domain_tokens(target_data) 212 | target_imagetokens = self.image_tokens(target_data, n_imgctx) 213 | 214 | source_tokens = torch.cat((source_domaintokens, target_domaintokens, source_imagetokens), dim=1) 215 | target_tokens = torch.cat((source_domaintokens, target_domaintokens, target_imagetokens), dim=1) 216 | 217 | source_prompts = [] 218 | for tokens_i in source_tokens: 219 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1) 220 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) 221 | source_prompts.append(pts_i) 222 | source_prompts = torch.stack(source_prompts) 223 | 224 | target_prompts = [] 225 | for tokens_i in target_tokens: 226 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1) 227 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) 228 | target_prompts.append(pts_i) 229 | target_prompts = torch.stack(target_prompts) 230 | 231 | return source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens 232 | 233 | class CustomCLIP(nn.Module): 234 | def __init__(self, cfg, classnames, clip_model): 235 | super().__init__() 236 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 237 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 238 | self.image_encoder = clip_model.visual 239 | self.text_encoder = TextEncoder(clip_model) 240 | self.logit_scale = clip_model.logit_scale 241 | self.dtype = clip_model.dtype 242 | 243 | @autocast() 244 | def forward(self, s_image, t_image): 245 | source_image_features, source_data = self.image_encoder(s_image.type(self.dtype)) 246 | target_image_features, target_data = self.image_encoder(t_image.type(self.dtype)) 247 | 248 | source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens = self.prompt_learner(source_data, target_data) 249 | tokenized_prompts = self.tokenized_prompts 250 | 251 | source_image_features = source_image_features / source_image_features.norm(dim=-1, 252 | keepdim=True) 253 | target_image_features = target_image_features / target_image_features.norm(dim=-1, 254 | keepdim=True) 255 | logit_scale = self.logit_scale.exp() 256 | 257 | source_text_features = [] 258 | for pts_i in source_prompts: 259 | tf = self.text_encoder(pts_i, tokenized_prompts) 260 | source_text_features.append(tf) 261 | source_text_features=torch.stack(source_text_features) 262 | source_text_features = source_text_features / source_text_features.norm(dim=-1, keepdim=True) 263 | 264 | target_text_features = [] 265 | for pts_i in target_prompts: 266 | tf = self.text_encoder(pts_i, tokenized_prompts) 267 | target_text_features.append(tf) 268 | target_text_features=torch.stack(target_text_features) 269 | target_text_features = target_text_features / target_text_features.norm(dim=-1, keepdim=True) 270 | 271 | 272 | source_logits = [] 273 | 274 | for txt, im in zip(source_text_features, source_image_features): 275 | l_i = logit_scale * im @ txt.t() 276 | source_logits.append(l_i) 277 | source_logits = torch.stack(source_logits) 278 | 279 | target_logits = [] 280 | 281 | for txt, im in zip(target_text_features, target_image_features): 282 | l_i = logit_scale * im @ txt.t() 283 | target_logits.append(l_i) 284 | target_logits = torch.stack(target_logits) 285 | 286 | target_probs = torch.nn.functional.softmax(target_logits, dim=1) 287 | 288 | return source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features 289 | 290 | 291 | class entropy_loss(nn.Module): 292 | def __init__(self): 293 | super(entropy_loss, self).__init__() 294 | 295 | def forward(self, target_prob): 296 | full_enp = torch.zeros(target_prob.shape[0]) 297 | target_prob = nn.functional.normalize(target_prob, dim=0) 298 | 299 | for i in range(len(target_prob)): 300 | total_en = 0 301 | for j in range(target_prob.shape[1]): 302 | total_en = total_en - target_prob[i][j] * torch.log(target_prob[i][j] + 1e-8) 303 | full_enp[i] = total_en 304 | avg_full_enp = torch.mean(full_enp) 305 | return avg_full_enp 306 | 307 | 308 | @TRAINER_REGISTRY.register() 309 | class ADCLIPL14(TrainerXU): 310 | def check_cfg(self, cfg): 311 | assert cfg.TRAINER.ADCLIPL14.PREC in ["fp16", "fp32", "amp"] 312 | 313 | def build_model(self): 314 | cfg = self.cfg 315 | classnames = self.dm.dataset.classnames 316 | 317 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 318 | clip_model = load_clip_to_cpu(cfg) 319 | 320 | if cfg.TRAINER.ADCLIPL14.PREC == "fp32" or cfg.TRAINER.ADCLIPL14.PREC == "amp": 321 | # CLIP's default precision is fp16 322 | clip_model.float() 323 | 324 | print("Building custom CLIP") 325 | self.model = CustomCLIP(cfg, classnames, clip_model) 326 | 327 | self.n_cls = self.model.prompt_learner.n_cls 328 | 329 | name_to_update = "prompt_learner" 330 | 331 | for name, param in self.model.named_parameters(): 332 | if name_to_update not in name: 333 | param.requires_grad_(False) 334 | 335 | enabled = set() 336 | for name, param in self.model.named_parameters(): 337 | if param.requires_grad: 338 | enabled.add(name) 339 | print(f"Parameters to be updated: {enabled}") 340 | 341 | if cfg.MODEL.INIT_WEIGHTS: 342 | load_pretrained_weights(self.model.prompt_learner, 343 | cfg.MODEL.INIT_WEIGHTS) 344 | 345 | self.model.to(self.device) 346 | 347 | # transform the epoch to step schedule 348 | len_train_loader_x = len(self.train_loader_x) 349 | len_train_loader_u = len(self.train_loader_u) 350 | if self.cfg.TRAIN.COUNT_ITER == "train_x": 351 | self.num_batches = len_train_loader_x 352 | elif self.cfg.TRAIN.COUNT_ITER == "train_u": 353 | self.num_batches = len_train_loader_u 354 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one": 355 | self.num_batches = min(len_train_loader_x, len_train_loader_u) 356 | else: 357 | raise ValueError 358 | 359 | # NOTE: only give prompt_learner to the optimizer 360 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 361 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 362 | ''' 363 | register model could be updated. When new module needs to be updated 364 | register the module before use 365 | ''' 366 | self.register_model("prompt_learner", self.model.prompt_learner, 367 | self.optim, self.sched) 368 | 369 | self.scaler = GradScaler() if cfg.TRAINER.ADCLIPL14.PREC == "amp" else None 370 | 371 | device_count = torch.cuda.device_count() 372 | if device_count > 1: 373 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 374 | self.model = nn.DataParallel(self.model) 375 | 376 | def save_model(self, epoch, directory, is_best=False, model_name=""): 377 | names = self.get_model_names() 378 | 379 | for name in names: 380 | model_dict = self._models[name].state_dict() 381 | 382 | optim_dict = None 383 | if self._optims[name] is not None: 384 | optim_dict = self._optims[name].state_dict() 385 | 386 | sched_dict = None 387 | if self._scheds[name] is not None: 388 | sched_dict = self._scheds[name].state_dict() 389 | 390 | save_checkpoint( 391 | { 392 | "state_dict": model_dict, 393 | "epoch": epoch + 1, 394 | "optimizer": optim_dict, 395 | "scheduler": sched_dict, 396 | }, 397 | osp.join(directory, name), 398 | is_best=is_best, 399 | model_name=model_name, 400 | ) 401 | 402 | def train(self): 403 | """Generic training loops.""" 404 | 405 | self.before_train() 406 | for self.epoch in range(self.start_epoch, self.max_epoch): 407 | self.before_epoch() 408 | self.run_epoch() 409 | self.after_epoch() 410 | self.after_train() 411 | 412 | def run_epoch(self): 413 | self.set_model_mode("train") 414 | losses = MetricMeter() 415 | batch_time = AverageMeter() 416 | data_time = AverageMeter() 417 | 418 | # Decide to iterate over labeled or unlabeled dataset 419 | len_train_loader_x = len(self.train_loader_x) 420 | len_train_loader_u = len(self.train_loader_u) 421 | if self.cfg.TRAIN.COUNT_ITER == "train_x": 422 | self.num_batches = len_train_loader_x 423 | elif self.cfg.TRAIN.COUNT_ITER == "train_u": 424 | self.num_batches = len_train_loader_u 425 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one": 426 | self.num_batches = min(len_train_loader_x, len_train_loader_u) 427 | else: 428 | raise ValueError 429 | 430 | train_loader_x_iter = iter(self.train_loader_x) 431 | train_loader_u_iter = iter(self.train_loader_u) 432 | 433 | 434 | end = time.time() 435 | for self.batch_idx in range(self.num_batches): 436 | try: 437 | batch_x = next(train_loader_x_iter) 438 | except StopIteration: 439 | train_loader_x_iter = iter(self.train_loader_x) 440 | batch_x = next(train_loader_x_iter) 441 | 442 | try: 443 | batch_u = next(train_loader_u_iter) 444 | except StopIteration: 445 | train_loader_u_iter = iter(self.train_loader_u) 446 | batch_u = next(train_loader_u_iter) 447 | 448 | data_time.update(time.time() - end) 449 | loss_summary = self.forward_backward(batch_x, batch_u) 450 | batch_time.update(time.time() - end) 451 | losses.update(loss_summary) 452 | 453 | if ( 454 | self.batch_idx + 1 455 | ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ: 456 | nb_remain = 0 457 | nb_remain += self.num_batches - self.batch_idx - 1 458 | nb_remain += (self.max_epoch - self.epoch - 459 | 1) * self.num_batches 460 | eta_seconds = batch_time.avg * nb_remain 461 | eta = str(datetime.timedelta(seconds=int(eta_seconds))) 462 | print("epoch [{0}/{1}][{2}/{3}]\t" 463 | "time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 464 | "data {data_time.val:.3f} ({data_time.avg:.3f})\t" 465 | "eta {eta}\t" 466 | "{losses}\t" 467 | "lr {lr:.6e}".format( 468 | self.epoch + 1, 469 | self.max_epoch, 470 | self.batch_idx + 1, 471 | self.num_batches, 472 | batch_time=batch_time, 473 | data_time=data_time, 474 | eta=eta, 475 | losses=losses, 476 | lr=self.get_current_lr(), 477 | )) 478 | 479 | n_iter = self.epoch * self.num_batches + self.batch_idx 480 | for name, meter in losses.meters.items(): 481 | self.write_scalar("train/" + name, meter.avg, n_iter) 482 | self.write_scalar("train/lr", self.get_current_lr(), n_iter) 483 | 484 | end = time.time() 485 | 486 | def forward_backward(self, batch_x, batch_u): 487 | self.entropy = entropy_loss() 488 | kl_loss = nn.KLDivLoss(reduction="batchmean") 489 | image_x, label, image_u = self.parse_batch_train(batch_x, batch_u) 490 | prec = self.cfg.TRAINER.ADCLIPL14.PREC 491 | # alpha_wt = self.alpha 492 | if prec == "amp": 493 | with autocast(): 494 | source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model(image_x, image_u) 495 | 496 | loss_ce = F.cross_entropy(source_logits, label) 497 | source_textfeat = F.log_softmax(source_text_features, dim=1) 498 | target_textfeat = F.softmax(target_text_features, dim=1) 499 | loss_kl = kl_loss(source_textfeat, target_textfeat) 500 | loss_smn = F.mse_loss(source_domaintokens, source_style_mappingtokens) 501 | loss_entropy = self.entropy(target_probs) 502 | 503 | loss = loss_ce + 0.1*loss_smn + 0.01*loss_entropy + loss_kl 504 | 505 | self.optim.zero_grad() 506 | self.scaler.scale(loss).backward() 507 | self.scaler.step(self.optim) 508 | self.scaler.update() 509 | 510 | 511 | loss_summary = { 512 | "loss": 513 | loss.item(), 514 | "loss_ce": 515 | loss_ce.item(), 516 | "loss_smn": 517 | loss_smn.item(), 518 | "loss_entropy": 519 | loss_entropy.item(), 520 | "loss_kl": 521 | loss_kl.item(), 522 | "acc_x": 523 | compute_accuracy(source_logits[:, :self.n_cls], label)[0].item(), 524 | } 525 | 526 | self.update_lr() 527 | 528 | return loss_summary 529 | 530 | def after_epoch(self): 531 | last_epoch = (self.epoch + 1) == self.max_epoch 532 | do_test = not self.cfg.TEST.NO_TEST 533 | meet_checkpoint_freq = ((self.epoch + 1) % 534 | self.cfg.TRAIN.CHECKPOINT_FREQ == 0 if 535 | self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False) 536 | 537 | if do_test: 538 | curr_result = self.test() 539 | is_best = curr_result > self.best_result 540 | if is_best: 541 | self.best_result = curr_result 542 | self.save_model(self.epoch, 543 | self.output_dir, 544 | model_name="model-best.pth.tar") 545 | 546 | self.set_model_mode("train") 547 | 548 | if meet_checkpoint_freq or last_epoch: 549 | self.save_model(self.epoch, self.output_dir) 550 | 551 | def parse_batch_train(self, batch_x, batch_u): 552 | input = batch_x["img"] 553 | label = batch_x["label"] 554 | input_u = batch_u["img"] 555 | input = input.to(self.device) 556 | label = label.to(self.device) 557 | input_u = input_u.to(self.device) 558 | return input, label, input_u 559 | 560 | def load_model(self, directory, epoch=None): 561 | if not directory: 562 | print( 563 | "Note that load_model() is skipped as no pretrained model is given" 564 | ) 565 | return 566 | 567 | names = self.get_model_names() 568 | 569 | # By default, the best model is loaded 570 | model_file = "model-best.pth.tar" 571 | 572 | if epoch is not None: 573 | model_file = "model.pth.tar-" + str(epoch) 574 | 575 | for name in names: 576 | model_path = osp.join(directory, name, model_file) 577 | 578 | if not osp.exists(model_path): 579 | raise FileNotFoundError( 580 | 'Model not found at "{}"'.format(model_path)) 581 | 582 | checkpoint = load_checkpoint(model_path) 583 | state_dict = checkpoint["state_dict"] 584 | epoch = checkpoint["epoch"] 585 | 586 | # Ignore fixed token vectors 587 | if "token_prefix" in state_dict: 588 | del state_dict["token_prefix"] 589 | 590 | if "token_suffix" in state_dict: 591 | del state_dict["token_suffix"] 592 | 593 | print("Loading weights to {} " 594 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 595 | # set strict=False 596 | self._models[name].load_state_dict(state_dict, strict=False) 597 | 598 | @torch.no_grad() 599 | def test(self, split=None): 600 | """A generic testing pipeline.""" 601 | self.set_model_mode("eval") 602 | self.evaluator.reset() 603 | 604 | if split is None: 605 | split = self.cfg.TEST.SPLIT 606 | 607 | split = "test" 608 | data_loader = self.test_loader 609 | print(f"Evaluate on the *{split}* set") 610 | 611 | 612 | for batch_idx, batch in enumerate(tqdm(data_loader)): 613 | input, label = self.parse_batch_test(batch) 614 | output, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model_inference(input) 615 | self.evaluator.process(output, label) 616 | 617 | results = self.evaluator.evaluate() 618 | 619 | for k, v in results.items(): 620 | tag = f"{split}/{k}" 621 | self.write_scalar(tag, v, self.epoch) 622 | 623 | return list(results.values())[0] 624 | -------------------------------------------------------------------------------- /trainers/adclip_rn50.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | import datetime 4 | import time 5 | from collections import OrderedDict 6 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from torch.cuda.amp import GradScaler, autocast 12 | from tqdm import tqdm 13 | 14 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 15 | from dassl.metrics import compute_accuracy 16 | from dassl.utils import MetricMeter, AverageMeter, load_pretrained_weights, load_checkpoint, save_checkpoint 17 | from dassl.optim import build_optimizer, build_lr_scheduler 18 | 19 | from clip import clip 20 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 21 | 22 | _tokenizer = _Tokenizer() 23 | 24 | device_cuda = "cuda" 25 | 26 | 27 | def load_clip_to_cpu(cfg): 28 | backbone_name = cfg.MODEL.BACKBONE.NAME 29 | url = clip._MODELS[backbone_name] 30 | model_path = clip._download(url, cfg.MODEL.BACKBONE.PATH) 31 | 32 | 33 | try: 34 | model = torch.jit.load(model_path, map_location="cpu").eval() 35 | state_dict = None 36 | 37 | except RuntimeError: 38 | state_dict = torch.load(model_path, map_location="cpu") 39 | 40 | model = clip.build_model(state_dict or model.state_dict()) 41 | 42 | return model 43 | 44 | class AdaIN(nn.Module): 45 | def __init__(self): 46 | super().__init__() 47 | def mu(self, x): 48 | return torch.mean(x, dim=(2, 3)) 49 | 50 | def sigma(self, x): 51 | mean = torch.mean(x, dim=(2, 3), keepdim=True) 52 | squared_diff = (x - mean) ** 2 53 | sum_squared_diff = torch.sum(squared_diff, dim=(2, 3)) 54 | epsilon = 1e-8 55 | std_dev = torch.sqrt((sum_squared_diff + epsilon) / (x.shape[2] * x.shape[3])) 56 | return std_dev 57 | 58 | class domain_projector(nn.Module): 59 | def __init__(self): 60 | super().__init__() 61 | self.linear1 = [] 62 | self.linear1.append(nn.Linear(256,256).to(device_cuda)) 63 | self.linear1.append(nn.Linear(512,256).to(device_cuda)) 64 | self.linear1.append(nn.Linear(1024,256).to(device_cuda)) 65 | self.linear1.append(nn.Linear(2048,256).to(device_cuda)) 66 | self.adain=AdaIN() 67 | self.gap = [] 68 | self.gap.append(nn.AdaptiveAvgPool2d((1,256))) 69 | self.gap.append(nn.AdaptiveAvgPool2d((1,512))) 70 | self.gap.append(nn.AdaptiveAvgPool2d((1,1024))) 71 | self.gap.append(nn.AdaptiveAvgPool2d((1,2048))) 72 | self.linear2 = nn.ModuleList(nn.Linear(256,512) for _ in range (4)) 73 | def forward(self, data): 74 | data_prompt=[] 75 | for i in range(len(data)): 76 | x_mu=self.adain.mu(data[i]).unsqueeze(1).to(torch.float32) 77 | x_sigma=self.adain.sigma(data[i]).unsqueeze(1).to(torch.float32) 78 | x_cat = torch.cat((x_mu, x_sigma),1) 79 | x_cat = self.gap[i](x_cat).squeeze(1) 80 | x_out = self.linear1[i](x_cat) 81 | x_final = self.linear2[i](x_out) 82 | data_prompt.append(x_final) 83 | output = torch.stack(data_prompt, dim=1) 84 | return output 85 | 86 | class image_projector(nn.Module): 87 | def __init__(self): 88 | super().__init__() 89 | self.linear = [] 90 | self.linear.append(nn.Linear(256,512).to(device_cuda)) 91 | self.linear.append(nn.Linear(512,512).to(device_cuda)) 92 | self.linear.append(nn.Linear(1024,512).to(device_cuda)) 93 | self.linear.append(nn.Linear(2048,512).to(device_cuda)) 94 | self.adain=AdaIN() 95 | self.lin = nn.Linear(4,1) 96 | self.gap=nn.AdaptiveAvgPool2d((1,1)) 97 | 98 | def forward(self, data, n_imgctx): 99 | data_prompt=[] 100 | for i in range(len(data)): 101 | x_gap = self.gap(data[i]).squeeze(3).squeeze(2) 102 | x_lin=self.linear[i](x_gap) 103 | data_prompt.append(x_lin) 104 | feat = torch.stack(data_prompt, dim=1) 105 | output = [] 106 | for i in range(n_imgctx): # L decoders 107 | x = self.lin(feat.permute(0,2,1)) 108 | x = x.permute(0,2,1) 109 | output.append(x) 110 | feat_tokens = torch.stack(output, dim=1).squeeze(2) 111 | return feat_tokens 112 | 113 | class style_mapping_projector(nn.Module): 114 | def __init__(self): 115 | super().__init__() 116 | self.linear1 = [] 117 | self.linear1.append(nn.Linear(256,384).to(device_cuda)) 118 | self.linear1.append(nn.Linear(512,384).to(device_cuda)) 119 | self.linear1.append(nn.Linear(1024,384).to(device_cuda)) 120 | self.linear1.append(nn.Linear(2048,384).to(device_cuda)) 121 | self.adain=AdaIN() 122 | self.relu = nn.ReLU() 123 | self.gap = [] 124 | self.gap.append(nn.AdaptiveAvgPool1d((256))) 125 | self.gap.append(nn.AdaptiveAvgPool1d((512))) 126 | self.gap.append(nn.AdaptiveAvgPool1d((1024))) 127 | self.gap.append(nn.AdaptiveAvgPool1d((2048))) 128 | self.linear2 = nn.ModuleList(nn.Linear(384,512) for _ in range (4)) 129 | def forward(self, data): 130 | data_prompt=[] 131 | for i in range(len(data)): 132 | x_mu=self.adain.mu(data[i]).to(torch.float32) 133 | x_sigma=self.adain.sigma(data[i]).to(torch.float32) 134 | x_cat = torch.cat((x_mu, x_sigma),1) 135 | x_gap = self.gap[i](x_cat) 136 | x_out = self.linear1[i](x_gap) 137 | x_relu = self.relu(x_out) 138 | x_final = self.linear2[i](x_relu) 139 | data_prompt.append(x_final) 140 | output = torch.stack(data_prompt, dim=1) 141 | return output 142 | 143 | class TextEncoder(nn.Module): 144 | def __init__(self, clip_model): 145 | super().__init__() 146 | self.transformer = clip_model.transformer 147 | self.positional_embedding = clip_model.positional_embedding 148 | self.ln_final = clip_model.ln_final 149 | self.text_projection = clip_model.text_projection 150 | self.dtype = clip_model.dtype 151 | 152 | @autocast() 153 | def forward(self, prompts, tokenized_prompts): 154 | x = prompts + self.positional_embedding.type(self.dtype) 155 | x = x.permute(1, 0, 2) # NLD -> LND 156 | x = self.transformer(x) 157 | 158 | x = x[0].permute(1, 0, 2) # LND -> NLD 159 | x = self.ln_final(x).type(self.dtype) 160 | x = x[torch.arange(x.shape[0]), 161 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection 162 | 163 | return x 164 | 165 | 166 | class PromptLearner(nn.Module): 167 | def __init__(self, cfg, classnames, clip_model): 168 | super().__init__() 169 | n_cls = len(classnames) 170 | n_imgctx = 4 171 | n_ctx = 8 + n_imgctx 172 | 173 | dtype = clip_model.dtype 174 | ctx_dim = clip_model.ln_final.weight.shape[0] 175 | vis_dim = clip_model.visual.output_dim 176 | clip_imsize = clip_model.visual.input_resolution 177 | cfg_imsize = cfg.INPUT.SIZE[0] 178 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 179 | 180 | self.domain_tokens = domain_projector() 181 | self.image_tokens = image_projector() 182 | self.style_mapping_tokens = style_mapping_projector() 183 | 184 | prompt_prefix = " ".join(["X"] * n_ctx) 185 | classnames = [name.replace("_", " ") for name in classnames] 186 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 187 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 188 | 189 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 190 | with torch.no_grad(): 191 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 192 | 193 | # These token vectors will be saved when in save_model(), 194 | # but they should be ignored in load_model() as we want to use 195 | # those computed using the current class names 196 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 197 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 198 | 199 | self.n_cls = n_cls 200 | self.n_ctx = n_ctx 201 | self.n_imgctx = n_imgctx 202 | self.tokenized_prompts = tokenized_prompts 203 | self.name_lens = name_lens 204 | 205 | def construct_prompts(self, ctx, prefix, suffix, label=None): 206 | # dim0 is either batch_size (during training) or n_cls (during testing) 207 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 208 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 209 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 210 | 211 | if label is not None: 212 | prefix = prefix[label] 213 | suffix = suffix[label] 214 | 215 | 216 | prompts = torch.cat( 217 | [ 218 | prefix, 219 | ctx, 220 | suffix, 221 | ], 222 | dim=1, 223 | ) 224 | 225 | return prompts 226 | @autocast() 227 | def forward(self, source_data, target_data): 228 | prefix = self.token_prefix 229 | suffix = self.token_suffix 230 | n_imgctx = self.n_imgctx 231 | 232 | source_domaintokens = self.domain_tokens(source_data) 233 | source_imagetokens = self.image_tokens(source_data, n_imgctx) 234 | source_style_mappingtokens = self.style_mapping_tokens(source_data) 235 | 236 | target_domaintokens = self.domain_tokens(target_data) 237 | target_imagetokens = self.image_tokens(target_data, n_imgctx) 238 | 239 | source_tokens = torch.cat((source_domaintokens, target_domaintokens, source_imagetokens), dim=1) 240 | target_tokens = torch.cat((source_domaintokens, target_domaintokens, target_imagetokens), dim=1) 241 | 242 | source_prompts = [] 243 | for tokens_i in source_tokens: 244 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1) 245 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) 246 | source_prompts.append(pts_i) 247 | source_prompts = torch.stack(source_prompts) 248 | 249 | target_prompts = [] 250 | for tokens_i in target_tokens: 251 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1) 252 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) 253 | target_prompts.append(pts_i) 254 | target_prompts = torch.stack(target_prompts) 255 | 256 | return source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens 257 | 258 | class CustomCLIP(nn.Module): 259 | def __init__(self, cfg, classnames, clip_model): 260 | super().__init__() 261 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 262 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 263 | self.image_encoder = clip_model.visual 264 | self.text_encoder = TextEncoder(clip_model) 265 | self.logit_scale = clip_model.logit_scale 266 | self.dtype = clip_model.dtype 267 | 268 | @autocast() 269 | def forward(self, s_image, t_image): 270 | source_image_features, source_data = self.image_encoder(s_image.type(self.dtype)) 271 | target_image_features, target_data = self.image_encoder(t_image.type(self.dtype)) 272 | 273 | source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens = self.prompt_learner(source_data, target_data) 274 | tokenized_prompts = self.tokenized_prompts 275 | 276 | source_image_features = source_image_features / source_image_features.norm(dim=-1, 277 | keepdim=True) 278 | target_image_features = target_image_features / target_image_features.norm(dim=-1, 279 | keepdim=True) 280 | logit_scale = self.logit_scale.exp() 281 | 282 | source_text_features = [] 283 | for pts_i in source_prompts: 284 | tf = self.text_encoder(pts_i, tokenized_prompts) 285 | source_text_features.append(tf) 286 | source_text_features=torch.stack(source_text_features) 287 | source_text_features = source_text_features / source_text_features.norm(dim=-1, keepdim=True) 288 | 289 | target_text_features = [] 290 | for pts_i in target_prompts: 291 | tf = self.text_encoder(pts_i, tokenized_prompts) 292 | target_text_features.append(tf) 293 | target_text_features=torch.stack(target_text_features) 294 | target_text_features = target_text_features / target_text_features.norm(dim=-1, keepdim=True) 295 | 296 | 297 | source_logits = [] 298 | 299 | for txt, im in zip(source_text_features, source_image_features): 300 | l_i = logit_scale * im @ txt.t() 301 | source_logits.append(l_i) 302 | source_logits = torch.stack(source_logits) 303 | 304 | target_logits = [] 305 | 306 | for txt, im in zip(target_text_features, target_image_features): 307 | l_i = logit_scale * im @ txt.t() 308 | target_logits.append(l_i) 309 | target_logits = torch.stack(target_logits) 310 | 311 | target_probs = torch.nn.functional.softmax(target_logits, dim=1) 312 | 313 | return source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features 314 | 315 | 316 | class entropy_loss(nn.Module): 317 | def __init__(self): 318 | super(entropy_loss, self).__init__() 319 | 320 | def forward(self, target_prob): 321 | full_enp = torch.zeros(target_prob.shape[0]) 322 | target_prob = nn.functional.normalize(target_prob, dim=0) 323 | 324 | for i in range(len(target_prob)): 325 | total_en = 0 326 | for j in range(target_prob.shape[1]): 327 | total_en = total_en - target_prob[i][j] * torch.log(target_prob[i][j] + 1e-8) 328 | full_enp[i] = total_en 329 | avg_full_enp = torch.mean(full_enp) 330 | return avg_full_enp 331 | 332 | 333 | @TRAINER_REGISTRY.register() 334 | class ADCLIPRN50(TrainerXU): 335 | def check_cfg(self, cfg): 336 | assert cfg.TRAINER.ADCLIPRN50.PREC in ["fp16", "fp32", "amp"] 337 | 338 | def build_model(self): 339 | cfg = self.cfg 340 | classnames = self.dm.dataset.classnames 341 | 342 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 343 | clip_model = load_clip_to_cpu(cfg) 344 | 345 | if cfg.TRAINER.ADCLIPRN50.PREC == "fp32" or cfg.TRAINER.ADCLIPRN50.PREC == "amp": 346 | # CLIP's default precision is fp16 347 | clip_model.float() 348 | 349 | print("Building custom CLIP") 350 | self.model = CustomCLIP(cfg, classnames, clip_model) 351 | 352 | self.n_cls = self.model.prompt_learner.n_cls 353 | 354 | name_to_update = "prompt_learner" 355 | 356 | for name, param in self.model.named_parameters(): 357 | if name_to_update not in name: 358 | param.requires_grad_(False) 359 | 360 | enabled = set() 361 | for name, param in self.model.named_parameters(): 362 | if param.requires_grad: 363 | enabled.add(name) 364 | print(f"Parameters to be updated: {enabled}") 365 | 366 | if cfg.MODEL.INIT_WEIGHTS: 367 | load_pretrained_weights(self.model.prompt_learner, 368 | cfg.MODEL.INIT_WEIGHTS) 369 | 370 | self.model.to(self.device) 371 | 372 | # transform the epoch to step schedule 373 | len_train_loader_x = len(self.train_loader_x) 374 | len_train_loader_u = len(self.train_loader_u) 375 | if self.cfg.TRAIN.COUNT_ITER == "train_x": 376 | self.num_batches = len_train_loader_x 377 | elif self.cfg.TRAIN.COUNT_ITER == "train_u": 378 | self.num_batches = len_train_loader_u 379 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one": 380 | self.num_batches = min(len_train_loader_x, len_train_loader_u) 381 | else: 382 | raise ValueError 383 | 384 | # NOTE: only give prompt_learner to the optimizer 385 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 386 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 387 | ''' 388 | register model could be updated. When new module needs to be updated 389 | register the module before use 390 | ''' 391 | self.register_model("prompt_learner", self.model.prompt_learner, 392 | self.optim, self.sched) 393 | 394 | self.scaler = GradScaler() if cfg.TRAINER.ADCLIPRN50.PREC == "amp" else None 395 | 396 | device_count = torch.cuda.device_count() 397 | if device_count > 1: 398 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 399 | self.model = nn.DataParallel(self.model) 400 | 401 | def save_model(self, epoch, directory, is_best=False, model_name=""): 402 | names = self.get_model_names() 403 | 404 | for name in names: 405 | model_dict = self._models[name].state_dict() 406 | 407 | optim_dict = None 408 | if self._optims[name] is not None: 409 | optim_dict = self._optims[name].state_dict() 410 | 411 | sched_dict = None 412 | if self._scheds[name] is not None: 413 | sched_dict = self._scheds[name].state_dict() 414 | 415 | save_checkpoint( 416 | { 417 | "state_dict": model_dict, 418 | "epoch": epoch + 1, 419 | "optimizer": optim_dict, 420 | "scheduler": sched_dict, 421 | }, 422 | osp.join(directory, name), 423 | is_best=is_best, 424 | model_name=model_name, 425 | ) 426 | 427 | def train(self): 428 | """Generic training loops.""" 429 | 430 | self.before_train() 431 | for self.epoch in range(self.start_epoch, self.max_epoch): 432 | self.before_epoch() 433 | self.run_epoch() 434 | self.after_epoch() 435 | self.after_train() 436 | 437 | def run_epoch(self): 438 | self.set_model_mode("train") 439 | losses = MetricMeter() 440 | batch_time = AverageMeter() 441 | data_time = AverageMeter() 442 | 443 | # Decide to iterate over labeled or unlabeled dataset 444 | len_train_loader_x = len(self.train_loader_x) 445 | len_train_loader_u = len(self.train_loader_u) 446 | if self.cfg.TRAIN.COUNT_ITER == "train_x": 447 | self.num_batches = len_train_loader_x 448 | elif self.cfg.TRAIN.COUNT_ITER == "train_u": 449 | self.num_batches = len_train_loader_u 450 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one": 451 | self.num_batches = min(len_train_loader_x, len_train_loader_u) 452 | else: 453 | raise ValueError 454 | 455 | train_loader_x_iter = iter(self.train_loader_x) 456 | train_loader_u_iter = iter(self.train_loader_u) 457 | 458 | 459 | end = time.time() 460 | for self.batch_idx in range(self.num_batches): 461 | try: 462 | batch_x = next(train_loader_x_iter) 463 | except StopIteration: 464 | train_loader_x_iter = iter(self.train_loader_x) 465 | batch_x = next(train_loader_x_iter) 466 | 467 | try: 468 | batch_u = next(train_loader_u_iter) 469 | except StopIteration: 470 | train_loader_u_iter = iter(self.train_loader_u) 471 | batch_u = next(train_loader_u_iter) 472 | 473 | data_time.update(time.time() - end) 474 | loss_summary = self.forward_backward(batch_x, batch_u) 475 | batch_time.update(time.time() - end) 476 | losses.update(loss_summary) 477 | 478 | if ( 479 | self.batch_idx + 1 480 | ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ: 481 | nb_remain = 0 482 | nb_remain += self.num_batches - self.batch_idx - 1 483 | nb_remain += (self.max_epoch - self.epoch - 484 | 1) * self.num_batches 485 | eta_seconds = batch_time.avg * nb_remain 486 | eta = str(datetime.timedelta(seconds=int(eta_seconds))) 487 | print("epoch [{0}/{1}][{2}/{3}]\t" 488 | "time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 489 | "data {data_time.val:.3f} ({data_time.avg:.3f})\t" 490 | "eta {eta}\t" 491 | "{losses}\t" 492 | "lr {lr:.6e}".format( 493 | self.epoch + 1, 494 | self.max_epoch, 495 | self.batch_idx + 1, 496 | self.num_batches, 497 | batch_time=batch_time, 498 | data_time=data_time, 499 | eta=eta, 500 | losses=losses, 501 | lr=self.get_current_lr(), 502 | )) 503 | 504 | n_iter = self.epoch * self.num_batches + self.batch_idx 505 | for name, meter in losses.meters.items(): 506 | self.write_scalar("train/" + name, meter.avg, n_iter) 507 | self.write_scalar("train/lr", self.get_current_lr(), n_iter) 508 | 509 | end = time.time() 510 | 511 | def forward_backward(self, batch_x, batch_u): 512 | self.entropy = entropy_loss() 513 | kl_loss = nn.KLDivLoss(reduction="batchmean") 514 | image_x, label, image_u = self.parse_batch_train(batch_x, batch_u) 515 | prec = self.cfg.TRAINER.ADCLIPRN50.PREC 516 | # alpha_wt = self.alpha 517 | if prec == "amp": 518 | with autocast(): 519 | source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model(image_x, image_u) 520 | 521 | loss_ce = F.cross_entropy(source_logits, label) 522 | source_textfeat = F.log_softmax(source_text_features, dim=1) 523 | target_textfeat = F.softmax(target_text_features, dim=1) 524 | loss_kl = kl_loss(source_textfeat, target_textfeat) 525 | loss_smn = F.mse_loss(source_domaintokens, source_style_mappingtokens) 526 | loss_entropy = self.entropy(target_probs) 527 | 528 | loss = loss_ce + 0.1*loss_smn + 0.01*loss_entropy + loss_kl 529 | 530 | self.optim.zero_grad() 531 | self.scaler.scale(loss).backward() 532 | self.scaler.step(self.optim) 533 | self.scaler.update() 534 | 535 | 536 | loss_summary = { 537 | "loss": 538 | loss.item(), 539 | "loss_ce": 540 | loss_ce.item(), 541 | "loss_smn": 542 | loss_smn.item(), 543 | "loss_entropy": 544 | loss_entropy.item(), 545 | "loss_kl": 546 | loss_kl.item(), 547 | "acc_x": 548 | compute_accuracy(source_logits[:, :self.n_cls], label)[0].item(), 549 | } 550 | 551 | self.update_lr() 552 | 553 | return loss_summary 554 | 555 | def after_epoch(self): 556 | last_epoch = (self.epoch + 1) == self.max_epoch 557 | do_test = not self.cfg.TEST.NO_TEST 558 | meet_checkpoint_freq = ((self.epoch + 1) % 559 | self.cfg.TRAIN.CHECKPOINT_FREQ == 0 if 560 | self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False) 561 | 562 | if do_test: 563 | curr_result = self.test() 564 | is_best = curr_result > self.best_result 565 | if is_best: 566 | self.best_result = curr_result 567 | self.save_model(self.epoch, 568 | self.output_dir, 569 | model_name="model-best.pth.tar") 570 | 571 | self.set_model_mode("train") 572 | 573 | if meet_checkpoint_freq or last_epoch: 574 | self.save_model(self.epoch, self.output_dir) 575 | 576 | def parse_batch_train(self, batch_x, batch_u): 577 | input = batch_x["img"] 578 | label = batch_x["label"] 579 | input_u = batch_u["img"] 580 | input = input.to(self.device) 581 | label = label.to(self.device) 582 | input_u = input_u.to(self.device) 583 | return input, label, input_u 584 | 585 | def load_model(self, directory, epoch=None): 586 | if not directory: 587 | print( 588 | "Note that load_model() is skipped as no pretrained model is given" 589 | ) 590 | return 591 | 592 | names = self.get_model_names() 593 | 594 | # By default, the best model is loaded 595 | model_file = "model-best.pth.tar" 596 | 597 | if epoch is not None: 598 | model_file = "model.pth.tar-" + str(epoch) 599 | 600 | for name in names: 601 | model_path = osp.join(directory, name, model_file) 602 | 603 | if not osp.exists(model_path): 604 | raise FileNotFoundError( 605 | 'Model not found at "{}"'.format(model_path)) 606 | 607 | checkpoint = load_checkpoint(model_path) 608 | state_dict = checkpoint["state_dict"] 609 | epoch = checkpoint["epoch"] 610 | 611 | # Ignore fixed token vectors 612 | if "token_prefix" in state_dict: 613 | del state_dict["token_prefix"] 614 | 615 | if "token_suffix" in state_dict: 616 | del state_dict["token_suffix"] 617 | 618 | print("Loading weights to {} " 619 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 620 | # set strict=False 621 | self._models[name].load_state_dict(state_dict, strict=False) 622 | 623 | @torch.no_grad() 624 | def test(self, split=None): 625 | """A generic testing pipeline.""" 626 | self.set_model_mode("eval") 627 | self.evaluator.reset() 628 | 629 | if split is None: 630 | split = self.cfg.TEST.SPLIT 631 | 632 | split = "test" 633 | data_loader = self.test_loader 634 | print(f"Evaluate on the *{split}* set") 635 | 636 | 637 | for batch_idx, batch in enumerate(tqdm(data_loader)): 638 | input, label = self.parse_batch_test(batch) 639 | output, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model_inference(input) 640 | self.evaluator.process(output, label) 641 | 642 | results = self.evaluator.evaluate() 643 | 644 | for k, v in results.items(): 645 | tag = f"{split}/{k}" 646 | self.write_scalar(tag, v, self.epoch) 647 | 648 | return list(results.values())[0] 649 | --------------------------------------------------------------------------------