├── .gitignore ├── LICENSE ├── README.md ├── demo.py ├── demo.sh ├── list └── DOMAINNET │ ├── sketch_test.txt │ └── sketch_train.txt └── task.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official implementation for **UEO** 2 | 3 | ## [**[ICML-2024] Realistic Unsupervised CLIP Fine-tuning with Universal Entropy Optimization**](https://arxiv.org/abs/2308.12919) 4 | 5 | ### Framework: 6 | 7 | 8 | ### Citation 9 | If you find our paper and repository useful for your research, please consider citing our paper: 10 | ```bibtex 11 | @inproceedings{liang2024realistic, 12 | title={Realistic Unsupervised CLIP Fine-tuning with Universal Entropy Optimization}, 13 | author={Liang, Jian and Sheng, Lijun and Wang, Zhengbo and He, Ran and Tan, Tieniu}, 14 | booktitle={International Conference on Machine Learning (ICML)}, 15 | year={2024} 16 | } 17 | ``` 18 | 19 | ### Contact 20 | - [liangjian92@gmail.com](mailto:liangjian92@gmail.com) -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | from torch.utils.data import DataLoader 11 | from data_list import ImageList, ImageList_idx 12 | import random, pdb, math, copy 13 | from tqdm import tqdm 14 | import sklearn.metrics as skm 15 | 16 | import clip 17 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 18 | 19 | class CoOp_PromptLearner(nn.Module): 20 | def __init__(self, classnames, clip_model): 21 | super().__init__() 22 | self.dtype = clip_model.dtype 23 | self.ctx_dim = clip_model.ln_final.weight.shape[0] 24 | self.n_cls = len(classnames) 25 | ctx_init = 'a photo of a' 26 | self.n_ctx = len(ctx_init.split(" ")) 27 | 28 | if ctx_init: 29 | # use given words to initialize context vectors 30 | prompt = clip.tokenize(ctx_init).cuda() 31 | with torch.no_grad(): 32 | embedding = clip_model.token_embedding(prompt).type(self.dtype) 33 | ctx_vectors = embedding[0, 1 : 1 + self.n_ctx, :].cuda() 34 | self.prompt_prefix = ctx_init 35 | else: 36 | # random initialization 37 | ctx_vectors = torch.empty(self.n_ctx, self.ctx_dim, dtype=self.dtype).cuda() 38 | nn.init.normal_(ctx_vectors, std=0.02) 39 | self.prompt_prefix = " ".join(["X"] * self.n_ctx) 40 | 41 | self.ctx = nn.Parameter(ctx_vectors) 42 | self.get_prefix_suffix_token(classnames, clip_model) 43 | 44 | print('Initial context: {:}, Number of context words (tokens): {:}'.format(self.prompt_prefix, self.n_ctx)) 45 | 46 | def get_prefix_suffix_token(self, classnames, clip_model): 47 | prompt_prefix = self.prompt_prefix 48 | classnames = [name.replace("_", " ") for name in classnames] 49 | _tokenizer = _Tokenizer() 50 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 51 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 52 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).cuda() # (n_cls, n_tkn) 53 | with torch.no_grad(): 54 | embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype) 55 | 56 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 57 | self.register_buffer("token_suffix", embedding[:, 1 + self.n_ctx :, :]) # CLS, EOS 58 | 59 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 60 | self.name_lens = name_lens 61 | 62 | def forward(self): 63 | ctx = self.ctx 64 | if ctx.dim() == 2: 65 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 66 | 67 | prompts = torch.cat([self.token_prefix, ctx, self.token_suffix], dim=1) 68 | 69 | return prompts 70 | 71 | class TextEncoder(nn.Module): 72 | def __init__(self, clip_model): 73 | super().__init__() 74 | self.transformer = clip_model.transformer 75 | self.positional_embedding = clip_model.positional_embedding 76 | self.ln_final = clip_model.ln_final 77 | self.text_projection = clip_model.text_projection 78 | self.dtype = clip_model.dtype 79 | 80 | def forward(self, prompts, tokenized_prompts): 81 | x = prompts + self.positional_embedding.type(self.dtype) 82 | x = x.permute(1, 0, 2) # NLD -> LND 83 | x = self.transformer(x) 84 | x = x.permute(1, 0, 2) # LND -> NLD 85 | x = self.ln_final(x).type(self.dtype) 86 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection.type(self.dtype) 87 | 88 | return x 89 | 90 | def image_clip_train(resize_size=224): 91 | normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 92 | std=[0.26862954, 0.26130258, 0.27577711]) 93 | return transforms.Compose([ 94 | transforms.RandomResizedCrop(size=resize_size, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 95 | transforms.RandomHorizontalFlip(p=0.5), 96 | _convert_image_to_rgb, 97 | transforms.ToTensor(), 98 | normalize 99 | ]) 100 | 101 | def image_clip_test(resize_size=224): 102 | normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 103 | std=[0.26862954, 0.26130258, 0.27577711]) 104 | return transforms.Compose([ 105 | transforms.Resize(size=resize_size, interpolation=transforms.InterpolationMode.BICUBIC), 106 | transforms.CenterCrop(resize_size), 107 | _convert_image_to_rgb, 108 | transforms.ToTensor(), 109 | normalize 110 | ]) 111 | 112 | def _convert_image_to_rgb(image): 113 | return image.convert("RGB") 114 | 115 | def modify_list(args): 116 | new_tar = [] 117 | txt_tar = open(args.train_dset_path).readlines() 118 | for i in range(len(txt_tar)): 119 | rec = txt_tar[i] 120 | reci = rec.strip().split(' ') 121 | if int(reci[1]) in args.tar_classes: 122 | new_tar.append(rec) 123 | txt_tar1 = new_tar.copy() 124 | 125 | new_tar = [] 126 | txt_tar = open(args.test_dset_path).readlines() 127 | for i in range(len(txt_tar)): 128 | rec = txt_tar[i] 129 | reci = rec.strip().split(' ') 130 | if int(reci[1]) in args.tst_classes: 131 | new_tar.append(rec) 132 | txt_tar2 = new_tar.copy() 133 | return txt_tar1, txt_tar2 134 | 135 | def prepare_dataset(args): 136 | ## prepare data split 137 | if args.dset == 'OFFICE': 138 | domains = ['amazon', 'dslr', 'webcam'] 139 | domain = domains[args.tid] + '/images' 140 | args.src_classes = [i for i in range(25)] 141 | args.tst_classes = [i for i in range(31)] 142 | 143 | if args.da == 'CSDA': 144 | args.src_classes = [i for i in range(31)] 145 | args.tar_classes = [i for i in range(31)] 146 | args.tst_classes = [i for i in range(31)] 147 | elif args.da == 'CDA': 148 | args.tar_classes = [i for i in range(25)] 149 | elif args.da == 'PDA': 150 | args.tar_classes = [i for i in range(20)] 151 | elif args.da == 'ODA': 152 | args.tar_classes = [i for i in range(28)] 153 | elif args.da == 'OPDA': 154 | args.tar_classes = [i for i in range(15)] + [i for i in range(25, 28)] 155 | 156 | allclassnames = ['back_pack', 'bike', 'bike_helmet', 'bookcase', 'bottle', 'calculator', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 157 | 'headphones', 'keyboard', 'laptop_computer', 'letter_tray', 'mobile_phone', 'monitor', 'mouse', 'mug', 'paper_notebook', 'pen', 158 | 'phone', 'printer', 'projector', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 159 | 'trash_can'] 160 | 161 | elif args.dset == 'OFFICEHOME': 162 | domains = ['Art', 'Clipart', 'Product', 'RealWorld'] 163 | domain = domains[args.tid] 164 | 165 | args.src_classes = [i for i in range(50)] 166 | args.tst_classes = [i for i in range(65)] 167 | 168 | if args.da == 'CSDA': 169 | args.src_classes = [i for i in range(65)] 170 | args.tar_classes = [i for i in range(65)] 171 | args.tst_classes = [i for i in range(65)] 172 | elif args.da == 'CDA': 173 | args.tar_classes = [i for i in range(50)] 174 | elif args.da == 'PDA': 175 | args.tar_classes = [i for i in range(35)] 176 | elif args.da == 'ODA': 177 | args.tar_classes = [i for i in range(60)] 178 | elif args.da == 'OPDA': 179 | args.tar_classes = [i for i in range(35)] + [i for i in range(50, 60)] 180 | 181 | allclassnames = ['alarm_clock', 'backpack', 'batteries', 'bed', 'bike', 'bottle', 'bucket', 'calculator', 'calendar', 'candles', 182 | 'chair', 'clipboards', 'computer', 'couch', 'curtains', 'desk_lamp', 'drill', 'eraser', 'exit_sign', 'fan', 183 | 'file_cabinet', 'flipflops', 'flowers', 'folder', 'fork', 'glasses', 'hammer', 'helmet', 'kettle', 'keyboard', 184 | 'knives', 'lamp_shade', 'laptop', 'marker', 'monitor', 'mop', 'mouse', 'mug', 'notebook', 'oven', 185 | 'pan', 'paper_clip', 'pen', 'pencil', 'postit_notes', 'printer', 'push_pin', 'radio', 'refrigerator', 'ruler', 186 | 'scissors', 'screwdriver', 'shelf', 'sink', 'sneakers', 'soda', 'speaker', 'spoon', 'table', 'telephone', 187 | 'toothbrush', 'toys', 'trash_can', 'tv', 'webcam'] 188 | 189 | elif args.dset == 'VISDAC': 190 | domains = ['train', 'validation'] 191 | domain = domains[args.tid] 192 | 193 | args.src_classes = [i for i in range(8)] 194 | args.tst_classes = [i for i in range(12)] 195 | 196 | if args.da == 'CSDA': 197 | args.src_classes = [i for i in range(12)] 198 | args.tar_classes = [i for i in range(12)] 199 | args.tst_classes = [i for i in range(12)] 200 | elif args.da == 'CDA': 201 | args.tar_classes = [i for i in range(8)] 202 | elif args.da == 'PDA': 203 | args.tar_classes = [i for i in range(6)] 204 | elif args.da == 'ODA': 205 | args.tar_classes = [i for i in range(10)] 206 | elif args.da == 'OPDA': 207 | args.tar_classes = [i for i in range(6)] + [i for i in range(8, 10)] 208 | 209 | allclassnames = ['aeroplane', 'bicycle', 'bus', 'car', 'horse', 'knife', 'motorcycle', 'person', 'plant', 'skateboard', 210 | 'train', 'truck'] 211 | 212 | elif args.dset == 'DOMAINNET': 213 | domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'] 214 | domain = domains[args.tid] 215 | 216 | args.src_classes = [i for i in range(300)] 217 | args.tst_classes = [i for i in range(345)] 218 | 219 | if args.da == 'CSDA': 220 | args.src_classes = [i for i in range(345)] 221 | args.tar_classes = [i for i in range(345)] 222 | args.tst_classes = [i for i in range(345)] 223 | elif args.da == 'CDA': 224 | args.tar_classes = [i for i in range(300)] 225 | elif args.da == 'PDA': 226 | args.tar_classes = [i for i in range(250)] 227 | elif args.da == 'ODA': 228 | args.tar_classes = [i for i in range(330)] 229 | elif args.da == 'OPDA': 230 | args.tar_classes = [i for i in range(250)] + [i for i in range(300, 330)] 231 | 232 | allclassnames = ['aircraft_carrier', 'airplane', 'alarm_clock', 'ambulance', 'angel', 'animal_migration', 'ant', 'anvil', 'apple', 'arm', 233 | 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball_bat', 'basket', 'basketball', 234 | 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 235 | 'binoculars', 'bird', 'birthday_cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 236 | 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 237 | 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 238 | 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling_fan', 'cello', 'cell_phone', 'chair', 'chandelier', 239 | 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee_cup', 'compass', 'computer', 'cookie', 'cooler', 240 | 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise_ship', 'cup', 'diamond', 'dishwasher', 241 | 'diving_board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 242 | 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 243 | 'feather', 'fence', 'finger', 'fire_hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip_flops', 244 | 'floor_lamp', 'flower', 'flying_saucer', 'foot', 'fork', 'frog', 'frying_pan', 'garden', 'garden_hose', 'giraffe', 245 | 'goatee', 'golf_club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 246 | 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey_puck', 'hockey_stick', 'horse', 'hospital', 'hot_air_balloon', 247 | 'hot_dog', 'hot_tub', 'hourglass', 'house', 'house_plant', 'hurricane', 'ice_cream', 'jacket', 'jail', 'kangaroo', 248 | 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light_bulb', 249 | 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 250 | 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 251 | 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 252 | 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paintbrush', 'paint_can', 'palm_tree', 'panda', 'pants', 253 | 'paper_clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 254 | 'pickup_truck', 'picture_frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police_car', 'pond', 'pool', 255 | 'popsicle', 'postcard', 'potato', 'power_outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 256 | 'rake', 'remote_control', 'rhinoceros', 'rifle', 'river', 'roller_coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 257 | 'saxophone', 'school_bus', 'scissors', 'scorpion', 'screwdriver', 'sea_turtle', 'see_saw', 'shark', 'sheep', 'shoe', 258 | 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping_bag', 'smiley_face', 'snail', 'snake', 259 | 'snorkel', 'snowflake', 'snowman', 'soccer_ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 260 | 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop_sign', 'stove', 261 | 'strawberry', 'streetlight', 'string_bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing_set', 'sword', 262 | 'syringe', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis_racquet', 'tent', 'The_Eiffel_Tower', 'The_Great_Wall_of_China', 263 | 'The_Mona_Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 264 | 'traffic_light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 't-shirt', 'umbrella', 'underwear', 265 | 'van', 'vase', 'violin', 'washing_machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine_bottle', 266 | 'wine_glass', 'wristwatch', 'yoga', 'zebra', 'zigzag'] 267 | 268 | args.allclassnames = allclassnames 269 | if args.dset == 'DOMAINNET': 270 | args.train_dset_path = os.path.join(args.list_root, args.dset, domains[args.tid] + '_train.txt') 271 | args.test_dset_path = os.path.join(args.list_root, args.dset, domains[args.tid] + '_test.txt') 272 | else: 273 | args.train_dset_path = os.path.join(args.list_root, args.dset, domains[args.tid] + '_list.txt') 274 | args.test_dset_path = os.path.join(args.list_root, args.dset, domains[args.tid] + '_list.txt') 275 | 276 | txt_tar, txt_test = modify_list(args) 277 | 278 | ## prepare dataloader 279 | dsets = {} 280 | dset_loaders = {} 281 | dsets['target'] = ImageList_idx(txt_tar, root=os.path.join(args.data_root, args.dset), transform=image_clip_train()) 282 | dset_loaders['target'] = DataLoader(dsets['target'], batch_size=args.bs, shuffle=True, num_workers=args.worker, drop_last=False) # pin_memory=True 283 | dsets["test"] = ImageList(txt_test, root=os.path.join(args.data_root, args.dset), transform=image_clip_test()) 284 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=args.bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 285 | 286 | dsets["tar_test"] = ImageList(txt_tar, root=os.path.join(args.data_root, args.dset), transform=image_clip_test()) 287 | dset_loaders["tar_test"] = DataLoader(dsets["tar_test"], batch_size=args.bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 288 | 289 | return args, dset_loaders 290 | 291 | def get_score(logits, labels, id_label): 292 | outputs = torch.nn.Softmax(dim=1)(logits) 293 | scores, p_labels = torch.max(outputs, dim=1) 294 | 295 | matrix = skm.confusion_matrix(labels[labels < id_label].numpy(), p_labels[labels < id_label].numpy()) 296 | global_acc = matrix.diagonal().sum()/ matrix.sum() 297 | class_acc = matrix.diagonal() / (matrix.sum(axis=1) + 1e-10) 298 | id_score = class_acc[matrix.sum(axis=1) > 0].mean() 299 | 300 | if sum(labels >= id_label) > 0: 301 | ood_labels = (labels < id_label).float() 302 | fpr, tpr, thresholds = skm.roc_curve(ood_labels, scores.numpy()) 303 | auc_score = skm.auc(fpr, tpr) 304 | return [global_acc, id_score, auc_score] 305 | else: 306 | return [id_score, global_acc] 307 | 308 | def norm_feature(features): 309 | features = features/ features.norm(dim=-1, keepdim=True) 310 | return features 311 | 312 | def load_features(vmodel, loader, text_features): 313 | logits_, labels_ = [], [] 314 | with torch.no_grad(): 315 | for _, data in enumerate(loader): 316 | images = data[0].cuda() 317 | labels = data[1] 318 | image_features = vmodel(images) 319 | image_features = norm_feature(image_features) 320 | # clip_logits = 100. * image_features @ text_features.T 321 | clip_logits = image_features @ text_features.T 322 | logits_.append(clip_logits) 323 | labels_.append(labels) 324 | logits_, labels_ = torch.cat(logits_), torch.cat(labels_) 325 | 326 | return logits_, labels_ 327 | 328 | def loss_entropy(input_, average=True): 329 | epsilon = 1e-5 330 | entropy = -input_ * torch.log(input_ + epsilon) 331 | if entropy.dim() == 1: 332 | entropy = torch.sum(entropy) 333 | return entropy 334 | 335 | if average: 336 | entropy = torch.sum(entropy, dim=1).mean() 337 | else: 338 | entropy = torch.sum(entropy, dim=1) 339 | return entropy 340 | 341 | def loss_entropy_wei(input_, weight): 342 | epsilon = 1e-5 343 | entropy = torch.sum(-input_ * torch.log(input_ + epsilon), dim=1) 344 | entropy = torch.sum(entropy * weight) / torch.sum(weight) 345 | return entropy 346 | 347 | def compute_transport_loss(logits, sim_t): 348 | s_dist = torch.nn.Softmax(dim=1)(logits) 349 | t_dist = torch.nn.Softmax(dim=0)(logits) 350 | cost = 1 - sim_t 351 | s_cost = (cost * s_dist).sum(1).mean() 352 | t_cost = (cost * t_dist).sum(0).mean() 353 | return s_cost + t_cost 354 | 355 | class VisualEncoder(nn.Module): 356 | def __init__(self, clip_model): 357 | super().__init__() 358 | self.model = clip_model.visual 359 | self.dtype = clip_model.dtype 360 | 361 | def forward(self, x): 362 | x = self.model(x.type(self.dtype)) 363 | return x 364 | 365 | def train_clip(args, dset_loaders): 366 | # load network 367 | clip_model, _ = clip.load(args.net) 368 | text_encoder = TextEncoder(clip_model).cuda() 369 | visual_encoder = VisualEncoder(clip_model).cuda() 370 | known_classnames = [args.allclassnames[i] for i in range(len(args.src_classes))] 371 | coop_prompt_learner = CoOp_PromptLearner(known_classnames, clip_model) 372 | 373 | text_encoder.eval() 374 | for p in text_encoder.parameters(): 375 | p.requires_grad = False 376 | 377 | visual_encoder.eval() 378 | for p in visual_encoder.parameters(): 379 | p.requires_grad = False 380 | 381 | coop_prompt_learner.eval() 382 | for p in coop_prompt_learner.parameters(): 383 | p.requires_grad = False 384 | 385 | # -------------------------------------- 386 | label_maps = torch.zeros(1 + max(max(args.src_classes), max(args.tst_classes)), ) 387 | for i in range(len(args.src_classes)): 388 | label_maps[args.src_classes[i]] = i 389 | k = 0 390 | 391 | for i in range(len(args.tst_classes)): 392 | if not args.tst_classes[i] in args.src_classes: 393 | label_maps[args.tst_classes[i]] = len(args.src_classes) + k 394 | k += 1 395 | 396 | args.label_maps = label_maps 397 | 398 | print('src private classes: {:}'.format(set(args.src_classes) - set(args.tar_classes))) 399 | print('tar private classes: {:}'.format(set(args.tar_classes) - set(args.src_classes))) 400 | print('shared classes: {:}'.format(set(args.tar_classes) & set(args.src_classes))) 401 | 402 | log_str = ('Training Epoch: {:} / {:}'.format(0, args.epochs)) 403 | with torch.no_grad(): 404 | prompts = coop_prompt_learner() 405 | tokenized_prompts = coop_prompt_learner.tokenized_prompts 406 | text_features = text_encoder(prompts, tokenized_prompts) 407 | text_features = norm_feature(text_features) 408 | 409 | clip_logits, images_labels = load_features(visual_encoder, dset_loaders['test'], text_features) 410 | images_labels = label_maps[images_labels].long() 411 | 412 | _score = get_score(clip_logits.cpu().float(), images_labels, len(args.src_classes)) 413 | 414 | if len(set(args.tst_classes) - set(args.src_classes)): 415 | log_str += ('\n(Tar-test) GACC:{:.2f} PACC:{:.2f} AUROC:{:.2f}'.format(_score[0]*100, _score[1]*100, _score[2]*100)) 416 | else: 417 | log_str += ('\n(Tar-test) PACC:{:.2f} GACC:{:.2f}'.format(_score[0]*100, _score[1]*100)) 418 | 419 | with torch.no_grad(): 420 | prompts = coop_prompt_learner() 421 | tokenized_prompts = coop_prompt_learner.tokenized_prompts 422 | text_features = text_encoder(prompts, tokenized_prompts) 423 | text_features = norm_feature(text_features) 424 | 425 | clip_logits, images_labels = load_features(visual_encoder, dset_loaders['tar_test'], text_features) 426 | images_labels = label_maps[images_labels].long() 427 | 428 | _score = get_score(clip_logits.cpu().float(), images_labels, len(args.src_classes)) 429 | 430 | if len(set(args.tar_classes) - set(args.src_classes)): 431 | log_str += ('\n(Tar-train) GACC:{:.2f} PACC:{:.2f} AUROC:{:.2f}'.format(_score[0]*100, _score[1]*100, _score[2]*100)) 432 | else: 433 | log_str += ('\n(Tar-train) PACC:{:.2f} GACC:{:.2f}'.format(_score[0]*100, _score[1]*100)) 434 | 435 | print(log_str) 436 | args.out_file.write(log_str + '\n') 437 | args.out_file.flush() 438 | 439 | param_group = [] 440 | if args.plr > 0: 441 | for k, v in coop_prompt_learner.named_parameters(): 442 | v.requires_grad = True 443 | param_group += [{'params': v, 'lr': args.plr}] 444 | if args.vlr > 0: 445 | for k, v in visual_encoder.named_parameters(): 446 | if 'bn' in k or 'ln' in k: 447 | v.requires_grad = True 448 | param_group += [{'params': v, 'lr': args.vlr}] 449 | 450 | optimizer = torch.optim.SGD(param_group) 451 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(dset_loaders['target'])) 452 | 453 | target_outputs = torch.nn.Softmax(dim=1)(clip_logits) 454 | target_maxpro, _ = torch.max(target_outputs, dim=1) 455 | 456 | for epoch_idx in range(args.epochs): 457 | visual_encoder.eval() 458 | 459 | for i, data in enumerate(dset_loaders['target']): 460 | images = data[0].cuda() 461 | index_t = data[2].cuda() 462 | 463 | image_features = visual_encoder(images) 464 | image_features = norm_feature(image_features) 465 | 466 | prompts = coop_prompt_learner() 467 | tokenized_prompts = coop_prompt_learner.tokenized_prompts 468 | text_features = text_encoder(prompts, tokenized_prompts) 469 | text_features = norm_feature(text_features) 470 | logits = 100. * image_features @ text_features.T 471 | 472 | outputs = torch.nn.Softmax(dim=1)(logits) 473 | if args.noweight: 474 | mean_outputs = torch.mean(outputs, dim=0) 475 | loss = loss_entropy(outputs) - args.trade * loss_entropy(mean_outputs) 476 | else: 477 | if args.oracle: 478 | images_labels = label_maps[data[1]].long().cuda() 479 | weight = (images_labels < len(args.src_classes)).type(target_maxpro.dtype) + 1e-3 480 | mean_outputs = torch.mm(torch.diag(1 / weight), outputs).sum(dim=0) / torch.sum(1 / weight) 481 | loss = loss_entropy_wei(outputs, weight) - args.trade * loss_entropy(mean_outputs) 482 | else: 483 | weight = target_maxpro[index_t] 484 | mean_outputs = torch.mm(torch.diag(1 / weight), outputs).sum(dim=0) / torch.sum(1 / weight) 485 | loss = loss_entropy_wei(outputs, weight) - args.trade * loss_entropy(mean_outputs) 486 | 487 | optimizer.zero_grad() 488 | optimizer.zero_grad() 489 | loss.backward() 490 | optimizer.step() 491 | scheduler.step() 492 | 493 | if (epoch_idx + 1) % args.eval_epoch == 0: 494 | log_str = ('Training Epoch: {:} / {:}'.format(epoch_idx + 1, args.epochs)) 495 | visual_encoder.eval() 496 | with torch.no_grad(): 497 | prompts = coop_prompt_learner() 498 | tokenized_prompts = coop_prompt_learner.tokenized_prompts 499 | text_features = text_encoder(prompts, tokenized_prompts) 500 | text_features = norm_feature(text_features) 501 | 502 | clip_logits, images_labels = load_features(visual_encoder, dset_loaders['test'], text_features) 503 | images_labels = label_maps[images_labels].long() 504 | 505 | _score = get_score(clip_logits.cpu().float(), images_labels, len(args.src_classes)) 506 | 507 | if len(set(args.tst_classes) - set(args.src_classes)): 508 | log_str += ('\n(Tar-test) GACC:{:.2f} PACC:{:.2f} AUROC:{:.2f}'.format(_score[0]*100, _score[1]*100, _score[2]*100)) 509 | else: 510 | log_str += ('\n(Tar-test) PACC:{:.2f} GACC:{:.2f}'.format(_score[0]*100, _score[1]*100)) 511 | 512 | with torch.no_grad(): 513 | prompts = coop_prompt_learner() 514 | tokenized_prompts = coop_prompt_learner.tokenized_prompts 515 | text_features = text_encoder(prompts, tokenized_prompts) 516 | text_features = norm_feature(text_features) 517 | 518 | clip_logits, images_labels = load_features(visual_encoder, dset_loaders['tar_test'], text_features) 519 | images_labels = label_maps[images_labels].long() 520 | 521 | _score = get_score(clip_logits.cpu().float(), images_labels, len(args.src_classes)) 522 | 523 | if len(set(args.tar_classes) - set(args.src_classes)): 524 | log_str += ('\n(Tar-train) GACC:{:.2f} PACC:{:.2f} AUROC:{:.2f}'.format(_score[0]*100, _score[1]*100, _score[2]*100)) 525 | else: 526 | log_str += ('\n(Tar-train) PACC:{:.2f} GACC:{:.2f}'.format(_score[0]*100, _score[1]*100)) 527 | 528 | print(log_str) 529 | args.out_file.write(log_str + '\n') 530 | args.out_file.flush() 531 | 532 | def print_args(args): 533 | s = "==========================================\n" 534 | for arg, content in args.__dict__.items(): 535 | s += "{}:{}\n".format(arg, content) 536 | return s 537 | 538 | class LinearAverage(nn.Module): 539 | def __init__(self, inputSize, outputSize, T=0.05, momentum=0.0): 540 | super(LinearAverage, self).__init__() 541 | self.nLem = outputSize 542 | self.momentum = momentum 543 | self.register_buffer('params', torch.tensor([T, momentum])); 544 | self.register_buffer('memory', torch.zeros(outputSize, inputSize)) 545 | self.flag = 0 546 | self.T = T 547 | self.memory = self.memory.cuda() 548 | 549 | def forward(self, x, y): 550 | # pdb.set_trace() 551 | out = torch.mm(x.float(), self.memory.t())/self.T 552 | return out 553 | 554 | def update_weight(self, features, index): 555 | if not self.flag: 556 | weight_pos = self.memory.index_select(0, index.data.view(-1)).resize_as_(features) 557 | weight_pos.mul_(0.0) 558 | weight_pos.add_(torch.mul(features.data, 1.0)) 559 | 560 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5) 561 | updated_weight = weight_pos.div(w_norm) 562 | self.memory.index_copy_(0, index, updated_weight) 563 | self.flag = 1 564 | else: 565 | weight_pos = self.memory.index_select(0, index.data.view(-1)).resize_as_(features) 566 | weight_pos.mul_(self.momentum) 567 | weight_pos.add_(torch.mul(features.data, 1 - self.momentum)) 568 | 569 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5) 570 | updated_weight = weight_pos.div(w_norm) 571 | self.memory.index_copy_(0, index, updated_weight) 572 | self.memory = torch.nn.functional.normalize(self.memory)#.cuda() 573 | 574 | def set_weight(self, features, index): 575 | self.memory.index_copy_(0, index, features) 576 | 577 | if __name__ == "__main__": 578 | parser = argparse.ArgumentParser(description='xxx') 579 | parser.add_argument('--dset', type=str, default='OFFICEHOME', choices=['VISDAC', 'OFFICE', 'OFFICEHOME', 'miniDOMAINNET', 'DOMAINNET']) 580 | parser.add_argument('--tid', type=int, default=0, help="target") 581 | parser.add_argument('--da', type=str, default='OPDA', choices=['CDA', 'PDA', 'ODA', 'OPDA', 'CSDA']) 582 | parser.add_argument('--net', type=str, default='RN50', choices=['RN50', 'RN101', 'ViT-B/32', 'ViT-B/16']) 583 | parser.add_argument('--seed', type=int, default=2023, help="random seed") 584 | parser.add_argument('--data_root', type=str, default='/data1/xxx/datasets/cls/') 585 | parser.add_argument('--list_root', type=str, default='./list/') 586 | 587 | parser.add_argument('--log', type=str, default='logszz/') 588 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 589 | parser.add_argument('--plr', type=float, default=1e-4, help="learning rate") 590 | parser.add_argument('--vlr', type=float, default=1e-4, help="learning rate") 591 | parser.add_argument('--bs', type=int, default=64, help="batch size") 592 | parser.add_argument('--epochs', type=int, default=50, help="number of epochs") 593 | parser.add_argument('--eval_epoch', type=int, default=10, help="the interval of evaluation epochs") 594 | parser.add_argument('--trade', type=float, default=1.0) 595 | parser.add_argument('--noweight', action="store_true") 596 | parser.add_argument('--oracle', action="store_true") 597 | 598 | args = parser.parse_args() 599 | SEED = args.seed 600 | torch.manual_seed(SEED) 601 | torch.cuda.manual_seed(SEED) 602 | np.random.seed(SEED) 603 | random.seed(SEED) 604 | torch.backends.cudnn.benchmark = True 605 | # torch.backends.cudnn.deterministic = True 606 | 607 | args, dset_loaders = prepare_dataset(args) 608 | envs = print_args(args) 609 | 610 | name = args.net.replace("/", "") 611 | if args.noweight: 612 | ff = '_noweight' 613 | elif args.oracle: 614 | ff = '_oracle' 615 | else: 616 | ff = '_ours' 617 | 618 | output_dir_src = osp.join(args.log, str(args.seed) + ff, name + '_vlr_' + str(args.vlr) + '_plr_' + str(args.plr), args.dset) 619 | if not osp.exists(output_dir_src): 620 | os.system('mkdir -p ' + output_dir_src) 621 | if not osp.exists(output_dir_src): 622 | os.mkdir(output_dir_src) 623 | 624 | args.out_file = open(osp.join(output_dir_src, '@{:}_{:}_trade_{:.1f}.txt'.format(args.tid, args.da, args.trade)), 'w') 625 | train_clip(args, dset_loaders) 626 | 627 | args.out_file.write('\n' + envs) 628 | args.out_file.flush() 629 | args.out_file.close() -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=$1 python demo.py --dset DOMAINNET --tid 5 --da OPDA --seed 2023 --epochs 5 --eval_epoch 1 --vlr 1e-4 --plr 1e-4 --log logs --trade 1.0 4 | CUDA_VISIBLE_DEVICES=$1 python demo.py --dset DOMAINNET --tid 5 --da ODA --seed 2023 --epochs 5 --eval_epoch 1 --vlr 1e-4 --plr 1e-4 --log logs --trade 1.0 5 | CUDA_VISIBLE_DEVICES=$1 python demo.py --dset DOMAINNET --tid 5 --da CDA --seed 2023 --epochs 5 --eval_epoch 1 --vlr 1e-4 --plr 1e-4 --log logs --trade 1.0 6 | CUDA_VISIBLE_DEVICES=$1 python demo.py --dset DOMAINNET --tid 5 --da PDA --seed 2023 --epochs 5 --eval_epoch 1 --vlr 1e-4 --plr 1e-4 --log logs --trade 1.0 -------------------------------------------------------------------------------- /task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/UEO/a5e749ecbe753081b061d2201f798f75ac02cade/task.png --------------------------------------------------------------------------------