├── images
├── teaser1.png
├── teaser2.png
├── officehome.png
└── architecture.png
├── docs
└── AD-CLIP_poster.pdf
├── clip
├── bpe_simple_vocab_16e6.txt.gz
├── simple_tokenizer.py
├── clip.py
└── model.py
├── requirements.txt
├── configs
├── datasets
│ ├── visda17.yaml
│ ├── mini_domainnet.yaml
│ └── officehome.yaml
└── trainer
│ ├── rn50.yaml
│ ├── vitB16.yaml
│ └── vitL14.yaml
├── scripts
├── main.sh
└── eval.sh
├── LICENSE
├── datasets
├── visda17.py
├── mini_domainnet.py
└── office_home.py
├── README.md
├── train.py
└── trainers
├── adclip_vitB16.py
├── adclip_vitL14.py
└── adclip_rn50.py
/images/teaser1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/images/teaser1.png
--------------------------------------------------------------------------------
/images/teaser2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/images/teaser2.png
--------------------------------------------------------------------------------
/images/officehome.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/images/officehome.png
--------------------------------------------------------------------------------
/docs/AD-CLIP_poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/docs/AD-CLIP_poster.pdf
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/images/architecture.png
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mainaksingha01/AD-CLIP/HEAD/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn
2 | tqdm
3 | ftfy
4 | regex
5 | yacs
6 | einops
7 | h5py
8 | tb-nightly
9 | future
10 | six
--------------------------------------------------------------------------------
/configs/datasets/visda17.yaml:
--------------------------------------------------------------------------------
1 | INPUT:
2 | SIZE: (224, 224)
3 | TRANSFORMS: ["random_flip", "center_crop", "normalize"]
4 |
5 | DATASET:
6 | NAME: "VisDA17"
7 | SOURCE_DOMAINS: ["synthetic"]
8 | TARGET_DOMAINS: ["synthetic"]
9 |
10 | MODEL:
11 | BACKBONE:
12 | NAME: "ViT-B/16"
13 |
14 | TEST:
15 | PER_CLASS_RESULT: True
--------------------------------------------------------------------------------
/configs/datasets/mini_domainnet.yaml:
--------------------------------------------------------------------------------
1 | INPUT:
2 | SIZE: (96, 96)
3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"]
4 |
5 | DATASET:
6 | NAME: "miniDomainNet"
7 | # SOURCE_DOMAINS: ["clipart"]
8 | # SOURCE_DOMAINS: ["painting"]
9 | # SOURCE_DOMAINS: ["real"]
10 | SOURCE_DOMAINS: ["sketch"]
11 |
12 | # TARGET_DOMAINS: ["clipart"]
13 | TARGET_DOMAINS: ["painting"]
14 | # TARGET_DOMAINS: ["real"]
15 | # TARGET_DOMAINS: ["sketch"]
16 |
17 | MODEL:
18 | BACKBONE:
19 | NAME: "ViT-B/16"
--------------------------------------------------------------------------------
/configs/datasets/officehome.yaml:
--------------------------------------------------------------------------------
1 | INPUT:
2 | SIZE: (224, 224)
3 | TRANSFORMS: ["random_flip", "center_crop", "normalize"]
4 |
5 | DATASET:
6 | NAME: "OfficeHome"
7 | # SOURCE_DOMAINS: ["real_world"]
8 | # SOURCE_DOMAINS: ["art"]
9 | # SOURCE_DOMAINS: ["clipart"]
10 | SOURCE_DOMAINS: ["product"]
11 |
12 | TARGET_DOMAINS: ["clipart"]
13 | # TARGET_DOMAINS: ["art"]
14 | # TARGET_DOMAINS: ["product"]
15 | # TARGET_DOMAINS: ["real_world"]
16 | # you can modify the code to explore four domains
17 |
18 | MODEL:
19 | BACKBONE:
20 | NAME: "ViT-B/16"
21 |
--------------------------------------------------------------------------------
/configs/trainer/rn50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 16
4 | TRAIN_U:
5 | BATCH_SIZE: 16
6 | TEST:
7 | BATCH_SIZE: 128
8 | NUM_WORKERS: 4
9 |
10 | INPUT:
11 | SIZE: (224, 224)
12 | INTERPOLATION: "bicubic"
13 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
14 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
15 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
16 |
17 | OPTIM:
18 | NAME: "adam"
19 | LR: 0.01
20 | MAX_EPOCH: 50
21 | LR_SCHEDULER: "cosine"
22 | WARMUP_EPOCH: 1
23 | WARMUP_TYPE: "linear"
24 | WARMUP_MIN_LR: 1e-5
25 |
26 | TRAIN:
27 | PRINT_FREQ: 100
28 |
29 | MODEL:
30 | BACKBONE:
31 | NAME: "RN50"
32 |
33 | TRAINER:
34 | ADCLIPRN50:
35 | PREC: "amp"
36 |
--------------------------------------------------------------------------------
/configs/trainer/vitB16.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 16
4 | TRAIN_U:
5 | BATCH_SIZE: 16
6 | TEST:
7 | BATCH_SIZE: 128
8 | NUM_WORKERS: 4
9 |
10 | INPUT:
11 | SIZE: (224, 224)
12 | INTERPOLATION: "bicubic"
13 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
14 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
15 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
16 |
17 | OPTIM:
18 | NAME: "adam"
19 | LR: 0.01
20 | MAX_EPOCH: 50
21 | LR_SCHEDULER: "cosine"
22 | WARMUP_EPOCH: 1
23 | WARMUP_TYPE: "linear"
24 | WARMUP_MIN_LR: 1e-5
25 |
26 | TRAIN:
27 | PRINT_FREQ: 100
28 |
29 | MODEL:
30 | BACKBONE:
31 | NAME: "ViT-B/16"
32 |
33 | TRAINER:
34 | ADCLIPB16:
35 | PREC: "amp"
36 |
--------------------------------------------------------------------------------
/configs/trainer/vitL14.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 16
4 | TRAIN_U:
5 | BATCH_SIZE: 16
6 | TEST:
7 | BATCH_SIZE: 128
8 | NUM_WORKERS: 4
9 |
10 | INPUT:
11 | SIZE: (224, 224)
12 | INTERPOLATION: "bicubic"
13 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
14 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
15 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
16 |
17 | OPTIM:
18 | NAME: "adam"
19 | LR: 0.01
20 | MAX_EPOCH: 50
21 | LR_SCHEDULER: "cosine"
22 | WARMUP_EPOCH: 1
23 | WARMUP_TYPE: "linear"
24 | WARMUP_MIN_LR: 1e-5
25 |
26 | TRAIN:
27 | PRINT_FREQ: 100
28 |
29 | MODEL:
30 | BACKBONE:
31 | NAME: "ViT-L/14"
32 |
33 | TRAINER:
34 | ADCLIPL14:
35 | PREC: "amp"
36 |
--------------------------------------------------------------------------------
/scripts/main.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 |
3 | DATA=data # change your data path here
4 | MODE=train
5 |
6 | DATASET=$1 # dataset name; officehome, visda17, mini_domainnet
7 | TRAINER=$2 # ADCLIPRN50, ADCLIPB16, ADCLIPL14
8 | CFG=$3 # config file; rn50, vitB16, vitL14
9 | #SEED=$4
10 |
11 | for SEED in 1 2 3 4 5
12 | do
13 | DIR=output/${DATASET}/${MODE}/${TRAINER}/${CFG}/seed_${SEED}
14 | if [ -d "$DIR" ]; then
15 | echo "Results are available in ${DIR}. Skip this job"
16 | else
17 | echo "Run this job and save the output to ${DIR}"
18 | python train.py \
19 | --root ${DATA} \
20 | --seed ${SEED} \
21 | --trainer ${TRAINER} \
22 | --dataset-config-file configs/datasets/${DATASET}.yaml \
23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
24 | --output-dir ${DIR}
25 | fi
26 | done
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | cd ..
2 |
3 | DATA=data # change your data path here
4 | MODE=test
5 |
6 | DATASET=$1 # dataset name; officehome, visda17, mini_domainnet
7 | TRAINER=$2 # ADCLIPRN50, ADCLIPB16, ADCLIPL14
8 | CFG=$3 # config file; rn50, vitB16, vitL14
9 | # SEED=$4
10 |
11 |
12 | for SEED in 1 2 3 4 5
13 | do
14 | MODEL_DIR=output/${DATASET}/${MODE}/${TRAINER}/${CFG}/seed_${SEED}
15 | DIR=output/${DATASET}/${MODE}/${TRAINER}/${CFG}/seed_${SEED}
16 | if false; then
17 | echo "The results already exist in ${DIR}"
18 | else
19 | python train.py \
20 | --root ${DATA} \
21 | --seed ${SEED} \
22 | --trainer ${TRAINER} \
23 | --dataset-config-file configs/datasets/${DATASET}.yaml \
24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
25 | --output-dir ${DIR} \
26 | --model-dir ${MODEL_DIR} \
27 | #--load-epoch ${LOADEP}\
28 | --eval-only
29 | done
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Mainak Singha
4 | Copyright (c) 2021 Kaiyang Zhou
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/datasets/visda17.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | from ..build import DATASET_REGISTRY
4 | from ..base_dataset import Datum, DatasetBase
5 |
6 |
7 | @DATASET_REGISTRY.register()
8 | class VisDA17(DatasetBase):
9 | """VisDA17.
10 |
11 | Focusing on simulation-to-reality domain shift.
12 |
13 | URL: http://ai.bu.edu/visda-2017/.
14 |
15 | Reference:
16 | - Peng et al. VisDA: The Visual Domain Adaptation
17 | Challenge. ArXiv 2017.
18 | """
19 |
20 | dataset_dir = "visda17"
21 | domains = ["synthetic", "real"]
22 |
23 | def __init__(self, cfg):
24 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
25 | self.dataset_dir = osp.join(root, self.dataset_dir)
26 |
27 | self.check_input_domains(
28 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
29 | )
30 |
31 | train_x = self._read_data("synthetic")
32 | train_u = self._read_data("real")
33 | test = self._read_data("real")
34 |
35 | super().__init__(train_x=train_x, train_u=train_u, test=test)
36 |
37 | def _read_data(self, dname):
38 | filedir = "train" if dname == "synthetic" else "validation"
39 | image_list = osp.join(self.dataset_dir, filedir, "image_list.txt")
40 | items = []
41 | # There is only one source domain
42 | domain = 0
43 |
44 | with open(image_list, "r") as f:
45 | lines = f.readlines()
46 |
47 | for line in lines:
48 | line = line.strip()
49 | impath, label = line.split(" ")
50 | classname = impath.split("/")[0]
51 | impath = osp.join(self.dataset_dir, filedir, impath)
52 | label = int(label)
53 | item = Datum(
54 | impath=impath,
55 | label=label,
56 | domain=domain,
57 | classname=classname
58 | )
59 | items.append(item)
60 |
61 | return items
62 |
--------------------------------------------------------------------------------
/datasets/mini_domainnet.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | from ..build import DATASET_REGISTRY
4 | from ..base_dataset import Datum, DatasetBase
5 |
6 |
7 | @DATASET_REGISTRY.register()
8 | class miniDomainNet(DatasetBase):
9 | """A subset of DomainNet.
10 |
11 | Reference:
12 | - Peng et al. Moment Matching for Multi-Source Domain
13 | Adaptation. ICCV 2019.
14 | - Zhou et al. Domain Adaptive Ensemble Learning.
15 | """
16 |
17 | dataset_dir = "domainnet"
18 | domains = ["clipart", "painting", "real", "sketch"]
19 |
20 | def __init__(self, cfg):
21 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
22 | self.dataset_dir = osp.join(root, self.dataset_dir)
23 | self.split_dir = osp.join(self.dataset_dir, "splits_mini")
24 |
25 | self.check_input_domains(
26 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
27 | )
28 |
29 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
30 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
31 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
32 |
33 | super().__init__(train_x=train_x, train_u=train_u, test=test)
34 |
35 | def _read_data(self, input_domains, split="train"):
36 | items = []
37 |
38 | for domain, dname in enumerate(input_domains):
39 | filename = dname + "_" + split + ".txt"
40 | split_file = osp.join(self.split_dir, filename)
41 |
42 | with open(split_file, "r") as f:
43 | lines = f.readlines()
44 | for line in lines:
45 | line = line.strip()
46 | impath, label = line.split(" ")
47 | classname = impath.split("/")[1]
48 | impath = osp.join(self.dataset_dir, impath)
49 | label = int(label)
50 | item = Datum(
51 | impath=impath,
52 | label=label,
53 | domain=domain,
54 | classname=classname
55 | )
56 | items.append(item)
57 |
58 | return items
59 |
--------------------------------------------------------------------------------
/datasets/office_home.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | from dassl.utils import listdir_nohidden
4 |
5 | from ..build import DATASET_REGISTRY
6 | from ..base_dataset import Datum, DatasetBase
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class OfficeHome(DatasetBase):
11 | """Office-Home.
12 |
13 | Statistics:
14 | - Around 15,500 images.
15 | - 65 classes related to office and home objects.
16 | - 4 domains: Art, Clipart, Product, Real World.
17 | - URL: http://hemanthdv.org/OfficeHome-Dataset/.
18 |
19 | Reference:
20 | - Venkateswara et al. Deep Hashing Network for Unsupervised
21 | Domain Adaptation. CVPR 2017.
22 | """
23 |
24 | dataset_dir = "office_home"
25 | domains = ["art", "clipart", "product", "real_world"]
26 |
27 | def __init__(self, cfg):
28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
29 | self.dataset_dir = osp.join(root, self.dataset_dir)
30 |
31 | self.check_input_domains(
32 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
33 | )
34 |
35 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
36 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
37 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
38 |
39 | super().__init__(train_x=train_x, train_u=train_u, test=test)
40 |
41 | def _read_data(self, input_domains):
42 | items = []
43 |
44 | for domain, dname in enumerate(input_domains):
45 | domain_dir = osp.join(self.dataset_dir, dname)
46 | class_names = listdir_nohidden(domain_dir)
47 | class_names.sort()
48 |
49 | for label, class_name in enumerate(class_names):
50 | class_path = osp.join(domain_dir, class_name)
51 | imnames = listdir_nohidden(class_path)
52 |
53 | for imname in imnames:
54 | impath = osp.join(class_path, imname)
55 | item = Datum(
56 | impath=impath,
57 | label=label,
58 | domain=domain,
59 | classname=class_name.lower(),
60 | )
61 | items.append(item)
62 |
63 | return items
64 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AD-CLIP: Adapting Domains in Prompt Space Using CLIP
2 | Official repository of AD-CLIP, which is focused on domain adaptation using *prompt learning* by adapting pre-trained vision-language models (VLM) like CLIP.
3 |
4 | ## **ICCVw 2023**
5 |
6 | [](https://openaccess.thecvf.com/content/ICCV2023W/OODCV/papers/Singha_AD-CLIP_Adapting_Domains_in_Prompt_Space_Using_CLIP_ICCVW_2023_paper.pdf)
7 | [](https://arxiv.org/pdf/2308.05659.pdf)
8 | [](https://github.com/mainaksingha01/AD-CLIP/blob/master/docs/AD-CLIP_poster.pdf)
9 |
10 | ## Abstract
11 |
12 |
13 | Although deep learning models have shown impressive performance on supervised learning tasks, they often struggle to generalize well when the training (source) and test (target) domains differ. Unsupervised domain adaptation (DA) has emerged as a popular solution to this problem. However, current DA techniques rely on visual backbones, which may lack semantic richness. Despite the potential of large-scale vision-language foundation models like CLIP, their effectiveness for DA has yet to be fully explored. To address this gap, we introduce AD-CLIP, a domain-agnostic prompt learning strategy for CLIP that aims to solve the DA problem in the prompt space. We leverage the frozen vision backbone of CLIP to extract both image style (domain) and content information, which we apply to learn prompt tokens. Our prompts are designed to be domain-invariant and class-generalizable, by conditioning prompt learning on image style and content features simultaneously. We use standard supervised contrastive learning in the source domain, while proposing an entropy minimization strategy to align domains in the embedding space given the target domain data. We also consider a scenario where only target domain samples are available during testing, without any source domain data, and propose a cross-domain style mapping network to hallucinate domain-agnostic tokens. Our extensive experiments on three benchmark DA datasets demonstrate the effectiveness of AD-CLIP compared to existing literature.
14 |
15 | ## Architecture
16 |
17 |
18 |
19 | ## How to install
20 |
21 | ### Create your environment:
22 |
23 | ```bash
24 | $ conda create -n adclip python=3.8
25 | $ conda activate adclip
26 | $ conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=10.2 -c pytorch
27 | $ pip install -r requirements.txt
28 | ```
29 |
30 | ## Code
31 |
32 | - `datasets` folder contains the dataloader files of each datasets.
33 | - `trainers` folder contains the code of our model in three variants ResNet50, ViT-B/16 and ViT-L/14.
34 | - Clone the awesome toolbox of [dassl](https://github.com/KaiyangZhou/Dassl.pytorch/tree/master/dassl) inside this repo.
35 | - In line 464 of `dassl.engine.trainer` file, replace the output by the returns of the `CustomCLIP` class of the trainers (e.g. adclip_vitB16) file for evaluation.
36 | - `scripts` folder holds the scripts of for training and testing.
37 | - Put data path in `main.sh` and `eval.sh`.
38 | - Choose the source and target domains from `configs.datasets` files.
39 |
40 | ```shell (for example)
41 | $ cd scripts
42 | $ bash main.sh officehome ADCLIPB16 vitB16
43 | $ bash eval.sh officehome ADCLIPB16 vitB16
44 | ```
45 |
46 | ## Bibtex
47 |
48 | Please cite the paper if you use our work . Thanks.
49 |
50 | ```
51 | @inproceedings{singha2023ad,
52 | title={Ad-clip: Adapting domains in prompt space using clip},
53 | author={Singha, Mainak and Pal, Harsh and Jha, Ankit and Banerjee, Biplab},
54 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
55 | pages={4355--4364},
56 | year={2023}
57 | }
58 | ```
59 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 |
5 | from dassl.utils import setup_logger, set_random_seed, collect_env_info
6 | from dassl.config import get_cfg_default
7 | from dassl.engine import build_trainer
8 |
9 | from dassl.data.datasets import VisDA17
10 | from dassl.data.datasets import OfficeHome
11 | from dassl.data.datasets import miniDomainNet
12 |
13 | import trainers.adclip_rn50
14 | import trainers.adclip_vitB16
15 | import trainers.adclip_vitL14
16 |
17 |
18 | def print_args(args, cfg):
19 | print("***************")
20 | print("** Arguments **")
21 | print("***************")
22 | optkeys = list(args.__dict__.keys())
23 | optkeys.sort()
24 | for key in optkeys:
25 | print("{}: {}".format(key, args.__dict__[key]))
26 | print("************")
27 | print("** Config **")
28 | print("************")
29 | print(cfg)
30 |
31 |
32 | def reset_cfg(cfg, args):
33 | if args.root:
34 | cfg.DATASET.ROOT = args.root
35 |
36 | if args.output_dir:
37 | cfg.OUTPUT_DIR = args.output_dir
38 |
39 | if args.resume:
40 | cfg.RESUME = args.resume
41 |
42 | if args.seed:
43 | cfg.SEED = args.seed
44 |
45 | if args.source_domains:
46 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains
47 |
48 | if args.target_domains:
49 | cfg.DATASET.TARGET_DOMAINS = args.target_domains
50 |
51 | if args.transforms:
52 | cfg.INPUT.TRANSFORMS = args.transforms
53 |
54 | if args.trainer:
55 | cfg.TRAINER.NAME = args.trainer
56 |
57 | if args.backbone:
58 | cfg.MODEL.BACKBONE.NAME = args.backbone
59 |
60 | if args.head:
61 | cfg.MODEL.HEAD.NAME = args.head
62 |
63 |
64 | def extend_cfg(cfg):
65 | """
66 | Add new config variables for DAPL.
67 |
68 | E.g.
69 | from yacs.config import CfgNode as CN
70 | cfg.TRAINER.MY_MODEL = CN()
71 | cfg.TRAINER.MY_MODEL.PARAM_A = 1.
72 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
73 | cfg.TRAINER.MY_MODEL.PARAM_C = False
74 | """
75 | from yacs.config import CfgNode as CN
76 |
77 | cfg.MODEL.BACKBONE.PATH = "./assets"
78 |
79 |
80 | cfg.TRAINER.ADCLIPRN50 = CN()
81 | cfg.TRAINER.ADCLIPRN50.PREC = "amp" # fp16, fp32, amp
82 |
83 | cfg.TRAINER.ADCLIPB16 = CN()
84 | cfg.TRAINER.ADCLIPB16.PREC = "amp" # fp16, fp32, amp
85 |
86 | cfg.TRAINER.ADCLIPL14 = CN()
87 | cfg.TRAINER.ADCLIPL14.PREC = "amp" # fp16, fp32, amp
88 |
89 |
90 | def setup_cfg(args):
91 | cfg = get_cfg_default()
92 | extend_cfg(cfg)
93 | print(cfg)
94 |
95 | # 1. From the dataset config file
96 | if args.dataset_config_file:
97 | cfg.merge_from_file(args.dataset_config_file)
98 |
99 | # 2. From the method config file
100 | if args.config_file:
101 | cfg.merge_from_file(args.config_file)
102 |
103 | # 3. From input arguments
104 | reset_cfg(cfg, args)
105 |
106 | # 4. From optional input arguments
107 | cfg.merge_from_list(args.opts)
108 |
109 | cfg.freeze()
110 |
111 | return cfg
112 |
113 |
114 | def main(args):
115 | cfg = setup_cfg(args)
116 | if cfg.SEED >= 0:
117 | print("Setting fixed seed: {}".format(cfg.SEED))
118 | set_random_seed(cfg.SEED)
119 | setup_logger(cfg.OUTPUT_DIR)
120 |
121 | if torch.cuda.is_available() and cfg.USE_CUDA:
122 | torch.backends.cudnn.benchmark = True
123 |
124 | print_args(args, cfg)
125 | print("Collecting env info ...")
126 | print("** System info **\n{}\n".format(collect_env_info()))
127 |
128 | trainer = build_trainer(cfg)
129 |
130 | if args.eval_only:
131 | # if True:
132 | print("Yess testing")
133 | trainer.load_model(args.model_dir, epoch=args.load_epoch)
134 | trainer.test()
135 | return
136 |
137 | if not args.no_train:
138 | print("No! Training")
139 | trainer.train()
140 |
141 |
142 | if __name__ == "__main__":
143 | parser = argparse.ArgumentParser()
144 | parser.add_argument("--root", type=str, default="", help="path to dataset")
145 | parser.add_argument("--output-dir",
146 | type=str,
147 | default="",
148 | help="output directory")
149 | parser.add_argument(
150 | "--resume",
151 | type=str,
152 | default="",
153 | help="checkpoint directory (from which the training resumes)",
154 | )
155 | parser.add_argument("--seed",
156 | type=int,
157 | default=-1,
158 | help="only positive value enables a fixed seed")
159 | parser.add_argument("--source-domains",
160 | type=str,
161 | nargs="+",
162 | help="source domains for DA/DG")
163 | parser.add_argument("--target-domains",
164 | type=str,
165 | nargs="+",
166 | help="target domains for DA/DG")
167 | parser.add_argument("--transforms",
168 | type=str,
169 | nargs="+",
170 | help="data augmentation methods")
171 | parser.add_argument("--config-file",
172 | type=str,
173 | default="",
174 | help="path to config file")
175 | parser.add_argument(
176 | "--dataset-config-file",
177 | type=str,
178 | default="",
179 | help="path to config file for dataset setup",
180 | )
181 | parser.add_argument("--trainer",
182 | type=str,
183 | default="",
184 | help="name of trainer")
185 | parser.add_argument("--backbone",
186 | type=str,
187 | default="",
188 | help="name of CNN backbone")
189 | parser.add_argument("--head", type=str, default="", help="name of head")
190 | parser.add_argument("--eval-only",
191 | action="store_true",
192 | help="evaluation only")
193 | parser.add_argument(
194 | "--model-dir",
195 | type=str,
196 | default="",
197 | help="load model from this directory for eval-only mode",
198 | )
199 | parser.add_argument("--load-epoch",
200 | type=int,
201 | help="load model weights at this epoch for evaluation")
202 | parser.add_argument("--no-train",
203 | action="store_true",
204 | help="do not call trainer.train()")
205 | parser.add_argument(
206 | "opts",
207 | default=None,
208 | nargs=argparse.REMAINDER,
209 | help="modify config options using the command-line",
210 | )
211 | args = parser.parse_args()
212 | main(args)
213 |
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Union, List
6 |
7 | import torch
8 | from PIL import Image
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 | from tqdm import tqdm
11 |
12 | from .model import build_model
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | try:
16 | from torchvision.transforms import InterpolationMode
17 | BICUBIC = InterpolationMode.BICUBIC
18 | except ImportError:
19 | BICUBIC = Image.BICUBIC
20 |
21 |
22 | if torch.__version__.split(".") < ["1", "7", "1"]:
23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24 |
25 |
26 | __all__ = ["available_models", "load", "tokenize"]
27 | _tokenizer = _Tokenizer()
28 |
29 | _MODELS = {
30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
37 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
38 | }
39 |
40 |
41 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
42 | os.makedirs(root, exist_ok=True)
43 | filename = os.path.basename(url)
44 |
45 | expected_sha256 = url.split("/")[-2]
46 | download_target = os.path.join(root, filename)
47 |
48 | if os.path.exists(download_target) and not os.path.isfile(download_target):
49 | raise RuntimeError(f"{download_target} exists and is not a regular file")
50 |
51 | if os.path.isfile(download_target):
52 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
53 | return download_target
54 | else:
55 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
56 |
57 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
58 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
59 | while True:
60 | buffer = source.read(8192)
61 | if not buffer:
62 | break
63 |
64 | output.write(buffer)
65 | loop.update(len(buffer))
66 |
67 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
68 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
69 |
70 | return download_target
71 |
72 |
73 | def _transform(n_px):
74 | return Compose([
75 | Resize(n_px, interpolation=BICUBIC),
76 | CenterCrop(n_px),
77 | lambda image: image.convert("RGB"),
78 | ToTensor(),
79 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
80 | ])
81 |
82 |
83 | def available_models() -> List[str]:
84 | """Returns the names of available CLIP models"""
85 | return list(_MODELS.keys())
86 |
87 |
88 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
89 | """Load a CLIP model
90 |
91 | Parameters
92 | ----------
93 | name : str
94 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
95 |
96 | device : Union[str, torch.device]
97 | The device to put the loaded model
98 |
99 | jit : bool
100 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
101 |
102 | Returns
103 | -------
104 | model : torch.nn.Module
105 | The CLIP model
106 |
107 | preprocess : Callable[[PIL.Image], torch.Tensor]
108 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
109 | """
110 | if name in _MODELS:
111 | model_path = _download(_MODELS[name])
112 | elif os.path.isfile(name):
113 | model_path = name
114 | else:
115 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
116 |
117 | try:
118 | # loading JIT archive
119 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
120 | state_dict = None
121 | except RuntimeError:
122 | # loading saved state dict
123 | if jit:
124 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
125 | jit = False
126 | state_dict = torch.load(model_path, map_location="cpu")
127 |
128 | if not jit:
129 | model = build_model(state_dict or model.state_dict()).to(device)
130 | if str(device) == "cpu":
131 | model.float()
132 | return model, _transform(model.visual.input_resolution)
133 |
134 | # patch the device names
135 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
136 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
137 |
138 | def patch_device(module):
139 | try:
140 | graphs = [module.graph] if hasattr(module, "graph") else []
141 | except RuntimeError:
142 | graphs = []
143 |
144 | if hasattr(module, "forward1"):
145 | graphs.append(module.forward1.graph)
146 |
147 | for graph in graphs:
148 | for node in graph.findAllNodes("prim::Constant"):
149 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
150 | node.copyAttributes(device_node)
151 |
152 | model.apply(patch_device)
153 | patch_device(model.encode_image)
154 | patch_device(model.encode_text)
155 |
156 | # patch dtype to float32 on CPU
157 | if str(device) == "cpu":
158 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
159 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
160 | float_node = float_input.node()
161 |
162 | def patch_float(module):
163 | try:
164 | graphs = [module.graph] if hasattr(module, "graph") else []
165 | except RuntimeError:
166 | graphs = []
167 |
168 | if hasattr(module, "forward1"):
169 | graphs.append(module.forward1.graph)
170 |
171 | for graph in graphs:
172 | for node in graph.findAllNodes("aten::to"):
173 | inputs = list(node.inputs())
174 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
175 | if inputs[i].node()["value"] == 5:
176 | inputs[i].node().copyAttributes(float_node)
177 |
178 | model.apply(patch_float)
179 | patch_float(model.encode_image)
180 | patch_float(model.encode_text)
181 |
182 | model.float()
183 |
184 | return model, _transform(model.input_resolution.item())
185 |
186 |
187 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
188 | """
189 | Returns the tokenized representation of given input string(s)
190 |
191 | Parameters
192 | ----------
193 | texts : Union[str, List[str]]
194 | An input string or a list of input strings to tokenize
195 |
196 | context_length : int
197 | The context length to use; all CLIP models use 77 as the context length
198 |
199 | truncate: bool
200 | Whether to truncate the text in case its encoding is longer than the context length
201 |
202 | Returns
203 | -------
204 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
205 | """
206 | if isinstance(texts, str):
207 | texts = [texts]
208 |
209 | sot_token = _tokenizer.encoder["<|startoftext|>"]
210 | eot_token = _tokenizer.encoder["<|endoftext|>"]
211 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
212 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
213 |
214 | for i, tokens in enumerate(all_tokens):
215 | if len(tokens) > context_length:
216 | if truncate:
217 | tokens = tokens[:context_length]
218 | tokens[-1] = eot_token
219 | else:
220 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
221 | result[i, :len(tokens)] = torch.tensor(tokens)
222 |
223 | return result
224 |
--------------------------------------------------------------------------------
/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 |
20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24 |
25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27 |
28 | self.relu = nn.ReLU(inplace=True)
29 | self.downsample = None
30 | self.stride = stride
31 |
32 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34 | self.downsample = nn.Sequential(OrderedDict([
35 | ("-1", nn.AvgPool2d(stride)),
36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37 | ("1", nn.BatchNorm2d(planes * self.expansion))
38 | ]))
39 |
40 | def forward(self, x: torch.Tensor):
41 | identity = x
42 |
43 | out = self.relu(self.bn1(self.conv1(x)))
44 | out = self.relu(self.bn2(self.conv2(out)))
45 | out = self.avgpool(out)
46 | out = self.bn3(self.conv3(out))
47 |
48 | if self.downsample is not None:
49 | identity = self.downsample(x)
50 |
51 | out += identity
52 | out = self.relu(out)
53 | return out
54 |
55 |
56 | class AttentionPool2d(nn.Module):
57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58 | super().__init__()
59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60 | self.k_proj = nn.Linear(embed_dim, embed_dim)
61 | self.q_proj = nn.Linear(embed_dim, embed_dim)
62 | self.v_proj = nn.Linear(embed_dim, embed_dim)
63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64 | self.num_heads = num_heads
65 |
66 | def forward(self, x):
67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70 | x, _ = F.multi_head_attention_forward(
71 | query=x, key=x, value=x,
72 | embed_dim_to_check=x.shape[-1],
73 | num_heads=self.num_heads,
74 | q_proj_weight=self.q_proj.weight,
75 | k_proj_weight=self.k_proj.weight,
76 | v_proj_weight=self.v_proj.weight,
77 | in_proj_weight=None,
78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79 | bias_k=None,
80 | bias_v=None,
81 | add_zero_attn=False,
82 | dropout_p=0,
83 | out_proj_weight=self.c_proj.weight,
84 | out_proj_bias=self.c_proj.bias,
85 | use_separate_proj_weight=True,
86 | training=self.training,
87 | need_weights=False
88 | )
89 |
90 | return x[0]
91 |
92 |
93 | class ModifiedResNet(nn.Module):
94 | """
95 | A ResNet class that is similar to torchvision's but contains the following changes:
96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98 | - The final pooling layer is a QKV attention instead of an average pool
99 | """
100 |
101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102 | super().__init__()
103 | self.output_dim = output_dim
104 | self.input_resolution = input_resolution
105 |
106 | # the 3-layer stem
107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108 | self.bn1 = nn.BatchNorm2d(width // 2)
109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110 | self.bn2 = nn.BatchNorm2d(width // 2)
111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112 | self.bn3 = nn.BatchNorm2d(width)
113 | self.avgpool = nn.AvgPool2d(2)
114 | self.relu = nn.ReLU(inplace=True)
115 |
116 | # residual layers
117 | self._inplanes = width # this is a *mutable* variable used during construction
118 | self.layer1 = self._make_layer(width, layers[0])
119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122 |
123 | embed_dim = width * 32 # the ResNet feature dimension
124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125 |
126 | def _make_layer(self, planes, blocks, stride=1):
127 | layers = [Bottleneck(self._inplanes, planes, stride)]
128 |
129 | self._inplanes = planes * Bottleneck.expansion
130 | for _ in range(1, blocks):
131 | layers.append(Bottleneck(self._inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | def stem(x):
137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138 | x = conv(x)
139 | x = bn(x)
140 | x = self.relu(x)
141 | # x = self.relu(bn(conv(x)))
142 | x = self.avgpool(x)
143 | return x
144 |
145 | x = x.type(self.conv1.weight.dtype)
146 | x = stem(x)
147 | data = []
148 | x1 = self.layer1(x)
149 | data.append(x1)
150 | x2 = self.layer2(x1)
151 | data.append(x2)
152 | x3 = self.layer3(x2)
153 | data.append(x3)
154 | x4 = self.layer4(x3)
155 | data.append(x4)
156 | feat = self.attnpool(x4)
157 | return feat, data
158 |
159 |
160 | class LayerNorm(nn.LayerNorm):
161 | """Subclass torch's LayerNorm to handle fp16."""
162 |
163 | def forward(self, x: torch.Tensor):
164 | orig_type = x.dtype
165 | ret = super().forward(x.type(torch.float32))
166 | return ret.type(orig_type)
167 |
168 |
169 | class QuickGELU(nn.Module):
170 | def forward(self, x: torch.Tensor):
171 | return x * torch.sigmoid(1.702 * x)
172 |
173 |
174 | class ResidualAttentionBlock(nn.Module):
175 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
176 | super().__init__()
177 |
178 | self.attn = nn.MultiheadAttention(d_model, n_head)
179 | self.ln_1 = LayerNorm(d_model)
180 | self.mlp = nn.Sequential(OrderedDict([
181 | ("c_fc", nn.Linear(d_model, d_model * 4)),
182 | ("gelu", QuickGELU()),
183 | ("c_proj", nn.Linear(d_model * 4, d_model))
184 | ]))
185 | self.ln_2 = LayerNorm(d_model)
186 | self.attn_mask = attn_mask
187 |
188 | def attention(self, x: torch.Tensor):
189 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
190 | y = self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
191 | return y
192 |
193 | def forward(self, x: torch.Tensor):
194 | x = x + self.attention(self.ln_1(x))
195 | x = x + self.mlp(self.ln_2(x))
196 | return x
197 |
198 |
199 | class Transformer(nn.Module):
200 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
201 | super().__init__()
202 | self.width = width
203 | self.layers = layers
204 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
205 |
206 | def forward(self, x: torch.Tensor):
207 | data=[]
208 | for layer in self.resblocks:
209 | x = layer(x)
210 | data.append(x.detach().permute(1,0,2))
211 | data = torch.stack(data)
212 | return x, data
213 |
214 |
215 | class VisionTransformer(nn.Module):
216 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
217 | super().__init__()
218 | self.input_resolution = input_resolution
219 | self.output_dim = output_dim
220 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
221 |
222 | scale = width ** -0.5
223 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
224 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
225 | self.ln_pre = LayerNorm(width)
226 |
227 | self.transformer = Transformer(width, layers, heads)
228 |
229 | self.ln_post = LayerNorm(width)
230 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
231 |
232 | def forward(self, x: torch.Tensor):
233 | x = self.conv1(x) # shape = [*, width, grid, grid]
234 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
235 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
236 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
237 | x = x + self.positional_embedding.to(x.dtype)
238 | x = self.ln_pre(x)
239 |
240 | x = x.permute(1, 0, 2) # NLD -> LND
241 | x, data = self.transformer(x)
242 | x = x.permute(1, 0, 2) # LND -> NLD
243 |
244 | x = self.ln_post(x[:, 0, :])
245 |
246 | if self.proj is not None:
247 | x = x @ self.proj
248 | return x, data
249 |
250 |
251 | class CLIP(nn.Module):
252 | def __init__(self,
253 | embed_dim: int,
254 | # vision
255 | image_resolution: int,
256 | vision_layers: Union[Tuple[int, int, int, int], int],
257 | vision_width: int,
258 | vision_patch_size: int,
259 | # text
260 | context_length: int,
261 | vocab_size: int,
262 | transformer_width: int,
263 | transformer_heads: int,
264 | transformer_layers: int
265 | ):
266 | super().__init__()
267 |
268 | self.context_length = context_length
269 |
270 | if isinstance(vision_layers, (tuple, list)):
271 | vision_heads = vision_width * 32 // 64
272 | self.visual = ModifiedResNet(
273 | layers=vision_layers,
274 | output_dim=embed_dim,
275 | heads=vision_heads,
276 | input_resolution=image_resolution,
277 | width=vision_width
278 | )
279 | else:
280 | vision_heads = vision_width // 64
281 | self.visual = VisionTransformer(
282 | input_resolution=image_resolution,
283 | patch_size=vision_patch_size,
284 | width=vision_width,
285 | layers=vision_layers,
286 | heads=vision_heads,
287 | output_dim=embed_dim
288 | )
289 |
290 | self.transformer = Transformer(
291 | width=transformer_width,
292 | layers=transformer_layers,
293 | heads=transformer_heads,
294 | attn_mask=self.build_attention_mask()
295 | )
296 |
297 | self.vocab_size = vocab_size
298 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
299 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
300 | self.ln_final = LayerNorm(transformer_width)
301 |
302 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
303 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
304 |
305 | self.initialize_parameters()
306 |
307 | def initialize_parameters(self):
308 | nn.init.normal_(self.token_embedding.weight, std=0.02)
309 | nn.init.normal_(self.positional_embedding, std=0.01)
310 |
311 | if isinstance(self.visual, ModifiedResNet):
312 | if self.visual.attnpool is not None:
313 | std = self.visual.attnpool.c_proj.in_features ** -0.5
314 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
315 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
316 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
317 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
318 |
319 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
320 | for name, param in resnet_block.named_parameters():
321 | if name.endswith("bn3.weight"):
322 | nn.init.zeros_(param)
323 |
324 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
325 | attn_std = self.transformer.width ** -0.5
326 | fc_std = (2 * self.transformer.width) ** -0.5
327 | for block in self.transformer.resblocks:
328 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
329 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
330 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
331 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
332 |
333 | if self.text_projection is not None:
334 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
335 |
336 | def build_attention_mask(self):
337 | # lazily create causal attention mask, with full attention between the vision tokens
338 | # pytorch uses additive attention mask; fill with -inf
339 | mask = torch.empty(self.context_length, self.context_length)
340 | mask.fill_(float("-inf"))
341 | mask.triu_(1) # zero out the lower diagonal
342 | return mask
343 |
344 | @property
345 | def dtype(self):
346 | return self.visual.conv1.weight.dtype
347 |
348 | def encode_image(self, image):
349 | return self.visual(image.type(self.dtype))
350 |
351 | def encode_text(self, text):
352 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
353 |
354 | x = x + self.positional_embedding.type(self.dtype)
355 | x = x.permute(1, 0, 2) # NLD -> LND
356 | x, temp = self.transformer(x)
357 | x = x.permute(1, 0, 2) # LND -> NLD
358 | x = self.ln_final(x).type(self.dtype)
359 |
360 | # x.shape = [batch_size, n_ctx, transformer.width]
361 | # take features from the eot embedding (eot_token is the highest number in each sequence)
362 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
363 |
364 | return x
365 |
366 | def forward(self, image, text):
367 | image_features = self.encode_image(image)
368 | text_features = self.encode_text(text)
369 |
370 | # normalized features
371 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
372 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
373 |
374 | # cosine similarity as logits
375 | logit_scale = self.logit_scale.exp()
376 | logits_per_image = logit_scale * image_features @ text_features.t()
377 | logits_per_text = logit_scale * text_features @ image_features.t()
378 |
379 | # shape = [global_batch_size, global_batch_size]
380 | return logits_per_image, logits_per_text
381 |
382 |
383 | def convert_weights(model: nn.Module):
384 | """Convert applicable model parameters to fp16"""
385 |
386 | def _convert_weights_to_fp16(l):
387 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
388 | l.weight.data = l.weight.data.half()
389 | if l.bias is not None:
390 | l.bias.data = l.bias.data.half()
391 |
392 | if isinstance(l, nn.MultiheadAttention):
393 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
394 | tensor = getattr(l, attr)
395 | if tensor is not None:
396 | tensor.data = tensor.data.half()
397 |
398 | for name in ["text_projection", "proj"]:
399 | if hasattr(l, name):
400 | attr = getattr(l, name)
401 | if attr is not None:
402 | attr.data = attr.data.half()
403 |
404 | model.apply(_convert_weights_to_fp16)
405 |
406 |
407 | def build_model(state_dict: dict):
408 | vit = "visual.proj" in state_dict
409 |
410 | if vit:
411 | vision_width = state_dict["visual.conv1.weight"].shape[0]
412 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
413 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
414 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
415 | image_resolution = vision_patch_size * grid_size
416 | else:
417 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
418 | vision_layers = tuple(counts)
419 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
420 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
421 | vision_patch_size = None
422 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
423 | image_resolution = output_width * 32
424 |
425 | embed_dim = state_dict["text_projection"].shape[1]
426 | context_length = state_dict["positional_embedding"].shape[0]
427 | vocab_size = state_dict["token_embedding.weight"].shape[0]
428 | transformer_width = state_dict["ln_final.weight"].shape[0]
429 | transformer_heads = transformer_width // 64
430 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
431 |
432 | model = CLIP(
433 | embed_dim,
434 | image_resolution, vision_layers, vision_width, vision_patch_size,
435 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
436 | )
437 |
438 | for key in ["input_resolution", "context_length", "vocab_size"]:
439 | if key in state_dict:
440 | del state_dict[key]
441 |
442 | convert_weights(model)
443 | model.load_state_dict(state_dict)
444 | return model.eval()
445 |
--------------------------------------------------------------------------------
/trainers/adclip_vitB16.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import os
3 | import datetime
4 | import time
5 | from collections import OrderedDict
6 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn import functional as F
11 | from torch.cuda.amp import GradScaler, autocast
12 | from tqdm import tqdm
13 |
14 | from dassl.engine import TRAINER_REGISTRY, TrainerXU
15 | from dassl.metrics import compute_accuracy
16 | from dassl.utils import MetricMeter, AverageMeter, load_pretrained_weights, load_checkpoint, save_checkpoint
17 | from dassl.optim import build_optimizer, build_lr_scheduler
18 |
19 | from clip import clip
20 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
21 |
22 | _tokenizer = _Tokenizer()
23 |
24 |
25 | def load_clip_to_cpu(cfg):
26 | backbone_name = cfg.MODEL.BACKBONE.NAME
27 | url = clip._MODELS[backbone_name]
28 | model_path = clip._download(url, cfg.MODEL.BACKBONE.PATH)
29 |
30 |
31 | try:
32 | model = torch.jit.load(model_path, map_location="cpu").eval()
33 | state_dict = None
34 |
35 | except RuntimeError:
36 | state_dict = torch.load(model_path, map_location="cpu")
37 |
38 | model = clip.build_model(state_dict or model.state_dict())
39 |
40 | return model
41 |
42 | class AdaIN(nn.Module):
43 | def __init__(self):
44 | super().__init__()
45 |
46 | def mu(self, x):
47 | return torch.sum(x,(1))/(x.shape[1])
48 |
49 | def sigma(self, x):
50 | return torch.sqrt((torch.sum((x.permute([1,0,2])-self.mu(x)).permute([1,0,2])**2,(1))+0.000000023)/(x.shape[1]))
51 |
52 |
53 | class domain_projector(nn.Module):
54 | def __init__(self):
55 | super().__init__()
56 | self.linear1 = nn.ModuleList(nn.Linear(768,256) for _ in range (12))
57 | self.linear2 = nn.ModuleList(nn.Linear(256,512) for _ in range (12))
58 | self.adain=AdaIN()
59 | self.gap=nn.AdaptiveAvgPool2d((1,768))
60 | def forward(self, data):
61 | data_prompt=[]
62 | for i in range(len(data)):
63 | x_mu=self.adain.mu(data[i]).unsqueeze(1).to(torch.float32)
64 | x_sigma=self.adain.sigma(data[i]).unsqueeze(1).to(torch.float32)
65 | x_cat = torch.cat((x_mu, x_sigma),1)
66 | x_cat = self.gap(x_cat).squeeze(1)
67 | x_out = self.linear1[i](x_cat)
68 | x_final = self.linear2[i](x_out)
69 | data_prompt.append(x_final)
70 | output = torch.stack(data_prompt, dim=1)
71 | return output
72 |
73 | class image_projector(nn.Module):
74 | def __init__(self):
75 | super().__init__()
76 | self.linear = nn.ModuleList(nn.Linear(768,512) for _ in range (12))
77 | self.adain=AdaIN()
78 | self.lin = nn.Linear(12,1)
79 | self.gap=nn.AdaptiveAvgPool2d((1,768))
80 |
81 | def forward(self, data, n_imgctx):
82 | data_prompt=[]
83 | for i in range(len(data)):
84 | x_gap = self.gap(data[i]).squeeze(1)
85 | x_lin=self.linear[i](x_gap)
86 | data_prompt.append(x_lin)
87 | feat = torch.stack(data_prompt, dim=1)
88 | output = []
89 | for i in range(n_imgctx): # L decoders
90 | x = self.lin(feat.permute(0,2,1))
91 | x = x.permute(0,2,1)
92 | output.append(x)
93 | feat_tokens = torch.stack(output, dim=1).squeeze(2)
94 | return feat_tokens
95 |
96 | class style_mapping_projector(nn.Module):
97 | def __init__(self):
98 | super().__init__()
99 | self.linear1 = nn.ModuleList(nn.Linear(768,384) for _ in range (12))
100 | self.linear2 = nn.ModuleList(nn.Linear(384,512) for _ in range (12))
101 | self.adain=AdaIN()
102 | self.relu = nn.ReLU()
103 | self.gap=nn.AdaptiveAvgPool1d((768))
104 | def forward(self, data):
105 | data_prompt=[]
106 | for i in range(len(data)):
107 | x_mu=self.adain.mu(data[i]).to(torch.float32)
108 | x_sigma=self.adain.sigma(data[i]).to(torch.float32)
109 | x_cat = torch.cat((x_mu, x_sigma),1)
110 | x_gap = self.gap(x_cat)
111 | x_out = self.linear1[i](x_gap)
112 | x_relu = self.relu(x_out)
113 | x_final = self.linear2[i](x_relu)
114 | data_prompt.append(x_final)
115 | output = torch.stack(data_prompt, dim=1)
116 | return output
117 |
118 | class TextEncoder(nn.Module):
119 | def __init__(self, clip_model):
120 | super().__init__()
121 | self.transformer = clip_model.transformer
122 | self.positional_embedding = clip_model.positional_embedding
123 | self.ln_final = clip_model.ln_final
124 | self.text_projection = clip_model.text_projection
125 | self.dtype = clip_model.dtype
126 |
127 | @autocast()
128 | def forward(self, prompts, tokenized_prompts):
129 | x = prompts + self.positional_embedding.type(self.dtype)
130 | x = x.permute(1, 0, 2)
131 | x = self.transformer(x)
132 |
133 | x = x[0].permute(1, 0, 2)
134 | x = self.ln_final(x).type(self.dtype)
135 | x = x[torch.arange(x.shape[0]),
136 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection
137 |
138 | return x
139 |
140 |
141 | class PromptLearner(nn.Module):
142 | def __init__(self, cfg, classnames, clip_model):
143 | super().__init__()
144 | n_cls = len(classnames)
145 | n_imgctx = 4
146 | n_ctx = 24 + n_imgctx
147 |
148 | dtype = clip_model.dtype
149 | clip_imsize = clip_model.visual.input_resolution
150 | cfg_imsize = cfg.INPUT.SIZE[0]
151 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
152 |
153 | self.domain_tokens = domain_projector()
154 | self.image_tokens = image_projector()
155 | self.style_mapping_tokens = style_mapping_projector()
156 |
157 | prompt_prefix = " ".join(["X"] * n_ctx)
158 | classnames = [name.replace("_", " ") for name in classnames]
159 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
160 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
161 |
162 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
163 | with torch.no_grad():
164 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
165 |
166 | # These token vectors will be saved when in save_model(),
167 | # but they should be ignored in load_model() as we want to use
168 | # those computed using the current class names
169 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
170 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
171 |
172 | self.n_cls = n_cls
173 | self.n_ctx = n_ctx
174 | self.n_imgctx = n_imgctx
175 | self.tokenized_prompts = tokenized_prompts
176 | self.name_lens = name_lens
177 |
178 | def construct_prompts(self, ctx, prefix, suffix, label=None):
179 | # dim0 is either batch_size (during training) or n_cls (during testing)
180 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
181 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
182 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
183 |
184 | if label is not None:
185 | prefix = prefix[label]
186 | suffix = suffix[label]
187 |
188 |
189 | prompts = torch.cat(
190 | [
191 | prefix,
192 | ctx,
193 | suffix,
194 | ],
195 | dim=1,
196 | )
197 |
198 | return prompts
199 | @autocast()
200 | def forward(self, source_data, target_data):
201 | prefix = self.token_prefix
202 | suffix = self.token_suffix
203 | n_imgctx = self.n_imgctx
204 |
205 | source_domaintokens = self.domain_tokens(source_data)
206 | source_imagetokens = self.image_tokens(source_data, n_imgctx)
207 | source_style_mappingtokens = self.style_mapping_tokens(source_data)
208 |
209 | target_domaintokens = self.domain_tokens(target_data)
210 | target_imagetokens = self.image_tokens(target_data, n_imgctx)
211 |
212 | source_tokens = torch.cat((source_domaintokens, target_domaintokens, source_imagetokens), dim=1)
213 | target_tokens = torch.cat((source_domaintokens, target_domaintokens, target_imagetokens), dim=1)
214 |
215 | source_prompts = []
216 | for tokens_i in source_tokens:
217 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1)
218 | pts_i = self.construct_prompts(ctx_i, prefix, suffix)
219 | source_prompts.append(pts_i)
220 | source_prompts = torch.stack(source_prompts)
221 |
222 | target_prompts = []
223 | for tokens_i in target_tokens:
224 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1)
225 | pts_i = self.construct_prompts(ctx_i, prefix, suffix)
226 | target_prompts.append(pts_i)
227 | target_prompts = torch.stack(target_prompts)
228 |
229 | return source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens
230 |
231 |
232 | class CustomCLIP(nn.Module):
233 | def __init__(self, cfg, classnames, clip_model):
234 | super().__init__()
235 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
236 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
237 | self.image_encoder = clip_model.visual
238 | self.text_encoder = TextEncoder(clip_model)
239 | self.logit_scale = clip_model.logit_scale
240 | self.dtype = clip_model.dtype
241 |
242 | @autocast()
243 | def forward(self, s_image, t_image):
244 | source_image_features, source_data = self.image_encoder(s_image.type(self.dtype))
245 | target_image_features, target_data = self.image_encoder(t_image.type(self.dtype))
246 |
247 | source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens = self.prompt_learner(source_data, target_data)
248 | tokenized_prompts = self.tokenized_prompts
249 |
250 | source_image_features = source_image_features / source_image_features.norm(dim=-1,
251 | keepdim=True)
252 | target_image_features = target_image_features / target_image_features.norm(dim=-1,
253 | keepdim=True)
254 | logit_scale = self.logit_scale.exp()
255 |
256 | source_text_features = []
257 | for pts_i in source_prompts:
258 | tf = self.text_encoder(pts_i, tokenized_prompts)
259 | source_text_features.append(tf)
260 | source_text_features=torch.stack(source_text_features)
261 | source_text_features = source_text_features / source_text_features.norm(dim=-1, keepdim=True)
262 |
263 | target_text_features = []
264 | for pts_i in target_prompts:
265 | tf = self.text_encoder(pts_i, tokenized_prompts)
266 | target_text_features.append(tf)
267 | target_text_features=torch.stack(target_text_features)
268 | target_text_features = target_text_features / target_text_features.norm(dim=-1, keepdim=True)
269 |
270 |
271 | source_logits = []
272 |
273 | for txt, im in zip(source_text_features, source_image_features):
274 | l_i = logit_scale * im @ txt.t()
275 | source_logits.append(l_i)
276 | source_logits = torch.stack(source_logits)
277 |
278 | target_logits = []
279 |
280 | for txt, im in zip(target_text_features, target_image_features):
281 | l_i = logit_scale * im @ txt.t()
282 | target_logits.append(l_i)
283 | target_logits = torch.stack(target_logits)
284 |
285 | target_probs = torch.nn.functional.softmax(target_logits, dim=1)
286 |
287 | return source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features
288 |
289 |
290 | class entropy_loss(nn.Module):
291 | def __init__(self):
292 | super(entropy_loss, self).__init__()
293 |
294 | def forward(self, target_prob):
295 | full_enp = torch.zeros(target_prob.shape[0])
296 | target_prob = nn.functional.normalize(target_prob, dim=0)
297 |
298 | for i in range(len(target_prob)):
299 | total_en = 0
300 | for j in range(target_prob.shape[1]):
301 | total_en = total_en - target_prob[i][j] * torch.log(target_prob[i][j] + 1e-8)
302 | full_enp[i] = total_en
303 | avg_full_enp = torch.mean(full_enp)
304 | return avg_full_enp
305 |
306 |
307 | @TRAINER_REGISTRY.register()
308 | class ADCLIPB16(TrainerXU):
309 | def check_cfg(self, cfg):
310 | assert cfg.TRAINER.ADCLIPB16.PREC in ["fp16", "fp32", "amp"]
311 |
312 | def build_model(self):
313 | cfg = self.cfg
314 | classnames = self.dm.dataset.classnames
315 |
316 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
317 | clip_model = load_clip_to_cpu(cfg)
318 |
319 | if cfg.TRAINER.ADCLIPB16.PREC == "fp32" or cfg.TRAINER.ADCLIPB16.PREC == "amp":
320 | # CLIP's default precision is fp16
321 | clip_model.float()
322 |
323 | print("Building custom CLIP")
324 | self.model = CustomCLIP(cfg, classnames, clip_model)
325 |
326 | self.n_cls = self.model.prompt_learner.n_cls
327 |
328 | name_to_update = "prompt_learner"
329 |
330 | for name, param in self.model.named_parameters():
331 | if name_to_update not in name:
332 | param.requires_grad_(False)
333 |
334 | # Double check
335 | enabled = set()
336 | for name, param in self.model.named_parameters():
337 | if param.requires_grad:
338 | enabled.add(name)
339 | print(f"Parameters to be updated: {enabled}")
340 |
341 |
342 | if cfg.MODEL.INIT_WEIGHTS:
343 | load_pretrained_weights(self.model.prompt_learner,
344 | cfg.MODEL.INIT_WEIGHTS)
345 |
346 | self.model.to(self.device)
347 |
348 | # transform the epoch to step schedule
349 | len_train_loader_x = len(self.train_loader_x)
350 | len_train_loader_u = len(self.train_loader_u)
351 | if self.cfg.TRAIN.COUNT_ITER == "train_x":
352 | self.num_batches = len_train_loader_x
353 | elif self.cfg.TRAIN.COUNT_ITER == "train_u":
354 | self.num_batches = len_train_loader_u
355 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
356 | self.num_batches = min(len_train_loader_x, len_train_loader_u)
357 | else:
358 | raise ValueError
359 |
360 | # NOTE: only give prompt_learner to the optimizer
361 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
362 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
363 | '''
364 | register model could be updated. When new module needs to be updated
365 | register the module before use
366 | '''
367 | self.register_model("prompt_learner", self.model.prompt_learner,
368 | self.optim, self.sched)
369 |
370 | self.scaler = GradScaler() if cfg.TRAINER.ADCLIPB16.PREC == "amp" else None
371 |
372 | device_count = torch.cuda.device_count()
373 | if device_count > 1:
374 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
375 | self.model = nn.DataParallel(self.model)
376 |
377 | def save_model(self, epoch, directory, is_best=False, model_name=""):
378 | names = self.get_model_names()
379 |
380 | for name in names:
381 | model_dict = self._models[name].state_dict()
382 |
383 | optim_dict = None
384 | if self._optims[name] is not None:
385 | optim_dict = self._optims[name].state_dict()
386 |
387 | sched_dict = None
388 | if self._scheds[name] is not None:
389 | sched_dict = self._scheds[name].state_dict()
390 |
391 | save_checkpoint(
392 | {
393 | "state_dict": model_dict,
394 | "epoch": epoch + 1,
395 | "optimizer": optim_dict,
396 | "scheduler": sched_dict,
397 | },
398 | osp.join(directory, name),
399 | is_best=is_best,
400 | model_name=model_name,
401 | )
402 |
403 | def train(self):
404 | """Generic training loops."""
405 |
406 | self.before_train()
407 | for self.epoch in range(self.start_epoch, self.max_epoch):
408 | self.before_epoch()
409 | self.run_epoch()
410 | self.after_epoch()
411 | self.after_train()
412 |
413 | def run_epoch(self):
414 | self.set_model_mode("train")
415 | losses = MetricMeter()
416 | batch_time = AverageMeter()
417 | data_time = AverageMeter()
418 |
419 | # Decide to iterate over labeled or unlabeled dataset
420 | len_train_loader_x = len(self.train_loader_x)
421 | len_train_loader_u = len(self.train_loader_u)
422 | if self.cfg.TRAIN.COUNT_ITER == "train_x":
423 | self.num_batches = len_train_loader_x
424 | elif self.cfg.TRAIN.COUNT_ITER == "train_u":
425 | self.num_batches = len_train_loader_u
426 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
427 | self.num_batches = min(len_train_loader_x, len_train_loader_u)
428 | else:
429 | raise ValueError
430 |
431 | train_loader_x_iter = iter(self.train_loader_x)
432 | train_loader_u_iter = iter(self.train_loader_u)
433 |
434 |
435 | end = time.time()
436 | for self.batch_idx in range(self.num_batches):
437 | try:
438 | batch_x = next(train_loader_x_iter)
439 | except StopIteration:
440 | train_loader_x_iter = iter(self.train_loader_x)
441 | batch_x = next(train_loader_x_iter)
442 |
443 | try:
444 | batch_u = next(train_loader_u_iter)
445 | except StopIteration:
446 | train_loader_u_iter = iter(self.train_loader_u)
447 | batch_u = next(train_loader_u_iter)
448 |
449 | data_time.update(time.time() - end)
450 | loss_summary = self.forward_backward(batch_x, batch_u)
451 | batch_time.update(time.time() - end)
452 | losses.update(loss_summary)
453 |
454 | if (
455 | self.batch_idx + 1
456 | ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
457 | nb_remain = 0
458 | nb_remain += self.num_batches - self.batch_idx - 1
459 | nb_remain += (self.max_epoch - self.epoch -
460 | 1) * self.num_batches
461 | eta_seconds = batch_time.avg * nb_remain
462 | eta = str(datetime.timedelta(seconds=int(eta_seconds)))
463 | print("epoch [{0}/{1}][{2}/{3}]\t"
464 | "time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
465 | "data {data_time.val:.3f} ({data_time.avg:.3f})\t"
466 | "eta {eta}\t"
467 | "{losses}\t"
468 | "lr {lr:.6e}".format(
469 | self.epoch + 1,
470 | self.max_epoch,
471 | self.batch_idx + 1,
472 | self.num_batches,
473 | batch_time=batch_time,
474 | data_time=data_time,
475 | eta=eta,
476 | losses=losses,
477 | lr=self.get_current_lr(),
478 | ))
479 |
480 | n_iter = self.epoch * self.num_batches + self.batch_idx
481 | for name, meter in losses.meters.items():
482 | self.write_scalar("train/" + name, meter.avg, n_iter)
483 | self.write_scalar("train/lr", self.get_current_lr(), n_iter)
484 |
485 | end = time.time()
486 |
487 | def forward_backward(self, batch_x, batch_u):
488 | self.entropy = entropy_loss()
489 | kl_loss = nn.KLDivLoss(reduction="batchmean")
490 | image_x, label, image_u = self.parse_batch_train(batch_x, batch_u)
491 | prec = self.cfg.TRAINER.ADCLIPB16.PREC
492 | if prec == "amp":
493 | with autocast():
494 | source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model(image_x, image_u)
495 |
496 | loss_ce = F.cross_entropy(source_logits, label)
497 | source_textfeat = F.log_softmax(source_text_features, dim=1)
498 | target_textfeat = F.softmax(target_text_features, dim=1)
499 | loss_kl = kl_loss(source_textfeat, target_textfeat)
500 | loss_smn = F.mse_loss(source_domaintokens, source_style_mappingtokens)
501 | loss_entropy = self.entropy(target_probs)
502 |
503 | loss = loss_ce + 0.1*loss_smn + 0.01*loss_entropy + loss_kl
504 |
505 | self.optim.zero_grad()
506 | self.scaler.scale(loss).backward()
507 | self.scaler.step(self.optim)
508 | self.scaler.update()
509 |
510 |
511 | loss_summary = {
512 | "loss":
513 | loss.item(),
514 | "loss_ce":
515 | loss_ce.item(),
516 | "loss_smn":
517 | loss_smn.item(),
518 | "loss_entropy":
519 | loss_entropy.item(),
520 | "loss_kl":
521 | loss_kl.item(),
522 | "acc_x":
523 | compute_accuracy(source_logits[:, :self.n_cls], label)[0].item(),
524 | }
525 |
526 | self.update_lr()
527 |
528 | return loss_summary
529 |
530 | def after_epoch(self):
531 | last_epoch = (self.epoch + 1) == self.max_epoch
532 | do_test = not self.cfg.TEST.NO_TEST
533 | meet_checkpoint_freq = ((self.epoch + 1) %
534 | self.cfg.TRAIN.CHECKPOINT_FREQ == 0 if
535 | self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False)
536 |
537 | if do_test:
538 | curr_result = self.test()
539 | is_best = curr_result > self.best_result
540 | if is_best:
541 | self.best_result = curr_result
542 | self.save_model(self.epoch,
543 | self.output_dir,
544 | model_name="model-best.pth.tar")
545 |
546 | self.set_model_mode("train")
547 |
548 | if meet_checkpoint_freq or last_epoch:
549 | self.save_model(self.epoch, self.output_dir)
550 |
551 | def parse_batch_train(self, batch_x, batch_u):
552 | input = batch_x["img"]
553 | label = batch_x["label"]
554 | input_u = batch_u["img"]
555 | input = input.to(self.device)
556 | label = label.to(self.device)
557 | input_u = input_u.to(self.device)
558 | return input, label, input_u
559 |
560 | def load_model(self, directory, epoch=None):
561 | if not directory:
562 | print(
563 | "Note that load_model() is skipped as no pretrained model is given"
564 | )
565 | return
566 |
567 | names = self.get_model_names()
568 |
569 | # By default, the best model is loaded
570 | model_file = "model-best.pth.tar"
571 |
572 | if epoch is not None:
573 | model_file = "model.pth.tar-" + str(epoch)
574 |
575 | for name in names:
576 | model_path = osp.join(directory, name, model_file)
577 |
578 | if not osp.exists(model_path):
579 | raise FileNotFoundError(
580 | 'Model not found at "{}"'.format(model_path))
581 |
582 | checkpoint = load_checkpoint(model_path)
583 | state_dict = checkpoint["state_dict"]
584 | epoch = checkpoint["epoch"]
585 |
586 | # Ignore fixed token vectors
587 | if "token_prefix" in state_dict:
588 | del state_dict["token_prefix"]
589 |
590 | if "token_suffix" in state_dict:
591 | del state_dict["token_suffix"]
592 |
593 | print("Loading weights to {} "
594 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
595 | # set strict=False
596 | self._models[name].load_state_dict(state_dict, strict=False)
597 |
598 | @torch.no_grad()
599 | def test(self, split=None):
600 | """A generic testing pipeline."""
601 | self.set_model_mode("eval")
602 | self.evaluator.reset()
603 |
604 | if split is None:
605 | split = self.cfg.TEST.SPLIT
606 |
607 | split = "test"
608 | data_loader = self.test_loader
609 | print(f"Evaluate on the *{split}* set")
610 |
611 |
612 | for batch_idx, batch in enumerate(tqdm(data_loader)):
613 | input, label = self.parse_batch_test(batch)
614 | output, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model_inference(input)
615 | self.evaluator.process(output, label)
616 |
617 | results = self.evaluator.evaluate()
618 |
619 | for k, v in results.items():
620 | tag = f"{split}/{k}"
621 | self.write_scalar(tag, v, self.epoch)
622 |
623 | return list(results.values())[0]
624 |
--------------------------------------------------------------------------------
/trainers/adclip_vitL14.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import os
3 | import datetime
4 | import time
5 | from collections import OrderedDict
6 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn import functional as F
11 | from torch.cuda.amp import GradScaler, autocast
12 | from tqdm import tqdm
13 |
14 | from dassl.engine import TRAINER_REGISTRY, TrainerXU
15 | from dassl.metrics import compute_accuracy
16 | from dassl.utils import MetricMeter, AverageMeter, load_pretrained_weights, load_checkpoint, save_checkpoint
17 | from dassl.optim import build_optimizer, build_lr_scheduler
18 |
19 | from clip import clip
20 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
21 |
22 | _tokenizer = _Tokenizer()
23 |
24 |
25 | def load_clip_to_cpu(cfg):
26 | backbone_name = cfg.MODEL.BACKBONE.NAME
27 | url = clip._MODELS[backbone_name]
28 | model_path = clip._download(url, cfg.MODEL.BACKBONE.PATH)
29 |
30 |
31 | try:
32 | model = torch.jit.load(model_path, map_location="cpu").eval()
33 | state_dict = None
34 |
35 | except RuntimeError:
36 | state_dict = torch.load(model_path, map_location="cpu")
37 |
38 | model = clip.build_model(state_dict or model.state_dict())
39 |
40 | return model
41 |
42 | class AdaIN(nn.Module):
43 | def __init__(self):
44 | super().__init__()
45 |
46 | def mu(self, x):
47 | return torch.sum(x,(1))/(x.shape[1])
48 |
49 | def sigma(self, x):
50 | return torch.sqrt((torch.sum((x.permute([1,0,2])-self.mu(x)).permute([1,0,2])**2,(1))+0.000000023)/(x.shape[1]))
51 |
52 |
53 | class domain_projector(nn.Module):
54 | def __init__(self):
55 | super().__init__()
56 | self.linear1 = nn.ModuleList(nn.Linear(1024,512) for _ in range (24))
57 | self.linear2 = nn.ModuleList(nn.Linear(512,768) for _ in range (24))
58 | self.adain=AdaIN()
59 | self.gap=nn.AdaptiveAvgPool2d((1,1024))
60 | def forward(self, data):
61 | data_prompt=[]
62 | for i in range(len(data)):
63 | x_mu=self.adain.mu(data[i]).unsqueeze(1).to(torch.float32)
64 | x_sigma=self.adain.sigma(data[i]).unsqueeze(1).to(torch.float32)
65 | x_cat = torch.cat((x_mu, x_sigma),1)
66 | x_cat = self.gap(x_cat).squeeze(1)
67 | x_out = self.linear1[i](x_cat)
68 | x_final = self.linear2[i](x_out)
69 | data_prompt.append(x_final)
70 | output = torch.stack(data_prompt, dim=1)
71 | return output
72 |
73 | class image_projector(nn.Module):
74 | def __init__(self):
75 | super().__init__()
76 | self.linear = nn.ModuleList(nn.Linear(1024,768) for _ in range (24))
77 | self.adain=AdaIN()
78 | self.lin = nn.Linear(24,1)
79 | self.gap=nn.AdaptiveAvgPool2d((1,1024))
80 |
81 | def forward(self, data, n_imgctx):
82 | data_prompt=[]
83 | for i in range(len(data)):
84 | x_gap = self.gap(data[i]).squeeze(1)
85 | x_lin=self.linear[i](x_gap)
86 | data_prompt.append(x_lin)
87 | feat = torch.stack(data_prompt, dim=1)
88 | output = []
89 | for i in range(n_imgctx): # L decoders
90 | x = self.lin(feat.permute(0,2,1))
91 | x = x.permute(0,2,1)
92 | output.append(x)
93 | feat_tokens = torch.stack(output, dim=1).squeeze(2)
94 | return feat_tokens
95 |
96 | class style_mapping_projector(nn.Module):
97 | def __init__(self):
98 | super().__init__()
99 | self.linear1 = nn.ModuleList(nn.Linear(1024,640) for _ in range (24))
100 | self.linear2 = nn.ModuleList(nn.Linear(640,768) for _ in range (24))
101 | self.adain=AdaIN()
102 | self.relu = nn.ReLU()
103 | self.gap=nn.AdaptiveAvgPool1d((1024))
104 | def forward(self, data):
105 | data_prompt=[]
106 | for i in range(len(data)):
107 | x_mu=self.adain.mu(data[i]).to(torch.float32)
108 | x_sigma=self.adain.sigma(data[i]).to(torch.float32)
109 | x_cat = torch.cat((x_mu, x_sigma),1)
110 | x_gap = self.gap(x_cat)
111 | x_out = self.linear1[i](x_gap)
112 | x_relu = self.relu(x_out)
113 | x_final = self.linear2[i](x_relu)
114 | data_prompt.append(x_final)
115 | output = torch.stack(data_prompt, dim=1)
116 | return output
117 |
118 | class TextEncoder(nn.Module):
119 | def __init__(self, clip_model):
120 | super().__init__()
121 | self.transformer = clip_model.transformer
122 | self.positional_embedding = clip_model.positional_embedding
123 | self.ln_final = clip_model.ln_final
124 | self.text_projection = clip_model.text_projection
125 | self.dtype = clip_model.dtype
126 |
127 | @autocast()
128 | def forward(self, prompts, tokenized_prompts):
129 | x = prompts + self.positional_embedding.type(self.dtype)
130 | x = x.permute(1, 0, 2) # NLD -> LND
131 | x = self.transformer(x)
132 |
133 | x = x[0].permute(1, 0, 2) # LND -> NLD
134 | x = self.ln_final(x).type(self.dtype)
135 | x = x[torch.arange(x.shape[0]),
136 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection
137 |
138 | return x
139 |
140 |
141 | class PromptLearner(nn.Module):
142 | def __init__(self, cfg, classnames, clip_model):
143 | super().__init__()
144 | n_cls = len(classnames)
145 | n_imgctx = 4
146 | n_ctx = 48 + n_imgctx
147 |
148 | dtype = clip_model.dtype
149 | ctx_dim = clip_model.ln_final.weight.shape[0]
150 | vis_dim = clip_model.visual.output_dim
151 | clip_imsize = clip_model.visual.input_resolution
152 | cfg_imsize = cfg.INPUT.SIZE[0]
153 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
154 |
155 | self.domain_tokens = domain_projector()
156 | self.image_tokens = image_projector()
157 | self.style_mapping_tokens = style_mapping_projector()
158 |
159 | prompt_prefix = " ".join(["X"] * n_ctx)
160 | classnames = [name.replace("_", " ") for name in classnames]
161 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
162 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
163 |
164 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
165 | with torch.no_grad():
166 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
167 |
168 | # These token vectors will be saved when in save_model(),
169 | # but they should be ignored in load_model() as we want to use
170 | # those computed using the current class names
171 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
172 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
173 |
174 | self.n_cls = n_cls
175 | self.n_ctx = n_ctx
176 | self.n_imgctx = n_imgctx
177 | self.tokenized_prompts = tokenized_prompts
178 | self.name_lens = name_lens
179 |
180 | def construct_prompts(self, ctx, prefix, suffix, label=None):
181 | # dim0 is either batch_size (during training) or n_cls (during testing)
182 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
183 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
184 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
185 |
186 | if label is not None:
187 | prefix = prefix[label]
188 | suffix = suffix[label]
189 |
190 |
191 | prompts = torch.cat(
192 | [
193 | prefix,
194 | ctx,
195 | suffix,
196 | ],
197 | dim=1,
198 | )
199 |
200 | return prompts
201 | @autocast()
202 | def forward(self, source_data, target_data):
203 | prefix = self.token_prefix
204 | suffix = self.token_suffix
205 | n_imgctx = self.n_imgctx
206 |
207 | source_domaintokens = self.domain_tokens(source_data)
208 | source_imagetokens = self.image_tokens(source_data, n_imgctx)
209 | source_style_mappingtokens = self.style_mapping_tokens(source_data)
210 |
211 | target_domaintokens = self.domain_tokens(target_data)
212 | target_imagetokens = self.image_tokens(target_data, n_imgctx)
213 |
214 | source_tokens = torch.cat((source_domaintokens, target_domaintokens, source_imagetokens), dim=1)
215 | target_tokens = torch.cat((source_domaintokens, target_domaintokens, target_imagetokens), dim=1)
216 |
217 | source_prompts = []
218 | for tokens_i in source_tokens:
219 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1)
220 | pts_i = self.construct_prompts(ctx_i, prefix, suffix)
221 | source_prompts.append(pts_i)
222 | source_prompts = torch.stack(source_prompts)
223 |
224 | target_prompts = []
225 | for tokens_i in target_tokens:
226 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1)
227 | pts_i = self.construct_prompts(ctx_i, prefix, suffix)
228 | target_prompts.append(pts_i)
229 | target_prompts = torch.stack(target_prompts)
230 |
231 | return source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens
232 |
233 | class CustomCLIP(nn.Module):
234 | def __init__(self, cfg, classnames, clip_model):
235 | super().__init__()
236 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
237 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
238 | self.image_encoder = clip_model.visual
239 | self.text_encoder = TextEncoder(clip_model)
240 | self.logit_scale = clip_model.logit_scale
241 | self.dtype = clip_model.dtype
242 |
243 | @autocast()
244 | def forward(self, s_image, t_image):
245 | source_image_features, source_data = self.image_encoder(s_image.type(self.dtype))
246 | target_image_features, target_data = self.image_encoder(t_image.type(self.dtype))
247 |
248 | source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens = self.prompt_learner(source_data, target_data)
249 | tokenized_prompts = self.tokenized_prompts
250 |
251 | source_image_features = source_image_features / source_image_features.norm(dim=-1,
252 | keepdim=True)
253 | target_image_features = target_image_features / target_image_features.norm(dim=-1,
254 | keepdim=True)
255 | logit_scale = self.logit_scale.exp()
256 |
257 | source_text_features = []
258 | for pts_i in source_prompts:
259 | tf = self.text_encoder(pts_i, tokenized_prompts)
260 | source_text_features.append(tf)
261 | source_text_features=torch.stack(source_text_features)
262 | source_text_features = source_text_features / source_text_features.norm(dim=-1, keepdim=True)
263 |
264 | target_text_features = []
265 | for pts_i in target_prompts:
266 | tf = self.text_encoder(pts_i, tokenized_prompts)
267 | target_text_features.append(tf)
268 | target_text_features=torch.stack(target_text_features)
269 | target_text_features = target_text_features / target_text_features.norm(dim=-1, keepdim=True)
270 |
271 |
272 | source_logits = []
273 |
274 | for txt, im in zip(source_text_features, source_image_features):
275 | l_i = logit_scale * im @ txt.t()
276 | source_logits.append(l_i)
277 | source_logits = torch.stack(source_logits)
278 |
279 | target_logits = []
280 |
281 | for txt, im in zip(target_text_features, target_image_features):
282 | l_i = logit_scale * im @ txt.t()
283 | target_logits.append(l_i)
284 | target_logits = torch.stack(target_logits)
285 |
286 | target_probs = torch.nn.functional.softmax(target_logits, dim=1)
287 |
288 | return source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features
289 |
290 |
291 | class entropy_loss(nn.Module):
292 | def __init__(self):
293 | super(entropy_loss, self).__init__()
294 |
295 | def forward(self, target_prob):
296 | full_enp = torch.zeros(target_prob.shape[0])
297 | target_prob = nn.functional.normalize(target_prob, dim=0)
298 |
299 | for i in range(len(target_prob)):
300 | total_en = 0
301 | for j in range(target_prob.shape[1]):
302 | total_en = total_en - target_prob[i][j] * torch.log(target_prob[i][j] + 1e-8)
303 | full_enp[i] = total_en
304 | avg_full_enp = torch.mean(full_enp)
305 | return avg_full_enp
306 |
307 |
308 | @TRAINER_REGISTRY.register()
309 | class ADCLIPL14(TrainerXU):
310 | def check_cfg(self, cfg):
311 | assert cfg.TRAINER.ADCLIPL14.PREC in ["fp16", "fp32", "amp"]
312 |
313 | def build_model(self):
314 | cfg = self.cfg
315 | classnames = self.dm.dataset.classnames
316 |
317 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
318 | clip_model = load_clip_to_cpu(cfg)
319 |
320 | if cfg.TRAINER.ADCLIPL14.PREC == "fp32" or cfg.TRAINER.ADCLIPL14.PREC == "amp":
321 | # CLIP's default precision is fp16
322 | clip_model.float()
323 |
324 | print("Building custom CLIP")
325 | self.model = CustomCLIP(cfg, classnames, clip_model)
326 |
327 | self.n_cls = self.model.prompt_learner.n_cls
328 |
329 | name_to_update = "prompt_learner"
330 |
331 | for name, param in self.model.named_parameters():
332 | if name_to_update not in name:
333 | param.requires_grad_(False)
334 |
335 | enabled = set()
336 | for name, param in self.model.named_parameters():
337 | if param.requires_grad:
338 | enabled.add(name)
339 | print(f"Parameters to be updated: {enabled}")
340 |
341 | if cfg.MODEL.INIT_WEIGHTS:
342 | load_pretrained_weights(self.model.prompt_learner,
343 | cfg.MODEL.INIT_WEIGHTS)
344 |
345 | self.model.to(self.device)
346 |
347 | # transform the epoch to step schedule
348 | len_train_loader_x = len(self.train_loader_x)
349 | len_train_loader_u = len(self.train_loader_u)
350 | if self.cfg.TRAIN.COUNT_ITER == "train_x":
351 | self.num_batches = len_train_loader_x
352 | elif self.cfg.TRAIN.COUNT_ITER == "train_u":
353 | self.num_batches = len_train_loader_u
354 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
355 | self.num_batches = min(len_train_loader_x, len_train_loader_u)
356 | else:
357 | raise ValueError
358 |
359 | # NOTE: only give prompt_learner to the optimizer
360 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
361 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
362 | '''
363 | register model could be updated. When new module needs to be updated
364 | register the module before use
365 | '''
366 | self.register_model("prompt_learner", self.model.prompt_learner,
367 | self.optim, self.sched)
368 |
369 | self.scaler = GradScaler() if cfg.TRAINER.ADCLIPL14.PREC == "amp" else None
370 |
371 | device_count = torch.cuda.device_count()
372 | if device_count > 1:
373 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
374 | self.model = nn.DataParallel(self.model)
375 |
376 | def save_model(self, epoch, directory, is_best=False, model_name=""):
377 | names = self.get_model_names()
378 |
379 | for name in names:
380 | model_dict = self._models[name].state_dict()
381 |
382 | optim_dict = None
383 | if self._optims[name] is not None:
384 | optim_dict = self._optims[name].state_dict()
385 |
386 | sched_dict = None
387 | if self._scheds[name] is not None:
388 | sched_dict = self._scheds[name].state_dict()
389 |
390 | save_checkpoint(
391 | {
392 | "state_dict": model_dict,
393 | "epoch": epoch + 1,
394 | "optimizer": optim_dict,
395 | "scheduler": sched_dict,
396 | },
397 | osp.join(directory, name),
398 | is_best=is_best,
399 | model_name=model_name,
400 | )
401 |
402 | def train(self):
403 | """Generic training loops."""
404 |
405 | self.before_train()
406 | for self.epoch in range(self.start_epoch, self.max_epoch):
407 | self.before_epoch()
408 | self.run_epoch()
409 | self.after_epoch()
410 | self.after_train()
411 |
412 | def run_epoch(self):
413 | self.set_model_mode("train")
414 | losses = MetricMeter()
415 | batch_time = AverageMeter()
416 | data_time = AverageMeter()
417 |
418 | # Decide to iterate over labeled or unlabeled dataset
419 | len_train_loader_x = len(self.train_loader_x)
420 | len_train_loader_u = len(self.train_loader_u)
421 | if self.cfg.TRAIN.COUNT_ITER == "train_x":
422 | self.num_batches = len_train_loader_x
423 | elif self.cfg.TRAIN.COUNT_ITER == "train_u":
424 | self.num_batches = len_train_loader_u
425 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
426 | self.num_batches = min(len_train_loader_x, len_train_loader_u)
427 | else:
428 | raise ValueError
429 |
430 | train_loader_x_iter = iter(self.train_loader_x)
431 | train_loader_u_iter = iter(self.train_loader_u)
432 |
433 |
434 | end = time.time()
435 | for self.batch_idx in range(self.num_batches):
436 | try:
437 | batch_x = next(train_loader_x_iter)
438 | except StopIteration:
439 | train_loader_x_iter = iter(self.train_loader_x)
440 | batch_x = next(train_loader_x_iter)
441 |
442 | try:
443 | batch_u = next(train_loader_u_iter)
444 | except StopIteration:
445 | train_loader_u_iter = iter(self.train_loader_u)
446 | batch_u = next(train_loader_u_iter)
447 |
448 | data_time.update(time.time() - end)
449 | loss_summary = self.forward_backward(batch_x, batch_u)
450 | batch_time.update(time.time() - end)
451 | losses.update(loss_summary)
452 |
453 | if (
454 | self.batch_idx + 1
455 | ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
456 | nb_remain = 0
457 | nb_remain += self.num_batches - self.batch_idx - 1
458 | nb_remain += (self.max_epoch - self.epoch -
459 | 1) * self.num_batches
460 | eta_seconds = batch_time.avg * nb_remain
461 | eta = str(datetime.timedelta(seconds=int(eta_seconds)))
462 | print("epoch [{0}/{1}][{2}/{3}]\t"
463 | "time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
464 | "data {data_time.val:.3f} ({data_time.avg:.3f})\t"
465 | "eta {eta}\t"
466 | "{losses}\t"
467 | "lr {lr:.6e}".format(
468 | self.epoch + 1,
469 | self.max_epoch,
470 | self.batch_idx + 1,
471 | self.num_batches,
472 | batch_time=batch_time,
473 | data_time=data_time,
474 | eta=eta,
475 | losses=losses,
476 | lr=self.get_current_lr(),
477 | ))
478 |
479 | n_iter = self.epoch * self.num_batches + self.batch_idx
480 | for name, meter in losses.meters.items():
481 | self.write_scalar("train/" + name, meter.avg, n_iter)
482 | self.write_scalar("train/lr", self.get_current_lr(), n_iter)
483 |
484 | end = time.time()
485 |
486 | def forward_backward(self, batch_x, batch_u):
487 | self.entropy = entropy_loss()
488 | kl_loss = nn.KLDivLoss(reduction="batchmean")
489 | image_x, label, image_u = self.parse_batch_train(batch_x, batch_u)
490 | prec = self.cfg.TRAINER.ADCLIPL14.PREC
491 | # alpha_wt = self.alpha
492 | if prec == "amp":
493 | with autocast():
494 | source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model(image_x, image_u)
495 |
496 | loss_ce = F.cross_entropy(source_logits, label)
497 | source_textfeat = F.log_softmax(source_text_features, dim=1)
498 | target_textfeat = F.softmax(target_text_features, dim=1)
499 | loss_kl = kl_loss(source_textfeat, target_textfeat)
500 | loss_smn = F.mse_loss(source_domaintokens, source_style_mappingtokens)
501 | loss_entropy = self.entropy(target_probs)
502 |
503 | loss = loss_ce + 0.1*loss_smn + 0.01*loss_entropy + loss_kl
504 |
505 | self.optim.zero_grad()
506 | self.scaler.scale(loss).backward()
507 | self.scaler.step(self.optim)
508 | self.scaler.update()
509 |
510 |
511 | loss_summary = {
512 | "loss":
513 | loss.item(),
514 | "loss_ce":
515 | loss_ce.item(),
516 | "loss_smn":
517 | loss_smn.item(),
518 | "loss_entropy":
519 | loss_entropy.item(),
520 | "loss_kl":
521 | loss_kl.item(),
522 | "acc_x":
523 | compute_accuracy(source_logits[:, :self.n_cls], label)[0].item(),
524 | }
525 |
526 | self.update_lr()
527 |
528 | return loss_summary
529 |
530 | def after_epoch(self):
531 | last_epoch = (self.epoch + 1) == self.max_epoch
532 | do_test = not self.cfg.TEST.NO_TEST
533 | meet_checkpoint_freq = ((self.epoch + 1) %
534 | self.cfg.TRAIN.CHECKPOINT_FREQ == 0 if
535 | self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False)
536 |
537 | if do_test:
538 | curr_result = self.test()
539 | is_best = curr_result > self.best_result
540 | if is_best:
541 | self.best_result = curr_result
542 | self.save_model(self.epoch,
543 | self.output_dir,
544 | model_name="model-best.pth.tar")
545 |
546 | self.set_model_mode("train")
547 |
548 | if meet_checkpoint_freq or last_epoch:
549 | self.save_model(self.epoch, self.output_dir)
550 |
551 | def parse_batch_train(self, batch_x, batch_u):
552 | input = batch_x["img"]
553 | label = batch_x["label"]
554 | input_u = batch_u["img"]
555 | input = input.to(self.device)
556 | label = label.to(self.device)
557 | input_u = input_u.to(self.device)
558 | return input, label, input_u
559 |
560 | def load_model(self, directory, epoch=None):
561 | if not directory:
562 | print(
563 | "Note that load_model() is skipped as no pretrained model is given"
564 | )
565 | return
566 |
567 | names = self.get_model_names()
568 |
569 | # By default, the best model is loaded
570 | model_file = "model-best.pth.tar"
571 |
572 | if epoch is not None:
573 | model_file = "model.pth.tar-" + str(epoch)
574 |
575 | for name in names:
576 | model_path = osp.join(directory, name, model_file)
577 |
578 | if not osp.exists(model_path):
579 | raise FileNotFoundError(
580 | 'Model not found at "{}"'.format(model_path))
581 |
582 | checkpoint = load_checkpoint(model_path)
583 | state_dict = checkpoint["state_dict"]
584 | epoch = checkpoint["epoch"]
585 |
586 | # Ignore fixed token vectors
587 | if "token_prefix" in state_dict:
588 | del state_dict["token_prefix"]
589 |
590 | if "token_suffix" in state_dict:
591 | del state_dict["token_suffix"]
592 |
593 | print("Loading weights to {} "
594 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
595 | # set strict=False
596 | self._models[name].load_state_dict(state_dict, strict=False)
597 |
598 | @torch.no_grad()
599 | def test(self, split=None):
600 | """A generic testing pipeline."""
601 | self.set_model_mode("eval")
602 | self.evaluator.reset()
603 |
604 | if split is None:
605 | split = self.cfg.TEST.SPLIT
606 |
607 | split = "test"
608 | data_loader = self.test_loader
609 | print(f"Evaluate on the *{split}* set")
610 |
611 |
612 | for batch_idx, batch in enumerate(tqdm(data_loader)):
613 | input, label = self.parse_batch_test(batch)
614 | output, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model_inference(input)
615 | self.evaluator.process(output, label)
616 |
617 | results = self.evaluator.evaluate()
618 |
619 | for k, v in results.items():
620 | tag = f"{split}/{k}"
621 | self.write_scalar(tag, v, self.epoch)
622 |
623 | return list(results.values())[0]
624 |
--------------------------------------------------------------------------------
/trainers/adclip_rn50.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import os
3 | import datetime
4 | import time
5 | from collections import OrderedDict
6 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn import functional as F
11 | from torch.cuda.amp import GradScaler, autocast
12 | from tqdm import tqdm
13 |
14 | from dassl.engine import TRAINER_REGISTRY, TrainerXU
15 | from dassl.metrics import compute_accuracy
16 | from dassl.utils import MetricMeter, AverageMeter, load_pretrained_weights, load_checkpoint, save_checkpoint
17 | from dassl.optim import build_optimizer, build_lr_scheduler
18 |
19 | from clip import clip
20 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
21 |
22 | _tokenizer = _Tokenizer()
23 |
24 | device_cuda = "cuda"
25 |
26 |
27 | def load_clip_to_cpu(cfg):
28 | backbone_name = cfg.MODEL.BACKBONE.NAME
29 | url = clip._MODELS[backbone_name]
30 | model_path = clip._download(url, cfg.MODEL.BACKBONE.PATH)
31 |
32 |
33 | try:
34 | model = torch.jit.load(model_path, map_location="cpu").eval()
35 | state_dict = None
36 |
37 | except RuntimeError:
38 | state_dict = torch.load(model_path, map_location="cpu")
39 |
40 | model = clip.build_model(state_dict or model.state_dict())
41 |
42 | return model
43 |
44 | class AdaIN(nn.Module):
45 | def __init__(self):
46 | super().__init__()
47 | def mu(self, x):
48 | return torch.mean(x, dim=(2, 3))
49 |
50 | def sigma(self, x):
51 | mean = torch.mean(x, dim=(2, 3), keepdim=True)
52 | squared_diff = (x - mean) ** 2
53 | sum_squared_diff = torch.sum(squared_diff, dim=(2, 3))
54 | epsilon = 1e-8
55 | std_dev = torch.sqrt((sum_squared_diff + epsilon) / (x.shape[2] * x.shape[3]))
56 | return std_dev
57 |
58 | class domain_projector(nn.Module):
59 | def __init__(self):
60 | super().__init__()
61 | self.linear1 = []
62 | self.linear1.append(nn.Linear(256,256).to(device_cuda))
63 | self.linear1.append(nn.Linear(512,256).to(device_cuda))
64 | self.linear1.append(nn.Linear(1024,256).to(device_cuda))
65 | self.linear1.append(nn.Linear(2048,256).to(device_cuda))
66 | self.adain=AdaIN()
67 | self.gap = []
68 | self.gap.append(nn.AdaptiveAvgPool2d((1,256)))
69 | self.gap.append(nn.AdaptiveAvgPool2d((1,512)))
70 | self.gap.append(nn.AdaptiveAvgPool2d((1,1024)))
71 | self.gap.append(nn.AdaptiveAvgPool2d((1,2048)))
72 | self.linear2 = nn.ModuleList(nn.Linear(256,512) for _ in range (4))
73 | def forward(self, data):
74 | data_prompt=[]
75 | for i in range(len(data)):
76 | x_mu=self.adain.mu(data[i]).unsqueeze(1).to(torch.float32)
77 | x_sigma=self.adain.sigma(data[i]).unsqueeze(1).to(torch.float32)
78 | x_cat = torch.cat((x_mu, x_sigma),1)
79 | x_cat = self.gap[i](x_cat).squeeze(1)
80 | x_out = self.linear1[i](x_cat)
81 | x_final = self.linear2[i](x_out)
82 | data_prompt.append(x_final)
83 | output = torch.stack(data_prompt, dim=1)
84 | return output
85 |
86 | class image_projector(nn.Module):
87 | def __init__(self):
88 | super().__init__()
89 | self.linear = []
90 | self.linear.append(nn.Linear(256,512).to(device_cuda))
91 | self.linear.append(nn.Linear(512,512).to(device_cuda))
92 | self.linear.append(nn.Linear(1024,512).to(device_cuda))
93 | self.linear.append(nn.Linear(2048,512).to(device_cuda))
94 | self.adain=AdaIN()
95 | self.lin = nn.Linear(4,1)
96 | self.gap=nn.AdaptiveAvgPool2d((1,1))
97 |
98 | def forward(self, data, n_imgctx):
99 | data_prompt=[]
100 | for i in range(len(data)):
101 | x_gap = self.gap(data[i]).squeeze(3).squeeze(2)
102 | x_lin=self.linear[i](x_gap)
103 | data_prompt.append(x_lin)
104 | feat = torch.stack(data_prompt, dim=1)
105 | output = []
106 | for i in range(n_imgctx): # L decoders
107 | x = self.lin(feat.permute(0,2,1))
108 | x = x.permute(0,2,1)
109 | output.append(x)
110 | feat_tokens = torch.stack(output, dim=1).squeeze(2)
111 | return feat_tokens
112 |
113 | class style_mapping_projector(nn.Module):
114 | def __init__(self):
115 | super().__init__()
116 | self.linear1 = []
117 | self.linear1.append(nn.Linear(256,384).to(device_cuda))
118 | self.linear1.append(nn.Linear(512,384).to(device_cuda))
119 | self.linear1.append(nn.Linear(1024,384).to(device_cuda))
120 | self.linear1.append(nn.Linear(2048,384).to(device_cuda))
121 | self.adain=AdaIN()
122 | self.relu = nn.ReLU()
123 | self.gap = []
124 | self.gap.append(nn.AdaptiveAvgPool1d((256)))
125 | self.gap.append(nn.AdaptiveAvgPool1d((512)))
126 | self.gap.append(nn.AdaptiveAvgPool1d((1024)))
127 | self.gap.append(nn.AdaptiveAvgPool1d((2048)))
128 | self.linear2 = nn.ModuleList(nn.Linear(384,512) for _ in range (4))
129 | def forward(self, data):
130 | data_prompt=[]
131 | for i in range(len(data)):
132 | x_mu=self.adain.mu(data[i]).to(torch.float32)
133 | x_sigma=self.adain.sigma(data[i]).to(torch.float32)
134 | x_cat = torch.cat((x_mu, x_sigma),1)
135 | x_gap = self.gap[i](x_cat)
136 | x_out = self.linear1[i](x_gap)
137 | x_relu = self.relu(x_out)
138 | x_final = self.linear2[i](x_relu)
139 | data_prompt.append(x_final)
140 | output = torch.stack(data_prompt, dim=1)
141 | return output
142 |
143 | class TextEncoder(nn.Module):
144 | def __init__(self, clip_model):
145 | super().__init__()
146 | self.transformer = clip_model.transformer
147 | self.positional_embedding = clip_model.positional_embedding
148 | self.ln_final = clip_model.ln_final
149 | self.text_projection = clip_model.text_projection
150 | self.dtype = clip_model.dtype
151 |
152 | @autocast()
153 | def forward(self, prompts, tokenized_prompts):
154 | x = prompts + self.positional_embedding.type(self.dtype)
155 | x = x.permute(1, 0, 2) # NLD -> LND
156 | x = self.transformer(x)
157 |
158 | x = x[0].permute(1, 0, 2) # LND -> NLD
159 | x = self.ln_final(x).type(self.dtype)
160 | x = x[torch.arange(x.shape[0]),
161 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection
162 |
163 | return x
164 |
165 |
166 | class PromptLearner(nn.Module):
167 | def __init__(self, cfg, classnames, clip_model):
168 | super().__init__()
169 | n_cls = len(classnames)
170 | n_imgctx = 4
171 | n_ctx = 8 + n_imgctx
172 |
173 | dtype = clip_model.dtype
174 | ctx_dim = clip_model.ln_final.weight.shape[0]
175 | vis_dim = clip_model.visual.output_dim
176 | clip_imsize = clip_model.visual.input_resolution
177 | cfg_imsize = cfg.INPUT.SIZE[0]
178 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
179 |
180 | self.domain_tokens = domain_projector()
181 | self.image_tokens = image_projector()
182 | self.style_mapping_tokens = style_mapping_projector()
183 |
184 | prompt_prefix = " ".join(["X"] * n_ctx)
185 | classnames = [name.replace("_", " ") for name in classnames]
186 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
187 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
188 |
189 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
190 | with torch.no_grad():
191 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
192 |
193 | # These token vectors will be saved when in save_model(),
194 | # but they should be ignored in load_model() as we want to use
195 | # those computed using the current class names
196 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
197 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
198 |
199 | self.n_cls = n_cls
200 | self.n_ctx = n_ctx
201 | self.n_imgctx = n_imgctx
202 | self.tokenized_prompts = tokenized_prompts
203 | self.name_lens = name_lens
204 |
205 | def construct_prompts(self, ctx, prefix, suffix, label=None):
206 | # dim0 is either batch_size (during training) or n_cls (during testing)
207 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
208 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
209 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
210 |
211 | if label is not None:
212 | prefix = prefix[label]
213 | suffix = suffix[label]
214 |
215 |
216 | prompts = torch.cat(
217 | [
218 | prefix,
219 | ctx,
220 | suffix,
221 | ],
222 | dim=1,
223 | )
224 |
225 | return prompts
226 | @autocast()
227 | def forward(self, source_data, target_data):
228 | prefix = self.token_prefix
229 | suffix = self.token_suffix
230 | n_imgctx = self.n_imgctx
231 |
232 | source_domaintokens = self.domain_tokens(source_data)
233 | source_imagetokens = self.image_tokens(source_data, n_imgctx)
234 | source_style_mappingtokens = self.style_mapping_tokens(source_data)
235 |
236 | target_domaintokens = self.domain_tokens(target_data)
237 | target_imagetokens = self.image_tokens(target_data, n_imgctx)
238 |
239 | source_tokens = torch.cat((source_domaintokens, target_domaintokens, source_imagetokens), dim=1)
240 | target_tokens = torch.cat((source_domaintokens, target_domaintokens, target_imagetokens), dim=1)
241 |
242 | source_prompts = []
243 | for tokens_i in source_tokens:
244 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1)
245 | pts_i = self.construct_prompts(ctx_i, prefix, suffix)
246 | source_prompts.append(pts_i)
247 | source_prompts = torch.stack(source_prompts)
248 |
249 | target_prompts = []
250 | for tokens_i in target_tokens:
251 | ctx_i = tokens_i.unsqueeze(0).expand(self.n_cls, -1, -1)
252 | pts_i = self.construct_prompts(ctx_i, prefix, suffix)
253 | target_prompts.append(pts_i)
254 | target_prompts = torch.stack(target_prompts)
255 |
256 | return source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens
257 |
258 | class CustomCLIP(nn.Module):
259 | def __init__(self, cfg, classnames, clip_model):
260 | super().__init__()
261 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
262 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
263 | self.image_encoder = clip_model.visual
264 | self.text_encoder = TextEncoder(clip_model)
265 | self.logit_scale = clip_model.logit_scale
266 | self.dtype = clip_model.dtype
267 |
268 | @autocast()
269 | def forward(self, s_image, t_image):
270 | source_image_features, source_data = self.image_encoder(s_image.type(self.dtype))
271 | target_image_features, target_data = self.image_encoder(t_image.type(self.dtype))
272 |
273 | source_prompts, target_prompts, source_domaintokens, source_style_mappingtokens = self.prompt_learner(source_data, target_data)
274 | tokenized_prompts = self.tokenized_prompts
275 |
276 | source_image_features = source_image_features / source_image_features.norm(dim=-1,
277 | keepdim=True)
278 | target_image_features = target_image_features / target_image_features.norm(dim=-1,
279 | keepdim=True)
280 | logit_scale = self.logit_scale.exp()
281 |
282 | source_text_features = []
283 | for pts_i in source_prompts:
284 | tf = self.text_encoder(pts_i, tokenized_prompts)
285 | source_text_features.append(tf)
286 | source_text_features=torch.stack(source_text_features)
287 | source_text_features = source_text_features / source_text_features.norm(dim=-1, keepdim=True)
288 |
289 | target_text_features = []
290 | for pts_i in target_prompts:
291 | tf = self.text_encoder(pts_i, tokenized_prompts)
292 | target_text_features.append(tf)
293 | target_text_features=torch.stack(target_text_features)
294 | target_text_features = target_text_features / target_text_features.norm(dim=-1, keepdim=True)
295 |
296 |
297 | source_logits = []
298 |
299 | for txt, im in zip(source_text_features, source_image_features):
300 | l_i = logit_scale * im @ txt.t()
301 | source_logits.append(l_i)
302 | source_logits = torch.stack(source_logits)
303 |
304 | target_logits = []
305 |
306 | for txt, im in zip(target_text_features, target_image_features):
307 | l_i = logit_scale * im @ txt.t()
308 | target_logits.append(l_i)
309 | target_logits = torch.stack(target_logits)
310 |
311 | target_probs = torch.nn.functional.softmax(target_logits, dim=1)
312 |
313 | return source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features
314 |
315 |
316 | class entropy_loss(nn.Module):
317 | def __init__(self):
318 | super(entropy_loss, self).__init__()
319 |
320 | def forward(self, target_prob):
321 | full_enp = torch.zeros(target_prob.shape[0])
322 | target_prob = nn.functional.normalize(target_prob, dim=0)
323 |
324 | for i in range(len(target_prob)):
325 | total_en = 0
326 | for j in range(target_prob.shape[1]):
327 | total_en = total_en - target_prob[i][j] * torch.log(target_prob[i][j] + 1e-8)
328 | full_enp[i] = total_en
329 | avg_full_enp = torch.mean(full_enp)
330 | return avg_full_enp
331 |
332 |
333 | @TRAINER_REGISTRY.register()
334 | class ADCLIPRN50(TrainerXU):
335 | def check_cfg(self, cfg):
336 | assert cfg.TRAINER.ADCLIPRN50.PREC in ["fp16", "fp32", "amp"]
337 |
338 | def build_model(self):
339 | cfg = self.cfg
340 | classnames = self.dm.dataset.classnames
341 |
342 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
343 | clip_model = load_clip_to_cpu(cfg)
344 |
345 | if cfg.TRAINER.ADCLIPRN50.PREC == "fp32" or cfg.TRAINER.ADCLIPRN50.PREC == "amp":
346 | # CLIP's default precision is fp16
347 | clip_model.float()
348 |
349 | print("Building custom CLIP")
350 | self.model = CustomCLIP(cfg, classnames, clip_model)
351 |
352 | self.n_cls = self.model.prompt_learner.n_cls
353 |
354 | name_to_update = "prompt_learner"
355 |
356 | for name, param in self.model.named_parameters():
357 | if name_to_update not in name:
358 | param.requires_grad_(False)
359 |
360 | enabled = set()
361 | for name, param in self.model.named_parameters():
362 | if param.requires_grad:
363 | enabled.add(name)
364 | print(f"Parameters to be updated: {enabled}")
365 |
366 | if cfg.MODEL.INIT_WEIGHTS:
367 | load_pretrained_weights(self.model.prompt_learner,
368 | cfg.MODEL.INIT_WEIGHTS)
369 |
370 | self.model.to(self.device)
371 |
372 | # transform the epoch to step schedule
373 | len_train_loader_x = len(self.train_loader_x)
374 | len_train_loader_u = len(self.train_loader_u)
375 | if self.cfg.TRAIN.COUNT_ITER == "train_x":
376 | self.num_batches = len_train_loader_x
377 | elif self.cfg.TRAIN.COUNT_ITER == "train_u":
378 | self.num_batches = len_train_loader_u
379 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
380 | self.num_batches = min(len_train_loader_x, len_train_loader_u)
381 | else:
382 | raise ValueError
383 |
384 | # NOTE: only give prompt_learner to the optimizer
385 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
386 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
387 | '''
388 | register model could be updated. When new module needs to be updated
389 | register the module before use
390 | '''
391 | self.register_model("prompt_learner", self.model.prompt_learner,
392 | self.optim, self.sched)
393 |
394 | self.scaler = GradScaler() if cfg.TRAINER.ADCLIPRN50.PREC == "amp" else None
395 |
396 | device_count = torch.cuda.device_count()
397 | if device_count > 1:
398 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
399 | self.model = nn.DataParallel(self.model)
400 |
401 | def save_model(self, epoch, directory, is_best=False, model_name=""):
402 | names = self.get_model_names()
403 |
404 | for name in names:
405 | model_dict = self._models[name].state_dict()
406 |
407 | optim_dict = None
408 | if self._optims[name] is not None:
409 | optim_dict = self._optims[name].state_dict()
410 |
411 | sched_dict = None
412 | if self._scheds[name] is not None:
413 | sched_dict = self._scheds[name].state_dict()
414 |
415 | save_checkpoint(
416 | {
417 | "state_dict": model_dict,
418 | "epoch": epoch + 1,
419 | "optimizer": optim_dict,
420 | "scheduler": sched_dict,
421 | },
422 | osp.join(directory, name),
423 | is_best=is_best,
424 | model_name=model_name,
425 | )
426 |
427 | def train(self):
428 | """Generic training loops."""
429 |
430 | self.before_train()
431 | for self.epoch in range(self.start_epoch, self.max_epoch):
432 | self.before_epoch()
433 | self.run_epoch()
434 | self.after_epoch()
435 | self.after_train()
436 |
437 | def run_epoch(self):
438 | self.set_model_mode("train")
439 | losses = MetricMeter()
440 | batch_time = AverageMeter()
441 | data_time = AverageMeter()
442 |
443 | # Decide to iterate over labeled or unlabeled dataset
444 | len_train_loader_x = len(self.train_loader_x)
445 | len_train_loader_u = len(self.train_loader_u)
446 | if self.cfg.TRAIN.COUNT_ITER == "train_x":
447 | self.num_batches = len_train_loader_x
448 | elif self.cfg.TRAIN.COUNT_ITER == "train_u":
449 | self.num_batches = len_train_loader_u
450 | elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
451 | self.num_batches = min(len_train_loader_x, len_train_loader_u)
452 | else:
453 | raise ValueError
454 |
455 | train_loader_x_iter = iter(self.train_loader_x)
456 | train_loader_u_iter = iter(self.train_loader_u)
457 |
458 |
459 | end = time.time()
460 | for self.batch_idx in range(self.num_batches):
461 | try:
462 | batch_x = next(train_loader_x_iter)
463 | except StopIteration:
464 | train_loader_x_iter = iter(self.train_loader_x)
465 | batch_x = next(train_loader_x_iter)
466 |
467 | try:
468 | batch_u = next(train_loader_u_iter)
469 | except StopIteration:
470 | train_loader_u_iter = iter(self.train_loader_u)
471 | batch_u = next(train_loader_u_iter)
472 |
473 | data_time.update(time.time() - end)
474 | loss_summary = self.forward_backward(batch_x, batch_u)
475 | batch_time.update(time.time() - end)
476 | losses.update(loss_summary)
477 |
478 | if (
479 | self.batch_idx + 1
480 | ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
481 | nb_remain = 0
482 | nb_remain += self.num_batches - self.batch_idx - 1
483 | nb_remain += (self.max_epoch - self.epoch -
484 | 1) * self.num_batches
485 | eta_seconds = batch_time.avg * nb_remain
486 | eta = str(datetime.timedelta(seconds=int(eta_seconds)))
487 | print("epoch [{0}/{1}][{2}/{3}]\t"
488 | "time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
489 | "data {data_time.val:.3f} ({data_time.avg:.3f})\t"
490 | "eta {eta}\t"
491 | "{losses}\t"
492 | "lr {lr:.6e}".format(
493 | self.epoch + 1,
494 | self.max_epoch,
495 | self.batch_idx + 1,
496 | self.num_batches,
497 | batch_time=batch_time,
498 | data_time=data_time,
499 | eta=eta,
500 | losses=losses,
501 | lr=self.get_current_lr(),
502 | ))
503 |
504 | n_iter = self.epoch * self.num_batches + self.batch_idx
505 | for name, meter in losses.meters.items():
506 | self.write_scalar("train/" + name, meter.avg, n_iter)
507 | self.write_scalar("train/lr", self.get_current_lr(), n_iter)
508 |
509 | end = time.time()
510 |
511 | def forward_backward(self, batch_x, batch_u):
512 | self.entropy = entropy_loss()
513 | kl_loss = nn.KLDivLoss(reduction="batchmean")
514 | image_x, label, image_u = self.parse_batch_train(batch_x, batch_u)
515 | prec = self.cfg.TRAINER.ADCLIPRN50.PREC
516 | # alpha_wt = self.alpha
517 | if prec == "amp":
518 | with autocast():
519 | source_logits, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model(image_x, image_u)
520 |
521 | loss_ce = F.cross_entropy(source_logits, label)
522 | source_textfeat = F.log_softmax(source_text_features, dim=1)
523 | target_textfeat = F.softmax(target_text_features, dim=1)
524 | loss_kl = kl_loss(source_textfeat, target_textfeat)
525 | loss_smn = F.mse_loss(source_domaintokens, source_style_mappingtokens)
526 | loss_entropy = self.entropy(target_probs)
527 |
528 | loss = loss_ce + 0.1*loss_smn + 0.01*loss_entropy + loss_kl
529 |
530 | self.optim.zero_grad()
531 | self.scaler.scale(loss).backward()
532 | self.scaler.step(self.optim)
533 | self.scaler.update()
534 |
535 |
536 | loss_summary = {
537 | "loss":
538 | loss.item(),
539 | "loss_ce":
540 | loss_ce.item(),
541 | "loss_smn":
542 | loss_smn.item(),
543 | "loss_entropy":
544 | loss_entropy.item(),
545 | "loss_kl":
546 | loss_kl.item(),
547 | "acc_x":
548 | compute_accuracy(source_logits[:, :self.n_cls], label)[0].item(),
549 | }
550 |
551 | self.update_lr()
552 |
553 | return loss_summary
554 |
555 | def after_epoch(self):
556 | last_epoch = (self.epoch + 1) == self.max_epoch
557 | do_test = not self.cfg.TEST.NO_TEST
558 | meet_checkpoint_freq = ((self.epoch + 1) %
559 | self.cfg.TRAIN.CHECKPOINT_FREQ == 0 if
560 | self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False)
561 |
562 | if do_test:
563 | curr_result = self.test()
564 | is_best = curr_result > self.best_result
565 | if is_best:
566 | self.best_result = curr_result
567 | self.save_model(self.epoch,
568 | self.output_dir,
569 | model_name="model-best.pth.tar")
570 |
571 | self.set_model_mode("train")
572 |
573 | if meet_checkpoint_freq or last_epoch:
574 | self.save_model(self.epoch, self.output_dir)
575 |
576 | def parse_batch_train(self, batch_x, batch_u):
577 | input = batch_x["img"]
578 | label = batch_x["label"]
579 | input_u = batch_u["img"]
580 | input = input.to(self.device)
581 | label = label.to(self.device)
582 | input_u = input_u.to(self.device)
583 | return input, label, input_u
584 |
585 | def load_model(self, directory, epoch=None):
586 | if not directory:
587 | print(
588 | "Note that load_model() is skipped as no pretrained model is given"
589 | )
590 | return
591 |
592 | names = self.get_model_names()
593 |
594 | # By default, the best model is loaded
595 | model_file = "model-best.pth.tar"
596 |
597 | if epoch is not None:
598 | model_file = "model.pth.tar-" + str(epoch)
599 |
600 | for name in names:
601 | model_path = osp.join(directory, name, model_file)
602 |
603 | if not osp.exists(model_path):
604 | raise FileNotFoundError(
605 | 'Model not found at "{}"'.format(model_path))
606 |
607 | checkpoint = load_checkpoint(model_path)
608 | state_dict = checkpoint["state_dict"]
609 | epoch = checkpoint["epoch"]
610 |
611 | # Ignore fixed token vectors
612 | if "token_prefix" in state_dict:
613 | del state_dict["token_prefix"]
614 |
615 | if "token_suffix" in state_dict:
616 | del state_dict["token_suffix"]
617 |
618 | print("Loading weights to {} "
619 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
620 | # set strict=False
621 | self._models[name].load_state_dict(state_dict, strict=False)
622 |
623 | @torch.no_grad()
624 | def test(self, split=None):
625 | """A generic testing pipeline."""
626 | self.set_model_mode("eval")
627 | self.evaluator.reset()
628 |
629 | if split is None:
630 | split = self.cfg.TEST.SPLIT
631 |
632 | split = "test"
633 | data_loader = self.test_loader
634 | print(f"Evaluate on the *{split}* set")
635 |
636 |
637 | for batch_idx, batch in enumerate(tqdm(data_loader)):
638 | input, label = self.parse_batch_test(batch)
639 | output, target_probs, source_domaintokens, source_style_mappingtokens, source_text_features, target_text_features = self.model_inference(input)
640 | self.evaluator.process(output, label)
641 |
642 | results = self.evaluator.evaluate()
643 |
644 | for k, v in results.items():
645 | tag = f"{split}/{k}"
646 | self.write_scalar(tag, v, self.epoch)
647 |
648 | return list(results.values())[0]
649 |
--------------------------------------------------------------------------------