├── dawin_rft ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── simple_tokenizer.py │ └── clip.py ├── datasets │ ├── imagenet_sketch.py │ ├── __init__.py │ ├── imagenetv2.py │ ├── common.py │ ├── get_objectnet_classnames.py │ ├── imagenet_r.py │ ├── imagenet_a.py │ ├── objectnet_metadata │ │ ├── objectnet_to_imagenet_1k.json │ │ └── folder_to_objectnet_label.json │ ├── objectnet_beta_metadata │ │ └── objectnet_to_imagenet_1k.json │ ├── imagenet.py │ └── objectnet.py ├── script │ ├── dawin.sh │ ├── dawin_baselines.sh │ ├── dawin_samplewise.sh │ ├── dawin_applications.sh │ └── dawin_ablation.sh ├── zeroshot.py ├── openai_imagenet_template.py ├── main.py ├── README.md └── mixturemodel.py ├── dawin_mtl ├── src │ ├── datasets │ │ ├── sun397.py │ │ ├── dtd.py │ │ ├── mnist.py │ │ ├── svhn.py │ │ ├── eurosat.py │ │ ├── registry.py │ │ ├── common.py │ │ ├── cars.py │ │ ├── templates.py │ │ └── gtsrb.py │ ├── main_individuals.py │ ├── heads.py │ ├── main_task_arithmetic.py │ ├── main_weight_avg.py │ ├── task_vectors.py │ ├── main_ties_merging.py │ ├── args.py │ ├── modeling.py │ ├── ties_merging_utils.py │ ├── utils.py │ ├── eval.py │ └── main_layer_wise_adamerging.py ├── download_ckpt.py ├── README.md └── split_dataset.py ├── dawin.yaml ├── NOTICE └── README.md /dawin_rft/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /dawin_rft/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/dawin/HEAD/dawin_rft/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /dawin_rft/datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .imagenet import ImageNet 3 | 4 | class ImageNetSketch(ImageNet): 5 | 6 | def populate_train(self): 7 | pass 8 | 9 | def get_test_path(self): 10 | return os.path.join(self.location, 'sketch') 11 | -------------------------------------------------------------------------------- /dawin_rft/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .imagenet import * 2 | from .imagenetv2 import ImageNetV2 3 | from .imagenet_a import ImageNetAValClasses, ImageNetA 4 | from .imagenet_r import ImageNetRValClasses, ImageNetR 5 | from .objectnet import ObjectNetValClasses, ObjectNet 6 | from .imagenet_sketch import ImageNetSketch -------------------------------------------------------------------------------- /dawin_rft/script/dawin.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 4 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 5 | --model ViT-B/32 \ 6 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 7 | --eval-batch-size 1024 --bmm_ncluster 3 \ 8 | --offset-adjustment \ 9 | --expertise negexp_ent_ratio -------------------------------------------------------------------------------- /dawin_rft/script/dawin_baselines.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | for method in uniform wise_gaussian gaussian 4 | do 5 | 6 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 7 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 8 | --model ViT-B/32 \ 9 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 10 | --eval-batch-size 1024 --bmm_ncluster 3 \ 11 | --baseline $method 12 | 13 | done -------------------------------------------------------------------------------- /dawin_rft/script/dawin_samplewise.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 4 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 5 | --model ViT-B/32 \ 6 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 7 | --eval-batch-size 1 --bmm_ncluster 0 \ 8 | --expertise negexp_loss_ratio 9 | 10 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 11 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 12 | --model ViT-B/32 \ 13 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 14 | --eval-batch-size 1 --bmm_ncluster 0 \ 15 | --expertise negexp_ent_ratio --offset-adjustment -------------------------------------------------------------------------------- /dawin_rft/script/dawin_applications.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 4 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 5 | --model ViT-B/32 \ 6 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 7 | --eval-batch-size 1024 \ 8 | --expertise selective_entropy 9 | 10 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 11 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 12 | --model ViT-B/32 \ 13 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 14 | --eval-batch-size 1024 \ 15 | --expertise negexp_ent_ratio --offset-adjustment \ 16 | --ensemble_coef dynamic -------------------------------------------------------------------------------- /dawin_rft/datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from imagenetv2_pytorch import ImageNetV2Dataset 4 | 5 | from .imagenet import ImageNet 6 | 7 | class ImageNetV2DatasetWithPaths(ImageNetV2Dataset): 8 | def __getitem__(self, i): 9 | img, label = Image.open(self.fnames[i]), int(self.fnames[i].parent.name) 10 | if self.transform is not None: 11 | img = self.transform(img) 12 | return { 13 | 'images': img, 14 | 'labels': label, 15 | 'image_paths': str(self.fnames[i]) 16 | } 17 | 18 | class ImageNetV2(ImageNet): 19 | def get_test_dataset(self): 20 | return ImageNetV2DatasetWithPaths(transform=self.preprocess, location=self.location) 21 | -------------------------------------------------------------------------------- /dawin_rft/script/dawin_ablation.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | for psl in 1 2 3 4 4 | do 5 | 6 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 7 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 8 | --model ViT-B/32 \ 9 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 10 | --eval-batch-size 1024 --bmm_ncluster 3 \ 11 | --offset-adjustment \ 12 | --expertise negexp_loss_ratio --pseudo_label $psl 13 | 14 | done 15 | 16 | for expt in maxlogit_ratio expconf_ratio expconfdiff_ratio 17 | do 18 | 19 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 20 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 21 | --model ViT-B/32 \ 22 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 23 | --eval-batch-size 1024 --bmm_ncluster 3 \ 24 | --expertise $expt 25 | 26 | done -------------------------------------------------------------------------------- /dawin_rft/datasets/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import glob 5 | import collections 6 | import random 7 | 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | import torchvision.datasets as datasets 13 | from torch.utils.data import Dataset, DataLoader, Sampler 14 | 15 | 16 | class SubsetSampler(Sampler): 17 | def __init__(self, indices): 18 | self.indices = indices 19 | 20 | def __iter__(self): 21 | return (i for i in self.indices) 22 | 23 | def __len__(self): 24 | return len(self.indices) 25 | 26 | class ImageFolderWithPaths(datasets.ImageFolder): 27 | def __init__(self, path, transform): 28 | super().__init__(path, transform) 29 | 30 | def __getitem__(self, index): 31 | image, label = super(ImageFolderWithPaths, self).__getitem__(index) 32 | return { 33 | 'images': image, 34 | 'labels': label, 35 | 'image_paths': self.samples[index][0] 36 | } 37 | -------------------------------------------------------------------------------- /dawin_rft/datasets/get_objectnet_classnames.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | 5 | if __name__ == '__main__': 6 | METADATA = Path(__file__).parent / 'objectnet_metadata' 7 | 8 | with open(METADATA / 'folder_to_objectnet_label.json', 'r') as f: 9 | folder_map = json.load(f) 10 | folder_map = {v: k for k, v in folder_map.items()} 11 | 12 | with open(METADATA / 'objectnet_to_imagenet_1k.json', 'r') as f: 13 | objectnet_map = json.load(f) 14 | 15 | with open(METADATA / 'imagenet_to_labels.json', 'r') as f: 16 | imagenet_map = json.load(f) 17 | imagenet_map = {v: k for k, v in imagenet_map.items()} 18 | 19 | folder_to_ids, class_sublist = {}, [] 20 | for objectnet_name, imagenet_names in objectnet_map.items(): 21 | imagenet_names = imagenet_names.split('; ') 22 | imagenet_ids = [int(imagenet_map[imagenet_name]) for imagenet_name in imagenet_names] 23 | class_sublist.extend(imagenet_ids) 24 | folder_to_ids[folder_map[objectnet_name]] = imagenet_ids 25 | 26 | class_sublist = sorted(class_sublist) 27 | class_sublist_mask = [(i in class_sublist) for i in range(1000)] 28 | folder_map = {k: [class_sublist.index(x) for x in v] for k, v in folder_to_ids.items()} 29 | 30 | print('here') -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class SUN397: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=32, 10 | num_workers=0): 11 | # Data loading code 12 | traindir = os.path.join(location, 'sun397', 'train') 13 | valdir = os.path.join(location, 'sun397', 'val') 14 | 15 | 16 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 17 | self.train_loader = torch.utils.data.DataLoader( 18 | self.train_dataset, 19 | shuffle=True, 20 | batch_size=batch_size, 21 | num_workers=num_workers, 22 | ) 23 | 24 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 25 | self.test_loader = torch.utils.data.DataLoader( 26 | self.test_dataset, 27 | batch_size=batch_size, 28 | num_workers=num_workers 29 | ) 30 | self.test_loader_shuffle = torch.utils.data.DataLoader( 31 | self.test_dataset, 32 | shuffle=True, 33 | batch_size=batch_size, 34 | num_workers=num_workers 35 | ) 36 | idx_to_class = dict((v, k) 37 | for k, v in self.train_dataset.class_to_idx.items()) 38 | self.classnames = [idx_to_class[i][2:].replace('_', ' ') for i in range(len(idx_to_class))] 39 | -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | 6 | class DTD: 7 | def __init__(self, 8 | preprocess, 9 | location=os.path.expanduser('~/data'), 10 | batch_size=32, 11 | num_workers=0): 12 | # Data loading code 13 | traindir = os.path.join(location, 'dtd', 'train') 14 | valdir = os.path.join(location, 'dtd', 'val') 15 | 16 | self.train_dataset = datasets.ImageFolder( 17 | traindir, transform=preprocess) 18 | self.train_loader = torch.utils.data.DataLoader( 19 | self.train_dataset, 20 | shuffle=True, 21 | batch_size=batch_size, 22 | num_workers=num_workers, 23 | ) 24 | 25 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 26 | self.test_loader = torch.utils.data.DataLoader( 27 | self.test_dataset, 28 | batch_size=batch_size, 29 | num_workers=num_workers 30 | ) 31 | 32 | self.test_loader_shuffle = torch.utils.data.DataLoader( 33 | self.test_dataset, 34 | shuffle=True, 35 | batch_size=batch_size, 36 | num_workers=num_workers 37 | ) 38 | 39 | idx_to_class = dict((v, k) 40 | for k, v in self.train_dataset.class_to_idx.items()) 41 | self.classnames = [idx_to_class[i].replace( 42 | '_', ' ') for i in range(len(idx_to_class))] -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class MNIST: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=128, 10 | num_workers=0): 11 | 12 | 13 | self.train_dataset = datasets.MNIST( 14 | root=location, 15 | #download=True, 16 | download=False, 17 | train=True, 18 | transform=preprocess 19 | ) 20 | 21 | self.train_loader = torch.utils.data.DataLoader( 22 | self.train_dataset, 23 | batch_size=batch_size, 24 | shuffle=True, 25 | num_workers=num_workers 26 | ) 27 | 28 | self.test_dataset = datasets.MNIST( 29 | root=location, 30 | #download=True, 31 | download=False, 32 | train=False, 33 | transform=preprocess 34 | ) 35 | 36 | self.test_loader = torch.utils.data.DataLoader( 37 | self.test_dataset, 38 | batch_size=batch_size, 39 | shuffle=False, 40 | num_workers=num_workers 41 | ) 42 | 43 | self.test_loader_shuffle = torch.utils.data.DataLoader( 44 | self.test_dataset, 45 | batch_size=batch_size, 46 | shuffle=True, 47 | num_workers=num_workers 48 | ) 49 | 50 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import SVHN as PyTorchSVHN 4 | import numpy as np 5 | 6 | 7 | class SVHN: 8 | def __init__(self, 9 | preprocess, 10 | location=os.path.expanduser('~/data'), 11 | batch_size=128, 12 | num_workers=0): 13 | 14 | # to fit with repo conventions for location 15 | modified_location = os.path.join(location, 'svhn') 16 | 17 | self.train_dataset = PyTorchSVHN( 18 | root=modified_location, 19 | #download=True, 20 | download=False, 21 | split='train', 22 | transform=preprocess 23 | ) 24 | 25 | self.train_loader = torch.utils.data.DataLoader( 26 | self.train_dataset, 27 | batch_size=batch_size, 28 | shuffle=True, 29 | num_workers=num_workers 30 | ) 31 | 32 | self.test_dataset = PyTorchSVHN( 33 | root=modified_location, 34 | #download=True, 35 | download=False, 36 | split='test', 37 | transform=preprocess 38 | ) 39 | 40 | self.test_loader = torch.utils.data.DataLoader( 41 | self.test_dataset, 42 | batch_size=batch_size, 43 | shuffle=False, 44 | num_workers=num_workers 45 | ) 46 | self.test_loader_shuffle = torch.utils.data.DataLoader( 47 | self.test_dataset, 48 | batch_size=batch_size, 49 | shuffle=True, 50 | num_workers=num_workers 51 | ) 52 | 53 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 54 | -------------------------------------------------------------------------------- /dawin_rft/datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | from .imagenet import ImageNetSubsample, ImageNetSubsampleValClasses 6 | import numpy as np 7 | 8 | 9 | CLASS_SUBLIST = [ 10 | 1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107, 11 | 113, 122, 12 | 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203, 13 | 207, 208, 219, 14 | 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289, 15 | 291, 292, 293, 16 | 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347, 17 | 353, 355, 361, 18 | 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 19 | 448, 457, 462, 20 | 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613, 21 | 617, 621, 629, 22 | 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 23 | 866, 875, 883, 24 | 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 25 | 967, 980, 981, 26 | 983, 988] 27 | CLASS_SUBLIST_MASK = [(i in CLASS_SUBLIST) for i in range(1000)] 28 | 29 | 30 | class ImageNetRValClasses(ImageNetSubsampleValClasses): 31 | def get_class_sublist_and_mask(self): 32 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK 33 | 34 | class ImageNetR(ImageNetSubsample): 35 | def get_class_sublist_and_mask(self): 36 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK 37 | 38 | def get_test_path(self): 39 | return os.path.join(self.location, 'imagenet-r') -------------------------------------------------------------------------------- /dawin_rft/datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | from .imagenet import ImageNetSubsample, ImageNetSubsampleValClasses 6 | import numpy as np 7 | 8 | 9 | CLASS_SUBLIST = [ 10 | 6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107, 11 | 108, 110, 12 | 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, 306, 307, 13 | 308, 309, 14 | 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, 372, 378, 386, 397, 15 | 400, 401, 16 | 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, 462, 470, 472, 483, 486, 488, 17 | 492, 496, 18 | 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, 573, 575, 579, 589, 606, 607, 609, 614, 19 | 626, 627, 20 | 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, 719, 736, 746, 749, 752, 758, 763, 765, 768, 773, 21 | 774, 776, 22 | 779, 780, 786, 792, 797, 802, 803, 804, 813, 815, 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870, 23 | 879, 880, 24 | 888, 890, 897, 900, 907, 913, 924, 932, 933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 25 | 981, 984, 26 | 986, 987, 988] 27 | CLASS_SUBLIST_MASK = [(i in CLASS_SUBLIST) for i in range(1000)] 28 | 29 | 30 | class ImageNetAValClasses(ImageNetSubsampleValClasses): 31 | def get_class_sublist_and_mask(self): 32 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK 33 | 34 | 35 | class ImageNetA(ImageNetSubsample): 36 | def get_class_sublist_and_mask(self): 37 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK 38 | 39 | def get_test_path(self): 40 | return os.path.join(self.location, 'imagenet-a') 41 | -------------------------------------------------------------------------------- /dawin_mtl/download_ckpt.py: -------------------------------------------------------------------------------- 1 | ''' 2 | official checkpoints provided in the codebase of "Editing Models with Task Arithmetic (ICLR23)" 3 | ''' 4 | import gdown 5 | import os 6 | prefix='/YOUR_ROOT_PATH/dawin/dawin_mtl/checkpoints/ViT-B-32' 7 | 8 | ftdict={ 9 | 'GTSRB':'1BDEpJSL0GWGfD93OvymW_oZZeWfyYclM', 10 | 'SUN397':'1HWs-n0GEewLQ9hXnOJYtm85GtJpJl35Q', 11 | 'EuroSAT':'1EZIkpSox2tuXRDADFw2Ih9Rfown94BKU', 12 | 'MNIST':'1IYE99VnUcKso2UOUYRfuts8MtrZ7IdqU', 13 | 'Cars':'101U8_jLvsDg6WePgsDK9QvEFw1_9jnbC', 14 | 'DTD':'1octqDdrX8vOSRfMWKHSBodPiC5OS4XBA', 15 | 'RESISC45':'1Glu2Hky3qa58LJgtWhwDa50CLkCPdkuo', 16 | 'SVHN':'1_xM_tfAJPm_0YoonnaSn4rVeek2qrzk9' 17 | } 18 | 19 | headdict={ 20 | 'GTSRB':'1j2yVXOGdzHzcT6Zp5kc8xold77VRxwIC', 21 | 'SUN397':'1kXAWZZ9P4p_fsqQ8UjOKEWGViksF8rsx', 22 | 'EuroSAT':'1-aqbjz5rKJDPPtduQk1L7n0OimWkpRM6', 23 | 'MNIST':'1_9lBNw877gb9fGcXQTZljVANgU8e_9Z4', 24 | 'Cars':'1NZmJHeELIwGDa0vhoPHp89J4iXpsVhb5', 25 | 'DTD':'1hWP7hN_inUnuZIin_4yooD74iKQfNq5q', 26 | 'RESISC45':'1HrUa1eDtG_CFd_TYox-tO_8BHqME1YR3', 27 | 'SVHN':'1AXWsTQjk5KKaRsIJK3wszZT4RvEWJ3cM' 28 | } 29 | 30 | for ds, fid in ftdict.items(): 31 | print(f'start download {ds} to {prefix}') 32 | if not os.path.exists(prefix+f'/{ds}'): os.makedirs(prefix+f'/{ds}',exist_ok=True) 33 | output = prefix + f'/{ds}/finetuned.pt' 34 | #url = f'https://drive.google.com/file/d/{fid}/view?usp=drive_link' 35 | url = f'https://drive.google.com/uc?id={fid}' 36 | gdown.download(url, output, quiet=False) 37 | 38 | for ds, fid in headdict.items(): 39 | print(f'start download {ds} head to {prefix}') 40 | output = prefix + f'/head_{ds}.pt' 41 | #url = f'https://drive.google.com/file/d/{fid}/view?usp=drive_link' 42 | url = f'https://drive.google.com/uc?id={fid}' 43 | gdown.download(url, output, quiet=False) 44 | 45 | print('start download zeroshot.pt') 46 | output = prefix + f'/zeroshot.pt' 47 | fid='145ZjznF8HyTQvtlK1Mw9ybArVhm4b-8C' 48 | url = f'https://drive.google.com/uc?id={fid}' 49 | gdown.download(url, output, quiet=False) 50 | -------------------------------------------------------------------------------- /dawin_mtl/src/main_individuals.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | 7 | import time 8 | import sys 9 | sys.path.append('..') 10 | 11 | from task_vectors import TaskVector 12 | from eval import eval_single_dataset 13 | from args import parse_arguments 14 | import wandb 15 | 16 | def create_log_dir(path, filename='log.txt'): 17 | import logging 18 | if not os.path.exists(path): 19 | os.makedirs(path) 20 | logger = logging.getLogger(path) 21 | logger.setLevel(logging.DEBUG) 22 | fh = logging.FileHandler(path+'/'+filename) 23 | fh.setLevel(logging.DEBUG) 24 | ch = logging.StreamHandler() 25 | ch.setLevel(logging.DEBUG) 26 | logger.addHandler(fh) 27 | logger.addHandler(ch) 28 | return logger 29 | 30 | 31 | model = 'ViT-B-32' 32 | args = parse_arguments() 33 | 34 | if args.train_dataset is None: 35 | args.train_dataset = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 36 | if args.eval_datasets is None: 37 | args.eval_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 38 | train_datasets = args.train_dataset 39 | exam_datasets = args.eval_datasets 40 | 41 | if args.wb_project: 42 | wandb_args = {"project": args.wb_project} 43 | wandb_args["name"] = args.wb_runname if args.wb_runname else None 44 | wandb.init(**wandb_args, config=vars(args), save_code=False) 45 | 46 | args.data_location = '../data' 47 | args.model = model 48 | args.save = '../checkpoints/' + model 49 | args.logs_path = '../logs/' + model 50 | pretrained_checkpoint = '../checkpoints/'+model+'/zeroshot.pt' 51 | 52 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 53 | log = create_log_dir(args.logs_path, 'log_{}_task_arithmetic.txt'.format(str_time_)) 54 | 55 | print(f'Use {args.device}') 56 | accs = [] 57 | metric_dict = {} 58 | for dataset in exam_datasets: 59 | task_vector = TaskVector(pretrained_checkpoint, '../checkpoints/'+model+'/'+dataset+'/finetuned.pt') 60 | image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) 61 | 62 | metrics = eval_single_dataset(image_encoder, dataset, args) 63 | log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%') 64 | accs.append(metrics.get('top1')*100) 65 | metric_dict[f"{str(dataset)}_Acc"] = metrics.get('top1')*100 66 | log.info('Avg ACC:' + str(np.mean(accs)) + '%') 67 | metric_dict[f"Avg_Acc"] = np.mean(accs) -------------------------------------------------------------------------------- /dawin_mtl/src/heads.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | 5 | import open_clip 6 | 7 | from datasets.templates import get_templates 8 | from datasets.registry import get_dataset 9 | 10 | from modeling import ClassificationHead, ImageEncoder 11 | 12 | 13 | def build_classification_head(model, dataset_name, template, data_location, device): 14 | template = get_templates(dataset_name) 15 | 16 | logit_scale = model.logit_scale 17 | dataset = get_dataset( 18 | dataset_name, 19 | None, 20 | location=data_location 21 | ) 22 | model.eval() 23 | model.to(device) 24 | 25 | print('Building classification head.') 26 | with torch.no_grad(): 27 | zeroshot_weights = [] 28 | for classname in tqdm(dataset.classnames): 29 | texts = [] 30 | for t in template: 31 | texts.append(t(classname)) 32 | texts = open_clip.tokenize(texts).to(device) # tokenize 33 | embeddings = model.encode_text(texts) # embed with text encoder 34 | embeddings /= embeddings.norm(dim=-1, keepdim=True) 35 | 36 | embeddings = embeddings.mean(dim=0, keepdim=True) 37 | embeddings /= embeddings.norm() 38 | 39 | zeroshot_weights.append(embeddings) 40 | 41 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) 42 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) 43 | 44 | zeroshot_weights *= logit_scale.exp() 45 | 46 | zeroshot_weights = zeroshot_weights.squeeze().float() 47 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) 48 | 49 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) 50 | 51 | return classification_head 52 | 53 | 54 | def get_classification_head(args, dataset): 55 | filename = os.path.join(args.save, f'head_{dataset}.pt') 56 | if os.path.exists(filename): 57 | print(f'Classification head for {args.model} on {dataset} exists at {filename}') 58 | return ClassificationHead.load(filename) 59 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.') 60 | model = ImageEncoder(args, keep_lang=True).model 61 | template = get_templates(dataset) 62 | classification_head = build_classification_head(model, dataset, template, args.data_location, args.device) 63 | os.makedirs(args.save, exist_ok=True) 64 | classification_head.save(filename) 65 | return classification_head 66 | 67 | -------------------------------------------------------------------------------- /dawin_mtl/src/main_task_arithmetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | from task_vectors import TaskVector 7 | from eval import eval_single_dataset 8 | from args import parse_arguments 9 | import wandb 10 | 11 | def create_log_dir(path, filename='log.txt'): 12 | import logging 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | logger = logging.getLogger(path) 16 | logger.setLevel(logging.DEBUG) 17 | fh = logging.FileHandler(path+'/'+filename) 18 | fh.setLevel(logging.DEBUG) 19 | ch = logging.StreamHandler() 20 | ch.setLevel(logging.DEBUG) 21 | logger.addHandler(fh) 22 | logger.addHandler(ch) 23 | return logger 24 | 25 | args = parse_arguments() 26 | 27 | model = 'ViT-B-32' 28 | 29 | if args.train_dataset is None: 30 | args.train_dataset = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 31 | if args.eval_datasets is None: 32 | args.eval_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 33 | train_datasets = args.train_dataset 34 | exam_datasets = args.eval_datasets 35 | 36 | if args.wb_project: 37 | wandb_args = {"project": args.wb_project} 38 | wandb_args["name"] = args.wb_runname if args.wb_runname else None 39 | wandb.init(**wandb_args, config=vars(args), save_code=False) 40 | 41 | args.data_location = '../data' 42 | args.model = model 43 | args.save = '../checkpoints/' + model 44 | args.logs_path = '../logs/' + model 45 | pretrained_checkpoint = '../checkpoints/'+model+'/zeroshot.pt' 46 | 47 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 48 | log = create_log_dir(args.logs_path, 'log_{}_task_arithmetic.txt'.format(str_time_)) 49 | 50 | task_vectors = [ 51 | TaskVector(pretrained_checkpoint, '../checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in train_datasets 52 | ] 53 | 54 | 55 | startt = time.time() 56 | task_vector_sum = sum(task_vectors) 57 | scaling_coef_ = 0.3 58 | image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=scaling_coef_) 59 | log.info('*'*20 + 'scaling_coef:' + str(scaling_coef_) + '*'*20) 60 | 61 | print(f'Use {args.device}') 62 | accs = [] 63 | metric_dict = {} 64 | for dataset in exam_datasets: 65 | metrics = eval_single_dataset(image_encoder, dataset, args) 66 | log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%') 67 | accs.append(metrics.get('top1')*100) 68 | metric_dict[f"{str(dataset)}_Acc"] = metrics.get('top1')*100 69 | log.info('Avg ACC:' + str(np.mean(accs)) + '%') 70 | metric_dict[f"Avg_Acc"] = np.mean(accs) 71 | runtime = time.time() - startt 72 | if args.wb_project: 73 | wandb.log(metric_dict) -------------------------------------------------------------------------------- /dawin_mtl/README.md: -------------------------------------------------------------------------------- 1 | # DaWin - Multi-task Learning Experiments 2 | 3 | Set the path first 4 | ``` 5 | cd dawin_mtl 6 | mkdir cache checkpoints 7 | export PYTHONPATH="$PYTHONPATH:$PWD" 8 | ``` 9 | 10 | ## Prepare datasets 11 | * Following previous works, we evaluate DaWin on eight visual recognition benchmarks: SUN397, Stanford Cars, RESISC45, EuroSAT, SVHN, GTSRB, MNIST, and DTD. 12 | * Refer to dataset processing in the [task_vectors](https://github.com/mlfoundations/task_vectors). 13 | * Or you can download the processed data from [Baidu Cloud disk](https://pan.baidu.com/s/1w0Z2UVv3NVmqDhjH8WTOJQ?pwd=kvg6) or [HugggingFace](https://huggingface.co/collections/tanganke/image-classification-datasets-662abda7d75efe6b0e6b43da). 14 | * After download datasets, you may need to run `split_dataset.py` to appropriately set splits per dataset 15 | * Then, make a symbolic link here `ln -s YOURDATAPATH ./dawin/dawin_mtl/data` 16 | 17 | ## Prepare checkpoints 18 | * You can mannually download the eight fine-tuned models' checkpoints (and corresponding classifier heads) from the [here](https://github.com/mlfoundations/task_vectors#checkpoints). 19 | * [Public Google Drive url of Task Arithmetic authors](https://drive.google.com/drive/folders/1u_Tva6x0p6oxu5Eo0ZZsf-520Cc_3MKw) 20 | * Or you can refer to `python download_ckpt.py` script to download all the required checkpoints. 21 | 22 | ## Run 23 | Run the following commands at `src/` to reproduce Table 6 -- main MTL experiments. 24 | 25 | * Run DaWin (Ours) 26 | > `python main_dawin.py` 27 | * Run Individual Fine-tuned Models' Evaluation 28 | > `python main_individuals.py` 29 | * Run Weight Averaging [paper1](https://arxiv.org/abs/2203.05482), [paper2](https://arxiv.org/abs/2208.05592) 30 | > `python main_weight_avg.py` 31 | * Run Task Atithmetic [paper](https://arxiv.org/abs/2212.04089) 32 | > `python main_task_arithmetic.py` 33 | * Run TIES-MERGING [paper](https://arxiv.org/abs/2306.01708) 34 | > `python main_ties_merging.py` 35 | * Run Task-wise AdaMerging [paper](https://arxiv.org/abs/2310.02575) 36 | > `python main_task_wise_adamerging.py` 37 | * Run Task-wise AdaMerging++ [paper](https://arxiv.org/abs/2310.02575) 38 | > `python main_task_wise_adamergingpp.py` 39 | * Run Layer-wise AdaMerging [paper](https://arxiv.org/abs/2310.02575) 40 | > `python main_layer_wise_adamerging.py` 41 | * Run Layer-wise AdaMerging++ [paper](https://arxiv.org/abs/2310.02575) 42 | > `python main_layer_wise_adamergingpp.py` 43 | 44 | ## Acknowledgement 45 | This repository is built on top of [AdaMerging](https://github.com/EnnengYang/AdaMerging) and other code blocks are somewhat borrowed from the [ModelSoup](https://github.com/mlfoundations/model-soups) project, we appreciate the authors' endeavors. -------------------------------------------------------------------------------- /dawin_mtl/src/main_weight_avg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | 7 | import time 8 | import sys 9 | 10 | from task_vectors import TaskVector 11 | from eval import eval_single_dataset 12 | from args import parse_arguments 13 | import wandb 14 | 15 | def create_log_dir(path, filename='log.txt'): 16 | import logging 17 | if not os.path.exists(path): 18 | os.makedirs(path) 19 | logger = logging.getLogger(path) 20 | logger.setLevel(logging.DEBUG) 21 | fh = logging.FileHandler(path+'/'+filename) 22 | fh.setLevel(logging.DEBUG) 23 | ch = logging.StreamHandler() 24 | ch.setLevel(logging.DEBUG) 25 | logger.addHandler(fh) 26 | logger.addHandler(ch) 27 | return logger 28 | 29 | 30 | model = 'ViT-B-32' 31 | args = parse_arguments() 32 | 33 | if args.train_dataset is None: 34 | args.train_dataset = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 35 | if args.eval_datasets is None: 36 | args.eval_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 37 | train_datasets = args.train_dataset 38 | exam_datasets = args.eval_datasets 39 | 40 | if args.wb_project: 41 | wandb_args = {"project": args.wb_project} 42 | wandb_args["name"] = args.wb_runname if args.wb_runname else None 43 | wandb.init(**wandb_args, config=vars(args), save_code=False) 44 | 45 | args.data_location = '../data' 46 | args.model = model 47 | args.save = '../checkpoints/' + model 48 | args.logs_path = '../logs/' + model 49 | pretrained_checkpoint = '../checkpoints/'+model+'/zeroshot.pt' 50 | 51 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 52 | log = create_log_dir(args.logs_path, 'log_{}_task_arithmetic.txt'.format(str_time_)) 53 | 54 | task_vectors = [ 55 | TaskVector(pretrained_checkpoint, '../checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets 56 | ] 57 | 58 | task_vector_sum = sum(task_vectors) 59 | 60 | #if args.ft_only: 61 | scaling_coef_ = 1/8 62 | ptcoef = 1.0 63 | # else: 64 | # scaling_coef_ = 1/9 65 | # ptcoef = 1.0 66 | 67 | # if args.pt_only: 68 | # scaling_coef_ = 0.0 69 | # else: 70 | # pass 71 | 72 | 73 | image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=scaling_coef_,pt_scaling_coef=ptcoef) 74 | log.info('*'*20 + 'scaling_coef:' + str(scaling_coef_) + '*'*20) 75 | 76 | print(f'Use {args.device}') 77 | accs = [] 78 | metric_dict = {} 79 | for dataset in exam_datasets: 80 | metrics = eval_single_dataset(image_encoder, dataset, args) 81 | log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%') 82 | accs.append(metrics.get('top1')*100) 83 | metric_dict[f"{str(dataset)}_Acc"] = metrics.get('top1')*100 84 | log.info('Avg ACC:' + str(np.mean(accs)) + '%') 85 | metric_dict[f"Avg_Acc"] = np.mean(accs) 86 | if args.wb_project: 87 | wandb.log(metric_dict) -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | import re 5 | 6 | def pretify_classname(classname): 7 | l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname) 8 | l = [i.lower() for i in l] 9 | out = ' '.join(l) 10 | if out.endswith('al'): 11 | return out + ' area' 12 | return out 13 | 14 | class EuroSATBase: 15 | def __init__(self, 16 | preprocess, 17 | test_split, 18 | location='~/data', 19 | batch_size=32, 20 | num_workers=0): 21 | # Data loading code 22 | traindir = os.path.join(location, 'EuroSAT_splits', 'train') 23 | testdir = os.path.join(location, 'EuroSAT_splits', test_split) 24 | 25 | 26 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 27 | self.train_loader = torch.utils.data.DataLoader( 28 | self.train_dataset, 29 | shuffle=True, 30 | batch_size=batch_size, 31 | num_workers=num_workers, 32 | ) 33 | 34 | self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess) 35 | self.test_loader = torch.utils.data.DataLoader( 36 | self.test_dataset, 37 | batch_size=batch_size, 38 | num_workers=num_workers 39 | ) 40 | self.test_loader_shuffle = torch.utils.data.DataLoader( 41 | self.test_dataset, 42 | shuffle=True, 43 | batch_size=batch_size, 44 | num_workers=num_workers 45 | ) 46 | idx_to_class = dict((v, k) 47 | for k, v in self.train_dataset.class_to_idx.items()) 48 | 49 | self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))] 50 | self.classnames = [pretify_classname(c) for c in self.classnames] 51 | ours_to_open_ai = { 52 | 'annual crop': 'annual crop land', 53 | 'forest': 'forest', 54 | 'herbaceous vegetation': 'brushland or shrubland', 55 | 'highway': 'highway or road', 56 | 'industrial area': 'industrial buildings or commercial buildings', 57 | 'pasture': 'pasture land', 58 | 'permanent crop': 'permanent crop land', 59 | 'residential area': 'residential buildings or homes or apartments', 60 | 'river': 'river', 61 | 'sea lake': 'lake or sea', 62 | } 63 | for i in range(len(self.classnames)): 64 | self.classnames[i] = ours_to_open_ai[self.classnames[i]] 65 | 66 | 67 | class EuroSAT(EuroSATBase): 68 | def __init__(self, 69 | preprocess, 70 | location='~/data', 71 | batch_size=32, 72 | num_workers=0): 73 | super().__init__(preprocess, 'test', location, batch_size, num_workers) 74 | 75 | 76 | class EuroSATVal(EuroSATBase): 77 | def __init__(self, 78 | preprocess, 79 | location='~/data', 80 | batch_size=32, 81 | num_workers=0): 82 | super().__init__(preprocess, 'val', location, batch_size, num_workers) 83 | -------------------------------------------------------------------------------- /dawin_mtl/src/task_vectors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TaskVector(): 5 | def __init__(self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None): 6 | """Initializes the task vector from a pretrained and a finetuned checkpoints. 7 | 8 | This can either be done by passing two state dicts (one corresponding to the 9 | pretrained model, and another to the finetuned model), or by directly passying in 10 | the task vector state dict. 11 | """ 12 | if vector is not None: 13 | self.vector = vector 14 | else: 15 | assert pretrained_checkpoint is not None and finetuned_checkpoint is not None 16 | with torch.no_grad(): 17 | print('TaskVector:' + finetuned_checkpoint) 18 | pretrained_state_dict = torch.load(pretrained_checkpoint).state_dict() 19 | finetuned_state_dict = torch.load(finetuned_checkpoint).state_dict() 20 | self.vector = {} 21 | for key in pretrained_state_dict: 22 | if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]: 23 | continue 24 | self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key] 25 | 26 | def __add__(self, other): 27 | """Add two task vectors together.""" 28 | with torch.no_grad(): 29 | new_vector = {} 30 | for key in self.vector: 31 | if key not in other.vector: 32 | print(f'Warning, key {key} is not present in both task vectors.') 33 | continue 34 | new_vector[key] = self.vector[key] + other.vector[key] 35 | return TaskVector(vector=new_vector) 36 | 37 | def __radd__(self, other): 38 | if other is None or isinstance(other, int): 39 | return self 40 | return self.__add__(other) 41 | 42 | def __neg__(self): 43 | """Negate a task vector.""" 44 | with torch.no_grad(): 45 | new_vector = {} 46 | for key in self.vector: 47 | new_vector[key] = - self.vector[key] 48 | return TaskVector(vector=new_vector) 49 | 50 | def weightmerging(self, taskvectors, coefficients): 51 | with torch.no_grad(): 52 | new_vector = {} 53 | for key in taskvectors[0].vector: 54 | new_vector[key] = sum(coefficients[k] * taskvectors[k][key] for k in range(len(taskvectors))) 55 | return TaskVector(vector=new_vector) 56 | 57 | def apply_to(self, pretrained_checkpoint, scaling_coef=1.0, pt_scaling_coef=1.0): 58 | """Apply a task vector to a pretrained model.""" 59 | with torch.no_grad(): 60 | pretrained_model = torch.load(pretrained_checkpoint) 61 | new_state_dict = {} 62 | pretrained_state_dict = pretrained_model.state_dict() 63 | for key in pretrained_state_dict: 64 | if key not in self.vector: 65 | print(f'Warning: key {key} is present in the pretrained state dict but not in the task vector') 66 | continue 67 | new_state_dict[key] = pt_scaling_coef * pretrained_state_dict[key] + scaling_coef * self.vector[key] 68 | pretrained_model.load_state_dict(new_state_dict, strict=False) 69 | return pretrained_model 70 | 71 | -------------------------------------------------------------------------------- /dawin_mtl/src/main_ties_merging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | 5 | from eval import eval_single_dataset 6 | from args import parse_arguments 7 | import wandb 8 | 9 | def create_log_dir(path, filename='log.txt'): 10 | import logging 11 | if not os.path.exists(path): 12 | os.makedirs(path) 13 | logger = logging.getLogger(path) 14 | logger.setLevel(logging.DEBUG) 15 | fh = logging.FileHandler(path+'/'+filename) 16 | fh.setLevel(logging.DEBUG) 17 | ch = logging.StreamHandler() 18 | ch.setLevel(logging.DEBUG) 19 | logger.addHandler(fh) 20 | logger.addHandler(ch) 21 | return logger 22 | 23 | args = parse_arguments() 24 | model = 'ViT-B-32' 25 | 26 | if args.train_dataset is None: 27 | args.train_dataset = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 28 | if args.eval_datasets is None: 29 | args.eval_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 30 | train_datasets = args.train_dataset 31 | exam_datasets = args.eval_datasets 32 | 33 | if args.wb_project: 34 | wandb_args = {"project": args.wb_project} 35 | wandb_args["name"] = args.wb_runname if args.wb_runname else None 36 | wandb.init(**wandb_args, config=vars(args), save_code=False) 37 | 38 | args.data_location = '../data' 39 | args.model = model 40 | args.save = '../checkpoints/' + model 41 | args.logs_path = '../logs/' + model 42 | pretrained_checkpoint = '../checkpoints/'+model+'/zeroshot.pt' 43 | 44 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 45 | log = create_log_dir(args.logs_path, 'log_{}_ties_merging.txt'.format(str_time_)) 46 | 47 | from ties_merging_utils import * 48 | ft_checks = [torch.load('../checkpoints/'+model+'/'+dataset_name+'/finetuned.pt').state_dict() for dataset_name in train_datasets] 49 | ptm_check = torch.load(pretrained_checkpoint).state_dict() 50 | check_parameterNamesMatch(ft_checks + [ptm_check]) 51 | 52 | remove_keys = [] 53 | print(f"Flattening out Checkpoints") 54 | flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks]) 55 | flat_ptm = state_dict_to_vector(ptm_check, remove_keys) 56 | 57 | tv_flat_checks = flat_ft - flat_ptm 58 | assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check) 59 | assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))]) 60 | 61 | 62 | K = 20 63 | merge_func = "dis-sum" 64 | scaling_coef_ = 0.3 65 | 66 | startt = time.time() 67 | merged_tv = ties_merging(tv_flat_checks, reset_thresh=K, merge_func=merge_func,) 68 | merged_check = flat_ptm + scaling_coef_ * merged_tv 69 | 70 | merged_state_dict = vector_to_state_dict(merged_check, ptm_check, remove_keys=remove_keys) 71 | 72 | image_encoder = torch.load(pretrained_checkpoint) 73 | image_encoder.load_state_dict(merged_state_dict, strict=False) 74 | 75 | metric_dict = {} 76 | Total_ACC = 0. 77 | for dataset in exam_datasets: 78 | metrics = eval_single_dataset(image_encoder, dataset, args) 79 | Total_ACC += metrics['top1'] * 100 80 | log.info(str(dataset) + ':' + str(metrics)) 81 | metric_dict[f"{str(dataset)}_Acc"] = metrics['top1']*100 82 | log.info('Final: ' + 'Avg ACC:' + str(Total_ACC / len(exam_datasets))) 83 | metric_dict[f"Avg_Acc"] = Total_ACC / len(exam_datasets) 84 | if args.wb_project: 85 | wandb.log(metric_dict) -------------------------------------------------------------------------------- /dawin_rft/zeroshot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import clip 5 | import os 6 | from tqdm import tqdm 7 | 8 | import datasets 9 | from utils import ModelWrapper, test_model_on_dataset 10 | from openai_imagenet_template import openai_imagenet_template 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | "--data-location", 16 | type=str, 17 | default=os.path.expanduser('~/data'), 18 | help="The root directory for the datasets.", 19 | ) 20 | parser.add_argument( 21 | "--model-location", 22 | type=str, 23 | default=os.path.expanduser('~/ssd/checkpoints/soups'), 24 | help="Where to download the models.", 25 | ) 26 | parser.add_argument( 27 | "--batch-size", 28 | type=int, 29 | default=512, 30 | ) 31 | parser.add_argument( 32 | "--custom-template", action="store_true", default=False, 33 | ) 34 | parser.add_argument( 35 | "--dataset", default="ImageNet", 36 | help=f"Must be one of {','.join(['ImageNet', 'ImageNetV2', 'ImageNetR', 'ObjectNet', 'ImageNetA'])}" 37 | 38 | ) 39 | parser.add_argument( 40 | "--workers", 41 | type=int, 42 | default=4, 43 | ) 44 | parser.add_argument( 45 | "--model", 46 | default='ViT-B/32', 47 | help='Model to use -- you can try another like ViT-L/14' 48 | ) 49 | return parser.parse_args() 50 | 51 | def zeroshot_classifier(model, classnames, templates, device): 52 | print('Building zero-shot classifier.') 53 | with torch.no_grad(): 54 | zeroshot_weights = [] 55 | for classname in tqdm(classnames): 56 | texts = [template(classname) for template in templates] #format with class 57 | texts = clip.tokenize(texts).to(device) #tokenize 58 | class_embeddings = model.encode_text(texts) 59 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 60 | class_embedding = class_embeddings.mean(dim=0) 61 | class_embedding /= class_embedding.norm() 62 | zeroshot_weights.append(class_embedding) 63 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 64 | return 100*zeroshot_weights.t() 65 | 66 | 67 | if __name__ == '__main__': 68 | args = parse_arguments() 69 | DEVICE = 'cuda' 70 | assert args.dataset in ['ImageNet', 'ImageNetV2', 'ImageNetR', 'ObjectNet', 'ImageNetA'] 71 | 72 | if args.custom_template: 73 | template = [lambda x : f"a photo of a {x}."] 74 | else: 75 | template = openai_imagenet_template 76 | 77 | base_model, preprocess = clip.load(args.model, 'cuda', jit=False) 78 | dset = getattr(datasets, args.dataset)(preprocess, location=args.data_location, batch_size=args.batch_size, num_workers=args.workers) 79 | clf = zeroshot_classifier(base_model, dset.classnames, template, DEVICE) 80 | NUM_CLASSES = len(dset.classnames) 81 | feature_dim = base_model.visual.output_dim 82 | 83 | model = ModelWrapper(base_model, feature_dim, NUM_CLASSES, normalize=True, initial_weights=clf) 84 | for p in model.parameters(): 85 | p.data = p.data.float() 86 | 87 | model = model.cuda() 88 | devices = [x for x in range(torch.cuda.device_count())] 89 | model = torch.nn.DataParallel(model, device_ids=devices) 90 | 91 | accuracy = test_model_on_dataset(model, dset) 92 | 93 | print(f'Accuracy is {round(100 * accuracy, 2)}.') 94 | -------------------------------------------------------------------------------- /dawin_rft/openai_imagenet_template.py: -------------------------------------------------------------------------------- 1 | openai_imagenet_template = [ 2 | lambda c: f'a bad photo of a {c}.', 3 | lambda c: f'a photo of many {c}.', 4 | lambda c: f'a sculpture of a {c}.', 5 | lambda c: f'a photo of the hard to see {c}.', 6 | lambda c: f'a low resolution photo of the {c}.', 7 | lambda c: f'a rendering of a {c}.', 8 | lambda c: f'graffiti of a {c}.', 9 | lambda c: f'a bad photo of the {c}.', 10 | lambda c: f'a cropped photo of the {c}.', 11 | lambda c: f'a tattoo of a {c}.', 12 | lambda c: f'the embroidered {c}.', 13 | lambda c: f'a photo of a hard to see {c}.', 14 | lambda c: f'a bright photo of a {c}.', 15 | lambda c: f'a photo of a clean {c}.', 16 | lambda c: f'a photo of a dirty {c}.', 17 | lambda c: f'a dark photo of the {c}.', 18 | lambda c: f'a drawing of a {c}.', 19 | lambda c: f'a photo of my {c}.', 20 | lambda c: f'the plastic {c}.', 21 | lambda c: f'a photo of the cool {c}.', 22 | lambda c: f'a close-up photo of a {c}.', 23 | lambda c: f'a black and white photo of the {c}.', 24 | lambda c: f'a painting of the {c}.', 25 | lambda c: f'a painting of a {c}.', 26 | lambda c: f'a pixelated photo of the {c}.', 27 | lambda c: f'a sculpture of the {c}.', 28 | lambda c: f'a bright photo of the {c}.', 29 | lambda c: f'a cropped photo of a {c}.', 30 | lambda c: f'a plastic {c}.', 31 | lambda c: f'a photo of the dirty {c}.', 32 | lambda c: f'a jpeg corrupted photo of a {c}.', 33 | lambda c: f'a blurry photo of the {c}.', 34 | lambda c: f'a photo of the {c}.', 35 | lambda c: f'a good photo of the {c}.', 36 | lambda c: f'a rendering of the {c}.', 37 | lambda c: f'a {c} in a video game.', 38 | lambda c: f'a photo of one {c}.', 39 | lambda c: f'a doodle of a {c}.', 40 | lambda c: f'a close-up photo of the {c}.', 41 | lambda c: f'a photo of a {c}.', 42 | lambda c: f'the origami {c}.', 43 | lambda c: f'the {c} in a video game.', 44 | lambda c: f'a sketch of a {c}.', 45 | lambda c: f'a doodle of the {c}.', 46 | lambda c: f'a origami {c}.', 47 | lambda c: f'a low resolution photo of a {c}.', 48 | lambda c: f'the toy {c}.', 49 | lambda c: f'a rendition of the {c}.', 50 | lambda c: f'a photo of the clean {c}.', 51 | lambda c: f'a photo of a large {c}.', 52 | lambda c: f'a rendition of a {c}.', 53 | lambda c: f'a photo of a nice {c}.', 54 | lambda c: f'a photo of a weird {c}.', 55 | lambda c: f'a blurry photo of a {c}.', 56 | lambda c: f'a cartoon {c}.', 57 | lambda c: f'art of a {c}.', 58 | lambda c: f'a sketch of the {c}.', 59 | lambda c: f'a embroidered {c}.', 60 | lambda c: f'a pixelated photo of a {c}.', 61 | lambda c: f'itap of the {c}.', 62 | lambda c: f'a jpeg corrupted photo of the {c}.', 63 | lambda c: f'a good photo of a {c}.', 64 | lambda c: f'a plushie {c}.', 65 | lambda c: f'a photo of the nice {c}.', 66 | lambda c: f'a photo of the small {c}.', 67 | lambda c: f'a photo of the weird {c}.', 68 | lambda c: f'the cartoon {c}.', 69 | lambda c: f'art of the {c}.', 70 | lambda c: f'a drawing of the {c}.', 71 | lambda c: f'a photo of the large {c}.', 72 | lambda c: f'a black and white photo of a {c}.', 73 | lambda c: f'the plushie {c}.', 74 | lambda c: f'a dark photo of a {c}.', 75 | lambda c: f'itap of a {c}.', 76 | lambda c: f'graffiti of the {c}.', 77 | lambda c: f'a toy {c}.', 78 | lambda c: f'itap of my {c}.', 79 | lambda c: f'a photo of a cool {c}.', 80 | lambda c: f'a photo of a small {c}.', 81 | lambda c: f'a tattoo of the {c}.', 82 | ] -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/registry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import random 4 | import torch 5 | import copy 6 | 7 | from torch.utils.data.dataset import random_split 8 | 9 | from datasets.cars import Cars 10 | from datasets.dtd import DTD 11 | from datasets.eurosat import EuroSAT, EuroSATVal 12 | from datasets.gtsrb import GTSRB 13 | from datasets.mnist import MNIST 14 | from datasets.resisc45 import RESISC45 15 | from datasets.svhn import SVHN 16 | from datasets.sun397 import SUN397 17 | 18 | registry = { 19 | name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) 20 | } 21 | 22 | 23 | class GenericDataset(object): 24 | def __init__(self): 25 | self.train_dataset = None 26 | self.train_loader = None 27 | self.test_dataset = None 28 | self.test_loader = None 29 | self.classnames = None 30 | 31 | 32 | def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0): 33 | assert val_fraction > 0. and val_fraction < 1. 34 | total_size = len(dataset.train_dataset) 35 | val_size = int(total_size * val_fraction) 36 | if max_val_samples is not None: 37 | val_size = min(val_size, max_val_samples) 38 | train_size = total_size - val_size 39 | 40 | assert val_size > 0 41 | assert train_size > 0 42 | 43 | lengths = [train_size, val_size] 44 | 45 | trainset, valset = random_split( 46 | dataset.train_dataset, 47 | lengths, 48 | generator=torch.Generator().manual_seed(seed) 49 | ) 50 | if new_dataset_class_name == 'MNISTVal': 51 | assert trainset.indices[0] == 36044 52 | 53 | 54 | new_dataset = None 55 | 56 | new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {}) 57 | new_dataset = new_dataset_class() 58 | 59 | new_dataset.train_dataset = trainset 60 | new_dataset.train_loader = torch.utils.data.DataLoader( 61 | new_dataset.train_dataset, 62 | shuffle=True, 63 | batch_size=batch_size, 64 | num_workers=num_workers, 65 | ) 66 | 67 | new_dataset.test_dataset = valset 68 | new_dataset.test_loader = torch.utils.data.DataLoader( 69 | new_dataset.test_dataset, 70 | batch_size=batch_size, 71 | num_workers=num_workers 72 | ) 73 | 74 | new_dataset.test_loader_shuffle = torch.utils.data.DataLoader( 75 | new_dataset.test_dataset, 76 | shuffle=True, 77 | batch_size=batch_size, 78 | num_workers=num_workers 79 | ) 80 | 81 | new_dataset.classnames = copy.copy(dataset.classnames) 82 | 83 | return new_dataset 84 | 85 | 86 | def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=0, val_fraction=0.1, max_val_samples=5000): 87 | if dataset_name.endswith('Val'): 88 | # Handle val splits 89 | if dataset_name in registry: 90 | dataset_class = registry[dataset_name] 91 | else: 92 | base_dataset_name = dataset_name.split('Val')[0] 93 | base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers) 94 | dataset = split_train_into_train_val( 95 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples) 96 | return dataset 97 | else: 98 | assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}' 99 | dataset_class = registry[dataset_name] 100 | dataset = dataset_class( 101 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 102 | ) 103 | return dataset 104 | -------------------------------------------------------------------------------- /dawin_mtl/src/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | def parse_arguments(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--data-location", 10 | type=str, 11 | default=os.path.expanduser('~/data'), 12 | help="The root directory for the datasets.", 13 | ) 14 | parser.add_argument( 15 | "--eval-datasets", 16 | default=None, 17 | type=lambda x: x.split(","), 18 | help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. " 19 | ) 20 | parser.add_argument( 21 | "--train-dataset", 22 | default=None, 23 | type=lambda x: x.split(","), 24 | help="Which dataset(s) to patch on.", 25 | ) 26 | parser.add_argument( 27 | "--exp_name", 28 | type=str, 29 | default=None, 30 | help="Name of the experiment, for organization purposes only." 31 | ) 32 | parser.add_argument( 33 | "--results-db", 34 | type=str, 35 | default=None, 36 | help="Where to store the results, else does not store", 37 | ) 38 | parser.add_argument( 39 | "--model", 40 | type=str, 41 | default=None, 42 | help="The type of model (e.g. RN50, ViT-B-32).", 43 | ) 44 | parser.add_argument( 45 | "--batch-size", 46 | type=int, 47 | default=2048, 48 | ) 49 | parser.add_argument( 50 | "--lr", 51 | type=float, 52 | default=0.001, 53 | help="Learning rate." 54 | ) 55 | parser.add_argument( 56 | "--wd", 57 | type=float, 58 | default=0.1, 59 | help="Weight decay" 60 | ) 61 | parser.add_argument( 62 | "--ls", 63 | type=float, 64 | default=0.0, 65 | help="Label smoothing." 66 | ) 67 | parser.add_argument( 68 | "--warmup_length", 69 | type=int, 70 | default=500, 71 | ) 72 | parser.add_argument( 73 | "--epochs", 74 | type=int, 75 | default=10, 76 | ) 77 | parser.add_argument( 78 | "--load", 79 | type=lambda x: x.split(","), 80 | default=None, 81 | help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", 82 | ) 83 | parser.add_argument( 84 | "--save", 85 | type=str, 86 | default=None, 87 | help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.", 88 | ) 89 | parser.add_argument( 90 | "--save-coef", 91 | type=int, 92 | default=1, 93 | help="Optionally save DaWin's coefficients", 94 | ) 95 | parser.add_argument( 96 | "--analysis", 97 | type=int, 98 | default=0, 99 | ) 100 | parser.add_argument( 101 | "--cache-dir", 102 | type=str, 103 | default=None, 104 | help="Directory for caching features and encoder", 105 | ) 106 | parser.add_argument( 107 | "--openclip-cachedir", 108 | type=str, 109 | default='/gscratch/efml/gamaga/.cache/open_clip', 110 | help='Directory for caching models from OpenCLIP' 111 | ) 112 | parser.add_argument( 113 | "--set-alphas", 114 | type=str, 115 | default='pre-trained', 116 | choices=['from_scratch','pre-trained','ours'], 117 | ) 118 | parser.add_argument( 119 | "--clustering", 120 | type=str, 121 | default='dmm', 122 | ) 123 | parser.add_argument( 124 | "--ft-only", 125 | type=int, 126 | default=1, 127 | ) 128 | parser.add_argument( 129 | "--pt-only", 130 | type=int, 131 | default=0, 132 | ) 133 | 134 | parser.add_argument( 135 | "--ncluster", 136 | type=int, 137 | default=1, 138 | ) 139 | parser.add_argument( 140 | "--eval-only", 141 | type=int, 142 | default=1, 143 | ) 144 | parser.add_argument( 145 | "--method", 146 | type=str, 147 | default='', 148 | ) 149 | parser.add_argument( 150 | "--wb_project", 151 | type=str, 152 | default='', 153 | ) 154 | parser.add_argument( 155 | "--wb_runname", 156 | type=str, 157 | default='', 158 | ) 159 | parsed_args = parser.parse_args() 160 | parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu" 161 | 162 | if parsed_args.load is not None and len(parsed_args.load) == 1: 163 | parsed_args.load = parsed_args.load[0] 164 | return parsed_args 165 | -------------------------------------------------------------------------------- /dawin_rft/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import clip 5 | import os 6 | from datasets import ImageNet2p, ImageNet, ImageNetV2, ImageNetSketch, ImageNetR, ObjectNet, ImageNetA 7 | from utils import get_model_from_sd, test_model_on_dataset 8 | import wandb 9 | 10 | def interpolation(alpha, weight1, weight2): 11 | return {key: (1 - alpha) * weight1[key] + alpha * weight2[key] for key in weight1.keys()} 12 | 13 | def parse_arguments(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--data-location", 17 | type=str, 18 | default=os.path.expanduser('data'), 19 | help="The root directory for the datasets.", 20 | ) 21 | parser.add_argument( 22 | "--model-location", 23 | type=str, 24 | default=os.path.expanduser('checkpoints/soups'), 25 | help="Where to download the models.", 26 | ) 27 | parser.add_argument( 28 | "--eval-single-model", type=str, default='', 29 | ) 30 | parser.add_argument( 31 | "--wiseft-alpha", type=float, default=-0.1, 32 | ) 33 | parser.add_argument( 34 | "--wiseft-ftpath", type=str, default='', 35 | ) 36 | parser.add_argument( 37 | "--wiseft-zspath", type=str, default='', 38 | ) 39 | parser.add_argument( 40 | "--model", type=str, default='ViT-B/32', 41 | ) 42 | 43 | parser.add_argument( 44 | "--batch-size", 45 | type=int, 46 | default=256, 47 | ) 48 | parser.add_argument( 49 | "--workers", 50 | type=int, 51 | default=8, 52 | ) 53 | parser.add_argument( 54 | "--wb_project", type=str, default='', # weight and bias 55 | ) 56 | parser.add_argument( 57 | "--wb_runname", type=str, default='', # weight and bias 58 | ) 59 | return parser.parse_args() 60 | 61 | 62 | if __name__ == '__main__': 63 | args = parse_arguments() 64 | if args.wb_project: 65 | wandb_args = {"project": args.wb_project} 66 | wandb_args["name"] = args.wb_runname if args.wb_runname else None 67 | wandb.init(**wandb_args, config=vars(args), save_code=False) 68 | 69 | base_model, preprocess = clip.load(args.model, 'cpu', jit=False) 70 | 71 | if args.eval_single_model: 72 | model_path = [os.path.join(args.model_location, f'{args.eval_single_model}')] 73 | 74 | assert os.path.exists(model_path) 75 | state_dict = torch.load(model_path, map_location=torch.device('cpu')) 76 | model = get_model_from_sd(state_dict, base_model) 77 | 78 | results = {'model_name' : f'{args.eval_single_model}'} 79 | for dataset_cls in [ImageNet2p, ImageNet, ImageNetV2, ImageNetSketch, ImageNetR, ObjectNet, ImageNetA]: 80 | print(f'Evaluating model-{args.eval_single_model} on {dataset_cls.__name__}.') 81 | dataset = dataset_cls(preprocess, args.data_location, args.batch_size, args.workers) 82 | accuracy = test_model_on_dataset(model, dataset) 83 | results[dataset_cls.__name__] = accuracy 84 | print(accuracy) 85 | 86 | results['OODavg'] = 1./5 * (results['ImageNetV2'] + 87 | results['ImageNetR'] + results['ImageNetSketch'] + 88 | results['ObjectNet'] + results['ImageNetA']) 89 | 90 | if args.wb_project: 91 | wandb.log({k+'_Acc': v for k, v in results.items()}) 92 | 93 | 94 | if args.wiseft_alpha >= 0.0: 95 | alpha = args.wiseft_alpha 96 | ft_state_dict = torch.load(args.wiseft_ftpath, map_location=torch.device('cuda')) 97 | zs_state_dict = torch.load(args.wiseft_zspath, map_location=torch.device('cuda')) 98 | 99 | merged = interpolation(alpha, zs_state_dict, ft_state_dict) 100 | model = get_model_from_sd(merged, base_model) 101 | 102 | results = {'model_name' : f'wiseft_{alpha:.2f}'} 103 | for i, dataset_cls in enumerate([ImageNet, ImageNetV2, ImageNetR, ImageNetA, ImageNetSketch, ObjectNet]): 104 | print(f'Evaluating on {dataset_cls.__name__} with coef {alpha}.') 105 | dataset = dataset_cls(preprocess, args.data_location, args.batch_size, args.workers) 106 | accuracy = test_model_on_dataset(model, dataset) 107 | results[dataset_cls.__name__] = accuracy 108 | print(accuracy) 109 | 110 | results['OODavg'] = 1./5 * (results['ImageNetV2'] + 111 | results['ImageNetR'] + results['ImageNetSketch'] + 112 | results['ObjectNet'] + results['ImageNetA']) 113 | 114 | if args.wb_project: 115 | wandb.log({k+'_Acc': v for k, v in results.items()}) -------------------------------------------------------------------------------- /dawin.yaml: -------------------------------------------------------------------------------- 1 | name: dawin 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotli-python=1.0.9=py311h6a678d5_8 11 | - bzip2=1.0.8=h5eee18b_6 12 | - ca-certificates=2024.3.11=h06a4308_0 13 | - certifi=2024.2.2=py311h06a4308_0 14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 15 | - cuda-cudart=11.8.89=0 16 | - cuda-cupti=11.8.87=0 17 | - cuda-libraries=11.8.0=0 18 | - cuda-nvrtc=11.8.89=0 19 | - cuda-nvtx=11.8.86=0 20 | - cuda-runtime=11.8.0=0 21 | - ffmpeg=4.3=hf484d3e_0 22 | - filelock=3.13.1=py311h06a4308_0 23 | - freetype=2.12.1=h4a9f257_0 24 | - gmp=6.2.1=h295c915_3 25 | - gmpy2=2.1.2=py311hc9b5ff0_0 26 | - gnutls=3.6.15=he1e5248_0 27 | - idna=3.7=py311h06a4308_0 28 | - intel-openmp=2023.1.0=hdb19cb5_46306 29 | - jinja2=3.1.3=py311h06a4308_0 30 | - jpeg=9e=h5eee18b_1 31 | - lame=3.100=h7b6447c_0 32 | - lcms2=2.12=h3be6417_0 33 | - ld_impl_linux-64=2.38=h1181459_1 34 | - lerc=3.0=h295c915_0 35 | - libcublas=11.11.3.6=0 36 | - libcufft=10.9.0.58=0 37 | - libcufile=1.9.1.3=0 38 | - libcurand=10.3.5.147=0 39 | - libcusolver=11.4.1.48=0 40 | - libcusparse=11.7.5.86=0 41 | - libdeflate=1.17=h5eee18b_1 42 | - libffi=3.4.4=h6a678d5_1 43 | - libgcc-ng=11.2.0=h1234567_1 44 | - libgomp=11.2.0=h1234567_1 45 | - libiconv=1.16=h5eee18b_3 46 | - libidn2=2.3.4=h5eee18b_0 47 | - libjpeg-turbo=2.0.0=h9bf148f_0 48 | - libnpp=11.8.0.86=0 49 | - libnvjpeg=11.9.0.86=0 50 | - libpng=1.6.39=h5eee18b_0 51 | - libstdcxx-ng=11.2.0=h1234567_1 52 | - libtasn1=4.19.0=h5eee18b_0 53 | - libtiff=4.5.1=h6a678d5_0 54 | - libunistring=0.9.10=h27cfd23_0 55 | - libuuid=1.41.5=h5eee18b_0 56 | - libwebp-base=1.3.2=h5eee18b_0 57 | - llvm-openmp=14.0.6=h9e868ea_0 58 | - lz4-c=1.9.4=h6a678d5_1 59 | - markupsafe=2.1.3=py311h5eee18b_0 60 | - mkl=2023.1.0=h213fc3f_46344 61 | - mkl-service=2.4.0=py311h5eee18b_1 62 | - mkl_fft=1.3.8=py311h5eee18b_0 63 | - mkl_random=1.2.4=py311hdb19cb5_0 64 | - mpc=1.1.0=h10f8cd9_1 65 | - mpfr=4.0.2=hb69a4c5_1 66 | - mpmath=1.3.0=py311h06a4308_0 67 | - ncurses=6.4=h6a678d5_0 68 | - nettle=3.7.3=hbbd107a_1 69 | - networkx=3.1=py311h06a4308_0 70 | - numpy=1.26.4=py311h08b1b3b_0 71 | - numpy-base=1.26.4=py311hf175353_0 72 | - openh264=2.1.1=h4ff587b_0 73 | - openjpeg=2.4.0=h3ad879b_0 74 | - openssl=3.0.13=h7f8727e_1 75 | - pillow=10.3.0=py311h5eee18b_0 76 | - pip=24.0=py311h06a4308_0 77 | - pysocks=1.7.1=py311h06a4308_0 78 | - python=3.11.9=h955ad1f_0 79 | - pytorch=2.1.1=py3.11_cuda11.8_cudnn8.7.0_0 80 | - pytorch-cuda=11.8=h7e8668a_5 81 | - pytorch-mutex=1.0=cuda 82 | - pyyaml=6.0.1=py311h5eee18b_0 83 | - readline=8.2=h5eee18b_0 84 | - requests=2.31.0=py311h06a4308_1 85 | - setuptools=69.5.1=py311h06a4308_0 86 | - sqlite=3.45.3=h5eee18b_0 87 | - sympy=1.12=py311h06a4308_0 88 | - tbb=2021.8.0=hdb19cb5_0 89 | - tk=8.6.14=h39e8969_0 90 | - torchaudio=2.1.1=py311_cu118 91 | - torchtriton=2.1.0=py311 92 | - torchvision=0.16.1=py311_cu118 93 | - typing_extensions=4.11.0=py311h06a4308_0 94 | - urllib3=2.2.1=py311h06a4308_0 95 | - wheel=0.43.0=py311h06a4308_0 96 | - xz=5.4.6=h5eee18b_1 97 | - yaml=0.2.5=h7b6447c_0 98 | - zlib=1.2.13=h5eee18b_1 99 | - zstd=1.5.5=hc292b87_2 100 | - pip: 101 | - bayesian-optimization==1.4.3 102 | - click==8.1.7 103 | - cma==3.3.0 104 | - colorama==0.4.6 105 | - contourpy==1.2.1 106 | - cycler==0.12.1 107 | - docker-pycreds==0.4.0 108 | - faiss-cpu==1.8.0.post1 109 | - fcmaes==1.6.5 110 | - fonttools==4.51.0 111 | - fsspec==2024.5.0 112 | - ftfy==6.2.0 113 | - gitdb==4.0.11 114 | - gitpython==3.1.43 115 | - huggingface-hub==0.23.4 116 | - git+https://github.com/modestyachts/ImageNetV2_pytorch 117 | - joblib==1.4.2 118 | - kiwisolver==1.4.5 119 | - lightning-utilities==0.11.2 120 | - llvmlite==0.43.0 121 | - open-clip-torch==2.0.2 122 | - loguru==0.7.2 123 | - matplotlib==3.8.4 124 | - numba==0.60.0 125 | - opencv-python==4.10.0.84 126 | - packaging==24.0 127 | - pandas==2.2.2 128 | - platformdirs==4.2.2 129 | - plotly==5.22.0 130 | - protobuf==4.25.3 131 | - psutil==5.9.8 132 | - pyparsing==3.1.2 133 | - python-dateutil==2.9.0.post0 134 | - pytz==2024.1 135 | - regex==2024.5.15 136 | - safetensors==0.4.3 137 | - scikit-learn==1.4.2 138 | - scipy==1.13.0 139 | - sentry-sdk==2.1.1 140 | - setproctitle==1.3.3 141 | - six==1.16.0 142 | - smmap==5.0.1 143 | - tenacity==8.5.0 144 | - threadpoolctl==3.5.0 145 | - timm==1.0.3 146 | - torchmetrics==1.4.0.post0 147 | - tqdm==4.66.4 148 | - tzdata==2024.1 149 | - wandb==0.17.0 150 | - wcwidth==0.2.13 151 | - wget==3.2 152 | -------------------------------------------------------------------------------- /dawin_mtl/split_dataset.py: -------------------------------------------------------------------------------- 1 | base_dir = f'/YOURPATH/data' 2 | 3 | ### PROCESS SUN397 DATASET 4 | import os 5 | import shutil 6 | from pathlib import Path 7 | downloaded_data_path = f"{base_dir}/SUN397" 8 | output_path = f"{base_dir}/SUN397" 9 | 10 | def process_dataset(txt_file, downloaded_data_path, output_folder): 11 | with open(txt_file, 'r') as file: 12 | lines = file.readlines() 13 | 14 | for i, line in enumerate(lines): 15 | input_path = line.strip() 16 | final_folder_name = "_".join(x for x in input_path.split('/')[:-1])[1:] 17 | filename = input_path.split('/')[-1] 18 | output_class_folder = os.path.join(output_folder, final_folder_name) 19 | 20 | if not os.path.exists(output_class_folder): 21 | os.makedirs(output_class_folder) 22 | 23 | full_input_path = os.path.join(downloaded_data_path, input_path[1:]) 24 | output_file_path = os.path.join(output_class_folder, filename) 25 | 26 | shutil.copy(full_input_path, output_file_path) 27 | if i % 100 == 0: 28 | print(f"Processed {i}/{len(lines)} images") 29 | 30 | process_dataset( 31 | os.path.join(downloaded_data_path, 'Training_01.txt'), 32 | #os.path.join(downloaded_data_path, 'SUN397'), 33 | downloaded_data_path, 34 | os.path.join(output_path, "train") 35 | ) 36 | process_dataset( 37 | os.path.join(downloaded_data_path, 'Testing_01.txt'), 38 | #os.path.join(downloaded_data_path, 'SUN397'), 39 | downloaded_data_path, 40 | os.path.join(output_path, "val") 41 | ) 42 | 43 | ### PROCESS DTD DATASET 44 | import os 45 | import shutil 46 | from pathlib import Path 47 | downloaded_data_path = f"{base_dir}/dtd/images" 48 | output_path = f"{base_dir}/dtd" 49 | 50 | def process_dataset(txt_file, downloaded_data_path, output_folder): 51 | with open(txt_file, 'r') as file: 52 | lines = file.readlines() 53 | 54 | for i, line in enumerate(lines): 55 | input_path = line.strip() 56 | final_folder_name = input_path.split('/')[:-1][0] 57 | filename = input_path.split('/')[-1] 58 | output_class_folder = os.path.join(output_folder, final_folder_name) 59 | 60 | if not os.path.exists(output_class_folder): 61 | os.makedirs(output_class_folder) 62 | 63 | full_input_path = os.path.join(downloaded_data_path, input_path) 64 | output_file_path = os.path.join(output_class_folder, filename) 65 | shutil.copy(full_input_path, output_file_path) 66 | if i % 100 == 0: 67 | print(f"Processed {i}/{len(lines)} images") 68 | 69 | process_dataset( 70 | f'{base_dir}/dtd/labels/train1.txt', downloaded_data_path, os.path.join(output_path, "train") 71 | ) 72 | process_dataset( 73 | f'{base_dir}/dtd/labels/test1.txt', downloaded_data_path, os.path.join(output_path, "val") 74 | ) 75 | 76 | ### PROCESS EuroSAT_RGB DATASET 77 | src_dir = f'{base_dir}/2750' # replace with the path to your dataset 78 | dst_dir = f'{base_dir}/EuroSAT_splits' # replace with the path to the output directory 79 | 80 | import os 81 | import shutil 82 | import random 83 | 84 | def create_directory_structure(dst_dir, classes): 85 | for dataset in ['train', 'val', 'test']: 86 | path = os.path.join(dst_dir, dataset) 87 | os.makedirs(path, exist_ok=True) 88 | for cls in classes: 89 | os.makedirs(os.path.join(path, cls), exist_ok=True) 90 | 91 | def split_dataset(dst_dir, src_dir, classes, val_size=270, test_size=270): 92 | for cls in classes: 93 | class_path = os.path.join(src_dir, cls) 94 | images = os.listdir(class_path) 95 | random.shuffle(images) 96 | 97 | val_images = images[:val_size] 98 | test_images = images[val_size:val_size + test_size] 99 | train_images = images[val_size + test_size:] 100 | 101 | for img in train_images: 102 | src_path = os.path.join(class_path, img) 103 | dst_path = os.path.join(dst_dir, 'train', cls, img) 104 | print(src_path, dst_path) 105 | shutil.copy(src_path, dst_path) 106 | # break 107 | for img in val_images: 108 | src_path = os.path.join(class_path, img) 109 | dst_path = os.path.join(dst_dir, 'val', cls, img) 110 | print(src_path, dst_path) 111 | shutil.copy(src_path, dst_path) 112 | # break 113 | for img in test_images: 114 | src_path = os.path.join(class_path, img) 115 | dst_path = os.path.join(dst_dir, 'test', cls, img) 116 | print(src_path, dst_path) 117 | shutil.copy(src_path, dst_path) 118 | # break 119 | 120 | classes = [d for d in os.listdir(src_dir) if os.path.isdir(os.path.join(src_dir, d))] 121 | create_directory_structure(dst_dir, classes) 122 | split_dataset(dst_dir, src_dir, classes) -------------------------------------------------------------------------------- /dawin_rft/datasets/objectnet_metadata/objectnet_to_imagenet_1k.json: -------------------------------------------------------------------------------- 1 | { 2 | "Alarm clock": "analog clock; digital clock", 3 | "Backpack": "backpack, back pack, knapsack, packsack, rucksack, haversack", 4 | "Banana": "banana", 5 | "Band Aid": "Band Aid", 6 | "Basket": "shopping basket", 7 | "Bath towel": "bath towel", 8 | "Beer bottle": "beer bottle", 9 | "Bench": "park bench", 10 | "Bicycle": "mountain bike, all-terrain bike, off-roader; bicycle-built-for-two, tandem bicycle, tandem", 11 | "Binder (closed)": "binder, ring-binder", 12 | "Bottle cap": "bottlecap", 13 | "Bread loaf": "French loaf", 14 | "Broom": "broom", 15 | "Bucket": "bucket, pail", 16 | "Butcher's knife": "cleaver, meat cleaver, chopper", 17 | "Can opener": "can opener, tin opener", 18 | "Candle": "candle, taper, wax light", 19 | "Cellphone": "cellular telephone, cellular phone, cellphone, cell, mobile phone", 20 | "Chair": "barber chair; folding chair; rocking chair, rocker", 21 | "Clothes hamper": "hamper", 22 | "Coffee/French press": "espresso maker", 23 | "Combination lock": "combination lock", 24 | "Computer mouse": "mouse, computer mouse", 25 | "Desk lamp": "table lamp", 26 | "Dishrag or hand towel": "dishrag, dishcloth", 27 | "Doormat": "doormat, welcome mat", 28 | "Dress shoe (men)": "Loafer", 29 | "Drill": "power drill", 30 | "Drinking Cup": "cup", 31 | "Drying rack for plates": "plate rack", 32 | "Envelope": "envelope", 33 | "Fan": "electric fan, blower", 34 | "Frying pan": "frying pan, frypan, skillet", 35 | "Dress": "gown", 36 | "Hair dryer": "hand blower, blow dryer, blow drier, hair dryer, hair drier", 37 | "Hammer": "hammer", 38 | "Helmet": "football helmet; crash helmet", 39 | "Iron (for clothes)": "iron, smoothing iron", 40 | "Jeans": "jean, blue jean, denim", 41 | "Keyboard": "computer keyboard, keypad", 42 | "Ladle": "ladle", 43 | "Lampshade": "lampshade, lamp shade", 44 | "Laptop (open)": "laptop, laptop computer", 45 | "Lemon": "lemon", 46 | "Letter opener": "letter opener, paper knife, paperknife", 47 | "Lighter": "lighter, light, igniter, ignitor", 48 | "Lipstick": "lipstick, lip rouge", 49 | "Match": "matchstick", 50 | "Measuring cup": "measuring cup", 51 | "Microwave": "microwave, microwave oven", 52 | "Mixing / Salad Bowl": "mixing bowl", 53 | "Monitor": "monitor", 54 | "Mug": "coffee mug", 55 | "Nail (fastener)": "nail", 56 | "Necklace": "necklace", 57 | "Orange": "orange", 58 | "Padlock": "padlock", 59 | "Paintbrush": "paintbrush", 60 | "Paper towel": "paper towel", 61 | "Pen": "ballpoint, ballpoint pen, ballpen, Biro; quill, quill pen; fountain pen", 62 | "Pill bottle": "pill bottle", 63 | "Pillow": "pillow", 64 | "Pitcher": "pitcher, ewer", 65 | "Plastic bag": "plastic bag", 66 | "Plate": "plate", 67 | "Plunger": "plunger, plumber's helper", 68 | "Pop can": "pop bottle, soda bottle", 69 | "Portable heater": "space heater", 70 | "Printer": "printer", 71 | "Remote control": "remote control, remote", 72 | "Ruler": "rule, ruler", 73 | "Running shoe": "running shoe", 74 | "Safety pin": "safety pin", 75 | "Salt shaker": "saltshaker, salt shaker", 76 | "Sandal": "sandal", 77 | "Screw": "screw", 78 | "Shovel": "shovel", 79 | "Skirt": "hoopskirt, crinoline; miniskirt, mini; overskirt", 80 | "Sleeping bag": "sleeping bag", 81 | "Soap dispenser": "soap dispenser", 82 | "Sock": "sock", 83 | "Soup Bowl": "soup bowl", 84 | "Spatula": "spatula", 85 | "Speaker": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", 86 | "Still Camera": "Polaroid camera, Polaroid Land camera; reflex camera", 87 | "Strainer": "strainer", 88 | "Stuffed animal": "teddy, teddy bear", 89 | "Suit jacket": "suit, suit of clothes", 90 | "Sunglasses": "sunglasses, dark glasses, shades", 91 | "Sweater": "sweatshirt", 92 | "Swimming trunks": "swimming trunks, bathing trunks", 93 | "T-shirt": "jersey, T-shirt, tee shirt", 94 | "TV": "television, television system", 95 | "Teapot": "teapot", 96 | "Tennis racket": "racket, racquet", 97 | "Tie": "bow tie, bow-tie, bowtie; Windsor tie", 98 | "Toaster": "toaster", 99 | "Toilet paper roll": "toilet tissue, toilet paper, bathroom tissue", 100 | "Trash bin": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", 101 | "Tray": "tray", 102 | "Umbrella": "umbrella", 103 | "Vacuum cleaner": "vacuum, vacuum cleaner", 104 | "Vase": "vase", 105 | "Wallet": "wallet, billfold, notecase, pocketbook", 106 | "Watch": "digital watch", 107 | "Water bottle": "water bottle", 108 | "Weight (exercise)": "dumbbell", 109 | "Weight scale": "scale, weighing machine", 110 | "Wheel": "car wheel; paddlewheel, paddle wheel", 111 | "Whistle": "whistle", 112 | "Wine bottle": "wine bottle", 113 | "Winter glove": "mitten", 114 | "Wok": "wok" 115 | } 116 | -------------------------------------------------------------------------------- /dawin_rft/datasets/objectnet_beta_metadata/objectnet_to_imagenet_1k.json: -------------------------------------------------------------------------------- 1 | { 2 | "Alarm clock": "analog clock; digital clock", 3 | "Backpack": "backpack, back pack, knapsack, packsack, rucksack, haversack", 4 | "Banana": "banana", 5 | "Band Aid": "Band Aid", 6 | "Basket": "shopping basket", 7 | "Bath towel": "bath towel", 8 | "Beer bottle": "beer bottle", 9 | "Bench": "park bench", 10 | "Bicycle": "mountain bike, all-terrain bike, off-roader; bicycle-built-for-two, tandem bicycle, tandem", 11 | "Binder (closed)": "binder, ring-binder", 12 | "Bottle cap": "bottlecap", 13 | "Bread loaf": "French loaf", 14 | "Broom": "broom", 15 | "Bucket": "bucket, pail", 16 | "Butcher's knife": "cleaver, meat cleaver, chopper", 17 | "Can opener": "can opener, tin opener", 18 | "Candle": "candle, taper, wax light", 19 | "Cellphone": "cellular telephone, cellular phone, cellphone, cell, mobile phone", 20 | "Chair": "barber chair; folding chair; rocking chair, rocker", 21 | "Clothes hamper": "hamper", 22 | "Coffee/French press": "espresso maker", 23 | "Combination lock": "combination lock", 24 | "Computer mouse": "mouse, computer mouse", 25 | "Desk lamp": "table lamp", 26 | "Dishrag or hand towel": "dishrag, dishcloth", 27 | "Doormat": "doormat, welcome mat", 28 | "Dress shoe (men)": "Loafer", 29 | "Drill": "power drill", 30 | "Drinking Cup": "cup", 31 | "Drying rack for plates": "plate rack", 32 | "Envelope": "envelope", 33 | "Fan": "electric fan, blower", 34 | "Frying pan": "frying pan, frypan, skillet", 35 | "Dress": "gown", 36 | "Hair dryer": "hand blower, blow dryer, blow drier, hair dryer, hair drier", 37 | "Hammer": "hammer", 38 | "Helmet": "football helmet; crash helmet", 39 | "Iron (for clothes)": "iron, smoothing iron", 40 | "Jeans": "jean, blue jean, denim", 41 | "Keyboard": "computer keyboard, keypad", 42 | "Ladle": "ladle", 43 | "Lampshade": "lampshade, lamp shade", 44 | "Laptop (open)": "laptop, laptop computer", 45 | "Lemon": "lemon", 46 | "Letter opener": "letter opener, paper knife, paperknife", 47 | "Lighter": "lighter, light, igniter, ignitor", 48 | "Lipstick": "lipstick, lip rouge", 49 | "Match": "matchstick", 50 | "Measuring cup": "measuring cup", 51 | "Microwave": "microwave, microwave oven", 52 | "Mixing / Salad Bowl": "mixing bowl", 53 | "Monitor": "monitor", 54 | "Mug": "coffee mug", 55 | "Nail (fastener)": "nail", 56 | "Necklace": "necklace", 57 | "Orange": "orange", 58 | "Padlock": "padlock", 59 | "Paintbrush": "paintbrush", 60 | "Paper towel": "paper towel", 61 | "Pen": "ballpoint, ballpoint pen, ballpen, Biro; quill, quill pen; fountain pen", 62 | "Pill bottle": "pill bottle", 63 | "Pillow": "pillow", 64 | "Pitcher": "pitcher, ewer", 65 | "Plastic bag": "plastic bag", 66 | "Plate": "plate", 67 | "Plunger": "plunger, plumber's helper", 68 | "Pop can": "pop bottle, soda bottle", 69 | "Portable heater": "space heater", 70 | "Printer": "printer", 71 | "Remote control": "remote control, remote", 72 | "Ruler": "rule, ruler", 73 | "Running shoe": "running shoe", 74 | "Safety pin": "safety pin", 75 | "Salt shaker": "saltshaker, salt shaker", 76 | "Sandal": "sandal", 77 | "Screw": "screw", 78 | "Shovel": "shovel", 79 | "Skirt": "hoopskirt, crinoline; miniskirt, mini; overskirt", 80 | "Sleeping bag": "sleeping bag", 81 | "Soap dispenser": "soap dispenser", 82 | "Sock": "sock", 83 | "Soup Bowl": "soup bowl", 84 | "Spatula": "spatula", 85 | "Speaker": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", 86 | "Still Camera": "Polaroid camera, Polaroid Land camera; reflex camera", 87 | "Strainer": "strainer", 88 | "Stuffed animal": "teddy, teddy bear", 89 | "Suit jacket": "suit, suit of clothes", 90 | "Sunglasses": "sunglasses, dark glasses, shades", 91 | "Sweater": "sweatshirt", 92 | "Swimming trunks": "swimming trunks, bathing trunks", 93 | "T-shirt": "jersey, T-shirt, tee shirt", 94 | "TV": "television, television system", 95 | "Teapot": "teapot", 96 | "Tennis racket": "racket, racquet", 97 | "Tie": "bow tie, bow-tie, bowtie; Windsor tie", 98 | "Toaster": "toaster", 99 | "Toilet paper roll": "toilet tissue, toilet paper, bathroom tissue", 100 | "Trash bin": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", 101 | "Tray": "tray", 102 | "Umbrella": "umbrella", 103 | "Vacuum cleaner": "vacuum, vacuum cleaner", 104 | "Vase": "vase", 105 | "Wallet": "wallet, billfold, notecase, pocketbook", 106 | "Watch": "digital watch", 107 | "Water bottle": "water bottle", 108 | "Weight (exercise)": "dumbbell", 109 | "Weight scale": "scale, weighing machine", 110 | "Wheel": "car wheel; paddlewheel, paddle wheel", 111 | "Whistle": "whistle", 112 | "Wine bottle": "wine bottle", 113 | "Winter glove": "mitten", 114 | "Wok": "wok" 115 | } 116 | -------------------------------------------------------------------------------- /dawin_rft/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import glob 5 | import collections 6 | import random 7 | 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | import torchvision.datasets as datasets 13 | from torch.utils.data import Dataset, DataLoader, Sampler 14 | 15 | 16 | class SubsetSampler(Sampler): 17 | def __init__(self, indices): 18 | self.indices = indices 19 | 20 | def __iter__(self): 21 | return (i for i in self.indices) 22 | 23 | def __len__(self): 24 | return len(self.indices) 25 | 26 | class ImageFolderWithPaths(datasets.ImageFolder): 27 | def __init__(self, path, transform, flip_label_prob=0.0): 28 | super().__init__(path, transform) 29 | self.flip_label_prob = flip_label_prob 30 | if self.flip_label_prob > 0: 31 | print(f'Flipping labels with probability {self.flip_label_prob}') 32 | num_classes = len(self.classes) 33 | for i in range(len(self.samples)): 34 | if random.random() < self.flip_label_prob: 35 | new_label = random.randint(0, num_classes-1) 36 | self.samples[i] = ( 37 | self.samples[i][0], 38 | new_label 39 | ) 40 | 41 | def __getitem__(self, index): 42 | image, label = super(ImageFolderWithPaths, self).__getitem__(index) 43 | return { 44 | 'images': image, 45 | 'labels': label, 46 | 'image_paths': self.samples[index][0] 47 | } 48 | 49 | 50 | def maybe_dictionarize(batch): 51 | if isinstance(batch, dict): 52 | return batch 53 | 54 | if len(batch) == 2: 55 | batch = {'images': batch[0], 'labels': batch[1]} 56 | elif len(batch) == 3: 57 | batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]} 58 | else: 59 | raise ValueError(f'Unexpected number of elements: {len(batch)}') 60 | 61 | return batch 62 | 63 | 64 | def get_features_helper(image_encoder, dataloader, device): 65 | all_data = collections.defaultdict(list) 66 | 67 | image_encoder = image_encoder.to(device) 68 | image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) 69 | image_encoder.eval() 70 | 71 | with torch.no_grad(): 72 | for batch in tqdm(dataloader): 73 | batch = maybe_dictionarize(batch) 74 | features = image_encoder(batch['images'].cuda()) 75 | 76 | all_data['features'].append(features.cpu()) 77 | 78 | for key, val in batch.items(): 79 | if key == 'images': 80 | continue 81 | if hasattr(val, 'cpu'): 82 | val = val.cpu() 83 | all_data[key].append(val) 84 | else: 85 | all_data[key].extend(val) 86 | 87 | for key, val in all_data.items(): 88 | if torch.is_tensor(val[0]): 89 | all_data[key] = torch.cat(val).numpy() 90 | 91 | return all_data 92 | 93 | 94 | def get_features(is_train, image_encoder, dataset, device): 95 | split = 'train' if is_train else 'val' 96 | dname = type(dataset).__name__ 97 | if image_encoder.cache_dir is not None: 98 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' 99 | cached_files = glob.glob(f'{cache_dir}/*') 100 | if image_encoder.cache_dir is not None and len(cached_files) > 0: 101 | print(f'Getting features from {cache_dir}') 102 | data = {} 103 | for cached_file in cached_files: 104 | name = os.path.splitext(os.path.basename(cached_file))[0] 105 | data[name] = torch.load(cached_file) 106 | else: 107 | print(f'Did not find cached features at {cache_dir}. Building from scratch.') 108 | loader = dataset.train_loader if is_train else dataset.test_loader 109 | data = get_features_helper(image_encoder, loader, device) 110 | if image_encoder.cache_dir is None: 111 | print('Not caching because no cache directory was passed.') 112 | else: 113 | os.makedirs(cache_dir, exist_ok=True) 114 | print(f'Caching data at {cache_dir}') 115 | for name, val in data.items(): 116 | torch.save(val, f'{cache_dir}/{name}.pt') 117 | return data 118 | 119 | 120 | class FeatureDataset(Dataset): 121 | def __init__(self, is_train, image_encoder, dataset, device): 122 | self.data = get_features(is_train, image_encoder, dataset, device) 123 | 124 | def __len__(self): 125 | return len(self.data['features']) 126 | 127 | def __getitem__(self, idx): 128 | data = {k: v[idx] for k, v in self.data.items()} 129 | data['features'] = torch.from_numpy(data['features']).float() 130 | return data 131 | 132 | 133 | def get_dataloader(dataset, is_train, args, image_encoder=None): 134 | if image_encoder is not None: 135 | feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device) 136 | dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train) 137 | else: 138 | dataloader = dataset.train_loader if is_train else dataset.test_loader 139 | return dataloader 140 | 141 | def get_dataloader_shuffle(dataset): 142 | dataloader = dataset.test_loader_shuffle 143 | return dataloader -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | DaWin 2 | Copyright (c) 2025-present NAVER Cloud Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | -------------------------------------------------------------------------------------- 17 | 18 | This project contains subcomponents with separate copyright notices and license terms. 19 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 20 | 21 | ===== 22 | 23 | EnnengYang/AdaMerging 24 | https://github.com/EnnengYang/AdaMerging 25 | 26 | 27 | MIT License 28 | 29 | Copyright (c) 2023 Enneng Yang 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in all 39 | copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 47 | SOFTWARE. 48 | 49 | ===== 50 | 51 | mlfoundations/model-soups 52 | https://github.com/mlfoundations/model-soups 53 | 54 | 55 | MIT License 56 | 57 | Copyright (c) 2023 mlfoundations 58 | 59 | Permission is hereby granted, free of charge, to any person obtaining a copy 60 | of this software and associated documentation files (the "Software"), to deal 61 | in the Software without restriction, including without limitation the rights 62 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 63 | copies of the Software, and to permit persons to whom the Software is 64 | furnished to do so, subject to the following conditions: 65 | 66 | The above copyright notice and this permission notice shall be included in all 67 | copies or substantial portions of the Software. 68 | 69 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 70 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 71 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 72 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 73 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 74 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 75 | SOFTWARE. 76 | 77 | ===== 78 | 79 | openai/CLIP 80 | https://github.com/openai/CLIP 81 | 82 | 83 | MIT License 84 | 85 | Copyright (c) 2021 OpenAI 86 | 87 | Permission is hereby granted, free of charge, to any person obtaining a copy 88 | of this software and associated documentation files (the "Software"), to deal 89 | in the Software without restriction, including without limitation the rights 90 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 91 | copies of the Software, and to permit persons to whom the Software is 92 | furnished to do so, subject to the following conditions: 93 | 94 | The above copyright notice and this permission notice shall be included in all 95 | copies or substantial portions of the Software. 96 | 97 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 98 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 99 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 100 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 101 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 102 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 103 | SOFTWARE. 104 | 105 | ===== 106 | 107 | mlfoundations/wise-ft 108 | https://github.com/mlfoundations/wise-ft 109 | 110 | 111 | Copyright (c) 2021-2022 Mitchell Wortsman, Gabriel Ilharco, 112 | Jong Wook Kim, Mike Li, Hannaneh Hajishirzi, Ali Farhadi, 113 | Hongseok Namkoong, Ludwig Schmidt. 114 | 115 | Permission is hereby granted, free of charge, to any person obtaining 116 | a copy of this software and associated documentation files (the 117 | "Software"), to deal in the Software without restriction, including 118 | without limitation the rights to use, copy, modify, merge, publish, 119 | distribute, sublicense, and/or sell copies of the Software, and to 120 | permit persons to whom the Software is furnished to do so, subject to 121 | the following conditions: 122 | 123 | The above copyright notice and this permission notice shall be 124 | included in all copies or substantial portions of the Software. 125 | 126 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 127 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 128 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 129 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 130 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 131 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 132 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 133 | 134 | ===== 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [DaWin: Training-free Dynamic Weight Interpolation for Robust Adaptation](https://arxiv.org/abs/2410.03782) 2 | 3 | > [Changdae Oh1,2*](https://changdaeoh.github.io/), [Sharon Li2](https://pages.cs.wisc.edu/~sharonli/), [Kyungwoo Song3†](https://gtshs2.github.io/), [Sangdoo Yun1†](https://sangdooyun.github.io/), [Dongyoon Han1†](https://dongyoonhan.github.io/),
4 | > * Work done during an internship at NAVER AI Lab, † corresponding authors
5 | > 1[NAVER AI LAB](https://naver-career.gitbook.io/en/teams/clova-cic/ai-lab), 2[University of Wisconsin--Madison](https://www.wisc.edu/), 3[Yonsei University](https://www.yonsei.ac.kr/en_sc/index.jsp) 6 | 7 | 8 | [![paper](https://img.shields.io/badge/arXiv-Paper-red.svg)](https://arxiv.org/abs/2410.03782) 9 | [![Paper](https://img.shields.io/badge/Paper-ICLR_2025-blue)](https://openreview.net/forum?id=L8e7tBf4pP) 10 | 11 |
12 | 13 | image 14 | 15 | 16 | ### Abstract 17 | >Adapting a pre-trained foundation model on downstream tasks should ensure robustness against distribution shifts without the need to retrain the whole model. Although existing weight interpolation methods are simple yet effective, we argue their static nature limits downstream performance while achieving efficiency. In this work, we propose DaWin, a training-free dynamic weight interpolation method that leverages the entropy of individual models over each unlabeled test sample to assess model expertise, and compute per-sample interpolation coefficients dynamically. Unlike previous works that typically rely on additional training to learn such coefficients, our approach requires no training. Then, we propose a mixture modeling approach that greatly reduces inference overhead raised by dynamic interpolation. We validate DaWin on the large-scale visual recognition benchmarks, spanning 14 tasks across robust fine-tuning -- ImageNet and derived five distribution shift benchmarks -- and multi-task learning with eight classification tasks. Results demonstrate that DaWin achieves significant performance gain in considered settings, with minimal computational overhead. We further discuss DaWin's analytic behavior to explain its empirical success. 18 | 19 | 20 | ### Updates 21 | * (2025/03/14): Our code is now available! 22 | * (2025/01/22): Our manuscript has been accepted at [ICLR 2025](https://iclr.cc/)🎉🎉; 23 | * (2024/10/09): A short version of the preprint has been accepted at [NeurIPS 2024 Workshop on Adaptive Foundation Models](https://adaptive-foundation-models.org/)🎉 24 | * (2024/10/03): Code is under internal review. 25 | * (2024/10/03): [Preprint](https://arxiv.org/abs/2410.03782) has been uploaded. 26 | 27 | --- 28 | 29 | ## Installation 30 | > conda env create -f dawin.yaml 31 | 32 | ## Application-specific Instructions 33 | * To reproduce multi-task learning experiments, refer to `dawin_rft/` and corresponding `README.md` for detailed instructions. 34 | * To reproduce multi-task learning experiments, refer to `dawin_mtl/` and corresponding `README.md` for detailed instructions. 35 | 36 | --- 37 | 38 | ## Quick Start with Pseudo-code 39 | ```python 40 | def interpolation(alpha, weight1, weight2): 41 | return {key: (1 - alpha) * weight1[key] + alpha * weight2[key] for key in weight1.keys()} 42 | 43 | basemodel, _ = clip.load(args.model, 'cpu', jit=False) 44 | 45 | # load state_dict of individual models to be interpolated 46 | sd_pt = torch.load(args.pt_path, map_location=torch.device(args.device)) 47 | sd_ft = torch.load(args.ft_path, map_location=torch.device(args.device)) 48 | 49 | # get entropy from the outputs of each model for all samples 50 | logit_pt = get_logits(basemodel, dataset_name=args.dsname, state_dict=sd_pt) 51 | logit_ft = get_logits(basemodel, dataset_name=args.dsname, state_dict=sd_ft) 52 | ent_pt = - (F.softmax(logit_pt,dim=1) * F.log_softmax(logit_pt, dim=1)).sum(dim=1) 53 | ent_ft = - (F.softmax(logit_ft,dim=1) * F.log_softmax(logit_ft, dim=1)).sum(dim=1) 54 | 55 | # exponentiated negative entropy as model expertise to weigh the interpolation 56 | expertise_pt = (-ent_pt).exp() 57 | expertise_ft = (-ent_ft).exp() 58 | lambdas = (expertise_ft) / (expertise_pt + expertise_ft) 59 | 60 | # sample-wise interpolation (w/o Beta Mixture Modeling) 61 | eval_dataloader = torch.utils.data.DataLoader(..., batch_size=1, shuffle=False) 62 | correct, n = 0., 0. 63 | for i, (inputs, labels) in enumerate(eval_dataloader): 64 | inputs, labels = inputs.cuda(), labels.cuda() 65 | merged_sd = interpolation(lambdas[i], sd_pt, sd_ft) 66 | model = get_model_from_sd(merged_sd, basemodel) 67 | logits = model(inputs) 68 | 69 | preds = logits.argmax(dim=1, keepdim=True).to(device) 70 | correct_current = preds.eq(labels.view_as(preds)).sum().item() 71 | correct += correct_current 72 | n += labels.size(0) 73 | 74 | top1_acc = correct / n 75 | ``` 76 | 77 | --- 78 | 79 | ## How to cite 80 | ``` 81 | @inproceedings{ 82 | oh2025dawin, 83 | title={DaWin: Training-free Dynamic Weight Interpolation for Robust Adaptation}, 84 | author={Changdae Oh and Yixuan Li and Kyungwoo Song and Sangdoo Yun and Dongyoon Han}, 85 | booktitle={International Conference on Learning Representations}, 86 | year={2025}, 87 | url={https://openreview.net/forum?id=L8e7tBf4pP} 88 | } 89 | ``` 90 | 91 | ## License 92 | ``` 93 | DaWin 94 | Copyright (c) 2025-present NAVER Cloud Corp. 95 | 96 | Licensed under the Apache License, Version 2.0 (the "License"); 97 | you may not use this file except in compliance with the License. 98 | You may obtain a copy of the License at 99 | 100 | http://www.apache.org/licenses/LICENSE-2.0 101 | 102 | Unless required by applicable law or agreed to in writing, software 103 | distributed under the License is distributed on an "AS IS" BASIS, 104 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 105 | See the License for the specific language governing permissions and 106 | limitations under the License. 107 | ``` 108 | -------------------------------------------------------------------------------- /dawin_rft/README.md: -------------------------------------------------------------------------------- 1 | # DaWin - Robust Fine-tuning Experiments 2 | 3 | ## Setup 4 | * Start with installation for the environment with `conda env create -f dawin.yml` 5 | * **Directory structure** 6 | * The directory structure should seem like below. 7 | ``` 8 | ├── cache # after the first run, cache files will be saved here 9 | │ ├── CLIPVITB32_ft_ImageNet_cache.pt 10 | │ ├── CLIPVITB32_zs_ImageNet_cache.pt 11 | │ ├── CLIPVITB32_ft_ImageNetV2_cache.pt 12 | │ ├── CLIPVITB32_zs_ImageNetV2_cache.pt 13 | │ ├── CLIPVITB32_ft_{*}_cache.pt # {*} denote datasetname 14 | │ ├── CLIPVITB32_zs_{*}_cache.pt 15 | ├── checkpoints 16 | │ ├── zeroshot_clipvitb32.pt 17 | │ ├── finetune_clipvitb32.pt 18 | │ ├── *.pt # checkpoints for baseline methods 19 | ├── data 20 | │ ├── imagenet 21 | │ ├── imagenet-r 22 | │ ├── * 23 | ├── datasets 24 | │ ├── imagenet.py 25 | │ ├── imagenet_r.py 26 | │ ├── * 27 | ├── script 28 | │ ├── dawin.sh 29 | │ ├── dawin_samplewise.sh 30 | │ ├── dawin_applications.sh 31 | │ ├── dawin_ablation.sh 32 | │ ├── dawin_baselines.sh 33 | ├── main_dawin.py 34 | ├── main.py 35 | ├── mixturemodel.py 36 | ├── utils.py 37 | ├── *.py 38 | ``` 39 | * Place your checkpoints (zero-shot, fine-tuned, and other baseline models') in the `checkpoints/` 40 | * Place your datasets in `data/` 41 | 42 | ``` linux 43 | mkdir cache checkpoints data analysis 44 | ln -s 'YOURDATAPATH' ./dawin/dawin_rft/data 45 | ``` 46 | 47 | * **Model checkpoints** 48 | * We reached out the authors of [Model Stock](https://github.com/naver-ai/model-stock) to get their checkpoints including individual fine-tuned models, model soup, and model stock weights that are used for the robust fine-tuning experiments 49 | * Here are the fine-tuned model weights we use for merging: [google drive](https://drive.google.com/drive/folders/1NZ9kxXmsFWuY6oSAozTaaZwSISlZ_RgX?usp=sharing) 50 | * You can also get weights of other baseline methods from the official [Model Stock repository](https://github.com/naver-ai/model-stock/blob/main/notebooks/model_stock_eval.ipynb). 51 | * **Dataset** 52 | * For the robust fine-tuning experiments, we adopted ImageNet distribution shifts benchmarks: 53 | * In-distribution (ID): [ImageNet-1K](https://www.image-net.org/download.php) 54 | * Out-of-distribution (OOD): [ImageNetV2](https://imagenetv2.org/), [ImageNet-Rendition](https://github.com/hendrycks/imagenet-r), [ImageNet-A](https://github.com/hendrycks/natural-adv-examples), [ImageNet-Sketch](https://github.com/HaohanWang/ImageNet-Sketch), and [ObjectNet](https://objectnet.dev/) 55 | * Refer to [`datasets.md`](https://github.com/mlfoundations/wise-ft/blob/master/datasets.md) of WiSE-FT repository to prepare these datasets. 56 | 57 | 58 | ## Run baseline methods 59 | * After preparing the model checkpoints, you could evaluate individual checkpoints (`*.pt`) through the commend as below 60 | > `main.py --eval-single-model *.pt` 61 | * You could also reproduce the WiSE-FT results as below 62 | > `main.py --wiseft-alpha 0.5 --wiseft-zspath checkpoints/zeroshot.pt --wiseft-ftpath checkpoints/finetune.pt` 63 | 64 | ## Run DaWin 65 | We provide the scripts to reproduce main results, ablation study, and extended applications (cover Table 1, 2, 3, 4, 5, Figure 5, and some tables in Appendix). 66 | * You could reproduce DaWin's evaluation across ImageNet variants with CLIP ViT-B/32 as below 67 | ``` 68 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 69 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 70 | --model ViT-B/32 \ 71 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 72 | --eval-batch-size 1024 --bmm_ncluster 3 \ 73 | --offset-adjustment \ 74 | --expertise negexp_ent_ratio 75 | ``` 76 | * refer to `sh script/vitb32_dawin.sh 0` 77 | * When you run the DaWin method for the first time, this produces cache files (at `./cache/`) containing logits and predicts of zero-shot and fine-tuned CLIP. Subsequent runs will be faster by loading these cache files. 78 | * To reproduce the ablation study over expertise metric (such as pseudo label-based X-entropy), run the command below 79 | ``` 80 | for psl in 1 2 3 4 81 | do 82 | 83 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 84 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 85 | --model ViT-B/32 \ 86 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 87 | --eval-batch-size 1024 --bmm_ncluster 3 \ 88 | --offset-adjustment \ 89 | --expertise negexp_loss_ratio --pseudo_label $psl 90 | 91 | done 92 | ``` 93 | * refer to `script/dawin_ablation.sh` 94 | * To reproduce application scenarios -- dynamic classifier selection and dynamic output ensemble -- of DaWin, run the commands below 95 | ``` 96 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 97 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 98 | --model ViT-B/32 \ 99 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 100 | --eval-batch-size 1024 \ 101 | --expertise selective_entropy 102 | ``` 103 | * refer to `script/applications.sh` 104 | * To simulate the sample-wise dynamic interpolation approaches, run the command below 105 | ``` 106 | CUDA_VISIBLE_DEVICES=0 python3 main_dawin.py --seed 1 \ 107 | --cache-dir cache/ --data-location data/ --model-location checkpoints/ \ 108 | --model ViT-B/32 \ 109 | --zs-path checkpoints/zeroshot_clipvitb32.pt --ft-path checkpoints/finetune_clipvitb32.pt \ 110 | --eval-batch-size 1 --bmm_ncluster 0 \ 111 | --expertise negexp_ent_ratio --offset-adjustment 112 | ``` 113 | * refer to `dawin_samplewise.sh` 114 | * Due to its reliance on the sample-wise merging operation, it takes about 2.5 hours to evaluate 50,000 images (e.g., ImageNet test set) on a single NVIDIA A100 GPU. Evalution for entire ImageNet variants require roughly 10 hours. 115 | 116 | ## Acknowledgement 117 | Some code blocks are borrowed from the projects below, we appreciate the authors' endeavors. 118 | - WiSE-FT: https://github.com/mlfoundations/wise-ft 119 | - Model Soups: https://github.com/mlfoundations/model-soups 120 | -------------------------------------------------------------------------------- /dawin_rft/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import SubsetRandomSampler 4 | import numpy as np 5 | 6 | from .common import ImageFolderWithPaths, SubsetSampler 7 | from .imagenet_classnames import get_classnames 8 | 9 | class ImageNet: 10 | def __init__(self, 11 | preprocess, 12 | location=os.path.expanduser('~/data'), 13 | batch_size=32, 14 | num_workers=8, 15 | classnames='openai', 16 | distributed=False): 17 | self.preprocess = preprocess 18 | self.location = location 19 | self.batch_size = batch_size 20 | self.num_workers = num_workers 21 | self.classnames = get_classnames(classnames) 22 | self.distributed = distributed 23 | 24 | self.populate_train() 25 | self.populate_test() 26 | 27 | def populate_train(self): 28 | traindir = os.path.join(self.location, self.name(), 'train') 29 | self.train_dataset = ImageFolderWithPaths( 30 | traindir, 31 | transform=self.preprocess, 32 | ) 33 | sampler = self.get_train_sampler() 34 | self.sampler = sampler 35 | kwargs = {'shuffle' : True} if sampler is None else {} 36 | # print('kwargs is', kwargs) 37 | self.train_loader = torch.utils.data.DataLoader( 38 | self.train_dataset, 39 | sampler=sampler, 40 | batch_size=self.batch_size, 41 | num_workers=self.num_workers, 42 | pin_memory=True, 43 | **kwargs, 44 | ) 45 | 46 | def populate_test(self): 47 | self.test_dataset = self.get_test_dataset() 48 | # import pdb 49 | # pdb.set_trace() 50 | self.test_loader = torch.utils.data.DataLoader( 51 | self.test_dataset, 52 | batch_size=self.batch_size, 53 | num_workers=self.num_workers, 54 | pin_memory=True, 55 | sampler=self.get_test_sampler() 56 | ) 57 | 58 | def get_test_path(self): 59 | test_path = os.path.join(self.location, self.name(), 'val_in_folder') 60 | if not os.path.exists(test_path): 61 | test_path = os.path.join(self.location, self.name(), 'val') 62 | return test_path 63 | 64 | def get_train_sampler(self): 65 | return torch.utils.data.distributed.DistributedSampler(self.train_dataset) if self.distributed else None 66 | 67 | 68 | def get_test_sampler(self): 69 | return None 70 | 71 | def get_test_dataset(self): 72 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess) 73 | 74 | def name(self): 75 | return 'imagenet' 76 | 77 | class ImageNetTrain(ImageNet): 78 | 79 | def get_test_dataset(self): 80 | pass 81 | 82 | def project_logits(logits, class_sublist_mask, device): 83 | if isinstance(logits, list): 84 | return [project_logits(l, class_sublist_mask, device) for l in logits] 85 | if logits.size(1) > sum(class_sublist_mask): 86 | return logits[:, class_sublist_mask].to(device) 87 | else: 88 | return logits.to(device) 89 | 90 | class ImageNetSubsample(ImageNet): 91 | def __init__(self, *args, **kwargs): 92 | super().__init__(*args, **kwargs) 93 | class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask() 94 | self.classnames = [self.classnames[i] for i in class_sublist] 95 | 96 | def get_class_sublist_and_mask(self): 97 | raise NotImplementedError() 98 | 99 | def populate_train(self): 100 | pass 101 | 102 | def project_logits(self, logits, device): 103 | return project_logits(logits, self.class_sublist_mask, device) 104 | 105 | class ImageNetSubsampleValClasses(ImageNet): 106 | def get_class_sublist_and_mask(self): 107 | raise NotImplementedError() 108 | 109 | def populate_train(self): 110 | pass 111 | 112 | def get_test_sampler(self): 113 | self.class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask() 114 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self.class_sublist] 115 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist]) 116 | 117 | sampler = SubsetSampler(idx_subsample_list) 118 | return sampler 119 | 120 | def project_labels(self, labels, device): 121 | projected_labels = [self.class_sublist.index(int(label)) for label in labels] 122 | return torch.LongTensor(projected_labels).to(device) 123 | 124 | def project_logits(self, logits, device): 125 | return project_logits(logits, self.class_sublist_mask, device) 126 | 127 | 128 | class ImageNet98p(ImageNet): 129 | 130 | def get_train_sampler(self): 131 | idx_file = 'imagenet_98_idxs.npy' 132 | assert os.path.exists(idx_file) 133 | #if os.path.exists(idx_file): 134 | with open(idx_file, 'rb') as f: 135 | idxs = np.load(f) 136 | # else: 137 | # idxs = np.zeros(len(self.train_dataset.targets)) 138 | # target_array = np.array(self.train_dataset.targets) 139 | # for c in range(1000): 140 | # m = target_array == c 141 | # n = len(idxs[m]) 142 | # arr = np.zeros(n) 143 | # arr[:26] = 1 144 | # np.random.shuffle(arr) 145 | # idxs[m] = arr 146 | # with open(idx_file, 'wb') as f: 147 | # np.save(f, idxs) 148 | 149 | idxs = (1 - idxs).astype('int') 150 | sampler = SubsetRandomSampler(np.where(idxs)[0]) 151 | 152 | return sampler 153 | 154 | 155 | class ImageNet2p(ImageNet): 156 | 157 | def get_train_sampler(self): 158 | idx_file = 'imagenet_98_idxs.npy' 159 | assert os.path.exists(idx_file) 160 | with open(idx_file, 'rb') as f: 161 | idxs = np.load(f) 162 | 163 | idxs = idxs.astype('int') 164 | sampler = SubsetSampler(np.where(idxs)[0]) 165 | return sampler 166 | 167 | class ImageNet2pShuffled(ImageNet): 168 | 169 | def get_train_sampler(self): 170 | print('shuffling val set.') 171 | idx_file = 'imagenet_98_idxs.npy' 172 | assert os.path.exists(idx_file) 173 | with open(idx_file, 'rb') as f: 174 | idxs = np.load(f) 175 | 176 | idxs = idxs.astype('int') 177 | sampler = SubsetRandomSampler(np.where(idxs)[0]) 178 | return sampler 179 | -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | import pdb 6 | import pathlib 7 | from typing import Callable, Optional, Any, Tuple 8 | 9 | from PIL import Image 10 | 11 | from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | 15 | class PytorchStanfordCars(VisionDataset): 16 | """`Stanford Cars `_ Dataset 17 | 18 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is 19 | split into 8,144 training images and 8,041 testing images, where each class 20 | has been split roughly in a 50-50 split 21 | 22 | .. note:: 23 | 24 | This class needs `scipy `_ to load target files from `.mat` format. 25 | 26 | Args: 27 | root (string): Root directory of dataset 28 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. 29 | transform (callable, optional): A function/transform that takes in an PIL image 30 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 31 | target_transform (callable, optional): A function/transform that takes in the 32 | target and transforms it. 33 | download (bool, optional): If True, downloads the dataset from the internet and 34 | puts it in root directory. If dataset is already downloaded, it is not 35 | downloaded again.""" 36 | 37 | def __init__( 38 | self, 39 | root: str, 40 | split: str = "train", 41 | transform: Optional[Callable] = None, 42 | target_transform: Optional[Callable] = None, 43 | download: bool = False, 44 | ) -> None: 45 | 46 | try: 47 | import scipy.io as sio 48 | except ImportError: 49 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") 50 | 51 | super().__init__(root, transform=transform, target_transform=target_transform) 52 | 53 | self._split = verify_str_arg(split, "split", ("train", "test")) 54 | #pdb.set_trace() 55 | self._base_folder = pathlib.Path(root) / "Cars" 56 | devkit = self._base_folder / "devkit" 57 | #import pdb;pdb.set_trace() 58 | if self._split == "train": 59 | self._annotations_mat_path = devkit / "cars_train_annos.mat" 60 | self._images_base_path = self._base_folder / "cars_train" 61 | else: 62 | self._annotations_mat_path = devkit / "cars_test_annos_withlabels.mat" 63 | self._images_base_path = self._base_folder / "cars_test" 64 | 65 | if download: 66 | self.download() 67 | 68 | if not self._check_exists(): 69 | raise RuntimeError("Dataset not found. You can use download=True to download it") 70 | 71 | self._samples = [ 72 | ( 73 | str(self._images_base_path / annotation["fname"]), 74 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1 75 | ) 76 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] 77 | ] 78 | 79 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() 80 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 81 | 82 | def __len__(self) -> int: 83 | return len(self._samples) 84 | 85 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 86 | """Returns pil_image and class_id for given index""" 87 | image_path, target = self._samples[idx] 88 | pil_image = Image.open(image_path).convert("RGB") 89 | 90 | if self.transform is not None: 91 | pil_image = self.transform(pil_image) 92 | if self.target_transform is not None: 93 | target = self.target_transform(target) 94 | return pil_image, target 95 | 96 | 97 | def download(self) -> None: 98 | if self._check_exists(): 99 | return 100 | 101 | download_and_extract_archive( 102 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", 103 | download_root=str(self._base_folder), 104 | md5="c3b158d763b6e2245038c8ad08e45376", 105 | ) 106 | if self._split == "train": 107 | download_and_extract_archive( 108 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", 109 | download_root=str(self._base_folder), 110 | md5="065e5b463ae28d29e77c1b4b166cfe61", 111 | ) 112 | else: 113 | download_and_extract_archive( 114 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", 115 | download_root=str(self._base_folder), 116 | md5="4ce7ebf6a94d07f1952d94dd34c4d501", 117 | ) 118 | download_url( 119 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", 120 | root=str(self._base_folder), 121 | md5="b0a2b23655a3edd16d84508592a98d10", 122 | ) 123 | 124 | def _check_exists(self) -> bool: 125 | if not (self._base_folder / "devkit").is_dir(): 126 | return False 127 | 128 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir() 129 | 130 | 131 | class Cars: 132 | def __init__(self, 133 | preprocess, 134 | location=os.path.expanduser('~/data'), 135 | batch_size=32, 136 | num_workers=0): 137 | # Data loading code 138 | 139 | self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=False) 140 | self.train_loader = torch.utils.data.DataLoader( 141 | self.train_dataset, 142 | shuffle=True, 143 | batch_size=batch_size, 144 | num_workers=num_workers, 145 | ) 146 | 147 | self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=False) 148 | self.test_loader = torch.utils.data.DataLoader( 149 | self.test_dataset, 150 | batch_size=batch_size, 151 | num_workers=num_workers 152 | ) 153 | self.test_loader_shuffle = torch.utils.data.DataLoader( 154 | self.test_dataset, 155 | shuffle=True, 156 | batch_size=batch_size, 157 | num_workers=num_workers 158 | ) 159 | idx_to_class = dict((v, k) for k, v in self.train_dataset.class_to_idx.items()) 160 | self.classnames = [idx_to_class[i].replace( 161 | '_', ' ') for i in range(len(idx_to_class))] 162 | -------------------------------------------------------------------------------- /dawin_mtl/src/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import open_clip 4 | 5 | import utils 6 | 7 | 8 | class ImageEncoder(torch.nn.Module): 9 | def __init__(self, args, keep_lang=False): 10 | super().__init__() 11 | 12 | print(f'Loading {args.model} pre-trained weights.') 13 | if '__pretrained__' in args.model: 14 | name, pretrained = args.model.split('__pretrained__') 15 | else: 16 | name = args.model 17 | pretrained = 'openai' 18 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( 19 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir) 20 | 21 | self.cache_dir = args.cache_dir 22 | 23 | if not keep_lang and hasattr(self.model, 'transformer'): 24 | delattr(self.model, 'transformer') 25 | 26 | def forward(self, images): 27 | assert self.model is not None 28 | return self.model.encode_image(images) 29 | 30 | def __call__(self, inputs): 31 | return self.forward(inputs) 32 | 33 | def save(self, filename): 34 | print(f'Saving image encoder to {filename}') 35 | utils.torch_save(self, filename) 36 | 37 | @classmethod 38 | def load(cls, model_name, filename): 39 | print(f'Loading image encoder from {filename}') 40 | state_dict = torch.load(filename) 41 | return cls.load(model_name, state_dict) 42 | 43 | @classmethod 44 | def load_from_state_dict(cls, model_name, state_dict): 45 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( 46 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir) 47 | self.model.load_from_state_dict(state_dict) 48 | 49 | 50 | 51 | 52 | class ClassificationHead(torch.nn.Linear): 53 | def __init__(self, normalize, weights, biases=None): 54 | output_size, input_size = weights.shape 55 | super().__init__(input_size, output_size) 56 | self.normalize = normalize 57 | if weights is not None: 58 | self.weight = torch.nn.Parameter(weights.clone()) 59 | if biases is not None: 60 | self.bias = torch.nn.Parameter(biases.clone()) 61 | else: 62 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) 63 | 64 | def forward(self, inputs): 65 | if self.normalize: 66 | inputs = inputs / inputs.norm(dim=-1, keepdim=True) 67 | return super().forward(inputs) 68 | 69 | def __call__(self, inputs): 70 | return self.forward(inputs) 71 | 72 | def save(self, filename): 73 | print(f'Saving classification head to {filename}') 74 | utils.torch_save(self, filename) 75 | 76 | @classmethod 77 | def load(cls, filename): 78 | print(f'Loading classification head from {filename}') 79 | return utils.torch_load(filename) 80 | 81 | 82 | class ImageClassifier(torch.nn.Module): 83 | def __init__(self, image_encoder, classification_head): 84 | super().__init__() 85 | self.image_encoder = image_encoder 86 | self.classification_head = classification_head 87 | if self.image_encoder is not None: 88 | if hasattr(self.image_encoder, 'train_preprocess'): 89 | self.train_preprocess = self.image_encoder.train_preprocess 90 | self.val_preprocess = self.image_encoder.val_preprocess 91 | elif hasattr(self.image_encoder.model, 'train_preprocess'): 92 | self.train_preprocess = self.image_encoder.model.train_preprocess 93 | self.val_preprocess = self.image_encoder.model.val_preprocess 94 | 95 | def freeze_head(self): 96 | self.classification_head.weight.requires_grad_(False) 97 | self.classification_head.bias.requires_grad_(False) 98 | 99 | def forward(self, inputs, out_feat=False): 100 | features = self.image_encoder(inputs) 101 | outputs = self.classification_head(features) 102 | if out_feat: 103 | return outputs, features 104 | return outputs 105 | 106 | def __call__(self, inputs): 107 | return self.forward(inputs) 108 | 109 | def save(self, filename): 110 | print(f'Saving image classifier to {filename}') 111 | utils.torch_save(self, filename) 112 | 113 | @classmethod 114 | def load(cls, filename): 115 | print(f'Loading image classifier from {filename}') 116 | return utils.torch_load(filename) 117 | 118 | class ImageClassifier_debug(torch.nn.Module): 119 | def __init__(self, image_encoder, image_encoder2, classification_head): 120 | super().__init__() 121 | self.image_encoder = image_encoder 122 | self.image_encoder2 = image_encoder2 123 | self.classification_head = classification_head 124 | if self.image_encoder is not None: 125 | self.train_preprocess = self.image_encoder.train_preprocess 126 | self.val_preprocess = self.image_encoder.val_preprocess 127 | 128 | def freeze_head(self): 129 | self.classification_head.weight.requires_grad_(False) 130 | self.classification_head.bias.requires_grad_(False) 131 | 132 | def forward(self, inputs): 133 | features = self.image_encoder(inputs) 134 | features2 = self.image_encoder2(inputs) 135 | outputs = self.classification_head(features + features2) 136 | return outputs 137 | 138 | def __call__(self, inputs): 139 | return self.forward(inputs) 140 | 141 | def save(self, filename): 142 | print(f'Saving image classifier to {filename}') 143 | utils.torch_save(self, filename) 144 | 145 | @classmethod 146 | def load(cls, filename): 147 | print(f'Loading image classifier from {filename}') 148 | return utils.torch_load(filename) 149 | 150 | class MultiHeadImageClassifier(torch.nn.Module): 151 | def __init__(self, image_encoder, classification_heads): 152 | super().__init__() 153 | self.image_encoder = image_encoder 154 | self.classification_heads = torch.nn.ModuleList(classification_heads) 155 | if self.image_encoder is not None: 156 | self.train_preprocess = self.image_encoder.train_preprocess 157 | self.val_preprocess = self.image_encoder.val_preprocess 158 | 159 | def freeze_head(self): 160 | for idx in range(len(self.classification_heads)): 161 | self.classification_heads[idx].weight.requires_grad_(False) 162 | self.classification_heads[idx].bias.requires_grad_(False) 163 | 164 | def forward(self, inputs, head_idx): 165 | features = self.image_encoder(inputs) 166 | outputs = self.classification_heads[head_idx](features) 167 | return outputs 168 | 169 | def __call__(self, inputs, head_idx): 170 | return self.forward(inputs, head_idx) 171 | 172 | def save(self, filename): 173 | print(f'Saving image classifier to {filename}') 174 | utils.torch_save(self, filename) 175 | 176 | @classmethod 177 | def load(cls, filename): 178 | print(f'Loading image classifier from {filename}') 179 | return utils.torch_load(filename) 180 | -------------------------------------------------------------------------------- /dawin_mtl/src/ties_merging_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os, copy 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import re 7 | from collections import OrderedDict 8 | import torch.nn.functional as F 9 | # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 10 | 11 | ## Model conversion utils 12 | def state_dict_to_vector(state_dict, remove_keys=[]): 13 | shared_state_dict = copy.deepcopy(state_dict) 14 | for key in remove_keys: 15 | if key in shared_state_dict: 16 | del shared_state_dict[key] 17 | sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items())) 18 | return torch.nn.utils.parameters_to_vector( 19 | [value.reshape(-1) for key, value in sorted_shared_state_dict.items()] 20 | ) 21 | 22 | 23 | def vector_to_state_dict(vector, state_dict, remove_keys=[]): 24 | # create a reference dict to define the order of the vector 25 | reference_dict = copy.deepcopy(state_dict) 26 | for key in remove_keys: 27 | if key in reference_dict: 28 | del reference_dict[key] 29 | sorted_reference_dict = OrderedDict(sorted(reference_dict.items())) 30 | 31 | # create a shared state dict using the refence dict 32 | torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values()) 33 | 34 | # add back the encoder and decoder embedding weights. 35 | if "transformer.shared.weight" in sorted_reference_dict: 36 | for key in remove_keys: 37 | sorted_reference_dict[key] = sorted_reference_dict[ 38 | "transformer.shared.weight" 39 | ] 40 | return sorted_reference_dict 41 | 42 | 43 | def add_ptm_to_tv(tv_dict, ptm_dict): 44 | assert set(tv_dict.keys()) == set( 45 | ptm_dict.keys() 46 | ), "Differing parameter names in models." 47 | final_dict = copy.deepcopy(tv_dict) 48 | for k, v in ptm_dict.items(): 49 | final_dict[k] = tv_dict[k] + v 50 | return final_dict 51 | 52 | 53 | def check_parameterNamesMatch(checkpoints): 54 | parameter_names = set(checkpoints[0].keys()) 55 | 56 | if len(checkpoints) >= 2: 57 | # raise ValueError("Number of models is less than 2.") 58 | for checkpoint in checkpoints[1:]: 59 | current_parameterNames = set(checkpoint.keys()) 60 | if current_parameterNames != parameter_names: 61 | raise ValueError( 62 | "Differing parameter names in models. " 63 | f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}" 64 | ) 65 | 66 | def check_state_dicts_equal(state_dict1, state_dict2): 67 | if set(state_dict1.keys()) != set(state_dict2.keys()): 68 | return False 69 | 70 | for key in state_dict1.keys(): 71 | if not torch.equal(state_dict1[key], state_dict2[key]): 72 | return False 73 | 74 | return True 75 | 76 | 77 | 78 | ## TIES MERGING UTILS 79 | 80 | def topk_values_mask(M, K=0.7, return_mask=False): 81 | if K > 1: 82 | K /= 100 83 | 84 | original_shape = M.shape 85 | if M.dim() == 1: 86 | M = M.unsqueeze(0) 87 | 88 | n, d = M.shape 89 | k = int(d * K) 90 | k = d - k # Keep top k elements instead of bottom k elements 91 | 92 | # Find the k-th smallest element by magnitude for each row 93 | kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True) 94 | # Create a mask tensor with True for the top k elements in each row 95 | mask = M.abs() >= kth_values 96 | final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask 97 | 98 | if return_mask: 99 | return M * final_mask, final_mask.float().mean(dim=1), final_mask 100 | return M * final_mask, final_mask.float().mean(dim=1) 101 | 102 | 103 | def resolve_zero_signs(sign_to_mult, method="majority"): 104 | majority_sign = torch.sign(sign_to_mult.sum()) 105 | 106 | if method == "majority": 107 | sign_to_mult[sign_to_mult == 0] = majority_sign 108 | elif method == "minority": 109 | sign_to_mult[sign_to_mult == 0] = -1 * majority_sign 110 | return sign_to_mult 111 | 112 | 113 | def resolve_sign(Tensor): 114 | sign_to_mult = torch.sign(Tensor.sum(dim=0)) 115 | sign_to_mult = resolve_zero_signs(sign_to_mult, "majority") 116 | return sign_to_mult 117 | 118 | 119 | def disjoint_merge(Tensor, merge_func, sign_to_mult): 120 | merge_func = merge_func.split("-")[-1] 121 | 122 | # If sign is provided then we select the corresponding entries and aggregate. 123 | if sign_to_mult is not None: 124 | rows_to_keep = torch.where( 125 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0 126 | ) 127 | selected_entries = Tensor * rows_to_keep 128 | # Else we select all non-zero entries and aggregate. 129 | else: 130 | rows_to_keep = Tensor != 0 131 | selected_entries = Tensor * rows_to_keep 132 | 133 | if merge_func == "mean": 134 | non_zero_counts = (selected_entries != 0).sum(dim=0).float() 135 | disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(non_zero_counts, min=1) 136 | elif merge_func == "sum": 137 | disjoint_aggs = torch.sum(selected_entries, dim=0) 138 | elif merge_func == "max": 139 | disjoint_aggs = selected_entries.abs().max(dim=0)[0] 140 | disjoint_aggs *= sign_to_mult 141 | else: 142 | raise ValueError(f"Merge method {merge_func} is not defined.") 143 | 144 | return disjoint_aggs 145 | 146 | 147 | def ties_merging( 148 | flat_task_checks, 149 | reset_thresh=None, 150 | merge_func="", 151 | ): 152 | all_checks = flat_task_checks.clone() 153 | updated_checks, *_ = topk_values_mask( 154 | all_checks, K=reset_thresh, return_mask=False 155 | ) 156 | print(f"RESOLVING SIGN") 157 | final_signs = resolve_sign(updated_checks) 158 | assert final_signs is not None 159 | 160 | print(f"Disjoint AGGREGATION: {merge_func}") 161 | merged_tv = disjoint_merge(updated_checks, merge_func, final_signs) 162 | 163 | return merged_tv 164 | 165 | def disjoint_merge_split(Tensor, merge_func, sign_to_mult): 166 | merge_func = merge_func.split("-")[-1] 167 | 168 | # If sign is provided then we select the corresponding entries and aggregate. 169 | if sign_to_mult is not None: 170 | rows_to_keep = torch.where( 171 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0 172 | ) 173 | selected_entries = Tensor * rows_to_keep 174 | # Else we select all non-zero entries and aggregate. 175 | else: 176 | rows_to_keep = Tensor != 0 177 | selected_entries = Tensor * rows_to_keep 178 | 179 | if merge_func == "sum": 180 | disjoint_aggs = torch.sum(selected_entries, dim=0) 181 | else: 182 | raise ValueError(f"Merge method {merge_func} is not defined.") 183 | 184 | return selected_entries, disjoint_aggs 185 | 186 | 187 | def ties_merging_split( 188 | flat_task_checks, 189 | reset_thresh=None, 190 | merge_func="", 191 | ): 192 | all_checks = flat_task_checks.clone() 193 | updated_checks, *_ = topk_values_mask( 194 | all_checks, K=reset_thresh, return_mask=False 195 | ) 196 | print(f"RESOLVING SIGN") 197 | final_signs = resolve_sign(updated_checks) 198 | assert final_signs is not None 199 | 200 | print(f"Disjoint AGGREGATION: {merge_func}") 201 | selected_entries, merged_tv = disjoint_merge_split(updated_checks, merge_func, final_signs) 202 | 203 | return selected_entries, merged_tv 204 | -------------------------------------------------------------------------------- /dawin_rft/datasets/objectnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pathlib import Path 4 | import PIL 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from torchvision import datasets 10 | from torchvision.transforms import Compose 11 | 12 | from .common import ImageFolderWithPaths, SubsetSampler 13 | from .imagenet import ImageNet, ImageNetSubsampleValClasses 14 | 15 | 16 | def get_metadata(is_beta): 17 | if is_beta: 18 | metadata = Path(__file__).parent / 'objectnet_beta_metadata' 19 | else: 20 | metadata = Path(__file__).parent / 'objectnet_metadata' 21 | 22 | with open(metadata / 'folder_to_objectnet_label.json', 'r') as f: 23 | folder_map = json.load(f) 24 | folder_map = {v: k for k, v in folder_map.items()} 25 | with open(metadata / 'objectnet_to_imagenet_1k.json', 'r') as f: 26 | objectnet_map = json.load(f) 27 | 28 | if is_beta: 29 | with open(metadata / 'imagenet_to_labels.json', 'r') as f: 30 | imagenet_map = json.load(f) 31 | imagenet_map = {v: k for k, v in imagenet_map.items()} 32 | else: 33 | with open(metadata / 'pytorch_to_imagenet_2012_id.json', 'r') as f: 34 | pytorch_map = json.load(f) 35 | pytorch_map = {v: k for k, v in pytorch_map.items()} 36 | 37 | with open(metadata / 'imagenet_to_label_2012_v2', 'r') as f: 38 | imagenet_map = {v.strip(): str(pytorch_map[i]) for i, v in enumerate(f)} 39 | 40 | folder_to_ids, class_sublist = {}, [] 41 | classnames = [] 42 | for objectnet_name, imagenet_names in objectnet_map.items(): 43 | imagenet_names = imagenet_names.split('; ') 44 | imagenet_ids = [int(imagenet_map[imagenet_name]) for imagenet_name in imagenet_names] 45 | class_sublist.extend(imagenet_ids) 46 | folder_to_ids[folder_map[objectnet_name]] = imagenet_ids 47 | 48 | class_sublist = sorted(class_sublist) 49 | class_sublist_mask = [(i in class_sublist) for i in range(1000)] 50 | classname_map = {v: k for k, v in folder_map.items()} 51 | return class_sublist, class_sublist_mask, folder_to_ids, classname_map 52 | 53 | 54 | def crop(img): 55 | width, height = img.size 56 | cropArea = (2, 2, width - 2, height - 2) 57 | img = img.crop(cropArea) 58 | return img 59 | 60 | 61 | def crop_beta(image, border=2): 62 | return PIL.ImageOps.crop(image, border=border) 63 | 64 | 65 | class ObjectNetDataset(datasets.ImageFolder): 66 | 67 | def __init__(self, label_map, path, transform): 68 | self.label_map = label_map 69 | super().__init__(path, transform=transform) 70 | self.samples = [ 71 | d for d in self.samples 72 | if os.path.basename(os.path.dirname(d[0])) in self.label_map 73 | ] 74 | self.imgs = self.samples 75 | 76 | def __len__(self): 77 | return len(self.samples) 78 | 79 | def __getitem__(self, index): 80 | path, target = self.samples[index] 81 | sample = self.loader(path) 82 | if self.transform is not None: 83 | sample = self.transform(sample) 84 | label = os.path.basename(os.path.dirname(path)) 85 | return { 86 | 'images': sample, 87 | 'labels': self.label_map[label], 88 | 'image_paths': path 89 | } 90 | 91 | 92 | class ObjectNetBase(ImageNet): 93 | def __init__(self, *args, **kwargs): 94 | (self._class_sublist, 95 | self.class_sublist_mask, 96 | self.folders_to_ids, 97 | self.classname_map) = get_metadata(self.is_beta()) 98 | 99 | super().__init__(*args, **kwargs) 100 | 101 | self.classnames = sorted(list(self.folders_to_ids.keys())) 102 | self.rev_class_idx_map = {} 103 | self.class_idx_map = {} 104 | for idx, name in enumerate(self.classnames): 105 | self.rev_class_idx_map[idx] = self.folders_to_ids[name] 106 | for imagenet_idx in self.rev_class_idx_map[idx]: 107 | self.class_idx_map[imagenet_idx] = idx 108 | 109 | if self.is_beta(): 110 | self.crop = crop_beta 111 | else: 112 | self.crop = crop 113 | self.preprocess = Compose([crop, self.preprocess]) 114 | self.classnames = [self.classname_map[c].lower() for c in self.classnames] 115 | 116 | def is_beta(self): 117 | raise NotImplementedError 118 | 119 | def populate_train(self): 120 | pass 121 | 122 | def get_test_dataset(self): 123 | subdir = 'objectnet-1.0-beta' if self.is_beta() else 'objectnet-1.0/images' 124 | valdir = os.path.join(self.location, subdir) 125 | label_map = {name: idx for idx, name in enumerate(sorted(list(self.folders_to_ids.keys())))} 126 | return ObjectNetDataset(label_map, valdir, transform=self.preprocess) 127 | 128 | def project_logits(self, logits, device): 129 | if isinstance(logits, list) or isinstance(logits, tuple): 130 | return [self.project_logits(l, device) for l in logits] 131 | if logits.shape[1] == 113: 132 | return logits 133 | if torch.is_tensor(logits): 134 | logits = logits.cpu().numpy() 135 | logits_projected = np.zeros((logits.shape[0], 113)) 136 | for k, v in self.rev_class_idx_map.items(): 137 | logits_projected[:, k] = np.max(logits[:, v], axis=1).squeeze() 138 | return torch.tensor(logits_projected).to(device) 139 | 140 | def scatter_weights(self, weights): 141 | if weights.size(1) == 1000: 142 | return weights 143 | new_weights = torch.ones((weights.size(0), 1000)).to(weights.device) * -10e8 144 | for k, v in self.rev_class_idx_map.items(): 145 | for vv in v: 146 | new_weights[:, vv] = weights[:, k] 147 | return new_weights 148 | 149 | 150 | 151 | def accuracy(logits, targets, img_paths, args): 152 | assert logits.shape[1] == 113 153 | preds = logits.argmax(dim=1) 154 | if torch.is_tensor(preds): 155 | preds = preds.cpu().numpy() 156 | if torch.is_tensor(targets): 157 | targets = targets.cpu().numpy() 158 | return np.sum(preds == targets), len(preds) 159 | 160 | 161 | class ObjectNetBetaValClassesBase(ObjectNetBase): 162 | 163 | def get_test_sampler(self): 164 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self._class_sublist] 165 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist]) 166 | 167 | sampler = SubsetSampler(idx_subsample_list) 168 | return sampler 169 | 170 | def get_test_dataset(self): 171 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess) 172 | 173 | def project_labels(self, labels, device): 174 | projected_labels = [self.class_idx_map[int(label)] for label in labels] 175 | return torch.LongTensor(projected_labels).to(device) 176 | 177 | 178 | class ObjectNetBetaValClasses(ObjectNetBetaValClassesBase): 179 | 180 | def is_beta(self): 181 | return True 182 | 183 | class ObjectNetValClasses(ObjectNetBetaValClassesBase): 184 | 185 | def is_beta(self): 186 | return False 187 | 188 | class ObjectNet(ObjectNetBase): 189 | 190 | def accuracy(self, logits, targets, img_paths, args): 191 | return accuracy(logits, targets, img_paths, args) 192 | 193 | def is_beta(self): 194 | return False 195 | 196 | class ObjectNetBeta(ObjectNetBase): 197 | 198 | def accuracy(self, logits, targets, img_paths, args): 199 | return accuracy(logits, targets, img_paths, args) 200 | 201 | def is_beta(self): 202 | return True -------------------------------------------------------------------------------- /dawin_rft/mixturemodel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import betaln, logsumexp 3 | from sklearn.cluster import KMeans 4 | 5 | class BetaMixtureModel: 6 | """ 7 | Beta Mixture Model (Multivariate version). 8 | Each dimension is modeled independently by a Beta distribution. 9 | """ 10 | 11 | def __init__(self, n_mixtures=3, random_seed=1): 12 | self.n_mixtures = n_mixtures 13 | self.random_seed = random_seed 14 | self.convergence = False 15 | 16 | def _init_clusters(self, data_matrix, init_round): 17 | """ 18 | Initialize the mixture responsibilities (assignments) via k-means or uniformly random 19 | """ 20 | if self.method == "kmeans": 21 | km = KMeans( 22 | n_clusters=self.n_mixtures, 23 | n_init=1, 24 | random_state=self.random_seed + init_round 25 | ).fit(data_matrix) 26 | resp_matrix = np.zeros((self.n_observations, self.n_mixtures)) 27 | resp_matrix[np.arange(self.n_observations), km.labels_] = 1 28 | else: 29 | np.random.seed(self.random_seed + init_round) 30 | resp_matrix = np.random.rand(self.n_observations, self.n_mixtures) 31 | resp_matrix /= resp_matrix.sum(axis=1, keepdims=True) 32 | 33 | # Numerical stability 34 | resp_matrix += 10 * np.finfo(resp_matrix.dtype).eps 35 | 36 | # Initialize beta parameters (alpha/beta for each dimension) 37 | self.beta_params_ = np.zeros((self.n_mixtures, self.n_components * 2)) 38 | self._M_step(data_matrix, np.log(resp_matrix)) 39 | 40 | 41 | def _calc_log_weights(self): 42 | """ 43 | Return log of current mixture weights. 44 | """ 45 | return np.log(self.mix_weights_) 46 | 47 | def _calc_mixture_log_probs(self, data_matrix, mixture_idx): 48 | """ 49 | Compute log-prob for a single mixture (used if parallelized). 50 | """ 51 | alpha_vec = self.beta_params_[mixture_idx, :self.n_components] 52 | beta_vec = self.beta_params_[mixture_idx, self.n_components:] 53 | beta_func_log = betaln(alpha_vec, beta_vec) 54 | return ( 55 | (alpha_vec - 1) * np.log(data_matrix) 56 | + (beta_vec - 1) * np.log(1 - data_matrix) 57 | - beta_func_log 58 | ).sum(axis=1) 59 | 60 | def _calc_log_probs_all_mixtures(self, data_matrix): 61 | """ 62 | Return log-prob for each observation under each mixture (unnormalized). 63 | """ 64 | log_prob = np.empty((self.n_observations, self.n_mixtures)) 65 | for mix in range(self.n_mixtures): 66 | alpha_vec = self.beta_params_[mix, :self.n_components] 67 | beta_vec = self.beta_params_[mix, self.n_components:] 68 | bfn = betaln(alpha_vec, beta_vec) 69 | log_prob[:, mix] = ( 70 | (alpha_vec - 1) * np.log(data_matrix) 71 | + (beta_vec - 1) * np.log(1 - data_matrix) 72 | - bfn 73 | ).sum(axis=1) 74 | return log_prob 75 | 76 | def _calc_weighted_log_probs(self, data_matrix): 77 | """ 78 | Return the sum of log-probabilities and log-weights. 79 | """ 80 | return self._calc_log_probs_all_mixtures(data_matrix) + self._calc_log_weights() 81 | 82 | def _calc_log_resp_and_norm(self, data_matrix): 83 | """ 84 | Return (log_prob_norm, log_resp) for the E-step. 85 | """ 86 | weighted_lp = self._calc_weighted_log_probs(data_matrix) 87 | lp_norm = logsumexp(weighted_lp, axis=1) 88 | with np.errstate(under="ignore"): 89 | log_resp = weighted_lp - lp_norm[:, None] 90 | return lp_norm, log_resp 91 | 92 | def _E_step(self, data_matrix): 93 | """ 94 | E-step: compute average log_prob_norm and log_resp. 95 | """ 96 | lp_norm, log_resp = self._calc_log_resp_and_norm(data_matrix) 97 | return np.mean(lp_norm), log_resp 98 | 99 | def _compute_responsibilities(self, log_resp): 100 | """ 101 | Exponentiate log_resp and sum across observations. 102 | """ 103 | resp_matrix = np.exp(log_resp) 104 | cluster_counts = resp_matrix.sum(axis=0) + 10 * np.finfo(resp_matrix.dtype).eps 105 | return resp_matrix, cluster_counts 106 | 107 | def _update_mixture_weights(self, cluster_counts): 108 | """ 109 | Update mixture weights from mixture counts. 110 | """ 111 | self.mix_weights_ = cluster_counts / cluster_counts.sum() 112 | 113 | def _M_step(self, data_matrix, log_resp): 114 | """ 115 | M-step: update weights and Beta distribution parameters via moment matching. 116 | """ 117 | resp_matrix, cluster_counts = self._compute_responsibilities(log_resp) 118 | self._update_mixture_weights(cluster_counts) 119 | 120 | w_sums = resp_matrix.T @ data_matrix 121 | w_sums_sq = resp_matrix.T @ (data_matrix ** 2) 122 | 123 | for m_idx in range(self.n_mixtures): 124 | sum_vals = w_sums[m_idx] 125 | sum_sq_vals = w_sums_sq[m_idx] 126 | mean_val = sum_vals / cluster_counts[m_idx] 127 | var_val = sum_sq_vals / cluster_counts[m_idx] - mean_val ** 2 128 | 129 | # Clip variance 130 | variance_cap = mean_val * (1 - mean_val) / 4 131 | var_val = np.minimum(var_val, variance_cap) 132 | var_val += 10 * np.finfo(var_val.dtype).eps 133 | 134 | # Compute factor 135 | scaling_factor = (mean_val * (1 - mean_val)) / (var_val + 1e-10) - 1 136 | self.beta_params_[m_idx, :self.n_components] = scaling_factor * mean_val 137 | self.beta_params_[m_idx, self.n_components:] = scaling_factor * (1 - mean_val) 138 | 139 | def fit(self, data_matrix, num_init=3, method="kmeans", max_iter=1000, tol=1e-4): 140 | """ 141 | Fit BetaMixtureModel to the data using EM, possibly with multiple initializations. 142 | """ 143 | self.n_observations, self.n_components = data_matrix.shape 144 | self.convergence = False 145 | self.method = method 146 | best_lower_bound = -np.inf 147 | optimal_params = None 148 | 149 | for init_round in range(num_init): 150 | print(f"{init_round + 1}-th BMM initialization") 151 | self._init_clusters(data_matrix, init_round) 152 | ll_bound = -np.inf 153 | 154 | for _ in range(max_iter): 155 | prev_bound = ll_bound 156 | lp_norm, log_resp = self._E_step(data_matrix) 157 | self._M_step(data_matrix, log_resp) 158 | ll_bound = lp_norm 159 | delta_bound = ll_bound - prev_bound 160 | 161 | if abs(delta_bound) < tol: 162 | self.convergence = True 163 | break 164 | 165 | if ll_bound > best_lower_bound: 166 | best_lower_bound = ll_bound 167 | # Update final weights 168 | _, cluster_counts = self._compute_responsibilities(log_resp) 169 | self._update_mixture_weights(cluster_counts) 170 | optimal_params = (self.mix_weights_.copy(), self.beta_params_.copy()) 171 | 172 | self.mix_weights_, self.beta_params_ = optimal_params 173 | self.max_lower_bound = best_lower_bound 174 | return self 175 | 176 | def predict_proba(self, data_matrix): 177 | """ 178 | Return the per-mixture membership probabilities for each sample. 179 | """ 180 | _, log_resp = self._calc_log_resp_and_norm(data_matrix) 181 | return np.exp(log_resp) 182 | 183 | def predict(self, data_matrix): 184 | """ 185 | Return the most probable mixture index for each sample. 186 | """ 187 | return np.argmax(self.predict_proba(data_matrix), axis=1) -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/templates.py: -------------------------------------------------------------------------------- 1 | cars_template = [ 2 | lambda c: f'a photo of a {c}.', 3 | lambda c: f'a photo of the {c}.', 4 | lambda c: f'a photo of my {c}.', 5 | lambda c: f'i love my {c}!', 6 | lambda c: f'a photo of my dirty {c}.', 7 | lambda c: f'a photo of my clean {c}.', 8 | lambda c: f'a photo of my new {c}.', 9 | lambda c: f'a photo of my old {c}.', 10 | ] 11 | 12 | cifar10_template = [ 13 | lambda c: f'a photo of a {c}.', 14 | lambda c: f'a blurry photo of a {c}.', 15 | lambda c: f'a black and white photo of a {c}.', 16 | lambda c: f'a low contrast photo of a {c}.', 17 | lambda c: f'a high contrast photo of a {c}.', 18 | lambda c: f'a bad photo of a {c}.', 19 | lambda c: f'a good photo of a {c}.', 20 | lambda c: f'a photo of a small {c}.', 21 | lambda c: f'a photo of a big {c}.', 22 | lambda c: f'a photo of the {c}.', 23 | lambda c: f'a blurry photo of the {c}.', 24 | lambda c: f'a black and white photo of the {c}.', 25 | lambda c: f'a low contrast photo of the {c}.', 26 | lambda c: f'a high contrast photo of the {c}.', 27 | lambda c: f'a bad photo of the {c}.', 28 | lambda c: f'a good photo of the {c}.', 29 | lambda c: f'a photo of the small {c}.', 30 | lambda c: f'a photo of the big {c}.', 31 | ] 32 | 33 | cifar100_template = [ 34 | lambda c: f'a photo of a {c}.', 35 | lambda c: f'a blurry photo of a {c}.', 36 | lambda c: f'a black and white photo of a {c}.', 37 | lambda c: f'a low contrast photo of a {c}.', 38 | lambda c: f'a high contrast photo of a {c}.', 39 | lambda c: f'a bad photo of a {c}.', 40 | lambda c: f'a good photo of a {c}.', 41 | lambda c: f'a photo of a small {c}.', 42 | lambda c: f'a photo of a big {c}.', 43 | lambda c: f'a photo of the {c}.', 44 | lambda c: f'a blurry photo of the {c}.', 45 | lambda c: f'a black and white photo of the {c}.', 46 | lambda c: f'a low contrast photo of the {c}.', 47 | lambda c: f'a high contrast photo of the {c}.', 48 | lambda c: f'a bad photo of the {c}.', 49 | lambda c: f'a good photo of the {c}.', 50 | lambda c: f'a photo of the small {c}.', 51 | lambda c: f'a photo of the big {c}.', 52 | ] 53 | 54 | dtd_template = [ 55 | lambda c: f'a photo of a {c} texture.', 56 | lambda c: f'a photo of a {c} pattern.', 57 | lambda c: f'a photo of a {c} thing.', 58 | lambda c: f'a photo of a {c} object.', 59 | lambda c: f'a photo of the {c} texture.', 60 | lambda c: f'a photo of the {c} pattern.', 61 | lambda c: f'a photo of the {c} thing.', 62 | lambda c: f'a photo of the {c} object.', 63 | ] 64 | 65 | eurosat_template = [ 66 | lambda c: f'a centered satellite photo of {c}.', 67 | lambda c: f'a centered satellite photo of a {c}.', 68 | lambda c: f'a centered satellite photo of the {c}.', 69 | ] 70 | 71 | food101_template = [ 72 | lambda c: f'a photo of {c}, a type of food.', 73 | ] 74 | 75 | gtsrb_template = [ 76 | lambda c: f'a zoomed in photo of a "{c}" traffic sign.', 77 | lambda c: f'a centered photo of a "{c}" traffic sign.', 78 | lambda c: f'a close up photo of a "{c}" traffic sign.', 79 | ] 80 | 81 | mnist_template = [ 82 | lambda c: f'a photo of the number: "{c}".', 83 | ] 84 | 85 | imagenet_template = [ 86 | lambda c: f'a bad photo of a {c}.', 87 | lambda c: f'a photo of many {c}.', 88 | lambda c: f'a sculpture of a {c}.', 89 | lambda c: f'a photo of the hard to see {c}.', 90 | lambda c: f'a low resolution photo of the {c}.', 91 | lambda c: f'a rendering of a {c}.', 92 | lambda c: f'graffiti of a {c}.', 93 | lambda c: f'a bad photo of the {c}.', 94 | lambda c: f'a cropped photo of the {c}.', 95 | lambda c: f'a tattoo of a {c}.', 96 | lambda c: f'the embroidered {c}.', 97 | lambda c: f'a photo of a hard to see {c}.', 98 | lambda c: f'a bright photo of a {c}.', 99 | lambda c: f'a photo of a clean {c}.', 100 | lambda c: f'a photo of a dirty {c}.', 101 | lambda c: f'a dark photo of the {c}.', 102 | lambda c: f'a drawing of a {c}.', 103 | lambda c: f'a photo of my {c}.', 104 | lambda c: f'the plastic {c}.', 105 | lambda c: f'a photo of the cool {c}.', 106 | lambda c: f'a close-up photo of a {c}.', 107 | lambda c: f'a black and white photo of the {c}.', 108 | lambda c: f'a painting of the {c}.', 109 | lambda c: f'a painting of a {c}.', 110 | lambda c: f'a pixelated photo of the {c}.', 111 | lambda c: f'a sculpture of the {c}.', 112 | lambda c: f'a bright photo of the {c}.', 113 | lambda c: f'a cropped photo of a {c}.', 114 | lambda c: f'a plastic {c}.', 115 | lambda c: f'a photo of the dirty {c}.', 116 | lambda c: f'a jpeg corrupted photo of a {c}.', 117 | lambda c: f'a blurry photo of the {c}.', 118 | lambda c: f'a photo of the {c}.', 119 | lambda c: f'a good photo of the {c}.', 120 | lambda c: f'a rendering of the {c}.', 121 | lambda c: f'a {c} in a video game.', 122 | lambda c: f'a photo of one {c}.', 123 | lambda c: f'a doodle of a {c}.', 124 | lambda c: f'a close-up photo of the {c}.', 125 | lambda c: f'a photo of a {c}.', 126 | lambda c: f'the origami {c}.', 127 | lambda c: f'the {c} in a video game.', 128 | lambda c: f'a sketch of a {c}.', 129 | lambda c: f'a doodle of the {c}.', 130 | lambda c: f'a origami {c}.', 131 | lambda c: f'a low resolution photo of a {c}.', 132 | lambda c: f'the toy {c}.', 133 | lambda c: f'a rendition of the {c}.', 134 | lambda c: f'a photo of the clean {c}.', 135 | lambda c: f'a photo of a large {c}.', 136 | lambda c: f'a rendition of a {c}.', 137 | lambda c: f'a photo of a nice {c}.', 138 | lambda c: f'a photo of a weird {c}.', 139 | lambda c: f'a blurry photo of a {c}.', 140 | lambda c: f'a cartoon {c}.', 141 | lambda c: f'art of a {c}.', 142 | lambda c: f'a sketch of the {c}.', 143 | lambda c: f'a embroidered {c}.', 144 | lambda c: f'a pixelated photo of a {c}.', 145 | lambda c: f'itap of the {c}.', 146 | lambda c: f'a jpeg corrupted photo of the {c}.', 147 | lambda c: f'a good photo of a {c}.', 148 | lambda c: f'a plushie {c}.', 149 | lambda c: f'a photo of the nice {c}.', 150 | lambda c: f'a photo of the small {c}.', 151 | lambda c: f'a photo of the weird {c}.', 152 | lambda c: f'the cartoon {c}.', 153 | lambda c: f'art of the {c}.', 154 | lambda c: f'a drawing of the {c}.', 155 | lambda c: f'a photo of the large {c}.', 156 | lambda c: f'a black and white photo of a {c}.', 157 | lambda c: f'the plushie {c}.', 158 | lambda c: f'a dark photo of a {c}.', 159 | lambda c: f'itap of a {c}.', 160 | lambda c: f'graffiti of the {c}.', 161 | lambda c: f'a toy {c}.', 162 | lambda c: f'itap of my {c}.', 163 | lambda c: f'a photo of a cool {c}.', 164 | lambda c: f'a photo of a small {c}.', 165 | lambda c: f'a tattoo of the {c}.', 166 | ] 167 | 168 | resisc45_template = [ 169 | lambda c: f'satellite imagery of {c}.', 170 | lambda c: f'aerial imagery of {c}.', 171 | lambda c: f'satellite photo of {c}.', 172 | lambda c: f'aerial photo of {c}.', 173 | lambda c: f'satellite view of {c}.', 174 | lambda c: f'aerial view of {c}.', 175 | lambda c: f'satellite imagery of a {c}.', 176 | lambda c: f'aerial imagery of a {c}.', 177 | lambda c: f'satellite photo of a {c}.', 178 | lambda c: f'aerial photo of a {c}.', 179 | lambda c: f'satellite view of a {c}.', 180 | lambda c: f'aerial view of a {c}.', 181 | lambda c: f'satellite imagery of the {c}.', 182 | lambda c: f'aerial imagery of the {c}.', 183 | lambda c: f'satellite photo of the {c}.', 184 | lambda c: f'aerial photo of the {c}.', 185 | lambda c: f'satellite view of the {c}.', 186 | lambda c: f'aerial view of the {c}.', 187 | ] 188 | 189 | stl10_template = [ 190 | lambda c: f'a photo of a {c}.', 191 | lambda c: f'a photo of the {c}.', 192 | ] 193 | 194 | sun397_template = [ 195 | lambda c: f'a photo of a {c}.', 196 | lambda c: f'a photo of the {c}.', 197 | ] 198 | 199 | svhn_template = [ 200 | lambda c: f'a photo of the number: "{c}".', 201 | ] 202 | 203 | 204 | dataset_to_template = { 205 | 'Cars': cars_template, 206 | 'CIFAR10': cifar10_template, 207 | 'CIFAR100': cifar100_template, 208 | 'DTD': dtd_template, 209 | 'EuroSAT': eurosat_template, 210 | 'Food101': food101_template, 211 | 'GTSRB': gtsrb_template, 212 | 'MNIST': mnist_template, 213 | 'ImageNet': imagenet_template, 214 | 'RESISC45': resisc45_template, 215 | 'STL10': stl10_template, 216 | 'SUN397': sun397_template, 217 | 'SVHN': svhn_template, 218 | } 219 | 220 | 221 | def get_templates(dataset_name): 222 | if dataset_name.endswith('Val'): 223 | return get_templates(dataset_name.replace('Val', '')) 224 | assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}' 225 | return dataset_to_template[dataset_name] -------------------------------------------------------------------------------- /dawin_mtl/src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import pickle 5 | import math 6 | 7 | import numpy as np 8 | from scipy.special import digamma, gammaln, psi 9 | from scipy.stats import dirichlet 10 | 11 | def assign_learning_rate(param_group, new_lr): 12 | param_group["lr"] = new_lr 13 | 14 | 15 | def _warmup_lr(base_lr, warmup_length, step): 16 | return base_lr * (step + 1) / warmup_length 17 | 18 | 19 | def cosine_lr(optimizer, base_lrs, warmup_length, steps): 20 | if not isinstance(base_lrs, list): 21 | base_lrs = [base_lrs for _ in optimizer.param_groups] 22 | assert len(base_lrs) == len(optimizer.param_groups) 23 | def _lr_adjuster(step): 24 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs): 25 | if step < warmup_length: 26 | lr = _warmup_lr(base_lr, warmup_length, step) 27 | else: 28 | e = step - warmup_length 29 | es = steps - warmup_length 30 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 31 | assign_learning_rate(param_group, lr) 32 | return _lr_adjuster 33 | 34 | 35 | def accuracy(output, target, topk=(1,)): 36 | pred = output.topk(max(topk), 1, True, True)[1].t() 37 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 38 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 39 | 40 | 41 | def torch_load_old(save_path, device=None): 42 | with open(save_path, 'rb') as f: 43 | classifier = pickle.load(f) 44 | if device is not None: 45 | classifier = classifier.to(device) 46 | return classifier 47 | 48 | 49 | def torch_save(model, save_path): 50 | if os.path.dirname(save_path) != '': 51 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 52 | torch.save(model.cpu(), save_path) 53 | 54 | 55 | def torch_load(save_path, device=None): 56 | model = torch.load(save_path) 57 | if device is not None: 58 | model = model.to(device) 59 | return model 60 | 61 | 62 | 63 | def get_logits(inputs, classifier): 64 | assert callable(classifier) 65 | if hasattr(classifier, 'to'): 66 | classifier = classifier.to(inputs.device) 67 | return classifier(inputs) 68 | 69 | 70 | def get_probs(inputs, classifier): 71 | if hasattr(classifier, 'predict_proba'): 72 | probs = classifier.predict_proba(inputs.detach().cpu().numpy()) 73 | return torch.from_numpy(probs) 74 | logits = get_logits(inputs, classifier) 75 | return logits.softmax(dim=1) 76 | 77 | 78 | class LabelSmoothing(torch.nn.Module): 79 | def __init__(self, smoothing=0.0): 80 | super(LabelSmoothing, self).__init__() 81 | self.confidence = 1.0 - smoothing 82 | self.smoothing = smoothing 83 | 84 | def forward(self, x, target): 85 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 86 | 87 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 88 | nll_loss = nll_loss.squeeze(1) 89 | smooth_loss = -logprobs.mean(dim=-1) 90 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 91 | return loss.mean() 92 | 93 | 94 | import time 95 | import numpy as np 96 | from scipy.special import gammaln, logsumexp 97 | from sklearn.cluster import KMeans 98 | 99 | class DirichletMixtureModel: 100 | """ 101 | Dirichlet Mixture Model for data with each row in X 102 | as a probability vector (summing to 1). 103 | """ 104 | 105 | def __init__(self, n_components=3, random_seed=1): 106 | self.n_mixtures = n_components 107 | self.random_seed = random_seed 108 | self.convergence = False 109 | 110 | 111 | def _init_clusters(self, X, init_idx): 112 | """Initialize mixture responsibilities via k-means or random approach.""" 113 | N = self.n_observations 114 | if self.method == "kmeans": 115 | kmeans = KMeans( 116 | n_clusters=self.n_components, 117 | n_init=1, 118 | random_state=self.random_seed + init_idx 119 | ) 120 | labels = kmeans.fit(X).labels_ 121 | resp = np.zeros((N, self.n_components)) 122 | resp[np.arange(N), labels] = 1.0 123 | else: # random 124 | rng = np.random.default_rng(self.random_seed + init_idx) 125 | resp = rng.random((N, self.n_components)) 126 | resp /= resp.sum(axis=1, keepdims=True) 127 | 128 | # Small stability offset 129 | resp += 1e-14 130 | 131 | # Random initial Dirichlet parameters 132 | self.dirichlet_params_ = np.random.rand(self.n_components, self.n_components) 133 | 134 | # Single M-step to set initial params 135 | self._M_step(X, np.log(resp)) 136 | 137 | def _estimate_log_weights(self): 138 | """Return log of current mixture weights.""" 139 | return np.log(self.weights_) 140 | 141 | def _dirichlet_log_prob_single(self, X, mixture_idx): 142 | """Compute log-prob for one mixture across all samples.""" 143 | alpha = self.dirichlet_params_[mixture_idx] 144 | log_C = gammaln(alpha.sum()) - np.sum(gammaln(alpha)) 145 | return log_C + np.sum((alpha - 1) * np.log(X), axis=1) 146 | 147 | def _estimate_log_prob(self, X): 148 | """Compute log-prob of each mixture for each sample.""" 149 | log_prob = np.empty((self.n_observations, self.n_components)) 150 | for k in range(self.n_components): 151 | log_prob[:, k] = self._dirichlet_log_prob_single(X, k) 152 | 153 | return log_prob 154 | 155 | def _weighted_log_prob(self, X): 156 | """Return mixture-weighted log-prob.""" 157 | return self._estimate_log_prob(X) + self._estimate_log_weights() 158 | 159 | def _estimate_log_resp(self, X): 160 | """ 161 | E-step: 162 | log_resp = log p(x | mixture) + log(weights) 163 | - logsumexp over mixtures 164 | """ 165 | wlp = self._weighted_log_prob(X) 166 | log_norm = logsumexp(wlp, axis=1) 167 | with np.errstate(under="ignore"): 168 | log_resp = wlp - log_norm[:, None] 169 | return log_norm, log_resp 170 | 171 | def _compute_responsibilities(self, log_resp): 172 | """Exponentiate log responsibilities and sum them per mixture.""" 173 | resp = np.exp(log_resp) 174 | nk = resp.sum(axis=0) + 1e-14 175 | return resp, nk 176 | 177 | def _update_weights(self, nk): 178 | """Update mixture weights.""" 179 | self.weights_ = nk / nk.sum() 180 | 181 | def _E_step(self, X): 182 | """Perform E-step, returning mean log-likelihood and log responsibilities.""" 183 | log_norm, log_resp = self._estimate_log_resp(X) 184 | return log_norm.mean(), log_resp 185 | 186 | def _M_step(self, X, log_resp): 187 | """Perform M-step using updated responsibilities.""" 188 | resp, nk = self._compute_responsibilities(log_resp) 189 | self._update_weights(nk) 190 | 191 | # Update each mixture's Dirichlet parameter 192 | #! Note that this is a simplified version with moment-matching heuristic, rather than the exact EM 193 | for k in range(self.n_components): 194 | alpha_new = resp[:, k] @ X # Weighted sum of X 195 | self.dirichlet_params_[k] = alpha_new / nk[k] 196 | 197 | def fit(self, X, n_init=10, method="kmeans", max_iter=2000, tol=1e-6): 198 | """ 199 | Run EM to fit a Dirichlet mixture model to data X. 200 | Each row of X should be a probability vector (sum to 1). 201 | """ 202 | self.n_observations, self.n_components = X.shape 203 | self.method = method 204 | self.convergence = False 205 | 206 | best_lower_bound = -np.inf 207 | best_params = None 208 | 209 | for init_idx in range(n_init): 210 | print(f"{init_idx + 1}-th DMM initialization") 211 | self._init_clusters(X, init_idx) 212 | 213 | lower_bound = -np.inf 214 | 215 | for iteration in range(max_iter): 216 | prev_bound = lower_bound 217 | 218 | log_prob_norm, log_resp = self._E_step(X) 219 | self._M_step(X, log_resp) 220 | 221 | lower_bound = log_prob_norm 222 | diff = lower_bound - prev_bound 223 | 224 | if abs(diff) < tol: 225 | self.convergence = True 226 | break 227 | 228 | if lower_bound > best_lower_bound: 229 | best_lower_bound = lower_bound 230 | _, nk = self._compute_responsibilities(log_resp) 231 | self._update_weights(nk) 232 | best_params = (self.weights_.copy(), self.dirichlet_params_.copy()) 233 | 234 | # Restore the best parameters 235 | self.weights_, self.params_ = best_params 236 | self.lower_bound_ = best_lower_bound 237 | return self 238 | 239 | def predict_proba(self, X): 240 | """Return soft assignments (N x K).""" 241 | _, log_resp = self._estimate_log_resp(X) 242 | return np.exp(log_resp) 243 | 244 | def predict(self, X): 245 | """Return the most likely mixture for each sample.""" 246 | return np.argmax(self.predict_proba(X), axis=1) -------------------------------------------------------------------------------- /dawin_mtl/src/datasets/gtsrb.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pathlib 4 | from typing import Any, Callable, Dict, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import PIL 8 | import torch 9 | from torchvision.datasets.folder import make_dataset 10 | from torchvision.datasets.utils import (download_and_extract_archive, verify_str_arg) 11 | from torchvision.datasets.vision import VisionDataset 12 | 13 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: 14 | """Finds the class folders in a dataset. 15 | 16 | See :class:`DatasetFolder` for details. 17 | """ 18 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 19 | if not classes: 20 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 21 | 22 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 23 | return classes, class_to_idx 24 | 25 | class PyTorchGTSRB(VisionDataset): 26 | """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset. 27 | 28 | Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB. 29 | 30 | Args: 31 | root (string): Root directory of the dataset. 32 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. 33 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 34 | version. E.g, ``transforms.RandomCrop``. 35 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 36 | download (bool, optional): If True, downloads the dataset from the internet and 37 | puts it in root directory. If dataset is already downloaded, it is not 38 | downloaded again. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | root: str, 44 | split: str = "train", 45 | transform: Optional[Callable] = None, 46 | target_transform: Optional[Callable] = None, 47 | download: bool = False, 48 | ) -> None: 49 | 50 | super().__init__(root, transform=transform, target_transform=target_transform) 51 | 52 | self._split = verify_str_arg(split, "split", ("train", "test")) 53 | self._base_folder = pathlib.Path(root) / "gtsrb" 54 | self._target_folder = ( 55 | self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") 56 | ) 57 | 58 | if download: 59 | self.download() 60 | 61 | if not self._check_exists(): 62 | raise RuntimeError("Dataset not found. You can use download=True to download it") 63 | 64 | if self._split == "train": 65 | _, class_to_idx = find_classes(str(self._target_folder)) 66 | samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx) 67 | else: 68 | with open(self._base_folder / "GT-final_test.csv") as csv_file: 69 | samples = [ 70 | (str(self._target_folder / row["Filename"]), int(row["ClassId"])) 71 | for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) 72 | ] 73 | 74 | self._samples = samples 75 | self.transform = transform 76 | self.target_transform = target_transform 77 | 78 | def __len__(self) -> int: 79 | return len(self._samples) 80 | 81 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 82 | 83 | path, target = self._samples[index] 84 | sample = PIL.Image.open(path).convert("RGB") 85 | 86 | if self.transform is not None: 87 | sample = self.transform(sample) 88 | 89 | if self.target_transform is not None: 90 | target = self.target_transform(target) 91 | 92 | return sample, target 93 | 94 | 95 | def _check_exists(self) -> bool: 96 | return self._target_folder.is_dir() 97 | 98 | def download(self) -> None: 99 | if self._check_exists(): 100 | return 101 | 102 | base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" 103 | 104 | if self._split == "train": 105 | download_and_extract_archive( 106 | f"{base_url}GTSRB-Training_fixed.zip", 107 | download_root=str(self._base_folder), 108 | md5="513f3c79a4c5141765e10e952eaa2478", 109 | ) 110 | else: 111 | download_and_extract_archive( 112 | f"{base_url}GTSRB_Final_Test_Images.zip", 113 | download_root=str(self._base_folder), 114 | md5="c7e4e6327067d32654124b0fe9e82185", 115 | ) 116 | download_and_extract_archive( 117 | f"{base_url}GTSRB_Final_Test_GT.zip", 118 | download_root=str(self._base_folder), 119 | md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", 120 | ) 121 | 122 | 123 | class GTSRB: 124 | def __init__(self, 125 | preprocess, 126 | location=os.path.expanduser('~/data'), 127 | batch_size=128, 128 | num_workers=0): 129 | 130 | # to fit with repo conventions for location 131 | self.train_dataset = PyTorchGTSRB( 132 | root=location, 133 | #download=True, 134 | download=False, 135 | split='train', 136 | transform=preprocess 137 | ) 138 | 139 | self.train_loader = torch.utils.data.DataLoader( 140 | self.train_dataset, 141 | batch_size=batch_size, 142 | shuffle=True, 143 | num_workers=num_workers 144 | ) 145 | 146 | self.test_dataset = PyTorchGTSRB( 147 | root=location, 148 | #download=True, 149 | download=False, 150 | split='test', 151 | transform=preprocess 152 | ) 153 | 154 | self.test_loader = torch.utils.data.DataLoader( 155 | self.test_dataset, 156 | batch_size=batch_size, 157 | shuffle=False, 158 | num_workers=num_workers 159 | ) 160 | 161 | self.test_loader_shuffle = torch.utils.data.DataLoader( 162 | self.test_dataset, 163 | batch_size=batch_size, 164 | shuffle=True, 165 | num_workers=num_workers 166 | ) 167 | 168 | # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md 169 | self.classnames = [ 170 | 'red and white circle 20 kph speed limit', 171 | 'red and white circle 30 kph speed limit', 172 | 'red and white circle 50 kph speed limit', 173 | 'red and white circle 60 kph speed limit', 174 | 'red and white circle 70 kph speed limit', 175 | 'red and white circle 80 kph speed limit', 176 | 'end / de-restriction of 80 kph speed limit', 177 | 'red and white circle 100 kph speed limit', 178 | 'red and white circle 120 kph speed limit', 179 | 'red and white circle red car and black car no passing', 180 | 'red and white circle red truck and black car no passing', 181 | 'red and white triangle road intersection warning', 182 | 'white and yellow diamond priority road', 183 | 'red and white upside down triangle yield right-of-way', 184 | 'stop', 185 | 'empty red and white circle', 186 | 'red and white circle no truck entry', 187 | 'red circle with white horizonal stripe no entry', 188 | 'red and white triangle with exclamation mark warning', 189 | 'red and white triangle with black left curve approaching warning', 190 | 'red and white triangle with black right curve approaching warning', 191 | 'red and white triangle with black double curve approaching warning', 192 | 'red and white triangle rough / bumpy road warning', 193 | 'red and white triangle car skidding / slipping warning', 194 | 'red and white triangle with merging / narrow lanes warning', 195 | 'red and white triangle with person digging / construction / road work warning', 196 | 'red and white triangle with traffic light approaching warning', 197 | 'red and white triangle with person walking warning', 198 | 'red and white triangle with child and person walking warning', 199 | 'red and white triangle with bicyle warning', 200 | 'red and white triangle with snowflake / ice warning', 201 | 'red and white triangle with deer warning', 202 | 'white circle with gray strike bar no speed limit', 203 | 'blue circle with white right turn arrow mandatory', 204 | 'blue circle with white left turn arrow mandatory', 205 | 'blue circle with white forward arrow mandatory', 206 | 'blue circle with white forward or right turn arrow mandatory', 207 | 'blue circle with white forward or left turn arrow mandatory', 208 | 'blue circle with white keep right arrow mandatory', 209 | 'blue circle with white keep left arrow mandatory', 210 | 'blue circle with white arrows indicating a traffic circle', 211 | 'white circle with gray strike bar indicating no passing for cars has ended', 212 | 'white circle with gray strike bar indicating no passing for trucks has ended', 213 | ] 214 | -------------------------------------------------------------------------------- /dawin_rft/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | } 40 | 41 | 42 | def _download(url: str, root: str): 43 | os.makedirs(root, exist_ok=True) 44 | filename = os.path.basename(url) 45 | 46 | expected_sha256 = url.split("/")[-2] 47 | download_target = os.path.join(root, filename) 48 | 49 | if os.path.exists(download_target) and not os.path.isfile(download_target): 50 | raise RuntimeError(f"{download_target} exists and is not a regular file") 51 | 52 | if os.path.isfile(download_target): 53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 54 | return download_target 55 | else: 56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 57 | 58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 60 | while True: 61 | buffer = source.read(8192) 62 | if not buffer: 63 | break 64 | 65 | output.write(buffer) 66 | loop.update(len(buffer)) 67 | 68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 69 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 70 | 71 | return download_target 72 | 73 | 74 | def _convert_image_to_rgb(image): 75 | return image.convert("RGB") 76 | 77 | 78 | def _transform(n_px): 79 | return Compose([ 80 | Resize(n_px, interpolation=BICUBIC), 81 | CenterCrop(n_px), 82 | _convert_image_to_rgb, 83 | ToTensor(), 84 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 85 | ]) 86 | 87 | 88 | def available_models() -> List[str]: 89 | """Returns the names of available CLIP models""" 90 | return list(_MODELS.keys()) 91 | 92 | 93 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 94 | """Load a CLIP model 95 | 96 | Parameters 97 | ---------- 98 | name : str 99 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 100 | 101 | device : Union[str, torch.device] 102 | The device to put the loaded model 103 | 104 | jit : bool 105 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 106 | 107 | download_root: str 108 | path to download the model files; by default, it uses "~/.cache/clip" 109 | 110 | Returns 111 | ------- 112 | model : torch.nn.Module 113 | The CLIP model 114 | 115 | preprocess : Callable[[PIL.Image], torch.Tensor] 116 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 117 | """ 118 | if name in _MODELS: 119 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 120 | elif os.path.isfile(name): 121 | model_path = name 122 | else: 123 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 124 | 125 | try: 126 | # loading JIT archive 127 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 128 | state_dict = None 129 | except RuntimeError: 130 | # loading saved state dict 131 | if jit: 132 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 133 | jit = False 134 | state_dict = torch.load(model_path, map_location="cpu") 135 | 136 | if not jit: 137 | model = build_model(state_dict or model.state_dict()).to(device) 138 | if str(device) == "cpu": 139 | model.float() 140 | return model, _transform(model.visual.input_resolution) 141 | 142 | # patch the device names 143 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 144 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 145 | 146 | def patch_device(module): 147 | try: 148 | graphs = [module.graph] if hasattr(module, "graph") else [] 149 | except RuntimeError: 150 | graphs = [] 151 | 152 | if hasattr(module, "forward1"): 153 | graphs.append(module.forward1.graph) 154 | 155 | for graph in graphs: 156 | for node in graph.findAllNodes("prim::Constant"): 157 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 158 | node.copyAttributes(device_node) 159 | 160 | model.apply(patch_device) 161 | patch_device(model.encode_image) 162 | patch_device(model.encode_text) 163 | 164 | # patch dtype to float32 on CPU 165 | if str(device) == "cpu": 166 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 167 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 168 | float_node = float_input.node() 169 | 170 | def patch_float(module): 171 | try: 172 | graphs = [module.graph] if hasattr(module, "graph") else [] 173 | except RuntimeError: 174 | graphs = [] 175 | 176 | if hasattr(module, "forward1"): 177 | graphs.append(module.forward1.graph) 178 | 179 | for graph in graphs: 180 | for node in graph.findAllNodes("aten::to"): 181 | inputs = list(node.inputs()) 182 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 183 | if inputs[i].node()["value"] == 5: 184 | inputs[i].node().copyAttributes(float_node) 185 | 186 | model.apply(patch_float) 187 | patch_float(model.encode_image) 188 | patch_float(model.encode_text) 189 | 190 | model.float() 191 | 192 | return model, _transform(model.input_resolution.item()) 193 | 194 | 195 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 196 | """ 197 | Returns the tokenized representation of given input string(s) 198 | 199 | Parameters 200 | ---------- 201 | texts : Union[str, List[str]] 202 | An input string or a list of input strings to tokenize 203 | 204 | context_length : int 205 | The context length to use; all CLIP models use 77 as the context length 206 | 207 | truncate: bool 208 | Whether to truncate the text in case its encoding is longer than the context length 209 | 210 | Returns 211 | ------- 212 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 213 | """ 214 | if isinstance(texts, str): 215 | texts = [texts] 216 | 217 | sot_token = _tokenizer.encoder["<|startoftext|>"] 218 | eot_token = _tokenizer.encoder["<|endoftext|>"] 219 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 220 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 221 | 222 | for i, tokens in enumerate(all_tokens): 223 | if len(tokens) > context_length: 224 | if truncate: 225 | tokens = tokens[:context_length] 226 | tokens[-1] = eot_token 227 | else: 228 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 229 | result[i, :len(tokens)] = torch.tensor(tokens) 230 | 231 | return result 232 | -------------------------------------------------------------------------------- /dawin_mtl/src/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tqdm 4 | 5 | import torch 6 | import numpy as np 7 | 8 | import utils 9 | from datasets.common import get_dataloader, maybe_dictionarize 10 | from heads import get_classification_head 11 | from modeling import ImageClassifier 12 | 13 | from datasets.registry import get_dataset 14 | from scipy.special import softmax 15 | import pdb 16 | 17 | def set_attr(obj, names, val): 18 | if len(names) == 1: 19 | setattr(obj, names[0], val) 20 | else: 21 | set_attr(getattr(obj, names[0]), names[1:], val) 22 | 23 | def load_weights(mod, names, params): 24 | for name, p in zip(names, params): 25 | set_attr(mod, name.split("."), p) 26 | 27 | #! DaWin 28 | def eval_single_dataset_dynamic(interpolator, dataset_name, args, dmms=None): 29 | interpolator.eval() 30 | coefs_label = interpolator.dmm_coef_labels[f'{dataset_name}'] 31 | 32 | dataset = get_dataset( 33 | dataset_name, 34 | interpolator.model.model.val_preprocess, 35 | location=args.data_location, 36 | batch_size=args.batch_size 37 | ) 38 | dataloader = get_dataloader( 39 | dataset, is_train=False, args=args, image_encoder=None) 40 | device = args.device 41 | 42 | try: 43 | n_tasks = len(args.train_dataset.split(',')) 44 | except: 45 | n_tasks = len(args.train_dataset) 46 | 47 | #pdb.set_trace() 48 | with torch.no_grad(): 49 | top1, correct, n = 0., 0., 0. 50 | 51 | n_eval_batches = len(dataloader) 52 | tot_logits, tot_labels = torch.tensor([]).cuda(), torch.tensor([]).cuda() 53 | 54 | #! domain-wise merging (our default setup) 55 | #* 0.3 (task arithmetic default scaler) * N_task 56 | try: 57 | alph_ = torch.cat([torch.tensor([1.0]),torch.tensor([0.3 * n_tasks * dmms[dataset_name].params_[0]]).squeeze()]) 58 | except: 59 | alph_ = torch.cat([torch.tensor([1.0]),torch.tensor([0.3 * n_tasks * dmms[dataset_name].params_[0]]).reshape(1)]) 60 | 61 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph_))) for j, p in enumerate(zip(*interpolator.paramslist))) 62 | params = tuple(p.cuda(0) for p in params) 63 | load_weights(interpolator.model, interpolator.names, params) 64 | 65 | for i, data in enumerate(tqdm.tqdm(dataloader)): 66 | data = maybe_dictionarize(data) 67 | x = data['images'].to(device) 68 | y = data['labels'].to(device) 69 | 70 | if args.ncluster == 0: 71 | logits = interpolator(x, None, dataset_name) 72 | else: 73 | #! batch-wise merging 74 | # if i < (n_eval_batches - 1): batch_idx = coefs_label[args.batch_size*i:args.batch_size*(i+1)] 75 | # else: batch_idx = coefs_label[args.batch_size*i:] 76 | # y_recon, features = torch.tensor([]).cuda(), torch.tensor([]).cuda() 77 | # for j in np.unique(batch_idx): 78 | # sub_idcs = batch_idx == j 79 | # if sum(sub_idcs).item() > 0: 80 | # if args.clustering == 'dmm': 81 | # try: 82 | # alph_ = torch.cat([torch.tensor([1.0]),torch.tensor([0.3 * n_tasks * dmms[dataset_name].params_[j]]).squeeze()]) 83 | # except: 84 | # alph_ = torch.cat([torch.tensor([1.0]),torch.tensor([0.3 * n_tasks * dmms[dataset_name].params_[j]]).reshape(1)]) 85 | # else: # kMeans 86 | # alph_ = torch.cat([torch.tensor([1.0]),torch.tensor([0.3 * n_tasks * softmax(dmms[dataset_name].cluster_centers_[j])]).squeeze()]) 87 | 88 | # params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph_))) for j, p in enumerate(zip(*interpolator.paramslist))) 89 | # params = tuple(p.cuda(0) for p in params) 90 | # load_weights(interpolator.model, interpolator.names, params) 91 | # with torch.no_grad(): 92 | # feature_sub = interpolator.model(x[sub_idcs]) 93 | # y_recon = torch.cat([y_recon, y[sub_idcs]]) 94 | # features = torch.cat([features, feature_sub]) 95 | 96 | with torch.no_grad(): 97 | features = interpolator.model(x) 98 | 99 | layer_name = 'classifier_{}'.format(dataset_name) 100 | classification_head = getattr(interpolator, layer_name) 101 | logits = classification_head(features) 102 | 103 | pred = logits.argmax(dim=1, keepdim=True).to(device) 104 | correct += pred.eq(y.view_as(pred)).sum().item() 105 | 106 | n += y.size(0) 107 | 108 | top1 = correct / n 109 | 110 | metrics = {'top1': top1} 111 | print(f'Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%') 112 | 113 | return metrics 114 | 115 | def eval_single_dataset(image_encoder, dataset_name, args): 116 | classification_head = get_classification_head(args, dataset_name) 117 | model = ImageClassifier(image_encoder, classification_head) 118 | 119 | model.eval() 120 | 121 | dataset = get_dataset( 122 | dataset_name, 123 | model.val_preprocess, 124 | location=args.data_location, 125 | batch_size=args.batch_size 126 | ) 127 | dataloader = get_dataloader( 128 | dataset, is_train=False, args=args, image_encoder=None) 129 | device = args.device 130 | print(f"{dataset_name} tot test samples: {len(dataloader) * args.batch_size}") 131 | 132 | with torch.no_grad(): 133 | top1, correct, n = 0., 0., 0. 134 | for i, data in enumerate(tqdm.tqdm(dataloader)): 135 | data = maybe_dictionarize(data) 136 | x = data['images'].to(device) 137 | y = data['labels'].to(device) 138 | 139 | logits = utils.get_logits(x, model) 140 | 141 | pred = logits.argmax(dim=1, keepdim=True).to(device) 142 | 143 | correct += pred.eq(y.view_as(pred)).sum().item() 144 | 145 | n += y.size(0) 146 | 147 | top1 = correct / n 148 | 149 | metrics = {'top1': top1} 150 | print(f'Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%') 151 | 152 | return metrics 153 | 154 | def eval_single_dataset_head(image_encoder, head, dataset_name, args): 155 | model = ImageClassifier(image_encoder, head) 156 | 157 | model.eval() 158 | 159 | dataset = get_dataset(dataset_name, model.val_preprocess, location=args.data_location, batch_size=args.batch_size) 160 | dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None) 161 | device = args.device 162 | print(f"{dataset_name} tot test samples: {len(dataloader) * args.batch_size}") 163 | 164 | with torch.no_grad(): 165 | top1, correct, n = 0., 0., 0. 166 | for i, data in enumerate(tqdm.tqdm(dataloader)): 167 | data = maybe_dictionarize(data) 168 | x = data['images'].to(device) 169 | y = data['labels'].to(device) 170 | 171 | logits = utils.get_logits(x, model) 172 | 173 | pred = logits.argmax(dim=1, keepdim=True).to(device) 174 | 175 | correct += pred.eq(y.view_as(pred)).sum().item() 176 | 177 | n += y.size(0) 178 | 179 | top1 = correct / n 180 | 181 | metrics = {'top1': top1} 182 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%') 183 | 184 | return metrics 185 | 186 | def eval_single_dataset_preprocess_head(image_encoder, head, dataset_name, args): 187 | model = ImageClassifier(image_encoder, head) 188 | 189 | model.eval() 190 | 191 | dataset = get_dataset(dataset_name, model.val_preprocess, location=args.data_location, batch_size=args.batch_size) 192 | dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None) 193 | device = args.device 194 | 195 | with torch.no_grad(): 196 | top1, correct, n = 0., 0., 0. 197 | for i, data in enumerate(tqdm.tqdm(dataloader)): 198 | data = maybe_dictionarize(data) 199 | x = data['images'].to(device) 200 | y = data['labels'].to(device) 201 | 202 | logits = utils.get_logits(x, model) 203 | 204 | pred = logits.argmax(dim=1, keepdim=True).to(device) 205 | 206 | correct += pred.eq(y.view_as(pred)).sum().item() 207 | 208 | n += y.size(0) 209 | 210 | top1 = correct / n 211 | 212 | metrics = {'top1': top1} 213 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%') 214 | 215 | return metrics 216 | 217 | def evaluate(image_encoder, args): 218 | if args.eval_datasets is None: 219 | return 220 | info = vars(args) 221 | for i, dataset_name in enumerate(args.eval_datasets): 222 | print('Evaluating on', dataset_name) 223 | 224 | results = eval_single_dataset(image_encoder, dataset_name, args) 225 | 226 | if 'top1' in results: 227 | print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}") 228 | for key, val in results.items(): 229 | if 'worst' in key or 'f1' in key.lower() or 'pm0' in key: 230 | print(f"{dataset_name} {key}: {val:.4f}") 231 | info[dataset_name + ':' + key] = val 232 | 233 | if args.results_db is not None: 234 | dirname = os.path.dirname(args.results_db) 235 | if dirname: 236 | os.makedirs(dirname, exist_ok=True) 237 | with open(args.results_db, 'a+') as f: 238 | f.write(json.dumps(info) + '\n') 239 | print(f'Results saved to {args.results_db}.') 240 | else: 241 | print('Results not saved (to do so, use --results_db to specify a path).') 242 | 243 | return info -------------------------------------------------------------------------------- /dawin_mtl/src/main_layer_wise_adamerging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | import tqdm 5 | 6 | import torch 7 | from task_vectors import TaskVector 8 | from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head 9 | from args import parse_arguments 10 | from merging_cofficient import get_merging_cofficients 11 | import wandb 12 | 13 | def create_log_dir(path, filename='log.txt'): 14 | import logging 15 | if not os.path.exists(path): 16 | os.makedirs(path) 17 | logger = logging.getLogger(path) 18 | logger.setLevel(logging.DEBUG) 19 | fh = logging.FileHandler(path+'/'+filename) 20 | fh.setLevel(logging.DEBUG) 21 | ch = logging.StreamHandler() 22 | ch.setLevel(logging.DEBUG) 23 | logger.addHandler(fh) 24 | logger.addHandler(ch) 25 | return logger 26 | 27 | model = 'ViT-B-32' 28 | args = parse_arguments() 29 | 30 | if args.train_dataset is None: 31 | args.train_dataset = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 32 | if args.eval_datasets is None: 33 | args.eval_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] 34 | train_datasets = args.train_dataset 35 | exam_datasets = args.eval_datasets 36 | 37 | if args.wb_project: 38 | wandb_args = {"project": args.wb_project} 39 | wandb_args["name"] = args.wb_runname if args.wb_runname else None 40 | wandb.init(**wandb_args, config=vars(args), save_code=False) 41 | 42 | args.data_location = '../data' 43 | args.model = model 44 | args.save = '../checkpoints/' + model 45 | args.logs_path = '../logs/' + model 46 | pretrained_checkpoint = '../checkpoints/'+model+'/zeroshot.pt' 47 | 48 | if args.set_alphas != 'from_scratch': 49 | args.eval_only = 1 50 | 51 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 52 | log = create_log_dir(args.logs_path, 'log_{}_Layer_wise_AdaMerging.txt'.format(str_time_)) 53 | args.log = log 54 | 55 | task_vectors = [TaskVector(pretrained_checkpoint, '../checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets] 56 | 57 | def del_attr(obj, names): 58 | if len(names) == 1: 59 | delattr(obj, names[0]) 60 | else: 61 | del_attr(getattr(obj, names[0]), names[1:]) 62 | 63 | def set_attr(obj, names, val): 64 | if len(names) == 1: 65 | setattr(obj, names[0], val) 66 | else: 67 | set_attr(getattr(obj, names[0]), names[1:], val) 68 | 69 | def make_functional(mod): 70 | orig_params = tuple(mod.parameters()) 71 | names = [] 72 | for name, p in list(mod.named_parameters()): 73 | del_attr(mod, name.split(".")) 74 | names.append(name) 75 | return orig_params, names 76 | 77 | def load_weights(mod, names, params): 78 | for name, p in zip(names, params): 79 | set_attr(mod, name.split("."), p) 80 | 81 | class ModelWrapper(torch.nn.Module): 82 | def __init__(self, model, initial_weights=None): 83 | super(ModelWrapper, self).__init__() 84 | self.model = model 85 | 86 | if hasattr(self.model, 'transformer'): 87 | delattr(self.model, 'transformer') 88 | 89 | def forward(self, images): 90 | features = self.model(images) 91 | return features 92 | 93 | from heads import get_classification_head 94 | class AdaMerging(torch.nn.Module): 95 | def __init__(self, paramslist, model, names, exam_datasets, set_alphas='from_scratch'): 96 | super(AdaMerging, self).__init__() 97 | self.paramslist = paramslist 98 | self.model = model 99 | self.names = names 100 | self.pretrain_lambdas = torch.ones(len(paramslist[0]), 1) 101 | prior = 0.3 102 | rlambdas = torch.ones(len(paramslist[0]), len(paramslist)-1) * prior 103 | self.lambdas_raw = torch.nn.Parameter(rlambdas) 104 | self.set_alphas = set_alphas 105 | 106 | self.classifier = [] 107 | for dataset_name in exam_datasets: 108 | classification_head = get_classification_head(args, dataset_name) 109 | layer_name = 'classifier_{}'.format(dataset_name) 110 | self.add_module(layer_name, classification_head.to(args.device)) 111 | self.classifier.append(layer_name) 112 | 113 | def set_pt_lambdas(self): 114 | ralpha = get_merging_cofficients('lw_adamerging', args.model) 115 | self.alpha = torch.Tensor(ralpha) 116 | return self.alpha 117 | 118 | def lambdas(self): 119 | if self.set_alphas == 'from_scratch': 120 | task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0) 121 | lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1) 122 | elif self.set_alphas == 'pre-trained': 123 | lambdass = self.set_pt_lambdas() 124 | else: 125 | lambdass = None 126 | return lambdass 127 | 128 | def collect_trainable_params(self): 129 | return [self.lambdas_raw] 130 | 131 | def get_classification_head(self, dataset_name): 132 | layer_name = 'classifier_{}'.format(dataset_name) 133 | classification_head = getattr(self, layer_name) 134 | return classification_head 135 | 136 | def get_image_encoder(self): 137 | alph = self.lambdas() 138 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 139 | params = tuple(p.cuda(0) for p in params) 140 | load_weights(self.model, self.names, params) 141 | return self.model 142 | 143 | def forward(self, inp, dataset_name): 144 | alph = self.lambdas() 145 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 146 | 147 | params = tuple(p.cuda(0) for p in params) 148 | load_weights(self.model, self.names, params) 149 | feature = self.model(inp) 150 | 151 | layer_name = 'classifier_{}'.format(dataset_name) 152 | classification_head = getattr(self, layer_name) 153 | out = classification_head(feature) 154 | return out 155 | 156 | def softmax_entropy(x): 157 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 158 | 159 | pretrained_model = torch.load(pretrained_checkpoint) 160 | pretrained_model_dic = pretrained_model.state_dict() 161 | 162 | model = ModelWrapper(pretrained_model, exam_datasets) 163 | model = model.to(args.device) 164 | _, names = make_functional(model) 165 | 166 | paramslist = [] 167 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain 168 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.vector.items()) for i, tv in enumerate(task_vectors)] # task vectors 169 | torch.cuda.empty_cache() 170 | adamerging_mtl_model = AdaMerging(paramslist, model, names, exam_datasets, args.set_alphas) 171 | 172 | print('init lambda:') 173 | print(adamerging_mtl_model.lambdas()) 174 | print('collect_trainable_params:') 175 | print(list(adamerging_mtl_model.collect_trainable_params())) 176 | 177 | from datasets.registry import get_dataset 178 | from datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle 179 | 180 | metric_dict = {} 181 | if args.eval_only: 182 | Total_ACC = 0. 183 | for dataset_name in exam_datasets: 184 | image_encoder = adamerging_mtl_model.get_image_encoder() 185 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 186 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 187 | Total_ACC += metrics['top1']*100 188 | log.info('Eval: init: ' + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 189 | metric_dict[f"{str(dataset_name)}_Acc"] = metrics['top1']*100 190 | log.info('Eval: init: ' + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 191 | metric_dict[f"Avg_Acc"] = Total_ACC / len(exam_datasets) 192 | if args.wb_project: 193 | wandb.log(metric_dict) 194 | else: 195 | epochs = 500 196 | optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.) 197 | 198 | startt = time.time() 199 | for epoch in range(epochs): 200 | losses = 0. 201 | for dataset_name in exam_datasets: 202 | dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=16) 203 | dataloader = get_dataloader_shuffle(dataset) 204 | 205 | for i, data in enumerate(tqdm.tqdm(dataloader)): 206 | data = maybe_dictionarize(data) 207 | x = data['images'].to(args.device) 208 | y = data['labels'].to(args.device) 209 | 210 | outputs = adamerging_mtl_model(x, dataset_name) 211 | loss = softmax_entropy(outputs).mean(0) 212 | losses += loss 213 | 214 | if i > 0: # Execute only one step 215 | break 216 | 217 | optimizer.zero_grad() 218 | losses.backward() 219 | optimizer.step() 220 | 221 | print(list(adamerging_mtl_model.lambdas().data)) 222 | 223 | if ((epoch+1) % 500) == 0: 224 | log.info(str(list(adamerging_mtl_model.lambdas().data))) 225 | 226 | Total_ACC = 0. 227 | for dataset_name in exam_datasets: 228 | image_encoder = adamerging_mtl_model.get_image_encoder() 229 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 230 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 231 | Total_ACC += metrics['top1'] 232 | log.info('Eval: Epoch: ' + str(epoch) + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 233 | log.info('Eval: Epoch: ' + str(epoch) + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 234 | runtime = time.time() - startt -------------------------------------------------------------------------------- /dawin_rft/datasets/objectnet_metadata/folder_to_objectnet_label.json: -------------------------------------------------------------------------------- 1 | { 2 | "squeegee": "Squeegee", 3 | "umbrella": "Umbrella", 4 | "eyeglasses": "Eyeglasses", 5 | "coaster": "Coaster", 6 | "winter_glove": "Winter glove", 7 | "pet_food_container": "Pet food container", 8 | "scissors": "Scissors", 9 | "computer_mouse": "Computer mouse", 10 | "still_camera": "Still Camera", 11 | "weight_scale": "Weight scale", 12 | "cutting_board": "Cutting board", 13 | "spatula": "Spatula", 14 | "plunger": "Plunger", 15 | "paper": "Paper", 16 | "paper_bag": "Paper bag", 17 | "microwave": "Microwave", 18 | "keyboard": "Keyboard", 19 | "standing_lamp": "Standing lamp", 20 | "chair": "Chair", 21 | "mixing_salad_bowl": "Mixing / Salad Bowl", 22 | "milk": "Milk", 23 | "blender": "Blender", 24 | "flashlight": "Flashlight", 25 | "floss_container": "Floss container", 26 | "rake": "Rake", 27 | "sandal": "Sandal", 28 | "book_closed": "Book (closed)", 29 | "extension_cable": "Extension cable", 30 | "drawer_open": "Drawer (open)", 31 | "tv": "TV", 32 | "backpack": "Backpack", 33 | "playing_cards": "Playing cards", 34 | "jar": "Jar", 35 | "frying_pan": "Frying pan", 36 | "bench": "Bench", 37 | "coffee_grinder": "Coffee grinder", 38 | "jam": "Jam", 39 | "tape": "Tape / duct tape", 40 | "dog_bed": "Dog bed", 41 | "throw_pillow": "Throw pillow", 42 | "orange": "Orange", 43 | "speaker": "Speaker", 44 | "paint_can": "Paint can", 45 | "hand_towel_or_rag": "Dishrag or hand towel", 46 | "hat": "Hat", 47 | "toaster": "Toaster", 48 | "match": "Match", 49 | "plate": "Plate", 50 | "ironing_board": "Ironing board", 51 | "alarm_clock": "Alarm clock", 52 | "sunglasses": "Sunglasses", 53 | "power_bar": "Power bar", 54 | "wine_glass": "Wine glass", 55 | "ladle": "Ladle", 56 | "ruler": "Ruler", 57 | "blanket": "Blanket", 58 | "contact_lens_case": "Contact lens case", 59 | "watch": "Watch", 60 | "trash_bin": "Trash bin", 61 | "vase": "Vase", 62 | "oven_mitts": "Oven mitts", 63 | "spoon": "Spoon", 64 | "thermos": "Thermos", 65 | "fan": "Fan", 66 | "egg_carton": "Egg carton", 67 | "lighter": "Lighter", 68 | "glue_container": "Glue container", 69 | "bottle_cap": "Bottle cap", 70 | "wallet": "Wallet", 71 | "coffee_table": "Coffee table", 72 | "nail_clippers": "Nail clippers", 73 | "cellphone_case": "Cellphone case", 74 | "canned_food": "Canned food", 75 | "letter_opener": "Letter opener", 76 | "stapler": "Stapler", 77 | "whisk": "Whisk", 78 | "tongs": "Tongs", 79 | "lettuce": "Lettuce", 80 | "shoelace": "Shoelace", 81 | "button": "Button", 82 | "hair_dryer": "Hair dryer", 83 | "bracelet": "Bracelet", 84 | "spray_bottle": "Spray bottle", 85 | "laptop_charger": "Laptop charger", 86 | "portable_heater": "Portable heater", 87 | "suit_jacket": "Suit jacket", 88 | "dress_shoe_men": "Dress shoe (men)", 89 | "rock": "Rock", 90 | "water_filter": "Water filter", 91 | "earbuds": "Earbuds", 92 | "biscuits": "Biscuits", 93 | "mouthwash": "Mouthwash", 94 | "slipper": "Slipper", 95 | "eraser_white_board": "Eraser (white board)", 96 | "bicycle": "Bicycle", 97 | "ziploc_bag": "Ziploc bag", 98 | "dish_soap": "Dish soap", 99 | "travel_case": "Travel case", 100 | "receipt": "Receipt", 101 | "boots": "Boots", 102 | "sponge": "Sponge", 103 | "full_sized_towel": "Bath towel", 104 | "flour_container": "Flour container", 105 | "pill_bottle": "Pill bottle", 106 | "candle": "Candle", 107 | "calendar": "Calendar", 108 | "tweezers": "Tweezers", 109 | "dvd_player": "DVD player", 110 | "plastic_wrap": "Plastic wrap", 111 | "ribbon": "Ribbon", 112 | "blouse": "Blouse", 113 | "walking_cane": "Walking cane", 114 | "leaf": "Leaf", 115 | "lipstick": "Lipstick", 116 | "hair_brush": "Hair brush", 117 | "night_light": "Night light", 118 | "kettle": "Kettle", 119 | "honey_container": "Honey container", 120 | "tarp": "Tarp", 121 | "ice": "Ice", 122 | "drinking_straw": "Drinking straw", 123 | "detergent": "Detergent", 124 | "mug": "Mug", 125 | "toilet_paper_roll": "Toilet paper roll", 126 | "wok": "Wok", 127 | "swimming_trunks": "Swimming trunks", 128 | "clothes_hamper": "Clothes hamper", 129 | "removable_blade": "Removable blade", 130 | "shampoo_bottle": "Shampoo bottle", 131 | "skirt": "Skirt", 132 | "loofah": "Loofah", 133 | "broom": "Broom", 134 | "photograph_printed": "Photograph (printed)", 135 | "multitool": "Multitool", 136 | "makeup": "Makeup", 137 | "nail_file": "Nail file", 138 | "brooch": "Brooch", 139 | "cork": "Cork", 140 | "coffee_french_press": "Coffee/French press", 141 | "toy": "Toy", 142 | "thermometer": "Thermometer", 143 | "hairclip": "Hairclip", 144 | "document_folder_closed": "Document folder (closed)", 145 | "pliers": "Pliers", 146 | "strainer": "Strainer", 147 | "comb": "Comb", 148 | "water_bottle": "Water bottle", 149 | "peeler": "Peeler", 150 | "monitor": "Monitor", 151 | "box": "Box", 152 | "pill_organizer": "Pill organizer", 153 | "stopper_sink_tub": "Stopper (sink/tub)", 154 | "walker": "Walker", 155 | "step_stool": "Step stool", 156 | "bills_money": "Bills (money)", 157 | "skateboard": "Skateboard", 158 | "running_shoe": "Running shoe", 159 | "coin_money": "Coin (money)", 160 | "magazine": "Magazine", 161 | "drying_rack_for_clothes": "Drying rack for clothes", 162 | "toothpaste": "Toothpaste", 163 | "paper_towel": "Paper towel", 164 | "remote_control": "Remote control", 165 | "sugar_container": "Sugar container", 166 | "dress_pants": "Dress pants", 167 | "scarf": "Scarf", 168 | "dress_shirt": "Dress shirt", 169 | "cheese": "Cheese", 170 | "can_opener": "Can opener", 171 | "shovel": "Shovel", 172 | "paintbrush": "Paintbrush", 173 | "tennis_racket": "Tennis racket", 174 | "battery": "Battery", 175 | "stuffed_animal": "Stuffed animal", 176 | "jeans": "Jeans", 177 | "tanktop": "Tanktop", 178 | "dust_pan": "Dust pan", 179 | "earring": "Earring", 180 | "tomato": "Tomato", 181 | "marker": "Marker", 182 | "makeup_brush": "Makeup brush", 183 | "ring": "Ring", 184 | "air_freshener": "Air freshener", 185 | "tablecloth": "Tablecloth", 186 | "teabag": "Teabag", 187 | "belt": "Belt", 188 | "razor": "Razor", 189 | "clothes_hanger": "Clothes hanger", 190 | "bookend": "Bookend", 191 | "sweater": "Sweater", 192 | "sock": "Sock", 193 | "usb_flash_drive": "Usb flash drive", 194 | "cellphone_charger": "Cellphone charger", 195 | "pepper_shaker": "Pepper shaker", 196 | "phone_landline": "Phone (landline)", 197 | "banana": "Banana", 198 | "printer": "Printer", 199 | "paperclip": "Paperclip", 200 | "fork": "Fork", 201 | "headphones_over_ear": "Headphones (over ear)", 202 | "cooking_oil_bottle": "Cooking oil bottle", 203 | "deodorant": "Deodorant", 204 | "usb_cable": "Usb cable", 205 | "shorts": "Shorts", 206 | "bread_loaf": "Bread loaf", 207 | "pillow": "Pillow", 208 | "drinking_cup": "Drinking Cup", 209 | "envelope": "Envelope", 210 | "mouse_pad": "Mouse pad", 211 | "chopstick": "Chopstick", 212 | "t-shirt": "T-shirt", 213 | "padlock": "Padlock", 214 | "ice_cube_tray": "Ice cube tray", 215 | "chess_piece": "Chess piece", 216 | "cereal": "Cereal", 217 | "hairtie": "Hairtie", 218 | "teapot": "Teapot", 219 | "board_game": "Board game", 220 | "butchers_knife": "Butcher's knife", 221 | "soup_bowl": "Soup Bowl", 222 | "beer_bottle": "Beer bottle", 223 | "nail_polish": "Nail polish", 224 | "hand_mirror": "Hand mirror", 225 | "combination_lock": "Combination lock", 226 | "nut_for_screw": "Nut for a screw", 227 | "nail_fastener": "Nail (fastener)", 228 | "figurine_or_statue": "Figurine or statue", 229 | "soap_bar": "Soap bar", 230 | "bucket": "Bucket", 231 | "binder_closed": "Binder (closed)", 232 | "video_camera": "Video Camera", 233 | "baseball_glove": "Baseball glove", 234 | "tape_measure": "Tape measure", 235 | "tissue": "Tissue", 236 | "coffee_beans": "Coffee beans", 237 | "scrub_brush": "Scrub brush", 238 | "drill": "Drill", 239 | "suitcase": "Suitcase", 240 | "newspaper": "Newspaper", 241 | "sleeping_bag": "Sleeping bag", 242 | "dress_shoe_women": "Dress shoe (women)", 243 | "trophy": "Trophy", 244 | "plastic_bag": "Plastic bag", 245 | "doormat": "Doormat", 246 | "webcam": "Webcam", 247 | "rolling_pin": "Rolling pin", 248 | "pencil": "Pencil", 249 | "table_knife": "Table knife", 250 | "bread_knife": "Bread knife", 251 | "toothbrush": "Toothbrush", 252 | "bathrobe": "Bathrobe", 253 | "paper_plates": "Paper plates", 254 | "placemat": "Placemat", 255 | "light_bulb": "Light bulb", 256 | "soap_dispenser": "Soap dispenser", 257 | "nightstand": "Nightstand", 258 | "pen": "Pen", 259 | "squeeze_bottle": "Squeeze bottle", 260 | "wheel": "Wheel", 261 | "dress": "Dress", 262 | "helmet": "Helmet", 263 | "lemon": "Lemon", 264 | "hammer": "Hammer", 265 | "lampshade": "Lampshade", 266 | "salt_shaker": "Salt shaker", 267 | "power_cable": "Power cable", 268 | "vacuum_cleaner": "Vacuum cleaner", 269 | "iron_for_clothes": "Iron (for clothes)", 270 | "laptop_open": "Laptop (open)", 271 | "poster": "Poster", 272 | "coffee_machine": "Coffee machine", 273 | "tie": "Tie", 274 | "cd_case": "CD case", 275 | "baseball_bat": "Baseball bat", 276 | "tablet_ipad": "Tablet / iPad", 277 | "bottle_opener": "Bottle opener", 278 | "briefcase": "Briefcase", 279 | "baking_sheet": "Baking sheet", 280 | "screw": "Screw", 281 | "pitcher": "Pitcher", 282 | "notepad": "Notepad", 283 | "tote_bag": "Tote bag", 284 | "raincoat": "Raincoat", 285 | "necklace": "Necklace", 286 | "band_aid": "Band Aid", 287 | "notebook": "Notebook", 288 | "measuring_cup": "Measuring cup", 289 | "weight_exercise": "Weight (exercise)", 290 | "handbag": "Handbag", 291 | "bike_pump": "Bike pump", 292 | "bottle_stopper": "Bottle stopper", 293 | "chocolate": "Chocolate", 294 | "safety_pin": "Safety pin", 295 | "plastic_cup": "Plastic cup", 296 | "butter": "Butter", 297 | "cellphone": "Cellphone", 298 | "drying_rack_for_dishes": "Drying rack for plates", 299 | "trash_bag": "Trash bag", 300 | "tray": "Tray", 301 | "wine_bottle": "Wine bottle", 302 | "whistle": "Whistle", 303 | "key_chain": "Key chain", 304 | "napkin": "Napkin", 305 | "desk_lamp": "Desk lamp", 306 | "first_aid_kit": "First aid kit", 307 | "bed_sheet": "Bed sheet", 308 | "beer_can": "Beer can", 309 | "wrench": "Wrench", 310 | "pop_can": "Pop can", 311 | "basket": "Basket", 312 | "leggings": "Leggings", 313 | "egg": "Egg", 314 | "sewing_kit": "Sewing kit" 315 | } 316 | --------------------------------------------------------------------------------